├── .github
├── CODEOWNERS
└── workflows
│ ├── api-docs.yml
│ ├── release.yml
│ └── tests.yml
├── .gitignore
├── .pre-commit-config.yaml
├── LICENSE
├── README.md
├── canals
├── __about__.py
├── __init__.py
├── component
│ ├── __init__.py
│ ├── component.py
│ ├── connection.py
│ ├── descriptions.py
│ ├── sockets.py
│ └── types.py
├── errors.py
├── pipeline
│ ├── __init__.py
│ ├── descriptions.py
│ ├── draw
│ │ ├── __init__.py
│ │ ├── draw.py
│ │ ├── graphviz.py
│ │ └── mermaid.py
│ ├── pipeline.py
│ └── validation.py
├── serialization.py
├── testing
│ ├── __init__.py
│ └── factory.py
└── type_utils.py
├── docs
├── api-docs
│ ├── canals.md
│ ├── component.md
│ ├── pipeline.md
│ └── testing.md
├── concepts
│ ├── components.md
│ ├── concepts.md
│ └── pipelines.md
└── index.md
├── images
├── canals-logo-dark.png
└── canals-logo-light.png
├── mkdocs.yml
├── pyproject.toml
├── sample_components
├── __init__.py
├── accumulate.py
├── add_value.py
├── concatenate.py
├── double.py
├── fstring.py
├── greet.py
├── hello.py
├── joiner.py
├── merge_loop.py
├── parity.py
├── remainder.py
├── repeat.py
├── self_loop.py
├── subtract.py
├── sum.py
├── text_splitter.py
└── threshold.py
└── test
├── __init__.py
├── component
├── test_component.py
└── test_connection.py
├── conftest.py
├── pipeline
├── __init__.py
├── integration
│ ├── __init__.py
│ ├── test_complex_pipeline.py
│ ├── test_default_value.py
│ ├── test_distinct_loops_pipeline.py
│ ├── test_double_loop_pipeline.py
│ ├── test_dynamic_inputs_pipeline.py
│ ├── test_fixed_decision_and_merge_pipeline.py
│ ├── test_fixed_decision_pipeline.py
│ ├── test_fixed_merging_pipeline.py
│ ├── test_joiners.py
│ ├── test_linear_pipeline.py
│ ├── test_looping_and_merge_pipeline.py
│ ├── test_looping_pipeline.py
│ ├── test_mutable_inputs.py
│ ├── test_parallel_branches_pipeline.py
│ ├── test_self_loop.py
│ ├── test_variable_decision_and_merge_pipeline.py
│ ├── test_variable_decision_pipeline.py
│ └── test_variable_merging_pipeline.py
└── unit
│ ├── __init__.py
│ ├── test_connections.py
│ ├── test_draw.py
│ ├── test_pipeline.py
│ └── test_validation_pipeline_io.py
├── sample_components
├── __init__.py
├── test_accumulate.py
├── test_add_value.py
├── test_concatenate.py
├── test_double.py
├── test_fstring.py
├── test_greet.py
├── test_merge_loop.py
├── test_parity.py
├── test_remainder.py
├── test_repeat.py
├── test_subtract.py
├── test_sum.py
└── test_threshold.py
├── test_files
├── mermaid_mock
│ └── test_response.png
└── pipeline_draw
│ └── pygraphviz.jpg
├── test_serialization.py
├── test_utils.py
└── testing
└── test_factory.py
/.github/CODEOWNERS:
--------------------------------------------------------------------------------
1 | # See https://help.github.com/articles/about-codeowners/ for syntax
2 |
3 | # Core Engineering will be the default owners for everything
4 | # in the repo. Unless a later match takes precedence,
5 | # @deepset-ai/core-engineering will be requested for review
6 | # when someone opens a pull request.
7 | * @deepset-ai/core-engineering
--------------------------------------------------------------------------------
/.github/workflows/api-docs.yml:
--------------------------------------------------------------------------------
1 | name: API Docs
2 |
3 | on:
4 | push:
5 | branches:
6 | - main
7 |
8 | permissions:
9 | contents: write
10 |
11 | jobs:
12 | deploy:
13 | runs-on: ubuntu-latest
14 | steps:
15 | - uses: actions/checkout@v3
16 | - uses: actions/setup-python@v4
17 | with:
18 | python-version: 3.x
19 |
20 | - uses: actions/cache@v2
21 | with:
22 | key: ${{ github.ref }}
23 | path: .cache
24 |
25 | - run: pip install mkdocs-material mkdocstrings[python] mkdocs-mermaid2-plugin
26 |
27 | - run: mkdocs gh-deploy --force
28 |
--------------------------------------------------------------------------------
/.github/workflows/release.yml:
--------------------------------------------------------------------------------
1 | name: Release
2 |
3 | on:
4 | push:
5 | tags:
6 | - "v[0-9].[0-9]+.[0-9]+*"
7 |
8 | jobs:
9 | release-on-pypi:
10 | runs-on: ubuntu-latest
11 |
12 | steps:
13 | - name: Checkout
14 | uses: actions/checkout@v3
15 |
16 | - name: Install Hatch
17 | run: pip install hatch
18 |
19 | - name: Build
20 | run: hatch build
21 |
22 | - name: Publish on PyPi
23 | env:
24 | HATCH_INDEX_USER: __token__
25 | HATCH_INDEX_AUTH: ${{ secrets.PYPI_API_TOKEN }}
26 | run: hatch publish -y
27 |
--------------------------------------------------------------------------------
/.github/workflows/tests.yml:
--------------------------------------------------------------------------------
1 | name: Tests
2 |
3 | on:
4 | push:
5 | branches:
6 | - main
7 | pull_request:
8 | paths:
9 | - "**.py"
10 | - "**/pyproject.toml"
11 |
12 | env:
13 | COVERALLS_NOISY: true
14 |
15 | jobs:
16 | mypy:
17 | runs-on: ubuntu-latest
18 | steps:
19 | - name: Checkout
20 | uses: actions/checkout@v3
21 |
22 | - name: Setup Python
23 | uses: actions/setup-python@v4
24 | with:
25 | python-version: '3.8'
26 |
27 | - name: Install Canals
28 | run: |
29 | sudo apt install graphviz libgraphviz-dev
30 | pip install .[dev] pygraphviz
31 |
32 | - name: Mypy
33 | run: |
34 | mkdir .mypy_cache/
35 | mypy --install-types --non-interactive --ignore-missing-imports canals/ sample_components/
36 |
37 | pylint:
38 | runs-on: ubuntu-latest
39 | steps:
40 | - name: Checkout
41 | uses: actions/checkout@v3
42 |
43 | - name: Setup Python
44 | uses: actions/setup-python@v4
45 | with:
46 | python-version: '3.8'
47 |
48 | - name: Install Canals
49 | run: |
50 | sudo apt install graphviz libgraphviz-dev
51 | pip install .[dev] pygraphviz
52 |
53 | - name: Pylint
54 | run: pylint -ry -j 0 canals/ sample_components/
55 |
56 | black:
57 | runs-on: ubuntu-latest
58 | steps:
59 | - name: Checkout
60 | uses: actions/checkout@v3
61 |
62 | - name: Setup Python
63 | uses: actions/setup-python@v4
64 | with:
65 | python-version: '3.8'
66 |
67 | - name: Install Canals
68 | run: pip install .[dev]
69 |
70 | - name: Check status
71 | run: black canals/ --check
72 |
73 | tests:
74 | name: Unit / Python ${{ matrix.version }} / ${{ matrix.os }}
75 | strategy:
76 | fail-fast: false
77 | matrix:
78 | version:
79 | - "3.8"
80 | - "3.9"
81 | - "3.10"
82 | - "3.11"
83 | os:
84 | - ubuntu-latest
85 | - windows-latest
86 | - macos-latest
87 | runs-on: ${{ matrix.os }}
88 | steps:
89 | - name: Checkout
90 | uses: actions/checkout@v3
91 |
92 | - name: Setup Python
93 | uses: actions/setup-python@v4
94 | with:
95 | python-version: ${{ matrix.version }}
96 |
97 | - name: Install Canals (Ubuntu)
98 | if: matrix.os == 'ubuntu-latest'
99 | run: |
100 | sudo apt install graphviz libgraphviz-dev
101 | pip install .[dev] pygraphviz
102 |
103 | - name: Install Canals (MacOS)
104 | if: matrix.os == 'macos-latest'
105 | run: |
106 | # brew only offers graphviz 8, which seems to be incompatible with pygraphviz :(
107 | # brew install graphviz@2.49.0
108 | pip install .[dev]
109 |
110 | - name: Install Canals (Windows)
111 | if: matrix.os == 'windows-latest'
112 | run: |
113 | # Doesn't seem to work in CI :(
114 | # choco install graphviz
115 | # python -m pip install --global-option=build_ext `
116 | # --global-option="-IC:\Program Files\Graphviz\include" `
117 | # --global-option="-LC:\Program Files\Graphviz\lib" `
118 | # pygraphviz
119 | pip install .[dev]
120 |
121 | - name: Run
122 | run: pytest --cov-report xml:coverage.xml --cov="canals" test/
123 |
124 | - name: Coverage
125 | if: matrix.os == 'ubuntu-latest' && matrix.version == 3.11
126 | uses: coverallsapp/github-action@v2
127 | with:
128 | path-to-lcov: coverage.xml
129 | parallel: false
130 | debug: true
131 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | .vscode/*
2 |
3 | # Byte-compiled / optimized / DLL files
4 | __pycache__/
5 | *.py[cod]
6 | *$py.class
7 |
8 | # C extensions
9 | *.so
10 |
11 | # Distribution / packaging
12 | .Python
13 | build/
14 | develop-eggs/
15 | dist/
16 | downloads/
17 | eggs/
18 | .eggs/
19 | lib/
20 | lib64/
21 | parts/
22 | sdist/
23 | var/
24 | wheels/
25 | pip-wheel-metadata/
26 | share/python-wheels/
27 | *.egg-info/
28 | .installed.cfg
29 | *.egg
30 | MANIFEST
31 |
32 | # PyInstaller
33 | # Usually these files are written by a python script from a template
34 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
35 | *.manifest
36 | *.spec
37 |
38 | # Installer logs
39 | pip-log.txt
40 | pip-delete-this-directory.txt
41 |
42 | # Unit test / coverage reports
43 | htmlcov/
44 | .tox/
45 | .nox/
46 | .coverage
47 | .coverage.*
48 | .cache
49 | nosetests.xml
50 | coverage.xml
51 | *.cover
52 | *.py,cover
53 | .hypothesis/
54 | .pytest_cache/
55 |
56 | # Translations
57 | *.mo
58 | *.pot
59 |
60 | # Django stuff:
61 | *.log
62 | local_settings.py
63 | db.sqlite3
64 | db.sqlite3-journal
65 |
66 | # Flask stuff:
67 | instance/
68 | .webassets-cache
69 |
70 | # Scrapy stuff:
71 | .scrapy
72 |
73 | # Sphinx documentation
74 | docs/_build/
75 |
76 | # PyBuilder
77 | target/
78 |
79 | # Jupyter Notebook
80 | .ipynb_checkpoints
81 |
82 | # IPython
83 | profile_default/
84 | ipython_config.py
85 |
86 | # pyenv
87 | .python-version
88 |
89 | # pipenv
90 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
91 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
92 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
93 | # install all needed dependencies.
94 | #Pipfile.lock
95 |
96 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
97 | __pypackages__/
98 |
99 | # Celery stuff
100 | celerybeat-schedule
101 | celerybeat.pid
102 |
103 | # SageMath parsed files
104 | *.sage.py
105 |
106 | # Environments
107 | .env
108 | .venv
109 | env/
110 | venv/
111 | ENV/
112 | env.bak/
113 | venv.bak/
114 |
115 | # Spyder project settings
116 | .spyderproject
117 | .spyproject
118 |
119 | # Rope project settings
120 | .ropeproject
121 |
122 | # mkdocs documentation
123 | /site
124 |
125 | # mypy
126 | .mypy_cache/
127 | .dmypy.json
128 | dmypy.json
129 |
130 | # Pyre type checker
131 | .pyre/
132 |
133 | # Canals
134 | drafts/
135 | .canals_debug/
136 | test/**/*.png
137 |
--------------------------------------------------------------------------------
/.pre-commit-config.yaml:
--------------------------------------------------------------------------------
1 | fail_fast: true
2 |
3 | repos:
4 | - repo: https://github.com/pre-commit/pre-commit-hooks
5 | rev: v4.2.0
6 | hooks:
7 | - id: check-ast # checks Python syntax
8 | - id: check-json # checks JSON syntax
9 | # - id: check-yaml # checks YAML syntax
10 | - id: check-toml # checks TOML syntax
11 | - id: end-of-file-fixer # checks there is a newline at the end of the file
12 | - id: trailing-whitespace # trims trailing whitespace
13 | - id: check-merge-conflict # checks for no merge conflict strings
14 | - id: check-shebang-scripts-are-executable # checks all shell scripts have executable permissions
15 | - id: mixed-line-ending # normalizes line endings
16 | #- id: no-commit-to-branch # prevents committing to main
17 |
18 | - repo: https://github.com/psf/black
19 | rev: 22.6.0 # IMPORTANT: keep this aligned with the black version in pyproject.toml
20 | hooks:
21 | - id: black-jupyter
22 |
23 | - repo: https://github.com/pre-commit/mirrors-mypy
24 | rev: 'v1.1.1'
25 | hooks:
26 | - id: mypy
27 | exclude: ^test/
28 | args: [--ignore-missing-imports]
29 | additional_dependencies: ['types-requests']
30 |
31 | - repo: https://github.com/pycqa/pylint
32 | rev: 'v2.17.0'
33 | hooks:
34 | - id: pylint
35 | exclude: ^test/
36 | args: [
37 | "--disable=import-error" # FIXME
38 | ]
39 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | > ⚠️ The project was merged into Haystack, namely into the [`core`](https://github.com/deepset-ai/haystack/tree/main/haystack/core) package.
2 |
3 | # Canals
4 |
5 |
6 |
7 |
8 |
9 |
10 | [](https://pypi.org/project/canals)
11 | [](https://pypi.org/project/canals)
12 |
13 |
14 |
15 | [](https://coveralls.io/github/deepset-ai/canals?branch=main)
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 | Canals is a **component orchestration engine**. Components are Python objects that can execute a task, like reading a file, performing calculations, or making API calls. Canals connects these objects together: it builds a graph of components and takes care of managing their execution order, making sure that each object receives the input it expects from the other components of the pipeline.
33 |
34 | Canals powers version 2.0 of the [Haystack framework](https://github.com/deepset-ai/haystack).
35 |
36 | ## Installation
37 |
38 | To install Canals, run:
39 |
40 | ```console
41 | pip install canals
42 | ```
43 |
44 | To be able to draw pipelines (`Pipeline.draw()` method), please make sure you have either an internet connection (to reach the Mermaid graph renderer at `https://mermaid.ink`) or [graphviz](https://graphviz.org/download/) (version 2.49.0) installed. If you
45 | plan to use Mermaid there is no additional steps to take, while for graphviz
46 | you need to do:
47 |
48 | ### GraphViz
49 | ```console
50 | sudo apt install graphviz # You may need `graphviz-dev` too
51 | pip install pygraphviz
52 | ```
53 |
--------------------------------------------------------------------------------
/canals/__about__.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: 2022-present deepset GmbH
2 | #
3 | # SPDX-License-Identifier: Apache-2.0
4 | __version__ = "0.11.0"
5 |
--------------------------------------------------------------------------------
/canals/__init__.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: 2022-present deepset GmbH
2 | #
3 | # SPDX-License-Identifier: Apache-2.0
4 | from canals.__about__ import __version__
5 |
6 | from canals.component import component, Component
7 | from canals.pipeline.pipeline import Pipeline
8 |
9 | __all__ = ["component", "Component", "Pipeline"]
10 |
--------------------------------------------------------------------------------
/canals/component/__init__.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: 2022-present deepset GmbH
2 | #
3 | # SPDX-License-Identifier: Apache-2.0
4 | from canals.component.component import component, Component
5 | from canals.component.sockets import InputSocket, OutputSocket
6 |
7 | __all__ = ["component", "Component", "InputSocket", "OutputSocket"]
8 |
--------------------------------------------------------------------------------
/canals/component/component.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: 2022-present deepset GmbH
2 | #
3 | # SPDX-License-Identifier: Apache-2.0
4 | """
5 | Attributes:
6 |
7 | component: Marks a class as a component. Any class decorated with `@component` can be used by a Pipeline.
8 |
9 | All components must follow the contract below. This docstring is the source of truth for components contract.
10 |
11 |
12 |
13 | `@component` decorator
14 |
15 | All component classes must be decorated with the `@component` decorator. This allows Canals to discover them.
16 |
17 |
18 |
19 | `__init__(self, **kwargs)`
20 |
21 | Optional method.
22 |
23 | Components may have an `__init__` method where they define:
24 |
25 | - `self.init_parameters = {same parameters that the __init__ method received}`:
26 | In this dictionary you can store any state the components wish to be persisted when they are saved.
27 | These values will be given to the `__init__` method of a new instance when the pipeline is loaded.
28 | Note that by default the `@component` decorator saves the arguments automatically.
29 | However, if a component sets their own `init_parameters` manually in `__init__()`, that will be used instead.
30 | Note: all of the values contained here **must be JSON serializable**. Serialize them manually if needed.
31 |
32 | Components should take only "basic" Python types as parameters of their `__init__` function, or iterables and
33 | dictionaries containing only such values. Anything else (objects, functions, etc) will raise an exception at init
34 | time. If there's the need for such values, consider serializing them to a string.
35 |
36 | _(TODO explain how to use classes and functions in init. In the meantime see `test/components/test_accumulate.py`)_
37 |
38 | The `__init__` must be extrememly lightweight, because it's a frequent operation during the construction and
39 | validation of the pipeline. If a component has some heavy state to initialize (models, backends, etc...) refer to
40 | the `warm_up()` method.
41 |
42 |
43 |
44 | `warm_up(self)`
45 |
46 | Optional method.
47 |
48 | This method is called by Pipeline before the graph execution. Make sure to avoid double-initializations,
49 | because Pipeline will not keep track of which components it called `warm_up()` on.
50 |
51 |
52 |
53 | `run(self, data)`
54 |
55 | Mandatory method.
56 |
57 | This is the method where the main functionality of the component should be carried out. It's called by
58 | `Pipeline.run()`.
59 |
60 | When the component should run, Pipeline will call this method with an instance of the dataclass returned by the
61 | method decorated with `@component.input`. This dataclass contains:
62 |
63 | - all the input values coming from other components connected to it,
64 | - if any is missing, the corresponding value defined in `self.defaults`, if it exists.
65 |
66 | `run()` must return a single instance of the dataclass declared through the method decorated with
67 | `@component.output`.
68 |
69 | """
70 |
71 | import logging
72 | import inspect
73 | from typing import Protocol, runtime_checkable, Any
74 | from types import new_class
75 | from copy import deepcopy
76 |
77 | from canals.component.sockets import InputSocket, OutputSocket
78 | from canals.errors import ComponentError
79 |
80 | logger = logging.getLogger(__name__)
81 |
82 |
83 | @runtime_checkable
84 | class Component(Protocol):
85 | """
86 | Note this is only used by type checking tools.
87 |
88 | In order to implement the `Component` protocol, custom components need to
89 | have a `run` method. The signature of the method and its return value
90 | won't be checked, i.e. classes with the following methods:
91 |
92 | def run(self, param: str) -> Dict[str, Any]:
93 | ...
94 |
95 | and
96 |
97 | def run(self, **kwargs):
98 | ...
99 |
100 | will be both considered as respecting the protocol. This makes the type
101 | checking much weaker, but we have other places where we ensure code is
102 | dealing with actual Components.
103 |
104 | The protocol is runtime checkable so it'll be possible to assert:
105 |
106 | isinstance(MyComponent, Component)
107 | """
108 |
109 | def run(self, *args: Any, **kwargs: Any): # pylint: disable=missing-function-docstring
110 | ...
111 |
112 |
113 | class ComponentMeta(type):
114 | def __call__(cls, *args, **kwargs):
115 | """
116 | This method is called when clients instantiate a Component and
117 | runs before __new__ and __init__.
118 | """
119 | # This will call __new__ then __init__, giving us back the Component instance
120 | instance = super().__call__(*args, **kwargs)
121 |
122 | # Before returning, we have the chance to modify the newly created
123 | # Component instance, so we take the chance and set up the I/O sockets
124 |
125 | # If `component.set_output_types()` was called in the component constructor,
126 | # `__canals_output__` is already populated, no need to do anything.
127 | if not hasattr(instance, "__canals_output__"):
128 | # If that's not the case, we need to populate `__canals_output__`
129 | #
130 | # If the `run` method was decorated, it has a `_output_types_cache` field assigned
131 | # that stores the output specification.
132 | # We deepcopy the content of the cache to transfer ownership from the class method
133 | # to the actual instance, so that different instances of the same class won't share this data.
134 | instance.__canals_output__ = deepcopy(getattr(instance.run, "_output_types_cache", {}))
135 |
136 | # Create the sockets if set_input_types() wasn't called in the constructor.
137 | # If it was called and there are some parameters also in the `run()` method, these take precedence.
138 | if not hasattr(instance, "__canals_input__"):
139 | instance.__canals_input__ = {}
140 | run_signature = inspect.signature(getattr(cls, "run"))
141 | for param in list(run_signature.parameters)[1:]: # First is 'self' and it doesn't matter.
142 | if run_signature.parameters[param].kind == inspect.Parameter.POSITIONAL_OR_KEYWORD: # ignore `**kwargs`
143 | instance.__canals_input__[param] = InputSocket(
144 | name=param,
145 | type=run_signature.parameters[param].annotation,
146 | is_mandatory=run_signature.parameters[param].default == inspect.Parameter.empty,
147 | )
148 | return instance
149 |
150 |
151 | class _Component:
152 | """
153 | See module's docstring.
154 |
155 | Args:
156 | class_: the class that Canals should use as a component.
157 | serializable: whether to check, at init time, if the component can be saved with
158 | `save_pipelines()`.
159 |
160 | Returns:
161 | A class that can be recognized as a component.
162 |
163 | Raises:
164 | ComponentError: if the class provided has no `run()` method or otherwise doesn't respect the component contract.
165 | """
166 |
167 | def __init__(self):
168 | self.registry = {}
169 |
170 | def set_input_types(self, instance, **types):
171 | """
172 | Method that specifies the input types when 'kwargs' is passed to the run method.
173 |
174 | Use as:
175 |
176 | ```python
177 | @component
178 | class MyComponent:
179 |
180 | def __init__(self, value: int):
181 | component.set_input_types(value_1=str, value_2=str)
182 | ...
183 |
184 | @component.output_types(output_1=int, output_2=str)
185 | def run(self, **kwargs):
186 | return {"output_1": kwargs["value_1"], "output_2": ""}
187 | ```
188 |
189 | Note that if the `run()` method also specifies some parameters, those will take precedence.
190 |
191 | For example:
192 |
193 | ```python
194 | @component
195 | class MyComponent:
196 |
197 | def __init__(self, value: int):
198 | component.set_input_types(value_1=str, value_2=str)
199 | ...
200 |
201 | @component.output_types(output_1=int, output_2=str)
202 | def run(self, value_0: str, value_1: Optional[str] = None, **kwargs):
203 | return {"output_1": kwargs["value_1"], "output_2": ""}
204 | ```
205 |
206 | would add a mandatory `value_0` parameters, make the `value_1`
207 | parameter optional with a default None, and keep the `value_2`
208 | parameter mandatory as specified in `set_input_types`.
209 |
210 | """
211 | instance.__canals_input__ = {name: InputSocket(name=name, type=type_) for name, type_ in types.items()}
212 |
213 | def set_output_types(self, instance, **types):
214 | """
215 | Method that specifies the output types when the 'run' method is not decorated
216 | with 'component.output_types'.
217 |
218 | Use as:
219 |
220 | ```python
221 | @component
222 | class MyComponent:
223 |
224 | def __init__(self, value: int):
225 | component.set_output_types(output_1=int, output_2=str)
226 | ...
227 |
228 | # no decorators here
229 | def run(self, value: int):
230 | return {"output_1": 1, "output_2": "2"}
231 | ```
232 | """
233 | instance.__canals_output__ = {name: OutputSocket(name=name, type=type_) for name, type_ in types.items()}
234 |
235 | def output_types(self, **types):
236 | """
237 | Decorator factory that specifies the output types of a component.
238 |
239 | Use as:
240 |
241 | ```python
242 | @component
243 | class MyComponent:
244 | @component.output_types(output_1=int, output_2=str)
245 | def run(self, value: int):
246 | return {"output_1": 1, "output_2": "2"}
247 | ```
248 | """
249 |
250 | def output_types_decorator(run_method):
251 | """
252 | This happens at class creation time, and since we don't have the decorated
253 | class available here, we temporarily store the output types as an attribute of
254 | the decorated method. The ComponentMeta metaclass will use this data to create
255 | sockets at instance creation time.
256 | """
257 | setattr(
258 | run_method,
259 | "_output_types_cache",
260 | {name: OutputSocket(name=name, type=type_) for name, type_ in types.items()},
261 | )
262 | return run_method
263 |
264 | return output_types_decorator
265 |
266 | def _component(self, class_):
267 | """
268 | Decorator validating the structure of the component and registering it in the components registry.
269 | """
270 | logger.debug("Registering %s as a component", class_)
271 |
272 | # Check for required methods and fail as soon as possible
273 | if not hasattr(class_, "run"):
274 | raise ComponentError(f"{class_.__name__} must have a 'run()' method. See the docs for more information.")
275 |
276 | def copy_class_namespace(namespace):
277 | """
278 | This is the callback that `typing.new_class` will use
279 | to populate the newly created class. We just copy
280 | the whole namespace from the decorated class.
281 | """
282 | for key, val in dict(class_.__dict__).items():
283 | namespace[key] = val
284 |
285 | # Recreate the decorated component class so it uses our metaclass
286 | class_ = new_class(class_.__name__, class_.__bases__, {"metaclass": ComponentMeta}, copy_class_namespace)
287 |
288 | # Save the component in the class registry (for deserialization)
289 | class_path = f"{class_.__module__}.{class_.__name__}"
290 | if class_path in self.registry:
291 | # Corner case, but it may occur easily in notebooks when re-running cells.
292 | logger.debug(
293 | "Component %s is already registered. Previous imported from '%s', new imported from '%s'",
294 | class_path,
295 | self.registry[class_path],
296 | class_,
297 | )
298 | self.registry[class_path] = class_
299 | logger.debug("Registered Component %s", class_)
300 |
301 | return class_
302 |
303 | def __call__(self, class_):
304 | return self._component(class_)
305 |
306 |
307 | component = _Component()
308 |
--------------------------------------------------------------------------------
/canals/component/connection.py:
--------------------------------------------------------------------------------
1 | import itertools
2 | from typing import Optional, List, Tuple
3 | from dataclasses import dataclass
4 |
5 | from canals.component.sockets import InputSocket, OutputSocket
6 | from canals.type_utils import _type_name, _types_are_compatible
7 | from canals.errors import PipelineConnectError
8 |
9 |
10 | @dataclass
11 | class Connection:
12 | sender: Optional[str]
13 | sender_socket: Optional[OutputSocket]
14 | receiver: Optional[str]
15 | receiver_socket: Optional[InputSocket]
16 |
17 | def __post_init__(self):
18 | if self.sender and self.sender_socket and self.receiver and self.receiver_socket:
19 | # Make sure the receiving socket isn't already connected, unless it's variadic. Sending sockets can be
20 | # connected as many times as needed, so they don't need this check
21 | if self.receiver_socket.senders and not self.receiver_socket.is_variadic:
22 | raise PipelineConnectError(
23 | f"Cannot connect '{self.sender}.{self.sender_socket.name}' with '{self.receiver}.{self.receiver_socket.name}': "
24 | f"{self.receiver}.{self.receiver_socket.name} is already connected to {self.receiver_socket.senders}.\n"
25 | )
26 |
27 | self.sender_socket.receivers.append(self.receiver)
28 | self.receiver_socket.senders.append(self.sender)
29 |
30 | def __repr__(self):
31 | if self.sender and self.sender_socket:
32 | sender_repr = f"{self.sender}.{self.sender_socket.name} ({_type_name(self.sender_socket.type)})"
33 | else:
34 | sender_repr = "input needed"
35 |
36 | if self.receiver and self.receiver_socket:
37 | receiver_repr = f"({_type_name(self.receiver_socket.type)}) {self.receiver}.{self.receiver_socket.name}"
38 | else:
39 | receiver_repr = "output"
40 |
41 | return f"{sender_repr} --> {receiver_repr}"
42 |
43 | def __hash__(self):
44 | """
45 | Connection is used as a dictionary key in Pipeline, it must be hashable
46 | """
47 | return hash(
48 | "-".join(
49 | [
50 | self.sender if self.sender else "input",
51 | self.sender_socket.name if self.sender_socket else "",
52 | self.receiver if self.receiver else "output",
53 | self.receiver_socket.name if self.receiver_socket else "",
54 | ]
55 | )
56 | )
57 |
58 | @property
59 | def is_mandatory(self) -> bool:
60 | """
61 | Returns True if the connection goes to a mandatory input socket, False otherwise
62 | """
63 | if self.receiver_socket:
64 | return self.receiver_socket.is_mandatory
65 | return False
66 |
67 | @staticmethod
68 | def from_list_of_sockets(
69 | sender_node: str, sender_sockets: List[OutputSocket], receiver_node: str, receiver_sockets: List[InputSocket]
70 | ) -> "Connection":
71 | """
72 | Find one single possible connection between two lists of sockets.
73 | """
74 | # List all sender/receiver combinations of sockets that match by type
75 | possible_connections = [
76 | (sender_sock, receiver_sock)
77 | for sender_sock, receiver_sock in itertools.product(sender_sockets, receiver_sockets)
78 | if _types_are_compatible(sender_sock.type, receiver_sock.type)
79 | ]
80 |
81 | # No connections seem to be possible
82 | if not possible_connections:
83 | connections_status_str = _connections_status(
84 | sender_node=sender_node,
85 | sender_sockets=sender_sockets,
86 | receiver_node=receiver_node,
87 | receiver_sockets=receiver_sockets,
88 | )
89 |
90 | # Both sockets were specified: explain why the types don't match
91 | if len(sender_sockets) == len(receiver_sockets) and len(sender_sockets) == 1:
92 | raise PipelineConnectError(
93 | f"Cannot connect '{sender_node}.{sender_sockets[0].name}' with '{receiver_node}.{receiver_sockets[0].name}': "
94 | f"their declared input and output types do not match.\n{connections_status_str}"
95 | )
96 |
97 | # Not both sockets were specified: explain there's no possible match on any pair
98 | connections_status_str = _connections_status(
99 | sender_node=sender_node,
100 | sender_sockets=sender_sockets,
101 | receiver_node=receiver_node,
102 | receiver_sockets=receiver_sockets,
103 | )
104 | raise PipelineConnectError(
105 | f"Cannot connect '{sender_node}' with '{receiver_node}': "
106 | f"no matching connections available.\n{connections_status_str}"
107 | )
108 |
109 | # There's more than one possible connection
110 | if len(possible_connections) > 1:
111 | # Try to match by name
112 | name_matches = [
113 | (out_sock, in_sock) for out_sock, in_sock in possible_connections if in_sock.name == out_sock.name
114 | ]
115 | if len(name_matches) != 1:
116 | # TODO allow for multiple connections at once if there is no ambiguity?
117 | # TODO give priority to sockets that have no default values?
118 | connections_status_str = _connections_status(
119 | sender_node=sender_node,
120 | sender_sockets=sender_sockets,
121 | receiver_node=receiver_node,
122 | receiver_sockets=receiver_sockets,
123 | )
124 | raise PipelineConnectError(
125 | f"Cannot connect '{sender_node}' with '{receiver_node}': more than one connection is possible "
126 | "between these components. Please specify the connection name, like: "
127 | f"pipeline.connect('{sender_node}.{possible_connections[0][0].name}', "
128 | f"'{receiver_node}.{possible_connections[0][1].name}').\n{connections_status_str}"
129 | )
130 |
131 | match = possible_connections[0]
132 | return Connection(sender_node, match[0], receiver_node, match[1])
133 |
134 |
135 | def _connections_status(
136 | sender_node: str, receiver_node: str, sender_sockets: List[OutputSocket], receiver_sockets: List[InputSocket]
137 | ):
138 | """
139 | Lists the status of the sockets, for error messages.
140 | """
141 | sender_sockets_entries = []
142 | for sender_socket in sender_sockets:
143 | sender_sockets_entries.append(f" - {sender_socket.name}: {_type_name(sender_socket.type)}")
144 | sender_sockets_list = "\n".join(sender_sockets_entries)
145 |
146 | receiver_sockets_entries = []
147 | for receiver_socket in receiver_sockets:
148 | if receiver_socket.senders:
149 | sender_status = f"sent by {','.join(receiver_socket.senders)}"
150 | else:
151 | sender_status = "available"
152 | receiver_sockets_entries.append(
153 | f" - {receiver_socket.name}: {_type_name(receiver_socket.type)} ({sender_status})"
154 | )
155 | receiver_sockets_list = "\n".join(receiver_sockets_entries)
156 |
157 | return f"'{sender_node}':\n{sender_sockets_list}\n'{receiver_node}':\n{receiver_sockets_list}"
158 |
159 |
160 | def parse_connect_string(connection: str) -> Tuple[str, Optional[str]]:
161 | """
162 | Returns component-connection pairs from a connect_to/from string
163 | """
164 | if "." in connection:
165 | split_str = connection.split(".", maxsplit=1)
166 | return (split_str[0], split_str[1])
167 | return connection, None
168 |
--------------------------------------------------------------------------------
/canals/component/descriptions.py:
--------------------------------------------------------------------------------
1 | from typing import Dict, Any
2 |
3 |
4 | def find_component_inputs(component: Any) -> Dict[str, Dict[str, Any]]:
5 | """
6 | Returns a mapping of input names to their expected types and optionality for a given component.
7 |
8 | :param component: The target component to introspect.
9 | :return: A dictionary where keys are input names, with each key's value being another dictionary
10 | containing 'type' (the data type expected) and 'is_optional' (a boolean indicating if the input is optional).
11 |
12 | :raise: Throws a ValueError if the class of component instance is not appropriately decorated with @component.
13 | """
14 | if not hasattr(component, "__canals_input__"):
15 | raise ValueError(
16 | f"Component {component} does not have defined inputs or is improperly decorated. "
17 | "Ensure it is a valid @component with declared inputs."
18 | )
19 |
20 | return {
21 | name: {"type": socket.type, "is_mandatory": socket.is_mandatory, "is_variadic": socket.is_variadic}
22 | for name, socket in component.__canals_input__.items()
23 | }
24 |
25 |
26 | def find_component_outputs(component: Any) -> Dict[str, Dict[str, Any]]:
27 | """
28 | Returns a mapping of component output names to their expected types.
29 |
30 | :param component: The component being examined for its outputs.
31 | :return: A dictionary where each key is an output name and the value is a dictionary with a 'type' key
32 | indicating the data type of the output.
33 |
34 | :raise: Throws a ValueError if the class of component instance is not appropriately decorated with @component.
35 | """
36 | if not hasattr(component, "__canals_output__"):
37 | raise ValueError(
38 | f"The specified component {component} does not have defined outputs or is not properly decorated. "
39 | "Check that it is a valid @component with outputs specified."
40 | )
41 |
42 | return {name: {"type": socket.type} for name, socket in component.__canals_output__.items()}
43 |
--------------------------------------------------------------------------------
/canals/component/sockets.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: 2022-present deepset GmbH
2 | #
3 | # SPDX-License-Identifier: Apache-2.0
4 | from typing import get_args, List, Type
5 | import logging
6 | from dataclasses import dataclass, field
7 |
8 | from canals.component.types import CANALS_VARIADIC_ANNOTATION
9 |
10 |
11 | logger = logging.getLogger(__name__)
12 |
13 |
14 | @dataclass
15 | class InputSocket:
16 | name: str
17 | type: Type
18 | is_mandatory: bool = True
19 | is_variadic: bool = field(init=False)
20 | senders: List[str] = field(default_factory=list)
21 |
22 | def __post_init__(self):
23 | try:
24 | # __metadata__ is a tuple
25 | self.is_variadic = self.type.__metadata__[0] == CANALS_VARIADIC_ANNOTATION
26 | except AttributeError:
27 | self.is_variadic = False
28 | if self.is_variadic:
29 | # We need to "unpack" the type inside the Variadic annotation,
30 | # otherwise the pipeline connection api will try to match
31 | # `Annotated[type, CANALS_VARIADIC_ANNOTATION]`.
32 | #
33 | # Note1: Variadic is expressed as an annotation of one single type,
34 | # so the return value of get_args will always be a one-item tuple.
35 | #
36 | # Note2: a pipeline always passes a list of items when a component
37 | # input is declared as Variadic, so the type itself always wraps
38 | # an iterable of the declared type. For example, Variadic[int]
39 | # is eventually an alias for Iterable[int]. Since we're interested
40 | # in getting the inner type `int`, we call `get_args` twice: the
41 | # first time to get `List[int]` out of `Variadic`, the second time
42 | # to get `int` out of `List[int]`.
43 | self.type = get_args(get_args(self.type)[0])[0]
44 |
45 |
46 | @dataclass
47 | class OutputSocket:
48 | name: str
49 | type: type
50 | receivers: List[str] = field(default_factory=list)
51 |
--------------------------------------------------------------------------------
/canals/component/types.py:
--------------------------------------------------------------------------------
1 | from typing import TypeVar, Iterable
2 | from typing_extensions import TypeAlias, Annotated # Python 3.8 compatibility
3 |
4 | CANALS_VARIADIC_ANNOTATION = "__canals__variadic_t"
5 |
6 | # # Generic type variable used in the Variadic container
7 | T = TypeVar("T")
8 |
9 |
10 | # Variadic is a custom annotation type we use to mark input types.
11 | # This type doesn't do anything else than "marking" the contained
12 | # type so it can be used in the `InputSocket` creation where we
13 | # check that its annotation equals to CANALS_VARIADIC_ANNOTATION
14 | Variadic: TypeAlias = Annotated[Iterable[T], CANALS_VARIADIC_ANNOTATION]
15 |
--------------------------------------------------------------------------------
/canals/errors.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: 2022-present deepset GmbH
2 | #
3 | # SPDX-License-Identifier: Apache-2.0
4 | class PipelineError(Exception):
5 | pass
6 |
7 |
8 | class PipelineRuntimeError(Exception):
9 | pass
10 |
11 |
12 | class PipelineConnectError(PipelineError):
13 | pass
14 |
15 |
16 | class PipelineValidationError(PipelineError):
17 | pass
18 |
19 |
20 | class PipelineDrawingError(PipelineError):
21 | pass
22 |
23 |
24 | class PipelineMaxLoops(PipelineError):
25 | pass
26 |
27 |
28 | class PipelineUnmarshalError(PipelineError):
29 | pass
30 |
31 |
32 | class ComponentError(Exception):
33 | pass
34 |
35 |
36 | class ComponentDeserializationError(Exception):
37 | pass
38 |
39 |
40 | class DeserializationError(Exception):
41 | pass
42 |
43 |
44 | class SerializationError(Exception):
45 | pass
46 |
--------------------------------------------------------------------------------
/canals/pipeline/__init__.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: 2022-present deepset GmbH
2 | #
3 | # SPDX-License-Identifier: Apache-2.0
4 | from canals.pipeline.pipeline import Pipeline
5 | from canals.errors import (
6 | PipelineError,
7 | PipelineRuntimeError,
8 | PipelineValidationError,
9 | PipelineConnectError,
10 | PipelineMaxLoops,
11 | )
12 |
--------------------------------------------------------------------------------
/canals/pipeline/descriptions.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: 2022-present deepset GmbH
2 | #
3 | # SPDX-License-Identifier: Apache-2.0
4 | from typing import List, Dict
5 | import logging
6 |
7 | import networkx # type:ignore
8 |
9 | from canals.type_utils import _type_name
10 | from canals.component.sockets import InputSocket, OutputSocket
11 |
12 |
13 | logger = logging.getLogger(__name__)
14 |
15 |
16 | def find_pipeline_inputs(graph: networkx.MultiDiGraph) -> Dict[str, List[InputSocket]]:
17 | """
18 | Collect components that have disconnected input sockets. Note that this method returns *ALL* disconnected
19 | input sockets, including all such sockets with default values.
20 | """
21 | return {
22 | name: [socket for socket in data.get("input_sockets", {}).values() if not socket.senders or socket.is_variadic]
23 | for name, data in graph.nodes(data=True)
24 | }
25 |
26 |
27 | def find_pipeline_outputs(graph: networkx.MultiDiGraph) -> Dict[str, List[OutputSocket]]:
28 | """
29 | Collect components that have disconnected output sockets. They define the pipeline output.
30 | """
31 | return {
32 | name: [socket for socket in data.get("output_sockets", {}).values() if not socket.receivers]
33 | for name, data in graph.nodes(data=True)
34 | }
35 |
36 |
37 | def describe_pipeline_inputs(graph: networkx.MultiDiGraph):
38 | """
39 | Returns a dictionary with the input names and types that this pipeline accepts.
40 | """
41 | inputs = {
42 | comp: {socket.name: {"type": socket.type, "is_mandatory": socket.is_mandatory} for socket in data}
43 | for comp, data in find_pipeline_inputs(graph).items()
44 | if data
45 | }
46 | return inputs
47 |
48 |
49 | def describe_pipeline_inputs_as_string(graph: networkx.MultiDiGraph):
50 | """
51 | Returns a string representation of the input names and types that this pipeline accepts.
52 | """
53 | inputs = describe_pipeline_inputs(graph)
54 | message = "This pipeline expects the following inputs:\n"
55 | for comp, sockets in inputs.items():
56 | if sockets:
57 | message += f"- {comp}:\n"
58 | for name, socket in sockets.items():
59 | message += f" - {name}: {_type_name(socket['type'])}\n"
60 | return message
61 |
--------------------------------------------------------------------------------
/canals/pipeline/draw/__init__.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: 2022-present deepset GmbH
2 | #
3 | # SPDX-License-Identifier: Apache-2.0
4 | from canals.pipeline.draw.draw import _draw, _convert, _convert_for_debug, RenderingEngines
5 |
--------------------------------------------------------------------------------
/canals/pipeline/draw/draw.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: 2022-present deepset GmbH
2 | #
3 | # SPDX-License-Identifier: Apache-2.0
4 | from typing import Literal, Optional, Dict, get_args, Any
5 |
6 | import logging
7 | from pathlib import Path
8 |
9 | import networkx # type:ignore
10 |
11 | from canals.pipeline.descriptions import find_pipeline_inputs, find_pipeline_outputs
12 | from canals.pipeline.draw.graphviz import _to_agraph
13 | from canals.pipeline.draw.mermaid import _to_mermaid_image, _to_mermaid_text
14 | from canals.type_utils import _type_name
15 |
16 | logger = logging.getLogger(__name__)
17 | RenderingEngines = Literal["graphviz", "mermaid-image", "mermaid-text"]
18 |
19 |
20 | def _draw(
21 | graph: networkx.MultiDiGraph,
22 | path: Path,
23 | engine: RenderingEngines = "mermaid-image",
24 | style_map: Optional[Dict[str, str]] = None,
25 | ) -> None:
26 | """
27 | Renders the pipeline graph and saves it to file.
28 | """
29 | converted_graph = _convert(graph=graph, engine=engine, style_map=style_map)
30 |
31 | if engine == "graphviz":
32 | converted_graph.draw(path)
33 |
34 | elif engine == "mermaid-image":
35 | with open(path, "wb") as imagefile:
36 | imagefile.write(converted_graph)
37 |
38 | elif engine == "mermaid-text":
39 | with open((path), "w", encoding="utf-8") as textfile:
40 | textfile.write(converted_graph)
41 |
42 | else:
43 | raise ValueError(f"Unknown rendering engine '{engine}'. Choose one from: {get_args(RenderingEngines)}.")
44 |
45 | logger.debug("Pipeline diagram saved at %s", path)
46 |
47 |
48 | def _convert_for_debug(
49 | graph: networkx.MultiDiGraph,
50 | ) -> Any:
51 | """
52 | Renders the pipeline graph with additional debug information into a text file that Mermaid can later render.
53 | """
54 | graph = _prepare_for_drawing(graph=graph, style_map={})
55 | return _to_mermaid_text(graph=graph)
56 |
57 |
58 | def _convert(
59 | graph: networkx.MultiDiGraph,
60 | engine: RenderingEngines = "mermaid-image",
61 | style_map: Optional[Dict[str, str]] = None,
62 | ) -> Any:
63 | """
64 | Renders the pipeline graph with the correct render and returns it.
65 | """
66 | graph = _prepare_for_drawing(graph=graph, style_map=style_map or {})
67 |
68 | if engine == "graphviz":
69 | return _to_agraph(graph=graph)
70 |
71 | if engine == "mermaid-image":
72 | return _to_mermaid_image(graph=graph)
73 |
74 | if engine == "mermaid-text":
75 | return _to_mermaid_text(graph=graph)
76 |
77 | raise ValueError(f"Unknown rendering engine '{engine}'. Choose one from: {get_args(RenderingEngines)}.")
78 |
79 |
80 | def _prepare_for_drawing(graph: networkx.MultiDiGraph, style_map: Dict[str, str]) -> networkx.MultiDiGraph:
81 | """
82 | Prepares the graph to be drawn: adds explitic input and output nodes, labels the edges, applies the styles, etc.
83 | """
84 | # Apply the styles
85 | if style_map:
86 | for node, style in style_map.items():
87 | graph.nodes[node]["style"] = style
88 |
89 | # Label the edges
90 | for inp, outp, key, data in graph.edges(keys=True, data=True):
91 | data["label"] = f"{data['from_socket'].name} -> {data['to_socket'].name}"
92 | graph.add_edge(inp, outp, key=key, **data)
93 |
94 | # Draw the inputs
95 | graph.add_node("input")
96 | for node, in_sockets in find_pipeline_inputs(graph).items():
97 | for in_socket in in_sockets:
98 | if not in_socket.senders and in_socket.is_mandatory:
99 | # If this socket has no sender it could be a socket that receives input
100 | # directly when running the Pipeline. We can't know that for sure, in doubt
101 | # we draw it as receiving input directly.
102 | graph.add_edge("input", node, label=in_socket.name, conn_type=_type_name(in_socket.type))
103 |
104 | # Draw the outputs
105 | graph.add_node("output")
106 | for node, out_sockets in find_pipeline_outputs(graph).items():
107 | for out_socket in out_sockets:
108 | graph.add_edge(node, "output", label=out_socket.name, conn_type=_type_name(out_socket.type))
109 |
110 | return graph
111 |
--------------------------------------------------------------------------------
/canals/pipeline/draw/graphviz.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: 2022-present deepset GmbH
2 | #
3 | # SPDX-License-Identifier: Apache-2.0
4 | import logging
5 |
6 | import networkx # type:ignore
7 |
8 | from networkx.drawing.nx_agraph import to_agraph as nx_to_agraph # type:ignore
9 |
10 |
11 | logger = logging.getLogger(__name__)
12 |
13 | # pyright: reportMissingImports=false
14 | # pylint: disable=unused-import,import-outside-toplevel
15 |
16 |
17 | def _to_agraph(graph: networkx.MultiDiGraph):
18 | """
19 | Renders a pipeline graph using PyGraphViz. You need to install it and all its system dependencies for it to work.
20 | """
21 | try:
22 | import pygraphviz # type: ignore
23 | except (ModuleNotFoundError, ImportError) as exc:
24 | raise ImportError(
25 | "Can't use 'pygraphviz' to draw this pipeline: pygraphviz could not be imported. "
26 | "Make sure pygraphviz is installed and all its system dependencies are setup correctly."
27 | ) from exc
28 |
29 | for inp, outp, key, data in graph.out_edges("input", keys=True, data=True):
30 | data["style"] = "dashed"
31 | graph.add_edge(inp, outp, key=key, **data)
32 |
33 | for inp, outp, key, data in graph.in_edges("output", keys=True, data=True):
34 | data["style"] = "dashed"
35 | graph.add_edge(inp, outp, key=key, **data)
36 |
37 | graph.nodes["input"]["shape"] = "plain"
38 | graph.nodes["output"]["shape"] = "plain"
39 | agraph = nx_to_agraph(graph)
40 | agraph.layout("dot")
41 | return agraph
42 |
--------------------------------------------------------------------------------
/canals/pipeline/draw/mermaid.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: 2022-present deepset GmbH
2 | #
3 | # SPDX-License-Identifier: Apache-2.0
4 | import logging
5 | import base64
6 |
7 | import requests
8 | import networkx # type:ignore
9 |
10 | from canals.errors import PipelineDrawingError
11 | from canals.type_utils import _type_name
12 |
13 | logger = logging.getLogger(__name__)
14 |
15 |
16 | MERMAID_STYLED_TEMPLATE = """
17 | %%{{ init: {{'theme': 'neutral' }} }}%%
18 |
19 | graph TD;
20 |
21 | {connections}
22 |
23 | classDef component text-align:center;
24 | """
25 |
26 |
27 | def _to_mermaid_image(graph: networkx.MultiDiGraph):
28 | """
29 | Renders a pipeline using Mermaid (hosted version at 'https://mermaid.ink'). Requires Internet access.
30 | """
31 | graph_styled = _to_mermaid_text(graph=graph)
32 |
33 | graphbytes = graph_styled.encode("ascii")
34 | base64_bytes = base64.b64encode(graphbytes)
35 | base64_string = base64_bytes.decode("ascii")
36 | url = "https://mermaid.ink/img/" + base64_string
37 |
38 | logging.debug("Rendeding graph at %s", url)
39 | try:
40 | resp = requests.get(url, timeout=10)
41 | if resp.status_code >= 400:
42 | logger.warning("Failed to draw the pipeline: https://mermaid.ink/img/ returned status %s", resp.status_code)
43 | logger.info("Exact URL requested: %s", url)
44 | logger.warning("No pipeline diagram will be saved.")
45 | resp.raise_for_status()
46 |
47 | except Exception as exc: # pylint: disable=broad-except
48 | logger.warning("Failed to draw the pipeline: could not connect to https://mermaid.ink/img/ (%s)", exc)
49 | logger.info("Exact URL requested: %s", url)
50 | logger.warning("No pipeline diagram will be saved.")
51 | raise PipelineDrawingError(
52 | "There was an issue with https://mermaid.ink/, see the stacktrace for details."
53 | ) from exc
54 |
55 | return resp.content
56 |
57 |
58 | def _to_mermaid_text(graph: networkx.MultiDiGraph) -> str:
59 | """
60 | Converts a Networkx graph into Mermaid syntax. The output of this function can be used in the documentation
61 | with `mermaid` codeblocks and it will be automatically rendered.
62 | """
63 | sockets = {
64 | comp: "".join(
65 | [
66 | f"{name} ({_type_name(socket.type)})"
67 | for name, socket in data.get("input_sockets", {}).items()
68 | if (not socket.is_mandatory and not socket.senders) or socket.is_variadic
69 | ]
70 | )
71 | for comp, data in graph.nodes(data=True)
72 | }
73 | optional_inputs = {
74 | comp: f"
Optional inputs:" if sockets else ""
75 | for comp, sockets in sockets.items()
76 | }
77 |
78 | states = {
79 | comp: f"{comp}[\"{comp}
{type(data['instance']).__name__}{optional_inputs[comp]}\"]:::component"
80 | for comp, data in graph.nodes(data=True)
81 | if comp not in ["input", "output"]
82 | }
83 |
84 | connections_list = [
85 | f"{states[from_comp]} -- \"{conn_data['label']}
{conn_data['conn_type']}\" --> {states[to_comp]}"
86 | for from_comp, to_comp, conn_data in graph.edges(data=True)
87 | if from_comp != "input" and to_comp != "output"
88 | ]
89 | input_connections = [
90 | f"i{{*}} -- \"{conn_data['label']}
{conn_data['conn_type']}\" --> {states[to_comp]}"
91 | for _, to_comp, conn_data in graph.out_edges("input", data=True)
92 | ]
93 | output_connections = [
94 | f"{states[from_comp]} -- \"{conn_data['label']}
{conn_data['conn_type']}\"--> o{{*}}"
95 | for from_comp, _, conn_data in graph.in_edges("output", data=True)
96 | ]
97 | connections = "\n".join(connections_list + input_connections + output_connections)
98 |
99 | graph_styled = MERMAID_STYLED_TEMPLATE.format(connections=connections)
100 | logger.debug("Mermaid diagram:\n%s", graph_styled)
101 |
102 | return graph_styled
103 |
--------------------------------------------------------------------------------
/canals/pipeline/validation.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: 2022-present deepset GmbH
2 | #
3 | # SPDX-License-Identifier: Apache-2.0
4 | from typing import Dict, Any
5 | import logging
6 |
7 | import networkx # type:ignore
8 |
9 | from canals.errors import PipelineValidationError
10 | from canals.component.sockets import InputSocket
11 | from canals.pipeline.descriptions import find_pipeline_inputs, describe_pipeline_inputs_as_string
12 |
13 |
14 | logger = logging.getLogger(__name__)
15 |
16 |
17 | def validate_pipeline_input(graph: networkx.MultiDiGraph, input_values: Dict[str, Any]) -> Dict[str, Any]:
18 | """
19 | Make sure the pipeline is properly built and that the input received makes sense.
20 | Returns the input values, validated and updated at need.
21 | """
22 | if not any(sockets for sockets in find_pipeline_inputs(graph).values()):
23 | raise PipelineValidationError("This pipeline has no inputs.")
24 |
25 | # Make sure the input keys are all nodes of the pipeline
26 | unknown_components = [key for key in input_values.keys() if not key in graph.nodes]
27 | if unknown_components:
28 | all_inputs = describe_pipeline_inputs_as_string(graph)
29 | raise ValueError(
30 | f"Pipeline received data for unknown component(s): {', '.join(unknown_components)}\n\n{all_inputs}"
31 | )
32 |
33 | # Make sure all necessary sockets are connected
34 | _validate_input_sockets_are_connected(graph, input_values)
35 |
36 | # Make sure that the pipeline input is only sent to nodes that won't receive data from other nodes
37 | _validate_nodes_receive_only_expected_input(graph, input_values)
38 |
39 | return input_values
40 |
41 |
42 | def _validate_input_sockets_are_connected(graph: networkx.MultiDiGraph, input_values: Dict[str, Any]):
43 | """
44 | Make sure all the inputs nodes are receiving all the values they need, either from the Pipeline's input or from
45 | other nodes.
46 | """
47 | valid_inputs = find_pipeline_inputs(graph)
48 | for node, sockets in valid_inputs.items():
49 | for socket in sockets:
50 | inputs_for_node = input_values.get(node, {})
51 | missing_input_value = (
52 | inputs_for_node is None
53 | or not socket.name in inputs_for_node.keys()
54 | or inputs_for_node.get(socket.name, None) is None
55 | )
56 | if missing_input_value and socket.is_mandatory and not socket.is_variadic:
57 | all_inputs = describe_pipeline_inputs_as_string(graph)
58 | raise ValueError(f"Missing input: {node}.{socket.name}\n\n{all_inputs}")
59 |
60 |
61 | def _validate_nodes_receive_only_expected_input(graph: networkx.MultiDiGraph, input_values: Dict[str, Any]):
62 | """
63 | Make sure that every input node is only receiving input values from EITHER the pipeline's input or another node,
64 | but never from both.
65 | """
66 | for node, input_data in input_values.items():
67 | for socket_name in input_data.keys():
68 | if input_data.get(socket_name, None) is None:
69 | continue
70 | if not socket_name in graph.nodes[node]["input_sockets"].keys():
71 | all_inputs = describe_pipeline_inputs_as_string(graph)
72 | raise ValueError(
73 | f"Component {node} is not expecting any input value called {socket_name}.\n\n{all_inputs}",
74 | )
75 |
76 | input_socket: InputSocket = graph.nodes[node]["input_sockets"][socket_name]
77 | if input_socket.senders and not input_socket.is_variadic:
78 | raise ValueError(f"The input {socket_name} of {node} is already sent by: {input_socket.senders}")
79 |
--------------------------------------------------------------------------------
/canals/serialization.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: 2022-present deepset GmbH
2 | #
3 | # SPDX-License-Identifier: Apache-2.0
4 | import inspect
5 | from typing import Type, Dict, Any
6 |
7 | from canals.errors import DeserializationError, SerializationError
8 |
9 |
10 | def component_to_dict(obj: Any) -> Dict[str, Any]:
11 | """
12 | The marshaller used by the Pipeline. If a `to_dict` method is present in the
13 | component instance, that will be used instead of the default method.
14 | """
15 | if hasattr(obj, "to_dict"):
16 | return obj.to_dict()
17 |
18 | init_parameters = {}
19 | for name, param in inspect.signature(obj.__init__).parameters.items():
20 | # Ignore `args` and `kwargs`, used by the default constructor
21 | if name in ("args", "kwargs"):
22 | continue
23 | try:
24 | # This only works if the Component constructor assigns the init
25 | # parameter to an instance variable or property with the same name
26 | param_value = getattr(obj, name)
27 | except AttributeError as e:
28 | # If the parameter doesn't have a default value, raise an error
29 | if param.default is param.empty:
30 | raise SerializationError(
31 | f"Cannot determine the value of the init parameter '{name}' for the class {obj.__class__.__name__}."
32 | f"You can fix this error by assigning 'self.{name} = {name}' or adding a "
33 | f"custom serialization method 'to_dict' to the class."
34 | ) from e
35 | # In case the init parameter was not assigned, we use the default value
36 | param_value = param.default
37 | init_parameters[name] = param_value
38 |
39 | return default_to_dict(obj, **init_parameters)
40 |
41 |
42 | def component_from_dict(cls: Type[object], data: Dict[str, Any]) -> Any:
43 | """
44 | The unmarshaller used by the Pipeline. If a `from_dict` method is present in the
45 | component instance, that will be used instead of the default method.
46 | """
47 | if hasattr(cls, "from_dict"):
48 | return cls.from_dict(data)
49 |
50 | return default_from_dict(cls, data)
51 |
52 |
53 | def default_to_dict(obj: Any, **init_parameters) -> Dict[str, Any]:
54 | """
55 | Utility function to serialize an object to a dictionary.
56 | This is mostly necessary for Components but it can be used by any object.
57 |
58 | `init_parameters` are parameters passed to the object class `__init__`.
59 | They must be defined explicitly as they'll be used when creating a new
60 | instance of `obj` with `from_dict`. Omitting them might cause deserialisation
61 | errors or unexpected behaviours later, when calling `from_dict`.
62 |
63 | An example usage:
64 |
65 | ```python
66 | class MyClass:
67 | def __init__(self, my_param: int = 10):
68 | self.my_param = my_param
69 |
70 | def to_dict(self):
71 | return default_to_dict(self, my_param=self.my_param)
72 |
73 |
74 | obj = MyClass(my_param=1000)
75 | data = obj.to_dict()
76 | assert data == {
77 | "type": "MyClass",
78 | "init_parameters": {
79 | "my_param": 1000,
80 | },
81 | }
82 | ```
83 | """
84 | return {
85 | "type": f"{obj.__class__.__module__}.{obj.__class__.__name__}",
86 | "init_parameters": init_parameters,
87 | }
88 |
89 |
90 | def default_from_dict(cls: Type[object], data: Dict[str, Any]) -> Any:
91 | """
92 | Utility function to deserialize a dictionary to an object.
93 | This is mostly necessary for Components but it can be used by any object.
94 |
95 | The function will raise a `DeserializationError` if the `type` field in `data` is
96 | missing or it doesn't match the type of `cls`.
97 |
98 | If `data` contains an `init_parameters` field it will be used as parameters to create
99 | a new instance of `cls`.
100 | """
101 | init_params = data.get("init_parameters", {})
102 | if "type" not in data:
103 | raise DeserializationError("Missing 'type' in serialization data")
104 | if data["type"] != f"{cls.__module__}.{cls.__name__}":
105 | raise DeserializationError(f"Class '{data['type']}' can't be deserialized as '{cls.__name__}'")
106 | return cls(**init_params)
107 |
--------------------------------------------------------------------------------
/canals/testing/__init__.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: 2022-present deepset GmbH
2 | #
3 | # SPDX-License-Identifier: Apache-2.0
4 |
--------------------------------------------------------------------------------
/canals/testing/factory.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: 2022-present deepset GmbH
2 | #
3 | # SPDX-License-Identifier: Apache-2.0
4 | from typing import Any, Dict, Optional, Tuple, Type
5 |
6 | from canals import component, Component
7 | from canals.serialization import default_to_dict, default_from_dict
8 |
9 |
10 | def component_class(
11 | name: str,
12 | input_types: Optional[Dict[str, Any]] = None,
13 | output_types: Optional[Dict[str, Any]] = None,
14 | output: Optional[Dict[str, Any]] = None,
15 | bases: Optional[Tuple[type, ...]] = None,
16 | extra_fields: Optional[Dict[str, Any]] = None,
17 | ) -> Type[Component]:
18 | """
19 | Utility class to create a Component class with the given name and input and output types.
20 |
21 | If `output` is set but `output_types` is not, `output_types` will be set to the types of the values in `output`.
22 | Though if `output_types` is set but `output` is not the component's `run` method will return a dictionary
23 | of the same keys as `output_types` all with a value of None.
24 |
25 | ### Usage
26 |
27 | Create a component class with default input and output types:
28 | ```python
29 | MyFakeComponent = component_class_factory("MyFakeComponent")
30 | component = MyFakeComponent()
31 | output = component.run(value=1)
32 | assert output == {"value": None}
33 | ```
34 |
35 | Create a component class with an "value" input of type `int` and with a "value" output of `10`:
36 | ```python
37 | MyFakeComponent = component_class_factory(
38 | "MyFakeComponent",
39 | input_types={"value": int},
40 | output={"value": 10}
41 | )
42 | component = MyFakeComponent()
43 | output = component.run(value=1)
44 | assert output == {"value": 10}
45 | ```
46 |
47 | Create a component class with a custom base class:
48 | ```python
49 | MyFakeComponent = component_class_factory(
50 | "MyFakeComponent",
51 | bases=(MyBaseClass,)
52 | )
53 | component = MyFakeComponent()
54 | assert isinstance(component, MyBaseClass)
55 | ```
56 |
57 | Create a component class with an extra field `my_field`:
58 | ```python
59 | MyFakeComponent = component_class_factory(
60 | "MyFakeComponent",
61 | extra_fields={"my_field": 10}
62 | )
63 | component = MyFakeComponent()
64 | assert component.my_field == 10
65 | ```
66 |
67 | Args:
68 | name: Name of the component class
69 | input_types: Dictionary of string and type that defines the inputs of the component,
70 | if set to None created component will expect a single input "value" of Any type.
71 | Defaults to None.
72 | output_types: Dictionary of string and type that defines the outputs of the component,
73 | if set to None created component will return a single output "value" of NoneType and None value.
74 | Defaults to None.
75 | output: Actual output dictionary returned by the created component run,
76 | is set to None it will return a dictionary of string and None values.
77 | Keys will be the same as the keys of output_types. Defaults to None.
78 | bases: Base classes for this component, if set to None only base is object. Defaults to None.
79 | extra_fields: Extra fields for the Component, defaults to None.
80 |
81 | :return: A class definition that can be used as a component.
82 | """
83 | if input_types is None:
84 | input_types = {"value": Any}
85 | if output_types is None and output is not None:
86 | output_types = {key: type(value) for key, value in output.items()}
87 | elif output_types is None:
88 | output_types = {"value": type(None)}
89 |
90 | def init(self):
91 | component.set_input_types(self, **input_types)
92 | component.set_output_types(self, **output_types)
93 |
94 | # Both arguments are necessary to correctly define
95 | # run but pylint doesn't like that we don't use them.
96 | # It's fine ignoring the warning here.
97 | def run(self, **kwargs): # pylint: disable=unused-argument
98 | if output is not None:
99 | return output
100 | return {name: None for name in output_types.keys()}
101 |
102 | def to_dict(self):
103 | return default_to_dict(self)
104 |
105 | def from_dict(cls, data: Dict[str, Any]):
106 | return default_from_dict(cls, data)
107 |
108 | fields = {
109 | "__init__": init,
110 | "run": run,
111 | "to_dict": to_dict,
112 | "from_dict": classmethod(from_dict),
113 | }
114 | if extra_fields is not None:
115 | fields = {**fields, **extra_fields}
116 |
117 | if bases is None:
118 | bases = (object,)
119 |
120 | cls = type(name, bases, fields)
121 | return component(cls)
122 |
--------------------------------------------------------------------------------
/canals/type_utils.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: 2022-present deepset GmbH
2 | #
3 | # SPDX-License-Identifier: Apache-2.0
4 | from typing import Union, get_args, get_origin, Any
5 |
6 | import logging
7 |
8 |
9 | logger = logging.getLogger(__name__)
10 |
11 |
12 | def _is_optional(type_: type) -> bool:
13 | """
14 | Utility method that returns whether a type is Optional.
15 | """
16 | return get_origin(type_) is Union and type(None) in get_args(type_)
17 |
18 |
19 | def _types_are_compatible(sender, receiver): # pylint: disable=too-many-return-statements
20 | """
21 | Checks whether the source type is equal or a subtype of the destination type. Used to validate pipeline connections.
22 |
23 | Note: this method has no pretense to perform proper type matching. It especially does not deal with aliasing of
24 | typing classes such as `List` or `Dict` to their runtime counterparts `list` and `dict`. It also does not deal well
25 | with "bare" types, so `List` is treated differently from `List[Any]`, even though they should be the same.
26 |
27 | Consider simplifying the typing of your components if you observe unexpected errors during component connection.
28 | """
29 | if sender == receiver or receiver is Any:
30 | return True
31 |
32 | if sender is Any:
33 | return False
34 |
35 | try:
36 | if issubclass(sender, receiver):
37 | return True
38 | except TypeError: # typing classes can't be used with issubclass, so we deal with them below
39 | pass
40 |
41 | sender_origin = get_origin(sender)
42 | receiver_origin = get_origin(receiver)
43 |
44 | if sender_origin is not Union and receiver_origin is Union:
45 | return any(_types_are_compatible(sender, union_arg) for union_arg in get_args(receiver))
46 |
47 | if not sender_origin or not receiver_origin or sender_origin != receiver_origin:
48 | return False
49 |
50 | sender_args = get_args(sender)
51 | receiver_args = get_args(receiver)
52 | if len(sender_args) > len(receiver_args):
53 | return False
54 |
55 | return all(_types_are_compatible(*args) for args in zip(sender_args, receiver_args))
56 |
57 |
58 | def _type_name(type_):
59 | """
60 | Util methods to get a nice readable representation of a type.
61 | Handles Optional and Literal in a special way to make it more readable.
62 | """
63 | # Literal args are strings, so we wrap them in quotes to make it clear
64 | if isinstance(type_, str):
65 | return f"'{type_}'"
66 |
67 | name = getattr(type_, "__name__", str(type_))
68 |
69 | if name.startswith("typing."):
70 | name = name[7:]
71 | if "[" in name:
72 | name = name.split("[")[0]
73 | args = get_args(type_)
74 | if name == "Union" and type(None) in args and len(args) == 2:
75 | # Optional is technically a Union of type and None
76 | # but we want to display it as Optional
77 | name = "Optional"
78 |
79 | if args:
80 | args = ", ".join([_type_name(a) for a in args if a is not type(None)])
81 | return f"{name}[{args}]"
82 |
83 | return f"{name}"
84 |
--------------------------------------------------------------------------------
/docs/api-docs/canals.md:
--------------------------------------------------------------------------------
1 | # Canals
2 |
3 | ::: canals
4 |
5 | ::: canals.__about__
6 |
7 | ::: canals.errors
8 |
9 | ::: canals.type_utils
10 |
11 | ::: canals.serialization
12 |
--------------------------------------------------------------------------------
/docs/api-docs/component.md:
--------------------------------------------------------------------------------
1 | # Component API
2 |
3 | ::: canals.component
4 |
5 | ::: canals.component.component
6 |
7 | ::: canals.component.sockets
8 |
--------------------------------------------------------------------------------
/docs/api-docs/pipeline.md:
--------------------------------------------------------------------------------
1 | # Pipeline
2 |
3 | ::: canals.pipeline
4 |
5 | ::: canals.pipeline.draw
6 |
7 | ::: canals.pipeline.draw.draw
8 |
9 | ::: canals.pipeline.draw.graphviz
10 |
11 | ::: canals.pipeline.draw.mermaid
12 |
13 | ::: canals.pipeline.pipeline
14 |
15 | ::: canals.pipeline.validation
16 |
--------------------------------------------------------------------------------
/docs/api-docs/testing.md:
--------------------------------------------------------------------------------
1 | # Testing
2 |
3 | ::: canals.testing
4 |
5 | ::: canals.testing.factory
6 |
--------------------------------------------------------------------------------
/docs/concepts/components.md:
--------------------------------------------------------------------------------
1 | # Components
2 |
3 | In order to be recognized as components and work in a Pipeline, Components must follow the contract below.
4 |
5 | ## Requirements
6 |
7 | ### `@component` decorator
8 |
9 | All component classes must be decorated with the `@component` decorator. This allows Canals to discover them.
10 |
11 | ### `@component.input`
12 |
13 | All components must decorate one single method with the `@component.input` decorator. This method must return a dataclass, which will be used as structure of the input of the component.
14 |
15 | For example, if the node is expecting a list of Documents, the fields of the returned dataclass should be `documents: List[Document]`. Note that you don't need to decorate the dataclass youself: `@component.input` will add the decorator for you.
16 |
17 | Here is an example of such method:
18 |
19 | ```python
20 | @component.input
21 | def input(self):
22 | class Input:
23 | value: int
24 | add: int
25 |
26 | return Input
27 | ```
28 |
29 | Defaults are allowed, as much as default factories and other dataclass properties.
30 |
31 | By default `@component.input` sets `None` as default for all fields, regardless of their definition: this gives you the
32 | possibility of passing a part of the input to the pipeline without defining every field of the component. For example,
33 | using the above definition, you can create an Input dataclass as:
34 |
35 | ```python
36 | self.input(add=3)
37 | ```
38 |
39 | and the resulting dataclass will look like `Input(value=None, add=3)`.
40 |
41 | However, if you don't explicitly define them as Optionals, Pipeline will make sure to collect all the values of this
42 | dataclass before calling the `run()` method, making them in practice non-optional.
43 |
44 | If you instead define a specific field as Optional in the dataclass, then Pipeline will **not** wait for them, and will
45 | run the component as soon as all the non-optional fields have received a value or, if all fields are optional, if at
46 | least one of them received it.
47 |
48 | This behavior allows Canals to define loops by not waiting on both incoming inputs of the entry component of the loop,
49 | and instead running as soon as at least one of them receives a value.
50 |
51 | ### `@component.output`
52 |
53 | All components must decorate one single method with the `@component.output` decorator. This method must return a dataclass, which will be used as structure of the output of the component.
54 |
55 | For example, if the node is producing a list of Documents, the fields of the returned dataclass should be `documents: List[Document]`. Note that you don't need to decorate the dataclass youself: `@component.output` will add the decorator for you.
56 |
57 | Here is an example of such method:
58 |
59 | ```python
60 | @component.output
61 | def output(self):
62 | class Output:
63 | value: int
64 |
65 | return Output
66 | ```
67 |
68 | Defaults are allowed, as much as default factories and other dataclass properties.
69 |
70 | ### `__init__(self, **kwargs)`
71 |
72 | Optional method.
73 |
74 | Components may have an `__init__` method where they define:
75 |
76 | - `self.defaults = {parameter_name: parameter_default_value, ...}`:
77 | All values defined here will be sent to the `run()` method when the Pipeline calls it.
78 | If any of these parameters is also receiving input from other components, those have precedence.
79 | This collection of values is supposed to replace the need for default values in `run()` and make them
80 | dynamically configurable. Keep in mind that only these defaults will count at runtime: defaults given to
81 | the `Input` dataclass (see above) will be ignored.
82 |
83 | - `self.init_parameters = {same parameters that the __init__ method received}`:
84 | In this dictionary you can store any state the components wish to be persisted when they are saved.
85 | These values will be given to the `__init__` method of a new instance when the pipeline is loaded.
86 | Note that by default the `@component` decorator saves the arguments automatically.
87 | However, if a component sets their own `init_parameters` manually in `__init__()`, that will be used instead.
88 | Note: all of the values contained here **must be JSON serializable**. Serialize them manually if needed.
89 |
90 | Components should take only "basic" Python types as parameters of their `__init__` function, or iterables and
91 | dictionaries containing only such values. Anything else (objects, functions, etc) will raise an exception at init
92 | time. If there's the need for such values, consider serializing them to a string.
93 |
94 | _(TODO explain how to use classes and functions in init. In the meantime see `test/components/test_accumulate.py`)_
95 |
96 | The `__init__` must be extrememly lightweight, because it's a frequent operation during the construction and
97 | validation of the pipeline. If a component has some heavy state to initialize (models, backends, etc...) refer to
98 | the `warm_up()` method.
99 |
100 |
101 | ### `warm_up(self)`
102 |
103 | Optional method.
104 |
105 | This method is called by Pipeline before the graph execution. Make sure to avoid double-initializations,
106 | because Pipeline will not keep track of which components it called `warm_up()` on.
107 |
108 |
109 | ### `run(self, data)`
110 |
111 | Mandatory method.
112 |
113 | This is the method where the main functionality of the component should be carried out. It's called by
114 | `Pipeline.run()`.
115 |
116 | When the component should run, Pipeline will call this method with an instance of the dataclass returned by the method decorated with `@component.input`. This dataclass contains:
117 |
118 | - all the input values coming from other components connected to it,
119 | - if any is missing, the corresponding value defined in `self.defaults`, if it exists.
120 |
121 | `run()` must return a single instance of the dataclass declared through the method decorated with `@component.output`.
122 |
123 |
124 | ## Example components
125 |
126 | Here is an example of a simple component that adds a fixed value to its input and returns their sum.
127 |
128 | ```python
129 | from typing import Optional
130 | from canals.component import component
131 |
132 | @component
133 | class AddFixedValue:
134 | """
135 | Adds the value of `add` to `value`. If not given, `add` defaults to 1.
136 | """
137 |
138 | @component.input # type: ignore
139 | def input(self):
140 | class Input:
141 | value: int
142 | add: int
143 |
144 | return Input
145 |
146 | @component.output # type: ignore
147 | def output(self):
148 | class Output:
149 | value: int
150 |
151 | return Output
152 |
153 | def __init__(self, add: Optional[int] = 1):
154 | if add:
155 | self.defaults = {"add": add}
156 |
157 | def run(self, data):
158 | return self.output(value=data.value + data.add)
159 | ```
160 |
161 | See `tests/sample_components` for examples of more complex components with variable inputs and output, and so on.
162 |
--------------------------------------------------------------------------------
/docs/concepts/concepts.md:
--------------------------------------------------------------------------------
1 | # Core concepts
2 |
3 | Canals is a **component orchestration engine**. It can be used to connect a group of smaller objects, called Components,
4 | that perform well-defined tasks into a network, called Pipeline, to achieve a larger goal.
5 |
6 | Components are Python objects that can execute a task, like reading a file, performing calculations, or making API
7 | calls. Canals connects these objects together: it builds a graph of components and takes care of managing their
8 | execution order, making sure that each object receives the input it expects from the other components of the pipeline at the right time.
9 |
10 | Canals relies on two main concepts: Components and Pipelines.
11 |
12 | ## What is a Component?
13 |
14 | A Component is a Python class that performs a well-defined task: for example a REST API call, a mathematical operation,
15 | a data trasformation, writing something to a file or a database, and so on.
16 |
17 | To be recognized as a Component by Canals, a Python class needs to respect these rules:
18 |
19 | 1. Must be decorated with the `@component` decorator.
20 | 3. Have a `run()` method that accepts a `data` parameter of type `ComponentInput` return a single object of type `ComponentOutput`.
21 |
22 | For example, the following is a Component that sums up two numbers:
23 |
24 | ```python
25 | from dataclasses import dataclass
26 | from canals.component import component, ComponentInput, ComponentOutput
27 |
28 | @component
29 | class AddFixedValue:
30 | """
31 | Adds the value of `add` to `value`. If not given, `add` defaults to 1.
32 | """
33 |
34 | @component.input
35 | def input(self):
36 | class Input:
37 | value: int
38 | add: int
39 |
40 | return Input
41 |
42 | @component.output
43 | def output(self):
44 | class Output:
45 | value: int
46 |
47 | return Output
48 |
49 | def __init__(self, add: Optional[int] = 1):
50 | if add:
51 | self.defaults = {"add": add}
52 |
53 | def run(self, data):
54 | return self.output(value=data.value + data.add)
55 | ```
56 |
57 | We will see the details of all of these requirements below.
58 |
59 | ## What is a Pipeline?
60 |
61 | A Pipeline is a network of Components. Pipelines define what components receive and send output to which other, makes
62 | sure all the connections are valid, and takes care of calling the component's `run()` method in the right order.
63 |
64 | Pipeline connects compoonents together through so-called connections, which are the edges of the pipeline graph.
65 | Pipeline is going to make sure that all the connections are valid based on the inputs and output that Components have
66 | declared.
67 |
68 | For example, if a component produces a value of type `List[Document]` and another component expects an input
69 | of type `List[Document]`, Pipeline will be able to connect them. Otherwise, it will raise an exception.
70 |
71 | This is a simple example of how a Pipeline is created:
72 |
73 |
74 | ```python
75 | from canals.pipeline import Pipeline
76 |
77 | # Some Canals components
78 | from my_components import AddFixedValue, MultiplyBy
79 |
80 | pipeline = Pipeline()
81 |
82 | # Components can be initialized as standalone objects.
83 | # These instances can be added to the Pipeline in several places.
84 | multiplication = MultiplyBy(multiply_by=2)
85 | addition = AddFixedValue(add=1)
86 |
87 | # Components are added with a name and an component
88 | pipeline.add_component("double", multiplication)
89 | pipeline.add_component("add_one", addition)
90 | pipeline.add_component("add_one_again", addition) # Component instances can be reused
91 | pipeline.add_component("add_two", AddFixedValue(add=2))
92 |
93 | # Connect the components together
94 | pipeline.connect(connect_from="double", connect_to="add_one")
95 | pipeline.connect(connect_from="add_one", connect_to="add_one_again")
96 | pipeline.connect(connect_from="add_one_again", connect_to="add_two")
97 |
98 | # Pipeline can be drawn
99 | pipeline.draw("pipeline.jpg")
100 |
101 | # Pipelines are run by giving them the data that the input nodes expect.
102 | results = pipeline.run(data={"double": multiplication.input(value=1)})
103 |
104 | print(results)
105 |
106 | # prints {"add_two": AddFixedValue.Output(value=6)}
107 | ```
108 |
109 | This is how the pipeline's graph looks like:
110 |
111 | ```mermaid
112 | graph TD;
113 | double -- value -> value --> add_one
114 | add_one -- value -> value --> add_one_again
115 | add_one_again -- value -> value --> add_two
116 | IN([input]) -- value --> double
117 | add_two -- value --> OUT([output])
118 | ```
119 |
--------------------------------------------------------------------------------
/docs/concepts/pipelines.md:
--------------------------------------------------------------------------------
1 | # Pipelines
2 |
3 | Canals aims to support pipelines of (close to) arbitrary complexity. It currently supports a variety of different topologies, such as:
4 |
5 | - Simple linear pipelines
6 | - Branching pipelines where all or only some branches are executed
7 | - Pipelines merging a variable number of inputs, depending on decisions taken upstream
8 | - Simple loops
9 | - Multiple entry components, either alternative or parallel
10 | - Multiple exit components, either alternative or parallel
11 |
12 | Check the pipeline's test suite for some examples.
13 |
14 | ## Validation
15 |
16 | Pipeline performs validation on the connection type level: when calling `Pipeline.connect()`, it uses the `@component.input` and `@component.output` dataclass fields to make sure that the connection is possible.
17 |
18 | On top of this, specific connections can be specified with the syntax `component_name.input_or_output_field`.
19 |
20 | For example, let's imagine we have two components with the following I/O declared:
21 |
22 | ```python
23 | @component
24 | class ComponentA:
25 |
26 | @component.input
27 | def input(self):
28 | class Input:
29 | input_value: int
30 |
31 | return Input
32 |
33 | @component.output
34 | def output(self):
35 | class Output:
36 | output_value: str
37 |
38 | return Output
39 |
40 | def run(self, data):
41 | return self.output(intermediate_value="hello")
42 |
43 | @component
44 | class ComponentB:
45 |
46 | @component.input
47 | def input(self):
48 | class Input:
49 | input_value: str
50 |
51 | return Input
52 |
53 | @component.output
54 | def output(self):
55 | class Output:
56 | output_value: List[str]
57 |
58 | return Output
59 |
60 | def run(self, data):
61 | return self.output(output_value=["h", "e", "l", "l", "o"])
62 | ```
63 |
64 | This is the behavior of `Pipeline.connect()`:
65 |
66 | ```python
67 | pipeline.add_component('component_a', ComponentA())
68 | pipeline.add_component('component_b', ComponentB())
69 |
70 | # All of these succeeds
71 | pipeline.connect('component_a', 'component_b')
72 | pipeline.connect('component_a.output_value', 'component_b')
73 | pipeline.connect('component_a', 'component_b.input_value')
74 | pipeline.connect('component_a.output_value', 'component_b.input_value')
75 | ```
76 |
77 | These, instead, fail:
78 |
79 | ```python
80 | pipeline.connect('component_a', 'component_a')
81 | # canals.errors.PipelineConnectError: Cannot connect 'component_a' with 'component_a': no matching connections available.
82 | # 'component_a':
83 | # - output_value (str)
84 | # 'component_a':
85 | # - input_value (int, available)
86 |
87 | pipeline.connect('component_b', 'component_a')
88 | # canals.errors.PipelineConnectError: Cannot connect 'component_b' with 'component_a': no matching connections available.
89 | # 'component_b':
90 | # - output_value (List[str])
91 | # 'component_a':
92 | # - input_value (int, available)
93 | ```
94 |
95 | In addition, components names are validated:
96 |
97 | ```python
98 | pipeline.connect('component_a', 'component_c')
99 | # ValueError: Component named component_c not found in the pipeline.
100 | ```
101 |
102 | Just like input and output names, when stated:
103 |
104 | ```python
105 | pipeline.connect('component_a.input_value', 'component_b')
106 | # canals.errors.PipelineConnectError: 'component_a.typo does not exist. Output connections of component_a are: output_value (type str)
107 |
108 | pipeline.connect('component_a', 'component_b.output_value')
109 | # canals.errors.PipelineConnectError: 'component_b.output_value does not exist. Input connections of component_b are: input_value (type str)
110 | ```
111 |
112 | ## Save and Load
113 |
114 | Pipelines can be serialized to Python dictionaries, that can be then dumped to JSON or to any other suitable format, like YAML, TOML, HCL, etc. These pipelines can then be loaded back.
115 |
116 | Here is an example of Pipeline saving and loading:
117 |
118 | ```python
119 | from haystack.pipelines import Pipeline, save_pipelines, load_pipelines
120 |
121 | pipe1 = Pipeline()
122 | pipe2 = Pipeline()
123 |
124 | # .. assemble the pipelines ...
125 |
126 | # Save the pipelines
127 | save_pipelines(
128 | pipelines={
129 | "pipe1": pipe1,
130 | "pipe2": pipe2,
131 | },
132 | path="my_pipelines.json",
133 | _writer=json.dumps
134 | )
135 |
136 | # Load the pipelines
137 | new_pipelines = load_pipelines(
138 | path="my_pipelines.json",
139 | _reader=json.loads
140 | )
141 |
142 | assert new_pipelines["pipe1"] == pipe1
143 | assert new_pipelines["pipe2"] == pipe2
144 | ```
145 |
146 | Note how the save/load functions accept a `_writer`/`_reader` function: this choice frees us from committing strongly to a specific template language, and although a default will be set (be it YAML, TOML, HCL or anything else) the decision can be overridden by passing another explicit reader/writer function to the `save_pipelines`/`load_pipelines` functions.
147 |
148 | This is how the resulting file will look like, assuming a JSON writer was chosen.
149 |
150 | `my_pipeline.json`
151 |
152 | ```python
153 | {
154 | "pipelines": {
155 | "pipe1": {
156 | # All the components that would be added with a
157 | # Pipeline.add_component() call
158 | "components": {
159 | "first_addition": {
160 | "type": "AddValue",
161 | "init_parameters": {
162 | "add": 1
163 | },
164 | },
165 | "double": {
166 | "type": "Double",
167 | "init_parameters": {}
168 | },
169 | "second_addition": {
170 | "type": "AddValue",
171 | "init_parameters": {
172 | "add": 1
173 | },
174 | },
175 | # This is how instances of the same component are reused
176 | "third_addition": {
177 | "refer_to": "pipe1.first_addition"
178 | },
179 | },
180 | # All the components that would be made with a
181 | # Pipeline.connect() call
182 | "connections": [
183 | ("first_addition", "double", "value/value"),
184 | ("double", "second_addition", "value/value"),
185 | ("second_addition", "third_addition", "value/value"),
186 | ],
187 | # All other Pipeline.__init__() parameters go here.
188 | "metadata": {"type": "test pipeline", "author": "me"},
189 | "max_loops_allowed": 100,
190 | },
191 | "pipe2": {
192 | "components": {
193 | "first_addition": {
194 | # We can reference components from other pipelines too!
195 | "refer_to": "pipe1.first_addition",
196 | },
197 | "double": {
198 | "type": "Double",
199 | "init_parameters": {}
200 | },
201 | "second_addition": {
202 | "refer_to": "pipe1.second_addition"
203 | },
204 | },
205 | "connections": [
206 | ("first_addition", "double", "value/value"),
207 | ("double", "second_addition", "value/value"),
208 | ],
209 | "metadata": {"type": "another test pipeline", "author": "you"},
210 | "max_loops_allowed": 100,
211 | },
212 | },
213 | # A list of "dependencies" for the application.
214 | # Used to ensure all external components are present when loading.
215 | "dependencies": ["my_custom_components_module"],
216 | }
217 | ```
218 |
--------------------------------------------------------------------------------
/docs/index.md:
--------------------------------------------------------------------------------
1 | # Canals
2 |
3 |
4 |
5 |
6 |
7 | [](https://pypi.org/project/canals)
8 | [](https://pypi.org/project/canals)
9 |
10 |
11 |
12 | Canals is a **component orchestration engine**. Components are Python objects that can execute a task, like reading a file, performing calculations, or making API calls. Canals connects these objects together: it builds a graph of components and takes care of managing their execution order, making sure that each object receives the input it expects from the other components of the pipeline.
13 |
14 | Canals powers version 2.0 of [Haystack](https://github.com/deepset-ai/haystack).
15 |
16 | ## Installation
17 |
18 | Running:
19 |
20 | ```console
21 | pip install canals
22 | ```
23 |
24 | gives you the bare minimum necessary to run Canals.
25 |
26 | To be able to draw pipelines, please make sure you have either an internet connection (to reach the Mermaid graph renderer at `https://mermaid.ink`) or [graphviz](https://graphviz.org/download/) installed and then install Canals as:
27 |
28 | ### Mermaid
29 | ```console
30 | pip install canals[mermaid]
31 | ```
32 |
33 | ### GraphViz
34 | ```console
35 | sudo apt install graphviz # You may need `graphviz-dev` too
36 | pip install canals[graphviz]
37 | ```
38 |
--------------------------------------------------------------------------------
/images/canals-logo-dark.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/deepset-ai/canals/5fd3dcff3cebff81bb6fc38ed3303c2c2a725321/images/canals-logo-dark.png
--------------------------------------------------------------------------------
/images/canals-logo-light.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/deepset-ai/canals/5fd3dcff3cebff81bb6fc38ed3303c2c2a725321/images/canals-logo-light.png
--------------------------------------------------------------------------------
/mkdocs.yml:
--------------------------------------------------------------------------------
1 | site_name: Canals
2 | site_url: https://deepset-ai.github.io/canals/
3 |
4 | theme:
5 | name: material
6 | features:
7 | - content.code.copy
8 |
9 | plugins:
10 | - search
11 | - mermaid2
12 | - mkdocstrings
13 |
14 | markdown_extensions:
15 | - pymdownx.highlight:
16 | anchor_linenums: true
17 | line_spans: __span
18 | pygments_lang_class: true
19 | - pymdownx.inlinehilite
20 | - pymdownx.snippets
21 | - pymdownx.superfences:
22 | preserve_tabs: true
23 | custom_fences:
24 | - name: mermaid
25 | class: mermaid
26 | format: !!python/name:pymdownx.superfences.fence_code_format
27 |
28 | extra_javascript:
29 | - optionalConfig.js
30 | - https://unpkg.com/mermaid@9.4.0/dist/mermaid.min.js
31 | - extra-loader.js
32 |
33 | nav:
34 | - Get Started: index.md
35 | - Concepts:
36 | - Core Concepts: concepts/concepts.md
37 | - Components: concepts/components.md
38 | - Pipelines: concepts/pipelines.md
39 | - API Docs:
40 | - Canals: api-docs/canals.md
41 | - Component: api-docs/component.md
42 | - Pipeline: api-docs/pipeline.md
43 | - Testing: api-docs/testing.md
44 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [build-system]
2 | requires = ["hatchling"]
3 | build-backend = "hatchling.build"
4 |
5 | [project]
6 | name = "canals"
7 | description = 'A component orchestration engine for Haystack'
8 | readme = "README.md"
9 | requires-python = ">=3.8"
10 | license = "Apache-2.0"
11 | keywords = []
12 | authors = [{ name = "ZanSara", email = "sara.zanzottera@deepset.ai" }]
13 | classifiers = [
14 | "Development Status :: 3 - Alpha",
15 | "License :: Freely Distributable",
16 | "License :: OSI Approved :: Apache Software License",
17 | "Programming Language :: Python",
18 | "Programming Language :: Python :: 3.8",
19 | "Programming Language :: Python :: 3.9",
20 | "Programming Language :: Python :: 3.10",
21 | "Programming Language :: Python :: 3.11",
22 | "Programming Language :: Python :: Implementation :: CPython",
23 | "Programming Language :: Python :: Implementation :: PyPy",
24 | ]
25 | dynamic = ["version"]
26 | dependencies = [
27 | "networkx", # Pipeline graphs
28 | "requests", # Mermaid diagrams
29 | "typing_extensions",
30 | ]
31 |
32 | [project.optional-dependencies]
33 | dev = [
34 | "hatch",
35 | "pre-commit",
36 | "mypy",
37 | "pylint==2.15.10",
38 | "black[jupyter]==22.6.0",
39 | "pytest",
40 | "pytest-cov",
41 | "requests",
42 | "coverage",
43 | ]
44 | docs = ["mkdocs-material", "mkdocstrings[python]", "mkdocs-mermaid2-plugin"]
45 |
46 | [project.urls]
47 | Documentation = "https://github.com/deepset-ai/canals#readme"
48 | Issues = "https://github.com/deepset-ai/canals/issues"
49 | Source = "https://github.com/deepset-ai/canals"
50 |
51 | [tool.hatch.version]
52 | path = "canals/__about__.py"
53 |
54 | [tool.hatch.build]
55 | include = ["/canals/**/*.py"]
56 |
57 | [tool.hatch.envs.default]
58 | dependencies = ["pytest", "pytest-cov", "requests"]
59 |
60 | [tool.hatch.envs.default.scripts]
61 | cov = "pytest --cov-report xml:coverage.xml --cov-config=pyproject.toml --cov=canals --cov=tests {args}"
62 | no-cov = "cov --no-cov {args}"
63 |
64 | [[tool.hatch.envs.test.matrix]]
65 | python = ["38", "39", "310", "311"]
66 |
67 | [tool.coverage.run]
68 | branch = true
69 | parallel = true
70 | omit = ["canals/__about__.py"]
71 |
72 | [tool.coverage.report]
73 | exclude_lines = ["no cov", "if __name__ == .__main__.:", "if TYPE_CHECKING:"]
74 |
75 | [tool.black]
76 | line-length = 120
77 |
78 | [tool.pylint.'MESSAGES CONTROL']
79 | max-line-length = 120
80 | good-names = "e"
81 | max-args = 10
82 | max-locals = 15
83 | disable = [
84 | "fixme",
85 | "line-too-long",
86 | "missing-class-docstring",
87 | "missing-module-docstring",
88 | "too-few-public-methods",
89 | ]
90 |
--------------------------------------------------------------------------------
/sample_components/__init__.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: 2022-present deepset GmbH
2 | #
3 | # SPDX-License-Identifier: Apache-2.0
4 | from sample_components.concatenate import Concatenate
5 | from sample_components.subtract import Subtract
6 | from sample_components.parity import Parity
7 | from sample_components.remainder import Remainder
8 | from sample_components.accumulate import Accumulate
9 | from sample_components.threshold import Threshold
10 | from sample_components.add_value import AddFixedValue
11 | from sample_components.repeat import Repeat
12 | from sample_components.sum import Sum
13 | from sample_components.greet import Greet
14 | from sample_components.double import Double
15 | from sample_components.joiner import StringJoiner, StringListJoiner, FirstIntSelector
16 | from sample_components.hello import Hello
17 | from sample_components.text_splitter import TextSplitter
18 | from sample_components.merge_loop import MergeLoop
19 | from sample_components.self_loop import SelfLoop
20 | from sample_components.fstring import FString
21 |
22 | __all__ = [
23 | "Concatenate",
24 | "Subtract",
25 | "Parity",
26 | "Remainder",
27 | "Accumulate",
28 | "Threshold",
29 | "AddFixedValue",
30 | "MergeLoop",
31 | "Repeat",
32 | "Sum",
33 | "Greet",
34 | "Double",
35 | "StringJoiner",
36 | "Hello",
37 | "TextSplitter",
38 | "StringListJoiner",
39 | "FirstIntSelector",
40 | "SelfLoop",
41 | "FString",
42 | ]
43 |
--------------------------------------------------------------------------------
/sample_components/accumulate.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: 2022-present deepset GmbH
2 | #
3 | # SPDX-License-Identifier: Apache-2.0
4 | from typing import Callable, Optional, Dict, Any
5 | import sys
6 | import builtins
7 | from importlib import import_module
8 |
9 | from canals.serialization import default_to_dict
10 | from canals.component import component
11 | from canals.errors import ComponentDeserializationError
12 |
13 |
14 | def _default_function(first: int, second: int) -> int:
15 | return first + second
16 |
17 |
18 | @component
19 | class Accumulate:
20 | """
21 | Accumulates the value flowing through the connection into an internal attribute.
22 | The sum function can be customized.
23 |
24 | Example of how to deal with serialization when some of the parameters
25 | are not directly serializable.
26 | """
27 |
28 | def __init__(self, function: Optional[Callable] = None):
29 | """
30 | :param function: the function to use to accumulate the values.
31 | The function must take exactly two values.
32 | If it's a callable, it's used as it is.
33 | If it's a string, the component will look for it in sys.modules and
34 | import it at need. This is also a parameter.
35 | """
36 | self.state = 0
37 | self.function: Callable = _default_function if function is None else function # type: ignore
38 |
39 | def to_dict(self) -> Dict[str, Any]: # pylint: disable=missing-function-docstring
40 | module = sys.modules.get(self.function.__module__)
41 | if not module:
42 | raise ValueError("Could not locate the import module.")
43 | if module == builtins:
44 | function_name = self.function.__name__
45 | else:
46 | function_name = f"{module.__name__}.{self.function.__name__}"
47 |
48 | return default_to_dict(self, function=function_name)
49 |
50 | @classmethod
51 | def from_dict(cls, data: Dict[str, Any]) -> "Accumulate": # pylint: disable=missing-function-docstring
52 | if "type" not in data:
53 | raise ComponentDeserializationError("Missing 'type' in component serialization data")
54 | if data["type"] != f"{cls.__module__}.{cls.__name__}":
55 | raise ComponentDeserializationError(f"Class '{data['type']}' can't be deserialized as '{cls.__name__}'")
56 |
57 | init_params = data.get("init_parameters", {})
58 |
59 | accumulator_function = None
60 | if "function" in init_params:
61 | parts = init_params["function"].split(".")
62 | module_name = ".".join(parts[:-1])
63 | function_name = parts[-1]
64 | module = import_module(module_name)
65 | accumulator_function = getattr(module, function_name)
66 |
67 | return cls(function=accumulator_function)
68 |
69 | @component.output_types(value=int)
70 | def run(self, value: int):
71 | """
72 | Accumulates the value flowing through the connection into an internal attribute.
73 | The sum function can be customized.
74 | """
75 | self.state = self.function(self.state, value)
76 | return {"value": self.state}
77 |
--------------------------------------------------------------------------------
/sample_components/add_value.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: 2022-present deepset GmbH
2 | #
3 | # SPDX-License-Identifier: Apache-2.0
4 | from typing import Optional
5 |
6 | from canals import component
7 |
8 |
9 | @component
10 | class AddFixedValue:
11 | """
12 | Adds two values together.
13 | """
14 |
15 | def __init__(self, add: int = 1):
16 | self.add = add
17 |
18 | @component.output_types(result=int)
19 | def run(self, value: int, add: Optional[int] = None):
20 | """
21 | Adds two values together.
22 | """
23 | if add is None:
24 | add = self.add
25 | return {"result": value + add}
26 |
--------------------------------------------------------------------------------
/sample_components/concatenate.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: 2022-present deepset GmbH
2 | #
3 | # SPDX-License-Identifier: Apache-2.0
4 | from typing import Union, List
5 |
6 | from canals import component
7 |
8 |
9 | @component
10 | class Concatenate:
11 | """
12 | Concatenates two values
13 | """
14 |
15 | @component.output_types(value=List[str])
16 | def run(self, first: Union[List[str], str], second: Union[List[str], str]):
17 | """
18 | Concatenates two values
19 | """
20 | if isinstance(first, str) and isinstance(second, str):
21 | res = [first, second]
22 | elif isinstance(first, list) and isinstance(second, list):
23 | res = first + second
24 | elif isinstance(first, list) and isinstance(second, str):
25 | res = first + [second]
26 | elif isinstance(first, str) and isinstance(second, list):
27 | res = [first] + second
28 | return {"value": res}
29 |
--------------------------------------------------------------------------------
/sample_components/double.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: 2022-present deepset GmbH
2 | #
3 | # SPDX-License-Identifier: Apache-2.0
4 | from canals import component
5 |
6 |
7 | @component
8 | class Double:
9 | """
10 | Doubles the input value.
11 | """
12 |
13 | @component.output_types(value=int)
14 | def run(self, value: int):
15 | """
16 | Doubles the input value.
17 | """
18 | return {"value": value * 2}
19 |
--------------------------------------------------------------------------------
/sample_components/fstring.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: 2022-present deepset GmbH
2 | #
3 | # SPDX-License-Identifier: Apache-2.0
4 | from typing import List, Any, Optional
5 |
6 | from canals import component
7 |
8 |
9 | @component
10 | class FString:
11 | """
12 | Takes a template string and a list of variables in input and returns the formatted string in output.
13 | """
14 |
15 | def __init__(self, template: str, variables: Optional[List[str]] = None):
16 | self.template = template
17 | self.variables = variables or []
18 | if "template" in self.variables:
19 | raise ValueError("The variable name 'template' is reserved and cannot be used.")
20 | component.set_input_types(self, **{variable: Any for variable in self.variables})
21 |
22 | @component.output_types(string=str)
23 | def run(self, template: Optional[str] = None, **kwargs):
24 | """
25 | Takes a template string and a list of variables in input and returns the formatted string in output.
26 |
27 | If the template is not given, the component will use the one given at initialization.
28 | """
29 | if not template:
30 | template = self.template
31 | return {"string": template.format(**kwargs)}
32 |
--------------------------------------------------------------------------------
/sample_components/greet.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: 2022-present deepset GmbH
2 | #
3 | # SPDX-License-Identifier: Apache-2.0
4 | from typing import Optional
5 | import logging
6 |
7 | from canals import component
8 |
9 |
10 | logger = logging.getLogger(__name__)
11 |
12 |
13 | @component
14 | class Greet:
15 | """
16 | Logs a greeting message without affecting the value passing on the connection.
17 | """
18 |
19 | def __init__(
20 | self,
21 | message: str = "\nGreeting component says: Hi! The value is {value}\n",
22 | log_level: str = "INFO",
23 | ):
24 | """
25 | :param message: the message to log. Can use `{value}` to embed the value.
26 | :param log_level: the level to log at.
27 | """
28 | if log_level and not getattr(logging, log_level):
29 | raise ValueError(f"This log level does not exist: {log_level}")
30 | self.message = message
31 | self.log_level = log_level
32 |
33 | @component.output_types(value=int)
34 | def run(self, value: int, message: Optional[str] = None, log_level: Optional[str] = None):
35 | """
36 | Logs a greeting message without affecting the value passing on the connection.
37 | """
38 | if not message:
39 | message = self.message
40 | if not log_level:
41 | log_level = self.log_level
42 |
43 | level = getattr(logging, log_level, None)
44 | if not level:
45 | raise ValueError(f"This log level does not exist: {log_level}")
46 |
47 | logger.log(level=level, msg=message.format(value=value))
48 | return {"value": value}
49 |
--------------------------------------------------------------------------------
/sample_components/hello.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: 2022-present deepset GmbH
2 | #
3 | # SPDX-License-Identifier: Apache-2.0
4 | from canals import component
5 |
6 |
7 | @component
8 | class Hello:
9 | @component.output_types(output=str)
10 | def run(self, word: str):
11 | """
12 | Takes a string in input and returns "Hello, !"
13 | in output.
14 | """
15 | return {"output": f"Hello, {word}!"}
16 |
--------------------------------------------------------------------------------
/sample_components/joiner.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: 2022-present deepset GmbH
2 | #
3 | # SPDX-License-Identifier: Apache-2.0
4 | from typing import List
5 |
6 | from canals import component
7 | from canals.component.types import Variadic
8 |
9 |
10 | @component
11 | class StringJoiner:
12 | @component.output_types(output=str)
13 | def run(self, input_str: Variadic[str]):
14 | """
15 | Take strings from multiple input nodes and join them
16 | into a single one returned in output. Since `input_str`
17 | is Variadic, we know we'll receive a List[str].
18 | """
19 | return {"output": " ".join(input_str)}
20 |
21 |
22 | @component
23 | class StringListJoiner:
24 | @component.output_types(output=str)
25 | def run(self, inputs: Variadic[List[str]]):
26 | """
27 | Take list of strings from multiple input nodes and join them
28 | into a single one returned in output. Since `input_str`
29 | is Variadic, we know we'll receive a List[List[str]].
30 | """
31 | retval: List[str] = []
32 | for list_of_strings in inputs:
33 | retval += list_of_strings
34 |
35 | return {"output": retval}
36 |
37 |
38 | @component
39 | class FirstIntSelector:
40 | @component.output_types(output=int)
41 | def run(self, inputs: Variadic[int]):
42 | """
43 | Take intd from multiple input nodes and return the first one
44 | that is not None. Since `input` is Variadic, we know we'll
45 | receive a List[int].
46 | """
47 | for inp in inputs: # type: ignore
48 | if inp is not None:
49 | return {"output": inp}
50 | return {}
51 |
--------------------------------------------------------------------------------
/sample_components/merge_loop.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: 2022-present deepset GmbH
2 | #
3 | # SPDX-License-Identifier: Apache-2.0
4 | from typing import List, Any, Optional, Dict
5 | import sys
6 |
7 | from canals import component
8 | from canals.errors import DeserializationError
9 | from canals.serialization import default_to_dict
10 |
11 |
12 | @component
13 | class MergeLoop:
14 | def __init__(self, expected_type: Any, inputs: List[str]):
15 | component.set_input_types(self, **{input_name: Optional[expected_type] for input_name in inputs})
16 | component.set_output_types(self, value=expected_type)
17 |
18 | if expected_type.__module__ == "builtins":
19 | self.expected_type = f"builtins.{expected_type.__name__}"
20 | elif expected_type.__module__ == "typing":
21 | self.expected_type = str(expected_type)
22 | else:
23 | self.expected_type = f"{expected_type.__module__}.{expected_type.__name__}"
24 |
25 | self.inputs = inputs
26 |
27 | def to_dict(self) -> Dict[str, Any]: # pylint: disable=missing-function-docstring
28 | return default_to_dict(
29 | self,
30 | expected_type=self.expected_type,
31 | inputs=self.inputs,
32 | )
33 |
34 | @classmethod
35 | def from_dict(cls, data: Dict[str, Any]) -> "MergeLoop": # pylint: disable=missing-function-docstring
36 | if "type" not in data:
37 | raise DeserializationError("Missing 'type' in component serialization data")
38 | if data["type"] != f"{cls.__module__}.{cls.__name__}":
39 | raise DeserializationError(f"Class '{data['type']}' can't be deserialized as '{cls.__name__}'")
40 |
41 | init_params = data.get("init_parameters", {})
42 |
43 | if "expected_type" not in init_params:
44 | raise DeserializationError("Missing 'expected_type' field in 'init_parameters'")
45 |
46 | if "inputs" not in init_params:
47 | raise DeserializationError("Missing 'inputs' field in 'init_parameters'")
48 |
49 | module = sys.modules[__name__]
50 | fully_qualified_type_name = init_params["expected_type"]
51 | if fully_qualified_type_name.startswith("builtins."):
52 | module = sys.modules["builtins"]
53 | type_name = fully_qualified_type_name.split(".")[-1]
54 | try:
55 | expected_type = getattr(module, type_name)
56 | except AttributeError as exc:
57 | raise DeserializationError(
58 | f"Can't find type '{type_name}', import '{fully_qualified_type_name}' to fix the issue"
59 | ) from exc
60 |
61 | inputs = init_params["inputs"]
62 |
63 | return cls(expected_type=expected_type, inputs=inputs)
64 |
65 | def run(self, **kwargs):
66 | """
67 | :param kwargs: find the first non-None value and return it.
68 | """
69 | for value in kwargs.values():
70 | if value is not None:
71 | return {"value": value}
72 | return {}
73 |
--------------------------------------------------------------------------------
/sample_components/parity.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: 2022-present deepset GmbH
2 | #
3 | # SPDX-License-Identifier: Apache-2.0
4 | from canals import component
5 |
6 |
7 | @component
8 | class Parity: # pylint: disable=too-few-public-methods
9 | """
10 | Redirects the value, unchanged, along the 'even' connection if even, or along the 'odd' one if odd.
11 | """
12 |
13 | @component.output_types(even=int, odd=int)
14 | def run(self, value: int):
15 | """
16 | :param value: The value to check for parity
17 | """
18 | remainder = value % 2
19 | if remainder:
20 | return {"odd": value}
21 | return {"even": value}
22 |
--------------------------------------------------------------------------------
/sample_components/remainder.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: 2022-present deepset GmbH
2 | #
3 | # SPDX-License-Identifier: Apache-2.0
4 | from canals import component
5 |
6 |
7 | @component
8 | class Remainder:
9 | def __init__(self, divisor=3):
10 | if divisor == 0:
11 | raise ValueError("Can't divide by zero")
12 | self.divisor = divisor
13 | component.set_output_types(self, **{f"remainder_is_{val}": int for val in range(divisor)})
14 |
15 | def run(self, value: int):
16 | """
17 | :param value: the value to check the remainder of.
18 | """
19 | remainder = value % self.divisor
20 | output = {f"remainder_is_{remainder}": value}
21 | return output
22 |
--------------------------------------------------------------------------------
/sample_components/repeat.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: 2022-present deepset GmbH
2 | #
3 | # SPDX-License-Identifier: Apache-2.0
4 | from typing import List
5 |
6 | from canals import component
7 |
8 |
9 | @component
10 | class Repeat:
11 | def __init__(self, outputs: List[str]):
12 | self.outputs = outputs
13 | component.set_output_types(self, **{k: int for k in outputs})
14 |
15 | def run(self, value: int):
16 | """
17 | :param value: the value to repeat.
18 | """
19 | return {val: value for val in self.outputs}
20 |
--------------------------------------------------------------------------------
/sample_components/self_loop.py:
--------------------------------------------------------------------------------
1 | from canals import component
2 | from canals.component.types import Variadic
3 |
4 |
5 | @component
6 | class SelfLoop:
7 | """
8 | Decreases the initial value in steps of 1 until the target value is reached.
9 | For no good reason it uses a self-loop to do so :)
10 | """
11 |
12 | def __init__(self, target: int = 0):
13 | self.target = target
14 |
15 | @component.output_types(current_value=int, final_result=int)
16 | def run(self, values: Variadic[int]):
17 | """
18 | Decreases the input value in steps of 1 until the target value is reached.
19 | """
20 | value = values[0] # type: ignore
21 | value -= 1
22 | if value == self.target:
23 | return {"final_result": value}
24 | return {"current_value": value}
25 |
--------------------------------------------------------------------------------
/sample_components/subtract.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: 2022-present deepset GmbH
2 | #
3 | # SPDX-License-Identifier: Apache-2.0
4 | from canals import component
5 |
6 |
7 | @component
8 | class Subtract:
9 | """
10 | Compute the difference between two values.
11 | """
12 |
13 | @component.output_types(difference=int)
14 | def run(self, first_value: int, second_value: int):
15 | """
16 | :param first_value: name of the connection carrying the value to subtract from.
17 | :param second_value: name of the connection carrying the value to subtract.
18 | """
19 | return {"difference": first_value - second_value}
20 |
--------------------------------------------------------------------------------
/sample_components/sum.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: 2022-present deepset GmbH
2 | #
3 | # SPDX-License-Identifier: Apache-2.0
4 | from canals import component
5 | from canals.component.types import Variadic
6 |
7 |
8 | @component
9 | class Sum:
10 | @component.output_types(total=int)
11 | def run(self, values: Variadic[int]):
12 | """
13 | :param value: the values to sum.
14 | """
15 | return {"total": sum(values)}
16 |
--------------------------------------------------------------------------------
/sample_components/text_splitter.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: 2022-present deepset GmbH
2 | #
3 | # SPDX-License-Identifier: Apache-2.0
4 | from typing import List
5 |
6 | from canals import component
7 |
8 |
9 | @component
10 | class TextSplitter:
11 | @component.output_types(output=List[str])
12 | def run(self, sentence: str):
13 | """
14 | Takes a sentence in input and returns its words"
15 | in output.
16 | """
17 | return {"output": sentence.split()}
18 |
--------------------------------------------------------------------------------
/sample_components/threshold.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: 2022-present deepset GmbH
2 | #
3 | # SPDX-License-Identifier: Apache-2.0
4 | from typing import Optional
5 |
6 | from canals import component
7 |
8 |
9 | @component
10 | class Threshold: # pylint: disable=too-few-public-methods
11 | """
12 | Redirects the value, unchanged, along a different connection whether the value is above
13 | or below the given threshold.
14 |
15 | :param threshold: the number to compare the input value against. This is also a parameter.
16 | """
17 |
18 | def __init__(self, threshold: int = 10):
19 | """
20 | :param threshold: the number to compare the input value against.
21 | """
22 | self.threshold = threshold
23 |
24 | @component.output_types(above=int, below=int)
25 | def run(self, value: int, threshold: Optional[int] = None):
26 | """
27 | Redirects the value, unchanged, along a different connection whether the value is above
28 | or below the given threshold.
29 |
30 | :param threshold: the number to compare the input value against. This is also a parameter.
31 | """
32 | if threshold is None:
33 | threshold = self.threshold
34 |
35 | if value < threshold:
36 | return {"below": value}
37 | return {"above": value}
38 |
--------------------------------------------------------------------------------
/test/__init__.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: 2022-present deepset GmbH
2 | #
3 | # SPDX-License-Identifier: Apache-2.0
4 |
--------------------------------------------------------------------------------
/test/component/test_component.py:
--------------------------------------------------------------------------------
1 | import typing
2 | from typing import Any, Optional
3 |
4 | import pytest
5 |
6 | from canals import component
7 | from canals.component.descriptions import find_component_inputs, find_component_outputs
8 | from canals.errors import ComponentError
9 | from canals.component import InputSocket, OutputSocket, Component
10 |
11 |
12 | def test_correct_declaration():
13 | @component
14 | class MockComponent:
15 | def to_dict(self):
16 | return {}
17 |
18 | @classmethod
19 | def from_dict(cls, data):
20 | return cls()
21 |
22 | @component.output_types(output_value=int)
23 | def run(self, input_value: int):
24 | return {"output_value": input_value}
25 |
26 | # Verifies also instantiation works with no issues
27 | assert MockComponent()
28 | assert component.registry["test_component.MockComponent"] == MockComponent
29 |
30 |
31 | def test_correct_declaration_with_additional_readonly_property():
32 | @component
33 | class MockComponent:
34 | @property
35 | def store(self):
36 | return "test_store"
37 |
38 | def to_dict(self):
39 | return {}
40 |
41 | @classmethod
42 | def from_dict(cls, data):
43 | return cls()
44 |
45 | @component.output_types(output_value=int)
46 | def run(self, input_value: int):
47 | return {"output_value": input_value}
48 |
49 | # Verifies that instantiation works with no issues
50 | assert MockComponent()
51 | assert component.registry["test_component.MockComponent"] == MockComponent
52 | assert MockComponent().store == "test_store"
53 |
54 |
55 | def test_correct_declaration_with_additional_writable_property():
56 | @component
57 | class MockComponent:
58 | @property
59 | def store(self):
60 | return "test_store"
61 |
62 | @store.setter
63 | def store(self, value):
64 | self._store = value
65 |
66 | def to_dict(self):
67 | return {}
68 |
69 | @classmethod
70 | def from_dict(cls, data):
71 | return cls()
72 |
73 | @component.output_types(output_value=int)
74 | def run(self, input_value: int):
75 | return {"output_value": input_value}
76 |
77 | # Verifies that instantiation works with no issues
78 | assert component.registry["test_component.MockComponent"] == MockComponent
79 | comp = MockComponent()
80 | comp.store = "test_store"
81 | assert comp.store == "test_store"
82 |
83 |
84 | def test_missing_run():
85 | with pytest.raises(ComponentError, match="must have a 'run\(\)' method"):
86 |
87 | @component
88 | class MockComponent:
89 | def another_method(self, input_value: int):
90 | return {"output_value": input_value}
91 |
92 |
93 | def test_set_input_types():
94 | class MockComponent:
95 | def __init__(self):
96 | component.set_input_types(self, value=Any)
97 |
98 | def to_dict(self):
99 | return {}
100 |
101 | @classmethod
102 | def from_dict(cls, data):
103 | return cls()
104 |
105 | @component.output_types(value=int)
106 | def run(self, **kwargs):
107 | return {"value": 1}
108 |
109 | comp = MockComponent()
110 | assert comp.__canals_input__ == {"value": InputSocket("value", Any)}
111 | assert comp.run() == {"value": 1}
112 |
113 |
114 | def test_set_output_types():
115 | @component
116 | class MockComponent:
117 | def __init__(self):
118 | component.set_output_types(self, value=int)
119 |
120 | def to_dict(self):
121 | return {}
122 |
123 | @classmethod
124 | def from_dict(cls, data):
125 | return cls()
126 |
127 | def run(self, value: int):
128 | return {"value": 1}
129 |
130 | comp = MockComponent()
131 | assert comp.__canals_output__ == {"value": OutputSocket("value", int)}
132 |
133 |
134 | def test_output_types_decorator_with_compatible_type():
135 | @component
136 | class MockComponent:
137 | @component.output_types(value=int)
138 | def run(self, value: int):
139 | return {"value": 1}
140 |
141 | def to_dict(self):
142 | return {}
143 |
144 | @classmethod
145 | def from_dict(cls, data):
146 | return cls()
147 |
148 | comp = MockComponent()
149 | assert comp.__canals_output__ == {"value": OutputSocket("value", int)}
150 |
151 |
152 | def test_component_decorator_set_it_as_component():
153 | @component
154 | class MockComponent:
155 | @component.output_types(value=int)
156 | def run(self, value: int):
157 | return {"value": 1}
158 |
159 | def to_dict(self):
160 | return {}
161 |
162 | @classmethod
163 | def from_dict(cls, data):
164 | return cls()
165 |
166 | comp = MockComponent()
167 | assert isinstance(comp, Component)
168 |
169 |
170 | def test_inputs_method_no_inputs():
171 | @component
172 | class MockComponent:
173 | def run(self):
174 | return {"value": 1}
175 |
176 | comp = MockComponent()
177 | assert find_component_inputs(comp) == {}
178 |
179 |
180 | def test_inputs_method_one_input():
181 | @component
182 | class MockComponent:
183 | def run(self, value: int):
184 | return {"value": 1}
185 |
186 | comp = MockComponent()
187 | assert find_component_inputs(comp) == {"value": {"is_mandatory": True, "is_variadic": False, "type": int}}
188 |
189 |
190 | def test_inputs_method_multiple_inputs():
191 | @component
192 | class MockComponent:
193 | def run(self, value1: int, value2: str):
194 | return {"value": 1}
195 |
196 | comp = MockComponent()
197 | assert find_component_inputs(comp) == {
198 | "value1": {"is_mandatory": True, "is_variadic": False, "type": int},
199 | "value2": {"is_mandatory": True, "is_variadic": False, "type": str},
200 | }
201 |
202 |
203 | def test_inputs_method_multiple_inputs_optional():
204 | @component
205 | class MockComponent:
206 | def run(self, value1: int, value2: Optional[str]):
207 | return {"value": 1}
208 |
209 | comp = MockComponent()
210 | assert find_component_inputs(comp) == {
211 | "value1": {"is_mandatory": True, "is_variadic": False, "type": int},
212 | "value2": {"is_mandatory": True, "is_variadic": False, "type": typing.Optional[str]},
213 | }
214 |
215 |
216 | def test_inputs_method_variadic_positional_args():
217 | @component
218 | class MockComponent:
219 | def __init__(self):
220 | component.set_input_types(self, value=Any)
221 |
222 | def run(self, *args):
223 | return {"value": 1}
224 |
225 | comp = MockComponent()
226 | assert find_component_inputs(comp) == {"value": {"is_mandatory": True, "is_variadic": False, "type": typing.Any}}
227 |
228 |
229 | def test_inputs_method_variadic_keyword_positional_args():
230 | @component
231 | class MockComponent:
232 | def __init__(self):
233 | component.set_input_types(self, value=Any)
234 |
235 | def run(self, **kwargs):
236 | return {"value": 1}
237 |
238 | comp = MockComponent()
239 | assert find_component_inputs(comp) == {"value": {"is_mandatory": True, "is_variadic": False, "type": typing.Any}}
240 |
241 |
242 | def test_inputs_dynamic_from_init():
243 | @component
244 | class MockComponent:
245 | def __init__(self):
246 | component.set_input_types(self, value=int)
247 |
248 | def run(self, value: int, **kwargs):
249 | return {"value": 1}
250 |
251 | comp = MockComponent()
252 | assert find_component_inputs(comp) == {"value": {"is_mandatory": True, "is_variadic": False, "type": int}}
253 |
254 |
255 | def test_outputs_method_no_outputs():
256 | @component
257 | class MockComponent:
258 | def run(self):
259 | return {}
260 |
261 | comp = MockComponent()
262 | assert find_component_outputs(comp) == {}
263 |
264 |
265 | def test_outputs_method_one_output():
266 | @component
267 | class MockComponent:
268 | @component.output_types(value=int)
269 | def run(self):
270 | return {"value": 1}
271 |
272 | comp = MockComponent()
273 | assert find_component_outputs(comp) == {"value": {"type": int}}
274 |
275 |
276 | def test_outputs_method_multiple_outputs():
277 | @component
278 | class MockComponent:
279 | @component.output_types(value1=int, value2=str)
280 | def run(self):
281 | return {"value1": 1, "value2": "test"}
282 |
283 | comp = MockComponent()
284 | assert find_component_outputs(comp) == {"value1": {"type": int}, "value2": {"type": str}}
285 |
286 |
287 | def test_outputs_dynamic_from_init():
288 | @component
289 | class MockComponent:
290 | def __init__(self):
291 | component.set_output_types(self, value=int)
292 |
293 | def run(self):
294 | return {"value": 1}
295 |
296 | comp = MockComponent()
297 | assert find_component_outputs(comp) == {"value": {"type": int}}
298 |
--------------------------------------------------------------------------------
/test/component/test_connection.py:
--------------------------------------------------------------------------------
1 | from canals.component.connection import Connection
2 | from canals.component.sockets import InputSocket, OutputSocket
3 | from canals.errors import PipelineConnectError
4 |
5 | import pytest
6 |
7 |
8 | @pytest.mark.parametrize(
9 | "c,expected",
10 | [
11 | (
12 | Connection("source_component", OutputSocket("out", int), "destination_component", InputSocket("in", int)),
13 | "source_component.out (int) --> (int) destination_component.in",
14 | ),
15 | (
16 | Connection(None, None, "destination_component", InputSocket("in", int)),
17 | "input needed --> (int) destination_component.in",
18 | ),
19 | (Connection("source_component", OutputSocket("out", int), None, None), "source_component.out (int) --> output"),
20 | (Connection(None, None, None, None), "input needed --> output"),
21 | ],
22 | )
23 | def test_repr(c, expected):
24 | assert str(c) == expected
25 |
26 |
27 | def test_is_mandatory():
28 | c = Connection(None, None, "destination_component", InputSocket("in", int))
29 | assert c.is_mandatory
30 |
31 | c = Connection(None, None, "destination_component", InputSocket("in", int, is_mandatory=False))
32 | assert not c.is_mandatory
33 |
34 | c = Connection("source_component", OutputSocket("out", int), None, None)
35 | assert not c.is_mandatory
36 |
37 |
38 | def test_from_list_of_sockets():
39 | sender_sockets = [
40 | OutputSocket("out_int", int),
41 | OutputSocket("out_str", str),
42 | ]
43 |
44 | receiver_sockets = [
45 | InputSocket("in_str", str),
46 | ]
47 |
48 | c = Connection.from_list_of_sockets("from_node", sender_sockets, "to_node", receiver_sockets)
49 | assert c.sender_socket.name == "out_str" # type:ignore
50 |
51 |
52 | def test_from_list_of_sockets_not_possible():
53 | sender_sockets = [
54 | OutputSocket("out_int", int),
55 | OutputSocket("out_str", str),
56 | ]
57 |
58 | receiver_sockets = [
59 | InputSocket("in_list", list),
60 | InputSocket("in_tuple", tuple),
61 | ]
62 |
63 | with pytest.raises(PipelineConnectError, match="no matching connections available"):
64 | Connection.from_list_of_sockets("from_node", sender_sockets, "to_node", receiver_sockets)
65 |
66 |
67 | def test_from_list_of_sockets_too_many():
68 | sender_sockets = [
69 | OutputSocket("out_int", int),
70 | OutputSocket("out_str", str),
71 | ]
72 |
73 | receiver_sockets = [
74 | InputSocket("in_int", int),
75 | InputSocket("in_str", str),
76 | ]
77 |
78 | with pytest.raises(PipelineConnectError, match="more than one connection is possible"):
79 | Connection.from_list_of_sockets("from_node", sender_sockets, "to_node", receiver_sockets)
80 |
81 |
82 | def test_from_list_of_sockets_only_one():
83 | sender_sockets = [
84 | OutputSocket("out_int", int),
85 | ]
86 |
87 | receiver_sockets = [
88 | InputSocket("in_str", str),
89 | ]
90 |
91 | with pytest.raises(PipelineConnectError, match="their declared input and output types do not match"):
92 | Connection.from_list_of_sockets("from_node", sender_sockets, "to_node", receiver_sockets)
93 |
--------------------------------------------------------------------------------
/test/conftest.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 |
3 | import pytest
4 |
5 | from unittest.mock import patch, MagicMock
6 |
7 |
8 | @pytest.fixture
9 | def test_files():
10 | return Path(__file__).parent / "test_files"
11 |
12 |
13 | @pytest.fixture(autouse=True)
14 | def mock_mermaid_request(test_files):
15 | """
16 | Prevents real requests to https://mermaid.ink/
17 | """
18 | with patch("canals.pipeline.draw.mermaid.requests.get") as mock_get:
19 | mock_response = MagicMock()
20 | mock_response.status_code = 200
21 | mock_response.content = open(test_files / "mermaid_mock" / "test_response.png", "rb").read()
22 | mock_get.return_value = mock_response
23 | yield
24 |
--------------------------------------------------------------------------------
/test/pipeline/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/deepset-ai/canals/5fd3dcff3cebff81bb6fc38ed3303c2c2a725321/test/pipeline/__init__.py
--------------------------------------------------------------------------------
/test/pipeline/integration/__init__.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: 2022-present deepset GmbH
2 | #
3 | # SPDX-License-Identifier: Apache-2.0
4 |
--------------------------------------------------------------------------------
/test/pipeline/integration/test_complex_pipeline.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: 2022-present deepset GmbH
2 | #
3 | # SPDX-License-Identifier: Apache-2.0
4 | from pathlib import Path
5 | from pprint import pprint
6 | import logging
7 |
8 | from canals.pipeline import Pipeline
9 | from sample_components import (
10 | Accumulate,
11 | AddFixedValue,
12 | Greet,
13 | Parity,
14 | Threshold,
15 | Double,
16 | Sum,
17 | Repeat,
18 | Subtract,
19 | MergeLoop,
20 | )
21 |
22 | logging.basicConfig(level=logging.DEBUG)
23 |
24 |
25 | def test_complex_pipeline(tmp_path):
26 | loop_merger = MergeLoop(expected_type=int, inputs=["in_1", "in_2"])
27 | summer = Sum()
28 |
29 | pipeline = Pipeline(max_loops_allowed=2)
30 | pipeline.add_component("greet_first", Greet(message="Hello, the value is {value}."))
31 | pipeline.add_component("accumulate_1", Accumulate())
32 | pipeline.add_component("add_two", AddFixedValue(add=2))
33 | pipeline.add_component("parity", Parity())
34 | pipeline.add_component("add_one", AddFixedValue(add=1))
35 | pipeline.add_component("accumulate_2", Accumulate())
36 |
37 | pipeline.add_component("loop_merger", loop_merger)
38 | pipeline.add_component("below_10", Threshold(threshold=10))
39 | pipeline.add_component("double", Double())
40 |
41 | pipeline.add_component("greet_again", Greet(message="Hello again, now the value is {value}."))
42 | pipeline.add_component("sum", summer)
43 |
44 | pipeline.add_component("greet_enumerator", Greet(message="Hello from enumerator, here the value became {value}."))
45 | pipeline.add_component("enumerate", Repeat(outputs=["first", "second"]))
46 | pipeline.add_component("add_three", AddFixedValue(add=3))
47 |
48 | pipeline.add_component("diff", Subtract())
49 | pipeline.add_component("greet_one_last_time", Greet(message="Bye bye! The value here is {value}!"))
50 | pipeline.add_component("replicate", Repeat(outputs=["first", "second"]))
51 | pipeline.add_component("add_five", AddFixedValue(add=5))
52 | pipeline.add_component("add_four", AddFixedValue(add=4))
53 | pipeline.add_component("accumulate_3", Accumulate())
54 |
55 | pipeline.connect("greet_first", "accumulate_1")
56 | pipeline.connect("accumulate_1", "add_two")
57 | pipeline.connect("add_two", "parity")
58 |
59 | pipeline.connect("parity.even", "greet_again")
60 | pipeline.connect("greet_again", "sum.values")
61 | pipeline.connect("sum", "diff.first_value")
62 | pipeline.connect("diff", "greet_one_last_time")
63 | pipeline.connect("greet_one_last_time", "replicate")
64 | pipeline.connect("replicate.first", "add_five.value")
65 | pipeline.connect("replicate.second", "add_four.value")
66 | pipeline.connect("add_four", "accumulate_3")
67 |
68 | pipeline.connect("parity.odd", "add_one.value")
69 | pipeline.connect("add_one", "loop_merger.in_1")
70 | pipeline.connect("loop_merger", "below_10")
71 |
72 | pipeline.connect("below_10.below", "double")
73 | pipeline.connect("double", "loop_merger.in_2")
74 |
75 | pipeline.connect("below_10.above", "accumulate_2")
76 | pipeline.connect("accumulate_2", "diff.second_value")
77 |
78 | pipeline.connect("greet_enumerator", "enumerate")
79 | pipeline.connect("enumerate.second", "sum.values")
80 |
81 | pipeline.connect("enumerate.first", "add_three.value")
82 | pipeline.connect("add_three", "sum.values")
83 |
84 | pipeline.draw(tmp_path / "complex_pipeline.png")
85 |
86 | results = pipeline.run({"greet_first": {"value": 1}, "greet_enumerator": {"value": 1}})
87 | assert results == {"accumulate_3": {"value": -7}, "add_five": {"result": -6}}
88 |
89 |
90 | if __name__ == "__main__":
91 | test_complex_pipeline(Path(__file__).parent)
92 |
--------------------------------------------------------------------------------
/test/pipeline/integration/test_default_value.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: 2022-present deepset GmbH
2 | #
3 | # SPDX-License-Identifier: Apache-2.0
4 | from pathlib import Path
5 | from pprint import pprint
6 |
7 | from canals import Pipeline, component
8 | from sample_components import AddFixedValue, Sum
9 |
10 | import logging
11 |
12 | logging.basicConfig(level=logging.DEBUG)
13 |
14 |
15 | @component
16 | class WithDefault:
17 | @component.output_types(b=int)
18 | def run(self, a: int, b: int = 2):
19 | return {"c": a + b}
20 |
21 |
22 | def test_pipeline(tmp_path):
23 | # https://github.com/deepset-ai/canals/issues/105
24 | pipeline = Pipeline()
25 | pipeline.add_component("with_defaults", WithDefault())
26 | pipeline.draw(tmp_path / "default_value.png")
27 |
28 | # Pass all the inputs
29 | results = pipeline.run({"with_defaults": {"a": 40, "b": 30}})
30 | pprint(results)
31 | assert results == {"with_defaults": {"c": 70}}
32 |
33 | # Rely on default value for 'b'
34 | results = pipeline.run({"with_defaults": {"a": 40}})
35 | pprint(results)
36 | assert results == {"with_defaults": {"c": 42}}
37 |
38 |
39 | if __name__ == "__main__":
40 | test_pipeline(Path(__file__).parent)
41 |
--------------------------------------------------------------------------------
/test/pipeline/integration/test_distinct_loops_pipeline.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: 2022-present deepset GmbH
2 | #
3 | # SPDX-License-Identifier: Apache-2.0
4 | from typing import *
5 | from pathlib import Path
6 | from pprint import pprint
7 |
8 | from canals.pipeline import Pipeline
9 | from sample_components import AddFixedValue, MergeLoop, Remainder, FirstIntSelector
10 |
11 | import logging
12 |
13 | logging.basicConfig(level=logging.DEBUG)
14 |
15 |
16 | def test_pipeline_equally_long_branches(tmp_path):
17 | pipeline = Pipeline(max_loops_allowed=10)
18 | pipeline.add_component("merge", MergeLoop(expected_type=int, inputs=["in", "in_1", "in_2"]))
19 | pipeline.add_component("remainder", Remainder(divisor=3))
20 | pipeline.add_component("add_one", AddFixedValue(add=1))
21 | pipeline.add_component("add_two", AddFixedValue(add=2))
22 |
23 | pipeline.connect("merge.value", "remainder.value")
24 | pipeline.connect("remainder.remainder_is_1", "add_two.value")
25 | pipeline.connect("remainder.remainder_is_2", "add_one.value")
26 | pipeline.connect("add_two", "merge.in_2")
27 | pipeline.connect("add_one", "merge.in_1")
28 |
29 | pipeline.draw(tmp_path / "distinct_loops_pipeline_same_branches.png")
30 |
31 | results = pipeline.run({"merge": {"in": 0}})
32 | pprint(results)
33 | assert results == {"remainder": {"remainder_is_0": 0}}
34 |
35 | results = pipeline.run({"merge": {"in": 3}})
36 | pprint(results)
37 | assert results == {"remainder": {"remainder_is_0": 3}}
38 |
39 | results = pipeline.run({"merge": {"in": 4}})
40 | pprint(results)
41 | assert results == {"remainder": {"remainder_is_0": 6}}
42 |
43 | results = pipeline.run({"merge": {"in": 5}})
44 | pprint(results)
45 | assert results == {"remainder": {"remainder_is_0": 6}}
46 |
47 | results = pipeline.run({"merge": {"in": 6}})
48 | pprint(results)
49 | assert results == {"remainder": {"remainder_is_0": 6}}
50 |
51 |
52 | def test_pipeline_differing_branches(tmp_path):
53 | pipeline = Pipeline(max_loops_allowed=10)
54 | pipeline.add_component("merge", MergeLoop(expected_type=int, inputs=["in", "in_1", "in_2"]))
55 | pipeline.add_component("remainder", Remainder(divisor=3))
56 | pipeline.add_component("add_one", AddFixedValue(add=1))
57 | pipeline.add_component("add_two_1", AddFixedValue(add=1))
58 | pipeline.add_component("add_two_2", AddFixedValue(add=1))
59 |
60 | pipeline.connect("merge.value", "remainder.value")
61 | pipeline.connect("remainder.remainder_is_1", "add_two_1.value")
62 | pipeline.connect("add_two_1", "add_two_2.value")
63 | pipeline.connect("add_two_2", "merge.in_2")
64 | pipeline.connect("remainder.remainder_is_2", "add_one.value")
65 | pipeline.connect("add_one", "merge.in_1")
66 |
67 | pipeline.draw(tmp_path / "distinct_loops_pipeline_different_branches.png")
68 |
69 | results = pipeline.run({"merge": {"in": 0}})
70 | pprint(results)
71 | assert results == {"remainder": {"remainder_is_0": 0}}
72 |
73 | results = pipeline.run({"merge": {"in": 3}})
74 | pprint(results)
75 | assert results == {"remainder": {"remainder_is_0": 3}}
76 |
77 | results = pipeline.run({"merge": {"in": 4}})
78 | pprint(results)
79 | assert results == {"remainder": {"remainder_is_0": 6}}
80 |
81 | results = pipeline.run({"merge": {"in": 5}})
82 | pprint(results)
83 | assert results == {"remainder": {"remainder_is_0": 6}}
84 |
85 | results = pipeline.run({"merge": {"in": 6}})
86 | pprint(results)
87 | assert results == {"remainder": {"remainder_is_0": 6}}
88 |
89 |
90 | def test_pipeline_differing_branches_variadic(tmp_path):
91 | pipeline = Pipeline(max_loops_allowed=10)
92 | pipeline.add_component("merge", FirstIntSelector())
93 | pipeline.add_component("remainder", Remainder(divisor=3))
94 | pipeline.add_component("add_one", AddFixedValue(add=1))
95 | pipeline.add_component("add_two_1", AddFixedValue(add=1))
96 | pipeline.add_component("add_two_2", AddFixedValue(add=1))
97 |
98 | pipeline.connect("merge", "remainder.value")
99 | pipeline.connect("remainder.remainder_is_1", "add_two_1.value")
100 | pipeline.connect("add_two_1", "add_two_2.value")
101 | pipeline.connect("add_two_2", "merge.inputs")
102 | pipeline.connect("remainder.remainder_is_2", "add_one.value")
103 | pipeline.connect("add_one", "merge.inputs")
104 |
105 | pipeline.draw(tmp_path / "distinct_loops_pipeline_different_branches_variadic.png")
106 |
107 | results = pipeline.run({"merge": {"inputs": 0}})
108 | pprint(results)
109 | assert results == {"remainder": {"remainder_is_0": 0}}
110 |
111 | results = pipeline.run({"merge": {"inputs": 3}})
112 | pprint(results)
113 | assert results == {"remainder": {"remainder_is_0": 3}}
114 |
115 | results = pipeline.run({"merge": {"inputs": 4}})
116 | pprint(results)
117 | assert results == {"remainder": {"remainder_is_0": 6}}
118 |
119 | results = pipeline.run({"merge": {"inputs": 5}})
120 | pprint(results)
121 | assert results == {"remainder": {"remainder_is_0": 6}}
122 |
123 | results = pipeline.run({"merge": {"inputs": 6}})
124 | pprint(results)
125 | assert results == {"remainder": {"remainder_is_0": 6}}
126 |
127 |
128 | if __name__ == "__main__":
129 | test_pipeline_equally_long_branches(Path(__file__).parent)
130 | test_pipeline_differing_branches(Path(__file__).parent)
131 | test_pipeline_differing_branches_variadic(Path(__file__).parent)
132 |
--------------------------------------------------------------------------------
/test/pipeline/integration/test_double_loop_pipeline.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: 2022-present deepset GmbH
2 | #
3 | # SPDX-License-Identifier: Apache-2.0
4 | from typing import *
5 | from pathlib import Path
6 | from pprint import pprint
7 |
8 | from canals.pipeline import Pipeline
9 | from sample_components import Accumulate, AddFixedValue, Threshold, MergeLoop
10 |
11 | import logging
12 |
13 | logging.basicConfig(level=logging.DEBUG)
14 |
15 |
16 | def test_pipeline(tmp_path):
17 | accumulator = Accumulate()
18 |
19 | pipeline = Pipeline(max_loops_allowed=10)
20 | pipeline.add_component("add_one", AddFixedValue(add=1))
21 | pipeline.add_component("merge", MergeLoop(expected_type=int, inputs=["in_1", "in_2", "in_3"]))
22 | pipeline.add_component("below_10", Threshold(threshold=10))
23 | pipeline.add_component("below_5", Threshold(threshold=5))
24 | pipeline.add_component("add_three", AddFixedValue(add=3))
25 | pipeline.add_component("accumulator", accumulator)
26 | pipeline.add_component("add_two", AddFixedValue(add=2))
27 |
28 | pipeline.connect("add_one.result", "merge.in_1")
29 | pipeline.connect("merge.value", "below_10.value")
30 | pipeline.connect("below_10.below", "accumulator.value")
31 | pipeline.connect("accumulator.value", "below_5.value")
32 | pipeline.connect("below_5.above", "add_three.value")
33 | pipeline.connect("below_5.below", "merge.in_2")
34 | pipeline.connect("add_three.result", "merge.in_3")
35 | pipeline.connect("below_10.above", "add_two.value")
36 |
37 | pipeline.draw(tmp_path / "double_loop_pipeline.png")
38 |
39 | results = pipeline.run({"add_one": {"value": 3}})
40 | pprint(results)
41 | print("accumulator: ", accumulator.state)
42 |
43 | assert results == {"add_two": {"result": 13}}
44 | assert accumulator.state == 8
45 |
46 |
47 | if __name__ == "__main__":
48 | test_pipeline(Path(__file__).parent)
49 |
--------------------------------------------------------------------------------
/test/pipeline/integration/test_dynamic_inputs_pipeline.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 | from canals import Pipeline
3 | from sample_components import FString, Hello, TextSplitter
4 |
5 |
6 | def test_pipeline(tmp_path):
7 | pipeline = Pipeline()
8 | pipeline.add_component("hello", Hello())
9 | pipeline.add_component("fstring", FString(template="This is the greeting: {greeting}!", variables=["greeting"]))
10 | pipeline.add_component("splitter", TextSplitter())
11 | pipeline.connect("hello.output", "fstring.greeting")
12 | pipeline.connect("fstring.string", "splitter.sentence")
13 |
14 | pipeline.draw(tmp_path / "dynamic_inputs_pipeline.png")
15 |
16 | output = pipeline.run({"hello": {"word": "Alice"}})
17 | assert output == {"splitter": {"output": ["This", "is", "the", "greeting:", "Hello,", "Alice!!"]}}
18 |
19 | output = pipeline.run({"hello": {"word": "Alice"}, "fstring": {"template": "Received: {greeting}"}})
20 | assert output == {"splitter": {"output": ["Received:", "Hello,", "Alice!"]}}
21 |
22 |
23 | if __name__ == "__main__":
24 | test_pipeline(Path(__file__).parent)
25 |
--------------------------------------------------------------------------------
/test/pipeline/integration/test_fixed_decision_and_merge_pipeline.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: 2022-present deepset GmbH
2 | #
3 | # SPDX-License-Identifier: Apache-2.0
4 | import logging
5 | from pathlib import Path
6 | from pprint import pprint
7 |
8 | from canals.pipeline import Pipeline
9 | from sample_components import AddFixedValue, Parity, Double, Subtract
10 |
11 | logging.basicConfig(level=logging.DEBUG)
12 |
13 |
14 | def test_pipeline(tmp_path):
15 | pipeline = Pipeline()
16 | pipeline.add_component("add_one", AddFixedValue())
17 | pipeline.add_component("parity", Parity())
18 | pipeline.add_component("add_ten", AddFixedValue(add=10))
19 | pipeline.add_component("double", Double())
20 | pipeline.add_component("add_four", AddFixedValue(add=4))
21 | pipeline.add_component("add_two", AddFixedValue())
22 | pipeline.add_component("add_two_as_well", AddFixedValue())
23 | pipeline.add_component("diff", Subtract())
24 |
25 | pipeline.connect("add_one.result", "parity.value")
26 | pipeline.connect("parity.even", "add_four.value")
27 | pipeline.connect("parity.odd", "double.value")
28 | pipeline.connect("add_ten.result", "diff.first_value")
29 | pipeline.connect("double.value", "diff.second_value")
30 | pipeline.connect("parity.odd", "add_ten.value")
31 | pipeline.connect("add_four.result", "add_two.value")
32 | pipeline.connect("add_four.result", "add_two_as_well.value")
33 |
34 | pipeline.draw(tmp_path / "fixed_decision_and_merge_pipeline.png")
35 |
36 | results = pipeline.run(
37 | {
38 | "add_one": {"value": 1},
39 | "add_two": {"add": 2},
40 | "add_two_as_well": {"add": 2},
41 | }
42 | )
43 | pprint(results)
44 |
45 | results == {
46 | "add_two": {"result": 8},
47 | }
48 |
49 | results = pipeline.run(
50 | {
51 | "add_one": {"value": 2},
52 | "add_two": {"add": 2},
53 | "add_two_as_well": {"add": 2},
54 | }
55 | )
56 | pprint(results)
57 |
58 | results == {
59 | "diff": {"difference": 7},
60 | }
61 |
62 |
63 | if __name__ == "__main__":
64 | test_pipeline(Path(__file__).parent)
65 |
--------------------------------------------------------------------------------
/test/pipeline/integration/test_fixed_decision_pipeline.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: 2022-present deepset GmbH
2 | #
3 | # SPDX-License-Identifier: Apache-2.0
4 | from pathlib import Path
5 | from pprint import pprint
6 |
7 | from canals.pipeline import Pipeline
8 | from sample_components import AddFixedValue, Parity, Double
9 |
10 | import logging
11 |
12 | logging.basicConfig(level=logging.DEBUG)
13 |
14 |
15 | def test_pipeline(tmp_path):
16 | pipeline = Pipeline()
17 | pipeline.add_component("add_one", AddFixedValue(add=1))
18 | pipeline.add_component("parity", Parity())
19 | pipeline.add_component("add_ten", AddFixedValue(add=10))
20 | pipeline.add_component("double", Double())
21 | pipeline.add_component("add_three", AddFixedValue(add=3))
22 |
23 | pipeline.connect("add_one.result", "parity.value")
24 | pipeline.connect("parity.even", "add_ten.value")
25 | pipeline.connect("parity.odd", "double.value")
26 | pipeline.connect("add_ten.result", "add_three.value")
27 |
28 | pipeline.draw(tmp_path / "fixed_decision_pipeline.png")
29 |
30 | results = pipeline.run({"add_one": {"value": 1}})
31 | pprint(results)
32 | assert results == {"add_three": {"result": 15}}
33 |
34 | results = pipeline.run({"add_one": {"value": 2}})
35 | pprint(results)
36 | assert results == {"double": {"value": 6}}
37 |
38 |
39 | if __name__ == "__main__":
40 | test_pipeline(Path(__file__).parent)
41 |
--------------------------------------------------------------------------------
/test/pipeline/integration/test_fixed_merging_pipeline.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: 2022-present deepset GmbH
2 | #
3 | # SPDX-License-Identifier: Apache-2.0
4 | from pathlib import Path
5 | from pprint import pprint
6 |
7 | from canals.pipeline import Pipeline
8 | from sample_components import AddFixedValue, Subtract
9 |
10 | import logging
11 |
12 | logging.basicConfig(level=logging.DEBUG)
13 |
14 |
15 | def test_pipeline(tmp_path):
16 | pipeline = Pipeline()
17 | pipeline.add_component("first_addition", AddFixedValue(add=2))
18 | pipeline.add_component("second_addition", AddFixedValue(add=2))
19 | pipeline.add_component("third_addition", AddFixedValue(add=2))
20 | pipeline.add_component("diff", Subtract())
21 | pipeline.add_component("fourth_addition", AddFixedValue(add=1))
22 |
23 | pipeline.connect("first_addition.result", "second_addition.value")
24 | pipeline.connect("second_addition.result", "diff.first_value")
25 | pipeline.connect("third_addition.result", "diff.second_value")
26 | pipeline.connect("diff", "fourth_addition.value")
27 |
28 | pipeline.draw(tmp_path / "fixed_merging_pipeline.png")
29 |
30 | results = pipeline.run(
31 | {
32 | "first_addition": {"value": 1},
33 | "third_addition": {"value": 1},
34 | }
35 | )
36 | pprint(results)
37 |
38 | assert results == {"fourth_addition": {"result": 3}}
39 |
40 |
41 | if __name__ == "__main__":
42 | test_pipeline(Path(__file__).parent)
43 |
--------------------------------------------------------------------------------
/test/pipeline/integration/test_joiners.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: 2022-present deepset GmbH
2 | #
3 | # SPDX-License-Identifier: Apache-2.0
4 | from pathlib import Path
5 | from pprint import pprint
6 |
7 | from canals.pipeline import Pipeline
8 | from sample_components import StringJoiner, StringListJoiner, Hello, TextSplitter
9 |
10 | import logging
11 |
12 | logging.basicConfig(level=logging.DEBUG)
13 |
14 |
15 | def test_joiner(tmp_path):
16 | pipeline = Pipeline()
17 | pipeline.add_component("hello_one", Hello())
18 | pipeline.add_component("hello_two", Hello())
19 | pipeline.add_component("hello_three", Hello())
20 | pipeline.add_component("joiner", StringJoiner())
21 |
22 | pipeline.connect("hello_one", "hello_two")
23 | pipeline.connect("hello_two", "joiner")
24 | pipeline.connect("hello_three", "joiner")
25 |
26 | pipeline.draw(tmp_path / "joiner_pipeline.png")
27 |
28 | results = pipeline.run({"hello_one": {"word": "world"}, "hello_three": {"word": "my friend"}})
29 | assert results == {"joiner": {"output": "Hello, my friend! Hello, Hello, world!!"}}
30 |
31 |
32 | def test_joiner_with_lists(tmp_path):
33 | pipeline = Pipeline()
34 | pipeline.add_component("first", TextSplitter())
35 | pipeline.add_component("second", TextSplitter())
36 | pipeline.add_component("joiner", StringListJoiner())
37 |
38 | pipeline.connect("first", "joiner")
39 | pipeline.connect("second", "joiner")
40 |
41 | pipeline.draw(tmp_path / "joiner_list_pipeline.png")
42 |
43 | results = pipeline.run({"first": {"sentence": "Hello world!"}, "second": {"sentence": "How are you?"}})
44 | assert results == {"joiner": {"output": ["Hello", "world!", "How", "are", "you?"]}}
45 |
46 |
47 | def test_joiner_with_pipeline_run(tmp_path):
48 | pipeline = Pipeline()
49 | pipeline.add_component("hello", Hello())
50 | pipeline.add_component("joiner", StringJoiner())
51 | pipeline.connect("hello", "joiner")
52 |
53 | pipeline.draw(tmp_path / "joiner_with_pipeline_run.png")
54 |
55 | results = pipeline.run({"hello": {"word": "world"}, "joiner": {"input_str": "another string!"}})
56 | assert results == {"joiner": {"output": "another string! Hello, world!"}}
57 |
58 |
59 | if __name__ == "__main__":
60 | test_joiner(Path(__file__).parent)
61 | test_joiner_with_lists(Path(__file__).parent)
62 | test_joiner_with_pipeline_run(Path(__file__).parent)
63 |
--------------------------------------------------------------------------------
/test/pipeline/integration/test_linear_pipeline.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: 2022-present deepset GmbH
2 | #
3 | # SPDX-License-Identifier: Apache-2.0
4 | from pathlib import Path
5 | from pprint import pprint
6 |
7 | from canals.pipeline import Pipeline
8 | from sample_components import AddFixedValue, Double
9 |
10 | import logging
11 |
12 | logging.basicConfig(level=logging.DEBUG)
13 |
14 |
15 | def test_pipeline(tmp_path):
16 | pipeline = Pipeline()
17 | pipeline.add_component("first_addition", AddFixedValue(add=2))
18 | pipeline.add_component("second_addition", AddFixedValue())
19 | pipeline.add_component("double", Double())
20 | pipeline.connect("first_addition", "double")
21 | pipeline.connect("double", "second_addition")
22 |
23 | pipeline.draw(tmp_path / "linear_pipeline.png")
24 |
25 | results = pipeline.run({"first_addition": {"value": 1}})
26 | pprint(results)
27 |
28 | assert results == {"second_addition": {"result": 7}}
29 |
30 |
31 | if __name__ == "__main__":
32 | test_pipeline(Path(__file__).parent)
33 |
--------------------------------------------------------------------------------
/test/pipeline/integration/test_looping_and_merge_pipeline.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: 2022-present deepset GmbH
2 | #
3 | # SPDX-License-Identifier: Apache-2.0
4 | from typing import *
5 | from pathlib import Path
6 | from pprint import pprint
7 |
8 | from canals.pipeline import Pipeline
9 | from sample_components import Accumulate, AddFixedValue, Threshold, Sum, FirstIntSelector, MergeLoop
10 |
11 | import logging
12 |
13 | logging.basicConfig(level=logging.DEBUG)
14 |
15 |
16 | def test_pipeline_fixed(tmp_path):
17 | accumulator = Accumulate()
18 | pipeline = Pipeline(max_loops_allowed=10)
19 | pipeline.add_component("add_zero", AddFixedValue(add=0))
20 | pipeline.add_component("merge", MergeLoop(expected_type=int, inputs=["in_1", "in_2"]))
21 | pipeline.add_component("sum", Sum())
22 | pipeline.add_component("below_10", Threshold(threshold=10))
23 | pipeline.add_component("add_one", AddFixedValue(add=1))
24 | pipeline.add_component("counter", accumulator)
25 | pipeline.add_component("add_two", AddFixedValue(add=2))
26 |
27 | pipeline.connect("add_zero", "merge.in_1")
28 | pipeline.connect("merge", "below_10.value")
29 | pipeline.connect("below_10.below", "add_one.value")
30 | pipeline.connect("add_one.result", "counter.value")
31 | pipeline.connect("counter.value", "merge.in_2")
32 | pipeline.connect("below_10.above", "add_two.value")
33 | pipeline.connect("add_two.result", "sum.values")
34 |
35 | pipeline.draw(tmp_path / "looping_and_fixed_merge_pipeline.png")
36 |
37 | results = pipeline.run(
38 | {"add_zero": {"value": 8}, "sum": {"values": 2}},
39 | )
40 | pprint(results)
41 | print("accumulate: ", accumulator.state)
42 |
43 | assert results == {"sum": {"total": 23}}
44 | assert accumulator.state == 19
45 |
46 |
47 | def test_pipeline_variadic(tmp_path):
48 | accumulator = Accumulate()
49 | pipeline = Pipeline(max_loops_allowed=10)
50 | pipeline.add_component("add_zero", AddFixedValue(add=0))
51 | pipeline.add_component("merge", FirstIntSelector())
52 | pipeline.add_component("sum", Sum())
53 | pipeline.add_component("below_10", Threshold(threshold=10))
54 | pipeline.add_component("add_one", AddFixedValue(add=1))
55 | pipeline.add_component("counter", accumulator)
56 | pipeline.add_component("add_two", AddFixedValue(add=2))
57 |
58 | pipeline.connect("add_zero", "merge")
59 | pipeline.connect("merge", "below_10.value")
60 | pipeline.connect("below_10.below", "add_one.value")
61 | pipeline.connect("add_one.result", "counter.value")
62 | pipeline.connect("counter.value", "merge.inputs")
63 | pipeline.connect("below_10.above", "add_two.value")
64 | pipeline.connect("add_two.result", "sum.values")
65 |
66 | pipeline.draw(tmp_path / "looping_and_variadic_merge_pipeline.png")
67 |
68 | results = pipeline.run(
69 | {"add_zero": {"value": 8}, "sum": {"values": 2}},
70 | )
71 | pprint(results)
72 | print("accumulate: ", accumulator.state)
73 |
74 | assert results == {"sum": {"total": 23}}
75 | assert accumulator.state == 19
76 |
77 |
78 | if __name__ == "__main__":
79 | test_pipeline_fixed(Path(__file__).parent)
80 | test_pipeline_variadic(Path(__file__).parent)
81 |
--------------------------------------------------------------------------------
/test/pipeline/integration/test_looping_pipeline.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: 2022-present deepset GmbH
2 | #
3 | # SPDX-License-Identifier: Apache-2.0
4 | from typing import *
5 | from pathlib import Path
6 | from pprint import pprint
7 |
8 | from canals.pipeline import Pipeline
9 | from sample_components import Accumulate, AddFixedValue, Threshold, MergeLoop, FirstIntSelector
10 |
11 | import logging
12 |
13 | logging.basicConfig(level=logging.DEBUG)
14 |
15 |
16 | def test_pipeline(tmp_path):
17 | accumulator = Accumulate()
18 |
19 | pipeline = Pipeline(max_loops_allowed=10)
20 | pipeline.add_component("add_one", AddFixedValue(add=1))
21 | pipeline.add_component("merge", MergeLoop(expected_type=int, inputs=["in_1", "in_2"]))
22 | pipeline.add_component("below_10", Threshold(threshold=10))
23 | pipeline.add_component("accumulator", accumulator)
24 | pipeline.add_component("add_two", AddFixedValue(add=2))
25 |
26 | pipeline.connect("add_one.result", "merge.in_1")
27 | pipeline.connect("merge.value", "below_10.value")
28 | pipeline.connect("below_10.below", "accumulator.value")
29 | pipeline.connect("accumulator.value", "merge.in_2")
30 | pipeline.connect("below_10.above", "add_two.value")
31 |
32 | pipeline.draw(tmp_path / "looping_pipeline.png")
33 |
34 | results = pipeline.run({"add_one": {"value": 3}})
35 | pprint(results)
36 |
37 | assert results == {"add_two": {"result": 18}}
38 | assert accumulator.state == 16
39 |
40 |
41 | def test_pipeline_direct_io_loop(tmp_path):
42 | accumulator = Accumulate()
43 |
44 | pipeline = Pipeline(max_loops_allowed=10)
45 | pipeline.add_component("merge", MergeLoop(expected_type=int, inputs=["in_1", "in_2"]))
46 | pipeline.add_component("below_10", Threshold(threshold=10))
47 | pipeline.add_component("accumulator", accumulator)
48 |
49 | pipeline.connect("merge.value", "below_10.value")
50 | pipeline.connect("below_10.below", "accumulator.value")
51 | pipeline.connect("accumulator.value", "merge.in_2")
52 |
53 | pipeline.draw(tmp_path / "looping_pipeline_direct_io_loop.png")
54 |
55 | results = pipeline.run({"merge": {"in_1": 4}})
56 | pprint(results)
57 |
58 | assert results == {"below_10": {"above": 16}}
59 | assert accumulator.state == 16
60 |
61 |
62 | def test_pipeline_fixed_merger_input(tmp_path):
63 | accumulator = Accumulate()
64 |
65 | pipeline = Pipeline(max_loops_allowed=10)
66 | pipeline.add_component("merge", MergeLoop(expected_type=int, inputs=["in_1", "in_2"]))
67 | pipeline.add_component("below_10", Threshold(threshold=10))
68 | pipeline.add_component("accumulator", accumulator)
69 | pipeline.add_component("add_two", AddFixedValue(add=2))
70 |
71 | pipeline.connect("merge.value", "below_10.value")
72 | pipeline.connect("below_10.below", "accumulator.value")
73 | pipeline.connect("accumulator.value", "merge.in_2")
74 | pipeline.connect("below_10.above", "add_two.value")
75 |
76 | pipeline.draw(tmp_path / "looping_pipeline_fixed_merger_input.png")
77 |
78 | results = pipeline.run({"merge": {"in_1": 4}})
79 | pprint(results)
80 | print("accumulator: ", accumulator.state)
81 |
82 | assert results == {"add_two": {"result": 18}}
83 | assert accumulator.state == 16
84 |
85 |
86 | def test_pipeline_variadic_merger_input(tmp_path):
87 | accumulator = Accumulate()
88 |
89 | pipeline = Pipeline(max_loops_allowed=10)
90 | pipeline.add_component("merge", FirstIntSelector())
91 | pipeline.add_component("below_10", Threshold(threshold=10))
92 | pipeline.add_component("accumulator", accumulator)
93 | pipeline.add_component("add_two", AddFixedValue(add=2))
94 |
95 | pipeline.connect("merge", "below_10.value")
96 | pipeline.connect("below_10.below", "accumulator.value")
97 | pipeline.connect("accumulator.value", "merge.inputs")
98 | pipeline.connect("below_10.above", "add_two.value")
99 |
100 | pipeline.draw(tmp_path / "looping_pipeline_variadic_merger_input.png")
101 |
102 | results = pipeline.run({"merge": {"inputs": 4}})
103 | pprint(results)
104 |
105 | assert results == {"add_two": {"result": 18}}
106 | assert accumulator.state == 16
107 |
108 |
109 | if __name__ == "__main__":
110 | test_pipeline(Path(__file__).parent)
111 | test_pipeline_direct_io_loop(Path(__file__).parent)
112 | test_pipeline_fixed_merger_input(Path(__file__).parent)
113 | test_pipeline_variadic_merger_input(Path(__file__).parent)
114 |
--------------------------------------------------------------------------------
/test/pipeline/integration/test_mutable_inputs.py:
--------------------------------------------------------------------------------
1 | from typing import List
2 |
3 | from canals import Pipeline, component
4 | from sample_components import StringListJoiner
5 |
6 |
7 | @component
8 | class InputMangler:
9 | @component.output_types(mangled_list=List[str])
10 | def run(self, input_list: List[str]):
11 | input_list.append("extra_item")
12 | return {"mangled_list": input_list}
13 |
14 |
15 | def test_mutable_inputs():
16 | pipe = Pipeline()
17 | pipe.add_component("mangler1", InputMangler())
18 | pipe.add_component("mangler2", InputMangler())
19 | pipe.add_component("concat1", StringListJoiner())
20 | pipe.add_component("concat2", StringListJoiner())
21 | pipe.connect("mangler1", "concat1")
22 | pipe.connect("mangler2", "concat2")
23 |
24 | mylist = ["foo", "bar"]
25 |
26 | result = pipe.run(data={"mangler1": {"input_list": mylist}, "mangler2": {"input_list": mylist}})
27 | assert result["concat1"]["output"] == result["concat2"]["output"] == ["foo", "bar", "extra_item"]
28 |
--------------------------------------------------------------------------------
/test/pipeline/integration/test_parallel_branches_pipeline.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: 2022-present deepset GmbH
2 | #
3 | # SPDX-License-Identifier: Apache-2.0
4 | from pathlib import Path
5 | from pprint import pprint
6 |
7 | from canals.pipeline import Pipeline
8 | from sample_components import AddFixedValue, Repeat, Double
9 |
10 | import logging
11 |
12 | logging.basicConfig(level=logging.DEBUG)
13 |
14 |
15 | def test_pipeline(tmp_path):
16 | pipeline = Pipeline()
17 | pipeline.add_component("add_one", AddFixedValue(add=1))
18 | pipeline.add_component("repeat", Repeat(outputs=["first", "second"]))
19 | pipeline.add_component("add_ten", AddFixedValue(add=10))
20 | pipeline.add_component("double", Double())
21 | pipeline.add_component("add_three", AddFixedValue(add=3))
22 | pipeline.add_component("add_one_again", AddFixedValue(add=1))
23 |
24 | pipeline.connect("add_one.result", "repeat.value")
25 | pipeline.connect("repeat.first", "add_ten.value")
26 | pipeline.connect("repeat.second", "double.value")
27 | pipeline.connect("repeat.second", "add_three.value")
28 | pipeline.connect("add_three.result", "add_one_again.value")
29 |
30 | pipeline.draw(tmp_path / "parallel_branches_pipeline.png")
31 |
32 | results = pipeline.run({"add_one": {"value": 1}})
33 | pprint(results)
34 |
35 | assert results == {
36 | "add_one_again": {"result": 6},
37 | "add_ten": {"result": 12},
38 | "double": {"value": 4},
39 | }
40 |
41 |
42 | if __name__ == "__main__":
43 | test_pipeline(Path(__file__).parent)
44 |
--------------------------------------------------------------------------------
/test/pipeline/integration/test_self_loop.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: 2022-present deepset GmbH
2 | #
3 | # SPDX-License-Identifier: Apache-2.0
4 | from typing import Optional
5 | from pathlib import Path
6 | from pprint import pprint
7 |
8 | from canals import component
9 | from canals.pipeline import Pipeline
10 | from sample_components import AddFixedValue, SelfLoop
11 |
12 | import logging
13 |
14 | logging.basicConfig(level=logging.DEBUG)
15 |
16 |
17 | def test_pipeline_one_node(tmp_path):
18 | pipeline = Pipeline(max_loops_allowed=10)
19 | pipeline.add_component("self_loop", SelfLoop())
20 | pipeline.connect("self_loop.current_value", "self_loop.values")
21 |
22 | pipeline.draw(tmp_path / "self_looping_pipeline_one_node.png")
23 |
24 | results = pipeline.run({"self_loop": {"values": 5}})
25 | pprint(results)
26 |
27 | assert results["self_loop"]["final_result"] == 0
28 |
29 |
30 | def test_pipeline(tmp_path):
31 | pipeline = Pipeline(max_loops_allowed=10)
32 | pipeline.add_component("add_1", AddFixedValue())
33 | pipeline.add_component("self_loop", SelfLoop())
34 | pipeline.add_component("add_2", AddFixedValue())
35 | pipeline.connect("add_1", "self_loop.values")
36 | pipeline.connect("self_loop.current_value", "self_loop.values")
37 | pipeline.connect("self_loop.final_result", "add_2.value")
38 |
39 | pipeline.draw(tmp_path / "self_looping_pipeline.png")
40 |
41 | results = pipeline.run({"add_1": {"value": 5}})
42 | pprint(results)
43 |
44 | assert results["add_2"]["result"] == 1
45 |
46 |
47 | if __name__ == "__main__":
48 | test_pipeline_one_node(Path(__file__).parent)
49 | test_pipeline(Path(__file__).parent)
50 |
--------------------------------------------------------------------------------
/test/pipeline/integration/test_variable_decision_and_merge_pipeline.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: 2022-present deepset GmbH
2 | #
3 | # SPDX-License-Identifier: Apache-2.0
4 | import logging
5 | from pathlib import Path
6 | from pprint import pprint
7 |
8 | from canals.pipeline import Pipeline
9 | from sample_components import AddFixedValue, Remainder, Double, Sum
10 |
11 | logging.basicConfig(level=logging.DEBUG)
12 |
13 |
14 | def test_pipeline(tmp_path):
15 | pipeline = Pipeline()
16 | pipeline.add_component("add_one", AddFixedValue())
17 | pipeline.add_component("parity", Remainder(divisor=2))
18 | pipeline.add_component("add_ten", AddFixedValue(add=10))
19 | pipeline.add_component("double", Double())
20 | pipeline.add_component("add_four", AddFixedValue(add=4))
21 | pipeline.add_component("add_one_again", AddFixedValue())
22 | pipeline.add_component("sum", Sum())
23 |
24 | pipeline.connect("add_one.result", "parity.value")
25 | pipeline.connect("parity.remainder_is_0", "add_ten.value")
26 | pipeline.connect("parity.remainder_is_1", "double.value")
27 | pipeline.connect("add_one.result", "sum.values")
28 | pipeline.connect("add_ten.result", "sum.values")
29 | pipeline.connect("double.value", "sum.values")
30 | pipeline.connect("parity.remainder_is_1", "add_four.value")
31 | pipeline.connect("add_four.result", "add_one_again.value")
32 | pipeline.connect("add_one_again.result", "sum.values")
33 |
34 | pipeline.draw(tmp_path / "variable_decision_and_merge_pipeline.png")
35 |
36 | results = pipeline.run({"add_one": {"value": 1}})
37 | pprint(results)
38 | assert results == {"sum": {"total": 14}}
39 |
40 | results = pipeline.run({"add_one": {"value": 2}})
41 | pprint(results)
42 | assert results == {"sum": {"total": 17}}
43 |
44 |
45 | if __name__ == "__main__":
46 | test_pipeline(Path(__file__).parent)
47 |
--------------------------------------------------------------------------------
/test/pipeline/integration/test_variable_decision_pipeline.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: 2022-present deepset GmbH
2 | #
3 | # SPDX-License-Identifier: Apache-2.0
4 | from pathlib import Path
5 | from pprint import pprint
6 |
7 | from canals.pipeline import Pipeline
8 | from sample_components import AddFixedValue, Remainder, Double
9 |
10 | import logging
11 |
12 | logging.basicConfig(level=logging.DEBUG)
13 |
14 |
15 | def test_pipeline(tmp_path):
16 | pipeline = Pipeline()
17 | pipeline.add_component("add_one", AddFixedValue(add=1))
18 | pipeline.add_component("remainder", Remainder(divisor=3))
19 | pipeline.add_component("add_ten", AddFixedValue(add=10))
20 | pipeline.add_component("double", Double())
21 | pipeline.add_component("add_three", AddFixedValue(add=3))
22 | pipeline.add_component("add_one_again", AddFixedValue(add=1))
23 |
24 | pipeline.connect("add_one.result", "remainder.value")
25 | pipeline.connect("remainder.remainder_is_0", "add_ten.value")
26 | pipeline.connect("remainder.remainder_is_1", "double.value")
27 | pipeline.connect("remainder.remainder_is_2", "add_three.value")
28 | pipeline.connect("add_three.result", "add_one_again.value")
29 |
30 | pipeline.draw(tmp_path / "variable_decision_pipeline.png")
31 |
32 | results = pipeline.run({"add_one": {"value": 1}})
33 | pprint(results)
34 |
35 | assert results == {"add_one_again": {"result": 6}}
36 |
37 |
38 | if __name__ == "__main__":
39 | test_pipeline(Path(__file__).parent)
40 |
--------------------------------------------------------------------------------
/test/pipeline/integration/test_variable_merging_pipeline.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: 2022-present deepset GmbH
2 | #
3 | # SPDX-License-Identifier: Apache-2.0
4 | from pathlib import Path
5 | from pprint import pprint
6 |
7 | from canals.pipeline import Pipeline
8 | from sample_components import AddFixedValue, Sum
9 |
10 | import logging
11 |
12 | logging.basicConfig(level=logging.DEBUG)
13 |
14 |
15 | def test_pipeline(tmp_path):
16 | pipeline = Pipeline()
17 | pipeline.add_component("first_addition", AddFixedValue(add=2))
18 | pipeline.add_component("second_addition", AddFixedValue(add=2))
19 | pipeline.add_component("third_addition", AddFixedValue(add=2))
20 | pipeline.add_component("sum", Sum())
21 | pipeline.add_component("fourth_addition", AddFixedValue(add=1))
22 |
23 | pipeline.connect("first_addition.result", "second_addition.value")
24 | pipeline.connect("first_addition.result", "sum.values")
25 | pipeline.connect("second_addition.result", "sum.values")
26 | pipeline.connect("third_addition.result", "sum.values")
27 | pipeline.connect("sum.total", "fourth_addition.value")
28 |
29 | pipeline.draw(tmp_path / "variable_merging_pipeline.png")
30 |
31 | results = pipeline.run(
32 | {
33 | "first_addition": {"value": 1},
34 | "third_addition": {"value": 1},
35 | }
36 | )
37 | pprint(results)
38 |
39 | assert results == {"fourth_addition": {"result": 12}}
40 |
41 |
42 | if __name__ == "__main__":
43 | test_pipeline(Path(__file__).parent)
44 |
--------------------------------------------------------------------------------
/test/pipeline/unit/__init__.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: 2022-present deepset GmbH
2 | #
3 | # SPDX-License-Identifier: Apache-2.0
4 |
--------------------------------------------------------------------------------
/test/pipeline/unit/test_draw.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: 2022-present deepset GmbH
2 | #
3 | # SPDX-License-Identifier: Apache-2.0
4 | import os
5 | import sys
6 | import filecmp
7 |
8 | from unittest.mock import patch, MagicMock
9 | import pytest
10 | import requests
11 |
12 | from canals.pipeline import Pipeline
13 | from canals.pipeline.draw import _draw, _convert
14 | from canals.errors import PipelineDrawingError
15 | from sample_components import Double, AddFixedValue
16 |
17 |
18 | @pytest.mark.skipif(sys.platform.lower().startswith("darwin"), reason="the available graphviz version is too recent")
19 | @pytest.mark.skipif(sys.platform.lower().startswith("win"), reason="pygraphviz is not really available in Windows")
20 | def test_draw_pygraphviz(tmp_path, test_files):
21 | pipe = Pipeline()
22 | pipe.add_component("comp1", Double())
23 | pipe.add_component("comp2", Double())
24 | pipe.connect("comp1", "comp2")
25 |
26 | _draw(pipe.graph, tmp_path / "test_pipe.jpg", engine="graphviz")
27 | assert os.path.exists(tmp_path / "test_pipe.jpg")
28 | assert filecmp.cmp(tmp_path / "test_pipe.jpg", test_files / "pipeline_draw" / "pygraphviz.jpg")
29 |
30 |
31 | def test_draw_mermaid_image(tmp_path, test_files):
32 | pipe = Pipeline()
33 | pipe.add_component("comp1", Double())
34 | pipe.add_component("comp2", Double())
35 | pipe.connect("comp1", "comp2")
36 | pipe.connect("comp2", "comp1")
37 |
38 | _draw(pipe.graph, tmp_path / "test_pipe.jpg", engine="mermaid-image")
39 | assert os.path.exists(tmp_path / "test_pipe.jpg")
40 | assert filecmp.cmp(tmp_path / "test_pipe.jpg", test_files / "mermaid_mock" / "test_response.png")
41 |
42 |
43 | def test_draw_mermaid_img_failing_request(tmp_path):
44 | pipe = Pipeline()
45 | pipe.add_component("comp1", Double())
46 | pipe.add_component("comp2", Double())
47 | pipe.connect("comp1", "comp2")
48 | pipe.connect("comp2", "comp1")
49 |
50 | with patch("canals.pipeline.draw.mermaid.requests.get") as mock_get:
51 |
52 | def raise_for_status(self):
53 | raise requests.HTTPError()
54 |
55 | mock_response = MagicMock()
56 | mock_response.status_code = 429
57 | mock_response.content = '{"error": "too many requests"}'
58 | mock_response.raise_for_status = raise_for_status
59 | mock_get.return_value = mock_response
60 |
61 | with pytest.raises(PipelineDrawingError, match="There was an issue with https://mermaid.ink/"):
62 | _draw(pipe.graph, tmp_path / "test_pipe.jpg", engine="mermaid-image")
63 |
64 |
65 | def test_draw_mermaid_text(tmp_path):
66 | pipe = Pipeline()
67 | pipe.add_component("comp1", AddFixedValue(add=3))
68 | pipe.add_component("comp2", Double())
69 | pipe.connect("comp1.result", "comp2.value")
70 | pipe.connect("comp2.value", "comp1.value")
71 |
72 | _draw(pipe.graph, tmp_path / "test_pipe.md", engine="mermaid-text")
73 | assert os.path.exists(tmp_path / "test_pipe.md")
74 | assert (
75 | open(tmp_path / "test_pipe.md", "r").read()
76 | == """
77 | %%{ init: {'theme': 'neutral' } }%%
78 |
79 | graph TD;
80 |
81 | comp1["comp1
AddFixedValue
Optional inputs:"]:::component -- "result -> value
int" --> comp2["comp2
Double"]:::component
82 | comp2["comp2
Double"]:::component -- "value -> value
int" --> comp1["comp1
AddFixedValue
Optional inputs:"]:::component
83 |
84 | classDef component text-align:center;
85 | """
86 | )
87 |
88 |
89 | def test_draw_unknown_engine(tmp_path):
90 | pipe = Pipeline()
91 | pipe.add_component("comp1", Double())
92 | pipe.add_component("comp2", Double())
93 | pipe.connect("comp1", "comp2")
94 | pipe.connect("comp2", "comp1")
95 |
96 | with pytest.raises(ValueError, match="Unknown rendering engine 'unknown'"):
97 | _draw(pipe.graph, tmp_path / "test_pipe.jpg", engine="unknown")
98 |
99 |
100 | def test_convert_unknown_engine(tmp_path):
101 | pipe = Pipeline()
102 | pipe.add_component("comp1", Double())
103 | pipe.add_component("comp2", Double())
104 | pipe.connect("comp1", "comp2")
105 | pipe.connect("comp2", "comp1")
106 |
107 | with pytest.raises(ValueError, match="Unknown rendering engine 'unknown'"):
108 | _convert(pipe.graph, engine="unknown")
109 |
--------------------------------------------------------------------------------
/test/pipeline/unit/test_validation_pipeline_io.py:
--------------------------------------------------------------------------------
1 | from typing import Optional
2 |
3 | import pytest
4 | import inspect
5 |
6 | from canals.pipeline import Pipeline
7 | from canals.component.types import Variadic
8 | from canals.errors import PipelineValidationError
9 | from canals.component.sockets import InputSocket, OutputSocket
10 | from canals.pipeline.descriptions import find_pipeline_inputs, find_pipeline_outputs
11 | from sample_components import Double, AddFixedValue, Sum, Parity
12 |
13 |
14 | def test_find_pipeline_input_no_input():
15 | pipe = Pipeline()
16 | pipe.add_component("comp1", Double())
17 | pipe.add_component("comp2", Double())
18 | pipe.connect("comp1", "comp2")
19 | pipe.connect("comp2", "comp1")
20 |
21 | assert find_pipeline_inputs(pipe.graph) == {"comp1": [], "comp2": []}
22 |
23 |
24 | def test_find_pipeline_input_one_input():
25 | pipe = Pipeline()
26 | pipe.add_component("comp1", Double())
27 | pipe.add_component("comp2", Double())
28 | pipe.connect("comp1", "comp2")
29 |
30 | assert find_pipeline_inputs(pipe.graph) == {
31 | "comp1": [InputSocket(name="value", type=int)],
32 | "comp2": [],
33 | }
34 |
35 |
36 | def test_find_pipeline_input_two_inputs_same_component():
37 | pipe = Pipeline()
38 | pipe.add_component("comp1", AddFixedValue())
39 | pipe.add_component("comp2", Double())
40 | pipe.connect("comp1", "comp2")
41 |
42 | assert find_pipeline_inputs(pipe.graph) == {
43 | "comp1": [
44 | InputSocket(name="value", type=int),
45 | InputSocket(name="add", type=Optional[int], is_mandatory=False),
46 | ],
47 | "comp2": [],
48 | }
49 |
50 |
51 | def test_find_pipeline_input_some_inputs_different_components():
52 | pipe = Pipeline()
53 | pipe.add_component("comp1", AddFixedValue())
54 | pipe.add_component("comp2", Double())
55 | pipe.add_component("comp3", AddFixedValue())
56 | pipe.connect("comp1.result", "comp3.value")
57 | pipe.connect("comp2.value", "comp3.add")
58 |
59 | assert find_pipeline_inputs(pipe.graph) == {
60 | "comp1": [
61 | InputSocket(name="value", type=int),
62 | InputSocket(name="add", type=Optional[int], is_mandatory=False),
63 | ],
64 | "comp2": [InputSocket(name="value", type=int)],
65 | "comp3": [],
66 | }
67 |
68 |
69 | def test_find_pipeline_variable_input_nodes_in_the_pipeline():
70 | pipe = Pipeline()
71 | pipe.add_component("comp1", AddFixedValue())
72 | pipe.add_component("comp2", Double())
73 | pipe.add_component("comp3", Sum())
74 |
75 | assert find_pipeline_inputs(pipe.graph) == {
76 | "comp1": [
77 | InputSocket(name="value", type=int),
78 | InputSocket(name="add", type=Optional[int], is_mandatory=False),
79 | ],
80 | "comp2": [InputSocket(name="value", type=int)],
81 | "comp3": [
82 | InputSocket(name="values", type=Variadic[int]),
83 | ],
84 | }
85 |
86 |
87 | def test_find_pipeline_output_no_output():
88 | pipe = Pipeline()
89 | pipe.add_component("comp1", Double())
90 | pipe.add_component("comp2", Double())
91 | pipe.connect("comp1", "comp2")
92 | pipe.connect("comp2", "comp1")
93 |
94 | assert find_pipeline_outputs(pipe.graph) == {"comp1": [], "comp2": []}
95 |
96 |
97 | def test_find_pipeline_output_one_output():
98 | pipe = Pipeline()
99 | pipe.add_component("comp1", Double())
100 | pipe.add_component("comp2", Double())
101 | pipe.connect("comp1", "comp2")
102 |
103 | assert find_pipeline_outputs(pipe.graph) == {"comp1": [], "comp2": [OutputSocket(name="value", type=int)]}
104 |
105 |
106 | def test_find_pipeline_some_outputs_same_component():
107 | pipe = Pipeline()
108 | pipe.add_component("comp1", Double())
109 | pipe.add_component("comp2", Parity())
110 | pipe.connect("comp1", "comp2")
111 |
112 | assert find_pipeline_outputs(pipe.graph) == {
113 | "comp1": [],
114 | "comp2": [OutputSocket(name="even", type=int), OutputSocket(name="odd", type=int)],
115 | }
116 |
117 |
118 | def test_find_pipeline_some_outputs_different_components():
119 | pipe = Pipeline()
120 | pipe.add_component("comp1", Double())
121 | pipe.add_component("comp2", Parity())
122 | pipe.add_component("comp3", Double())
123 | pipe.connect("comp1", "comp2")
124 | pipe.connect("comp1", "comp3")
125 |
126 | assert find_pipeline_outputs(pipe.graph) == {
127 | "comp1": [],
128 | "comp2": [OutputSocket(name="even", type=int), OutputSocket(name="odd", type=int)],
129 | "comp3": [
130 | OutputSocket(name="value", type=int),
131 | ],
132 | }
133 |
134 |
135 | def test_validate_pipeline_input_pipeline_with_no_inputs():
136 | pipe = Pipeline()
137 | pipe.add_component("comp1", Double())
138 | pipe.add_component("comp2", Double())
139 | pipe.connect("comp1", "comp2")
140 | pipe.connect("comp2", "comp1")
141 | with pytest.raises(PipelineValidationError, match="This pipeline has no inputs."):
142 | pipe.run({})
143 |
144 |
145 | def test_validate_pipeline_input_unknown_component():
146 | pipe = Pipeline()
147 | pipe.add_component("comp1", Double())
148 | pipe.add_component("comp2", Double())
149 | pipe.connect("comp1", "comp2")
150 | with pytest.raises(ValueError, match="Pipeline received data for unknown component\(s\): test_component"):
151 | pipe.run({"test_component": {"value": 1}})
152 |
153 |
154 | def test_validate_pipeline_input_all_necessary_input_is_present():
155 | pipe = Pipeline()
156 | pipe.add_component("comp1", Double())
157 | pipe.add_component("comp2", Double())
158 | pipe.connect("comp1", "comp2")
159 | with pytest.raises(ValueError, match="Missing input: comp1.value"):
160 | pipe.run({})
161 |
162 |
163 | def test_validate_pipeline_input_all_necessary_input_is_present_considering_defaults():
164 | pipe = Pipeline()
165 | pipe.add_component("comp1", AddFixedValue())
166 | pipe.add_component("comp2", Double())
167 | pipe.connect("comp1", "comp2")
168 | pipe.run({"comp1": {"value": 1}})
169 | pipe.run({"comp1": {"value": 1, "add": 2}})
170 | with pytest.raises(ValueError, match="Missing input: comp1.value"):
171 | pipe.run({"comp1": {"add": 3}})
172 |
173 |
174 | def test_validate_pipeline_input_only_expected_input_is_present():
175 | pipe = Pipeline()
176 | pipe.add_component("comp1", Double())
177 | pipe.add_component("comp2", Double())
178 | pipe.connect("comp1", "comp2")
179 | with pytest.raises(ValueError, match="The input value of comp2 is already sent by: \['comp1'\]"):
180 | pipe.run({"comp1": {"value": 1}, "comp2": {"value": 2}})
181 |
182 |
183 | def test_validate_pipeline_input_only_expected_input_is_present_falsy():
184 | pipe = Pipeline()
185 | pipe.add_component("comp1", Double())
186 | pipe.add_component("comp2", Double())
187 | pipe.connect("comp1", "comp2")
188 | with pytest.raises(ValueError, match="The input value of comp2 is already sent by: \['comp1'\]"):
189 | pipe.run({"comp1": {"value": 1}, "comp2": {"value": 0}})
190 |
191 |
192 | def test_validate_pipeline_falsy_input_present():
193 | pipe = Pipeline()
194 | pipe.add_component("comp", Double())
195 | assert pipe.run({"comp": {"value": 0}}) == {"comp": {"value": 0}}
196 |
197 |
198 | def test_validate_pipeline_falsy_input_missing():
199 | pipe = Pipeline()
200 | pipe.add_component("comp", Double())
201 | with pytest.raises(ValueError, match="Missing input: comp.value"):
202 | pipe.run({"comp": {}})
203 |
204 |
205 | def test_validate_pipeline_input_only_expected_input_is_present_including_unknown_names():
206 | pipe = Pipeline()
207 | pipe.add_component("comp1", Double())
208 | pipe.add_component("comp2", Double())
209 | pipe.connect("comp1", "comp2")
210 |
211 | with pytest.raises(ValueError, match="Component comp1 is not expecting any input value called add"):
212 | pipe.run({"comp1": {"value": 1, "add": 2}})
213 |
214 |
215 | def test_validate_pipeline_input_only_expected_input_is_present_and_defaults_dont_interfere():
216 | pipe = Pipeline()
217 | pipe.add_component("comp1", AddFixedValue(add=10))
218 | pipe.add_component("comp2", Double())
219 | pipe.connect("comp1", "comp2")
220 | assert pipe.run({"comp1": {"value": 1, "add": 5}}) == {"comp2": {"value": 12}}
221 |
--------------------------------------------------------------------------------
/test/sample_components/__init__.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: 2022-present deepset GmbH
2 | #
3 | # SPDX-License-Identifier: Apache-2.0
4 | from test.sample_components.test_accumulate import Accumulate
5 | from test.sample_components.test_add_value import AddFixedValue
6 | from test.sample_components.test_double import Double
7 | from test.sample_components.test_parity import Parity
8 | from test.sample_components.test_greet import Greet
9 | from test.sample_components.test_remainder import Remainder
10 | from test.sample_components.test_repeat import Repeat
11 | from test.sample_components.test_subtract import Subtract
12 | from test.sample_components.test_sum import Sum
13 | from test.sample_components.test_threshold import Threshold
14 |
--------------------------------------------------------------------------------
/test/sample_components/test_accumulate.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: 2022-present deepset GmbH
2 | #
3 | # SPDX-License-Identifier: Apache-2.0
4 | from sample_components.accumulate import Accumulate, _default_function
5 |
6 |
7 | def my_subtract(first, second):
8 | return first - second
9 |
10 |
11 | def test_to_dict():
12 | accumulate = Accumulate()
13 | res = accumulate.to_dict()
14 | assert res == {
15 | "type": "sample_components.accumulate.Accumulate",
16 | "init_parameters": {"function": "sample_components.accumulate._default_function"},
17 | }
18 |
19 |
20 | def test_to_dict_with_custom_function():
21 | accumulate = Accumulate(function=my_subtract)
22 | res = accumulate.to_dict()
23 | assert res == {
24 | "type": "sample_components.accumulate.Accumulate",
25 | "init_parameters": {"function": "test.sample_components.test_accumulate.my_subtract"},
26 | }
27 |
28 |
29 | def test_from_dict():
30 | data = {
31 | "type": "sample_components.accumulate.Accumulate",
32 | "init_parameters": {},
33 | }
34 | accumulate = Accumulate.from_dict(data)
35 | assert accumulate.function == _default_function
36 |
37 |
38 | def test_from_dict_with_default_function():
39 | data = {
40 | "type": "sample_components.accumulate.Accumulate",
41 | "init_parameters": {"function": "sample_components.accumulate._default_function"},
42 | }
43 | accumulate = Accumulate.from_dict(data)
44 | assert accumulate.function == _default_function
45 |
46 |
47 | def test_from_dict_with_custom_function():
48 | data = {
49 | "type": "sample_components.accumulate.Accumulate",
50 | "init_parameters": {"function": "test.sample_components.test_accumulate.my_subtract"},
51 | }
52 | accumulate = Accumulate.from_dict(data)
53 | assert accumulate.function == my_subtract
54 |
55 |
56 | def test_accumulate_default():
57 | component = Accumulate()
58 | results = component.run(value=10)
59 | assert results == {"value": 10}
60 | assert component.state == 10
61 |
62 | results = component.run(value=1)
63 | assert results == {"value": 11}
64 | assert component.state == 11
65 |
66 |
67 | def test_accumulate_callable():
68 | component = Accumulate(function=my_subtract)
69 |
70 | results = component.run(value=10)
71 | assert results == {"value": -10}
72 | assert component.state == -10
73 |
74 | results = component.run(value=1)
75 | assert results == {"value": -11}
76 | assert component.state == -11
77 |
--------------------------------------------------------------------------------
/test/sample_components/test_add_value.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: 2022-present deepset GmbH
2 | #
3 | # SPDX-License-Identifier: Apache-2.0
4 | from sample_components import AddFixedValue
5 | from canals.serialization import component_to_dict, component_from_dict
6 |
7 |
8 | def test_run():
9 | component = AddFixedValue()
10 | results = component.run(value=50, add=10)
11 | assert results == {"result": 60}
12 |
--------------------------------------------------------------------------------
/test/sample_components/test_concatenate.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: 2022-present deepset GmbH
2 | #
3 | # SPDX-License-Identifier: Apache-2.0
4 | from sample_components import Concatenate
5 | from canals.serialization import component_to_dict, component_from_dict
6 |
7 |
8 | def test_input_lists():
9 | component = Concatenate()
10 | res = component.run(first=["This"], second=["That"])
11 | assert res == {"value": ["This", "That"]}
12 |
13 |
14 | def test_input_strings():
15 | component = Concatenate()
16 | res = component.run(first="This", second="That")
17 | assert res == {"value": ["This", "That"]}
18 |
19 |
20 | def test_input_first_list_second_string():
21 | component = Concatenate()
22 | res = component.run(first=["This"], second="That")
23 | assert res == {"value": ["This", "That"]}
24 |
25 |
26 | def test_input_first_string_second_list():
27 | component = Concatenate()
28 | res = component.run(first="This", second=["That"])
29 | assert res == {"value": ["This", "That"]}
30 |
--------------------------------------------------------------------------------
/test/sample_components/test_double.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: 2022-present deepset GmbH
2 | #
3 | # SPDX-License-Identifier: Apache-2.0
4 |
5 | from sample_components import Double
6 | from canals.serialization import component_to_dict, component_from_dict
7 |
8 |
9 | def test_double_default():
10 | component = Double()
11 | results = component.run(value=10)
12 | assert results == {"value": 20}
13 |
--------------------------------------------------------------------------------
/test/sample_components/test_fstring.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | from sample_components import FString
3 |
4 |
5 | def test_fstring_with_one_var():
6 | fstring = FString(template="Hello, {name}!", variables=["name"])
7 | output = fstring.run(name="Alice")
8 | assert output == {"string": "Hello, Alice!"}
9 |
10 |
11 | def test_fstring_with_no_vars():
12 | fstring = FString(template="No variables in this template.", variables=[])
13 | output = fstring.run()
14 | assert output == {"string": "No variables in this template."}
15 |
16 |
17 | def test_fstring_with_template_at_runtime():
18 | fstring = FString(template="Hello {name}", variables=["name"])
19 | output = fstring.run(template="Goodbye {name}!", name="Alice")
20 | assert output == {"string": "Goodbye Alice!"}
21 |
22 |
23 | def test_fstring_with_vars_mismatch():
24 | fstring = FString(template="Hello {name}", variables=["name"])
25 | with pytest.raises(KeyError):
26 | fstring.run(template="Goodbye {person}!", name="Alice")
27 |
28 |
29 | def test_fstring_with_vars_in_excess():
30 | fstring = FString(template="Hello {name}", variables=["name"])
31 | output = fstring.run(template="Goodbye!", name="Alice")
32 | assert output == {"string": "Goodbye!"}
33 |
34 |
35 | def test_fstring_with_vars_missing():
36 | fstring = FString(template="{greeting}, {name}!", variables=["name"])
37 | with pytest.raises(KeyError):
38 | fstring.run(greeting="Hello")
39 |
--------------------------------------------------------------------------------
/test/sample_components/test_greet.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: 2022-present deepset GmbH
2 | #
3 | # SPDX-License-Identifier: Apache-2.0
4 | import logging
5 |
6 | from sample_components import Greet
7 | from canals.serialization import component_to_dict, component_from_dict
8 |
9 |
10 | def test_greet_message(caplog):
11 | caplog.set_level(logging.WARNING)
12 | component = Greet()
13 | results = component.run(value=10, message="Hello, that's {value}", log_level="WARNING")
14 | assert results == {"value": 10}
15 | assert "Hello, that's 10" in caplog.text
16 |
--------------------------------------------------------------------------------
/test/sample_components/test_merge_loop.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: 2022-present deepset GmbH
2 | #
3 | # SPDX-License-Identifier: Apache-2.0
4 | from typing import Dict
5 |
6 | import pytest
7 |
8 | from canals.errors import DeserializationError
9 |
10 | from sample_components import MergeLoop
11 |
12 |
13 | def test_to_dict():
14 | component = MergeLoop(expected_type=int, inputs=["first", "second"])
15 | res = component.to_dict()
16 | assert res == {
17 | "type": "sample_components.merge_loop.MergeLoop",
18 | "init_parameters": {"expected_type": "builtins.int", "inputs": ["first", "second"]},
19 | }
20 |
21 |
22 | def test_to_dict_with_typing_class():
23 | component = MergeLoop(expected_type=Dict, inputs=["first", "second"])
24 | res = component.to_dict()
25 | assert res == {
26 | "type": "sample_components.merge_loop.MergeLoop",
27 | "init_parameters": {
28 | "expected_type": "typing.Dict",
29 | "inputs": ["first", "second"],
30 | },
31 | }
32 |
33 |
34 | def test_to_dict_with_custom_class():
35 | component = MergeLoop(expected_type=MergeLoop, inputs=["first", "second"])
36 | res = component.to_dict()
37 | assert res == {
38 | "type": "sample_components.merge_loop.MergeLoop",
39 | "init_parameters": {
40 | "expected_type": "sample_components.merge_loop.MergeLoop",
41 | "inputs": ["first", "second"],
42 | },
43 | }
44 |
45 |
46 | def test_from_dict():
47 | data = {
48 | "type": "sample_components.merge_loop.MergeLoop",
49 | "init_parameters": {"expected_type": "builtins.int", "inputs": ["first", "second"]},
50 | }
51 | component = MergeLoop.from_dict(data)
52 | assert component.expected_type == "builtins.int"
53 | assert component.inputs == ["first", "second"]
54 |
55 |
56 | def test_from_dict_with_typing_class():
57 | data = {
58 | "type": "sample_components.merge_loop.MergeLoop",
59 | "init_parameters": {
60 | "expected_type": "typing.Dict",
61 | "inputs": ["first", "second"],
62 | },
63 | }
64 | component = MergeLoop.from_dict(data)
65 | assert component.expected_type == "typing.Dict"
66 | assert component.inputs == ["first", "second"]
67 |
68 |
69 | def test_from_dict_with_custom_class():
70 | data = {
71 | "type": "sample_components.merge_loop.MergeLoop",
72 | "init_parameters": {
73 | "expected_type": "sample_components.merge_loop.MergeLoop",
74 | "inputs": ["first", "second"],
75 | },
76 | }
77 | component = MergeLoop.from_dict(data)
78 | assert component.expected_type == "sample_components.merge_loop.MergeLoop"
79 | assert component.inputs == ["first", "second"]
80 |
81 |
82 | def test_from_dict_without_expected_type():
83 | data = {
84 | "type": "sample_components.merge_loop.MergeLoop",
85 | "init_parameters": {
86 | "inputs": ["first", "second"],
87 | },
88 | }
89 | with pytest.raises(DeserializationError) as exc:
90 | MergeLoop.from_dict(data)
91 |
92 | exc.match("Missing 'expected_type' field in 'init_parameters'")
93 |
94 |
95 | def test_from_dict_without_inputs():
96 | data = {
97 | "type": "sample_components.merge_loop.MergeLoop",
98 | "init_parameters": {
99 | "expected_type": "sample_components.merge_loop.MergeLoop",
100 | },
101 | }
102 | with pytest.raises(DeserializationError) as exc:
103 | MergeLoop.from_dict(data)
104 |
105 | exc.match("Missing 'inputs' field in 'init_parameters'")
106 |
107 |
108 | def test_merge_first():
109 | component = MergeLoop(expected_type=int, inputs=["in_1", "in_2"])
110 | results = component.run(in_1=5)
111 | assert results == {"value": 5}
112 |
113 |
114 | def test_merge_second():
115 | component = MergeLoop(expected_type=int, inputs=["in_1", "in_2"])
116 | results = component.run(in_2=5)
117 | assert results == {"value": 5}
118 |
119 |
120 | def test_merge_nones():
121 | component = MergeLoop(expected_type=int, inputs=["in_1", "in_2", "in_3"])
122 | results = component.run()
123 | assert results == {}
124 |
125 |
126 | def test_merge_one():
127 | component = MergeLoop(expected_type=int, inputs=["in_1"])
128 | results = component.run(in_1=1)
129 | assert results == {"value": 1}
130 |
131 |
132 | def test_merge_one_none():
133 | component = MergeLoop(expected_type=int, inputs=[])
134 | results = component.run()
135 | assert results == {}
136 |
--------------------------------------------------------------------------------
/test/sample_components/test_parity.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: 2022-present deepset GmbH
2 | #
3 | # SPDX-License-Identifier: Apache-2.0
4 | from sample_components import Parity
5 | from canals.serialization import component_to_dict, component_from_dict
6 |
7 |
8 | def test_parity():
9 | component = Parity()
10 | results = component.run(value=1)
11 | assert results == {"odd": 1}
12 | results = component.run(value=2)
13 | assert results == {"even": 2}
14 |
--------------------------------------------------------------------------------
/test/sample_components/test_remainder.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: 2022-present deepset GmbH
2 | #
3 | # SPDX-License-Identifier: Apache-2.0
4 | import pytest
5 |
6 | from sample_components import Remainder
7 | from canals.serialization import component_to_dict, component_from_dict
8 |
9 |
10 | def test_remainder_default():
11 | component = Remainder()
12 | results = component.run(value=4)
13 | assert results == {"remainder_is_1": 4}
14 |
15 |
16 | def test_remainder_with_divisor():
17 | component = Remainder(divisor=4)
18 | results = component.run(value=4)
19 | assert results == {"remainder_is_0": 4}
20 |
21 |
22 | def test_remainder_zero():
23 | with pytest.raises(ValueError):
24 | Remainder(divisor=0)
25 |
--------------------------------------------------------------------------------
/test/sample_components/test_repeat.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: 2022-present deepset GmbH
2 | #
3 | # SPDX-License-Identifier: Apache-2.0
4 | from sample_components import Repeat
5 | from canals.serialization import component_to_dict, component_from_dict
6 |
7 |
8 | def test_repeat_default():
9 | component = Repeat(outputs=["one", "two"])
10 | results = component.run(value=10)
11 | assert results == {"one": 10, "two": 10}
12 |
--------------------------------------------------------------------------------
/test/sample_components/test_subtract.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: 2022-present deepset GmbH
2 | #
3 | # SPDX-License-Identifier: Apache-2.0
4 | from sample_components import Subtract
5 | from canals.serialization import component_to_dict, component_from_dict
6 |
7 |
8 | def test_subtract():
9 | component = Subtract()
10 | results = component.run(first_value=10, second_value=7)
11 | assert results == {"difference": 3}
12 |
--------------------------------------------------------------------------------
/test/sample_components/test_sum.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: 2022-present deepset GmbH
2 | #
3 | # SPDX-License-Identifier: Apache-2.0
4 |
5 | from sample_components import Sum
6 | from canals.serialization import component_to_dict, component_from_dict
7 |
8 |
9 | def test_sum_receives_no_values():
10 | component = Sum()
11 | results = component.run(values=[])
12 | assert results == {"total": 0}
13 |
14 |
15 | def test_sum_receives_one_value():
16 | component = Sum()
17 | assert component.run(values=[10]) == {"total": 10}
18 |
19 |
20 | def test_sum_receives_few_values():
21 | component = Sum()
22 | assert component.run(values=[10, 2]) == {"total": 12}
23 |
--------------------------------------------------------------------------------
/test/sample_components/test_threshold.py:
--------------------------------------------------------------------------------
1 | # SPDX-FileCopyrightText: 2022-present deepset GmbH
2 | #
3 | # SPDX-License-Identifier: Apache-2.0
4 | from sample_components import Threshold
5 | from canals.serialization import component_to_dict, component_from_dict
6 |
7 |
8 | def test_threshold():
9 | component = Threshold()
10 |
11 | results = component.run(value=5, threshold=10)
12 | assert results == {"below": 5}
13 |
14 | results = component.run(value=15, threshold=10)
15 | assert results == {"above": 15}
16 |
--------------------------------------------------------------------------------
/test/test_files/mermaid_mock/test_response.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/deepset-ai/canals/5fd3dcff3cebff81bb6fc38ed3303c2c2a725321/test/test_files/mermaid_mock/test_response.png
--------------------------------------------------------------------------------
/test/test_files/pipeline_draw/pygraphviz.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/deepset-ai/canals/5fd3dcff3cebff81bb6fc38ed3303c2c2a725321/test/test_files/pipeline_draw/pygraphviz.jpg
--------------------------------------------------------------------------------
/test/test_serialization.py:
--------------------------------------------------------------------------------
1 | import sys
2 | from unittest.mock import Mock
3 |
4 | import pytest
5 |
6 | from canals import Pipeline, component
7 | from canals.errors import DeserializationError
8 | from canals.testing import factory
9 | from canals.serialization import default_to_dict, default_from_dict
10 |
11 |
12 | def test_default_component_to_dict():
13 | MyComponent = factory.component_class("MyComponent")
14 | comp = MyComponent()
15 | res = default_to_dict(comp)
16 | assert res == {
17 | "type": "canals.testing.factory.MyComponent",
18 | "init_parameters": {},
19 | }
20 |
21 |
22 | def test_default_component_to_dict_with_init_parameters():
23 | MyComponent = factory.component_class("MyComponent")
24 | comp = MyComponent()
25 | res = default_to_dict(comp, some_key="some_value")
26 | assert res == {
27 | "type": "canals.testing.factory.MyComponent",
28 | "init_parameters": {"some_key": "some_value"},
29 | }
30 |
31 |
32 | def test_default_component_from_dict():
33 | def custom_init(self, some_param):
34 | self.some_param = some_param
35 |
36 | extra_fields = {"__init__": custom_init}
37 | MyComponent = factory.component_class("MyComponent", extra_fields=extra_fields)
38 | comp = default_from_dict(
39 | MyComponent,
40 | {
41 | "type": "canals.testing.factory.MyComponent",
42 | "init_parameters": {
43 | "some_param": 10,
44 | },
45 | },
46 | )
47 | assert isinstance(comp, MyComponent)
48 | assert comp.some_param == 10
49 |
50 |
51 | def test_default_component_from_dict_without_type():
52 | with pytest.raises(DeserializationError, match="Missing 'type' in serialization data"):
53 | default_from_dict(Mock, {})
54 |
55 |
56 | def test_default_component_from_dict_unregistered_component(request):
57 | # We use the test function name as component name to make sure it's not registered.
58 | # Since the registry is global we risk to have a component with the same name registered in another test.
59 | component_name = request.node.name
60 |
61 | with pytest.raises(DeserializationError, match=f"Class '{component_name}' can't be deserialized as 'Mock'"):
62 | default_from_dict(Mock, {"type": component_name})
63 |
64 |
65 | def test_from_dict_import_type():
66 | pipeline_dict = {
67 | "metadata": {},
68 | "max_loops_allowed": 100,
69 | "components": {
70 | "greeter": {
71 | "type": "sample_components.greet.Greet",
72 | "init_parameters": {
73 | "message": "\nGreeting component says: Hi! The value is {value}\n",
74 | "log_level": "INFO",
75 | },
76 | }
77 | },
78 | "connections": [],
79 | }
80 |
81 | # remove the target component from the registry if already there
82 | component.registry.pop("sample_components.greet.Greet", None)
83 | # remove the module from sys.modules if already there
84 | sys.modules.pop("sample_components.greet", None)
85 |
86 | p = Pipeline.from_dict(pipeline_dict)
87 |
88 | from sample_components.greet import Greet
89 |
90 | assert type(p.get_component("greeter")) == Greet
91 |
--------------------------------------------------------------------------------
/test/test_utils.py:
--------------------------------------------------------------------------------
1 | from typing import List, Set, Sequence, Tuple, Dict, Mapping, Literal, Union, Optional, Any
2 | from enum import Enum
3 | from pathlib import Path
4 |
5 | import pytest
6 |
7 | from canals.type_utils import _type_name
8 |
9 |
10 | class Class1:
11 | ...
12 |
13 |
14 | class Class2:
15 | ...
16 |
17 |
18 | class Class3(Class1):
19 | ...
20 |
21 |
22 | class Enum1(Enum):
23 | TEST1 = Class1
24 | TEST2 = Class2
25 |
26 |
27 | @pytest.mark.parametrize(
28 | "type_,repr",
29 | [
30 | pytest.param(str, "str", id="primitive-types"),
31 | pytest.param(Any, "Any", id="any"),
32 | pytest.param(Class1, "Class1", id="class"),
33 | pytest.param(Optional[int], "Optional[int]", id="shallow-optional-with-primitive"),
34 | pytest.param(Optional[Any], "Optional[Any]", id="shallow-optional-with-any"),
35 | pytest.param(Optional[Class1], "Optional[Class1]", id="shallow-optional-with-class"),
36 | pytest.param(Union[bool, Class1], "Union[bool, Class1]", id="shallow-union"),
37 | pytest.param(List[str], "List[str]", id="shallow-sequence-of-primitives"),
38 | pytest.param(List[Set[Sequence[str]]], "List[Set[Sequence[str]]]", id="nested-sequence-of-primitives"),
39 | pytest.param(
40 | Optional[List[Set[Sequence[str]]]],
41 | "Optional[List[Set[Sequence[str]]]]",
42 | id="optional-nested-sequence-of-primitives",
43 | ),
44 | pytest.param(
45 | List[Set[Sequence[Optional[str]]]],
46 | "List[Set[Sequence[Optional[str]]]]",
47 | id="nested-optional-sequence-of-primitives",
48 | ),
49 | pytest.param(List[Class1], "List[Class1]", id="shallow-sequence-of-classes"),
50 | pytest.param(List[Set[Sequence[Class1]]], "List[Set[Sequence[Class1]]]", id="nested-sequence-of-classes"),
51 | pytest.param(Dict[str, int], "Dict[str, int]", id="shallow-mapping-of-primitives"),
52 | pytest.param(
53 | Dict[str, Mapping[str, Dict[str, int]]],
54 | "Dict[str, Mapping[str, Dict[str, int]]]",
55 | id="nested-mapping-of-primitives",
56 | ),
57 | pytest.param(
58 | Dict[str, Mapping[Any, Dict[str, int]]],
59 | "Dict[str, Mapping[Any, Dict[str, int]]]",
60 | id="nested-mapping-of-primitives-with-any",
61 | ),
62 | pytest.param(Dict[str, Class1], "Dict[str, Class1]", id="shallow-mapping-of-classes"),
63 | pytest.param(
64 | Dict[str, Mapping[str, Dict[str, Class1]]],
65 | "Dict[str, Mapping[str, Dict[str, Class1]]]",
66 | id="nested-mapping-of-classes",
67 | ),
68 | pytest.param(
69 | Literal["a", "b", "c"],
70 | "Literal['a', 'b', 'c']",
71 | id="string-literal",
72 | ),
73 | pytest.param(
74 | Literal[1, 2, 3],
75 | "Literal[1, 2, 3]",
76 | id="primitive-literal",
77 | ),
78 | pytest.param(
79 | Literal[Enum1.TEST1],
80 | "Literal[Enum1.TEST1]",
81 | id="enum-literal",
82 | ),
83 | pytest.param(
84 | Tuple[Optional[Literal["a", "b", "c"]], Union[Path, Dict[int, Class1]]],
85 | "Tuple[Optional[Literal['a', 'b', 'c']], Union[Path, Dict[int, Class1]]]",
86 | id="deeply-nested-complex-type",
87 | ),
88 | ],
89 | )
90 | def test_type_name(type_, repr):
91 | assert _type_name(type_) == repr
92 |
--------------------------------------------------------------------------------
/test/testing/test_factory.py:
--------------------------------------------------------------------------------
1 | import pytest
2 |
3 | from canals import component
4 | from canals.errors import ComponentError
5 | from canals.testing import factory
6 |
7 |
8 | def test_component_class_default():
9 | MyComponent = factory.component_class("MyComponent")
10 | comp = MyComponent()
11 | res = comp.run(value=1)
12 | assert res == {"value": None}
13 |
14 | res = comp.run(value="something")
15 | assert res == {"value": None}
16 |
17 | res = comp.run(non_existing_input=1)
18 | assert res == {"value": None}
19 |
20 |
21 | def test_component_class_is_registered():
22 | MyComponent = factory.component_class("MyComponent")
23 | assert component.registry["canals.testing.factory.MyComponent"] == MyComponent
24 |
25 |
26 | def test_component_class_with_input_types():
27 | MyComponent = factory.component_class("MyComponent", input_types={"value": int})
28 | comp = MyComponent()
29 | res = comp.run(value=1)
30 | assert res == {"value": None}
31 |
32 | res = comp.run(value="something")
33 | assert res == {"value": None}
34 |
35 |
36 | def test_component_class_with_output_types():
37 | MyComponent = factory.component_class("MyComponent", output_types={"value": int})
38 | comp = MyComponent()
39 |
40 | res = comp.run(value=1)
41 | assert res == {"value": None}
42 |
43 |
44 | def test_component_class_with_output():
45 | MyComponent = factory.component_class("MyComponent", output={"value": 100})
46 | comp = MyComponent()
47 | res = comp.run(value=1)
48 | assert res == {"value": 100}
49 |
50 |
51 | def test_component_class_with_output_and_output_types():
52 | MyComponent = factory.component_class("MyComponent", output_types={"value": str}, output={"value": 100})
53 | comp = MyComponent()
54 |
55 | res = comp.run(value=1)
56 | assert res == {"value": 100}
57 |
58 |
59 | def test_component_class_with_bases():
60 | MyComponent = factory.component_class("MyComponent", bases=(Exception,))
61 | comp = MyComponent()
62 | assert isinstance(comp, Exception)
63 |
64 |
65 | def test_component_class_with_extra_fields():
66 | MyComponent = factory.component_class("MyComponent", extra_fields={"my_field": 10})
67 | comp = MyComponent()
68 | assert comp.my_field == 10
69 |
--------------------------------------------------------------------------------