├── .github ├── CODEOWNERS └── workflows │ ├── api-docs.yml │ ├── release.yml │ └── tests.yml ├── .gitignore ├── .pre-commit-config.yaml ├── LICENSE ├── README.md ├── canals ├── __about__.py ├── __init__.py ├── component │ ├── __init__.py │ ├── component.py │ ├── connection.py │ ├── descriptions.py │ ├── sockets.py │ └── types.py ├── errors.py ├── pipeline │ ├── __init__.py │ ├── descriptions.py │ ├── draw │ │ ├── __init__.py │ │ ├── draw.py │ │ ├── graphviz.py │ │ └── mermaid.py │ ├── pipeline.py │ └── validation.py ├── serialization.py ├── testing │ ├── __init__.py │ └── factory.py └── type_utils.py ├── docs ├── api-docs │ ├── canals.md │ ├── component.md │ ├── pipeline.md │ └── testing.md ├── concepts │ ├── components.md │ ├── concepts.md │ └── pipelines.md └── index.md ├── images ├── canals-logo-dark.png └── canals-logo-light.png ├── mkdocs.yml ├── pyproject.toml ├── sample_components ├── __init__.py ├── accumulate.py ├── add_value.py ├── concatenate.py ├── double.py ├── fstring.py ├── greet.py ├── hello.py ├── joiner.py ├── merge_loop.py ├── parity.py ├── remainder.py ├── repeat.py ├── self_loop.py ├── subtract.py ├── sum.py ├── text_splitter.py └── threshold.py └── test ├── __init__.py ├── component ├── test_component.py └── test_connection.py ├── conftest.py ├── pipeline ├── __init__.py ├── integration │ ├── __init__.py │ ├── test_complex_pipeline.py │ ├── test_default_value.py │ ├── test_distinct_loops_pipeline.py │ ├── test_double_loop_pipeline.py │ ├── test_dynamic_inputs_pipeline.py │ ├── test_fixed_decision_and_merge_pipeline.py │ ├── test_fixed_decision_pipeline.py │ ├── test_fixed_merging_pipeline.py │ ├── test_joiners.py │ ├── test_linear_pipeline.py │ ├── test_looping_and_merge_pipeline.py │ ├── test_looping_pipeline.py │ ├── test_mutable_inputs.py │ ├── test_parallel_branches_pipeline.py │ ├── test_self_loop.py │ ├── test_variable_decision_and_merge_pipeline.py │ ├── test_variable_decision_pipeline.py │ └── test_variable_merging_pipeline.py └── unit │ ├── __init__.py │ ├── test_connections.py │ ├── test_draw.py │ ├── test_pipeline.py │ └── test_validation_pipeline_io.py ├── sample_components ├── __init__.py ├── test_accumulate.py ├── test_add_value.py ├── test_concatenate.py ├── test_double.py ├── test_fstring.py ├── test_greet.py ├── test_merge_loop.py ├── test_parity.py ├── test_remainder.py ├── test_repeat.py ├── test_subtract.py ├── test_sum.py └── test_threshold.py ├── test_files ├── mermaid_mock │ └── test_response.png └── pipeline_draw │ └── pygraphviz.jpg ├── test_serialization.py ├── test_utils.py └── testing └── test_factory.py /.github/CODEOWNERS: -------------------------------------------------------------------------------- 1 | # See https://help.github.com/articles/about-codeowners/ for syntax 2 | 3 | # Core Engineering will be the default owners for everything 4 | # in the repo. Unless a later match takes precedence, 5 | # @deepset-ai/core-engineering will be requested for review 6 | # when someone opens a pull request. 7 | * @deepset-ai/core-engineering -------------------------------------------------------------------------------- /.github/workflows/api-docs.yml: -------------------------------------------------------------------------------- 1 | name: API Docs 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | 8 | permissions: 9 | contents: write 10 | 11 | jobs: 12 | deploy: 13 | runs-on: ubuntu-latest 14 | steps: 15 | - uses: actions/checkout@v3 16 | - uses: actions/setup-python@v4 17 | with: 18 | python-version: 3.x 19 | 20 | - uses: actions/cache@v2 21 | with: 22 | key: ${{ github.ref }} 23 | path: .cache 24 | 25 | - run: pip install mkdocs-material mkdocstrings[python] mkdocs-mermaid2-plugin 26 | 27 | - run: mkdocs gh-deploy --force 28 | -------------------------------------------------------------------------------- /.github/workflows/release.yml: -------------------------------------------------------------------------------- 1 | name: Release 2 | 3 | on: 4 | push: 5 | tags: 6 | - "v[0-9].[0-9]+.[0-9]+*" 7 | 8 | jobs: 9 | release-on-pypi: 10 | runs-on: ubuntu-latest 11 | 12 | steps: 13 | - name: Checkout 14 | uses: actions/checkout@v3 15 | 16 | - name: Install Hatch 17 | run: pip install hatch 18 | 19 | - name: Build 20 | run: hatch build 21 | 22 | - name: Publish on PyPi 23 | env: 24 | HATCH_INDEX_USER: __token__ 25 | HATCH_INDEX_AUTH: ${{ secrets.PYPI_API_TOKEN }} 26 | run: hatch publish -y 27 | -------------------------------------------------------------------------------- /.github/workflows/tests.yml: -------------------------------------------------------------------------------- 1 | name: Tests 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | pull_request: 8 | paths: 9 | - "**.py" 10 | - "**/pyproject.toml" 11 | 12 | env: 13 | COVERALLS_NOISY: true 14 | 15 | jobs: 16 | mypy: 17 | runs-on: ubuntu-latest 18 | steps: 19 | - name: Checkout 20 | uses: actions/checkout@v3 21 | 22 | - name: Setup Python 23 | uses: actions/setup-python@v4 24 | with: 25 | python-version: '3.8' 26 | 27 | - name: Install Canals 28 | run: | 29 | sudo apt install graphviz libgraphviz-dev 30 | pip install .[dev] pygraphviz 31 | 32 | - name: Mypy 33 | run: | 34 | mkdir .mypy_cache/ 35 | mypy --install-types --non-interactive --ignore-missing-imports canals/ sample_components/ 36 | 37 | pylint: 38 | runs-on: ubuntu-latest 39 | steps: 40 | - name: Checkout 41 | uses: actions/checkout@v3 42 | 43 | - name: Setup Python 44 | uses: actions/setup-python@v4 45 | with: 46 | python-version: '3.8' 47 | 48 | - name: Install Canals 49 | run: | 50 | sudo apt install graphviz libgraphviz-dev 51 | pip install .[dev] pygraphviz 52 | 53 | - name: Pylint 54 | run: pylint -ry -j 0 canals/ sample_components/ 55 | 56 | black: 57 | runs-on: ubuntu-latest 58 | steps: 59 | - name: Checkout 60 | uses: actions/checkout@v3 61 | 62 | - name: Setup Python 63 | uses: actions/setup-python@v4 64 | with: 65 | python-version: '3.8' 66 | 67 | - name: Install Canals 68 | run: pip install .[dev] 69 | 70 | - name: Check status 71 | run: black canals/ --check 72 | 73 | tests: 74 | name: Unit / Python ${{ matrix.version }} / ${{ matrix.os }} 75 | strategy: 76 | fail-fast: false 77 | matrix: 78 | version: 79 | - "3.8" 80 | - "3.9" 81 | - "3.10" 82 | - "3.11" 83 | os: 84 | - ubuntu-latest 85 | - windows-latest 86 | - macos-latest 87 | runs-on: ${{ matrix.os }} 88 | steps: 89 | - name: Checkout 90 | uses: actions/checkout@v3 91 | 92 | - name: Setup Python 93 | uses: actions/setup-python@v4 94 | with: 95 | python-version: ${{ matrix.version }} 96 | 97 | - name: Install Canals (Ubuntu) 98 | if: matrix.os == 'ubuntu-latest' 99 | run: | 100 | sudo apt install graphviz libgraphviz-dev 101 | pip install .[dev] pygraphviz 102 | 103 | - name: Install Canals (MacOS) 104 | if: matrix.os == 'macos-latest' 105 | run: | 106 | # brew only offers graphviz 8, which seems to be incompatible with pygraphviz :( 107 | # brew install graphviz@2.49.0 108 | pip install .[dev] 109 | 110 | - name: Install Canals (Windows) 111 | if: matrix.os == 'windows-latest' 112 | run: | 113 | # Doesn't seem to work in CI :( 114 | # choco install graphviz 115 | # python -m pip install --global-option=build_ext ` 116 | # --global-option="-IC:\Program Files\Graphviz\include" ` 117 | # --global-option="-LC:\Program Files\Graphviz\lib" ` 118 | # pygraphviz 119 | pip install .[dev] 120 | 121 | - name: Run 122 | run: pytest --cov-report xml:coverage.xml --cov="canals" test/ 123 | 124 | - name: Coverage 125 | if: matrix.os == 'ubuntu-latest' && matrix.version == 3.11 126 | uses: coverallsapp/github-action@v2 127 | with: 128 | path-to-lcov: coverage.xml 129 | parallel: false 130 | debug: true 131 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .vscode/* 2 | 3 | # Byte-compiled / optimized / DLL files 4 | __pycache__/ 5 | *.py[cod] 6 | *$py.class 7 | 8 | # C extensions 9 | *.so 10 | 11 | # Distribution / packaging 12 | .Python 13 | build/ 14 | develop-eggs/ 15 | dist/ 16 | downloads/ 17 | eggs/ 18 | .eggs/ 19 | lib/ 20 | lib64/ 21 | parts/ 22 | sdist/ 23 | var/ 24 | wheels/ 25 | pip-wheel-metadata/ 26 | share/python-wheels/ 27 | *.egg-info/ 28 | .installed.cfg 29 | *.egg 30 | MANIFEST 31 | 32 | # PyInstaller 33 | # Usually these files are written by a python script from a template 34 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 35 | *.manifest 36 | *.spec 37 | 38 | # Installer logs 39 | pip-log.txt 40 | pip-delete-this-directory.txt 41 | 42 | # Unit test / coverage reports 43 | htmlcov/ 44 | .tox/ 45 | .nox/ 46 | .coverage 47 | .coverage.* 48 | .cache 49 | nosetests.xml 50 | coverage.xml 51 | *.cover 52 | *.py,cover 53 | .hypothesis/ 54 | .pytest_cache/ 55 | 56 | # Translations 57 | *.mo 58 | *.pot 59 | 60 | # Django stuff: 61 | *.log 62 | local_settings.py 63 | db.sqlite3 64 | db.sqlite3-journal 65 | 66 | # Flask stuff: 67 | instance/ 68 | .webassets-cache 69 | 70 | # Scrapy stuff: 71 | .scrapy 72 | 73 | # Sphinx documentation 74 | docs/_build/ 75 | 76 | # PyBuilder 77 | target/ 78 | 79 | # Jupyter Notebook 80 | .ipynb_checkpoints 81 | 82 | # IPython 83 | profile_default/ 84 | ipython_config.py 85 | 86 | # pyenv 87 | .python-version 88 | 89 | # pipenv 90 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 91 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 92 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 93 | # install all needed dependencies. 94 | #Pipfile.lock 95 | 96 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 97 | __pypackages__/ 98 | 99 | # Celery stuff 100 | celerybeat-schedule 101 | celerybeat.pid 102 | 103 | # SageMath parsed files 104 | *.sage.py 105 | 106 | # Environments 107 | .env 108 | .venv 109 | env/ 110 | venv/ 111 | ENV/ 112 | env.bak/ 113 | venv.bak/ 114 | 115 | # Spyder project settings 116 | .spyderproject 117 | .spyproject 118 | 119 | # Rope project settings 120 | .ropeproject 121 | 122 | # mkdocs documentation 123 | /site 124 | 125 | # mypy 126 | .mypy_cache/ 127 | .dmypy.json 128 | dmypy.json 129 | 130 | # Pyre type checker 131 | .pyre/ 132 | 133 | # Canals 134 | drafts/ 135 | .canals_debug/ 136 | test/**/*.png 137 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | fail_fast: true 2 | 3 | repos: 4 | - repo: https://github.com/pre-commit/pre-commit-hooks 5 | rev: v4.2.0 6 | hooks: 7 | - id: check-ast # checks Python syntax 8 | - id: check-json # checks JSON syntax 9 | # - id: check-yaml # checks YAML syntax 10 | - id: check-toml # checks TOML syntax 11 | - id: end-of-file-fixer # checks there is a newline at the end of the file 12 | - id: trailing-whitespace # trims trailing whitespace 13 | - id: check-merge-conflict # checks for no merge conflict strings 14 | - id: check-shebang-scripts-are-executable # checks all shell scripts have executable permissions 15 | - id: mixed-line-ending # normalizes line endings 16 | #- id: no-commit-to-branch # prevents committing to main 17 | 18 | - repo: https://github.com/psf/black 19 | rev: 22.6.0 # IMPORTANT: keep this aligned with the black version in pyproject.toml 20 | hooks: 21 | - id: black-jupyter 22 | 23 | - repo: https://github.com/pre-commit/mirrors-mypy 24 | rev: 'v1.1.1' 25 | hooks: 26 | - id: mypy 27 | exclude: ^test/ 28 | args: [--ignore-missing-imports] 29 | additional_dependencies: ['types-requests'] 30 | 31 | - repo: https://github.com/pycqa/pylint 32 | rev: 'v2.17.0' 33 | hooks: 34 | - id: pylint 35 | exclude: ^test/ 36 | args: [ 37 | "--disable=import-error" # FIXME 38 | ] 39 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | > ⚠️ The project was merged into Haystack, namely into the [`core`](https://github.com/deepset-ai/haystack/tree/main/haystack/core) package. 2 | 3 | # Canals 4 | 5 |

6 | 7 | 8 |

9 | 10 | [![PyPI - Version](https://img.shields.io/pypi/v/canals.svg)](https://pypi.org/project/canals) 11 | [![PyPI - Python Version](https://img.shields.io/pypi/pyversions/canals.svg)](https://pypi.org/project/canals) 12 | 13 | Tests 14 | 15 | [![Coverage Status](https://coveralls.io/repos/github/deepset-ai/canals/badge.svg?branch=main)](https://coveralls.io/github/deepset-ai/canals?branch=main) 16 | 17 | Documentation 18 | 19 | 20 | Last commit 21 | 22 | 23 | Monthly Downloads 24 | 25 | 26 | Stars 27 | 28 | 29 | Stats 30 | 31 | 32 | Canals is a **component orchestration engine**. Components are Python objects that can execute a task, like reading a file, performing calculations, or making API calls. Canals connects these objects together: it builds a graph of components and takes care of managing their execution order, making sure that each object receives the input it expects from the other components of the pipeline. 33 | 34 | Canals powers version 2.0 of the [Haystack framework](https://github.com/deepset-ai/haystack). 35 | 36 | ## Installation 37 | 38 | To install Canals, run: 39 | 40 | ```console 41 | pip install canals 42 | ``` 43 | 44 | To be able to draw pipelines (`Pipeline.draw()` method), please make sure you have either an internet connection (to reach the Mermaid graph renderer at `https://mermaid.ink`) or [graphviz](https://graphviz.org/download/) (version 2.49.0) installed. If you 45 | plan to use Mermaid there is no additional steps to take, while for graphviz 46 | you need to do: 47 | 48 | ### GraphViz 49 | ```console 50 | sudo apt install graphviz # You may need `graphviz-dev` too 51 | pip install pygraphviz 52 | ``` 53 | -------------------------------------------------------------------------------- /canals/__about__.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: 2022-present deepset GmbH 2 | # 3 | # SPDX-License-Identifier: Apache-2.0 4 | __version__ = "0.11.0" 5 | -------------------------------------------------------------------------------- /canals/__init__.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: 2022-present deepset GmbH 2 | # 3 | # SPDX-License-Identifier: Apache-2.0 4 | from canals.__about__ import __version__ 5 | 6 | from canals.component import component, Component 7 | from canals.pipeline.pipeline import Pipeline 8 | 9 | __all__ = ["component", "Component", "Pipeline"] 10 | -------------------------------------------------------------------------------- /canals/component/__init__.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: 2022-present deepset GmbH 2 | # 3 | # SPDX-License-Identifier: Apache-2.0 4 | from canals.component.component import component, Component 5 | from canals.component.sockets import InputSocket, OutputSocket 6 | 7 | __all__ = ["component", "Component", "InputSocket", "OutputSocket"] 8 | -------------------------------------------------------------------------------- /canals/component/component.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: 2022-present deepset GmbH 2 | # 3 | # SPDX-License-Identifier: Apache-2.0 4 | """ 5 | Attributes: 6 | 7 | component: Marks a class as a component. Any class decorated with `@component` can be used by a Pipeline. 8 | 9 | All components must follow the contract below. This docstring is the source of truth for components contract. 10 | 11 |
12 | 13 | `@component` decorator 14 | 15 | All component classes must be decorated with the `@component` decorator. This allows Canals to discover them. 16 | 17 |
18 | 19 | `__init__(self, **kwargs)` 20 | 21 | Optional method. 22 | 23 | Components may have an `__init__` method where they define: 24 | 25 | - `self.init_parameters = {same parameters that the __init__ method received}`: 26 | In this dictionary you can store any state the components wish to be persisted when they are saved. 27 | These values will be given to the `__init__` method of a new instance when the pipeline is loaded. 28 | Note that by default the `@component` decorator saves the arguments automatically. 29 | However, if a component sets their own `init_parameters` manually in `__init__()`, that will be used instead. 30 | Note: all of the values contained here **must be JSON serializable**. Serialize them manually if needed. 31 | 32 | Components should take only "basic" Python types as parameters of their `__init__` function, or iterables and 33 | dictionaries containing only such values. Anything else (objects, functions, etc) will raise an exception at init 34 | time. If there's the need for such values, consider serializing them to a string. 35 | 36 | _(TODO explain how to use classes and functions in init. In the meantime see `test/components/test_accumulate.py`)_ 37 | 38 | The `__init__` must be extrememly lightweight, because it's a frequent operation during the construction and 39 | validation of the pipeline. If a component has some heavy state to initialize (models, backends, etc...) refer to 40 | the `warm_up()` method. 41 | 42 |
43 | 44 | `warm_up(self)` 45 | 46 | Optional method. 47 | 48 | This method is called by Pipeline before the graph execution. Make sure to avoid double-initializations, 49 | because Pipeline will not keep track of which components it called `warm_up()` on. 50 | 51 |
52 | 53 | `run(self, data)` 54 | 55 | Mandatory method. 56 | 57 | This is the method where the main functionality of the component should be carried out. It's called by 58 | `Pipeline.run()`. 59 | 60 | When the component should run, Pipeline will call this method with an instance of the dataclass returned by the 61 | method decorated with `@component.input`. This dataclass contains: 62 | 63 | - all the input values coming from other components connected to it, 64 | - if any is missing, the corresponding value defined in `self.defaults`, if it exists. 65 | 66 | `run()` must return a single instance of the dataclass declared through the method decorated with 67 | `@component.output`. 68 | 69 | """ 70 | 71 | import logging 72 | import inspect 73 | from typing import Protocol, runtime_checkable, Any 74 | from types import new_class 75 | from copy import deepcopy 76 | 77 | from canals.component.sockets import InputSocket, OutputSocket 78 | from canals.errors import ComponentError 79 | 80 | logger = logging.getLogger(__name__) 81 | 82 | 83 | @runtime_checkable 84 | class Component(Protocol): 85 | """ 86 | Note this is only used by type checking tools. 87 | 88 | In order to implement the `Component` protocol, custom components need to 89 | have a `run` method. The signature of the method and its return value 90 | won't be checked, i.e. classes with the following methods: 91 | 92 | def run(self, param: str) -> Dict[str, Any]: 93 | ... 94 | 95 | and 96 | 97 | def run(self, **kwargs): 98 | ... 99 | 100 | will be both considered as respecting the protocol. This makes the type 101 | checking much weaker, but we have other places where we ensure code is 102 | dealing with actual Components. 103 | 104 | The protocol is runtime checkable so it'll be possible to assert: 105 | 106 | isinstance(MyComponent, Component) 107 | """ 108 | 109 | def run(self, *args: Any, **kwargs: Any): # pylint: disable=missing-function-docstring 110 | ... 111 | 112 | 113 | class ComponentMeta(type): 114 | def __call__(cls, *args, **kwargs): 115 | """ 116 | This method is called when clients instantiate a Component and 117 | runs before __new__ and __init__. 118 | """ 119 | # This will call __new__ then __init__, giving us back the Component instance 120 | instance = super().__call__(*args, **kwargs) 121 | 122 | # Before returning, we have the chance to modify the newly created 123 | # Component instance, so we take the chance and set up the I/O sockets 124 | 125 | # If `component.set_output_types()` was called in the component constructor, 126 | # `__canals_output__` is already populated, no need to do anything. 127 | if not hasattr(instance, "__canals_output__"): 128 | # If that's not the case, we need to populate `__canals_output__` 129 | # 130 | # If the `run` method was decorated, it has a `_output_types_cache` field assigned 131 | # that stores the output specification. 132 | # We deepcopy the content of the cache to transfer ownership from the class method 133 | # to the actual instance, so that different instances of the same class won't share this data. 134 | instance.__canals_output__ = deepcopy(getattr(instance.run, "_output_types_cache", {})) 135 | 136 | # Create the sockets if set_input_types() wasn't called in the constructor. 137 | # If it was called and there are some parameters also in the `run()` method, these take precedence. 138 | if not hasattr(instance, "__canals_input__"): 139 | instance.__canals_input__ = {} 140 | run_signature = inspect.signature(getattr(cls, "run")) 141 | for param in list(run_signature.parameters)[1:]: # First is 'self' and it doesn't matter. 142 | if run_signature.parameters[param].kind == inspect.Parameter.POSITIONAL_OR_KEYWORD: # ignore `**kwargs` 143 | instance.__canals_input__[param] = InputSocket( 144 | name=param, 145 | type=run_signature.parameters[param].annotation, 146 | is_mandatory=run_signature.parameters[param].default == inspect.Parameter.empty, 147 | ) 148 | return instance 149 | 150 | 151 | class _Component: 152 | """ 153 | See module's docstring. 154 | 155 | Args: 156 | class_: the class that Canals should use as a component. 157 | serializable: whether to check, at init time, if the component can be saved with 158 | `save_pipelines()`. 159 | 160 | Returns: 161 | A class that can be recognized as a component. 162 | 163 | Raises: 164 | ComponentError: if the class provided has no `run()` method or otherwise doesn't respect the component contract. 165 | """ 166 | 167 | def __init__(self): 168 | self.registry = {} 169 | 170 | def set_input_types(self, instance, **types): 171 | """ 172 | Method that specifies the input types when 'kwargs' is passed to the run method. 173 | 174 | Use as: 175 | 176 | ```python 177 | @component 178 | class MyComponent: 179 | 180 | def __init__(self, value: int): 181 | component.set_input_types(value_1=str, value_2=str) 182 | ... 183 | 184 | @component.output_types(output_1=int, output_2=str) 185 | def run(self, **kwargs): 186 | return {"output_1": kwargs["value_1"], "output_2": ""} 187 | ``` 188 | 189 | Note that if the `run()` method also specifies some parameters, those will take precedence. 190 | 191 | For example: 192 | 193 | ```python 194 | @component 195 | class MyComponent: 196 | 197 | def __init__(self, value: int): 198 | component.set_input_types(value_1=str, value_2=str) 199 | ... 200 | 201 | @component.output_types(output_1=int, output_2=str) 202 | def run(self, value_0: str, value_1: Optional[str] = None, **kwargs): 203 | return {"output_1": kwargs["value_1"], "output_2": ""} 204 | ``` 205 | 206 | would add a mandatory `value_0` parameters, make the `value_1` 207 | parameter optional with a default None, and keep the `value_2` 208 | parameter mandatory as specified in `set_input_types`. 209 | 210 | """ 211 | instance.__canals_input__ = {name: InputSocket(name=name, type=type_) for name, type_ in types.items()} 212 | 213 | def set_output_types(self, instance, **types): 214 | """ 215 | Method that specifies the output types when the 'run' method is not decorated 216 | with 'component.output_types'. 217 | 218 | Use as: 219 | 220 | ```python 221 | @component 222 | class MyComponent: 223 | 224 | def __init__(self, value: int): 225 | component.set_output_types(output_1=int, output_2=str) 226 | ... 227 | 228 | # no decorators here 229 | def run(self, value: int): 230 | return {"output_1": 1, "output_2": "2"} 231 | ``` 232 | """ 233 | instance.__canals_output__ = {name: OutputSocket(name=name, type=type_) for name, type_ in types.items()} 234 | 235 | def output_types(self, **types): 236 | """ 237 | Decorator factory that specifies the output types of a component. 238 | 239 | Use as: 240 | 241 | ```python 242 | @component 243 | class MyComponent: 244 | @component.output_types(output_1=int, output_2=str) 245 | def run(self, value: int): 246 | return {"output_1": 1, "output_2": "2"} 247 | ``` 248 | """ 249 | 250 | def output_types_decorator(run_method): 251 | """ 252 | This happens at class creation time, and since we don't have the decorated 253 | class available here, we temporarily store the output types as an attribute of 254 | the decorated method. The ComponentMeta metaclass will use this data to create 255 | sockets at instance creation time. 256 | """ 257 | setattr( 258 | run_method, 259 | "_output_types_cache", 260 | {name: OutputSocket(name=name, type=type_) for name, type_ in types.items()}, 261 | ) 262 | return run_method 263 | 264 | return output_types_decorator 265 | 266 | def _component(self, class_): 267 | """ 268 | Decorator validating the structure of the component and registering it in the components registry. 269 | """ 270 | logger.debug("Registering %s as a component", class_) 271 | 272 | # Check for required methods and fail as soon as possible 273 | if not hasattr(class_, "run"): 274 | raise ComponentError(f"{class_.__name__} must have a 'run()' method. See the docs for more information.") 275 | 276 | def copy_class_namespace(namespace): 277 | """ 278 | This is the callback that `typing.new_class` will use 279 | to populate the newly created class. We just copy 280 | the whole namespace from the decorated class. 281 | """ 282 | for key, val in dict(class_.__dict__).items(): 283 | namespace[key] = val 284 | 285 | # Recreate the decorated component class so it uses our metaclass 286 | class_ = new_class(class_.__name__, class_.__bases__, {"metaclass": ComponentMeta}, copy_class_namespace) 287 | 288 | # Save the component in the class registry (for deserialization) 289 | class_path = f"{class_.__module__}.{class_.__name__}" 290 | if class_path in self.registry: 291 | # Corner case, but it may occur easily in notebooks when re-running cells. 292 | logger.debug( 293 | "Component %s is already registered. Previous imported from '%s', new imported from '%s'", 294 | class_path, 295 | self.registry[class_path], 296 | class_, 297 | ) 298 | self.registry[class_path] = class_ 299 | logger.debug("Registered Component %s", class_) 300 | 301 | return class_ 302 | 303 | def __call__(self, class_): 304 | return self._component(class_) 305 | 306 | 307 | component = _Component() 308 | -------------------------------------------------------------------------------- /canals/component/connection.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | from typing import Optional, List, Tuple 3 | from dataclasses import dataclass 4 | 5 | from canals.component.sockets import InputSocket, OutputSocket 6 | from canals.type_utils import _type_name, _types_are_compatible 7 | from canals.errors import PipelineConnectError 8 | 9 | 10 | @dataclass 11 | class Connection: 12 | sender: Optional[str] 13 | sender_socket: Optional[OutputSocket] 14 | receiver: Optional[str] 15 | receiver_socket: Optional[InputSocket] 16 | 17 | def __post_init__(self): 18 | if self.sender and self.sender_socket and self.receiver and self.receiver_socket: 19 | # Make sure the receiving socket isn't already connected, unless it's variadic. Sending sockets can be 20 | # connected as many times as needed, so they don't need this check 21 | if self.receiver_socket.senders and not self.receiver_socket.is_variadic: 22 | raise PipelineConnectError( 23 | f"Cannot connect '{self.sender}.{self.sender_socket.name}' with '{self.receiver}.{self.receiver_socket.name}': " 24 | f"{self.receiver}.{self.receiver_socket.name} is already connected to {self.receiver_socket.senders}.\n" 25 | ) 26 | 27 | self.sender_socket.receivers.append(self.receiver) 28 | self.receiver_socket.senders.append(self.sender) 29 | 30 | def __repr__(self): 31 | if self.sender and self.sender_socket: 32 | sender_repr = f"{self.sender}.{self.sender_socket.name} ({_type_name(self.sender_socket.type)})" 33 | else: 34 | sender_repr = "input needed" 35 | 36 | if self.receiver and self.receiver_socket: 37 | receiver_repr = f"({_type_name(self.receiver_socket.type)}) {self.receiver}.{self.receiver_socket.name}" 38 | else: 39 | receiver_repr = "output" 40 | 41 | return f"{sender_repr} --> {receiver_repr}" 42 | 43 | def __hash__(self): 44 | """ 45 | Connection is used as a dictionary key in Pipeline, it must be hashable 46 | """ 47 | return hash( 48 | "-".join( 49 | [ 50 | self.sender if self.sender else "input", 51 | self.sender_socket.name if self.sender_socket else "", 52 | self.receiver if self.receiver else "output", 53 | self.receiver_socket.name if self.receiver_socket else "", 54 | ] 55 | ) 56 | ) 57 | 58 | @property 59 | def is_mandatory(self) -> bool: 60 | """ 61 | Returns True if the connection goes to a mandatory input socket, False otherwise 62 | """ 63 | if self.receiver_socket: 64 | return self.receiver_socket.is_mandatory 65 | return False 66 | 67 | @staticmethod 68 | def from_list_of_sockets( 69 | sender_node: str, sender_sockets: List[OutputSocket], receiver_node: str, receiver_sockets: List[InputSocket] 70 | ) -> "Connection": 71 | """ 72 | Find one single possible connection between two lists of sockets. 73 | """ 74 | # List all sender/receiver combinations of sockets that match by type 75 | possible_connections = [ 76 | (sender_sock, receiver_sock) 77 | for sender_sock, receiver_sock in itertools.product(sender_sockets, receiver_sockets) 78 | if _types_are_compatible(sender_sock.type, receiver_sock.type) 79 | ] 80 | 81 | # No connections seem to be possible 82 | if not possible_connections: 83 | connections_status_str = _connections_status( 84 | sender_node=sender_node, 85 | sender_sockets=sender_sockets, 86 | receiver_node=receiver_node, 87 | receiver_sockets=receiver_sockets, 88 | ) 89 | 90 | # Both sockets were specified: explain why the types don't match 91 | if len(sender_sockets) == len(receiver_sockets) and len(sender_sockets) == 1: 92 | raise PipelineConnectError( 93 | f"Cannot connect '{sender_node}.{sender_sockets[0].name}' with '{receiver_node}.{receiver_sockets[0].name}': " 94 | f"their declared input and output types do not match.\n{connections_status_str}" 95 | ) 96 | 97 | # Not both sockets were specified: explain there's no possible match on any pair 98 | connections_status_str = _connections_status( 99 | sender_node=sender_node, 100 | sender_sockets=sender_sockets, 101 | receiver_node=receiver_node, 102 | receiver_sockets=receiver_sockets, 103 | ) 104 | raise PipelineConnectError( 105 | f"Cannot connect '{sender_node}' with '{receiver_node}': " 106 | f"no matching connections available.\n{connections_status_str}" 107 | ) 108 | 109 | # There's more than one possible connection 110 | if len(possible_connections) > 1: 111 | # Try to match by name 112 | name_matches = [ 113 | (out_sock, in_sock) for out_sock, in_sock in possible_connections if in_sock.name == out_sock.name 114 | ] 115 | if len(name_matches) != 1: 116 | # TODO allow for multiple connections at once if there is no ambiguity? 117 | # TODO give priority to sockets that have no default values? 118 | connections_status_str = _connections_status( 119 | sender_node=sender_node, 120 | sender_sockets=sender_sockets, 121 | receiver_node=receiver_node, 122 | receiver_sockets=receiver_sockets, 123 | ) 124 | raise PipelineConnectError( 125 | f"Cannot connect '{sender_node}' with '{receiver_node}': more than one connection is possible " 126 | "between these components. Please specify the connection name, like: " 127 | f"pipeline.connect('{sender_node}.{possible_connections[0][0].name}', " 128 | f"'{receiver_node}.{possible_connections[0][1].name}').\n{connections_status_str}" 129 | ) 130 | 131 | match = possible_connections[0] 132 | return Connection(sender_node, match[0], receiver_node, match[1]) 133 | 134 | 135 | def _connections_status( 136 | sender_node: str, receiver_node: str, sender_sockets: List[OutputSocket], receiver_sockets: List[InputSocket] 137 | ): 138 | """ 139 | Lists the status of the sockets, for error messages. 140 | """ 141 | sender_sockets_entries = [] 142 | for sender_socket in sender_sockets: 143 | sender_sockets_entries.append(f" - {sender_socket.name}: {_type_name(sender_socket.type)}") 144 | sender_sockets_list = "\n".join(sender_sockets_entries) 145 | 146 | receiver_sockets_entries = [] 147 | for receiver_socket in receiver_sockets: 148 | if receiver_socket.senders: 149 | sender_status = f"sent by {','.join(receiver_socket.senders)}" 150 | else: 151 | sender_status = "available" 152 | receiver_sockets_entries.append( 153 | f" - {receiver_socket.name}: {_type_name(receiver_socket.type)} ({sender_status})" 154 | ) 155 | receiver_sockets_list = "\n".join(receiver_sockets_entries) 156 | 157 | return f"'{sender_node}':\n{sender_sockets_list}\n'{receiver_node}':\n{receiver_sockets_list}" 158 | 159 | 160 | def parse_connect_string(connection: str) -> Tuple[str, Optional[str]]: 161 | """ 162 | Returns component-connection pairs from a connect_to/from string 163 | """ 164 | if "." in connection: 165 | split_str = connection.split(".", maxsplit=1) 166 | return (split_str[0], split_str[1]) 167 | return connection, None 168 | -------------------------------------------------------------------------------- /canals/component/descriptions.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Any 2 | 3 | 4 | def find_component_inputs(component: Any) -> Dict[str, Dict[str, Any]]: 5 | """ 6 | Returns a mapping of input names to their expected types and optionality for a given component. 7 | 8 | :param component: The target component to introspect. 9 | :return: A dictionary where keys are input names, with each key's value being another dictionary 10 | containing 'type' (the data type expected) and 'is_optional' (a boolean indicating if the input is optional). 11 | 12 | :raise: Throws a ValueError if the class of component instance is not appropriately decorated with @component. 13 | """ 14 | if not hasattr(component, "__canals_input__"): 15 | raise ValueError( 16 | f"Component {component} does not have defined inputs or is improperly decorated. " 17 | "Ensure it is a valid @component with declared inputs." 18 | ) 19 | 20 | return { 21 | name: {"type": socket.type, "is_mandatory": socket.is_mandatory, "is_variadic": socket.is_variadic} 22 | for name, socket in component.__canals_input__.items() 23 | } 24 | 25 | 26 | def find_component_outputs(component: Any) -> Dict[str, Dict[str, Any]]: 27 | """ 28 | Returns a mapping of component output names to their expected types. 29 | 30 | :param component: The component being examined for its outputs. 31 | :return: A dictionary where each key is an output name and the value is a dictionary with a 'type' key 32 | indicating the data type of the output. 33 | 34 | :raise: Throws a ValueError if the class of component instance is not appropriately decorated with @component. 35 | """ 36 | if not hasattr(component, "__canals_output__"): 37 | raise ValueError( 38 | f"The specified component {component} does not have defined outputs or is not properly decorated. " 39 | "Check that it is a valid @component with outputs specified." 40 | ) 41 | 42 | return {name: {"type": socket.type} for name, socket in component.__canals_output__.items()} 43 | -------------------------------------------------------------------------------- /canals/component/sockets.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: 2022-present deepset GmbH 2 | # 3 | # SPDX-License-Identifier: Apache-2.0 4 | from typing import get_args, List, Type 5 | import logging 6 | from dataclasses import dataclass, field 7 | 8 | from canals.component.types import CANALS_VARIADIC_ANNOTATION 9 | 10 | 11 | logger = logging.getLogger(__name__) 12 | 13 | 14 | @dataclass 15 | class InputSocket: 16 | name: str 17 | type: Type 18 | is_mandatory: bool = True 19 | is_variadic: bool = field(init=False) 20 | senders: List[str] = field(default_factory=list) 21 | 22 | def __post_init__(self): 23 | try: 24 | # __metadata__ is a tuple 25 | self.is_variadic = self.type.__metadata__[0] == CANALS_VARIADIC_ANNOTATION 26 | except AttributeError: 27 | self.is_variadic = False 28 | if self.is_variadic: 29 | # We need to "unpack" the type inside the Variadic annotation, 30 | # otherwise the pipeline connection api will try to match 31 | # `Annotated[type, CANALS_VARIADIC_ANNOTATION]`. 32 | # 33 | # Note1: Variadic is expressed as an annotation of one single type, 34 | # so the return value of get_args will always be a one-item tuple. 35 | # 36 | # Note2: a pipeline always passes a list of items when a component 37 | # input is declared as Variadic, so the type itself always wraps 38 | # an iterable of the declared type. For example, Variadic[int] 39 | # is eventually an alias for Iterable[int]. Since we're interested 40 | # in getting the inner type `int`, we call `get_args` twice: the 41 | # first time to get `List[int]` out of `Variadic`, the second time 42 | # to get `int` out of `List[int]`. 43 | self.type = get_args(get_args(self.type)[0])[0] 44 | 45 | 46 | @dataclass 47 | class OutputSocket: 48 | name: str 49 | type: type 50 | receivers: List[str] = field(default_factory=list) 51 | -------------------------------------------------------------------------------- /canals/component/types.py: -------------------------------------------------------------------------------- 1 | from typing import TypeVar, Iterable 2 | from typing_extensions import TypeAlias, Annotated # Python 3.8 compatibility 3 | 4 | CANALS_VARIADIC_ANNOTATION = "__canals__variadic_t" 5 | 6 | # # Generic type variable used in the Variadic container 7 | T = TypeVar("T") 8 | 9 | 10 | # Variadic is a custom annotation type we use to mark input types. 11 | # This type doesn't do anything else than "marking" the contained 12 | # type so it can be used in the `InputSocket` creation where we 13 | # check that its annotation equals to CANALS_VARIADIC_ANNOTATION 14 | Variadic: TypeAlias = Annotated[Iterable[T], CANALS_VARIADIC_ANNOTATION] 15 | -------------------------------------------------------------------------------- /canals/errors.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: 2022-present deepset GmbH 2 | # 3 | # SPDX-License-Identifier: Apache-2.0 4 | class PipelineError(Exception): 5 | pass 6 | 7 | 8 | class PipelineRuntimeError(Exception): 9 | pass 10 | 11 | 12 | class PipelineConnectError(PipelineError): 13 | pass 14 | 15 | 16 | class PipelineValidationError(PipelineError): 17 | pass 18 | 19 | 20 | class PipelineDrawingError(PipelineError): 21 | pass 22 | 23 | 24 | class PipelineMaxLoops(PipelineError): 25 | pass 26 | 27 | 28 | class PipelineUnmarshalError(PipelineError): 29 | pass 30 | 31 | 32 | class ComponentError(Exception): 33 | pass 34 | 35 | 36 | class ComponentDeserializationError(Exception): 37 | pass 38 | 39 | 40 | class DeserializationError(Exception): 41 | pass 42 | 43 | 44 | class SerializationError(Exception): 45 | pass 46 | -------------------------------------------------------------------------------- /canals/pipeline/__init__.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: 2022-present deepset GmbH 2 | # 3 | # SPDX-License-Identifier: Apache-2.0 4 | from canals.pipeline.pipeline import Pipeline 5 | from canals.errors import ( 6 | PipelineError, 7 | PipelineRuntimeError, 8 | PipelineValidationError, 9 | PipelineConnectError, 10 | PipelineMaxLoops, 11 | ) 12 | -------------------------------------------------------------------------------- /canals/pipeline/descriptions.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: 2022-present deepset GmbH 2 | # 3 | # SPDX-License-Identifier: Apache-2.0 4 | from typing import List, Dict 5 | import logging 6 | 7 | import networkx # type:ignore 8 | 9 | from canals.type_utils import _type_name 10 | from canals.component.sockets import InputSocket, OutputSocket 11 | 12 | 13 | logger = logging.getLogger(__name__) 14 | 15 | 16 | def find_pipeline_inputs(graph: networkx.MultiDiGraph) -> Dict[str, List[InputSocket]]: 17 | """ 18 | Collect components that have disconnected input sockets. Note that this method returns *ALL* disconnected 19 | input sockets, including all such sockets with default values. 20 | """ 21 | return { 22 | name: [socket for socket in data.get("input_sockets", {}).values() if not socket.senders or socket.is_variadic] 23 | for name, data in graph.nodes(data=True) 24 | } 25 | 26 | 27 | def find_pipeline_outputs(graph: networkx.MultiDiGraph) -> Dict[str, List[OutputSocket]]: 28 | """ 29 | Collect components that have disconnected output sockets. They define the pipeline output. 30 | """ 31 | return { 32 | name: [socket for socket in data.get("output_sockets", {}).values() if not socket.receivers] 33 | for name, data in graph.nodes(data=True) 34 | } 35 | 36 | 37 | def describe_pipeline_inputs(graph: networkx.MultiDiGraph): 38 | """ 39 | Returns a dictionary with the input names and types that this pipeline accepts. 40 | """ 41 | inputs = { 42 | comp: {socket.name: {"type": socket.type, "is_mandatory": socket.is_mandatory} for socket in data} 43 | for comp, data in find_pipeline_inputs(graph).items() 44 | if data 45 | } 46 | return inputs 47 | 48 | 49 | def describe_pipeline_inputs_as_string(graph: networkx.MultiDiGraph): 50 | """ 51 | Returns a string representation of the input names and types that this pipeline accepts. 52 | """ 53 | inputs = describe_pipeline_inputs(graph) 54 | message = "This pipeline expects the following inputs:\n" 55 | for comp, sockets in inputs.items(): 56 | if sockets: 57 | message += f"- {comp}:\n" 58 | for name, socket in sockets.items(): 59 | message += f" - {name}: {_type_name(socket['type'])}\n" 60 | return message 61 | -------------------------------------------------------------------------------- /canals/pipeline/draw/__init__.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: 2022-present deepset GmbH 2 | # 3 | # SPDX-License-Identifier: Apache-2.0 4 | from canals.pipeline.draw.draw import _draw, _convert, _convert_for_debug, RenderingEngines 5 | -------------------------------------------------------------------------------- /canals/pipeline/draw/draw.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: 2022-present deepset GmbH 2 | # 3 | # SPDX-License-Identifier: Apache-2.0 4 | from typing import Literal, Optional, Dict, get_args, Any 5 | 6 | import logging 7 | from pathlib import Path 8 | 9 | import networkx # type:ignore 10 | 11 | from canals.pipeline.descriptions import find_pipeline_inputs, find_pipeline_outputs 12 | from canals.pipeline.draw.graphviz import _to_agraph 13 | from canals.pipeline.draw.mermaid import _to_mermaid_image, _to_mermaid_text 14 | from canals.type_utils import _type_name 15 | 16 | logger = logging.getLogger(__name__) 17 | RenderingEngines = Literal["graphviz", "mermaid-image", "mermaid-text"] 18 | 19 | 20 | def _draw( 21 | graph: networkx.MultiDiGraph, 22 | path: Path, 23 | engine: RenderingEngines = "mermaid-image", 24 | style_map: Optional[Dict[str, str]] = None, 25 | ) -> None: 26 | """ 27 | Renders the pipeline graph and saves it to file. 28 | """ 29 | converted_graph = _convert(graph=graph, engine=engine, style_map=style_map) 30 | 31 | if engine == "graphviz": 32 | converted_graph.draw(path) 33 | 34 | elif engine == "mermaid-image": 35 | with open(path, "wb") as imagefile: 36 | imagefile.write(converted_graph) 37 | 38 | elif engine == "mermaid-text": 39 | with open((path), "w", encoding="utf-8") as textfile: 40 | textfile.write(converted_graph) 41 | 42 | else: 43 | raise ValueError(f"Unknown rendering engine '{engine}'. Choose one from: {get_args(RenderingEngines)}.") 44 | 45 | logger.debug("Pipeline diagram saved at %s", path) 46 | 47 | 48 | def _convert_for_debug( 49 | graph: networkx.MultiDiGraph, 50 | ) -> Any: 51 | """ 52 | Renders the pipeline graph with additional debug information into a text file that Mermaid can later render. 53 | """ 54 | graph = _prepare_for_drawing(graph=graph, style_map={}) 55 | return _to_mermaid_text(graph=graph) 56 | 57 | 58 | def _convert( 59 | graph: networkx.MultiDiGraph, 60 | engine: RenderingEngines = "mermaid-image", 61 | style_map: Optional[Dict[str, str]] = None, 62 | ) -> Any: 63 | """ 64 | Renders the pipeline graph with the correct render and returns it. 65 | """ 66 | graph = _prepare_for_drawing(graph=graph, style_map=style_map or {}) 67 | 68 | if engine == "graphviz": 69 | return _to_agraph(graph=graph) 70 | 71 | if engine == "mermaid-image": 72 | return _to_mermaid_image(graph=graph) 73 | 74 | if engine == "mermaid-text": 75 | return _to_mermaid_text(graph=graph) 76 | 77 | raise ValueError(f"Unknown rendering engine '{engine}'. Choose one from: {get_args(RenderingEngines)}.") 78 | 79 | 80 | def _prepare_for_drawing(graph: networkx.MultiDiGraph, style_map: Dict[str, str]) -> networkx.MultiDiGraph: 81 | """ 82 | Prepares the graph to be drawn: adds explitic input and output nodes, labels the edges, applies the styles, etc. 83 | """ 84 | # Apply the styles 85 | if style_map: 86 | for node, style in style_map.items(): 87 | graph.nodes[node]["style"] = style 88 | 89 | # Label the edges 90 | for inp, outp, key, data in graph.edges(keys=True, data=True): 91 | data["label"] = f"{data['from_socket'].name} -> {data['to_socket'].name}" 92 | graph.add_edge(inp, outp, key=key, **data) 93 | 94 | # Draw the inputs 95 | graph.add_node("input") 96 | for node, in_sockets in find_pipeline_inputs(graph).items(): 97 | for in_socket in in_sockets: 98 | if not in_socket.senders and in_socket.is_mandatory: 99 | # If this socket has no sender it could be a socket that receives input 100 | # directly when running the Pipeline. We can't know that for sure, in doubt 101 | # we draw it as receiving input directly. 102 | graph.add_edge("input", node, label=in_socket.name, conn_type=_type_name(in_socket.type)) 103 | 104 | # Draw the outputs 105 | graph.add_node("output") 106 | for node, out_sockets in find_pipeline_outputs(graph).items(): 107 | for out_socket in out_sockets: 108 | graph.add_edge(node, "output", label=out_socket.name, conn_type=_type_name(out_socket.type)) 109 | 110 | return graph 111 | -------------------------------------------------------------------------------- /canals/pipeline/draw/graphviz.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: 2022-present deepset GmbH 2 | # 3 | # SPDX-License-Identifier: Apache-2.0 4 | import logging 5 | 6 | import networkx # type:ignore 7 | 8 | from networkx.drawing.nx_agraph import to_agraph as nx_to_agraph # type:ignore 9 | 10 | 11 | logger = logging.getLogger(__name__) 12 | 13 | # pyright: reportMissingImports=false 14 | # pylint: disable=unused-import,import-outside-toplevel 15 | 16 | 17 | def _to_agraph(graph: networkx.MultiDiGraph): 18 | """ 19 | Renders a pipeline graph using PyGraphViz. You need to install it and all its system dependencies for it to work. 20 | """ 21 | try: 22 | import pygraphviz # type: ignore 23 | except (ModuleNotFoundError, ImportError) as exc: 24 | raise ImportError( 25 | "Can't use 'pygraphviz' to draw this pipeline: pygraphviz could not be imported. " 26 | "Make sure pygraphviz is installed and all its system dependencies are setup correctly." 27 | ) from exc 28 | 29 | for inp, outp, key, data in graph.out_edges("input", keys=True, data=True): 30 | data["style"] = "dashed" 31 | graph.add_edge(inp, outp, key=key, **data) 32 | 33 | for inp, outp, key, data in graph.in_edges("output", keys=True, data=True): 34 | data["style"] = "dashed" 35 | graph.add_edge(inp, outp, key=key, **data) 36 | 37 | graph.nodes["input"]["shape"] = "plain" 38 | graph.nodes["output"]["shape"] = "plain" 39 | agraph = nx_to_agraph(graph) 40 | agraph.layout("dot") 41 | return agraph 42 | -------------------------------------------------------------------------------- /canals/pipeline/draw/mermaid.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: 2022-present deepset GmbH 2 | # 3 | # SPDX-License-Identifier: Apache-2.0 4 | import logging 5 | import base64 6 | 7 | import requests 8 | import networkx # type:ignore 9 | 10 | from canals.errors import PipelineDrawingError 11 | from canals.type_utils import _type_name 12 | 13 | logger = logging.getLogger(__name__) 14 | 15 | 16 | MERMAID_STYLED_TEMPLATE = """ 17 | %%{{ init: {{'theme': 'neutral' }} }}%% 18 | 19 | graph TD; 20 | 21 | {connections} 22 | 23 | classDef component text-align:center; 24 | """ 25 | 26 | 27 | def _to_mermaid_image(graph: networkx.MultiDiGraph): 28 | """ 29 | Renders a pipeline using Mermaid (hosted version at 'https://mermaid.ink'). Requires Internet access. 30 | """ 31 | graph_styled = _to_mermaid_text(graph=graph) 32 | 33 | graphbytes = graph_styled.encode("ascii") 34 | base64_bytes = base64.b64encode(graphbytes) 35 | base64_string = base64_bytes.decode("ascii") 36 | url = "https://mermaid.ink/img/" + base64_string 37 | 38 | logging.debug("Rendeding graph at %s", url) 39 | try: 40 | resp = requests.get(url, timeout=10) 41 | if resp.status_code >= 400: 42 | logger.warning("Failed to draw the pipeline: https://mermaid.ink/img/ returned status %s", resp.status_code) 43 | logger.info("Exact URL requested: %s", url) 44 | logger.warning("No pipeline diagram will be saved.") 45 | resp.raise_for_status() 46 | 47 | except Exception as exc: # pylint: disable=broad-except 48 | logger.warning("Failed to draw the pipeline: could not connect to https://mermaid.ink/img/ (%s)", exc) 49 | logger.info("Exact URL requested: %s", url) 50 | logger.warning("No pipeline diagram will be saved.") 51 | raise PipelineDrawingError( 52 | "There was an issue with https://mermaid.ink/, see the stacktrace for details." 53 | ) from exc 54 | 55 | return resp.content 56 | 57 | 58 | def _to_mermaid_text(graph: networkx.MultiDiGraph) -> str: 59 | """ 60 | Converts a Networkx graph into Mermaid syntax. The output of this function can be used in the documentation 61 | with `mermaid` codeblocks and it will be automatically rendered. 62 | """ 63 | sockets = { 64 | comp: "".join( 65 | [ 66 | f"
  • {name} ({_type_name(socket.type)})
  • " 67 | for name, socket in data.get("input_sockets", {}).items() 68 | if (not socket.is_mandatory and not socket.senders) or socket.is_variadic 69 | ] 70 | ) 71 | for comp, data in graph.nodes(data=True) 72 | } 73 | optional_inputs = { 74 | comp: f"

    Optional inputs:
      {sockets}
    " if sockets else "" 75 | for comp, sockets in sockets.items() 76 | } 77 | 78 | states = { 79 | comp: f"{comp}[\"{comp}
    {type(data['instance']).__name__}{optional_inputs[comp]}\"]:::component" 80 | for comp, data in graph.nodes(data=True) 81 | if comp not in ["input", "output"] 82 | } 83 | 84 | connections_list = [ 85 | f"{states[from_comp]} -- \"{conn_data['label']}
    {conn_data['conn_type']}\" --> {states[to_comp]}" 86 | for from_comp, to_comp, conn_data in graph.edges(data=True) 87 | if from_comp != "input" and to_comp != "output" 88 | ] 89 | input_connections = [ 90 | f"i{{*}} -- \"{conn_data['label']}
    {conn_data['conn_type']}\" --> {states[to_comp]}" 91 | for _, to_comp, conn_data in graph.out_edges("input", data=True) 92 | ] 93 | output_connections = [ 94 | f"{states[from_comp]} -- \"{conn_data['label']}
    {conn_data['conn_type']}\"--> o{{*}}" 95 | for from_comp, _, conn_data in graph.in_edges("output", data=True) 96 | ] 97 | connections = "\n".join(connections_list + input_connections + output_connections) 98 | 99 | graph_styled = MERMAID_STYLED_TEMPLATE.format(connections=connections) 100 | logger.debug("Mermaid diagram:\n%s", graph_styled) 101 | 102 | return graph_styled 103 | -------------------------------------------------------------------------------- /canals/pipeline/validation.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: 2022-present deepset GmbH 2 | # 3 | # SPDX-License-Identifier: Apache-2.0 4 | from typing import Dict, Any 5 | import logging 6 | 7 | import networkx # type:ignore 8 | 9 | from canals.errors import PipelineValidationError 10 | from canals.component.sockets import InputSocket 11 | from canals.pipeline.descriptions import find_pipeline_inputs, describe_pipeline_inputs_as_string 12 | 13 | 14 | logger = logging.getLogger(__name__) 15 | 16 | 17 | def validate_pipeline_input(graph: networkx.MultiDiGraph, input_values: Dict[str, Any]) -> Dict[str, Any]: 18 | """ 19 | Make sure the pipeline is properly built and that the input received makes sense. 20 | Returns the input values, validated and updated at need. 21 | """ 22 | if not any(sockets for sockets in find_pipeline_inputs(graph).values()): 23 | raise PipelineValidationError("This pipeline has no inputs.") 24 | 25 | # Make sure the input keys are all nodes of the pipeline 26 | unknown_components = [key for key in input_values.keys() if not key in graph.nodes] 27 | if unknown_components: 28 | all_inputs = describe_pipeline_inputs_as_string(graph) 29 | raise ValueError( 30 | f"Pipeline received data for unknown component(s): {', '.join(unknown_components)}\n\n{all_inputs}" 31 | ) 32 | 33 | # Make sure all necessary sockets are connected 34 | _validate_input_sockets_are_connected(graph, input_values) 35 | 36 | # Make sure that the pipeline input is only sent to nodes that won't receive data from other nodes 37 | _validate_nodes_receive_only_expected_input(graph, input_values) 38 | 39 | return input_values 40 | 41 | 42 | def _validate_input_sockets_are_connected(graph: networkx.MultiDiGraph, input_values: Dict[str, Any]): 43 | """ 44 | Make sure all the inputs nodes are receiving all the values they need, either from the Pipeline's input or from 45 | other nodes. 46 | """ 47 | valid_inputs = find_pipeline_inputs(graph) 48 | for node, sockets in valid_inputs.items(): 49 | for socket in sockets: 50 | inputs_for_node = input_values.get(node, {}) 51 | missing_input_value = ( 52 | inputs_for_node is None 53 | or not socket.name in inputs_for_node.keys() 54 | or inputs_for_node.get(socket.name, None) is None 55 | ) 56 | if missing_input_value and socket.is_mandatory and not socket.is_variadic: 57 | all_inputs = describe_pipeline_inputs_as_string(graph) 58 | raise ValueError(f"Missing input: {node}.{socket.name}\n\n{all_inputs}") 59 | 60 | 61 | def _validate_nodes_receive_only_expected_input(graph: networkx.MultiDiGraph, input_values: Dict[str, Any]): 62 | """ 63 | Make sure that every input node is only receiving input values from EITHER the pipeline's input or another node, 64 | but never from both. 65 | """ 66 | for node, input_data in input_values.items(): 67 | for socket_name in input_data.keys(): 68 | if input_data.get(socket_name, None) is None: 69 | continue 70 | if not socket_name in graph.nodes[node]["input_sockets"].keys(): 71 | all_inputs = describe_pipeline_inputs_as_string(graph) 72 | raise ValueError( 73 | f"Component {node} is not expecting any input value called {socket_name}.\n\n{all_inputs}", 74 | ) 75 | 76 | input_socket: InputSocket = graph.nodes[node]["input_sockets"][socket_name] 77 | if input_socket.senders and not input_socket.is_variadic: 78 | raise ValueError(f"The input {socket_name} of {node} is already sent by: {input_socket.senders}") 79 | -------------------------------------------------------------------------------- /canals/serialization.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: 2022-present deepset GmbH 2 | # 3 | # SPDX-License-Identifier: Apache-2.0 4 | import inspect 5 | from typing import Type, Dict, Any 6 | 7 | from canals.errors import DeserializationError, SerializationError 8 | 9 | 10 | def component_to_dict(obj: Any) -> Dict[str, Any]: 11 | """ 12 | The marshaller used by the Pipeline. If a `to_dict` method is present in the 13 | component instance, that will be used instead of the default method. 14 | """ 15 | if hasattr(obj, "to_dict"): 16 | return obj.to_dict() 17 | 18 | init_parameters = {} 19 | for name, param in inspect.signature(obj.__init__).parameters.items(): 20 | # Ignore `args` and `kwargs`, used by the default constructor 21 | if name in ("args", "kwargs"): 22 | continue 23 | try: 24 | # This only works if the Component constructor assigns the init 25 | # parameter to an instance variable or property with the same name 26 | param_value = getattr(obj, name) 27 | except AttributeError as e: 28 | # If the parameter doesn't have a default value, raise an error 29 | if param.default is param.empty: 30 | raise SerializationError( 31 | f"Cannot determine the value of the init parameter '{name}' for the class {obj.__class__.__name__}." 32 | f"You can fix this error by assigning 'self.{name} = {name}' or adding a " 33 | f"custom serialization method 'to_dict' to the class." 34 | ) from e 35 | # In case the init parameter was not assigned, we use the default value 36 | param_value = param.default 37 | init_parameters[name] = param_value 38 | 39 | return default_to_dict(obj, **init_parameters) 40 | 41 | 42 | def component_from_dict(cls: Type[object], data: Dict[str, Any]) -> Any: 43 | """ 44 | The unmarshaller used by the Pipeline. If a `from_dict` method is present in the 45 | component instance, that will be used instead of the default method. 46 | """ 47 | if hasattr(cls, "from_dict"): 48 | return cls.from_dict(data) 49 | 50 | return default_from_dict(cls, data) 51 | 52 | 53 | def default_to_dict(obj: Any, **init_parameters) -> Dict[str, Any]: 54 | """ 55 | Utility function to serialize an object to a dictionary. 56 | This is mostly necessary for Components but it can be used by any object. 57 | 58 | `init_parameters` are parameters passed to the object class `__init__`. 59 | They must be defined explicitly as they'll be used when creating a new 60 | instance of `obj` with `from_dict`. Omitting them might cause deserialisation 61 | errors or unexpected behaviours later, when calling `from_dict`. 62 | 63 | An example usage: 64 | 65 | ```python 66 | class MyClass: 67 | def __init__(self, my_param: int = 10): 68 | self.my_param = my_param 69 | 70 | def to_dict(self): 71 | return default_to_dict(self, my_param=self.my_param) 72 | 73 | 74 | obj = MyClass(my_param=1000) 75 | data = obj.to_dict() 76 | assert data == { 77 | "type": "MyClass", 78 | "init_parameters": { 79 | "my_param": 1000, 80 | }, 81 | } 82 | ``` 83 | """ 84 | return { 85 | "type": f"{obj.__class__.__module__}.{obj.__class__.__name__}", 86 | "init_parameters": init_parameters, 87 | } 88 | 89 | 90 | def default_from_dict(cls: Type[object], data: Dict[str, Any]) -> Any: 91 | """ 92 | Utility function to deserialize a dictionary to an object. 93 | This is mostly necessary for Components but it can be used by any object. 94 | 95 | The function will raise a `DeserializationError` if the `type` field in `data` is 96 | missing or it doesn't match the type of `cls`. 97 | 98 | If `data` contains an `init_parameters` field it will be used as parameters to create 99 | a new instance of `cls`. 100 | """ 101 | init_params = data.get("init_parameters", {}) 102 | if "type" not in data: 103 | raise DeserializationError("Missing 'type' in serialization data") 104 | if data["type"] != f"{cls.__module__}.{cls.__name__}": 105 | raise DeserializationError(f"Class '{data['type']}' can't be deserialized as '{cls.__name__}'") 106 | return cls(**init_params) 107 | -------------------------------------------------------------------------------- /canals/testing/__init__.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: 2022-present deepset GmbH 2 | # 3 | # SPDX-License-Identifier: Apache-2.0 4 | -------------------------------------------------------------------------------- /canals/testing/factory.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: 2022-present deepset GmbH 2 | # 3 | # SPDX-License-Identifier: Apache-2.0 4 | from typing import Any, Dict, Optional, Tuple, Type 5 | 6 | from canals import component, Component 7 | from canals.serialization import default_to_dict, default_from_dict 8 | 9 | 10 | def component_class( 11 | name: str, 12 | input_types: Optional[Dict[str, Any]] = None, 13 | output_types: Optional[Dict[str, Any]] = None, 14 | output: Optional[Dict[str, Any]] = None, 15 | bases: Optional[Tuple[type, ...]] = None, 16 | extra_fields: Optional[Dict[str, Any]] = None, 17 | ) -> Type[Component]: 18 | """ 19 | Utility class to create a Component class with the given name and input and output types. 20 | 21 | If `output` is set but `output_types` is not, `output_types` will be set to the types of the values in `output`. 22 | Though if `output_types` is set but `output` is not the component's `run` method will return a dictionary 23 | of the same keys as `output_types` all with a value of None. 24 | 25 | ### Usage 26 | 27 | Create a component class with default input and output types: 28 | ```python 29 | MyFakeComponent = component_class_factory("MyFakeComponent") 30 | component = MyFakeComponent() 31 | output = component.run(value=1) 32 | assert output == {"value": None} 33 | ``` 34 | 35 | Create a component class with an "value" input of type `int` and with a "value" output of `10`: 36 | ```python 37 | MyFakeComponent = component_class_factory( 38 | "MyFakeComponent", 39 | input_types={"value": int}, 40 | output={"value": 10} 41 | ) 42 | component = MyFakeComponent() 43 | output = component.run(value=1) 44 | assert output == {"value": 10} 45 | ``` 46 | 47 | Create a component class with a custom base class: 48 | ```python 49 | MyFakeComponent = component_class_factory( 50 | "MyFakeComponent", 51 | bases=(MyBaseClass,) 52 | ) 53 | component = MyFakeComponent() 54 | assert isinstance(component, MyBaseClass) 55 | ``` 56 | 57 | Create a component class with an extra field `my_field`: 58 | ```python 59 | MyFakeComponent = component_class_factory( 60 | "MyFakeComponent", 61 | extra_fields={"my_field": 10} 62 | ) 63 | component = MyFakeComponent() 64 | assert component.my_field == 10 65 | ``` 66 | 67 | Args: 68 | name: Name of the component class 69 | input_types: Dictionary of string and type that defines the inputs of the component, 70 | if set to None created component will expect a single input "value" of Any type. 71 | Defaults to None. 72 | output_types: Dictionary of string and type that defines the outputs of the component, 73 | if set to None created component will return a single output "value" of NoneType and None value. 74 | Defaults to None. 75 | output: Actual output dictionary returned by the created component run, 76 | is set to None it will return a dictionary of string and None values. 77 | Keys will be the same as the keys of output_types. Defaults to None. 78 | bases: Base classes for this component, if set to None only base is object. Defaults to None. 79 | extra_fields: Extra fields for the Component, defaults to None. 80 | 81 | :return: A class definition that can be used as a component. 82 | """ 83 | if input_types is None: 84 | input_types = {"value": Any} 85 | if output_types is None and output is not None: 86 | output_types = {key: type(value) for key, value in output.items()} 87 | elif output_types is None: 88 | output_types = {"value": type(None)} 89 | 90 | def init(self): 91 | component.set_input_types(self, **input_types) 92 | component.set_output_types(self, **output_types) 93 | 94 | # Both arguments are necessary to correctly define 95 | # run but pylint doesn't like that we don't use them. 96 | # It's fine ignoring the warning here. 97 | def run(self, **kwargs): # pylint: disable=unused-argument 98 | if output is not None: 99 | return output 100 | return {name: None for name in output_types.keys()} 101 | 102 | def to_dict(self): 103 | return default_to_dict(self) 104 | 105 | def from_dict(cls, data: Dict[str, Any]): 106 | return default_from_dict(cls, data) 107 | 108 | fields = { 109 | "__init__": init, 110 | "run": run, 111 | "to_dict": to_dict, 112 | "from_dict": classmethod(from_dict), 113 | } 114 | if extra_fields is not None: 115 | fields = {**fields, **extra_fields} 116 | 117 | if bases is None: 118 | bases = (object,) 119 | 120 | cls = type(name, bases, fields) 121 | return component(cls) 122 | -------------------------------------------------------------------------------- /canals/type_utils.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: 2022-present deepset GmbH 2 | # 3 | # SPDX-License-Identifier: Apache-2.0 4 | from typing import Union, get_args, get_origin, Any 5 | 6 | import logging 7 | 8 | 9 | logger = logging.getLogger(__name__) 10 | 11 | 12 | def _is_optional(type_: type) -> bool: 13 | """ 14 | Utility method that returns whether a type is Optional. 15 | """ 16 | return get_origin(type_) is Union and type(None) in get_args(type_) 17 | 18 | 19 | def _types_are_compatible(sender, receiver): # pylint: disable=too-many-return-statements 20 | """ 21 | Checks whether the source type is equal or a subtype of the destination type. Used to validate pipeline connections. 22 | 23 | Note: this method has no pretense to perform proper type matching. It especially does not deal with aliasing of 24 | typing classes such as `List` or `Dict` to their runtime counterparts `list` and `dict`. It also does not deal well 25 | with "bare" types, so `List` is treated differently from `List[Any]`, even though they should be the same. 26 | 27 | Consider simplifying the typing of your components if you observe unexpected errors during component connection. 28 | """ 29 | if sender == receiver or receiver is Any: 30 | return True 31 | 32 | if sender is Any: 33 | return False 34 | 35 | try: 36 | if issubclass(sender, receiver): 37 | return True 38 | except TypeError: # typing classes can't be used with issubclass, so we deal with them below 39 | pass 40 | 41 | sender_origin = get_origin(sender) 42 | receiver_origin = get_origin(receiver) 43 | 44 | if sender_origin is not Union and receiver_origin is Union: 45 | return any(_types_are_compatible(sender, union_arg) for union_arg in get_args(receiver)) 46 | 47 | if not sender_origin or not receiver_origin or sender_origin != receiver_origin: 48 | return False 49 | 50 | sender_args = get_args(sender) 51 | receiver_args = get_args(receiver) 52 | if len(sender_args) > len(receiver_args): 53 | return False 54 | 55 | return all(_types_are_compatible(*args) for args in zip(sender_args, receiver_args)) 56 | 57 | 58 | def _type_name(type_): 59 | """ 60 | Util methods to get a nice readable representation of a type. 61 | Handles Optional and Literal in a special way to make it more readable. 62 | """ 63 | # Literal args are strings, so we wrap them in quotes to make it clear 64 | if isinstance(type_, str): 65 | return f"'{type_}'" 66 | 67 | name = getattr(type_, "__name__", str(type_)) 68 | 69 | if name.startswith("typing."): 70 | name = name[7:] 71 | if "[" in name: 72 | name = name.split("[")[0] 73 | args = get_args(type_) 74 | if name == "Union" and type(None) in args and len(args) == 2: 75 | # Optional is technically a Union of type and None 76 | # but we want to display it as Optional 77 | name = "Optional" 78 | 79 | if args: 80 | args = ", ".join([_type_name(a) for a in args if a is not type(None)]) 81 | return f"{name}[{args}]" 82 | 83 | return f"{name}" 84 | -------------------------------------------------------------------------------- /docs/api-docs/canals.md: -------------------------------------------------------------------------------- 1 | # Canals 2 | 3 | ::: canals 4 | 5 | ::: canals.__about__ 6 | 7 | ::: canals.errors 8 | 9 | ::: canals.type_utils 10 | 11 | ::: canals.serialization 12 | -------------------------------------------------------------------------------- /docs/api-docs/component.md: -------------------------------------------------------------------------------- 1 | # Component API 2 | 3 | ::: canals.component 4 | 5 | ::: canals.component.component 6 | 7 | ::: canals.component.sockets 8 | -------------------------------------------------------------------------------- /docs/api-docs/pipeline.md: -------------------------------------------------------------------------------- 1 | # Pipeline 2 | 3 | ::: canals.pipeline 4 | 5 | ::: canals.pipeline.draw 6 | 7 | ::: canals.pipeline.draw.draw 8 | 9 | ::: canals.pipeline.draw.graphviz 10 | 11 | ::: canals.pipeline.draw.mermaid 12 | 13 | ::: canals.pipeline.pipeline 14 | 15 | ::: canals.pipeline.validation 16 | -------------------------------------------------------------------------------- /docs/api-docs/testing.md: -------------------------------------------------------------------------------- 1 | # Testing 2 | 3 | ::: canals.testing 4 | 5 | ::: canals.testing.factory 6 | -------------------------------------------------------------------------------- /docs/concepts/components.md: -------------------------------------------------------------------------------- 1 | # Components 2 | 3 | In order to be recognized as components and work in a Pipeline, Components must follow the contract below. 4 | 5 | ## Requirements 6 | 7 | ### `@component` decorator 8 | 9 | All component classes must be decorated with the `@component` decorator. This allows Canals to discover them. 10 | 11 | ### `@component.input` 12 | 13 | All components must decorate one single method with the `@component.input` decorator. This method must return a dataclass, which will be used as structure of the input of the component. 14 | 15 | For example, if the node is expecting a list of Documents, the fields of the returned dataclass should be `documents: List[Document]`. Note that you don't need to decorate the dataclass youself: `@component.input` will add the decorator for you. 16 | 17 | Here is an example of such method: 18 | 19 | ```python 20 | @component.input 21 | def input(self): 22 | class Input: 23 | value: int 24 | add: int 25 | 26 | return Input 27 | ``` 28 | 29 | Defaults are allowed, as much as default factories and other dataclass properties. 30 | 31 | By default `@component.input` sets `None` as default for all fields, regardless of their definition: this gives you the 32 | possibility of passing a part of the input to the pipeline without defining every field of the component. For example, 33 | using the above definition, you can create an Input dataclass as: 34 | 35 | ```python 36 | self.input(add=3) 37 | ``` 38 | 39 | and the resulting dataclass will look like `Input(value=None, add=3)`. 40 | 41 | However, if you don't explicitly define them as Optionals, Pipeline will make sure to collect all the values of this 42 | dataclass before calling the `run()` method, making them in practice non-optional. 43 | 44 | If you instead define a specific field as Optional in the dataclass, then Pipeline will **not** wait for them, and will 45 | run the component as soon as all the non-optional fields have received a value or, if all fields are optional, if at 46 | least one of them received it. 47 | 48 | This behavior allows Canals to define loops by not waiting on both incoming inputs of the entry component of the loop, 49 | and instead running as soon as at least one of them receives a value. 50 | 51 | ### `@component.output` 52 | 53 | All components must decorate one single method with the `@component.output` decorator. This method must return a dataclass, which will be used as structure of the output of the component. 54 | 55 | For example, if the node is producing a list of Documents, the fields of the returned dataclass should be `documents: List[Document]`. Note that you don't need to decorate the dataclass youself: `@component.output` will add the decorator for you. 56 | 57 | Here is an example of such method: 58 | 59 | ```python 60 | @component.output 61 | def output(self): 62 | class Output: 63 | value: int 64 | 65 | return Output 66 | ``` 67 | 68 | Defaults are allowed, as much as default factories and other dataclass properties. 69 | 70 | ### `__init__(self, **kwargs)` 71 | 72 | Optional method. 73 | 74 | Components may have an `__init__` method where they define: 75 | 76 | - `self.defaults = {parameter_name: parameter_default_value, ...}`: 77 | All values defined here will be sent to the `run()` method when the Pipeline calls it. 78 | If any of these parameters is also receiving input from other components, those have precedence. 79 | This collection of values is supposed to replace the need for default values in `run()` and make them 80 | dynamically configurable. Keep in mind that only these defaults will count at runtime: defaults given to 81 | the `Input` dataclass (see above) will be ignored. 82 | 83 | - `self.init_parameters = {same parameters that the __init__ method received}`: 84 | In this dictionary you can store any state the components wish to be persisted when they are saved. 85 | These values will be given to the `__init__` method of a new instance when the pipeline is loaded. 86 | Note that by default the `@component` decorator saves the arguments automatically. 87 | However, if a component sets their own `init_parameters` manually in `__init__()`, that will be used instead. 88 | Note: all of the values contained here **must be JSON serializable**. Serialize them manually if needed. 89 | 90 | Components should take only "basic" Python types as parameters of their `__init__` function, or iterables and 91 | dictionaries containing only such values. Anything else (objects, functions, etc) will raise an exception at init 92 | time. If there's the need for such values, consider serializing them to a string. 93 | 94 | _(TODO explain how to use classes and functions in init. In the meantime see `test/components/test_accumulate.py`)_ 95 | 96 | The `__init__` must be extrememly lightweight, because it's a frequent operation during the construction and 97 | validation of the pipeline. If a component has some heavy state to initialize (models, backends, etc...) refer to 98 | the `warm_up()` method. 99 | 100 | 101 | ### `warm_up(self)` 102 | 103 | Optional method. 104 | 105 | This method is called by Pipeline before the graph execution. Make sure to avoid double-initializations, 106 | because Pipeline will not keep track of which components it called `warm_up()` on. 107 | 108 | 109 | ### `run(self, data)` 110 | 111 | Mandatory method. 112 | 113 | This is the method where the main functionality of the component should be carried out. It's called by 114 | `Pipeline.run()`. 115 | 116 | When the component should run, Pipeline will call this method with an instance of the dataclass returned by the method decorated with `@component.input`. This dataclass contains: 117 | 118 | - all the input values coming from other components connected to it, 119 | - if any is missing, the corresponding value defined in `self.defaults`, if it exists. 120 | 121 | `run()` must return a single instance of the dataclass declared through the method decorated with `@component.output`. 122 | 123 | 124 | ## Example components 125 | 126 | Here is an example of a simple component that adds a fixed value to its input and returns their sum. 127 | 128 | ```python 129 | from typing import Optional 130 | from canals.component import component 131 | 132 | @component 133 | class AddFixedValue: 134 | """ 135 | Adds the value of `add` to `value`. If not given, `add` defaults to 1. 136 | """ 137 | 138 | @component.input # type: ignore 139 | def input(self): 140 | class Input: 141 | value: int 142 | add: int 143 | 144 | return Input 145 | 146 | @component.output # type: ignore 147 | def output(self): 148 | class Output: 149 | value: int 150 | 151 | return Output 152 | 153 | def __init__(self, add: Optional[int] = 1): 154 | if add: 155 | self.defaults = {"add": add} 156 | 157 | def run(self, data): 158 | return self.output(value=data.value + data.add) 159 | ``` 160 | 161 | See `tests/sample_components` for examples of more complex components with variable inputs and output, and so on. 162 | -------------------------------------------------------------------------------- /docs/concepts/concepts.md: -------------------------------------------------------------------------------- 1 | # Core concepts 2 | 3 | Canals is a **component orchestration engine**. It can be used to connect a group of smaller objects, called Components, 4 | that perform well-defined tasks into a network, called Pipeline, to achieve a larger goal. 5 | 6 | Components are Python objects that can execute a task, like reading a file, performing calculations, or making API 7 | calls. Canals connects these objects together: it builds a graph of components and takes care of managing their 8 | execution order, making sure that each object receives the input it expects from the other components of the pipeline at the right time. 9 | 10 | Canals relies on two main concepts: Components and Pipelines. 11 | 12 | ## What is a Component? 13 | 14 | A Component is a Python class that performs a well-defined task: for example a REST API call, a mathematical operation, 15 | a data trasformation, writing something to a file or a database, and so on. 16 | 17 | To be recognized as a Component by Canals, a Python class needs to respect these rules: 18 | 19 | 1. Must be decorated with the `@component` decorator. 20 | 3. Have a `run()` method that accepts a `data` parameter of type `ComponentInput` return a single object of type `ComponentOutput`. 21 | 22 | For example, the following is a Component that sums up two numbers: 23 | 24 | ```python 25 | from dataclasses import dataclass 26 | from canals.component import component, ComponentInput, ComponentOutput 27 | 28 | @component 29 | class AddFixedValue: 30 | """ 31 | Adds the value of `add` to `value`. If not given, `add` defaults to 1. 32 | """ 33 | 34 | @component.input 35 | def input(self): 36 | class Input: 37 | value: int 38 | add: int 39 | 40 | return Input 41 | 42 | @component.output 43 | def output(self): 44 | class Output: 45 | value: int 46 | 47 | return Output 48 | 49 | def __init__(self, add: Optional[int] = 1): 50 | if add: 51 | self.defaults = {"add": add} 52 | 53 | def run(self, data): 54 | return self.output(value=data.value + data.add) 55 | ``` 56 | 57 | We will see the details of all of these requirements below. 58 | 59 | ## What is a Pipeline? 60 | 61 | A Pipeline is a network of Components. Pipelines define what components receive and send output to which other, makes 62 | sure all the connections are valid, and takes care of calling the component's `run()` method in the right order. 63 | 64 | Pipeline connects compoonents together through so-called connections, which are the edges of the pipeline graph. 65 | Pipeline is going to make sure that all the connections are valid based on the inputs and output that Components have 66 | declared. 67 | 68 | For example, if a component produces a value of type `List[Document]` and another component expects an input 69 | of type `List[Document]`, Pipeline will be able to connect them. Otherwise, it will raise an exception. 70 | 71 | This is a simple example of how a Pipeline is created: 72 | 73 | 74 | ```python 75 | from canals.pipeline import Pipeline 76 | 77 | # Some Canals components 78 | from my_components import AddFixedValue, MultiplyBy 79 | 80 | pipeline = Pipeline() 81 | 82 | # Components can be initialized as standalone objects. 83 | # These instances can be added to the Pipeline in several places. 84 | multiplication = MultiplyBy(multiply_by=2) 85 | addition = AddFixedValue(add=1) 86 | 87 | # Components are added with a name and an component 88 | pipeline.add_component("double", multiplication) 89 | pipeline.add_component("add_one", addition) 90 | pipeline.add_component("add_one_again", addition) # Component instances can be reused 91 | pipeline.add_component("add_two", AddFixedValue(add=2)) 92 | 93 | # Connect the components together 94 | pipeline.connect(connect_from="double", connect_to="add_one") 95 | pipeline.connect(connect_from="add_one", connect_to="add_one_again") 96 | pipeline.connect(connect_from="add_one_again", connect_to="add_two") 97 | 98 | # Pipeline can be drawn 99 | pipeline.draw("pipeline.jpg") 100 | 101 | # Pipelines are run by giving them the data that the input nodes expect. 102 | results = pipeline.run(data={"double": multiplication.input(value=1)}) 103 | 104 | print(results) 105 | 106 | # prints {"add_two": AddFixedValue.Output(value=6)} 107 | ``` 108 | 109 | This is how the pipeline's graph looks like: 110 | 111 | ```mermaid 112 | graph TD; 113 | double -- value -> value --> add_one 114 | add_one -- value -> value --> add_one_again 115 | add_one_again -- value -> value --> add_two 116 | IN([input]) -- value --> double 117 | add_two -- value --> OUT([output]) 118 | ``` 119 | -------------------------------------------------------------------------------- /docs/concepts/pipelines.md: -------------------------------------------------------------------------------- 1 | # Pipelines 2 | 3 | Canals aims to support pipelines of (close to) arbitrary complexity. It currently supports a variety of different topologies, such as: 4 | 5 | - Simple linear pipelines 6 | - Branching pipelines where all or only some branches are executed 7 | - Pipelines merging a variable number of inputs, depending on decisions taken upstream 8 | - Simple loops 9 | - Multiple entry components, either alternative or parallel 10 | - Multiple exit components, either alternative or parallel 11 | 12 | Check the pipeline's test suite for some examples. 13 | 14 | ## Validation 15 | 16 | Pipeline performs validation on the connection type level: when calling `Pipeline.connect()`, it uses the `@component.input` and `@component.output` dataclass fields to make sure that the connection is possible. 17 | 18 | On top of this, specific connections can be specified with the syntax `component_name.input_or_output_field`. 19 | 20 | For example, let's imagine we have two components with the following I/O declared: 21 | 22 | ```python 23 | @component 24 | class ComponentA: 25 | 26 | @component.input 27 | def input(self): 28 | class Input: 29 | input_value: int 30 | 31 | return Input 32 | 33 | @component.output 34 | def output(self): 35 | class Output: 36 | output_value: str 37 | 38 | return Output 39 | 40 | def run(self, data): 41 | return self.output(intermediate_value="hello") 42 | 43 | @component 44 | class ComponentB: 45 | 46 | @component.input 47 | def input(self): 48 | class Input: 49 | input_value: str 50 | 51 | return Input 52 | 53 | @component.output 54 | def output(self): 55 | class Output: 56 | output_value: List[str] 57 | 58 | return Output 59 | 60 | def run(self, data): 61 | return self.output(output_value=["h", "e", "l", "l", "o"]) 62 | ``` 63 | 64 | This is the behavior of `Pipeline.connect()`: 65 | 66 | ```python 67 | pipeline.add_component('component_a', ComponentA()) 68 | pipeline.add_component('component_b', ComponentB()) 69 | 70 | # All of these succeeds 71 | pipeline.connect('component_a', 'component_b') 72 | pipeline.connect('component_a.output_value', 'component_b') 73 | pipeline.connect('component_a', 'component_b.input_value') 74 | pipeline.connect('component_a.output_value', 'component_b.input_value') 75 | ``` 76 | 77 | These, instead, fail: 78 | 79 | ```python 80 | pipeline.connect('component_a', 'component_a') 81 | # canals.errors.PipelineConnectError: Cannot connect 'component_a' with 'component_a': no matching connections available. 82 | # 'component_a': 83 | # - output_value (str) 84 | # 'component_a': 85 | # - input_value (int, available) 86 | 87 | pipeline.connect('component_b', 'component_a') 88 | # canals.errors.PipelineConnectError: Cannot connect 'component_b' with 'component_a': no matching connections available. 89 | # 'component_b': 90 | # - output_value (List[str]) 91 | # 'component_a': 92 | # - input_value (int, available) 93 | ``` 94 | 95 | In addition, components names are validated: 96 | 97 | ```python 98 | pipeline.connect('component_a', 'component_c') 99 | # ValueError: Component named component_c not found in the pipeline. 100 | ``` 101 | 102 | Just like input and output names, when stated: 103 | 104 | ```python 105 | pipeline.connect('component_a.input_value', 'component_b') 106 | # canals.errors.PipelineConnectError: 'component_a.typo does not exist. Output connections of component_a are: output_value (type str) 107 | 108 | pipeline.connect('component_a', 'component_b.output_value') 109 | # canals.errors.PipelineConnectError: 'component_b.output_value does not exist. Input connections of component_b are: input_value (type str) 110 | ``` 111 | 112 | ## Save and Load 113 | 114 | Pipelines can be serialized to Python dictionaries, that can be then dumped to JSON or to any other suitable format, like YAML, TOML, HCL, etc. These pipelines can then be loaded back. 115 | 116 | Here is an example of Pipeline saving and loading: 117 | 118 | ```python 119 | from haystack.pipelines import Pipeline, save_pipelines, load_pipelines 120 | 121 | pipe1 = Pipeline() 122 | pipe2 = Pipeline() 123 | 124 | # .. assemble the pipelines ... 125 | 126 | # Save the pipelines 127 | save_pipelines( 128 | pipelines={ 129 | "pipe1": pipe1, 130 | "pipe2": pipe2, 131 | }, 132 | path="my_pipelines.json", 133 | _writer=json.dumps 134 | ) 135 | 136 | # Load the pipelines 137 | new_pipelines = load_pipelines( 138 | path="my_pipelines.json", 139 | _reader=json.loads 140 | ) 141 | 142 | assert new_pipelines["pipe1"] == pipe1 143 | assert new_pipelines["pipe2"] == pipe2 144 | ``` 145 | 146 | Note how the save/load functions accept a `_writer`/`_reader` function: this choice frees us from committing strongly to a specific template language, and although a default will be set (be it YAML, TOML, HCL or anything else) the decision can be overridden by passing another explicit reader/writer function to the `save_pipelines`/`load_pipelines` functions. 147 | 148 | This is how the resulting file will look like, assuming a JSON writer was chosen. 149 | 150 | `my_pipeline.json` 151 | 152 | ```python 153 | { 154 | "pipelines": { 155 | "pipe1": { 156 | # All the components that would be added with a 157 | # Pipeline.add_component() call 158 | "components": { 159 | "first_addition": { 160 | "type": "AddValue", 161 | "init_parameters": { 162 | "add": 1 163 | }, 164 | }, 165 | "double": { 166 | "type": "Double", 167 | "init_parameters": {} 168 | }, 169 | "second_addition": { 170 | "type": "AddValue", 171 | "init_parameters": { 172 | "add": 1 173 | }, 174 | }, 175 | # This is how instances of the same component are reused 176 | "third_addition": { 177 | "refer_to": "pipe1.first_addition" 178 | }, 179 | }, 180 | # All the components that would be made with a 181 | # Pipeline.connect() call 182 | "connections": [ 183 | ("first_addition", "double", "value/value"), 184 | ("double", "second_addition", "value/value"), 185 | ("second_addition", "third_addition", "value/value"), 186 | ], 187 | # All other Pipeline.__init__() parameters go here. 188 | "metadata": {"type": "test pipeline", "author": "me"}, 189 | "max_loops_allowed": 100, 190 | }, 191 | "pipe2": { 192 | "components": { 193 | "first_addition": { 194 | # We can reference components from other pipelines too! 195 | "refer_to": "pipe1.first_addition", 196 | }, 197 | "double": { 198 | "type": "Double", 199 | "init_parameters": {} 200 | }, 201 | "second_addition": { 202 | "refer_to": "pipe1.second_addition" 203 | }, 204 | }, 205 | "connections": [ 206 | ("first_addition", "double", "value/value"), 207 | ("double", "second_addition", "value/value"), 208 | ], 209 | "metadata": {"type": "another test pipeline", "author": "you"}, 210 | "max_loops_allowed": 100, 211 | }, 212 | }, 213 | # A list of "dependencies" for the application. 214 | # Used to ensure all external components are present when loading. 215 | "dependencies": ["my_custom_components_module"], 216 | } 217 | ``` 218 | -------------------------------------------------------------------------------- /docs/index.md: -------------------------------------------------------------------------------- 1 | # Canals 2 | 3 |

    4 | 5 |

    6 | 7 | [![PyPI - Version](https://img.shields.io/pypi/v/canals.svg)](https://pypi.org/project/canals) 8 | [![PyPI - Python Version](https://img.shields.io/pypi/pyversions/canals.svg)](https://pypi.org/project/canals) 9 | 10 |
    11 | 12 | Canals is a **component orchestration engine**. Components are Python objects that can execute a task, like reading a file, performing calculations, or making API calls. Canals connects these objects together: it builds a graph of components and takes care of managing their execution order, making sure that each object receives the input it expects from the other components of the pipeline. 13 | 14 | Canals powers version 2.0 of [Haystack](https://github.com/deepset-ai/haystack). 15 | 16 | ## Installation 17 | 18 | Running: 19 | 20 | ```console 21 | pip install canals 22 | ``` 23 | 24 | gives you the bare minimum necessary to run Canals. 25 | 26 | To be able to draw pipelines, please make sure you have either an internet connection (to reach the Mermaid graph renderer at `https://mermaid.ink`) or [graphviz](https://graphviz.org/download/) installed and then install Canals as: 27 | 28 | ### Mermaid 29 | ```console 30 | pip install canals[mermaid] 31 | ``` 32 | 33 | ### GraphViz 34 | ```console 35 | sudo apt install graphviz # You may need `graphviz-dev` too 36 | pip install canals[graphviz] 37 | ``` 38 | -------------------------------------------------------------------------------- /images/canals-logo-dark.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepset-ai/canals/5fd3dcff3cebff81bb6fc38ed3303c2c2a725321/images/canals-logo-dark.png -------------------------------------------------------------------------------- /images/canals-logo-light.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepset-ai/canals/5fd3dcff3cebff81bb6fc38ed3303c2c2a725321/images/canals-logo-light.png -------------------------------------------------------------------------------- /mkdocs.yml: -------------------------------------------------------------------------------- 1 | site_name: Canals 2 | site_url: https://deepset-ai.github.io/canals/ 3 | 4 | theme: 5 | name: material 6 | features: 7 | - content.code.copy 8 | 9 | plugins: 10 | - search 11 | - mermaid2 12 | - mkdocstrings 13 | 14 | markdown_extensions: 15 | - pymdownx.highlight: 16 | anchor_linenums: true 17 | line_spans: __span 18 | pygments_lang_class: true 19 | - pymdownx.inlinehilite 20 | - pymdownx.snippets 21 | - pymdownx.superfences: 22 | preserve_tabs: true 23 | custom_fences: 24 | - name: mermaid 25 | class: mermaid 26 | format: !!python/name:pymdownx.superfences.fence_code_format 27 | 28 | extra_javascript: 29 | - optionalConfig.js 30 | - https://unpkg.com/mermaid@9.4.0/dist/mermaid.min.js 31 | - extra-loader.js 32 | 33 | nav: 34 | - Get Started: index.md 35 | - Concepts: 36 | - Core Concepts: concepts/concepts.md 37 | - Components: concepts/components.md 38 | - Pipelines: concepts/pipelines.md 39 | - API Docs: 40 | - Canals: api-docs/canals.md 41 | - Component: api-docs/component.md 42 | - Pipeline: api-docs/pipeline.md 43 | - Testing: api-docs/testing.md 44 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["hatchling"] 3 | build-backend = "hatchling.build" 4 | 5 | [project] 6 | name = "canals" 7 | description = 'A component orchestration engine for Haystack' 8 | readme = "README.md" 9 | requires-python = ">=3.8" 10 | license = "Apache-2.0" 11 | keywords = [] 12 | authors = [{ name = "ZanSara", email = "sara.zanzottera@deepset.ai" }] 13 | classifiers = [ 14 | "Development Status :: 3 - Alpha", 15 | "License :: Freely Distributable", 16 | "License :: OSI Approved :: Apache Software License", 17 | "Programming Language :: Python", 18 | "Programming Language :: Python :: 3.8", 19 | "Programming Language :: Python :: 3.9", 20 | "Programming Language :: Python :: 3.10", 21 | "Programming Language :: Python :: 3.11", 22 | "Programming Language :: Python :: Implementation :: CPython", 23 | "Programming Language :: Python :: Implementation :: PyPy", 24 | ] 25 | dynamic = ["version"] 26 | dependencies = [ 27 | "networkx", # Pipeline graphs 28 | "requests", # Mermaid diagrams 29 | "typing_extensions", 30 | ] 31 | 32 | [project.optional-dependencies] 33 | dev = [ 34 | "hatch", 35 | "pre-commit", 36 | "mypy", 37 | "pylint==2.15.10", 38 | "black[jupyter]==22.6.0", 39 | "pytest", 40 | "pytest-cov", 41 | "requests", 42 | "coverage", 43 | ] 44 | docs = ["mkdocs-material", "mkdocstrings[python]", "mkdocs-mermaid2-plugin"] 45 | 46 | [project.urls] 47 | Documentation = "https://github.com/deepset-ai/canals#readme" 48 | Issues = "https://github.com/deepset-ai/canals/issues" 49 | Source = "https://github.com/deepset-ai/canals" 50 | 51 | [tool.hatch.version] 52 | path = "canals/__about__.py" 53 | 54 | [tool.hatch.build] 55 | include = ["/canals/**/*.py"] 56 | 57 | [tool.hatch.envs.default] 58 | dependencies = ["pytest", "pytest-cov", "requests"] 59 | 60 | [tool.hatch.envs.default.scripts] 61 | cov = "pytest --cov-report xml:coverage.xml --cov-config=pyproject.toml --cov=canals --cov=tests {args}" 62 | no-cov = "cov --no-cov {args}" 63 | 64 | [[tool.hatch.envs.test.matrix]] 65 | python = ["38", "39", "310", "311"] 66 | 67 | [tool.coverage.run] 68 | branch = true 69 | parallel = true 70 | omit = ["canals/__about__.py"] 71 | 72 | [tool.coverage.report] 73 | exclude_lines = ["no cov", "if __name__ == .__main__.:", "if TYPE_CHECKING:"] 74 | 75 | [tool.black] 76 | line-length = 120 77 | 78 | [tool.pylint.'MESSAGES CONTROL'] 79 | max-line-length = 120 80 | good-names = "e" 81 | max-args = 10 82 | max-locals = 15 83 | disable = [ 84 | "fixme", 85 | "line-too-long", 86 | "missing-class-docstring", 87 | "missing-module-docstring", 88 | "too-few-public-methods", 89 | ] 90 | -------------------------------------------------------------------------------- /sample_components/__init__.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: 2022-present deepset GmbH 2 | # 3 | # SPDX-License-Identifier: Apache-2.0 4 | from sample_components.concatenate import Concatenate 5 | from sample_components.subtract import Subtract 6 | from sample_components.parity import Parity 7 | from sample_components.remainder import Remainder 8 | from sample_components.accumulate import Accumulate 9 | from sample_components.threshold import Threshold 10 | from sample_components.add_value import AddFixedValue 11 | from sample_components.repeat import Repeat 12 | from sample_components.sum import Sum 13 | from sample_components.greet import Greet 14 | from sample_components.double import Double 15 | from sample_components.joiner import StringJoiner, StringListJoiner, FirstIntSelector 16 | from sample_components.hello import Hello 17 | from sample_components.text_splitter import TextSplitter 18 | from sample_components.merge_loop import MergeLoop 19 | from sample_components.self_loop import SelfLoop 20 | from sample_components.fstring import FString 21 | 22 | __all__ = [ 23 | "Concatenate", 24 | "Subtract", 25 | "Parity", 26 | "Remainder", 27 | "Accumulate", 28 | "Threshold", 29 | "AddFixedValue", 30 | "MergeLoop", 31 | "Repeat", 32 | "Sum", 33 | "Greet", 34 | "Double", 35 | "StringJoiner", 36 | "Hello", 37 | "TextSplitter", 38 | "StringListJoiner", 39 | "FirstIntSelector", 40 | "SelfLoop", 41 | "FString", 42 | ] 43 | -------------------------------------------------------------------------------- /sample_components/accumulate.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: 2022-present deepset GmbH 2 | # 3 | # SPDX-License-Identifier: Apache-2.0 4 | from typing import Callable, Optional, Dict, Any 5 | import sys 6 | import builtins 7 | from importlib import import_module 8 | 9 | from canals.serialization import default_to_dict 10 | from canals.component import component 11 | from canals.errors import ComponentDeserializationError 12 | 13 | 14 | def _default_function(first: int, second: int) -> int: 15 | return first + second 16 | 17 | 18 | @component 19 | class Accumulate: 20 | """ 21 | Accumulates the value flowing through the connection into an internal attribute. 22 | The sum function can be customized. 23 | 24 | Example of how to deal with serialization when some of the parameters 25 | are not directly serializable. 26 | """ 27 | 28 | def __init__(self, function: Optional[Callable] = None): 29 | """ 30 | :param function: the function to use to accumulate the values. 31 | The function must take exactly two values. 32 | If it's a callable, it's used as it is. 33 | If it's a string, the component will look for it in sys.modules and 34 | import it at need. This is also a parameter. 35 | """ 36 | self.state = 0 37 | self.function: Callable = _default_function if function is None else function # type: ignore 38 | 39 | def to_dict(self) -> Dict[str, Any]: # pylint: disable=missing-function-docstring 40 | module = sys.modules.get(self.function.__module__) 41 | if not module: 42 | raise ValueError("Could not locate the import module.") 43 | if module == builtins: 44 | function_name = self.function.__name__ 45 | else: 46 | function_name = f"{module.__name__}.{self.function.__name__}" 47 | 48 | return default_to_dict(self, function=function_name) 49 | 50 | @classmethod 51 | def from_dict(cls, data: Dict[str, Any]) -> "Accumulate": # pylint: disable=missing-function-docstring 52 | if "type" not in data: 53 | raise ComponentDeserializationError("Missing 'type' in component serialization data") 54 | if data["type"] != f"{cls.__module__}.{cls.__name__}": 55 | raise ComponentDeserializationError(f"Class '{data['type']}' can't be deserialized as '{cls.__name__}'") 56 | 57 | init_params = data.get("init_parameters", {}) 58 | 59 | accumulator_function = None 60 | if "function" in init_params: 61 | parts = init_params["function"].split(".") 62 | module_name = ".".join(parts[:-1]) 63 | function_name = parts[-1] 64 | module = import_module(module_name) 65 | accumulator_function = getattr(module, function_name) 66 | 67 | return cls(function=accumulator_function) 68 | 69 | @component.output_types(value=int) 70 | def run(self, value: int): 71 | """ 72 | Accumulates the value flowing through the connection into an internal attribute. 73 | The sum function can be customized. 74 | """ 75 | self.state = self.function(self.state, value) 76 | return {"value": self.state} 77 | -------------------------------------------------------------------------------- /sample_components/add_value.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: 2022-present deepset GmbH 2 | # 3 | # SPDX-License-Identifier: Apache-2.0 4 | from typing import Optional 5 | 6 | from canals import component 7 | 8 | 9 | @component 10 | class AddFixedValue: 11 | """ 12 | Adds two values together. 13 | """ 14 | 15 | def __init__(self, add: int = 1): 16 | self.add = add 17 | 18 | @component.output_types(result=int) 19 | def run(self, value: int, add: Optional[int] = None): 20 | """ 21 | Adds two values together. 22 | """ 23 | if add is None: 24 | add = self.add 25 | return {"result": value + add} 26 | -------------------------------------------------------------------------------- /sample_components/concatenate.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: 2022-present deepset GmbH 2 | # 3 | # SPDX-License-Identifier: Apache-2.0 4 | from typing import Union, List 5 | 6 | from canals import component 7 | 8 | 9 | @component 10 | class Concatenate: 11 | """ 12 | Concatenates two values 13 | """ 14 | 15 | @component.output_types(value=List[str]) 16 | def run(self, first: Union[List[str], str], second: Union[List[str], str]): 17 | """ 18 | Concatenates two values 19 | """ 20 | if isinstance(first, str) and isinstance(second, str): 21 | res = [first, second] 22 | elif isinstance(first, list) and isinstance(second, list): 23 | res = first + second 24 | elif isinstance(first, list) and isinstance(second, str): 25 | res = first + [second] 26 | elif isinstance(first, str) and isinstance(second, list): 27 | res = [first] + second 28 | return {"value": res} 29 | -------------------------------------------------------------------------------- /sample_components/double.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: 2022-present deepset GmbH 2 | # 3 | # SPDX-License-Identifier: Apache-2.0 4 | from canals import component 5 | 6 | 7 | @component 8 | class Double: 9 | """ 10 | Doubles the input value. 11 | """ 12 | 13 | @component.output_types(value=int) 14 | def run(self, value: int): 15 | """ 16 | Doubles the input value. 17 | """ 18 | return {"value": value * 2} 19 | -------------------------------------------------------------------------------- /sample_components/fstring.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: 2022-present deepset GmbH 2 | # 3 | # SPDX-License-Identifier: Apache-2.0 4 | from typing import List, Any, Optional 5 | 6 | from canals import component 7 | 8 | 9 | @component 10 | class FString: 11 | """ 12 | Takes a template string and a list of variables in input and returns the formatted string in output. 13 | """ 14 | 15 | def __init__(self, template: str, variables: Optional[List[str]] = None): 16 | self.template = template 17 | self.variables = variables or [] 18 | if "template" in self.variables: 19 | raise ValueError("The variable name 'template' is reserved and cannot be used.") 20 | component.set_input_types(self, **{variable: Any for variable in self.variables}) 21 | 22 | @component.output_types(string=str) 23 | def run(self, template: Optional[str] = None, **kwargs): 24 | """ 25 | Takes a template string and a list of variables in input and returns the formatted string in output. 26 | 27 | If the template is not given, the component will use the one given at initialization. 28 | """ 29 | if not template: 30 | template = self.template 31 | return {"string": template.format(**kwargs)} 32 | -------------------------------------------------------------------------------- /sample_components/greet.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: 2022-present deepset GmbH 2 | # 3 | # SPDX-License-Identifier: Apache-2.0 4 | from typing import Optional 5 | import logging 6 | 7 | from canals import component 8 | 9 | 10 | logger = logging.getLogger(__name__) 11 | 12 | 13 | @component 14 | class Greet: 15 | """ 16 | Logs a greeting message without affecting the value passing on the connection. 17 | """ 18 | 19 | def __init__( 20 | self, 21 | message: str = "\nGreeting component says: Hi! The value is {value}\n", 22 | log_level: str = "INFO", 23 | ): 24 | """ 25 | :param message: the message to log. Can use `{value}` to embed the value. 26 | :param log_level: the level to log at. 27 | """ 28 | if log_level and not getattr(logging, log_level): 29 | raise ValueError(f"This log level does not exist: {log_level}") 30 | self.message = message 31 | self.log_level = log_level 32 | 33 | @component.output_types(value=int) 34 | def run(self, value: int, message: Optional[str] = None, log_level: Optional[str] = None): 35 | """ 36 | Logs a greeting message without affecting the value passing on the connection. 37 | """ 38 | if not message: 39 | message = self.message 40 | if not log_level: 41 | log_level = self.log_level 42 | 43 | level = getattr(logging, log_level, None) 44 | if not level: 45 | raise ValueError(f"This log level does not exist: {log_level}") 46 | 47 | logger.log(level=level, msg=message.format(value=value)) 48 | return {"value": value} 49 | -------------------------------------------------------------------------------- /sample_components/hello.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: 2022-present deepset GmbH 2 | # 3 | # SPDX-License-Identifier: Apache-2.0 4 | from canals import component 5 | 6 | 7 | @component 8 | class Hello: 9 | @component.output_types(output=str) 10 | def run(self, word: str): 11 | """ 12 | Takes a string in input and returns "Hello, !" 13 | in output. 14 | """ 15 | return {"output": f"Hello, {word}!"} 16 | -------------------------------------------------------------------------------- /sample_components/joiner.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: 2022-present deepset GmbH 2 | # 3 | # SPDX-License-Identifier: Apache-2.0 4 | from typing import List 5 | 6 | from canals import component 7 | from canals.component.types import Variadic 8 | 9 | 10 | @component 11 | class StringJoiner: 12 | @component.output_types(output=str) 13 | def run(self, input_str: Variadic[str]): 14 | """ 15 | Take strings from multiple input nodes and join them 16 | into a single one returned in output. Since `input_str` 17 | is Variadic, we know we'll receive a List[str]. 18 | """ 19 | return {"output": " ".join(input_str)} 20 | 21 | 22 | @component 23 | class StringListJoiner: 24 | @component.output_types(output=str) 25 | def run(self, inputs: Variadic[List[str]]): 26 | """ 27 | Take list of strings from multiple input nodes and join them 28 | into a single one returned in output. Since `input_str` 29 | is Variadic, we know we'll receive a List[List[str]]. 30 | """ 31 | retval: List[str] = [] 32 | for list_of_strings in inputs: 33 | retval += list_of_strings 34 | 35 | return {"output": retval} 36 | 37 | 38 | @component 39 | class FirstIntSelector: 40 | @component.output_types(output=int) 41 | def run(self, inputs: Variadic[int]): 42 | """ 43 | Take intd from multiple input nodes and return the first one 44 | that is not None. Since `input` is Variadic, we know we'll 45 | receive a List[int]. 46 | """ 47 | for inp in inputs: # type: ignore 48 | if inp is not None: 49 | return {"output": inp} 50 | return {} 51 | -------------------------------------------------------------------------------- /sample_components/merge_loop.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: 2022-present deepset GmbH 2 | # 3 | # SPDX-License-Identifier: Apache-2.0 4 | from typing import List, Any, Optional, Dict 5 | import sys 6 | 7 | from canals import component 8 | from canals.errors import DeserializationError 9 | from canals.serialization import default_to_dict 10 | 11 | 12 | @component 13 | class MergeLoop: 14 | def __init__(self, expected_type: Any, inputs: List[str]): 15 | component.set_input_types(self, **{input_name: Optional[expected_type] for input_name in inputs}) 16 | component.set_output_types(self, value=expected_type) 17 | 18 | if expected_type.__module__ == "builtins": 19 | self.expected_type = f"builtins.{expected_type.__name__}" 20 | elif expected_type.__module__ == "typing": 21 | self.expected_type = str(expected_type) 22 | else: 23 | self.expected_type = f"{expected_type.__module__}.{expected_type.__name__}" 24 | 25 | self.inputs = inputs 26 | 27 | def to_dict(self) -> Dict[str, Any]: # pylint: disable=missing-function-docstring 28 | return default_to_dict( 29 | self, 30 | expected_type=self.expected_type, 31 | inputs=self.inputs, 32 | ) 33 | 34 | @classmethod 35 | def from_dict(cls, data: Dict[str, Any]) -> "MergeLoop": # pylint: disable=missing-function-docstring 36 | if "type" not in data: 37 | raise DeserializationError("Missing 'type' in component serialization data") 38 | if data["type"] != f"{cls.__module__}.{cls.__name__}": 39 | raise DeserializationError(f"Class '{data['type']}' can't be deserialized as '{cls.__name__}'") 40 | 41 | init_params = data.get("init_parameters", {}) 42 | 43 | if "expected_type" not in init_params: 44 | raise DeserializationError("Missing 'expected_type' field in 'init_parameters'") 45 | 46 | if "inputs" not in init_params: 47 | raise DeserializationError("Missing 'inputs' field in 'init_parameters'") 48 | 49 | module = sys.modules[__name__] 50 | fully_qualified_type_name = init_params["expected_type"] 51 | if fully_qualified_type_name.startswith("builtins."): 52 | module = sys.modules["builtins"] 53 | type_name = fully_qualified_type_name.split(".")[-1] 54 | try: 55 | expected_type = getattr(module, type_name) 56 | except AttributeError as exc: 57 | raise DeserializationError( 58 | f"Can't find type '{type_name}', import '{fully_qualified_type_name}' to fix the issue" 59 | ) from exc 60 | 61 | inputs = init_params["inputs"] 62 | 63 | return cls(expected_type=expected_type, inputs=inputs) 64 | 65 | def run(self, **kwargs): 66 | """ 67 | :param kwargs: find the first non-None value and return it. 68 | """ 69 | for value in kwargs.values(): 70 | if value is not None: 71 | return {"value": value} 72 | return {} 73 | -------------------------------------------------------------------------------- /sample_components/parity.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: 2022-present deepset GmbH 2 | # 3 | # SPDX-License-Identifier: Apache-2.0 4 | from canals import component 5 | 6 | 7 | @component 8 | class Parity: # pylint: disable=too-few-public-methods 9 | """ 10 | Redirects the value, unchanged, along the 'even' connection if even, or along the 'odd' one if odd. 11 | """ 12 | 13 | @component.output_types(even=int, odd=int) 14 | def run(self, value: int): 15 | """ 16 | :param value: The value to check for parity 17 | """ 18 | remainder = value % 2 19 | if remainder: 20 | return {"odd": value} 21 | return {"even": value} 22 | -------------------------------------------------------------------------------- /sample_components/remainder.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: 2022-present deepset GmbH 2 | # 3 | # SPDX-License-Identifier: Apache-2.0 4 | from canals import component 5 | 6 | 7 | @component 8 | class Remainder: 9 | def __init__(self, divisor=3): 10 | if divisor == 0: 11 | raise ValueError("Can't divide by zero") 12 | self.divisor = divisor 13 | component.set_output_types(self, **{f"remainder_is_{val}": int for val in range(divisor)}) 14 | 15 | def run(self, value: int): 16 | """ 17 | :param value: the value to check the remainder of. 18 | """ 19 | remainder = value % self.divisor 20 | output = {f"remainder_is_{remainder}": value} 21 | return output 22 | -------------------------------------------------------------------------------- /sample_components/repeat.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: 2022-present deepset GmbH 2 | # 3 | # SPDX-License-Identifier: Apache-2.0 4 | from typing import List 5 | 6 | from canals import component 7 | 8 | 9 | @component 10 | class Repeat: 11 | def __init__(self, outputs: List[str]): 12 | self.outputs = outputs 13 | component.set_output_types(self, **{k: int for k in outputs}) 14 | 15 | def run(self, value: int): 16 | """ 17 | :param value: the value to repeat. 18 | """ 19 | return {val: value for val in self.outputs} 20 | -------------------------------------------------------------------------------- /sample_components/self_loop.py: -------------------------------------------------------------------------------- 1 | from canals import component 2 | from canals.component.types import Variadic 3 | 4 | 5 | @component 6 | class SelfLoop: 7 | """ 8 | Decreases the initial value in steps of 1 until the target value is reached. 9 | For no good reason it uses a self-loop to do so :) 10 | """ 11 | 12 | def __init__(self, target: int = 0): 13 | self.target = target 14 | 15 | @component.output_types(current_value=int, final_result=int) 16 | def run(self, values: Variadic[int]): 17 | """ 18 | Decreases the input value in steps of 1 until the target value is reached. 19 | """ 20 | value = values[0] # type: ignore 21 | value -= 1 22 | if value == self.target: 23 | return {"final_result": value} 24 | return {"current_value": value} 25 | -------------------------------------------------------------------------------- /sample_components/subtract.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: 2022-present deepset GmbH 2 | # 3 | # SPDX-License-Identifier: Apache-2.0 4 | from canals import component 5 | 6 | 7 | @component 8 | class Subtract: 9 | """ 10 | Compute the difference between two values. 11 | """ 12 | 13 | @component.output_types(difference=int) 14 | def run(self, first_value: int, second_value: int): 15 | """ 16 | :param first_value: name of the connection carrying the value to subtract from. 17 | :param second_value: name of the connection carrying the value to subtract. 18 | """ 19 | return {"difference": first_value - second_value} 20 | -------------------------------------------------------------------------------- /sample_components/sum.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: 2022-present deepset GmbH 2 | # 3 | # SPDX-License-Identifier: Apache-2.0 4 | from canals import component 5 | from canals.component.types import Variadic 6 | 7 | 8 | @component 9 | class Sum: 10 | @component.output_types(total=int) 11 | def run(self, values: Variadic[int]): 12 | """ 13 | :param value: the values to sum. 14 | """ 15 | return {"total": sum(values)} 16 | -------------------------------------------------------------------------------- /sample_components/text_splitter.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: 2022-present deepset GmbH 2 | # 3 | # SPDX-License-Identifier: Apache-2.0 4 | from typing import List 5 | 6 | from canals import component 7 | 8 | 9 | @component 10 | class TextSplitter: 11 | @component.output_types(output=List[str]) 12 | def run(self, sentence: str): 13 | """ 14 | Takes a sentence in input and returns its words" 15 | in output. 16 | """ 17 | return {"output": sentence.split()} 18 | -------------------------------------------------------------------------------- /sample_components/threshold.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: 2022-present deepset GmbH 2 | # 3 | # SPDX-License-Identifier: Apache-2.0 4 | from typing import Optional 5 | 6 | from canals import component 7 | 8 | 9 | @component 10 | class Threshold: # pylint: disable=too-few-public-methods 11 | """ 12 | Redirects the value, unchanged, along a different connection whether the value is above 13 | or below the given threshold. 14 | 15 | :param threshold: the number to compare the input value against. This is also a parameter. 16 | """ 17 | 18 | def __init__(self, threshold: int = 10): 19 | """ 20 | :param threshold: the number to compare the input value against. 21 | """ 22 | self.threshold = threshold 23 | 24 | @component.output_types(above=int, below=int) 25 | def run(self, value: int, threshold: Optional[int] = None): 26 | """ 27 | Redirects the value, unchanged, along a different connection whether the value is above 28 | or below the given threshold. 29 | 30 | :param threshold: the number to compare the input value against. This is also a parameter. 31 | """ 32 | if threshold is None: 33 | threshold = self.threshold 34 | 35 | if value < threshold: 36 | return {"below": value} 37 | return {"above": value} 38 | -------------------------------------------------------------------------------- /test/__init__.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: 2022-present deepset GmbH 2 | # 3 | # SPDX-License-Identifier: Apache-2.0 4 | -------------------------------------------------------------------------------- /test/component/test_component.py: -------------------------------------------------------------------------------- 1 | import typing 2 | from typing import Any, Optional 3 | 4 | import pytest 5 | 6 | from canals import component 7 | from canals.component.descriptions import find_component_inputs, find_component_outputs 8 | from canals.errors import ComponentError 9 | from canals.component import InputSocket, OutputSocket, Component 10 | 11 | 12 | def test_correct_declaration(): 13 | @component 14 | class MockComponent: 15 | def to_dict(self): 16 | return {} 17 | 18 | @classmethod 19 | def from_dict(cls, data): 20 | return cls() 21 | 22 | @component.output_types(output_value=int) 23 | def run(self, input_value: int): 24 | return {"output_value": input_value} 25 | 26 | # Verifies also instantiation works with no issues 27 | assert MockComponent() 28 | assert component.registry["test_component.MockComponent"] == MockComponent 29 | 30 | 31 | def test_correct_declaration_with_additional_readonly_property(): 32 | @component 33 | class MockComponent: 34 | @property 35 | def store(self): 36 | return "test_store" 37 | 38 | def to_dict(self): 39 | return {} 40 | 41 | @classmethod 42 | def from_dict(cls, data): 43 | return cls() 44 | 45 | @component.output_types(output_value=int) 46 | def run(self, input_value: int): 47 | return {"output_value": input_value} 48 | 49 | # Verifies that instantiation works with no issues 50 | assert MockComponent() 51 | assert component.registry["test_component.MockComponent"] == MockComponent 52 | assert MockComponent().store == "test_store" 53 | 54 | 55 | def test_correct_declaration_with_additional_writable_property(): 56 | @component 57 | class MockComponent: 58 | @property 59 | def store(self): 60 | return "test_store" 61 | 62 | @store.setter 63 | def store(self, value): 64 | self._store = value 65 | 66 | def to_dict(self): 67 | return {} 68 | 69 | @classmethod 70 | def from_dict(cls, data): 71 | return cls() 72 | 73 | @component.output_types(output_value=int) 74 | def run(self, input_value: int): 75 | return {"output_value": input_value} 76 | 77 | # Verifies that instantiation works with no issues 78 | assert component.registry["test_component.MockComponent"] == MockComponent 79 | comp = MockComponent() 80 | comp.store = "test_store" 81 | assert comp.store == "test_store" 82 | 83 | 84 | def test_missing_run(): 85 | with pytest.raises(ComponentError, match="must have a 'run\(\)' method"): 86 | 87 | @component 88 | class MockComponent: 89 | def another_method(self, input_value: int): 90 | return {"output_value": input_value} 91 | 92 | 93 | def test_set_input_types(): 94 | class MockComponent: 95 | def __init__(self): 96 | component.set_input_types(self, value=Any) 97 | 98 | def to_dict(self): 99 | return {} 100 | 101 | @classmethod 102 | def from_dict(cls, data): 103 | return cls() 104 | 105 | @component.output_types(value=int) 106 | def run(self, **kwargs): 107 | return {"value": 1} 108 | 109 | comp = MockComponent() 110 | assert comp.__canals_input__ == {"value": InputSocket("value", Any)} 111 | assert comp.run() == {"value": 1} 112 | 113 | 114 | def test_set_output_types(): 115 | @component 116 | class MockComponent: 117 | def __init__(self): 118 | component.set_output_types(self, value=int) 119 | 120 | def to_dict(self): 121 | return {} 122 | 123 | @classmethod 124 | def from_dict(cls, data): 125 | return cls() 126 | 127 | def run(self, value: int): 128 | return {"value": 1} 129 | 130 | comp = MockComponent() 131 | assert comp.__canals_output__ == {"value": OutputSocket("value", int)} 132 | 133 | 134 | def test_output_types_decorator_with_compatible_type(): 135 | @component 136 | class MockComponent: 137 | @component.output_types(value=int) 138 | def run(self, value: int): 139 | return {"value": 1} 140 | 141 | def to_dict(self): 142 | return {} 143 | 144 | @classmethod 145 | def from_dict(cls, data): 146 | return cls() 147 | 148 | comp = MockComponent() 149 | assert comp.__canals_output__ == {"value": OutputSocket("value", int)} 150 | 151 | 152 | def test_component_decorator_set_it_as_component(): 153 | @component 154 | class MockComponent: 155 | @component.output_types(value=int) 156 | def run(self, value: int): 157 | return {"value": 1} 158 | 159 | def to_dict(self): 160 | return {} 161 | 162 | @classmethod 163 | def from_dict(cls, data): 164 | return cls() 165 | 166 | comp = MockComponent() 167 | assert isinstance(comp, Component) 168 | 169 | 170 | def test_inputs_method_no_inputs(): 171 | @component 172 | class MockComponent: 173 | def run(self): 174 | return {"value": 1} 175 | 176 | comp = MockComponent() 177 | assert find_component_inputs(comp) == {} 178 | 179 | 180 | def test_inputs_method_one_input(): 181 | @component 182 | class MockComponent: 183 | def run(self, value: int): 184 | return {"value": 1} 185 | 186 | comp = MockComponent() 187 | assert find_component_inputs(comp) == {"value": {"is_mandatory": True, "is_variadic": False, "type": int}} 188 | 189 | 190 | def test_inputs_method_multiple_inputs(): 191 | @component 192 | class MockComponent: 193 | def run(self, value1: int, value2: str): 194 | return {"value": 1} 195 | 196 | comp = MockComponent() 197 | assert find_component_inputs(comp) == { 198 | "value1": {"is_mandatory": True, "is_variadic": False, "type": int}, 199 | "value2": {"is_mandatory": True, "is_variadic": False, "type": str}, 200 | } 201 | 202 | 203 | def test_inputs_method_multiple_inputs_optional(): 204 | @component 205 | class MockComponent: 206 | def run(self, value1: int, value2: Optional[str]): 207 | return {"value": 1} 208 | 209 | comp = MockComponent() 210 | assert find_component_inputs(comp) == { 211 | "value1": {"is_mandatory": True, "is_variadic": False, "type": int}, 212 | "value2": {"is_mandatory": True, "is_variadic": False, "type": typing.Optional[str]}, 213 | } 214 | 215 | 216 | def test_inputs_method_variadic_positional_args(): 217 | @component 218 | class MockComponent: 219 | def __init__(self): 220 | component.set_input_types(self, value=Any) 221 | 222 | def run(self, *args): 223 | return {"value": 1} 224 | 225 | comp = MockComponent() 226 | assert find_component_inputs(comp) == {"value": {"is_mandatory": True, "is_variadic": False, "type": typing.Any}} 227 | 228 | 229 | def test_inputs_method_variadic_keyword_positional_args(): 230 | @component 231 | class MockComponent: 232 | def __init__(self): 233 | component.set_input_types(self, value=Any) 234 | 235 | def run(self, **kwargs): 236 | return {"value": 1} 237 | 238 | comp = MockComponent() 239 | assert find_component_inputs(comp) == {"value": {"is_mandatory": True, "is_variadic": False, "type": typing.Any}} 240 | 241 | 242 | def test_inputs_dynamic_from_init(): 243 | @component 244 | class MockComponent: 245 | def __init__(self): 246 | component.set_input_types(self, value=int) 247 | 248 | def run(self, value: int, **kwargs): 249 | return {"value": 1} 250 | 251 | comp = MockComponent() 252 | assert find_component_inputs(comp) == {"value": {"is_mandatory": True, "is_variadic": False, "type": int}} 253 | 254 | 255 | def test_outputs_method_no_outputs(): 256 | @component 257 | class MockComponent: 258 | def run(self): 259 | return {} 260 | 261 | comp = MockComponent() 262 | assert find_component_outputs(comp) == {} 263 | 264 | 265 | def test_outputs_method_one_output(): 266 | @component 267 | class MockComponent: 268 | @component.output_types(value=int) 269 | def run(self): 270 | return {"value": 1} 271 | 272 | comp = MockComponent() 273 | assert find_component_outputs(comp) == {"value": {"type": int}} 274 | 275 | 276 | def test_outputs_method_multiple_outputs(): 277 | @component 278 | class MockComponent: 279 | @component.output_types(value1=int, value2=str) 280 | def run(self): 281 | return {"value1": 1, "value2": "test"} 282 | 283 | comp = MockComponent() 284 | assert find_component_outputs(comp) == {"value1": {"type": int}, "value2": {"type": str}} 285 | 286 | 287 | def test_outputs_dynamic_from_init(): 288 | @component 289 | class MockComponent: 290 | def __init__(self): 291 | component.set_output_types(self, value=int) 292 | 293 | def run(self): 294 | return {"value": 1} 295 | 296 | comp = MockComponent() 297 | assert find_component_outputs(comp) == {"value": {"type": int}} 298 | -------------------------------------------------------------------------------- /test/component/test_connection.py: -------------------------------------------------------------------------------- 1 | from canals.component.connection import Connection 2 | from canals.component.sockets import InputSocket, OutputSocket 3 | from canals.errors import PipelineConnectError 4 | 5 | import pytest 6 | 7 | 8 | @pytest.mark.parametrize( 9 | "c,expected", 10 | [ 11 | ( 12 | Connection("source_component", OutputSocket("out", int), "destination_component", InputSocket("in", int)), 13 | "source_component.out (int) --> (int) destination_component.in", 14 | ), 15 | ( 16 | Connection(None, None, "destination_component", InputSocket("in", int)), 17 | "input needed --> (int) destination_component.in", 18 | ), 19 | (Connection("source_component", OutputSocket("out", int), None, None), "source_component.out (int) --> output"), 20 | (Connection(None, None, None, None), "input needed --> output"), 21 | ], 22 | ) 23 | def test_repr(c, expected): 24 | assert str(c) == expected 25 | 26 | 27 | def test_is_mandatory(): 28 | c = Connection(None, None, "destination_component", InputSocket("in", int)) 29 | assert c.is_mandatory 30 | 31 | c = Connection(None, None, "destination_component", InputSocket("in", int, is_mandatory=False)) 32 | assert not c.is_mandatory 33 | 34 | c = Connection("source_component", OutputSocket("out", int), None, None) 35 | assert not c.is_mandatory 36 | 37 | 38 | def test_from_list_of_sockets(): 39 | sender_sockets = [ 40 | OutputSocket("out_int", int), 41 | OutputSocket("out_str", str), 42 | ] 43 | 44 | receiver_sockets = [ 45 | InputSocket("in_str", str), 46 | ] 47 | 48 | c = Connection.from_list_of_sockets("from_node", sender_sockets, "to_node", receiver_sockets) 49 | assert c.sender_socket.name == "out_str" # type:ignore 50 | 51 | 52 | def test_from_list_of_sockets_not_possible(): 53 | sender_sockets = [ 54 | OutputSocket("out_int", int), 55 | OutputSocket("out_str", str), 56 | ] 57 | 58 | receiver_sockets = [ 59 | InputSocket("in_list", list), 60 | InputSocket("in_tuple", tuple), 61 | ] 62 | 63 | with pytest.raises(PipelineConnectError, match="no matching connections available"): 64 | Connection.from_list_of_sockets("from_node", sender_sockets, "to_node", receiver_sockets) 65 | 66 | 67 | def test_from_list_of_sockets_too_many(): 68 | sender_sockets = [ 69 | OutputSocket("out_int", int), 70 | OutputSocket("out_str", str), 71 | ] 72 | 73 | receiver_sockets = [ 74 | InputSocket("in_int", int), 75 | InputSocket("in_str", str), 76 | ] 77 | 78 | with pytest.raises(PipelineConnectError, match="more than one connection is possible"): 79 | Connection.from_list_of_sockets("from_node", sender_sockets, "to_node", receiver_sockets) 80 | 81 | 82 | def test_from_list_of_sockets_only_one(): 83 | sender_sockets = [ 84 | OutputSocket("out_int", int), 85 | ] 86 | 87 | receiver_sockets = [ 88 | InputSocket("in_str", str), 89 | ] 90 | 91 | with pytest.raises(PipelineConnectError, match="their declared input and output types do not match"): 92 | Connection.from_list_of_sockets("from_node", sender_sockets, "to_node", receiver_sockets) 93 | -------------------------------------------------------------------------------- /test/conftest.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import pytest 4 | 5 | from unittest.mock import patch, MagicMock 6 | 7 | 8 | @pytest.fixture 9 | def test_files(): 10 | return Path(__file__).parent / "test_files" 11 | 12 | 13 | @pytest.fixture(autouse=True) 14 | def mock_mermaid_request(test_files): 15 | """ 16 | Prevents real requests to https://mermaid.ink/ 17 | """ 18 | with patch("canals.pipeline.draw.mermaid.requests.get") as mock_get: 19 | mock_response = MagicMock() 20 | mock_response.status_code = 200 21 | mock_response.content = open(test_files / "mermaid_mock" / "test_response.png", "rb").read() 22 | mock_get.return_value = mock_response 23 | yield 24 | -------------------------------------------------------------------------------- /test/pipeline/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepset-ai/canals/5fd3dcff3cebff81bb6fc38ed3303c2c2a725321/test/pipeline/__init__.py -------------------------------------------------------------------------------- /test/pipeline/integration/__init__.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: 2022-present deepset GmbH 2 | # 3 | # SPDX-License-Identifier: Apache-2.0 4 | -------------------------------------------------------------------------------- /test/pipeline/integration/test_complex_pipeline.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: 2022-present deepset GmbH 2 | # 3 | # SPDX-License-Identifier: Apache-2.0 4 | from pathlib import Path 5 | from pprint import pprint 6 | import logging 7 | 8 | from canals.pipeline import Pipeline 9 | from sample_components import ( 10 | Accumulate, 11 | AddFixedValue, 12 | Greet, 13 | Parity, 14 | Threshold, 15 | Double, 16 | Sum, 17 | Repeat, 18 | Subtract, 19 | MergeLoop, 20 | ) 21 | 22 | logging.basicConfig(level=logging.DEBUG) 23 | 24 | 25 | def test_complex_pipeline(tmp_path): 26 | loop_merger = MergeLoop(expected_type=int, inputs=["in_1", "in_2"]) 27 | summer = Sum() 28 | 29 | pipeline = Pipeline(max_loops_allowed=2) 30 | pipeline.add_component("greet_first", Greet(message="Hello, the value is {value}.")) 31 | pipeline.add_component("accumulate_1", Accumulate()) 32 | pipeline.add_component("add_two", AddFixedValue(add=2)) 33 | pipeline.add_component("parity", Parity()) 34 | pipeline.add_component("add_one", AddFixedValue(add=1)) 35 | pipeline.add_component("accumulate_2", Accumulate()) 36 | 37 | pipeline.add_component("loop_merger", loop_merger) 38 | pipeline.add_component("below_10", Threshold(threshold=10)) 39 | pipeline.add_component("double", Double()) 40 | 41 | pipeline.add_component("greet_again", Greet(message="Hello again, now the value is {value}.")) 42 | pipeline.add_component("sum", summer) 43 | 44 | pipeline.add_component("greet_enumerator", Greet(message="Hello from enumerator, here the value became {value}.")) 45 | pipeline.add_component("enumerate", Repeat(outputs=["first", "second"])) 46 | pipeline.add_component("add_three", AddFixedValue(add=3)) 47 | 48 | pipeline.add_component("diff", Subtract()) 49 | pipeline.add_component("greet_one_last_time", Greet(message="Bye bye! The value here is {value}!")) 50 | pipeline.add_component("replicate", Repeat(outputs=["first", "second"])) 51 | pipeline.add_component("add_five", AddFixedValue(add=5)) 52 | pipeline.add_component("add_four", AddFixedValue(add=4)) 53 | pipeline.add_component("accumulate_3", Accumulate()) 54 | 55 | pipeline.connect("greet_first", "accumulate_1") 56 | pipeline.connect("accumulate_1", "add_two") 57 | pipeline.connect("add_two", "parity") 58 | 59 | pipeline.connect("parity.even", "greet_again") 60 | pipeline.connect("greet_again", "sum.values") 61 | pipeline.connect("sum", "diff.first_value") 62 | pipeline.connect("diff", "greet_one_last_time") 63 | pipeline.connect("greet_one_last_time", "replicate") 64 | pipeline.connect("replicate.first", "add_five.value") 65 | pipeline.connect("replicate.second", "add_four.value") 66 | pipeline.connect("add_four", "accumulate_3") 67 | 68 | pipeline.connect("parity.odd", "add_one.value") 69 | pipeline.connect("add_one", "loop_merger.in_1") 70 | pipeline.connect("loop_merger", "below_10") 71 | 72 | pipeline.connect("below_10.below", "double") 73 | pipeline.connect("double", "loop_merger.in_2") 74 | 75 | pipeline.connect("below_10.above", "accumulate_2") 76 | pipeline.connect("accumulate_2", "diff.second_value") 77 | 78 | pipeline.connect("greet_enumerator", "enumerate") 79 | pipeline.connect("enumerate.second", "sum.values") 80 | 81 | pipeline.connect("enumerate.first", "add_three.value") 82 | pipeline.connect("add_three", "sum.values") 83 | 84 | pipeline.draw(tmp_path / "complex_pipeline.png") 85 | 86 | results = pipeline.run({"greet_first": {"value": 1}, "greet_enumerator": {"value": 1}}) 87 | assert results == {"accumulate_3": {"value": -7}, "add_five": {"result": -6}} 88 | 89 | 90 | if __name__ == "__main__": 91 | test_complex_pipeline(Path(__file__).parent) 92 | -------------------------------------------------------------------------------- /test/pipeline/integration/test_default_value.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: 2022-present deepset GmbH 2 | # 3 | # SPDX-License-Identifier: Apache-2.0 4 | from pathlib import Path 5 | from pprint import pprint 6 | 7 | from canals import Pipeline, component 8 | from sample_components import AddFixedValue, Sum 9 | 10 | import logging 11 | 12 | logging.basicConfig(level=logging.DEBUG) 13 | 14 | 15 | @component 16 | class WithDefault: 17 | @component.output_types(b=int) 18 | def run(self, a: int, b: int = 2): 19 | return {"c": a + b} 20 | 21 | 22 | def test_pipeline(tmp_path): 23 | # https://github.com/deepset-ai/canals/issues/105 24 | pipeline = Pipeline() 25 | pipeline.add_component("with_defaults", WithDefault()) 26 | pipeline.draw(tmp_path / "default_value.png") 27 | 28 | # Pass all the inputs 29 | results = pipeline.run({"with_defaults": {"a": 40, "b": 30}}) 30 | pprint(results) 31 | assert results == {"with_defaults": {"c": 70}} 32 | 33 | # Rely on default value for 'b' 34 | results = pipeline.run({"with_defaults": {"a": 40}}) 35 | pprint(results) 36 | assert results == {"with_defaults": {"c": 42}} 37 | 38 | 39 | if __name__ == "__main__": 40 | test_pipeline(Path(__file__).parent) 41 | -------------------------------------------------------------------------------- /test/pipeline/integration/test_distinct_loops_pipeline.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: 2022-present deepset GmbH 2 | # 3 | # SPDX-License-Identifier: Apache-2.0 4 | from typing import * 5 | from pathlib import Path 6 | from pprint import pprint 7 | 8 | from canals.pipeline import Pipeline 9 | from sample_components import AddFixedValue, MergeLoop, Remainder, FirstIntSelector 10 | 11 | import logging 12 | 13 | logging.basicConfig(level=logging.DEBUG) 14 | 15 | 16 | def test_pipeline_equally_long_branches(tmp_path): 17 | pipeline = Pipeline(max_loops_allowed=10) 18 | pipeline.add_component("merge", MergeLoop(expected_type=int, inputs=["in", "in_1", "in_2"])) 19 | pipeline.add_component("remainder", Remainder(divisor=3)) 20 | pipeline.add_component("add_one", AddFixedValue(add=1)) 21 | pipeline.add_component("add_two", AddFixedValue(add=2)) 22 | 23 | pipeline.connect("merge.value", "remainder.value") 24 | pipeline.connect("remainder.remainder_is_1", "add_two.value") 25 | pipeline.connect("remainder.remainder_is_2", "add_one.value") 26 | pipeline.connect("add_two", "merge.in_2") 27 | pipeline.connect("add_one", "merge.in_1") 28 | 29 | pipeline.draw(tmp_path / "distinct_loops_pipeline_same_branches.png") 30 | 31 | results = pipeline.run({"merge": {"in": 0}}) 32 | pprint(results) 33 | assert results == {"remainder": {"remainder_is_0": 0}} 34 | 35 | results = pipeline.run({"merge": {"in": 3}}) 36 | pprint(results) 37 | assert results == {"remainder": {"remainder_is_0": 3}} 38 | 39 | results = pipeline.run({"merge": {"in": 4}}) 40 | pprint(results) 41 | assert results == {"remainder": {"remainder_is_0": 6}} 42 | 43 | results = pipeline.run({"merge": {"in": 5}}) 44 | pprint(results) 45 | assert results == {"remainder": {"remainder_is_0": 6}} 46 | 47 | results = pipeline.run({"merge": {"in": 6}}) 48 | pprint(results) 49 | assert results == {"remainder": {"remainder_is_0": 6}} 50 | 51 | 52 | def test_pipeline_differing_branches(tmp_path): 53 | pipeline = Pipeline(max_loops_allowed=10) 54 | pipeline.add_component("merge", MergeLoop(expected_type=int, inputs=["in", "in_1", "in_2"])) 55 | pipeline.add_component("remainder", Remainder(divisor=3)) 56 | pipeline.add_component("add_one", AddFixedValue(add=1)) 57 | pipeline.add_component("add_two_1", AddFixedValue(add=1)) 58 | pipeline.add_component("add_two_2", AddFixedValue(add=1)) 59 | 60 | pipeline.connect("merge.value", "remainder.value") 61 | pipeline.connect("remainder.remainder_is_1", "add_two_1.value") 62 | pipeline.connect("add_two_1", "add_two_2.value") 63 | pipeline.connect("add_two_2", "merge.in_2") 64 | pipeline.connect("remainder.remainder_is_2", "add_one.value") 65 | pipeline.connect("add_one", "merge.in_1") 66 | 67 | pipeline.draw(tmp_path / "distinct_loops_pipeline_different_branches.png") 68 | 69 | results = pipeline.run({"merge": {"in": 0}}) 70 | pprint(results) 71 | assert results == {"remainder": {"remainder_is_0": 0}} 72 | 73 | results = pipeline.run({"merge": {"in": 3}}) 74 | pprint(results) 75 | assert results == {"remainder": {"remainder_is_0": 3}} 76 | 77 | results = pipeline.run({"merge": {"in": 4}}) 78 | pprint(results) 79 | assert results == {"remainder": {"remainder_is_0": 6}} 80 | 81 | results = pipeline.run({"merge": {"in": 5}}) 82 | pprint(results) 83 | assert results == {"remainder": {"remainder_is_0": 6}} 84 | 85 | results = pipeline.run({"merge": {"in": 6}}) 86 | pprint(results) 87 | assert results == {"remainder": {"remainder_is_0": 6}} 88 | 89 | 90 | def test_pipeline_differing_branches_variadic(tmp_path): 91 | pipeline = Pipeline(max_loops_allowed=10) 92 | pipeline.add_component("merge", FirstIntSelector()) 93 | pipeline.add_component("remainder", Remainder(divisor=3)) 94 | pipeline.add_component("add_one", AddFixedValue(add=1)) 95 | pipeline.add_component("add_two_1", AddFixedValue(add=1)) 96 | pipeline.add_component("add_two_2", AddFixedValue(add=1)) 97 | 98 | pipeline.connect("merge", "remainder.value") 99 | pipeline.connect("remainder.remainder_is_1", "add_two_1.value") 100 | pipeline.connect("add_two_1", "add_two_2.value") 101 | pipeline.connect("add_two_2", "merge.inputs") 102 | pipeline.connect("remainder.remainder_is_2", "add_one.value") 103 | pipeline.connect("add_one", "merge.inputs") 104 | 105 | pipeline.draw(tmp_path / "distinct_loops_pipeline_different_branches_variadic.png") 106 | 107 | results = pipeline.run({"merge": {"inputs": 0}}) 108 | pprint(results) 109 | assert results == {"remainder": {"remainder_is_0": 0}} 110 | 111 | results = pipeline.run({"merge": {"inputs": 3}}) 112 | pprint(results) 113 | assert results == {"remainder": {"remainder_is_0": 3}} 114 | 115 | results = pipeline.run({"merge": {"inputs": 4}}) 116 | pprint(results) 117 | assert results == {"remainder": {"remainder_is_0": 6}} 118 | 119 | results = pipeline.run({"merge": {"inputs": 5}}) 120 | pprint(results) 121 | assert results == {"remainder": {"remainder_is_0": 6}} 122 | 123 | results = pipeline.run({"merge": {"inputs": 6}}) 124 | pprint(results) 125 | assert results == {"remainder": {"remainder_is_0": 6}} 126 | 127 | 128 | if __name__ == "__main__": 129 | test_pipeline_equally_long_branches(Path(__file__).parent) 130 | test_pipeline_differing_branches(Path(__file__).parent) 131 | test_pipeline_differing_branches_variadic(Path(__file__).parent) 132 | -------------------------------------------------------------------------------- /test/pipeline/integration/test_double_loop_pipeline.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: 2022-present deepset GmbH 2 | # 3 | # SPDX-License-Identifier: Apache-2.0 4 | from typing import * 5 | from pathlib import Path 6 | from pprint import pprint 7 | 8 | from canals.pipeline import Pipeline 9 | from sample_components import Accumulate, AddFixedValue, Threshold, MergeLoop 10 | 11 | import logging 12 | 13 | logging.basicConfig(level=logging.DEBUG) 14 | 15 | 16 | def test_pipeline(tmp_path): 17 | accumulator = Accumulate() 18 | 19 | pipeline = Pipeline(max_loops_allowed=10) 20 | pipeline.add_component("add_one", AddFixedValue(add=1)) 21 | pipeline.add_component("merge", MergeLoop(expected_type=int, inputs=["in_1", "in_2", "in_3"])) 22 | pipeline.add_component("below_10", Threshold(threshold=10)) 23 | pipeline.add_component("below_5", Threshold(threshold=5)) 24 | pipeline.add_component("add_three", AddFixedValue(add=3)) 25 | pipeline.add_component("accumulator", accumulator) 26 | pipeline.add_component("add_two", AddFixedValue(add=2)) 27 | 28 | pipeline.connect("add_one.result", "merge.in_1") 29 | pipeline.connect("merge.value", "below_10.value") 30 | pipeline.connect("below_10.below", "accumulator.value") 31 | pipeline.connect("accumulator.value", "below_5.value") 32 | pipeline.connect("below_5.above", "add_three.value") 33 | pipeline.connect("below_5.below", "merge.in_2") 34 | pipeline.connect("add_three.result", "merge.in_3") 35 | pipeline.connect("below_10.above", "add_two.value") 36 | 37 | pipeline.draw(tmp_path / "double_loop_pipeline.png") 38 | 39 | results = pipeline.run({"add_one": {"value": 3}}) 40 | pprint(results) 41 | print("accumulator: ", accumulator.state) 42 | 43 | assert results == {"add_two": {"result": 13}} 44 | assert accumulator.state == 8 45 | 46 | 47 | if __name__ == "__main__": 48 | test_pipeline(Path(__file__).parent) 49 | -------------------------------------------------------------------------------- /test/pipeline/integration/test_dynamic_inputs_pipeline.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from canals import Pipeline 3 | from sample_components import FString, Hello, TextSplitter 4 | 5 | 6 | def test_pipeline(tmp_path): 7 | pipeline = Pipeline() 8 | pipeline.add_component("hello", Hello()) 9 | pipeline.add_component("fstring", FString(template="This is the greeting: {greeting}!", variables=["greeting"])) 10 | pipeline.add_component("splitter", TextSplitter()) 11 | pipeline.connect("hello.output", "fstring.greeting") 12 | pipeline.connect("fstring.string", "splitter.sentence") 13 | 14 | pipeline.draw(tmp_path / "dynamic_inputs_pipeline.png") 15 | 16 | output = pipeline.run({"hello": {"word": "Alice"}}) 17 | assert output == {"splitter": {"output": ["This", "is", "the", "greeting:", "Hello,", "Alice!!"]}} 18 | 19 | output = pipeline.run({"hello": {"word": "Alice"}, "fstring": {"template": "Received: {greeting}"}}) 20 | assert output == {"splitter": {"output": ["Received:", "Hello,", "Alice!"]}} 21 | 22 | 23 | if __name__ == "__main__": 24 | test_pipeline(Path(__file__).parent) 25 | -------------------------------------------------------------------------------- /test/pipeline/integration/test_fixed_decision_and_merge_pipeline.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: 2022-present deepset GmbH 2 | # 3 | # SPDX-License-Identifier: Apache-2.0 4 | import logging 5 | from pathlib import Path 6 | from pprint import pprint 7 | 8 | from canals.pipeline import Pipeline 9 | from sample_components import AddFixedValue, Parity, Double, Subtract 10 | 11 | logging.basicConfig(level=logging.DEBUG) 12 | 13 | 14 | def test_pipeline(tmp_path): 15 | pipeline = Pipeline() 16 | pipeline.add_component("add_one", AddFixedValue()) 17 | pipeline.add_component("parity", Parity()) 18 | pipeline.add_component("add_ten", AddFixedValue(add=10)) 19 | pipeline.add_component("double", Double()) 20 | pipeline.add_component("add_four", AddFixedValue(add=4)) 21 | pipeline.add_component("add_two", AddFixedValue()) 22 | pipeline.add_component("add_two_as_well", AddFixedValue()) 23 | pipeline.add_component("diff", Subtract()) 24 | 25 | pipeline.connect("add_one.result", "parity.value") 26 | pipeline.connect("parity.even", "add_four.value") 27 | pipeline.connect("parity.odd", "double.value") 28 | pipeline.connect("add_ten.result", "diff.first_value") 29 | pipeline.connect("double.value", "diff.second_value") 30 | pipeline.connect("parity.odd", "add_ten.value") 31 | pipeline.connect("add_four.result", "add_two.value") 32 | pipeline.connect("add_four.result", "add_two_as_well.value") 33 | 34 | pipeline.draw(tmp_path / "fixed_decision_and_merge_pipeline.png") 35 | 36 | results = pipeline.run( 37 | { 38 | "add_one": {"value": 1}, 39 | "add_two": {"add": 2}, 40 | "add_two_as_well": {"add": 2}, 41 | } 42 | ) 43 | pprint(results) 44 | 45 | results == { 46 | "add_two": {"result": 8}, 47 | } 48 | 49 | results = pipeline.run( 50 | { 51 | "add_one": {"value": 2}, 52 | "add_two": {"add": 2}, 53 | "add_two_as_well": {"add": 2}, 54 | } 55 | ) 56 | pprint(results) 57 | 58 | results == { 59 | "diff": {"difference": 7}, 60 | } 61 | 62 | 63 | if __name__ == "__main__": 64 | test_pipeline(Path(__file__).parent) 65 | -------------------------------------------------------------------------------- /test/pipeline/integration/test_fixed_decision_pipeline.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: 2022-present deepset GmbH 2 | # 3 | # SPDX-License-Identifier: Apache-2.0 4 | from pathlib import Path 5 | from pprint import pprint 6 | 7 | from canals.pipeline import Pipeline 8 | from sample_components import AddFixedValue, Parity, Double 9 | 10 | import logging 11 | 12 | logging.basicConfig(level=logging.DEBUG) 13 | 14 | 15 | def test_pipeline(tmp_path): 16 | pipeline = Pipeline() 17 | pipeline.add_component("add_one", AddFixedValue(add=1)) 18 | pipeline.add_component("parity", Parity()) 19 | pipeline.add_component("add_ten", AddFixedValue(add=10)) 20 | pipeline.add_component("double", Double()) 21 | pipeline.add_component("add_three", AddFixedValue(add=3)) 22 | 23 | pipeline.connect("add_one.result", "parity.value") 24 | pipeline.connect("parity.even", "add_ten.value") 25 | pipeline.connect("parity.odd", "double.value") 26 | pipeline.connect("add_ten.result", "add_three.value") 27 | 28 | pipeline.draw(tmp_path / "fixed_decision_pipeline.png") 29 | 30 | results = pipeline.run({"add_one": {"value": 1}}) 31 | pprint(results) 32 | assert results == {"add_three": {"result": 15}} 33 | 34 | results = pipeline.run({"add_one": {"value": 2}}) 35 | pprint(results) 36 | assert results == {"double": {"value": 6}} 37 | 38 | 39 | if __name__ == "__main__": 40 | test_pipeline(Path(__file__).parent) 41 | -------------------------------------------------------------------------------- /test/pipeline/integration/test_fixed_merging_pipeline.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: 2022-present deepset GmbH 2 | # 3 | # SPDX-License-Identifier: Apache-2.0 4 | from pathlib import Path 5 | from pprint import pprint 6 | 7 | from canals.pipeline import Pipeline 8 | from sample_components import AddFixedValue, Subtract 9 | 10 | import logging 11 | 12 | logging.basicConfig(level=logging.DEBUG) 13 | 14 | 15 | def test_pipeline(tmp_path): 16 | pipeline = Pipeline() 17 | pipeline.add_component("first_addition", AddFixedValue(add=2)) 18 | pipeline.add_component("second_addition", AddFixedValue(add=2)) 19 | pipeline.add_component("third_addition", AddFixedValue(add=2)) 20 | pipeline.add_component("diff", Subtract()) 21 | pipeline.add_component("fourth_addition", AddFixedValue(add=1)) 22 | 23 | pipeline.connect("first_addition.result", "second_addition.value") 24 | pipeline.connect("second_addition.result", "diff.first_value") 25 | pipeline.connect("third_addition.result", "diff.second_value") 26 | pipeline.connect("diff", "fourth_addition.value") 27 | 28 | pipeline.draw(tmp_path / "fixed_merging_pipeline.png") 29 | 30 | results = pipeline.run( 31 | { 32 | "first_addition": {"value": 1}, 33 | "third_addition": {"value": 1}, 34 | } 35 | ) 36 | pprint(results) 37 | 38 | assert results == {"fourth_addition": {"result": 3}} 39 | 40 | 41 | if __name__ == "__main__": 42 | test_pipeline(Path(__file__).parent) 43 | -------------------------------------------------------------------------------- /test/pipeline/integration/test_joiners.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: 2022-present deepset GmbH 2 | # 3 | # SPDX-License-Identifier: Apache-2.0 4 | from pathlib import Path 5 | from pprint import pprint 6 | 7 | from canals.pipeline import Pipeline 8 | from sample_components import StringJoiner, StringListJoiner, Hello, TextSplitter 9 | 10 | import logging 11 | 12 | logging.basicConfig(level=logging.DEBUG) 13 | 14 | 15 | def test_joiner(tmp_path): 16 | pipeline = Pipeline() 17 | pipeline.add_component("hello_one", Hello()) 18 | pipeline.add_component("hello_two", Hello()) 19 | pipeline.add_component("hello_three", Hello()) 20 | pipeline.add_component("joiner", StringJoiner()) 21 | 22 | pipeline.connect("hello_one", "hello_two") 23 | pipeline.connect("hello_two", "joiner") 24 | pipeline.connect("hello_three", "joiner") 25 | 26 | pipeline.draw(tmp_path / "joiner_pipeline.png") 27 | 28 | results = pipeline.run({"hello_one": {"word": "world"}, "hello_three": {"word": "my friend"}}) 29 | assert results == {"joiner": {"output": "Hello, my friend! Hello, Hello, world!!"}} 30 | 31 | 32 | def test_joiner_with_lists(tmp_path): 33 | pipeline = Pipeline() 34 | pipeline.add_component("first", TextSplitter()) 35 | pipeline.add_component("second", TextSplitter()) 36 | pipeline.add_component("joiner", StringListJoiner()) 37 | 38 | pipeline.connect("first", "joiner") 39 | pipeline.connect("second", "joiner") 40 | 41 | pipeline.draw(tmp_path / "joiner_list_pipeline.png") 42 | 43 | results = pipeline.run({"first": {"sentence": "Hello world!"}, "second": {"sentence": "How are you?"}}) 44 | assert results == {"joiner": {"output": ["Hello", "world!", "How", "are", "you?"]}} 45 | 46 | 47 | def test_joiner_with_pipeline_run(tmp_path): 48 | pipeline = Pipeline() 49 | pipeline.add_component("hello", Hello()) 50 | pipeline.add_component("joiner", StringJoiner()) 51 | pipeline.connect("hello", "joiner") 52 | 53 | pipeline.draw(tmp_path / "joiner_with_pipeline_run.png") 54 | 55 | results = pipeline.run({"hello": {"word": "world"}, "joiner": {"input_str": "another string!"}}) 56 | assert results == {"joiner": {"output": "another string! Hello, world!"}} 57 | 58 | 59 | if __name__ == "__main__": 60 | test_joiner(Path(__file__).parent) 61 | test_joiner_with_lists(Path(__file__).parent) 62 | test_joiner_with_pipeline_run(Path(__file__).parent) 63 | -------------------------------------------------------------------------------- /test/pipeline/integration/test_linear_pipeline.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: 2022-present deepset GmbH 2 | # 3 | # SPDX-License-Identifier: Apache-2.0 4 | from pathlib import Path 5 | from pprint import pprint 6 | 7 | from canals.pipeline import Pipeline 8 | from sample_components import AddFixedValue, Double 9 | 10 | import logging 11 | 12 | logging.basicConfig(level=logging.DEBUG) 13 | 14 | 15 | def test_pipeline(tmp_path): 16 | pipeline = Pipeline() 17 | pipeline.add_component("first_addition", AddFixedValue(add=2)) 18 | pipeline.add_component("second_addition", AddFixedValue()) 19 | pipeline.add_component("double", Double()) 20 | pipeline.connect("first_addition", "double") 21 | pipeline.connect("double", "second_addition") 22 | 23 | pipeline.draw(tmp_path / "linear_pipeline.png") 24 | 25 | results = pipeline.run({"first_addition": {"value": 1}}) 26 | pprint(results) 27 | 28 | assert results == {"second_addition": {"result": 7}} 29 | 30 | 31 | if __name__ == "__main__": 32 | test_pipeline(Path(__file__).parent) 33 | -------------------------------------------------------------------------------- /test/pipeline/integration/test_looping_and_merge_pipeline.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: 2022-present deepset GmbH 2 | # 3 | # SPDX-License-Identifier: Apache-2.0 4 | from typing import * 5 | from pathlib import Path 6 | from pprint import pprint 7 | 8 | from canals.pipeline import Pipeline 9 | from sample_components import Accumulate, AddFixedValue, Threshold, Sum, FirstIntSelector, MergeLoop 10 | 11 | import logging 12 | 13 | logging.basicConfig(level=logging.DEBUG) 14 | 15 | 16 | def test_pipeline_fixed(tmp_path): 17 | accumulator = Accumulate() 18 | pipeline = Pipeline(max_loops_allowed=10) 19 | pipeline.add_component("add_zero", AddFixedValue(add=0)) 20 | pipeline.add_component("merge", MergeLoop(expected_type=int, inputs=["in_1", "in_2"])) 21 | pipeline.add_component("sum", Sum()) 22 | pipeline.add_component("below_10", Threshold(threshold=10)) 23 | pipeline.add_component("add_one", AddFixedValue(add=1)) 24 | pipeline.add_component("counter", accumulator) 25 | pipeline.add_component("add_two", AddFixedValue(add=2)) 26 | 27 | pipeline.connect("add_zero", "merge.in_1") 28 | pipeline.connect("merge", "below_10.value") 29 | pipeline.connect("below_10.below", "add_one.value") 30 | pipeline.connect("add_one.result", "counter.value") 31 | pipeline.connect("counter.value", "merge.in_2") 32 | pipeline.connect("below_10.above", "add_two.value") 33 | pipeline.connect("add_two.result", "sum.values") 34 | 35 | pipeline.draw(tmp_path / "looping_and_fixed_merge_pipeline.png") 36 | 37 | results = pipeline.run( 38 | {"add_zero": {"value": 8}, "sum": {"values": 2}}, 39 | ) 40 | pprint(results) 41 | print("accumulate: ", accumulator.state) 42 | 43 | assert results == {"sum": {"total": 23}} 44 | assert accumulator.state == 19 45 | 46 | 47 | def test_pipeline_variadic(tmp_path): 48 | accumulator = Accumulate() 49 | pipeline = Pipeline(max_loops_allowed=10) 50 | pipeline.add_component("add_zero", AddFixedValue(add=0)) 51 | pipeline.add_component("merge", FirstIntSelector()) 52 | pipeline.add_component("sum", Sum()) 53 | pipeline.add_component("below_10", Threshold(threshold=10)) 54 | pipeline.add_component("add_one", AddFixedValue(add=1)) 55 | pipeline.add_component("counter", accumulator) 56 | pipeline.add_component("add_two", AddFixedValue(add=2)) 57 | 58 | pipeline.connect("add_zero", "merge") 59 | pipeline.connect("merge", "below_10.value") 60 | pipeline.connect("below_10.below", "add_one.value") 61 | pipeline.connect("add_one.result", "counter.value") 62 | pipeline.connect("counter.value", "merge.inputs") 63 | pipeline.connect("below_10.above", "add_two.value") 64 | pipeline.connect("add_two.result", "sum.values") 65 | 66 | pipeline.draw(tmp_path / "looping_and_variadic_merge_pipeline.png") 67 | 68 | results = pipeline.run( 69 | {"add_zero": {"value": 8}, "sum": {"values": 2}}, 70 | ) 71 | pprint(results) 72 | print("accumulate: ", accumulator.state) 73 | 74 | assert results == {"sum": {"total": 23}} 75 | assert accumulator.state == 19 76 | 77 | 78 | if __name__ == "__main__": 79 | test_pipeline_fixed(Path(__file__).parent) 80 | test_pipeline_variadic(Path(__file__).parent) 81 | -------------------------------------------------------------------------------- /test/pipeline/integration/test_looping_pipeline.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: 2022-present deepset GmbH 2 | # 3 | # SPDX-License-Identifier: Apache-2.0 4 | from typing import * 5 | from pathlib import Path 6 | from pprint import pprint 7 | 8 | from canals.pipeline import Pipeline 9 | from sample_components import Accumulate, AddFixedValue, Threshold, MergeLoop, FirstIntSelector 10 | 11 | import logging 12 | 13 | logging.basicConfig(level=logging.DEBUG) 14 | 15 | 16 | def test_pipeline(tmp_path): 17 | accumulator = Accumulate() 18 | 19 | pipeline = Pipeline(max_loops_allowed=10) 20 | pipeline.add_component("add_one", AddFixedValue(add=1)) 21 | pipeline.add_component("merge", MergeLoop(expected_type=int, inputs=["in_1", "in_2"])) 22 | pipeline.add_component("below_10", Threshold(threshold=10)) 23 | pipeline.add_component("accumulator", accumulator) 24 | pipeline.add_component("add_two", AddFixedValue(add=2)) 25 | 26 | pipeline.connect("add_one.result", "merge.in_1") 27 | pipeline.connect("merge.value", "below_10.value") 28 | pipeline.connect("below_10.below", "accumulator.value") 29 | pipeline.connect("accumulator.value", "merge.in_2") 30 | pipeline.connect("below_10.above", "add_two.value") 31 | 32 | pipeline.draw(tmp_path / "looping_pipeline.png") 33 | 34 | results = pipeline.run({"add_one": {"value": 3}}) 35 | pprint(results) 36 | 37 | assert results == {"add_two": {"result": 18}} 38 | assert accumulator.state == 16 39 | 40 | 41 | def test_pipeline_direct_io_loop(tmp_path): 42 | accumulator = Accumulate() 43 | 44 | pipeline = Pipeline(max_loops_allowed=10) 45 | pipeline.add_component("merge", MergeLoop(expected_type=int, inputs=["in_1", "in_2"])) 46 | pipeline.add_component("below_10", Threshold(threshold=10)) 47 | pipeline.add_component("accumulator", accumulator) 48 | 49 | pipeline.connect("merge.value", "below_10.value") 50 | pipeline.connect("below_10.below", "accumulator.value") 51 | pipeline.connect("accumulator.value", "merge.in_2") 52 | 53 | pipeline.draw(tmp_path / "looping_pipeline_direct_io_loop.png") 54 | 55 | results = pipeline.run({"merge": {"in_1": 4}}) 56 | pprint(results) 57 | 58 | assert results == {"below_10": {"above": 16}} 59 | assert accumulator.state == 16 60 | 61 | 62 | def test_pipeline_fixed_merger_input(tmp_path): 63 | accumulator = Accumulate() 64 | 65 | pipeline = Pipeline(max_loops_allowed=10) 66 | pipeline.add_component("merge", MergeLoop(expected_type=int, inputs=["in_1", "in_2"])) 67 | pipeline.add_component("below_10", Threshold(threshold=10)) 68 | pipeline.add_component("accumulator", accumulator) 69 | pipeline.add_component("add_two", AddFixedValue(add=2)) 70 | 71 | pipeline.connect("merge.value", "below_10.value") 72 | pipeline.connect("below_10.below", "accumulator.value") 73 | pipeline.connect("accumulator.value", "merge.in_2") 74 | pipeline.connect("below_10.above", "add_two.value") 75 | 76 | pipeline.draw(tmp_path / "looping_pipeline_fixed_merger_input.png") 77 | 78 | results = pipeline.run({"merge": {"in_1": 4}}) 79 | pprint(results) 80 | print("accumulator: ", accumulator.state) 81 | 82 | assert results == {"add_two": {"result": 18}} 83 | assert accumulator.state == 16 84 | 85 | 86 | def test_pipeline_variadic_merger_input(tmp_path): 87 | accumulator = Accumulate() 88 | 89 | pipeline = Pipeline(max_loops_allowed=10) 90 | pipeline.add_component("merge", FirstIntSelector()) 91 | pipeline.add_component("below_10", Threshold(threshold=10)) 92 | pipeline.add_component("accumulator", accumulator) 93 | pipeline.add_component("add_two", AddFixedValue(add=2)) 94 | 95 | pipeline.connect("merge", "below_10.value") 96 | pipeline.connect("below_10.below", "accumulator.value") 97 | pipeline.connect("accumulator.value", "merge.inputs") 98 | pipeline.connect("below_10.above", "add_two.value") 99 | 100 | pipeline.draw(tmp_path / "looping_pipeline_variadic_merger_input.png") 101 | 102 | results = pipeline.run({"merge": {"inputs": 4}}) 103 | pprint(results) 104 | 105 | assert results == {"add_two": {"result": 18}} 106 | assert accumulator.state == 16 107 | 108 | 109 | if __name__ == "__main__": 110 | test_pipeline(Path(__file__).parent) 111 | test_pipeline_direct_io_loop(Path(__file__).parent) 112 | test_pipeline_fixed_merger_input(Path(__file__).parent) 113 | test_pipeline_variadic_merger_input(Path(__file__).parent) 114 | -------------------------------------------------------------------------------- /test/pipeline/integration/test_mutable_inputs.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | from canals import Pipeline, component 4 | from sample_components import StringListJoiner 5 | 6 | 7 | @component 8 | class InputMangler: 9 | @component.output_types(mangled_list=List[str]) 10 | def run(self, input_list: List[str]): 11 | input_list.append("extra_item") 12 | return {"mangled_list": input_list} 13 | 14 | 15 | def test_mutable_inputs(): 16 | pipe = Pipeline() 17 | pipe.add_component("mangler1", InputMangler()) 18 | pipe.add_component("mangler2", InputMangler()) 19 | pipe.add_component("concat1", StringListJoiner()) 20 | pipe.add_component("concat2", StringListJoiner()) 21 | pipe.connect("mangler1", "concat1") 22 | pipe.connect("mangler2", "concat2") 23 | 24 | mylist = ["foo", "bar"] 25 | 26 | result = pipe.run(data={"mangler1": {"input_list": mylist}, "mangler2": {"input_list": mylist}}) 27 | assert result["concat1"]["output"] == result["concat2"]["output"] == ["foo", "bar", "extra_item"] 28 | -------------------------------------------------------------------------------- /test/pipeline/integration/test_parallel_branches_pipeline.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: 2022-present deepset GmbH 2 | # 3 | # SPDX-License-Identifier: Apache-2.0 4 | from pathlib import Path 5 | from pprint import pprint 6 | 7 | from canals.pipeline import Pipeline 8 | from sample_components import AddFixedValue, Repeat, Double 9 | 10 | import logging 11 | 12 | logging.basicConfig(level=logging.DEBUG) 13 | 14 | 15 | def test_pipeline(tmp_path): 16 | pipeline = Pipeline() 17 | pipeline.add_component("add_one", AddFixedValue(add=1)) 18 | pipeline.add_component("repeat", Repeat(outputs=["first", "second"])) 19 | pipeline.add_component("add_ten", AddFixedValue(add=10)) 20 | pipeline.add_component("double", Double()) 21 | pipeline.add_component("add_three", AddFixedValue(add=3)) 22 | pipeline.add_component("add_one_again", AddFixedValue(add=1)) 23 | 24 | pipeline.connect("add_one.result", "repeat.value") 25 | pipeline.connect("repeat.first", "add_ten.value") 26 | pipeline.connect("repeat.second", "double.value") 27 | pipeline.connect("repeat.second", "add_three.value") 28 | pipeline.connect("add_three.result", "add_one_again.value") 29 | 30 | pipeline.draw(tmp_path / "parallel_branches_pipeline.png") 31 | 32 | results = pipeline.run({"add_one": {"value": 1}}) 33 | pprint(results) 34 | 35 | assert results == { 36 | "add_one_again": {"result": 6}, 37 | "add_ten": {"result": 12}, 38 | "double": {"value": 4}, 39 | } 40 | 41 | 42 | if __name__ == "__main__": 43 | test_pipeline(Path(__file__).parent) 44 | -------------------------------------------------------------------------------- /test/pipeline/integration/test_self_loop.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: 2022-present deepset GmbH 2 | # 3 | # SPDX-License-Identifier: Apache-2.0 4 | from typing import Optional 5 | from pathlib import Path 6 | from pprint import pprint 7 | 8 | from canals import component 9 | from canals.pipeline import Pipeline 10 | from sample_components import AddFixedValue, SelfLoop 11 | 12 | import logging 13 | 14 | logging.basicConfig(level=logging.DEBUG) 15 | 16 | 17 | def test_pipeline_one_node(tmp_path): 18 | pipeline = Pipeline(max_loops_allowed=10) 19 | pipeline.add_component("self_loop", SelfLoop()) 20 | pipeline.connect("self_loop.current_value", "self_loop.values") 21 | 22 | pipeline.draw(tmp_path / "self_looping_pipeline_one_node.png") 23 | 24 | results = pipeline.run({"self_loop": {"values": 5}}) 25 | pprint(results) 26 | 27 | assert results["self_loop"]["final_result"] == 0 28 | 29 | 30 | def test_pipeline(tmp_path): 31 | pipeline = Pipeline(max_loops_allowed=10) 32 | pipeline.add_component("add_1", AddFixedValue()) 33 | pipeline.add_component("self_loop", SelfLoop()) 34 | pipeline.add_component("add_2", AddFixedValue()) 35 | pipeline.connect("add_1", "self_loop.values") 36 | pipeline.connect("self_loop.current_value", "self_loop.values") 37 | pipeline.connect("self_loop.final_result", "add_2.value") 38 | 39 | pipeline.draw(tmp_path / "self_looping_pipeline.png") 40 | 41 | results = pipeline.run({"add_1": {"value": 5}}) 42 | pprint(results) 43 | 44 | assert results["add_2"]["result"] == 1 45 | 46 | 47 | if __name__ == "__main__": 48 | test_pipeline_one_node(Path(__file__).parent) 49 | test_pipeline(Path(__file__).parent) 50 | -------------------------------------------------------------------------------- /test/pipeline/integration/test_variable_decision_and_merge_pipeline.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: 2022-present deepset GmbH 2 | # 3 | # SPDX-License-Identifier: Apache-2.0 4 | import logging 5 | from pathlib import Path 6 | from pprint import pprint 7 | 8 | from canals.pipeline import Pipeline 9 | from sample_components import AddFixedValue, Remainder, Double, Sum 10 | 11 | logging.basicConfig(level=logging.DEBUG) 12 | 13 | 14 | def test_pipeline(tmp_path): 15 | pipeline = Pipeline() 16 | pipeline.add_component("add_one", AddFixedValue()) 17 | pipeline.add_component("parity", Remainder(divisor=2)) 18 | pipeline.add_component("add_ten", AddFixedValue(add=10)) 19 | pipeline.add_component("double", Double()) 20 | pipeline.add_component("add_four", AddFixedValue(add=4)) 21 | pipeline.add_component("add_one_again", AddFixedValue()) 22 | pipeline.add_component("sum", Sum()) 23 | 24 | pipeline.connect("add_one.result", "parity.value") 25 | pipeline.connect("parity.remainder_is_0", "add_ten.value") 26 | pipeline.connect("parity.remainder_is_1", "double.value") 27 | pipeline.connect("add_one.result", "sum.values") 28 | pipeline.connect("add_ten.result", "sum.values") 29 | pipeline.connect("double.value", "sum.values") 30 | pipeline.connect("parity.remainder_is_1", "add_four.value") 31 | pipeline.connect("add_four.result", "add_one_again.value") 32 | pipeline.connect("add_one_again.result", "sum.values") 33 | 34 | pipeline.draw(tmp_path / "variable_decision_and_merge_pipeline.png") 35 | 36 | results = pipeline.run({"add_one": {"value": 1}}) 37 | pprint(results) 38 | assert results == {"sum": {"total": 14}} 39 | 40 | results = pipeline.run({"add_one": {"value": 2}}) 41 | pprint(results) 42 | assert results == {"sum": {"total": 17}} 43 | 44 | 45 | if __name__ == "__main__": 46 | test_pipeline(Path(__file__).parent) 47 | -------------------------------------------------------------------------------- /test/pipeline/integration/test_variable_decision_pipeline.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: 2022-present deepset GmbH 2 | # 3 | # SPDX-License-Identifier: Apache-2.0 4 | from pathlib import Path 5 | from pprint import pprint 6 | 7 | from canals.pipeline import Pipeline 8 | from sample_components import AddFixedValue, Remainder, Double 9 | 10 | import logging 11 | 12 | logging.basicConfig(level=logging.DEBUG) 13 | 14 | 15 | def test_pipeline(tmp_path): 16 | pipeline = Pipeline() 17 | pipeline.add_component("add_one", AddFixedValue(add=1)) 18 | pipeline.add_component("remainder", Remainder(divisor=3)) 19 | pipeline.add_component("add_ten", AddFixedValue(add=10)) 20 | pipeline.add_component("double", Double()) 21 | pipeline.add_component("add_three", AddFixedValue(add=3)) 22 | pipeline.add_component("add_one_again", AddFixedValue(add=1)) 23 | 24 | pipeline.connect("add_one.result", "remainder.value") 25 | pipeline.connect("remainder.remainder_is_0", "add_ten.value") 26 | pipeline.connect("remainder.remainder_is_1", "double.value") 27 | pipeline.connect("remainder.remainder_is_2", "add_three.value") 28 | pipeline.connect("add_three.result", "add_one_again.value") 29 | 30 | pipeline.draw(tmp_path / "variable_decision_pipeline.png") 31 | 32 | results = pipeline.run({"add_one": {"value": 1}}) 33 | pprint(results) 34 | 35 | assert results == {"add_one_again": {"result": 6}} 36 | 37 | 38 | if __name__ == "__main__": 39 | test_pipeline(Path(__file__).parent) 40 | -------------------------------------------------------------------------------- /test/pipeline/integration/test_variable_merging_pipeline.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: 2022-present deepset GmbH 2 | # 3 | # SPDX-License-Identifier: Apache-2.0 4 | from pathlib import Path 5 | from pprint import pprint 6 | 7 | from canals.pipeline import Pipeline 8 | from sample_components import AddFixedValue, Sum 9 | 10 | import logging 11 | 12 | logging.basicConfig(level=logging.DEBUG) 13 | 14 | 15 | def test_pipeline(tmp_path): 16 | pipeline = Pipeline() 17 | pipeline.add_component("first_addition", AddFixedValue(add=2)) 18 | pipeline.add_component("second_addition", AddFixedValue(add=2)) 19 | pipeline.add_component("third_addition", AddFixedValue(add=2)) 20 | pipeline.add_component("sum", Sum()) 21 | pipeline.add_component("fourth_addition", AddFixedValue(add=1)) 22 | 23 | pipeline.connect("first_addition.result", "second_addition.value") 24 | pipeline.connect("first_addition.result", "sum.values") 25 | pipeline.connect("second_addition.result", "sum.values") 26 | pipeline.connect("third_addition.result", "sum.values") 27 | pipeline.connect("sum.total", "fourth_addition.value") 28 | 29 | pipeline.draw(tmp_path / "variable_merging_pipeline.png") 30 | 31 | results = pipeline.run( 32 | { 33 | "first_addition": {"value": 1}, 34 | "third_addition": {"value": 1}, 35 | } 36 | ) 37 | pprint(results) 38 | 39 | assert results == {"fourth_addition": {"result": 12}} 40 | 41 | 42 | if __name__ == "__main__": 43 | test_pipeline(Path(__file__).parent) 44 | -------------------------------------------------------------------------------- /test/pipeline/unit/__init__.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: 2022-present deepset GmbH 2 | # 3 | # SPDX-License-Identifier: Apache-2.0 4 | -------------------------------------------------------------------------------- /test/pipeline/unit/test_draw.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: 2022-present deepset GmbH 2 | # 3 | # SPDX-License-Identifier: Apache-2.0 4 | import os 5 | import sys 6 | import filecmp 7 | 8 | from unittest.mock import patch, MagicMock 9 | import pytest 10 | import requests 11 | 12 | from canals.pipeline import Pipeline 13 | from canals.pipeline.draw import _draw, _convert 14 | from canals.errors import PipelineDrawingError 15 | from sample_components import Double, AddFixedValue 16 | 17 | 18 | @pytest.mark.skipif(sys.platform.lower().startswith("darwin"), reason="the available graphviz version is too recent") 19 | @pytest.mark.skipif(sys.platform.lower().startswith("win"), reason="pygraphviz is not really available in Windows") 20 | def test_draw_pygraphviz(tmp_path, test_files): 21 | pipe = Pipeline() 22 | pipe.add_component("comp1", Double()) 23 | pipe.add_component("comp2", Double()) 24 | pipe.connect("comp1", "comp2") 25 | 26 | _draw(pipe.graph, tmp_path / "test_pipe.jpg", engine="graphviz") 27 | assert os.path.exists(tmp_path / "test_pipe.jpg") 28 | assert filecmp.cmp(tmp_path / "test_pipe.jpg", test_files / "pipeline_draw" / "pygraphviz.jpg") 29 | 30 | 31 | def test_draw_mermaid_image(tmp_path, test_files): 32 | pipe = Pipeline() 33 | pipe.add_component("comp1", Double()) 34 | pipe.add_component("comp2", Double()) 35 | pipe.connect("comp1", "comp2") 36 | pipe.connect("comp2", "comp1") 37 | 38 | _draw(pipe.graph, tmp_path / "test_pipe.jpg", engine="mermaid-image") 39 | assert os.path.exists(tmp_path / "test_pipe.jpg") 40 | assert filecmp.cmp(tmp_path / "test_pipe.jpg", test_files / "mermaid_mock" / "test_response.png") 41 | 42 | 43 | def test_draw_mermaid_img_failing_request(tmp_path): 44 | pipe = Pipeline() 45 | pipe.add_component("comp1", Double()) 46 | pipe.add_component("comp2", Double()) 47 | pipe.connect("comp1", "comp2") 48 | pipe.connect("comp2", "comp1") 49 | 50 | with patch("canals.pipeline.draw.mermaid.requests.get") as mock_get: 51 | 52 | def raise_for_status(self): 53 | raise requests.HTTPError() 54 | 55 | mock_response = MagicMock() 56 | mock_response.status_code = 429 57 | mock_response.content = '{"error": "too many requests"}' 58 | mock_response.raise_for_status = raise_for_status 59 | mock_get.return_value = mock_response 60 | 61 | with pytest.raises(PipelineDrawingError, match="There was an issue with https://mermaid.ink/"): 62 | _draw(pipe.graph, tmp_path / "test_pipe.jpg", engine="mermaid-image") 63 | 64 | 65 | def test_draw_mermaid_text(tmp_path): 66 | pipe = Pipeline() 67 | pipe.add_component("comp1", AddFixedValue(add=3)) 68 | pipe.add_component("comp2", Double()) 69 | pipe.connect("comp1.result", "comp2.value") 70 | pipe.connect("comp2.value", "comp1.value") 71 | 72 | _draw(pipe.graph, tmp_path / "test_pipe.md", engine="mermaid-text") 73 | assert os.path.exists(tmp_path / "test_pipe.md") 74 | assert ( 75 | open(tmp_path / "test_pipe.md", "r").read() 76 | == """ 77 | %%{ init: {'theme': 'neutral' } }%% 78 | 79 | graph TD; 80 | 81 | comp1["comp1
    AddFixedValue

    Optional inputs:
    • add (Optional[int])
    "]:::component -- "result -> value
    int" --> comp2["comp2
    Double"]:::component 82 | comp2["comp2
    Double"]:::component -- "value -> value
    int" --> comp1["comp1
    AddFixedValue

    Optional inputs:
    • add (Optional[int])
    "]:::component 83 | 84 | classDef component text-align:center; 85 | """ 86 | ) 87 | 88 | 89 | def test_draw_unknown_engine(tmp_path): 90 | pipe = Pipeline() 91 | pipe.add_component("comp1", Double()) 92 | pipe.add_component("comp2", Double()) 93 | pipe.connect("comp1", "comp2") 94 | pipe.connect("comp2", "comp1") 95 | 96 | with pytest.raises(ValueError, match="Unknown rendering engine 'unknown'"): 97 | _draw(pipe.graph, tmp_path / "test_pipe.jpg", engine="unknown") 98 | 99 | 100 | def test_convert_unknown_engine(tmp_path): 101 | pipe = Pipeline() 102 | pipe.add_component("comp1", Double()) 103 | pipe.add_component("comp2", Double()) 104 | pipe.connect("comp1", "comp2") 105 | pipe.connect("comp2", "comp1") 106 | 107 | with pytest.raises(ValueError, match="Unknown rendering engine 'unknown'"): 108 | _convert(pipe.graph, engine="unknown") 109 | -------------------------------------------------------------------------------- /test/pipeline/unit/test_validation_pipeline_io.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import pytest 4 | import inspect 5 | 6 | from canals.pipeline import Pipeline 7 | from canals.component.types import Variadic 8 | from canals.errors import PipelineValidationError 9 | from canals.component.sockets import InputSocket, OutputSocket 10 | from canals.pipeline.descriptions import find_pipeline_inputs, find_pipeline_outputs 11 | from sample_components import Double, AddFixedValue, Sum, Parity 12 | 13 | 14 | def test_find_pipeline_input_no_input(): 15 | pipe = Pipeline() 16 | pipe.add_component("comp1", Double()) 17 | pipe.add_component("comp2", Double()) 18 | pipe.connect("comp1", "comp2") 19 | pipe.connect("comp2", "comp1") 20 | 21 | assert find_pipeline_inputs(pipe.graph) == {"comp1": [], "comp2": []} 22 | 23 | 24 | def test_find_pipeline_input_one_input(): 25 | pipe = Pipeline() 26 | pipe.add_component("comp1", Double()) 27 | pipe.add_component("comp2", Double()) 28 | pipe.connect("comp1", "comp2") 29 | 30 | assert find_pipeline_inputs(pipe.graph) == { 31 | "comp1": [InputSocket(name="value", type=int)], 32 | "comp2": [], 33 | } 34 | 35 | 36 | def test_find_pipeline_input_two_inputs_same_component(): 37 | pipe = Pipeline() 38 | pipe.add_component("comp1", AddFixedValue()) 39 | pipe.add_component("comp2", Double()) 40 | pipe.connect("comp1", "comp2") 41 | 42 | assert find_pipeline_inputs(pipe.graph) == { 43 | "comp1": [ 44 | InputSocket(name="value", type=int), 45 | InputSocket(name="add", type=Optional[int], is_mandatory=False), 46 | ], 47 | "comp2": [], 48 | } 49 | 50 | 51 | def test_find_pipeline_input_some_inputs_different_components(): 52 | pipe = Pipeline() 53 | pipe.add_component("comp1", AddFixedValue()) 54 | pipe.add_component("comp2", Double()) 55 | pipe.add_component("comp3", AddFixedValue()) 56 | pipe.connect("comp1.result", "comp3.value") 57 | pipe.connect("comp2.value", "comp3.add") 58 | 59 | assert find_pipeline_inputs(pipe.graph) == { 60 | "comp1": [ 61 | InputSocket(name="value", type=int), 62 | InputSocket(name="add", type=Optional[int], is_mandatory=False), 63 | ], 64 | "comp2": [InputSocket(name="value", type=int)], 65 | "comp3": [], 66 | } 67 | 68 | 69 | def test_find_pipeline_variable_input_nodes_in_the_pipeline(): 70 | pipe = Pipeline() 71 | pipe.add_component("comp1", AddFixedValue()) 72 | pipe.add_component("comp2", Double()) 73 | pipe.add_component("comp3", Sum()) 74 | 75 | assert find_pipeline_inputs(pipe.graph) == { 76 | "comp1": [ 77 | InputSocket(name="value", type=int), 78 | InputSocket(name="add", type=Optional[int], is_mandatory=False), 79 | ], 80 | "comp2": [InputSocket(name="value", type=int)], 81 | "comp3": [ 82 | InputSocket(name="values", type=Variadic[int]), 83 | ], 84 | } 85 | 86 | 87 | def test_find_pipeline_output_no_output(): 88 | pipe = Pipeline() 89 | pipe.add_component("comp1", Double()) 90 | pipe.add_component("comp2", Double()) 91 | pipe.connect("comp1", "comp2") 92 | pipe.connect("comp2", "comp1") 93 | 94 | assert find_pipeline_outputs(pipe.graph) == {"comp1": [], "comp2": []} 95 | 96 | 97 | def test_find_pipeline_output_one_output(): 98 | pipe = Pipeline() 99 | pipe.add_component("comp1", Double()) 100 | pipe.add_component("comp2", Double()) 101 | pipe.connect("comp1", "comp2") 102 | 103 | assert find_pipeline_outputs(pipe.graph) == {"comp1": [], "comp2": [OutputSocket(name="value", type=int)]} 104 | 105 | 106 | def test_find_pipeline_some_outputs_same_component(): 107 | pipe = Pipeline() 108 | pipe.add_component("comp1", Double()) 109 | pipe.add_component("comp2", Parity()) 110 | pipe.connect("comp1", "comp2") 111 | 112 | assert find_pipeline_outputs(pipe.graph) == { 113 | "comp1": [], 114 | "comp2": [OutputSocket(name="even", type=int), OutputSocket(name="odd", type=int)], 115 | } 116 | 117 | 118 | def test_find_pipeline_some_outputs_different_components(): 119 | pipe = Pipeline() 120 | pipe.add_component("comp1", Double()) 121 | pipe.add_component("comp2", Parity()) 122 | pipe.add_component("comp3", Double()) 123 | pipe.connect("comp1", "comp2") 124 | pipe.connect("comp1", "comp3") 125 | 126 | assert find_pipeline_outputs(pipe.graph) == { 127 | "comp1": [], 128 | "comp2": [OutputSocket(name="even", type=int), OutputSocket(name="odd", type=int)], 129 | "comp3": [ 130 | OutputSocket(name="value", type=int), 131 | ], 132 | } 133 | 134 | 135 | def test_validate_pipeline_input_pipeline_with_no_inputs(): 136 | pipe = Pipeline() 137 | pipe.add_component("comp1", Double()) 138 | pipe.add_component("comp2", Double()) 139 | pipe.connect("comp1", "comp2") 140 | pipe.connect("comp2", "comp1") 141 | with pytest.raises(PipelineValidationError, match="This pipeline has no inputs."): 142 | pipe.run({}) 143 | 144 | 145 | def test_validate_pipeline_input_unknown_component(): 146 | pipe = Pipeline() 147 | pipe.add_component("comp1", Double()) 148 | pipe.add_component("comp2", Double()) 149 | pipe.connect("comp1", "comp2") 150 | with pytest.raises(ValueError, match="Pipeline received data for unknown component\(s\): test_component"): 151 | pipe.run({"test_component": {"value": 1}}) 152 | 153 | 154 | def test_validate_pipeline_input_all_necessary_input_is_present(): 155 | pipe = Pipeline() 156 | pipe.add_component("comp1", Double()) 157 | pipe.add_component("comp2", Double()) 158 | pipe.connect("comp1", "comp2") 159 | with pytest.raises(ValueError, match="Missing input: comp1.value"): 160 | pipe.run({}) 161 | 162 | 163 | def test_validate_pipeline_input_all_necessary_input_is_present_considering_defaults(): 164 | pipe = Pipeline() 165 | pipe.add_component("comp1", AddFixedValue()) 166 | pipe.add_component("comp2", Double()) 167 | pipe.connect("comp1", "comp2") 168 | pipe.run({"comp1": {"value": 1}}) 169 | pipe.run({"comp1": {"value": 1, "add": 2}}) 170 | with pytest.raises(ValueError, match="Missing input: comp1.value"): 171 | pipe.run({"comp1": {"add": 3}}) 172 | 173 | 174 | def test_validate_pipeline_input_only_expected_input_is_present(): 175 | pipe = Pipeline() 176 | pipe.add_component("comp1", Double()) 177 | pipe.add_component("comp2", Double()) 178 | pipe.connect("comp1", "comp2") 179 | with pytest.raises(ValueError, match="The input value of comp2 is already sent by: \['comp1'\]"): 180 | pipe.run({"comp1": {"value": 1}, "comp2": {"value": 2}}) 181 | 182 | 183 | def test_validate_pipeline_input_only_expected_input_is_present_falsy(): 184 | pipe = Pipeline() 185 | pipe.add_component("comp1", Double()) 186 | pipe.add_component("comp2", Double()) 187 | pipe.connect("comp1", "comp2") 188 | with pytest.raises(ValueError, match="The input value of comp2 is already sent by: \['comp1'\]"): 189 | pipe.run({"comp1": {"value": 1}, "comp2": {"value": 0}}) 190 | 191 | 192 | def test_validate_pipeline_falsy_input_present(): 193 | pipe = Pipeline() 194 | pipe.add_component("comp", Double()) 195 | assert pipe.run({"comp": {"value": 0}}) == {"comp": {"value": 0}} 196 | 197 | 198 | def test_validate_pipeline_falsy_input_missing(): 199 | pipe = Pipeline() 200 | pipe.add_component("comp", Double()) 201 | with pytest.raises(ValueError, match="Missing input: comp.value"): 202 | pipe.run({"comp": {}}) 203 | 204 | 205 | def test_validate_pipeline_input_only_expected_input_is_present_including_unknown_names(): 206 | pipe = Pipeline() 207 | pipe.add_component("comp1", Double()) 208 | pipe.add_component("comp2", Double()) 209 | pipe.connect("comp1", "comp2") 210 | 211 | with pytest.raises(ValueError, match="Component comp1 is not expecting any input value called add"): 212 | pipe.run({"comp1": {"value": 1, "add": 2}}) 213 | 214 | 215 | def test_validate_pipeline_input_only_expected_input_is_present_and_defaults_dont_interfere(): 216 | pipe = Pipeline() 217 | pipe.add_component("comp1", AddFixedValue(add=10)) 218 | pipe.add_component("comp2", Double()) 219 | pipe.connect("comp1", "comp2") 220 | assert pipe.run({"comp1": {"value": 1, "add": 5}}) == {"comp2": {"value": 12}} 221 | -------------------------------------------------------------------------------- /test/sample_components/__init__.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: 2022-present deepset GmbH 2 | # 3 | # SPDX-License-Identifier: Apache-2.0 4 | from test.sample_components.test_accumulate import Accumulate 5 | from test.sample_components.test_add_value import AddFixedValue 6 | from test.sample_components.test_double import Double 7 | from test.sample_components.test_parity import Parity 8 | from test.sample_components.test_greet import Greet 9 | from test.sample_components.test_remainder import Remainder 10 | from test.sample_components.test_repeat import Repeat 11 | from test.sample_components.test_subtract import Subtract 12 | from test.sample_components.test_sum import Sum 13 | from test.sample_components.test_threshold import Threshold 14 | -------------------------------------------------------------------------------- /test/sample_components/test_accumulate.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: 2022-present deepset GmbH 2 | # 3 | # SPDX-License-Identifier: Apache-2.0 4 | from sample_components.accumulate import Accumulate, _default_function 5 | 6 | 7 | def my_subtract(first, second): 8 | return first - second 9 | 10 | 11 | def test_to_dict(): 12 | accumulate = Accumulate() 13 | res = accumulate.to_dict() 14 | assert res == { 15 | "type": "sample_components.accumulate.Accumulate", 16 | "init_parameters": {"function": "sample_components.accumulate._default_function"}, 17 | } 18 | 19 | 20 | def test_to_dict_with_custom_function(): 21 | accumulate = Accumulate(function=my_subtract) 22 | res = accumulate.to_dict() 23 | assert res == { 24 | "type": "sample_components.accumulate.Accumulate", 25 | "init_parameters": {"function": "test.sample_components.test_accumulate.my_subtract"}, 26 | } 27 | 28 | 29 | def test_from_dict(): 30 | data = { 31 | "type": "sample_components.accumulate.Accumulate", 32 | "init_parameters": {}, 33 | } 34 | accumulate = Accumulate.from_dict(data) 35 | assert accumulate.function == _default_function 36 | 37 | 38 | def test_from_dict_with_default_function(): 39 | data = { 40 | "type": "sample_components.accumulate.Accumulate", 41 | "init_parameters": {"function": "sample_components.accumulate._default_function"}, 42 | } 43 | accumulate = Accumulate.from_dict(data) 44 | assert accumulate.function == _default_function 45 | 46 | 47 | def test_from_dict_with_custom_function(): 48 | data = { 49 | "type": "sample_components.accumulate.Accumulate", 50 | "init_parameters": {"function": "test.sample_components.test_accumulate.my_subtract"}, 51 | } 52 | accumulate = Accumulate.from_dict(data) 53 | assert accumulate.function == my_subtract 54 | 55 | 56 | def test_accumulate_default(): 57 | component = Accumulate() 58 | results = component.run(value=10) 59 | assert results == {"value": 10} 60 | assert component.state == 10 61 | 62 | results = component.run(value=1) 63 | assert results == {"value": 11} 64 | assert component.state == 11 65 | 66 | 67 | def test_accumulate_callable(): 68 | component = Accumulate(function=my_subtract) 69 | 70 | results = component.run(value=10) 71 | assert results == {"value": -10} 72 | assert component.state == -10 73 | 74 | results = component.run(value=1) 75 | assert results == {"value": -11} 76 | assert component.state == -11 77 | -------------------------------------------------------------------------------- /test/sample_components/test_add_value.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: 2022-present deepset GmbH 2 | # 3 | # SPDX-License-Identifier: Apache-2.0 4 | from sample_components import AddFixedValue 5 | from canals.serialization import component_to_dict, component_from_dict 6 | 7 | 8 | def test_run(): 9 | component = AddFixedValue() 10 | results = component.run(value=50, add=10) 11 | assert results == {"result": 60} 12 | -------------------------------------------------------------------------------- /test/sample_components/test_concatenate.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: 2022-present deepset GmbH 2 | # 3 | # SPDX-License-Identifier: Apache-2.0 4 | from sample_components import Concatenate 5 | from canals.serialization import component_to_dict, component_from_dict 6 | 7 | 8 | def test_input_lists(): 9 | component = Concatenate() 10 | res = component.run(first=["This"], second=["That"]) 11 | assert res == {"value": ["This", "That"]} 12 | 13 | 14 | def test_input_strings(): 15 | component = Concatenate() 16 | res = component.run(first="This", second="That") 17 | assert res == {"value": ["This", "That"]} 18 | 19 | 20 | def test_input_first_list_second_string(): 21 | component = Concatenate() 22 | res = component.run(first=["This"], second="That") 23 | assert res == {"value": ["This", "That"]} 24 | 25 | 26 | def test_input_first_string_second_list(): 27 | component = Concatenate() 28 | res = component.run(first="This", second=["That"]) 29 | assert res == {"value": ["This", "That"]} 30 | -------------------------------------------------------------------------------- /test/sample_components/test_double.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: 2022-present deepset GmbH 2 | # 3 | # SPDX-License-Identifier: Apache-2.0 4 | 5 | from sample_components import Double 6 | from canals.serialization import component_to_dict, component_from_dict 7 | 8 | 9 | def test_double_default(): 10 | component = Double() 11 | results = component.run(value=10) 12 | assert results == {"value": 20} 13 | -------------------------------------------------------------------------------- /test/sample_components/test_fstring.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from sample_components import FString 3 | 4 | 5 | def test_fstring_with_one_var(): 6 | fstring = FString(template="Hello, {name}!", variables=["name"]) 7 | output = fstring.run(name="Alice") 8 | assert output == {"string": "Hello, Alice!"} 9 | 10 | 11 | def test_fstring_with_no_vars(): 12 | fstring = FString(template="No variables in this template.", variables=[]) 13 | output = fstring.run() 14 | assert output == {"string": "No variables in this template."} 15 | 16 | 17 | def test_fstring_with_template_at_runtime(): 18 | fstring = FString(template="Hello {name}", variables=["name"]) 19 | output = fstring.run(template="Goodbye {name}!", name="Alice") 20 | assert output == {"string": "Goodbye Alice!"} 21 | 22 | 23 | def test_fstring_with_vars_mismatch(): 24 | fstring = FString(template="Hello {name}", variables=["name"]) 25 | with pytest.raises(KeyError): 26 | fstring.run(template="Goodbye {person}!", name="Alice") 27 | 28 | 29 | def test_fstring_with_vars_in_excess(): 30 | fstring = FString(template="Hello {name}", variables=["name"]) 31 | output = fstring.run(template="Goodbye!", name="Alice") 32 | assert output == {"string": "Goodbye!"} 33 | 34 | 35 | def test_fstring_with_vars_missing(): 36 | fstring = FString(template="{greeting}, {name}!", variables=["name"]) 37 | with pytest.raises(KeyError): 38 | fstring.run(greeting="Hello") 39 | -------------------------------------------------------------------------------- /test/sample_components/test_greet.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: 2022-present deepset GmbH 2 | # 3 | # SPDX-License-Identifier: Apache-2.0 4 | import logging 5 | 6 | from sample_components import Greet 7 | from canals.serialization import component_to_dict, component_from_dict 8 | 9 | 10 | def test_greet_message(caplog): 11 | caplog.set_level(logging.WARNING) 12 | component = Greet() 13 | results = component.run(value=10, message="Hello, that's {value}", log_level="WARNING") 14 | assert results == {"value": 10} 15 | assert "Hello, that's 10" in caplog.text 16 | -------------------------------------------------------------------------------- /test/sample_components/test_merge_loop.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: 2022-present deepset GmbH 2 | # 3 | # SPDX-License-Identifier: Apache-2.0 4 | from typing import Dict 5 | 6 | import pytest 7 | 8 | from canals.errors import DeserializationError 9 | 10 | from sample_components import MergeLoop 11 | 12 | 13 | def test_to_dict(): 14 | component = MergeLoop(expected_type=int, inputs=["first", "second"]) 15 | res = component.to_dict() 16 | assert res == { 17 | "type": "sample_components.merge_loop.MergeLoop", 18 | "init_parameters": {"expected_type": "builtins.int", "inputs": ["first", "second"]}, 19 | } 20 | 21 | 22 | def test_to_dict_with_typing_class(): 23 | component = MergeLoop(expected_type=Dict, inputs=["first", "second"]) 24 | res = component.to_dict() 25 | assert res == { 26 | "type": "sample_components.merge_loop.MergeLoop", 27 | "init_parameters": { 28 | "expected_type": "typing.Dict", 29 | "inputs": ["first", "second"], 30 | }, 31 | } 32 | 33 | 34 | def test_to_dict_with_custom_class(): 35 | component = MergeLoop(expected_type=MergeLoop, inputs=["first", "second"]) 36 | res = component.to_dict() 37 | assert res == { 38 | "type": "sample_components.merge_loop.MergeLoop", 39 | "init_parameters": { 40 | "expected_type": "sample_components.merge_loop.MergeLoop", 41 | "inputs": ["first", "second"], 42 | }, 43 | } 44 | 45 | 46 | def test_from_dict(): 47 | data = { 48 | "type": "sample_components.merge_loop.MergeLoop", 49 | "init_parameters": {"expected_type": "builtins.int", "inputs": ["first", "second"]}, 50 | } 51 | component = MergeLoop.from_dict(data) 52 | assert component.expected_type == "builtins.int" 53 | assert component.inputs == ["first", "second"] 54 | 55 | 56 | def test_from_dict_with_typing_class(): 57 | data = { 58 | "type": "sample_components.merge_loop.MergeLoop", 59 | "init_parameters": { 60 | "expected_type": "typing.Dict", 61 | "inputs": ["first", "second"], 62 | }, 63 | } 64 | component = MergeLoop.from_dict(data) 65 | assert component.expected_type == "typing.Dict" 66 | assert component.inputs == ["first", "second"] 67 | 68 | 69 | def test_from_dict_with_custom_class(): 70 | data = { 71 | "type": "sample_components.merge_loop.MergeLoop", 72 | "init_parameters": { 73 | "expected_type": "sample_components.merge_loop.MergeLoop", 74 | "inputs": ["first", "second"], 75 | }, 76 | } 77 | component = MergeLoop.from_dict(data) 78 | assert component.expected_type == "sample_components.merge_loop.MergeLoop" 79 | assert component.inputs == ["first", "second"] 80 | 81 | 82 | def test_from_dict_without_expected_type(): 83 | data = { 84 | "type": "sample_components.merge_loop.MergeLoop", 85 | "init_parameters": { 86 | "inputs": ["first", "second"], 87 | }, 88 | } 89 | with pytest.raises(DeserializationError) as exc: 90 | MergeLoop.from_dict(data) 91 | 92 | exc.match("Missing 'expected_type' field in 'init_parameters'") 93 | 94 | 95 | def test_from_dict_without_inputs(): 96 | data = { 97 | "type": "sample_components.merge_loop.MergeLoop", 98 | "init_parameters": { 99 | "expected_type": "sample_components.merge_loop.MergeLoop", 100 | }, 101 | } 102 | with pytest.raises(DeserializationError) as exc: 103 | MergeLoop.from_dict(data) 104 | 105 | exc.match("Missing 'inputs' field in 'init_parameters'") 106 | 107 | 108 | def test_merge_first(): 109 | component = MergeLoop(expected_type=int, inputs=["in_1", "in_2"]) 110 | results = component.run(in_1=5) 111 | assert results == {"value": 5} 112 | 113 | 114 | def test_merge_second(): 115 | component = MergeLoop(expected_type=int, inputs=["in_1", "in_2"]) 116 | results = component.run(in_2=5) 117 | assert results == {"value": 5} 118 | 119 | 120 | def test_merge_nones(): 121 | component = MergeLoop(expected_type=int, inputs=["in_1", "in_2", "in_3"]) 122 | results = component.run() 123 | assert results == {} 124 | 125 | 126 | def test_merge_one(): 127 | component = MergeLoop(expected_type=int, inputs=["in_1"]) 128 | results = component.run(in_1=1) 129 | assert results == {"value": 1} 130 | 131 | 132 | def test_merge_one_none(): 133 | component = MergeLoop(expected_type=int, inputs=[]) 134 | results = component.run() 135 | assert results == {} 136 | -------------------------------------------------------------------------------- /test/sample_components/test_parity.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: 2022-present deepset GmbH 2 | # 3 | # SPDX-License-Identifier: Apache-2.0 4 | from sample_components import Parity 5 | from canals.serialization import component_to_dict, component_from_dict 6 | 7 | 8 | def test_parity(): 9 | component = Parity() 10 | results = component.run(value=1) 11 | assert results == {"odd": 1} 12 | results = component.run(value=2) 13 | assert results == {"even": 2} 14 | -------------------------------------------------------------------------------- /test/sample_components/test_remainder.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: 2022-present deepset GmbH 2 | # 3 | # SPDX-License-Identifier: Apache-2.0 4 | import pytest 5 | 6 | from sample_components import Remainder 7 | from canals.serialization import component_to_dict, component_from_dict 8 | 9 | 10 | def test_remainder_default(): 11 | component = Remainder() 12 | results = component.run(value=4) 13 | assert results == {"remainder_is_1": 4} 14 | 15 | 16 | def test_remainder_with_divisor(): 17 | component = Remainder(divisor=4) 18 | results = component.run(value=4) 19 | assert results == {"remainder_is_0": 4} 20 | 21 | 22 | def test_remainder_zero(): 23 | with pytest.raises(ValueError): 24 | Remainder(divisor=0) 25 | -------------------------------------------------------------------------------- /test/sample_components/test_repeat.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: 2022-present deepset GmbH 2 | # 3 | # SPDX-License-Identifier: Apache-2.0 4 | from sample_components import Repeat 5 | from canals.serialization import component_to_dict, component_from_dict 6 | 7 | 8 | def test_repeat_default(): 9 | component = Repeat(outputs=["one", "two"]) 10 | results = component.run(value=10) 11 | assert results == {"one": 10, "two": 10} 12 | -------------------------------------------------------------------------------- /test/sample_components/test_subtract.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: 2022-present deepset GmbH 2 | # 3 | # SPDX-License-Identifier: Apache-2.0 4 | from sample_components import Subtract 5 | from canals.serialization import component_to_dict, component_from_dict 6 | 7 | 8 | def test_subtract(): 9 | component = Subtract() 10 | results = component.run(first_value=10, second_value=7) 11 | assert results == {"difference": 3} 12 | -------------------------------------------------------------------------------- /test/sample_components/test_sum.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: 2022-present deepset GmbH 2 | # 3 | # SPDX-License-Identifier: Apache-2.0 4 | 5 | from sample_components import Sum 6 | from canals.serialization import component_to_dict, component_from_dict 7 | 8 | 9 | def test_sum_receives_no_values(): 10 | component = Sum() 11 | results = component.run(values=[]) 12 | assert results == {"total": 0} 13 | 14 | 15 | def test_sum_receives_one_value(): 16 | component = Sum() 17 | assert component.run(values=[10]) == {"total": 10} 18 | 19 | 20 | def test_sum_receives_few_values(): 21 | component = Sum() 22 | assert component.run(values=[10, 2]) == {"total": 12} 23 | -------------------------------------------------------------------------------- /test/sample_components/test_threshold.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: 2022-present deepset GmbH 2 | # 3 | # SPDX-License-Identifier: Apache-2.0 4 | from sample_components import Threshold 5 | from canals.serialization import component_to_dict, component_from_dict 6 | 7 | 8 | def test_threshold(): 9 | component = Threshold() 10 | 11 | results = component.run(value=5, threshold=10) 12 | assert results == {"below": 5} 13 | 14 | results = component.run(value=15, threshold=10) 15 | assert results == {"above": 15} 16 | -------------------------------------------------------------------------------- /test/test_files/mermaid_mock/test_response.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepset-ai/canals/5fd3dcff3cebff81bb6fc38ed3303c2c2a725321/test/test_files/mermaid_mock/test_response.png -------------------------------------------------------------------------------- /test/test_files/pipeline_draw/pygraphviz.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deepset-ai/canals/5fd3dcff3cebff81bb6fc38ed3303c2c2a725321/test/test_files/pipeline_draw/pygraphviz.jpg -------------------------------------------------------------------------------- /test/test_serialization.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from unittest.mock import Mock 3 | 4 | import pytest 5 | 6 | from canals import Pipeline, component 7 | from canals.errors import DeserializationError 8 | from canals.testing import factory 9 | from canals.serialization import default_to_dict, default_from_dict 10 | 11 | 12 | def test_default_component_to_dict(): 13 | MyComponent = factory.component_class("MyComponent") 14 | comp = MyComponent() 15 | res = default_to_dict(comp) 16 | assert res == { 17 | "type": "canals.testing.factory.MyComponent", 18 | "init_parameters": {}, 19 | } 20 | 21 | 22 | def test_default_component_to_dict_with_init_parameters(): 23 | MyComponent = factory.component_class("MyComponent") 24 | comp = MyComponent() 25 | res = default_to_dict(comp, some_key="some_value") 26 | assert res == { 27 | "type": "canals.testing.factory.MyComponent", 28 | "init_parameters": {"some_key": "some_value"}, 29 | } 30 | 31 | 32 | def test_default_component_from_dict(): 33 | def custom_init(self, some_param): 34 | self.some_param = some_param 35 | 36 | extra_fields = {"__init__": custom_init} 37 | MyComponent = factory.component_class("MyComponent", extra_fields=extra_fields) 38 | comp = default_from_dict( 39 | MyComponent, 40 | { 41 | "type": "canals.testing.factory.MyComponent", 42 | "init_parameters": { 43 | "some_param": 10, 44 | }, 45 | }, 46 | ) 47 | assert isinstance(comp, MyComponent) 48 | assert comp.some_param == 10 49 | 50 | 51 | def test_default_component_from_dict_without_type(): 52 | with pytest.raises(DeserializationError, match="Missing 'type' in serialization data"): 53 | default_from_dict(Mock, {}) 54 | 55 | 56 | def test_default_component_from_dict_unregistered_component(request): 57 | # We use the test function name as component name to make sure it's not registered. 58 | # Since the registry is global we risk to have a component with the same name registered in another test. 59 | component_name = request.node.name 60 | 61 | with pytest.raises(DeserializationError, match=f"Class '{component_name}' can't be deserialized as 'Mock'"): 62 | default_from_dict(Mock, {"type": component_name}) 63 | 64 | 65 | def test_from_dict_import_type(): 66 | pipeline_dict = { 67 | "metadata": {}, 68 | "max_loops_allowed": 100, 69 | "components": { 70 | "greeter": { 71 | "type": "sample_components.greet.Greet", 72 | "init_parameters": { 73 | "message": "\nGreeting component says: Hi! The value is {value}\n", 74 | "log_level": "INFO", 75 | }, 76 | } 77 | }, 78 | "connections": [], 79 | } 80 | 81 | # remove the target component from the registry if already there 82 | component.registry.pop("sample_components.greet.Greet", None) 83 | # remove the module from sys.modules if already there 84 | sys.modules.pop("sample_components.greet", None) 85 | 86 | p = Pipeline.from_dict(pipeline_dict) 87 | 88 | from sample_components.greet import Greet 89 | 90 | assert type(p.get_component("greeter")) == Greet 91 | -------------------------------------------------------------------------------- /test/test_utils.py: -------------------------------------------------------------------------------- 1 | from typing import List, Set, Sequence, Tuple, Dict, Mapping, Literal, Union, Optional, Any 2 | from enum import Enum 3 | from pathlib import Path 4 | 5 | import pytest 6 | 7 | from canals.type_utils import _type_name 8 | 9 | 10 | class Class1: 11 | ... 12 | 13 | 14 | class Class2: 15 | ... 16 | 17 | 18 | class Class3(Class1): 19 | ... 20 | 21 | 22 | class Enum1(Enum): 23 | TEST1 = Class1 24 | TEST2 = Class2 25 | 26 | 27 | @pytest.mark.parametrize( 28 | "type_,repr", 29 | [ 30 | pytest.param(str, "str", id="primitive-types"), 31 | pytest.param(Any, "Any", id="any"), 32 | pytest.param(Class1, "Class1", id="class"), 33 | pytest.param(Optional[int], "Optional[int]", id="shallow-optional-with-primitive"), 34 | pytest.param(Optional[Any], "Optional[Any]", id="shallow-optional-with-any"), 35 | pytest.param(Optional[Class1], "Optional[Class1]", id="shallow-optional-with-class"), 36 | pytest.param(Union[bool, Class1], "Union[bool, Class1]", id="shallow-union"), 37 | pytest.param(List[str], "List[str]", id="shallow-sequence-of-primitives"), 38 | pytest.param(List[Set[Sequence[str]]], "List[Set[Sequence[str]]]", id="nested-sequence-of-primitives"), 39 | pytest.param( 40 | Optional[List[Set[Sequence[str]]]], 41 | "Optional[List[Set[Sequence[str]]]]", 42 | id="optional-nested-sequence-of-primitives", 43 | ), 44 | pytest.param( 45 | List[Set[Sequence[Optional[str]]]], 46 | "List[Set[Sequence[Optional[str]]]]", 47 | id="nested-optional-sequence-of-primitives", 48 | ), 49 | pytest.param(List[Class1], "List[Class1]", id="shallow-sequence-of-classes"), 50 | pytest.param(List[Set[Sequence[Class1]]], "List[Set[Sequence[Class1]]]", id="nested-sequence-of-classes"), 51 | pytest.param(Dict[str, int], "Dict[str, int]", id="shallow-mapping-of-primitives"), 52 | pytest.param( 53 | Dict[str, Mapping[str, Dict[str, int]]], 54 | "Dict[str, Mapping[str, Dict[str, int]]]", 55 | id="nested-mapping-of-primitives", 56 | ), 57 | pytest.param( 58 | Dict[str, Mapping[Any, Dict[str, int]]], 59 | "Dict[str, Mapping[Any, Dict[str, int]]]", 60 | id="nested-mapping-of-primitives-with-any", 61 | ), 62 | pytest.param(Dict[str, Class1], "Dict[str, Class1]", id="shallow-mapping-of-classes"), 63 | pytest.param( 64 | Dict[str, Mapping[str, Dict[str, Class1]]], 65 | "Dict[str, Mapping[str, Dict[str, Class1]]]", 66 | id="nested-mapping-of-classes", 67 | ), 68 | pytest.param( 69 | Literal["a", "b", "c"], 70 | "Literal['a', 'b', 'c']", 71 | id="string-literal", 72 | ), 73 | pytest.param( 74 | Literal[1, 2, 3], 75 | "Literal[1, 2, 3]", 76 | id="primitive-literal", 77 | ), 78 | pytest.param( 79 | Literal[Enum1.TEST1], 80 | "Literal[Enum1.TEST1]", 81 | id="enum-literal", 82 | ), 83 | pytest.param( 84 | Tuple[Optional[Literal["a", "b", "c"]], Union[Path, Dict[int, Class1]]], 85 | "Tuple[Optional[Literal['a', 'b', 'c']], Union[Path, Dict[int, Class1]]]", 86 | id="deeply-nested-complex-type", 87 | ), 88 | ], 89 | ) 90 | def test_type_name(type_, repr): 91 | assert _type_name(type_) == repr 92 | -------------------------------------------------------------------------------- /test/testing/test_factory.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from canals import component 4 | from canals.errors import ComponentError 5 | from canals.testing import factory 6 | 7 | 8 | def test_component_class_default(): 9 | MyComponent = factory.component_class("MyComponent") 10 | comp = MyComponent() 11 | res = comp.run(value=1) 12 | assert res == {"value": None} 13 | 14 | res = comp.run(value="something") 15 | assert res == {"value": None} 16 | 17 | res = comp.run(non_existing_input=1) 18 | assert res == {"value": None} 19 | 20 | 21 | def test_component_class_is_registered(): 22 | MyComponent = factory.component_class("MyComponent") 23 | assert component.registry["canals.testing.factory.MyComponent"] == MyComponent 24 | 25 | 26 | def test_component_class_with_input_types(): 27 | MyComponent = factory.component_class("MyComponent", input_types={"value": int}) 28 | comp = MyComponent() 29 | res = comp.run(value=1) 30 | assert res == {"value": None} 31 | 32 | res = comp.run(value="something") 33 | assert res == {"value": None} 34 | 35 | 36 | def test_component_class_with_output_types(): 37 | MyComponent = factory.component_class("MyComponent", output_types={"value": int}) 38 | comp = MyComponent() 39 | 40 | res = comp.run(value=1) 41 | assert res == {"value": None} 42 | 43 | 44 | def test_component_class_with_output(): 45 | MyComponent = factory.component_class("MyComponent", output={"value": 100}) 46 | comp = MyComponent() 47 | res = comp.run(value=1) 48 | assert res == {"value": 100} 49 | 50 | 51 | def test_component_class_with_output_and_output_types(): 52 | MyComponent = factory.component_class("MyComponent", output_types={"value": str}, output={"value": 100}) 53 | comp = MyComponent() 54 | 55 | res = comp.run(value=1) 56 | assert res == {"value": 100} 57 | 58 | 59 | def test_component_class_with_bases(): 60 | MyComponent = factory.component_class("MyComponent", bases=(Exception,)) 61 | comp = MyComponent() 62 | assert isinstance(comp, Exception) 63 | 64 | 65 | def test_component_class_with_extra_fields(): 66 | MyComponent = factory.component_class("MyComponent", extra_fields={"my_field": 10}) 67 | comp = MyComponent() 68 | assert comp.my_field == 10 69 | --------------------------------------------------------------------------------