├── .github └── workflows │ └── tests.yml ├── .gitignore ├── LICENSE ├── Makefile ├── README.md ├── griptape ├── __init__.py └── flow │ ├── __init__.py │ ├── artifacts │ ├── __init__.py │ ├── error_output.py │ ├── structure_artifact.py │ └── text_output.py │ ├── drivers │ ├── __init__.py │ ├── memory │ │ ├── __init__.py │ │ ├── disk_memory_driver.py │ │ └── memory_driver.py │ └── prompt │ │ ├── __init__.py │ │ ├── base_prompt_driver.py │ │ ├── cohere_prompt_driver.py │ │ ├── hugging_face_hub_prompt_driver.py │ │ ├── hugging_face_pipeline_prompt_driver.py │ │ └── openai_prompt_driver.py │ ├── memory │ ├── __init__.py │ ├── buffer_pipeline_memory.py │ ├── pipeline_memory.py │ ├── pipeline_run.py │ └── summary_pipeline_memory.py │ ├── rules │ ├── __init__.py │ ├── json.py │ ├── meta.py │ └── rule.py │ ├── schemas │ ├── __init__.py │ ├── base_schema.py │ ├── drivers │ │ ├── __init__.py │ │ ├── openai_prompt_driver_schema.py │ │ └── prompt_driver_schema.py │ ├── memory │ │ ├── __init__.py │ │ ├── buffer_pipeline_memory_schema.py │ │ ├── pipeline_memory_schema.py │ │ ├── pipeline_run_schema.py │ │ └── summary_pipeline_memory_schema.py │ ├── polymorphic_schema.py │ ├── rule_schema.py │ ├── steps │ │ ├── __init__.py │ │ ├── prompt_step_schema.py │ │ ├── step_schema.py │ │ └── toolkit_step_schema.py │ ├── structures │ │ ├── __init__.py │ │ ├── pipeline_schema.py │ │ ├── structure_schema.py │ │ └── workflow_schema.py │ ├── summarizers │ │ ├── __init__.py │ │ ├── prompt_driver_summarizer_schema.py │ │ └── summarizer_schema.py │ └── tokenizers │ │ ├── __init__.py │ │ └── tiktoken_tokenizer_schema.py │ ├── steps │ ├── __init__.py │ ├── prompt_step.py │ ├── step.py │ ├── tool_substep.py │ └── toolkit_step.py │ ├── structures │ ├── __init__.py │ ├── pipeline.py │ ├── structure.py │ └── workflow.py │ ├── summarizers │ ├── __init__.py │ ├── prompt_driver_summarizer.py │ └── summarizer.py │ ├── templates │ └── prompts │ │ ├── context.j2 │ │ ├── memory.j2 │ │ ├── pipeline.j2 │ │ ├── run_context.j2 │ │ ├── steps │ │ ├── prompt.j2 │ │ └── tool │ │ │ ├── substep.j2 │ │ │ ├── substeps.j2 │ │ │ └── tool.j2 │ │ ├── summarize.j2 │ │ ├── tool.j2 │ │ └── workflow.j2 │ ├── tokenizers │ ├── __init__.py │ ├── base_tokenizer.py │ ├── cohere_tokenizer.py │ ├── hugging_face_tokenizer.py │ └── tiktoken_tokenizer.py │ └── utils │ ├── __init__.py │ ├── conversation.py │ ├── j2.py │ └── tool_loader.py ├── poetry.lock ├── pyproject.toml └── tests ├── __init__.py ├── mocks ├── __init__.py ├── mock_driver.py ├── mock_failing_driver.py └── mock_value_driver.py └── unit ├── __init__.py ├── drivers ├── __init__.py ├── test_disk_memory_driver.py └── test_prompt_driver.py ├── memory ├── __init__.py ├── test_pipeline_buffer_memory.py ├── test_pipeline_memory.py └── test_pipeline_summary_memory.py ├── schemas ├── __init__.py ├── test_pipeline_schema.py └── test_workflow_schema.py ├── steps ├── __init__.py ├── test_prompt_step.py ├── test_tool_substep.py └── test_toolkit_step.py ├── structures ├── __init__.py ├── test_pipeline.py └── test_workflow.py ├── tokenizers ├── test_hugging_face_tokenizer.py └── test_tiktoken_tokenizer.py └── utils ├── __init__.py ├── test_conversation.py └── test_tool_loader.py /.github/workflows/tests.yml: -------------------------------------------------------------------------------- 1 | name: Tests 2 | 3 | on: 4 | push: 5 | branches: [ "main" ] 6 | pull_request: 7 | branches: [ "main" ] 8 | 9 | jobs: 10 | test: 11 | runs-on: ubuntu-latest 12 | strategy: 13 | fail-fast: false 14 | matrix: 15 | python-version: ["3.9", "3.10", "3.11"] 16 | steps: 17 | - name: Checkout actions 18 | uses: actions/checkout@v3 19 | 20 | - name: Set up Python ${{ matrix.python-version }} 21 | uses: actions/setup-python@v4 22 | with: 23 | python-version: ${{ matrix.python-version }} 24 | 25 | - name: Install and configure Poetry 26 | uses: snok/install-poetry@v1 27 | with: 28 | virtualenvs-create: true 29 | virtualenvs-in-project: true 30 | installer-parallel: true 31 | 32 | - name: Load cached venv 33 | id: cached-poetry-dependencies 34 | uses: actions/cache@v3 35 | with: 36 | path: .venv 37 | key: venv-${{ runner.os }}-${{ steps.setup-python.outputs.python-version }}-${{ hashFiles('**/poetry.lock') }} 38 | 39 | - name: Install dependencies 40 | if: steps.cached-poetry-dependencies.outputs.cache-hit != 'true' 41 | run: poetry install --no-interaction --no-root 42 | 43 | - name: Install project 44 | run: poetry install --no-interaction 45 | 46 | - name: Run tests 47 | run: | 48 | source .venv/bin/activate 49 | pytest tests/unit 50 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .env 2 | .idea 3 | .DS_Store 4 | .huskyrc.json 5 | out 6 | log.log 7 | **/node_modules 8 | *.pyc 9 | *.vsix 10 | **/.vscode/.ropeproject/** 11 | **/testFiles/**/.cache/** 12 | *.noseids 13 | .nyc_output 14 | .vscode-test 15 | __pycache__ 16 | npm-debug.log 17 | **/.mypy_cache/** 18 | !yarn.lock 19 | coverage/ 20 | cucumber-report.json 21 | **/.vscode-test/** 22 | **/.vscode test/** 23 | **/.vscode-smoke/** 24 | **/.venv*/ 25 | port.txt 26 | precommit.hook 27 | pythonFiles/lib/** 28 | debug_coverage*/** 29 | languageServer/** 30 | languageServer.*/** 31 | bin/** 32 | obj/** 33 | .pytest_cache 34 | tmp/** 35 | .python-version 36 | .vs/ 37 | test-results*.xml 38 | xunit-test-results.xml 39 | build/ci/performance/performance-results.json 40 | !build/ 41 | debug*.log 42 | debugpy*.log 43 | pydevd*.log 44 | nodeLanguageServer/** 45 | nodeLanguageServer.*/** 46 | dist/** 47 | *.egg-info 48 | 49 | # translation files 50 | *.xlf 51 | *.nls.*.json 52 | *.i18n.json -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright Vasily Vasinov 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | version: 2 | @poetry version $(v) 3 | @git add pyproject.toml 4 | @git commit -m "Version bump v$$(poetry version -s)" 5 | @git tag v$$(poetry version -s) 6 | @git push 7 | @git push --tags 8 | @poetry build 9 | @poetry publish -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # griptape-flow 2 | 3 | [![Tests](https://github.com/griptape-ai/griptape-flow/actions/workflows/tests.yml/badge.svg)](https://github.com/griptape-ai/griptape-flow/actions/workflows/tests.yml) 4 | [![PyPI Version](https://img.shields.io/pypi/v/griptape-flow.svg)](https://pypi.python.org/pypi/griptape-flow) 5 | [![Docs](https://readthedocs.org/projects/griptape/badge/)](https://griptape.readthedocs.io/en/latest/griptape_flow/) 6 | [![License](https://img.shields.io/badge/License-Apache%202.0-blue.svg)](https://github.com/gitbucket/gitbucket/blob/master/LICENSE) 7 | [![Griptape Discord](https://dcbadge.vercel.app/api/server/gnWRz88eym?compact=true&style=flat)](https://discord.gg/gnWRz88eym) 8 | 9 | **griptape-flow** is a Python framework for creating workflow DAGs and pipelines that use large language models (LLMs) such as GPT, Claude, Titan, and Cohere. 10 | 11 | **griptape-flow** is part of [griptape](https://github.com/griptape-ai/griptape), a modular Python framework for integrating data, APIs, tools, memory, and chain of thought reasoning into LLMs. 12 | 13 | ## Documentation 14 | 15 | Please refer to [Griptape Docs](https://griptape.readthedocs.io) for: 16 | 17 | - Getting started guides. 18 | - Core concepts and design overviews. 19 | - Examples. 20 | - Contribution guidelines. 21 | 22 | ## License 23 | 24 | griptape-flow is available under the Apache 2.0 License. 25 | -------------------------------------------------------------------------------- /griptape/__init__.py: -------------------------------------------------------------------------------- 1 | __path__ = __import__("pkgutil").extend_path(__path__, __name__) 2 | -------------------------------------------------------------------------------- /griptape/flow/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | PACKAGE_ABS_PATH = os.path.dirname(os.path.abspath(__file__)) 4 | 5 | __all__ = [ 6 | "PACKAGE_ABS_PATH" 7 | ] 8 | -------------------------------------------------------------------------------- /griptape/flow/artifacts/__init__.py: -------------------------------------------------------------------------------- 1 | from griptape.flow.artifacts.structure_artifact import StructureArtifact 2 | from griptape.flow.artifacts.error_output import ErrorOutput 3 | from griptape.flow.artifacts.text_output import TextOutput 4 | 5 | 6 | __all__ = [ 7 | "StructureArtifact", 8 | "ErrorOutput", 9 | "TextOutput" 10 | ] 11 | -------------------------------------------------------------------------------- /griptape/flow/artifacts/error_output.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | from typing import TYPE_CHECKING 3 | from typing import Optional 4 | from attr import define, field 5 | from griptape.flow.artifacts import StructureArtifact 6 | 7 | 8 | if TYPE_CHECKING: 9 | from griptape.flow.steps import Step 10 | 11 | 12 | @define(frozen=True) 13 | class ErrorOutput(StructureArtifact): 14 | exception: Optional[Exception] = field(default=None, kw_only=True) 15 | step: Optional[Step] = field(default=None, kw_only=True) 16 | -------------------------------------------------------------------------------- /griptape/flow/artifacts/structure_artifact.py: -------------------------------------------------------------------------------- 1 | from abc import ABC 2 | from typing import Optional 3 | from attr import define, field 4 | 5 | 6 | @define 7 | class StructureArtifact(ABC): 8 | value: Optional[any] = field() 9 | -------------------------------------------------------------------------------- /griptape/flow/artifacts/text_output.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | from attr import define, field 3 | from griptape.flow.artifacts import StructureArtifact 4 | from griptape.flow.tokenizers import BaseTokenizer 5 | 6 | 7 | @define(frozen=True) 8 | class TextOutput(StructureArtifact): 9 | meta: Optional[any] = field(default=None) 10 | 11 | def token_count(self, tokenizer: BaseTokenizer) -> Optional[int]: 12 | if isinstance(self.value, str): 13 | return tokenizer.token_count(self.value) 14 | else: 15 | return None 16 | -------------------------------------------------------------------------------- /griptape/flow/drivers/__init__.py: -------------------------------------------------------------------------------- 1 | from griptape.flow.drivers.prompt.base_prompt_driver import BasePromptDriver 2 | from griptape.flow.drivers.prompt.openai_prompt_driver import OpenAiPromptDriver 3 | from griptape.flow.drivers.prompt.cohere_prompt_driver import CoherePromptDriver 4 | from griptape.flow.drivers.prompt.hugging_face_pipeline_prompt_driver import HuggingFacePipelinePromptDriver 5 | from griptape.flow.drivers.prompt.hugging_face_hub_prompt_driver import HuggingFaceHubPromptDriver 6 | from griptape.flow.drivers.memory.memory_driver import MemoryDriver 7 | from griptape.flow.drivers.memory.disk_memory_driver import DiskMemoryDriver 8 | 9 | __all__ = [ 10 | "BasePromptDriver", 11 | "OpenAiPromptDriver", 12 | "CoherePromptDriver", 13 | "HuggingFacePipelinePromptDriver", 14 | "HuggingFaceHubPromptDriver", 15 | 16 | "MemoryDriver", 17 | "DiskMemoryDriver" 18 | ] 19 | -------------------------------------------------------------------------------- /griptape/flow/drivers/memory/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/griptape-ai/griptape-flow/c5e88f53a1fa6088ae3a281b300c6b3478a8e68e/griptape/flow/drivers/memory/__init__.py -------------------------------------------------------------------------------- /griptape/flow/drivers/memory/disk_memory_driver.py: -------------------------------------------------------------------------------- 1 | from attr import define, field 2 | from griptape.flow.drivers import MemoryDriver 3 | from griptape.flow.memory import PipelineMemory 4 | 5 | 6 | @define 7 | class DiskMemoryDriver(MemoryDriver): 8 | file_path: str = field(default="griptape_memory.json", kw_only=True) 9 | 10 | def store(self, memory: PipelineMemory) -> None: 11 | with open(self.file_path, "w") as file: 12 | file.write(memory.to_json()) 13 | 14 | def load(self) -> PipelineMemory: 15 | with open(self.file_path, "r") as file: 16 | memory = PipelineMemory.from_json(file.read()) 17 | 18 | memory.driver = self 19 | 20 | return memory 21 | -------------------------------------------------------------------------------- /griptape/flow/drivers/memory/memory_driver.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from griptape.flow.memory import PipelineMemory 3 | 4 | 5 | class MemoryDriver(ABC): 6 | @abstractmethod 7 | def store(self, memory: PipelineMemory) -> None: 8 | ... 9 | 10 | @abstractmethod 11 | def load(self) -> PipelineMemory: 12 | ... 13 | -------------------------------------------------------------------------------- /griptape/flow/drivers/prompt/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/griptape-ai/griptape-flow/c5e88f53a1fa6088ae3a281b300c6b3478a8e68e/griptape/flow/drivers/prompt/__init__.py -------------------------------------------------------------------------------- /griptape/flow/drivers/prompt/base_prompt_driver.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | import logging 3 | import time 4 | from abc import ABC, abstractmethod 5 | from typing import TYPE_CHECKING 6 | from attr import define, field, Factory 7 | from griptape.flow.tokenizers import BaseTokenizer 8 | 9 | if TYPE_CHECKING: 10 | from griptape.flow.artifacts import TextOutput 11 | 12 | 13 | @define 14 | class BasePromptDriver(ABC): 15 | max_retries: int = field(default=8, kw_only=True) 16 | retry_delay: float = field(default=1, kw_only=True) 17 | type: str = field(default=Factory(lambda self: self.__class__.__name__, takes_self=True), kw_only=True) 18 | temperature: float = field(default=0.5, kw_only=True) 19 | model: str 20 | tokenizer: BaseTokenizer 21 | 22 | def run(self, **kwargs) -> TextOutput: 23 | for attempt in range(0, self.max_retries + 1): 24 | try: 25 | return self.try_run(**kwargs) 26 | except Exception as e: 27 | logging.error(f"PromptDriver.run attempt {attempt} failed: {e}\nRetrying in {self.retry_delay} seconds") 28 | 29 | if attempt < self.max_retries: 30 | time.sleep(self.retry_delay) 31 | else: 32 | raise e 33 | 34 | @abstractmethod 35 | def try_run(self, **kwargs) -> TextOutput: 36 | ... 37 | -------------------------------------------------------------------------------- /griptape/flow/drivers/prompt/cohere_prompt_driver.py: -------------------------------------------------------------------------------- 1 | import cohere 2 | from attr import define, field, Factory 3 | from griptape.flow.artifacts import TextOutput 4 | from griptape.flow.drivers import BasePromptDriver 5 | from griptape.flow.tokenizers import CohereTokenizer 6 | 7 | 8 | @define 9 | class CoherePromptDriver(BasePromptDriver): 10 | api_key: str = field(kw_only=True) 11 | model: str = field(default=CohereTokenizer.DEFAULT_MODEL, kw_only=True) 12 | client: cohere.Client = field( 13 | default=Factory(lambda self: cohere.Client(self.api_key), takes_self=True), kw_only=True 14 | ) 15 | tokenizer: CohereTokenizer = field( 16 | default=Factory(lambda self: CohereTokenizer(model=self.model, client=self.client), takes_self=True), 17 | kw_only=True 18 | ) 19 | 20 | def try_run(self, value: any) -> TextOutput: 21 | result = self.client.generate( 22 | value, 23 | model=self.model, 24 | temperature=self.temperature, 25 | end_sequences=[self.tokenizer.stop_sequence, "Input:"], 26 | max_tokens=self.tokenizer.tokens_left(value) 27 | ) 28 | 29 | if len(result.generations) == 1: 30 | generation = result.generations[0] 31 | 32 | return TextOutput( 33 | value=generation.text.strip(), 34 | meta=result.meta 35 | ) 36 | else: 37 | raise Exception("Completion with more than one choice is not supported yet.") -------------------------------------------------------------------------------- /griptape/flow/drivers/prompt/hugging_face_hub_prompt_driver.py: -------------------------------------------------------------------------------- 1 | from attr import define, field, Factory 2 | from huggingface_hub import InferenceApi 3 | from transformers import AutoTokenizer 4 | from griptape.flow.artifacts import TextOutput 5 | from griptape.flow.drivers import BasePromptDriver 6 | from griptape.flow.tokenizers import HuggingFaceTokenizer 7 | 8 | 9 | @define 10 | class HuggingFaceHubPromptDriver(BasePromptDriver): 11 | SUPPORTED_TASKS = ["text2text-generation", "text-generation"] 12 | MAX_NEW_TOKENS = 250 13 | DEFAULT_PARAMS = { 14 | "return_full_text": False, 15 | "max_new_tokens": MAX_NEW_TOKENS 16 | } 17 | 18 | repo_id: str = field(kw_only=True) 19 | api_token: str = field(kw_only=True) 20 | use_gpu: bool = field(default=False, kw_only=True) 21 | params: dict = field(factory=dict, kw_only=True) 22 | model: str = field(default=Factory(lambda self: self.repo_id, takes_self=True), kw_only=True) 23 | client: InferenceApi = field( 24 | default=Factory( 25 | lambda self: InferenceApi(repo_id=self.repo_id, token=self.api_token, gpu=self.use_gpu), takes_self=True 26 | ), 27 | kw_only=True 28 | ) 29 | tokenizer: HuggingFaceTokenizer = field( 30 | default=Factory( 31 | lambda self: HuggingFaceTokenizer( 32 | tokenizer=AutoTokenizer.from_pretrained(self.repo_id), 33 | max_tokens=self.MAX_NEW_TOKENS 34 | ), takes_self=True 35 | ), 36 | kw_only=True 37 | ) 38 | 39 | def try_run(self, value: any) -> TextOutput: 40 | if self.client.task in self.SUPPORTED_TASKS: 41 | response = self.client( 42 | inputs=value, 43 | params=self.DEFAULT_PARAMS | self.params 44 | ) 45 | 46 | if len(response) == 1: 47 | return TextOutput( 48 | value=response[0]["generated_text"].strip() 49 | ) 50 | else: 51 | raise Exception("Completion with more than one choice is not supported yet.") 52 | else: 53 | raise Exception(f"Only models with the following tasks are supported: {self.SUPPORTED_TASKS}") 54 | -------------------------------------------------------------------------------- /griptape/flow/drivers/prompt/hugging_face_pipeline_prompt_driver.py: -------------------------------------------------------------------------------- 1 | from attr import define, field, Factory 2 | from transformers import pipeline, AutoTokenizer 3 | from griptape.flow.artifacts import TextOutput 4 | from griptape.flow.drivers import BasePromptDriver 5 | from griptape.flow.tokenizers import HuggingFaceTokenizer 6 | 7 | 8 | @define 9 | class HuggingFacePipelinePromptDriver(BasePromptDriver): 10 | SUPPORTED_TASKS = ["text2text-generation", "text-generation"] 11 | DEFAULT_PARAMS = { 12 | "return_full_text": False, 13 | "num_return_sequences": 1 14 | } 15 | 16 | model: str = field(kw_only=True) 17 | params: dict = field(factory=dict, kw_only=True) 18 | tokenizer: HuggingFaceTokenizer = field( 19 | default=Factory( 20 | lambda self: HuggingFaceTokenizer( 21 | tokenizer=AutoTokenizer.from_pretrained(self.model) 22 | ), takes_self=True 23 | ), 24 | kw_only=True 25 | ) 26 | 27 | def try_run(self, value: any) -> TextOutput: 28 | generator = pipeline( 29 | tokenizer=self.tokenizer.tokenizer, 30 | model=self.model, 31 | max_new_tokens=self.tokenizer.tokens_left(value) 32 | ) 33 | 34 | if generator.task in self.SUPPORTED_TASKS: 35 | extra_params = { 36 | "pad_token_id": self.tokenizer.tokenizer.eos_token_id 37 | } 38 | 39 | response = generator( 40 | value, 41 | **(self.DEFAULT_PARAMS | extra_params | self.params) 42 | ) 43 | 44 | if len(response) == 1: 45 | return TextOutput( 46 | value=response[0]["generated_text"].strip() 47 | ) 48 | else: 49 | raise Exception("Completion with more than one choice is not supported yet.") 50 | else: 51 | raise Exception(f"Only models with the following tasks are supported: {self.SUPPORTED_TASKS}") 52 | -------------------------------------------------------------------------------- /griptape/flow/drivers/prompt/openai_prompt_driver.py: -------------------------------------------------------------------------------- 1 | import json 2 | from typing import Optional 3 | import openai 4 | from attr import define, field, Factory 5 | from griptape.flow.artifacts import TextOutput 6 | from griptape.flow.drivers import BasePromptDriver 7 | from griptape.flow.tokenizers import TiktokenTokenizer 8 | 9 | 10 | @define 11 | class OpenAiPromptDriver(BasePromptDriver): 12 | api_type: str = field(default=openai.api_type, kw_only=True) 13 | api_version: Optional[str] = field(default=openai.api_version, kw_only=True) 14 | api_base: str = field(default=openai.api_base, kw_only=True) 15 | api_key: Optional[str] = field(default=openai.api_key, kw_only=True) 16 | organization: Optional[str] = field(default=openai.organization, kw_only=True) 17 | model: str = field(default=TiktokenTokenizer.DEFAULT_MODEL, kw_only=True) 18 | tokenizer: TiktokenTokenizer = field( 19 | default=Factory(lambda self: TiktokenTokenizer(model=self.model), takes_self=True), 20 | kw_only=True 21 | ) 22 | user: str = field(default="", kw_only=True) 23 | 24 | def __attrs_post_init__(self): 25 | openai.api_type = self.api_type 26 | openai.api_version = self.api_version 27 | openai.api_base = self.api_base 28 | openai.api_key = self.api_key 29 | openai.organization = self.organization 30 | 31 | def try_run(self, value: any) -> TextOutput: 32 | if self.tokenizer.is_chat(): 33 | return self.__run_chat(value) 34 | else: 35 | return self.__run_completion(value) 36 | 37 | def __run_chat(self, value: str) -> TextOutput: 38 | result = openai.ChatCompletion.create( 39 | model=self.tokenizer.model, 40 | messages=[ 41 | { 42 | "role": "user", 43 | "content": value 44 | } 45 | ], 46 | max_tokens=self.tokenizer.tokens_left(value), 47 | temperature=self.temperature, 48 | stop=self.tokenizer.stop_sequence, 49 | user=self.user 50 | ) 51 | 52 | if len(result.choices) == 1: 53 | return TextOutput( 54 | value=result.choices[0]["message"]["content"].strip(), 55 | meta={ 56 | "id": result["id"], 57 | "created": result["created"], 58 | "usage": json.dumps(result["usage"]) 59 | } 60 | ) 61 | else: 62 | raise Exception("Completion with more than one choice is not supported yet.") 63 | 64 | def __run_completion(self, value: str) -> TextOutput: 65 | result = openai.Completion.create( 66 | model=self.tokenizer.model, 67 | prompt=value, 68 | max_tokens=self.tokenizer.tokens_left(value), 69 | temperature=self.temperature, 70 | stop=self.tokenizer.stop_sequence, 71 | user=self.user 72 | ) 73 | 74 | if len(result.choices) == 1: 75 | return TextOutput( 76 | value=result.choices[0].text.strip(), 77 | meta={ 78 | "id": result["id"], 79 | "created": result["created"], 80 | "usage": json.dumps(result["usage"]) 81 | } 82 | ) 83 | else: 84 | raise Exception("Completion with more than one choice is not supported yet.") 85 | -------------------------------------------------------------------------------- /griptape/flow/memory/__init__.py: -------------------------------------------------------------------------------- 1 | from griptape.flow.memory.pipeline_run import PipelineRun 2 | from griptape.flow.memory.pipeline_memory import PipelineMemory 3 | from griptape.flow.memory.summary_pipeline_memory import SummaryPipelineMemory 4 | from griptape.flow.memory.buffer_pipeline_memory import BufferPipelineMemory 5 | 6 | 7 | __all__ = [ 8 | "PipelineRun", 9 | "PipelineMemory", 10 | "SummaryPipelineMemory", 11 | "BufferPipelineMemory" 12 | ] 13 | -------------------------------------------------------------------------------- /griptape/flow/memory/buffer_pipeline_memory.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | import json 3 | from attr import define, field 4 | from griptape.flow.memory import PipelineMemory, PipelineRun 5 | 6 | 7 | @define 8 | class BufferPipelineMemory(PipelineMemory): 9 | buffer_size: int = field(default=1, kw_only=True) 10 | 11 | def process_add_run(self, run: PipelineRun) -> None: 12 | super().process_add_run(run) 13 | 14 | while len(self.runs) > self.buffer_size: 15 | self.runs.pop(0) 16 | 17 | def to_dict(self) -> dict: 18 | return BufferPipelineMemory().dump(self) 19 | 20 | @classmethod 21 | def from_dict(cls, memory_dict: dict) -> PipelineMemory: 22 | return BufferPipelineMemory().load(memory_dict) 23 | 24 | @classmethod 25 | def from_json(cls, memory_json: str) -> PipelineMemory: 26 | return BufferPipelineMemory.from_dict(json.loads(memory_json)) 27 | -------------------------------------------------------------------------------- /griptape/flow/memory/pipeline_memory.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | import json 3 | from typing import TYPE_CHECKING, Optional 4 | from attr import define, field, Factory 5 | from griptape.flow.memory import PipelineRun 6 | from griptape.flow.utils import J2 7 | 8 | if TYPE_CHECKING: 9 | from griptape.flow.drivers import MemoryDriver 10 | from griptape.flow.structures import Pipeline 11 | 12 | 13 | @define 14 | class PipelineMemory: 15 | type: str = field(default=Factory(lambda self: self.__class__.__name__, takes_self=True), kw_only=True) 16 | driver: Optional[MemoryDriver] = field(default=None, kw_only=True) 17 | runs: list[PipelineRun] = field(factory=list, kw_only=True) 18 | pipeline: Pipeline = field(init=False) 19 | 20 | def add_run(self, run: PipelineRun) -> PipelineMemory: 21 | self.before_add_run() 22 | self.process_add_run(run) 23 | self.after_add_run() 24 | 25 | return self 26 | 27 | def before_add_run(self) -> None: 28 | pass 29 | 30 | def process_add_run(self, run: PipelineRun) -> None: 31 | self.runs.append(run) 32 | 33 | def after_add_run(self) -> None: 34 | if self.driver: 35 | self.driver.store(self) 36 | 37 | def is_empty(self) -> bool: 38 | return not self.runs 39 | 40 | def to_prompt_string(self, last_n: Optional[int] = None) -> str: 41 | return J2("prompts/memory.j2").render( 42 | runs=self.runs if last_n is None else self.runs[-last_n:] 43 | ) 44 | 45 | def to_json(self) -> str: 46 | return json.dumps(self.to_dict(), indent=2) 47 | 48 | def to_dict(self) -> dict: 49 | from griptape.flow.schemas import PipelineMemorySchema 50 | 51 | return PipelineMemorySchema().dump(self) 52 | 53 | @classmethod 54 | def from_dict(cls, memory_dict: dict) -> PipelineMemory: 55 | from griptape.flow.schemas import PipelineMemorySchema 56 | 57 | return PipelineMemorySchema().load(memory_dict) 58 | 59 | @classmethod 60 | def from_json(cls, memory_json: str) -> PipelineMemory: 61 | return PipelineMemory.from_dict(json.loads(memory_json)) 62 | -------------------------------------------------------------------------------- /griptape/flow/memory/pipeline_run.py: -------------------------------------------------------------------------------- 1 | import uuid 2 | 3 | from attr import define, field, Factory 4 | 5 | from griptape.flow.utils import J2 6 | 7 | 8 | @define 9 | class PipelineRun: 10 | id: str = field(default=Factory(lambda: uuid.uuid4().hex), kw_only=True) 11 | input: str = field(kw_only=True) 12 | output: str = field(kw_only=True) 13 | 14 | def render(self) -> str: 15 | return J2("prompts/run_context.j2").render( 16 | run=self 17 | ) 18 | -------------------------------------------------------------------------------- /griptape/flow/memory/summary_pipeline_memory.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | import json 3 | from typing import TYPE_CHECKING 4 | from typing import Optional 5 | from attr import define, field 6 | from griptape.flow.utils import J2 7 | from griptape.flow.memory import PipelineMemory 8 | 9 | if TYPE_CHECKING: 10 | from griptape.flow.summarizers import Summarizer 11 | from griptape.flow.memory import PipelineRun 12 | 13 | 14 | @define 15 | class SummaryPipelineMemory(PipelineMemory): 16 | offset: int = field(default=1, kw_only=True) 17 | summarizer: Optional[Summarizer] = field(default=None, kw_only=True) 18 | summary: Optional[str] = field(default=None, kw_only=True) 19 | summary_index: int = field(default=0, kw_only=True) 20 | 21 | def unsummarized_runs(self, last_n: Optional[int] = None) -> list[PipelineRun]: 22 | summary_index_runs = self.runs[self.summary_index:] 23 | 24 | if last_n: 25 | last_n_runs = self.runs[-last_n:] 26 | 27 | if len(summary_index_runs) > len(last_n_runs): 28 | return last_n_runs 29 | else: 30 | return summary_index_runs 31 | else: 32 | return summary_index_runs 33 | 34 | def process_add_run(self, run: PipelineRun) -> None: 35 | super().process_add_run(run) 36 | 37 | if self.summarizer: 38 | unsummarized_runs = self.unsummarized_runs() 39 | runs_to_summarize = unsummarized_runs[:max(0, len(unsummarized_runs) - self.offset)] 40 | 41 | if len(runs_to_summarize) > 0: 42 | self.summary = self.summarizer.summarize(self, runs_to_summarize) 43 | self.summary_index = 1 + self.runs.index(runs_to_summarize[-1]) 44 | 45 | def to_prompt_string(self, last_n: Optional[int] = None): 46 | return J2("prompts/memory.j2").render( 47 | summary=self.summary, 48 | runs=self.unsummarized_runs(last_n) 49 | ) 50 | 51 | def to_dict(self) -> dict: 52 | return SummaryPipelineMemory().dump(self) 53 | 54 | @classmethod 55 | def from_dict(cls, memory_dict: dict) -> PipelineMemory: 56 | return SummaryPipelineMemory().load(memory_dict) 57 | 58 | @classmethod 59 | def from_json(cls, memory_json: str) -> PipelineMemory: 60 | return SummaryPipelineMemory.from_dict(json.loads(memory_json)) 61 | -------------------------------------------------------------------------------- /griptape/flow/rules/__init__.py: -------------------------------------------------------------------------------- 1 | from griptape.flow.rules.rule import Rule 2 | from . import json 3 | from . import meta 4 | 5 | 6 | __all__ = [ 7 | "Rule", 8 | "json", 9 | "meta" 10 | ] 11 | -------------------------------------------------------------------------------- /griptape/flow/rules/json.py: -------------------------------------------------------------------------------- 1 | from griptape.flow.rules.rule import Rule 2 | 3 | 4 | def return_valid_json() -> Rule: 5 | return Rule( 6 | "only output valid JSON" 7 | ) 8 | 9 | 10 | def return_array() -> Rule: 11 | return Rule( 12 | "only output a valid JSON array" 13 | ) 14 | 15 | 16 | def return_object() -> Rule: 17 | return Rule( 18 | "only output a valid JSON object" 19 | ) 20 | 21 | 22 | def put_answer_in_field(value: str) -> Rule: 23 | return Rule( 24 | f"only output a valid JSON object with your answer in '{value}'" 25 | ) 26 | -------------------------------------------------------------------------------- /griptape/flow/rules/meta.py: -------------------------------------------------------------------------------- 1 | from griptape.flow.rules.rule import Rule 2 | 3 | 4 | def be_truthful() -> Rule: 5 | return Rule( 6 | "be truthful and say \"I don't know\" if you don't have the knowledge to answer a question" 7 | ) 8 | 9 | 10 | def speculate() -> Rule: 11 | return Rule( 12 | "say \"I don't know\" if you don't know the answer to the question but also be creative and speculate what the " 13 | "possible answer could be" 14 | ) 15 | 16 | 17 | def your_name_is(name: str) -> Rule: 18 | return Rule( 19 | f"respond to name \"{name}\"" 20 | ) 21 | -------------------------------------------------------------------------------- /griptape/flow/rules/rule.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | from attr import define 3 | 4 | 5 | @define(frozen=True) 6 | class Rule: 7 | value: str 8 | 9 | -------------------------------------------------------------------------------- /griptape/flow/schemas/__init__.py: -------------------------------------------------------------------------------- 1 | from griptape.flow.schemas.base_schema import BaseSchema 2 | 3 | from griptape.flow.schemas.polymorphic_schema import PolymorphicSchema 4 | 5 | from griptape.flow.schemas.rule_schema import RuleSchema 6 | 7 | from griptape.flow.schemas.tokenizers.tiktoken_tokenizer_schema import TiktokenTokenizerSchema 8 | 9 | from griptape.flow.schemas.drivers.prompt_driver_schema import PromptDriverSchema 10 | from griptape.flow.schemas.drivers.openai_prompt_driver_schema import OpenAiPromptDriverSchema 11 | 12 | from griptape.flow.schemas.steps.step_schema import StepSchema 13 | from griptape.flow.schemas.steps.prompt_step_schema import PromptStepSchema 14 | from griptape.flow.schemas.steps.toolkit_step_schema import ToolkitStepSchema 15 | 16 | from griptape.flow.schemas.summarizers.summarizer_schema import SummarizerSchema 17 | from griptape.flow.schemas.summarizers.prompt_driver_summarizer_schema import PromptDriverSummarizerSchema 18 | 19 | from griptape.flow.schemas.memory.pipeline_run_schema import PipelineRunSchema 20 | from griptape.flow.schemas.memory.pipeline_memory_schema import PipelineMemorySchema 21 | from griptape.flow.schemas.memory.buffer_pipeline_memory_schema import BufferPipelineMemorySchema 22 | from griptape.flow.schemas.memory.summary_pipeline_memory_schema import SummaryPipelineMemorySchema 23 | 24 | from griptape.flow.schemas.structures.structure_schema import StructureSchema 25 | from griptape.flow.schemas.structures.pipeline_schema import PipelineSchema 26 | from griptape.flow.schemas.structures.workflow_schema import WorkflowSchema 27 | 28 | __all__ = [ 29 | "BaseSchema", 30 | 31 | "PolymorphicSchema", 32 | 33 | "RuleSchema", 34 | 35 | "TiktokenTokenizerSchema", 36 | 37 | "PromptDriverSchema", 38 | "OpenAiPromptDriverSchema", 39 | 40 | "StepSchema", 41 | "PromptStepSchema", 42 | "ToolkitStepSchema", 43 | 44 | "SummarizerSchema", 45 | "PromptDriverSummarizerSchema", 46 | 47 | "PipelineRunSchema", 48 | "PipelineMemorySchema", 49 | "BufferPipelineMemorySchema", 50 | "SummaryPipelineMemorySchema", 51 | 52 | "StructureSchema", 53 | "PipelineSchema", 54 | "WorkflowSchema" 55 | ] 56 | -------------------------------------------------------------------------------- /griptape/flow/schemas/base_schema.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod 2 | from marshmallow import Schema, fields 3 | 4 | 5 | class BaseSchema(Schema): 6 | schema_namespace = fields.Str(allow_none=True) 7 | 8 | @abstractmethod 9 | def make_obj(self, data, **kwargs): 10 | ... 11 | -------------------------------------------------------------------------------- /griptape/flow/schemas/drivers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/griptape-ai/griptape-flow/c5e88f53a1fa6088ae3a281b300c6b3478a8e68e/griptape/flow/schemas/drivers/__init__.py -------------------------------------------------------------------------------- /griptape/flow/schemas/drivers/openai_prompt_driver_schema.py: -------------------------------------------------------------------------------- 1 | from marshmallow import fields, post_load 2 | from griptape.flow.schemas import PromptDriverSchema 3 | 4 | 5 | class OpenAiPromptDriverSchema(PromptDriverSchema): 6 | api_type = fields.Str() 7 | api_version = fields.Str(allow_none=True) 8 | api_base = fields.Str() 9 | api_key = fields.Str(allow_none=True) 10 | organization = fields.Str(allow_none=True) 11 | model = fields.Str() 12 | temperature = fields.Float() 13 | user = fields.Str() 14 | 15 | @post_load 16 | def make_obj(self, data, **kwargs): 17 | from griptape.flow.drivers import OpenAiPromptDriver 18 | 19 | return OpenAiPromptDriver(**data) 20 | -------------------------------------------------------------------------------- /griptape/flow/schemas/drivers/prompt_driver_schema.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod 2 | from marshmallow import fields 3 | from griptape.flow.schemas import BaseSchema, PolymorphicSchema 4 | 5 | 6 | class PromptDriverSchema(BaseSchema): 7 | class Meta: 8 | ordered = True 9 | 10 | max_retries = fields.Int() 11 | retry_delay = fields.Float() 12 | model = fields.Str() 13 | tokenizer = fields.Nested(PolymorphicSchema()) 14 | 15 | @abstractmethod 16 | def make_obj(self, data, **kwargs): 17 | ... 18 | -------------------------------------------------------------------------------- /griptape/flow/schemas/memory/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/griptape-ai/griptape-flow/c5e88f53a1fa6088ae3a281b300c6b3478a8e68e/griptape/flow/schemas/memory/__init__.py -------------------------------------------------------------------------------- /griptape/flow/schemas/memory/buffer_pipeline_memory_schema.py: -------------------------------------------------------------------------------- 1 | from marshmallow import fields, post_load 2 | from griptape.flow.schemas import BaseSchema, PipelineRunSchema, PipelineMemorySchema 3 | 4 | 5 | class BufferPipelineMemorySchema(PipelineMemorySchema): 6 | buffer_size = fields.Int() 7 | 8 | @post_load 9 | def make_obj(self, data, **kwargs): 10 | from griptape.flow.memory import BufferPipelineMemory 11 | 12 | return BufferPipelineMemory(**data) 13 | -------------------------------------------------------------------------------- /griptape/flow/schemas/memory/pipeline_memory_schema.py: -------------------------------------------------------------------------------- 1 | from marshmallow import fields, post_load 2 | from griptape.flow.schemas import BaseSchema, PipelineRunSchema 3 | 4 | 5 | class PipelineMemorySchema(BaseSchema): 6 | class Meta: 7 | ordered = True 8 | 9 | type = fields.Str(required=True) 10 | runs = fields.List(fields.Nested(PipelineRunSchema())) 11 | 12 | @post_load 13 | def make_obj(self, data, **kwargs): 14 | from griptape.flow.memory import PipelineMemory 15 | 16 | return PipelineMemory(**data) 17 | -------------------------------------------------------------------------------- /griptape/flow/schemas/memory/pipeline_run_schema.py: -------------------------------------------------------------------------------- 1 | from marshmallow import fields, post_load 2 | from griptape.flow.schemas import BaseSchema 3 | 4 | 5 | class PipelineRunSchema(BaseSchema): 6 | class Meta: 7 | ordered = True 8 | 9 | id = fields.Str() 10 | input = fields.Str() 11 | output = fields.Str() 12 | 13 | @post_load 14 | def make_obj(self, data, **kwargs): 15 | from griptape.flow.memory import PipelineRun 16 | 17 | return PipelineRun(**data) 18 | -------------------------------------------------------------------------------- /griptape/flow/schemas/memory/summary_pipeline_memory_schema.py: -------------------------------------------------------------------------------- 1 | from marshmallow import fields, post_load 2 | from griptape.flow.schemas import PipelineMemorySchema, PolymorphicSchema 3 | 4 | 5 | class SummaryPipelineMemorySchema(PipelineMemorySchema): 6 | offset = fields.Int() 7 | summary = fields.Str() 8 | summary_index = fields.Int() 9 | summarizer = fields.Nested(PolymorphicSchema()) 10 | 11 | @post_load 12 | def make_obj(self, data, **kwargs): 13 | from griptape.flow.memory import SummaryPipelineMemory 14 | 15 | return SummaryPipelineMemory(**data) 16 | -------------------------------------------------------------------------------- /griptape/flow/schemas/polymorphic_schema.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod 2 | from pydoc import locate 3 | from typing import Optional 4 | from marshmallow import ValidationError, Schema 5 | from griptape.flow.schemas import BaseSchema 6 | 7 | 8 | class PolymorphicSchema(BaseSchema): 9 | """ 10 | PolymorphicSchema is based on https://github.com/marshmallow-code/marshmallow-oneofschema 11 | """ 12 | 13 | def get_schema(self, class_name: str, obj: Optional[object], schema_namespace: Optional[str]): 14 | if schema_namespace: 15 | namespace = schema_namespace 16 | elif obj is not None and hasattr(obj, "schema_namespace"): 17 | if locate(f"griptape.flow.schemas.{class_name}Schema"): 18 | namespace = "griptape.flow.schemas" 19 | elif obj.schema_namespace is None: 20 | namespace = obj.schema_namespace = f"{obj.__module__}_schema" 21 | else: 22 | namespace = obj.schema_namespace 23 | else: 24 | namespace = "griptape.flow.schemas" 25 | 26 | klass = locate(f"{namespace}.{class_name}Schema") 27 | 28 | if klass: 29 | return klass 30 | else: 31 | raise ValidationError(f"Missing schema for '{class_name}'") 32 | 33 | type_field = "type" 34 | type_field_remove = True 35 | 36 | def get_obj_type(self, obj): 37 | """Returns name of the schema during dump() calls, given the object 38 | being dumped.""" 39 | return obj.__class__.__name__ 40 | 41 | def get_data_type(self, data): 42 | """Returns name of the schema during load() calls, given the data being 43 | loaded. Defaults to looking up `type_field` in the data.""" 44 | data_type = data.get(self.type_field) 45 | if self.type_field in data and self.type_field_remove: 46 | data.pop(self.type_field) 47 | return data_type 48 | 49 | def dump(self, obj, *, many=None, **kwargs): 50 | errors = {} 51 | result_data = [] 52 | result_errors = {} 53 | many = self.many if many is None else bool(many) 54 | if not many: 55 | result = result_data = self._dump(obj, **kwargs) 56 | else: 57 | for idx, o in enumerate(obj): 58 | try: 59 | result = self._dump(o, **kwargs) 60 | result_data.append(result) 61 | except ValidationError as error: 62 | result_errors[idx] = error.normalized_messages() 63 | result_data.append(error.valid_data) 64 | 65 | result = result_data 66 | errors = result_errors 67 | 68 | if not errors: 69 | return result 70 | else: 71 | exc = ValidationError(errors, data=obj, valid_data=result) 72 | raise exc 73 | 74 | def _dump(self, obj, *, update_fields=True, **kwargs): 75 | obj_type = self.get_obj_type(obj) 76 | 77 | if not obj_type: 78 | return ( 79 | None, 80 | {"_schema": "Unknown object class: %s" % obj.__class__.__name__}, 81 | ) 82 | 83 | type_schema = self.get_schema(obj_type, obj, None) 84 | 85 | if not type_schema: 86 | return None, {"_schema": "Unsupported object type: %s" % obj_type} 87 | 88 | schema = type_schema if isinstance(type_schema, Schema) else type_schema() 89 | 90 | schema.context.update(getattr(self, "context", {})) 91 | 92 | result = schema.dump(obj, many=False, **kwargs) 93 | 94 | if result is not None: 95 | result[self.type_field] = obj_type 96 | 97 | return result 98 | 99 | def load(self, data, *, many=None, partial=None, unknown=None, **kwargs): 100 | errors = {} 101 | result_data = [] 102 | result_errors = {} 103 | many = self.many if many is None else bool(many) 104 | if partial is None: 105 | partial = self.partial 106 | if not many: 107 | try: 108 | result = result_data = self._load( 109 | data, partial=partial, unknown=unknown, **kwargs 110 | ) 111 | # result_data.append(result) 112 | except ValidationError as error: 113 | result_errors = error.normalized_messages() 114 | result_data.append(error.valid_data) 115 | else: 116 | for idx, item in enumerate(data): 117 | try: 118 | result = self._load(item, partial=partial, **kwargs) 119 | result_data.append(result) 120 | except ValidationError as error: 121 | result_errors[idx] = error.normalized_messages() 122 | result_data.append(error.valid_data) 123 | 124 | result = result_data 125 | errors = result_errors 126 | 127 | if not errors: 128 | return result 129 | else: 130 | exc = ValidationError(errors, data=data, valid_data=result) 131 | raise exc 132 | 133 | def _load(self, data, *, partial=None, unknown=None, **kwargs): 134 | if not isinstance(data, dict): 135 | raise ValidationError({"_schema": "Invalid data type: %s" % data}) 136 | 137 | data = dict(data) 138 | unknown = unknown or self.unknown 139 | data_type = self.get_data_type(data) 140 | 141 | if data_type is None: 142 | raise ValidationError( 143 | {self.type_field: ["Missing data for required field."]} 144 | ) 145 | 146 | schema_namespace = data.get("schema_namespace") 147 | 148 | try: 149 | type_schema = self.get_schema(data_type, None, schema_namespace) 150 | except TypeError: 151 | # data_type could be unhashable 152 | raise ValidationError({self.type_field: ["Invalid value: %s" % data_type]}) 153 | if not type_schema: 154 | raise ValidationError( 155 | {self.type_field: ["Unsupported value: %s" % data_type]} 156 | ) 157 | 158 | schema = type_schema if isinstance(type_schema, Schema) else type_schema() 159 | 160 | schema.context.update(getattr(self, "context", {})) 161 | 162 | return schema.load(data, many=False, partial=partial, unknown=unknown, **kwargs) 163 | 164 | def validate(self, data, *, many=None, partial=None): 165 | try: 166 | self.load(data, many=many, partial=partial) 167 | except ValidationError as ve: 168 | return ve.messages 169 | return {} 170 | 171 | @abstractmethod 172 | def make_obj(self, data, **kwargs): 173 | ... 174 | -------------------------------------------------------------------------------- /griptape/flow/schemas/rule_schema.py: -------------------------------------------------------------------------------- 1 | from marshmallow import fields, post_load 2 | from griptape.flow.schemas import BaseSchema 3 | 4 | 5 | class RuleSchema(BaseSchema): 6 | value = fields.Str() 7 | 8 | @post_load 9 | def make_obj(self, data, **kwargs): 10 | from griptape.flow.rules import Rule 11 | 12 | return Rule(**data) 13 | -------------------------------------------------------------------------------- /griptape/flow/schemas/steps/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/griptape-ai/griptape-flow/c5e88f53a1fa6088ae3a281b300c6b3478a8e68e/griptape/flow/schemas/steps/__init__.py -------------------------------------------------------------------------------- /griptape/flow/schemas/steps/prompt_step_schema.py: -------------------------------------------------------------------------------- 1 | from marshmallow import fields, post_load 2 | from griptape.flow.schemas import StepSchema, PolymorphicSchema 3 | 4 | 5 | class PromptStepSchema(StepSchema): 6 | prompt_template = fields.Str(required=True) 7 | context = fields.Dict(keys=fields.Str(), values=fields.Raw()) 8 | driver = fields.Nested(PolymorphicSchema(), allow_none=True) 9 | 10 | @post_load 11 | def make_obj(self, data, **kwargs): 12 | from griptape.flow.steps import PromptStep 13 | 14 | return PromptStep(**data) 15 | -------------------------------------------------------------------------------- /griptape/flow/schemas/steps/step_schema.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod 2 | from marshmallow import fields 3 | from marshmallow_enum import EnumField 4 | from griptape.flow.schemas import BaseSchema 5 | from griptape.flow.steps import Step 6 | 7 | 8 | class StepSchema(BaseSchema): 9 | class Meta: 10 | ordered = True 11 | 12 | id = fields.Str() 13 | state = EnumField(Step.State) 14 | parent_ids = fields.List(fields.Str()) 15 | child_ids = fields.List(fields.Str()) 16 | 17 | @abstractmethod 18 | def make_obj(self, data, **kwargs): 19 | ... 20 | -------------------------------------------------------------------------------- /griptape/flow/schemas/steps/toolkit_step_schema.py: -------------------------------------------------------------------------------- 1 | from marshmallow import fields, post_load 2 | from griptape.flow.schemas import PolymorphicSchema, StepSchema 3 | 4 | 5 | class ToolkitStepSchema(StepSchema): 6 | prompt_template = fields.Str(required=True) 7 | max_substeps = fields.Int(allow_none=True) 8 | tool_names = fields.List(fields.Str(), required=True) 9 | context = fields.Dict(keys=fields.Str(), values=fields.Raw()) 10 | driver = fields.Nested(PolymorphicSchema(), allow_none=True) 11 | 12 | @post_load 13 | def make_obj(self, data, **kwargs): 14 | from griptape.flow.steps import ToolkitStep 15 | 16 | return ToolkitStep(**data) 17 | -------------------------------------------------------------------------------- /griptape/flow/schemas/structures/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/griptape-ai/griptape-flow/c5e88f53a1fa6088ae3a281b300c6b3478a8e68e/griptape/flow/schemas/structures/__init__.py -------------------------------------------------------------------------------- /griptape/flow/schemas/structures/pipeline_schema.py: -------------------------------------------------------------------------------- 1 | from marshmallow import post_load, fields 2 | from griptape.flow.schemas import StructureSchema 3 | 4 | 5 | class PipelineSchema(StructureSchema): 6 | autoprune_memory = fields.Bool() 7 | 8 | @post_load 9 | def make_obj(self, data, **kwargs): 10 | from griptape.flow.structures import Pipeline 11 | 12 | return Pipeline(**data) 13 | -------------------------------------------------------------------------------- /griptape/flow/schemas/structures/structure_schema.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod 2 | from marshmallow import fields 3 | from griptape.flow.schemas import PolymorphicSchema, RuleSchema, BaseSchema 4 | 5 | 6 | class StructureSchema(BaseSchema): 7 | class Meta: 8 | ordered = True 9 | 10 | id = fields.Str() 11 | type = fields.Str(required=True) 12 | prompt_driver = fields.Nested(PolymorphicSchema()) 13 | rules = fields.List(fields.Nested(RuleSchema())) 14 | steps = fields.List(fields.Nested(PolymorphicSchema())) 15 | 16 | @abstractmethod 17 | def make_obj(self, data, **kwargs): 18 | ... 19 | -------------------------------------------------------------------------------- /griptape/flow/schemas/structures/workflow_schema.py: -------------------------------------------------------------------------------- 1 | from marshmallow import post_load 2 | from griptape.flow.schemas import StructureSchema 3 | 4 | 5 | class WorkflowSchema(StructureSchema): 6 | @post_load 7 | def make_obj(self, data, **kwargs): 8 | from griptape.flow.structures import Workflow 9 | 10 | return Workflow(**data) 11 | -------------------------------------------------------------------------------- /griptape/flow/schemas/summarizers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/griptape-ai/griptape-flow/c5e88f53a1fa6088ae3a281b300c6b3478a8e68e/griptape/flow/schemas/summarizers/__init__.py -------------------------------------------------------------------------------- /griptape/flow/schemas/summarizers/prompt_driver_summarizer_schema.py: -------------------------------------------------------------------------------- 1 | from marshmallow import fields 2 | from griptape.flow.schemas import PolymorphicSchema, SummarizerSchema 3 | 4 | 5 | class PromptDriverSummarizerSchema(SummarizerSchema): 6 | driver = fields.Nested(PolymorphicSchema()) 7 | 8 | def make_obj(self, data, **kwargs): 9 | from griptape.flow.summarizers import PromptDriverSummarizer 10 | 11 | return PromptDriverSummarizer(**data) 12 | -------------------------------------------------------------------------------- /griptape/flow/schemas/summarizers/summarizer_schema.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod 2 | from marshmallow import fields 3 | from griptape.flow.schemas import BaseSchema 4 | 5 | 6 | class SummarizerSchema(BaseSchema): 7 | class Meta: 8 | ordered = True 9 | 10 | type = fields.Str(required=True) 11 | 12 | @abstractmethod 13 | def make_obj(self, data, **kwargs): 14 | ... 15 | -------------------------------------------------------------------------------- /griptape/flow/schemas/tokenizers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/griptape-ai/griptape-flow/c5e88f53a1fa6088ae3a281b300c6b3478a8e68e/griptape/flow/schemas/tokenizers/__init__.py -------------------------------------------------------------------------------- /griptape/flow/schemas/tokenizers/tiktoken_tokenizer_schema.py: -------------------------------------------------------------------------------- 1 | from marshmallow import fields, post_load 2 | from griptape.flow.schemas import BaseSchema 3 | 4 | 5 | class TiktokenTokenizerSchema(BaseSchema): 6 | model = fields.Str() 7 | stop_sequence = fields.Str() 8 | 9 | @post_load 10 | def make_obj(self, data, **kwargs): 11 | from griptape.flow.tokenizers import TiktokenTokenizer 12 | 13 | return TiktokenTokenizer(**data) 14 | -------------------------------------------------------------------------------- /griptape/flow/steps/__init__.py: -------------------------------------------------------------------------------- 1 | from griptape.flow.steps.step import Step 2 | from griptape.flow.steps.prompt_step import PromptStep 3 | from griptape.flow.steps.tool_substep import ToolSubstep 4 | from griptape.flow.steps.toolkit_step import ToolkitStep 5 | 6 | __all__ = [ 7 | "Step", 8 | "PromptStep", 9 | "ToolSubstep", 10 | "ToolkitStep" 11 | ] 12 | -------------------------------------------------------------------------------- /griptape/flow/steps/prompt_step.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | from typing import TYPE_CHECKING, Optional 3 | from attr import define, field 4 | from griptape.flow.utils import J2 5 | from griptape.flow.steps import Step 6 | from griptape.flow.artifacts import TextOutput 7 | 8 | if TYPE_CHECKING: 9 | from griptape.flow.drivers import BasePromptDriver 10 | 11 | 12 | @define 13 | class PromptStep(Step): 14 | prompt_template: str = field(default="{{ args[0] }}") 15 | context: dict[str, any] = field(factory=dict, kw_only=True) 16 | driver: Optional[BasePromptDriver] = field(default=None, kw_only=True) 17 | 18 | def before_run(self) -> None: 19 | super().before_run() 20 | 21 | self.structure.logger.info(f"Step {self.id}\nInput: {self.render_prompt()}") 22 | 23 | def run(self) -> TextOutput: 24 | self.output = self.active_driver().run(value=self.structure.to_prompt_string(self)) 25 | 26 | return self.output 27 | 28 | def after_run(self) -> None: 29 | super().after_run() 30 | 31 | self.structure.logger.info(f"Step {self.id}\nOutput: {self.output.value}") 32 | 33 | def active_driver(self) -> BasePromptDriver: 34 | if self.driver is None: 35 | return self.structure.prompt_driver 36 | else: 37 | return self.driver 38 | 39 | def render_prompt(self) -> str: 40 | return J2().render_from_string( 41 | self.prompt_template, 42 | **self.full_context 43 | ) 44 | 45 | def render(self) -> str: 46 | return J2("prompts/steps/prompt.j2").render( 47 | step=self 48 | ) 49 | 50 | @property 51 | def full_context(self) -> dict[str, any]: 52 | structure_context = self.structure.context(self) 53 | 54 | structure_context.update(self.context) 55 | 56 | return structure_context 57 | -------------------------------------------------------------------------------- /griptape/flow/steps/step.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | import uuid 3 | from abc import ABC, abstractmethod 4 | from enum import Enum 5 | from typing import TYPE_CHECKING, Optional 6 | from attr import define, field, Factory 7 | from griptape.flow.artifacts import ErrorOutput 8 | 9 | if TYPE_CHECKING: 10 | from griptape.flow.artifacts import TextOutput, StructureArtifact 11 | from griptape.flow.steps import Step 12 | from griptape.flow.structures import Structure 13 | 14 | 15 | @define 16 | class Step(ABC): 17 | class State(Enum): 18 | PENDING = 1 19 | EXECUTING = 2 20 | FINISHED = 3 21 | 22 | id: str = field(default=Factory(lambda: uuid.uuid4().hex), kw_only=True) 23 | state: State = field(default=State.PENDING, kw_only=True) 24 | type: str = field(default=Factory(lambda self: self.__class__.__name__, takes_self=True), kw_only=True) 25 | parent_ids: list[str] = field(factory=list, kw_only=True) 26 | child_ids: list[str] = field(factory=list, kw_only=True) 27 | 28 | output: Optional[StructureArtifact] = field(default=None, init=False) 29 | structure: Optional[Structure] = field(default=None, init=False) 30 | 31 | @property 32 | def parents(self) -> list[Step]: 33 | return [self.structure.find_step(parent_id) for parent_id in self.parent_ids] 34 | 35 | @property 36 | def children(self) -> list[Step]: 37 | return [self.structure.find_step(child_id) for child_id in self.child_ids] 38 | 39 | def add_child(self, child: Step) -> Step: 40 | if self.structure: 41 | child.structure = self.structure 42 | elif child.structure: 43 | self.structure = child.structure 44 | 45 | if child not in self.structure.steps: 46 | self.structure.steps.append(child) 47 | 48 | if self not in self.structure.steps: 49 | self.structure.steps.append(self) 50 | 51 | if child.id not in self.child_ids: 52 | self.child_ids.append(child.id) 53 | 54 | if self.id not in child.parent_ids: 55 | child.parent_ids.append(self.id) 56 | 57 | return child 58 | 59 | def add_parent(self, parent: Step) -> Step: 60 | if self.structure: 61 | parent.structure = self.structure 62 | elif parent.structure: 63 | self.structure = parent.structure 64 | 65 | if parent not in self.structure.steps: 66 | self.structure.steps.append(parent) 67 | 68 | if self not in self.structure.steps: 69 | self.structure.steps.append(self) 70 | 71 | if parent.id not in self.parent_ids: 72 | self.parent_ids.append(parent.id) 73 | 74 | if self.id not in parent.child_ids: 75 | parent.child_ids.append(self.id) 76 | 77 | return parent 78 | 79 | def is_pending(self) -> bool: 80 | return self.state == Step.State.PENDING 81 | 82 | def is_finished(self) -> bool: 83 | return self.state == Step.State.FINISHED 84 | 85 | def is_executing(self) -> bool: 86 | return self.state == Step.State.EXECUTING 87 | 88 | def before_run(self) -> None: 89 | pass 90 | 91 | def after_run(self) -> None: 92 | pass 93 | 94 | def execute(self) -> StructureArtifact: 95 | try: 96 | self.state = Step.State.EXECUTING 97 | 98 | self.before_run() 99 | 100 | self.output = self.run() 101 | 102 | self.after_run() 103 | except Exception as e: 104 | self.structure.logger.error(f"Step {self.id}\n{e}", exc_info=True) 105 | 106 | self.output = ErrorOutput(str(e), exception=e, step=self) 107 | finally: 108 | self.state = Step.State.FINISHED 109 | 110 | return self.output 111 | 112 | def can_execute(self) -> bool: 113 | return self.state == Step.State.PENDING and all(parent.is_finished() for parent in self.parents) 114 | 115 | def reset(self) -> Step: 116 | self.state = Step.State.PENDING 117 | self.output = None 118 | 119 | return self 120 | 121 | @abstractmethod 122 | def run(self) -> TextOutput: 123 | ... 124 | -------------------------------------------------------------------------------- /griptape/flow/steps/tool_substep.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | import ast 3 | import json 4 | import re 5 | from typing import TYPE_CHECKING, Optional 6 | from attr import define, field 7 | from jsonschema.exceptions import ValidationError 8 | from jsonschema.validators import validate 9 | from griptape.flow.artifacts import TextOutput, ErrorOutput 10 | from griptape.flow.steps import PromptStep 11 | from griptape.core import BaseTool 12 | from griptape.flow.utils import J2 13 | 14 | if TYPE_CHECKING: 15 | from griptape.flow.artifacts import StructureArtifact 16 | from griptape.flow.steps import ToolkitStep 17 | 18 | 19 | @define 20 | class ToolSubstep(PromptStep): 21 | THOUGHT_PATTERN = r"^Thought:\s*(.*)$" 22 | ACTION_PATTERN = r"^Action:\s*({.*})$" 23 | OUTPUT_PATTERN = r"^Output:\s?([\s\S]*)$" 24 | INVALID_ACTION_ERROR_MSG = f"invalid action input, try again" 25 | 26 | parent_step_id: Optional[str] = field(default=None, kw_only=True) 27 | thought: Optional[str] = field(default=None, kw_only=True) 28 | tool_name: Optional[str] = field(default=None, kw_only=True) 29 | tool_action: Optional[str] = field(default=None, kw_only=True) 30 | tool_value: Optional[str] = field(default=None, kw_only=True) 31 | 32 | _tool: Optional[BaseTool] = None 33 | 34 | def attach(self, parent_step: ToolkitStep): 35 | self.parent_step_id = parent_step.id 36 | self.structure = parent_step.structure 37 | self.__init_from_prompt(self.render_prompt()) 38 | 39 | @property 40 | def toolkit_step(self) -> Optional[ToolkitStep]: 41 | return self.structure.find_step(self.parent_step_id) 42 | 43 | @property 44 | def parents(self) -> list[ToolSubstep]: 45 | return [self.toolkit_step.find_substep(parent_id) for parent_id in self.parent_ids] 46 | 47 | @property 48 | def children(self) -> list[ToolSubstep]: 49 | return [self.toolkit_step.find_substep(child_id) for child_id in self.child_ids] 50 | 51 | def before_run(self) -> None: 52 | self.structure.logger.info(f"Substep {self.id}\n{self.render_prompt()}") 53 | 54 | def run(self) -> StructureArtifact: 55 | try: 56 | if self.tool_name == "error": 57 | self.output = ErrorOutput(self.tool_value, step=self) 58 | else: 59 | if self._tool: 60 | observation = self.structure.tool_loader.executor.execute( 61 | getattr(self._tool, self.tool_action), 62 | self.tool_value.encode() 63 | ).decode() 64 | else: 65 | observation = "tool not found" 66 | 67 | self.output = TextOutput(observation) 68 | except Exception as e: 69 | self.structure.logger.error(f"Substep {self.id}\n{e}", exc_info=True) 70 | 71 | self.output = ErrorOutput(str(e), exception=e, step=self) 72 | finally: 73 | return self.output 74 | 75 | def after_run(self) -> None: 76 | self.structure.logger.info(f"Substep {self.id}\nObservation: {self.output.value}") 77 | 78 | def render(self) -> str: 79 | return J2("prompts/steps/tool/substep.j2").render( 80 | substep=self 81 | ) 82 | 83 | def to_json(self) -> str: 84 | json_dict = {} 85 | 86 | if self.tool_name: 87 | json_dict["tool"] = self.tool_name 88 | 89 | if self.tool_action: 90 | json_dict["action"] = self.tool_action 91 | 92 | if self.tool_value: 93 | json_dict["value"] = self.tool_value 94 | 95 | return json.dumps(json_dict) 96 | 97 | def add_child(self, child: ToolSubstep) -> ToolSubstep: 98 | if child.id not in self.child_ids: 99 | self.child_ids.append(child.id) 100 | 101 | if self.id not in child.parent_ids: 102 | child.parent_ids.append(self.id) 103 | 104 | return child 105 | 106 | def add_parent(self, parent: ToolSubstep) -> ToolSubstep: 107 | if parent.id not in self.parent_ids: 108 | self.parent_ids.append(parent.id) 109 | 110 | if self.id not in parent.child_ids: 111 | parent.child_ids.append(self.id) 112 | 113 | return parent 114 | 115 | def __init_from_prompt(self, value: str) -> None: 116 | thought_matches = re.findall(self.THOUGHT_PATTERN, value, re.MULTILINE) 117 | action_matches = re.findall(self.ACTION_PATTERN, value, re.MULTILINE) 118 | output_matches = re.findall(self.OUTPUT_PATTERN, value, re.MULTILINE) 119 | 120 | if self.thought is None and len(thought_matches) > 0: 121 | self.thought = thought_matches[-1] 122 | 123 | if len(action_matches) > 0: 124 | try: 125 | parsed_value = ast.literal_eval(action_matches[-1]) 126 | 127 | # Load the tool name; throw exception if the key is not present 128 | if self.tool_name is None: 129 | self.tool_name = parsed_value["tool"] 130 | 131 | # Load the tool action; throw exception if the key is not present 132 | if self.tool_action is None: 133 | self.tool_action = parsed_value["action"] 134 | 135 | # Load the tool itself 136 | if self.tool_name: 137 | self._tool = self.toolkit_step.find_tool(self.tool_name) 138 | 139 | # Validate input based on tool schema 140 | if self._tool: 141 | validate( 142 | instance=parsed_value["value"], 143 | schema=self._tool.action_schema(getattr(self._tool, self.tool_action)) 144 | ) 145 | 146 | # Load optional input value; don't throw exceptions if key is not present 147 | if self.tool_value is None: 148 | self.tool_value = str(parsed_value.get("value")) 149 | 150 | except SyntaxError as e: 151 | self.structure.logger.error(f"Step {self.toolkit_step.id}\nSyntax error: {e}") 152 | 153 | self.tool_name = "error" 154 | self.tool_value = f"syntax error: {e}" 155 | except ValidationError as e: 156 | self.structure.logger.error(f"Step {self.toolkit_step.id}\nInvalid JSON: {e}") 157 | 158 | self.tool_name = "error" 159 | self.tool_value = f"JSON validation error: {e}" 160 | except Exception as e: 161 | self.structure.logger.error(f"Step {self.toolkit_step.id}\nError parsing tool action: {e}") 162 | 163 | self.tool_name = "error" 164 | self.tool_value = f"error: {self.INVALID_ACTION_ERROR_MSG}" 165 | elif self.output is None and len(output_matches) > 0: 166 | self.output = TextOutput(output_matches[-1]) 167 | -------------------------------------------------------------------------------- /griptape/flow/steps/toolkit_step.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | from abc import ABC, abstractmethod 3 | from typing import TYPE_CHECKING, Optional 4 | from attr import define, field 5 | from griptape.core import BaseTool 6 | from griptape.flow.utils import J2 7 | from griptape.flow.steps import PromptStep 8 | from griptape.flow.artifacts import TextOutput, ErrorOutput 9 | 10 | if TYPE_CHECKING: 11 | from griptape.flow.steps import ToolSubstep 12 | 13 | 14 | @define 15 | class ToolkitStep(PromptStep, ABC): 16 | DEFAULT_MAX_SUBSTEPS = 20 17 | 18 | tool_names: list[str] = field(kw_only=True) 19 | max_substeps: int = field(default=DEFAULT_MAX_SUBSTEPS, kw_only=True) 20 | _substeps: list[ToolSubstep] = field(factory=list) 21 | 22 | @tool_names.validator 23 | def validate_tool_names(self, _, tool_names) -> None: 24 | if len(tool_names) > len(set(tool_names)): 25 | raise ValueError("tool names have to be unique") 26 | 27 | @property 28 | def tools(self) -> list[BaseTool]: 29 | return [ 30 | t for t in [self.structure.tool_loader.load_tool(t) for t in self.tool_names] if t is not None 31 | ] 32 | 33 | def run(self) -> TextOutput: 34 | from griptape.flow.steps import ToolSubstep 35 | 36 | self._substeps.clear() 37 | 38 | substep = self.add_substep( 39 | ToolSubstep( 40 | self.active_driver().run(value=self.structure.to_prompt_string(self)).value 41 | ) 42 | ) 43 | 44 | while True: 45 | if substep.output is None: 46 | if len(self._substeps) >= self.max_substeps: 47 | substep.output = ErrorOutput( 48 | f"Exceeded maximum tool execution limit of {self.max_substeps} per step", 49 | step=self 50 | ) 51 | elif substep.tool_name is None: 52 | # handle case when the LLM failed to follow the ReAct prompt and didn't return a proper action 53 | substep.output = TextOutput(substep.prompt_template) 54 | else: 55 | substep.before_run() 56 | substep.run() 57 | substep.after_run() 58 | 59 | substep = self.add_substep( 60 | ToolSubstep( 61 | self.active_driver().run(value=self.structure.to_prompt_string(self)).value 62 | ) 63 | ) 64 | else: 65 | break 66 | 67 | self.output = substep.output 68 | 69 | return self.output 70 | 71 | def render(self) -> str: 72 | return J2("prompts/steps/tool/tool.j2").render( 73 | step=self, 74 | substeps=self._substeps 75 | ) 76 | 77 | def find_substep(self, step_id: str) -> Optional[ToolSubstep]: 78 | return next((step for step in self._substeps if step.id == step_id), None) 79 | 80 | def add_substep(self, substep: ToolSubstep) -> ToolSubstep: 81 | substep.attach(self) 82 | 83 | if len(self._substeps) > 0: 84 | self._substeps[-1].add_child(substep) 85 | 86 | self._substeps.append(substep) 87 | 88 | return substep 89 | 90 | def find_tool(self, tool_name: str) -> Optional[BaseTool]: 91 | return next( 92 | (t for t in self.tools if t.name == tool_name), 93 | None 94 | ) 95 | -------------------------------------------------------------------------------- /griptape/flow/structures/__init__.py: -------------------------------------------------------------------------------- 1 | from griptape.flow.structures.structure import Structure 2 | from griptape.flow.structures.pipeline import Pipeline 3 | from griptape.flow.structures.workflow import Workflow 4 | 5 | 6 | __all__ = [ 7 | "Structure", 8 | "Pipeline", 9 | "Workflow" 10 | ] 11 | -------------------------------------------------------------------------------- /griptape/flow/structures/pipeline.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | import json 3 | from typing import TYPE_CHECKING, Optional 4 | from attr import define, field 5 | from griptape.flow.artifacts import ErrorOutput 6 | from griptape.flow.structures import Structure 7 | from griptape.flow.memory import PipelineMemory, PipelineRun 8 | from griptape.flow.utils import J2 9 | 10 | if TYPE_CHECKING: 11 | from griptape.flow.steps import Step 12 | 13 | 14 | @define 15 | class Pipeline(Structure): 16 | memory: Optional[PipelineMemory] = field(default=None, kw_only=True) 17 | autoprune_memory: bool = field(default=True, kw_only=True) 18 | 19 | def __attrs_post_init__(self): 20 | super().__attrs_post_init__() 21 | 22 | if self.memory: 23 | self.memory.pipeline = self 24 | 25 | def first_step(self) -> Optional[Step]: 26 | return None if self.is_empty() else self.steps[0] 27 | 28 | def last_step(self) -> Optional[Step]: 29 | return None if self.is_empty() else self.steps[-1] 30 | 31 | def finished_steps(self) -> list[Step]: 32 | return [s for s in self.steps if s.is_finished()] 33 | 34 | def add_step(self, step: Step) -> Step: 35 | if self.last_step(): 36 | self.last_step().add_child(step) 37 | else: 38 | step.structure = self 39 | 40 | self.steps.append(step) 41 | 42 | return step 43 | 44 | def prompt_stack(self, step: Step) -> list[str]: 45 | final_stack = super().prompt_stack(step) 46 | step_prompt = J2("prompts/pipeline.j2").render( 47 | has_memory=self.memory is not None, 48 | finished_steps=self.finished_steps(), 49 | current_step=step 50 | ) 51 | 52 | if self.memory: 53 | if self.autoprune_memory: 54 | last_n = len(self.memory.runs) 55 | should_prune = True 56 | 57 | while should_prune and last_n > 0: 58 | temp_stack = final_stack.copy() 59 | temp_stack.append(step_prompt) 60 | 61 | temp_stack.append(self.memory.to_prompt_string(last_n)) 62 | 63 | if self.prompt_driver.tokenizer.tokens_left(self.stack_to_prompt_string(temp_stack)) > 0: 64 | should_prune = False 65 | else: 66 | last_n -= 1 67 | 68 | if last_n > 0: 69 | final_stack.append(self.memory.to_prompt_string(last_n)) 70 | else: 71 | final_stack.append(self.memory.to_prompt_string()) 72 | 73 | final_stack.append(step_prompt) 74 | 75 | return final_stack 76 | 77 | def run(self, *args) -> Step: 78 | self._execution_args = args 79 | 80 | [step.reset() for step in self.steps] 81 | 82 | self.__run_from_step(self.first_step()) 83 | 84 | if self.memory: 85 | run = PipelineRun( 86 | input=self.first_step().render_prompt(), 87 | output=self.last_step().output.value 88 | ) 89 | 90 | self.memory.add_run(run) 91 | 92 | self._execution_args = () 93 | 94 | return self.last_step() 95 | 96 | def context(self, step: Step) -> dict[str, any]: 97 | context = super().context(step) 98 | 99 | context.update( 100 | { 101 | "input": step.parents[0].output.value if step.parents and step.parents[0].output else None, 102 | "parent": step.parents[0] if step.parents else None, 103 | "child": step.children[0] if step.children else None 104 | } 105 | ) 106 | 107 | return context 108 | 109 | def to_dict(self) -> dict: 110 | from griptape.flow.schemas import PipelineSchema 111 | 112 | return PipelineSchema().dump(self) 113 | 114 | @classmethod 115 | def from_dict(cls, pipeline_dict: dict) -> Pipeline: 116 | from griptape.flow.schemas import PipelineSchema 117 | 118 | return PipelineSchema().load(pipeline_dict) 119 | 120 | @classmethod 121 | def from_json(cls, pipeline_json: str) -> Pipeline: 122 | return Pipeline.from_dict(json.loads(pipeline_json)) 123 | 124 | def __run_from_step(self, step: Optional[Step]) -> None: 125 | if step is None: 126 | return 127 | else: 128 | if isinstance(step.execute(), ErrorOutput): 129 | return 130 | else: 131 | self.__run_from_step(next(iter(step.children), None)) 132 | -------------------------------------------------------------------------------- /griptape/flow/structures/structure.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | import json 3 | import logging 4 | import uuid 5 | from abc import ABC, abstractmethod 6 | from logging import Logger 7 | from typing import Optional, Union, TYPE_CHECKING 8 | from attr import define, field, Factory 9 | from rich.logging import RichHandler 10 | from griptape.flow.drivers import BasePromptDriver, OpenAiPromptDriver 11 | from griptape.flow.utils import J2, ToolLoader 12 | 13 | if TYPE_CHECKING: 14 | from griptape.flow.rules import Rule 15 | from griptape.flow.steps import Step 16 | 17 | 18 | @define 19 | class Structure(ABC): 20 | LOGGER_NAME = "griptape-flow" 21 | 22 | id: str = field(default=Factory(lambda: uuid.uuid4().hex), kw_only=True) 23 | type: str = field(default=Factory(lambda self: self.__class__.__name__, takes_self=True), kw_only=True) 24 | prompt_driver: BasePromptDriver = field(default=OpenAiPromptDriver(), kw_only=True) 25 | rules: list[Rule] = field(factory=list, kw_only=True) 26 | steps: list[Step] = field(factory=list, kw_only=True) 27 | custom_logger: Optional[Logger] = field(default=None, kw_only=True) 28 | logger_level: int = field(default=logging.INFO, kw_only=True) 29 | tool_loader: ToolLoader = field(default=ToolLoader(), kw_only=True) 30 | 31 | _execution_args: tuple = () 32 | _logger: Optional[Logger] = None 33 | 34 | def __attrs_post_init__(self): 35 | for step in self.steps: 36 | step.structure = self 37 | 38 | @property 39 | def execution_args(self) -> tuple: 40 | return self._execution_args 41 | 42 | @property 43 | def logger(self) -> Logger: 44 | if self.custom_logger: 45 | return self.custom_logger 46 | else: 47 | if self._logger is None: 48 | self._logger = logging.getLogger(self.LOGGER_NAME) 49 | 50 | self._logger.propagate = False 51 | self._logger.level = self.logger_level 52 | 53 | self._logger.handlers = [ 54 | RichHandler( 55 | show_time=True, 56 | show_path=False 57 | ) 58 | ] 59 | 60 | return self._logger 61 | 62 | def is_finished(self) -> bool: 63 | return all(s.is_finished() for s in self.steps) 64 | 65 | def is_executing(self) -> bool: 66 | return any(s for s in self.steps if s.is_executing()) 67 | 68 | def is_empty(self) -> bool: 69 | return not self.steps 70 | 71 | def find_step(self, step_id: str) -> Optional[Step]: 72 | return next((step for step in self.steps if step.id == step_id), None) 73 | 74 | def add_steps(self, *steps: Step) -> list[Step]: 75 | return [self.add_step(s) for s in steps] 76 | 77 | def prompt_stack(self, step: Step) -> list[str]: 78 | from griptape.flow.steps import ToolkitStep 79 | 80 | tools = step.tools if isinstance(step, ToolkitStep) else [] 81 | 82 | stack = [ 83 | J2("prompts/context.j2").render( 84 | rules=self.rules, 85 | tool_names=str.join(", ", [tool.name for tool in tools]), 86 | tools=[J2("prompts/tool.j2").render(tool=tool) for tool in tools] 87 | ) 88 | ] 89 | 90 | return stack 91 | 92 | def to_prompt_string(self, step: Step) -> str: 93 | return self.stack_to_prompt_string(self.prompt_stack(step)) 94 | 95 | def stack_to_prompt_string(self, stack: list[str]) -> str: 96 | return str.join("\n", stack) 97 | 98 | def to_json(self) -> str: 99 | return json.dumps(self.to_dict(), indent=2) 100 | 101 | def context(self, step: Step) -> dict[str, any]: 102 | return { 103 | "args": self.execution_args, 104 | "structure": self, 105 | } 106 | 107 | @abstractmethod 108 | def add_step(self, step: Step) -> Step: 109 | ... 110 | 111 | @abstractmethod 112 | def run(self, *args) -> Union[Step, list[Step]]: 113 | ... 114 | 115 | @abstractmethod 116 | def to_dict(self) -> dict: 117 | ... 118 | 119 | @classmethod 120 | @abstractmethod 121 | def from_dict(cls, workflow_dict: dict) -> Structure: 122 | ... 123 | 124 | @classmethod 125 | @abstractmethod 126 | def from_json(cls, workflow_json: str) -> Structure: 127 | ... 128 | -------------------------------------------------------------------------------- /griptape/flow/structures/workflow.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | import concurrent.futures as futures 3 | import json 4 | from graphlib import TopologicalSorter 5 | from attr import define, field 6 | from griptape.flow.artifacts import ErrorOutput 7 | from griptape.flow.steps import Step 8 | from griptape.flow.structures import Structure 9 | from griptape.flow.utils import J2 10 | 11 | 12 | @define 13 | class Workflow(Structure): 14 | executor: futures.Executor = field(default=futures.ThreadPoolExecutor(), kw_only=True) 15 | 16 | def add_step(self, step: Step) -> Step: 17 | step.structure = self 18 | 19 | self.steps.append(step) 20 | 21 | return step 22 | 23 | def prompt_stack(self, step: Step) -> list[str]: 24 | stack = Structure.prompt_stack(self, step) 25 | 26 | stack.append( 27 | J2("prompts/workflow.j2").render( 28 | step=step 29 | ) 30 | ) 31 | 32 | return stack 33 | 34 | def run(self, *args) -> list[Step]: 35 | self._execution_args = args 36 | ordered_steps = self.order_steps() 37 | exit_loop = False 38 | 39 | while not self.is_finished() and not exit_loop: 40 | futures_list = {} 41 | 42 | for step in ordered_steps: 43 | if step.can_execute(): 44 | future = self.executor.submit(step.execute) 45 | futures_list[future] = step 46 | 47 | # Wait for all tasks to complete 48 | for future in futures.as_completed(futures_list): 49 | if isinstance(future.result(), ErrorOutput): 50 | exit_loop = True 51 | 52 | break 53 | 54 | self._execution_args = () 55 | 56 | return self.output_steps() 57 | 58 | def context(self, step: Step) -> dict[str, any]: 59 | context = super().context(step) 60 | 61 | context.update( 62 | { 63 | "inputs": {parent.id: parent.output.value if parent.output else "" for parent in step.parents}, 64 | "parents": {parent.id: parent for parent in step.parents}, 65 | "children": {child.id: child for child in step.children} 66 | } 67 | ) 68 | 69 | return context 70 | 71 | def output_steps(self) -> list[Step]: 72 | return [step for step in self.steps if not step.children] 73 | 74 | def to_graph(self) -> dict[str, set[str]]: 75 | graph: dict[str, set[str]] = {} 76 | 77 | for key_step in self.steps: 78 | graph[key_step.id] = set() 79 | 80 | for value_step in self.steps: 81 | if key_step.id in value_step.child_ids: 82 | graph[key_step.id].add(value_step.id) 83 | 84 | return graph 85 | 86 | def order_steps(self) -> list[Step]: 87 | return [self.find_step(step_id) for step_id in TopologicalSorter(self.to_graph()).static_order()] 88 | 89 | def to_dict(self) -> dict: 90 | from griptape.flow.schemas import WorkflowSchema 91 | 92 | return WorkflowSchema().dump(self) 93 | 94 | @classmethod 95 | def from_dict(cls, workflow_dict: dict) -> Workflow: 96 | from griptape.flow.schemas import WorkflowSchema 97 | 98 | return WorkflowSchema().load(workflow_dict) 99 | 100 | @classmethod 101 | def from_json(cls, workflow_json: str) -> Workflow: 102 | return Workflow.from_dict(json.loads(workflow_json)) 103 | -------------------------------------------------------------------------------- /griptape/flow/summarizers/__init__.py: -------------------------------------------------------------------------------- 1 | from griptape.flow.summarizers.summarizer import Summarizer 2 | from griptape.flow.summarizers.prompt_driver_summarizer import PromptDriverSummarizer 3 | 4 | __all__ = [ 5 | "Summarizer", 6 | "PromptDriverSummarizer" 7 | ] -------------------------------------------------------------------------------- /griptape/flow/summarizers/prompt_driver_summarizer.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | from typing import TYPE_CHECKING, Optional 3 | from attr import define, field 4 | from griptape.flow.utils import J2 5 | from griptape.flow.drivers import BasePromptDriver 6 | from griptape.flow.summarizers.summarizer import Summarizer 7 | 8 | 9 | if TYPE_CHECKING: 10 | from griptape.flow.memory import PipelineMemory, PipelineRun 11 | 12 | 13 | @define 14 | class PromptDriverSummarizer(Summarizer): 15 | driver: BasePromptDriver = field(kw_only=True) 16 | 17 | def summarize(self, memory: PipelineMemory, runs: list[PipelineRun]) -> Optional[str]: 18 | try: 19 | if len(runs) > 0: 20 | return self.driver.run( 21 | value=J2("prompts/summarize.j2").render( 22 | summary=memory.summary, 23 | runs=runs 24 | ) 25 | ).value 26 | else: 27 | return memory.summary 28 | except Exception as e: 29 | self.pipeline.logger.error(f"Error summarizing memory: {type(e).__name__}({e})") 30 | 31 | return memory.summary 32 | -------------------------------------------------------------------------------- /griptape/flow/summarizers/summarizer.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | from typing import TYPE_CHECKING, Optional 3 | from abc import ABC, abstractmethod 4 | from attr import define, field, Factory 5 | 6 | if TYPE_CHECKING: 7 | from griptape.flow.memory import PipelineMemory, PipelineRun 8 | 9 | 10 | @define 11 | class Summarizer(ABC): 12 | type: str = field(default=Factory(lambda self: self.__class__.__name__, takes_self=True), kw_only=True) 13 | 14 | @abstractmethod 15 | def summarize(self, memory: PipelineMemory, runs: list[PipelineRun]) -> Optional[str]: 16 | ... 17 | -------------------------------------------------------------------------------- /griptape/flow/templates/prompts/context.j2: -------------------------------------------------------------------------------- 1 | {% if tools|length > 0 %} 2 | You are an assistant that follows rules and can use tools to answer questions and complete tasks one at a time. To use a tool, use this structure: 3 | 4 | Input: 5 | Thought: 6 | Action: {"tool": "", "action": "" "value": } 7 | Observation: 8 | ...repeat Thought/Action/Observation until you can respond to the original request 9 | Thought: I have enough information to respond to the original request 10 | Output: 11 | 12 | All tool action input JSON is based on the JSON Schema Draft-07 format. 13 | 14 | You have access only to the following tools: [{{ tool_names }}]. NEVER make up tools and tool names. If you encounter an error from a tool you should try to fix it. Don't request extra information from the user. If you don't need to use a tool or if you don't know which tool to use, respond like this: 15 | 16 | Input: 17 | Output: 18 | 19 | # Tool Descriptions 20 | {% for tool in tools %} 21 | {{ tool }} 22 | {% endfor %} 23 | {% else %} 24 | You are an assistant that follows rules and answers questions. Here is the conversation structure that I want you to use: 25 | 26 | Input: 27 | Output: 28 | {% endif %} 29 | 30 | {% if rules|length > 0 %} 31 | When answering questions, follow the following additional rules: 32 | {% for rule in rules %} 33 | Rule #{{loop.index}} 34 | {{ rule.value }} 35 | 36 | {% endfor %} 37 | {% endif %} -------------------------------------------------------------------------------- /griptape/flow/templates/prompts/memory.j2: -------------------------------------------------------------------------------- 1 | {% if summary %} 2 | Summary of the conversation so far: 3 | 4 | {{ summary }} 5 | 6 | Conversation begins: 7 | {% else %} 8 | Conversation begins: 9 | {% endif %} 10 | {% for run in runs %} 11 | {{ run.render() }} 12 | {% endfor %} -------------------------------------------------------------------------------- /griptape/flow/templates/prompts/pipeline.j2: -------------------------------------------------------------------------------- 1 | {% if not has_memory %} 2 | Conversation begins: 3 | {% endif %} 4 | {% for step in finished_steps %} 5 | {{ step.render() }} 6 | {% endfor %} 7 | {{ current_step.render() }} -------------------------------------------------------------------------------- /griptape/flow/templates/prompts/run_context.j2: -------------------------------------------------------------------------------- 1 | Input: {{ run.input }} 2 | Output: {{ run.output }} -------------------------------------------------------------------------------- /griptape/flow/templates/prompts/steps/prompt.j2: -------------------------------------------------------------------------------- 1 | Input: {{ step.render_prompt() }} 2 | {% if step.output %} 3 | Output: {{ step.output.value }} 4 | 5 | {% else %} 6 | Output: 7 | {% endif %} -------------------------------------------------------------------------------- /griptape/flow/templates/prompts/steps/tool/substep.j2: -------------------------------------------------------------------------------- 1 | {% if substep.thought %} 2 | Thought: {{ substep.thought }} 3 | {% endif %} 4 | Action: {{ substep.to_json() }} 5 | {% if substep.output %} 6 | Observation: {{ substep.output.value }} 7 | {% endif %} -------------------------------------------------------------------------------- /griptape/flow/templates/prompts/steps/tool/substeps.j2: -------------------------------------------------------------------------------- 1 | {% for substep in substeps %}{{ substep.render() }} 2 | {% endfor %} -------------------------------------------------------------------------------- /griptape/flow/templates/prompts/steps/tool/tool.j2: -------------------------------------------------------------------------------- 1 | Input: {{ step.render_prompt() }} 2 | {% if step.output %} 3 | Output: {{ step.output.value }} 4 | {% else %} 5 | {% for substep in substeps %}{{ substep.render() }} 6 | {% endfor %} 7 | {% endif %} -------------------------------------------------------------------------------- /griptape/flow/templates/prompts/summarize.j2: -------------------------------------------------------------------------------- 1 | {% if summary %} 2 | Conversation summary: 3 | {{ summary }} 4 | 5 | Update summary with this: 6 | {% else %} 7 | Summarize the following conversation: 8 | {% endif %} 9 | 10 | {% for run in runs %} 11 | Input: {{ run.input }} 12 | Output: {{ run.output }} 13 | 14 | {% endfor %} 15 | {% if summary %} 16 | Updated short summary: 17 | {% else %} 18 | Short summary: 19 | {% endif %} -------------------------------------------------------------------------------- /griptape/flow/templates/prompts/tool.j2: -------------------------------------------------------------------------------- 1 | ## Tool: {{ tool.name }} 2 | {% if tool.metadata %} 3 | Tool metadata: {{ tool.metadata }} 4 | {% endif %} 5 | {% for action in tool.actions() %} 6 | Tool action name: {{ tool.action_name(action) }} 7 | Tool action description: {{ tool.action_description(action) }} 8 | {% endfor %} 9 | -------------------------------------------------------------------------------- /griptape/flow/templates/prompts/workflow.j2: -------------------------------------------------------------------------------- 1 | Conversation begins: 2 | {{ step.render() }} -------------------------------------------------------------------------------- /griptape/flow/tokenizers/__init__.py: -------------------------------------------------------------------------------- 1 | from griptape.flow.tokenizers.base_tokenizer import BaseTokenizer 2 | from griptape.flow.tokenizers.tiktoken_tokenizer import TiktokenTokenizer 3 | from griptape.flow.tokenizers.cohere_tokenizer import CohereTokenizer 4 | from griptape.flow.tokenizers.hugging_face_tokenizer import HuggingFaceTokenizer 5 | 6 | 7 | __all__ = [ 8 | "BaseTokenizer", 9 | "TiktokenTokenizer", 10 | "CohereTokenizer", 11 | "HuggingFaceTokenizer" 12 | ] 13 | -------------------------------------------------------------------------------- /griptape/flow/tokenizers/base_tokenizer.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from attr import define, field 3 | 4 | 5 | @define 6 | class BaseTokenizer(ABC): 7 | DEFAULT_STOP_SEQUENCE = "Observation:" 8 | 9 | stop_sequence: str = field(default=DEFAULT_STOP_SEQUENCE, kw_only=True) 10 | 11 | @property 12 | @abstractmethod 13 | def max_tokens(self) -> int: 14 | ... 15 | 16 | def tokens_left(self, text: str) -> int: 17 | diff = self.max_tokens - self.token_count(text) 18 | 19 | if diff > 0: 20 | return diff 21 | else: 22 | return 0 23 | 24 | def token_count(self, text: str) -> int: 25 | return len(self.encode(text)) 26 | 27 | @abstractmethod 28 | def encode(self, text: str) -> list[int]: 29 | ... 30 | 31 | @abstractmethod 32 | def decode(self, tokens: list[int]) -> str: 33 | ... 34 | -------------------------------------------------------------------------------- /griptape/flow/tokenizers/cohere_tokenizer.py: -------------------------------------------------------------------------------- 1 | import cohere 2 | from attr import define, field 3 | from griptape.flow.tokenizers import BaseTokenizer 4 | 5 | 6 | @define(frozen=True) 7 | class CohereTokenizer(BaseTokenizer): 8 | DEFAULT_MODEL = "xlarge" 9 | MAX_TOKENS = 2048 10 | 11 | model: str = field(default=DEFAULT_MODEL, kw_only=True) 12 | client: cohere.Client = field(kw_only=True) 13 | 14 | @property 15 | def max_tokens(self) -> int: 16 | return self.MAX_TOKENS 17 | 18 | def token_count(self, text: str) -> int: 19 | return len(self.encode(text)) 20 | 21 | def encode(self, text: str) -> list[int]: 22 | return self.client.tokenize(text=text).tokens 23 | 24 | def decode(self, tokens: list[int]) -> str: 25 | return self.client.detokenize(tokens=tokens).text 26 | -------------------------------------------------------------------------------- /griptape/flow/tokenizers/hugging_face_tokenizer.py: -------------------------------------------------------------------------------- 1 | from attr import define, field, Factory 2 | from griptape.flow.tokenizers import BaseTokenizer 3 | from transformers import PreTrainedTokenizerBase 4 | 5 | 6 | @define(frozen=True) 7 | class HuggingFaceTokenizer(BaseTokenizer): 8 | tokenizer: PreTrainedTokenizerBase = field(kw_only=True) 9 | max_tokens: int = field( 10 | default=Factory(lambda self: self.tokenizer.model_max_length, takes_self=True), 11 | kw_only=True 12 | ) 13 | 14 | def token_count(self, text: str) -> int: 15 | return len(self.encode(text)) 16 | 17 | def encode(self, text: str) -> list[int]: 18 | return self.tokenizer.encode(text) 19 | 20 | def decode(self, tokens: list[int]) -> str: 21 | return self.tokenizer.decode(tokens) 22 | -------------------------------------------------------------------------------- /griptape/flow/tokenizers/tiktoken_tokenizer.py: -------------------------------------------------------------------------------- 1 | from attr import define, field 2 | import tiktoken 3 | from griptape.flow.tokenizers import BaseTokenizer 4 | 5 | 6 | @define(frozen=True) 7 | class TiktokenTokenizer(BaseTokenizer): 8 | DEFAULT_MODEL = "gpt-3.5-turbo" 9 | DEFAULT_ENCODING = "cl100k_base" 10 | DEFAULT_MAX_TOKENS = 2049 11 | TOKEN_OFFSET = 8 12 | 13 | MODEL_PREFIXES_TO_MAX_TOKENS = { 14 | "gpt-4-32k": 32768, 15 | "gpt-4": 8192, 16 | "gpt-3.5-turbo": 4096, 17 | "text-davinci-003": 4097, 18 | "text-davinci-002": 4097, 19 | "code-davinci-002": 8001 20 | } 21 | 22 | CHAT_API_PREFIXES = [ 23 | "gpt-3.5-turbo", 24 | "gpt-4" 25 | ] 26 | 27 | model: str = field(default=DEFAULT_MODEL, kw_only=True) 28 | 29 | @property 30 | def encoding(self) -> tiktoken.Encoding: 31 | try: 32 | return tiktoken.encoding_for_model(self.model) 33 | except KeyError: 34 | return tiktoken.get_encoding(self.DEFAULT_ENCODING) 35 | 36 | @property 37 | def max_tokens(self) -> int: 38 | tokens = next(v for k, v in self.MODEL_PREFIXES_TO_MAX_TOKENS.items() if self.model.startswith(k)) 39 | 40 | return (tokens if tokens else self.DEFAULT_MAX_TOKENS) - self.TOKEN_OFFSET 41 | 42 | def encode(self, text: str) -> list[int]: 43 | return self.encoding.encode(text, allowed_special={self.stop_sequence}) 44 | 45 | def decode(self, tokens: list[int]) -> str: 46 | return self.encoding.decode(tokens) 47 | 48 | def is_chat(self) -> bool: 49 | return next(p for p in self.CHAT_API_PREFIXES if self.model.startswith(p)) is not None -------------------------------------------------------------------------------- /griptape/flow/utils/__init__.py: -------------------------------------------------------------------------------- 1 | import json 2 | from griptape.flow.utils.j2 import J2 3 | from griptape.flow.utils.conversation import Conversation 4 | from griptape.flow.utils.tool_loader import ToolLoader 5 | 6 | __all__ = [ 7 | "J2", 8 | "Conversation", 9 | "ToolLoader" 10 | ] 11 | 12 | 13 | def minify_json(value: str) -> str: 14 | return json.dumps(json.loads(value), separators=(',', ':')) 15 | -------------------------------------------------------------------------------- /griptape/flow/utils/conversation.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | from typing import TYPE_CHECKING 3 | from attr import define, field 4 | 5 | if TYPE_CHECKING: 6 | from griptape.flow.memory import PipelineMemory 7 | 8 | 9 | @define(frozen=True) 10 | class Conversation: 11 | memory: PipelineMemory = field() 12 | 13 | def lines(self) -> list[str]: 14 | lines = [] 15 | 16 | for run in self.memory.runs: 17 | lines.append(f"Q: {run.input}") 18 | lines.append(f"A: {run.output}") 19 | 20 | return lines 21 | 22 | def to_string(self) -> str: 23 | return str.join("\n", self.lines()) 24 | -------------------------------------------------------------------------------- /griptape/flow/utils/j2.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Optional 3 | from attr import define, field, Factory 4 | from jinja2 import Environment, FileSystemLoader 5 | from griptape.flow.tokenizers import TiktokenTokenizer, BaseTokenizer 6 | import griptape 7 | 8 | 9 | @define(frozen=True) 10 | class J2: 11 | template_name: Optional[str] = field(default=None) 12 | templates_dir: str = field(default=os.path.join(griptape.flow.PACKAGE_ABS_PATH, "templates"), kw_only=True) 13 | tokenizer: BaseTokenizer = field(default=TiktokenTokenizer(), kw_only=True) 14 | environment: Environment = field( 15 | default=Factory( 16 | lambda self: Environment( 17 | loader=FileSystemLoader(self.templates_dir), 18 | trim_blocks=True, 19 | lstrip_blocks=True 20 | ), 21 | takes_self=True 22 | ), 23 | kw_only=True 24 | ) 25 | 26 | def render(self, **kwargs): 27 | if not kwargs.get("stop_sequence"): 28 | kwargs["stop_sequence"] = self.tokenizer.stop_sequence 29 | 30 | return self.environment.get_template(self.template_name).render(kwargs) 31 | 32 | def render_from_string(self, value: str, **kwargs): 33 | if not kwargs.get("stop_sequence"): 34 | kwargs["stop_sequence"] = self.tokenizer.stop_sequence 35 | 36 | return self.environment.from_string(value).render(kwargs) 37 | -------------------------------------------------------------------------------- /griptape/flow/utils/tool_loader.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | from attr import define, field 3 | from griptape.core import BaseTool, BaseExecutor 4 | from griptape.core.executors import LocalExecutor 5 | 6 | 7 | @define 8 | class ToolLoader: 9 | tools: list[BaseTool] = field(factory=list, kw_only=True) 10 | executor: BaseExecutor = field(default=LocalExecutor(), kw_only=True) 11 | 12 | @tools.validator 13 | def validate_tools(self, _, tools) -> None: 14 | tool_names = [t.name for t in tools] 15 | 16 | if len(tool_names) > len(set(tool_names)): 17 | raise ValueError("tools have to be unique") 18 | 19 | def load_tool(self, tool_name: str) -> Optional[BaseTool]: 20 | return next((t for t in self.tools if t.name == tool_name), None) 21 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.poetry] 2 | name = "griptape-flow" 3 | version = "0.12.1" 4 | description = "Python framework for LLM workflows and pipelines." 5 | authors = ["Griptape "] 6 | license = "Apache 2.0" 7 | readme = "README.md" 8 | repository = "https://github.com/griptape-ai/griptape-flow" 9 | 10 | packages = [ 11 | {include = "griptape"} 12 | ] 13 | 14 | [tool.poetry.dependencies] 15 | python = "^3.9" 16 | griptape-core = ">= 0.9.2" 17 | python-dotenv = ">=0.21" 18 | openai = ">=0.27" 19 | cohere = ">=4" 20 | attrs = ">=22" 21 | jinja2 = ">=3.1" 22 | jsonschema = ">=4" 23 | marshmallow = ">=3" 24 | marshmallow-enum = ">=1.5" 25 | graphlib = "*" 26 | tiktoken = ">=0.3" 27 | rich = ">=13" 28 | stopit = "*" 29 | transformers = ">=4" 30 | huggingface-hub = ">=0.13" 31 | torch = ">= 2" 32 | 33 | [tool.poetry.group.test.dependencies] 34 | griptape-tools = ">= 0.6.0" 35 | pytest = "~=7.1" 36 | pytest-cover = "*" 37 | twine = ">=4" 38 | 39 | [build-system] 40 | requires = ["poetry-core"] 41 | build-backend = "poetry.core.masonry.api" 42 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/griptape-ai/griptape-flow/c5e88f53a1fa6088ae3a281b300c6b3478a8e68e/tests/__init__.py -------------------------------------------------------------------------------- /tests/mocks/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/griptape-ai/griptape-flow/c5e88f53a1fa6088ae3a281b300c6b3478a8e68e/tests/mocks/__init__.py -------------------------------------------------------------------------------- /tests/mocks/mock_driver.py: -------------------------------------------------------------------------------- 1 | from attr import define 2 | from griptape.flow.drivers import BasePromptDriver 3 | from griptape.flow.tokenizers import TiktokenTokenizer, BaseTokenizer 4 | from griptape.flow.artifacts import TextOutput 5 | 6 | 7 | @define 8 | class MockDriver(BasePromptDriver): 9 | model: str = "test-model" 10 | tokenizer: BaseTokenizer = TiktokenTokenizer() 11 | 12 | def try_run(self, value: str) -> TextOutput: 13 | return TextOutput(value=f"mock output", meta={}) 14 | -------------------------------------------------------------------------------- /tests/mocks/mock_failing_driver.py: -------------------------------------------------------------------------------- 1 | from attr import define 2 | from griptape.flow.drivers import BasePromptDriver 3 | from griptape.flow.tokenizers import TiktokenTokenizer, BaseTokenizer 4 | from griptape.flow.artifacts import TextOutput 5 | 6 | 7 | @define 8 | class MockFailingDriver(BasePromptDriver): 9 | max_failures: int 10 | current_attempt: int = 0 11 | model: str = "test-model" 12 | tokenizer: BaseTokenizer = TiktokenTokenizer() 13 | 14 | def try_run(self, **kwargs) -> TextOutput: 15 | if self.current_attempt < self.max_failures: 16 | self.current_attempt += 1 17 | 18 | raise Exception(f"failed attempt") 19 | else: 20 | return TextOutput("success") 21 | -------------------------------------------------------------------------------- /tests/mocks/mock_value_driver.py: -------------------------------------------------------------------------------- 1 | from attr import define 2 | from griptape.flow.drivers import BasePromptDriver 3 | from griptape.flow.tokenizers import TiktokenTokenizer, BaseTokenizer 4 | from griptape.flow.artifacts import TextOutput 5 | 6 | 7 | @define 8 | class MockValueDriver(BasePromptDriver): 9 | value: str 10 | model: str = "test-model" 11 | tokenizer: BaseTokenizer = TiktokenTokenizer() 12 | 13 | def try_run(self, **kwargs) -> TextOutput: 14 | return TextOutput(value=self.value, meta={}) 15 | -------------------------------------------------------------------------------- /tests/unit/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/griptape-ai/griptape-flow/c5e88f53a1fa6088ae3a281b300c6b3478a8e68e/tests/unit/__init__.py -------------------------------------------------------------------------------- /tests/unit/drivers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/griptape-ai/griptape-flow/c5e88f53a1fa6088ae3a281b300c6b3478a8e68e/tests/unit/drivers/__init__.py -------------------------------------------------------------------------------- /tests/unit/drivers/test_disk_memory_driver.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pytest 3 | from tests.mocks.mock_driver import MockDriver 4 | from griptape.flow.drivers import DiskMemoryDriver 5 | from griptape.flow.memory import PipelineMemory 6 | from griptape.flow.steps import PromptStep 7 | from griptape.flow.structures import Pipeline 8 | 9 | 10 | class TestPromptDriver: 11 | MEMORY_FILE_PATH = "test_memory.json" 12 | 13 | @pytest.fixture(autouse=True) 14 | def run_before_and_after_tests(self): 15 | self.__delete_file(self.MEMORY_FILE_PATH) 16 | 17 | yield 18 | 19 | self.__delete_file(self.MEMORY_FILE_PATH) 20 | 21 | def test_store(self): 22 | prompt_driver = MockDriver() 23 | memory_driver = DiskMemoryDriver(file_path=self.MEMORY_FILE_PATH) 24 | memory = PipelineMemory(driver=memory_driver) 25 | pipeline = Pipeline(prompt_driver=prompt_driver, memory=memory) 26 | 27 | pipeline.add_step( 28 | PromptStep("test") 29 | ) 30 | 31 | try: 32 | with open(self.MEMORY_FILE_PATH, "r"): 33 | assert False 34 | except FileNotFoundError: 35 | assert True 36 | 37 | pipeline.run() 38 | 39 | with open(self.MEMORY_FILE_PATH, "r"): 40 | assert True 41 | 42 | def test_load(self): 43 | prompt_driver = MockDriver() 44 | memory_driver = DiskMemoryDriver(file_path=self.MEMORY_FILE_PATH) 45 | memory = PipelineMemory(driver=memory_driver) 46 | pipeline = Pipeline(prompt_driver=prompt_driver, memory=memory) 47 | 48 | pipeline.add_step( 49 | PromptStep("test") 50 | ) 51 | 52 | pipeline.run() 53 | pipeline.run() 54 | 55 | new_memory = memory_driver.load() 56 | 57 | assert new_memory.type == "PipelineMemory" 58 | assert len(new_memory.runs) == 2 59 | assert new_memory.runs[0].input == "test" 60 | assert new_memory.runs[0].output == "mock output" 61 | 62 | def __delete_file(self, file_path): 63 | try: 64 | os.remove(file_path) 65 | except FileNotFoundError: 66 | pass -------------------------------------------------------------------------------- /tests/unit/drivers/test_prompt_driver.py: -------------------------------------------------------------------------------- 1 | from tests.mocks.mock_failing_driver import MockFailingDriver 2 | from griptape.flow.artifacts import ErrorOutput, TextOutput 3 | from griptape.flow.steps import PromptStep 4 | from griptape.flow.structures import Pipeline 5 | 6 | 7 | class TestPromptDriver: 8 | def test_run_retries_success(self): 9 | driver = MockFailingDriver(max_failures=1, max_retries=1, retry_delay=0.01) 10 | pipeline = Pipeline(prompt_driver=driver) 11 | 12 | pipeline.add_step( 13 | PromptStep("test") 14 | ) 15 | 16 | assert isinstance(pipeline.run().output, TextOutput) 17 | 18 | def test_run_retries_failure(self): 19 | driver = MockFailingDriver(max_failures=2, max_retries=1, retry_delay=0.01) 20 | pipeline = Pipeline(prompt_driver=driver) 21 | 22 | pipeline.add_step( 23 | PromptStep("test") 24 | ) 25 | 26 | assert isinstance(pipeline.run().output, ErrorOutput) 27 | -------------------------------------------------------------------------------- /tests/unit/memory/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/griptape-ai/griptape-flow/c5e88f53a1fa6088ae3a281b300c6b3478a8e68e/tests/unit/memory/__init__.py -------------------------------------------------------------------------------- /tests/unit/memory/test_pipeline_buffer_memory.py: -------------------------------------------------------------------------------- 1 | from griptape.flow.steps import PromptStep 2 | from griptape.flow.structures import Pipeline 3 | from griptape.flow.memory import BufferPipelineMemory 4 | from tests.mocks.mock_driver import MockDriver 5 | 6 | 7 | class TestBufferMemory: 8 | def test_after_run(self): 9 | memory = BufferPipelineMemory(buffer_size=2) 10 | 11 | pipeline = Pipeline(memory=memory, prompt_driver=MockDriver()) 12 | 13 | pipeline.add_steps( 14 | PromptStep("test"), 15 | PromptStep("test"), 16 | PromptStep("test"), 17 | PromptStep("test") 18 | ) 19 | 20 | pipeline.run() 21 | pipeline.run() 22 | 23 | assert len(pipeline.memory.runs) == 2 24 | -------------------------------------------------------------------------------- /tests/unit/memory/test_pipeline_memory.py: -------------------------------------------------------------------------------- 1 | from griptape.flow.memory import PipelineMemory, PipelineRun 2 | 3 | 4 | class TestMemory: 5 | def test_is_empty(self): 6 | memory = PipelineMemory() 7 | 8 | assert memory.is_empty() 9 | 10 | memory.add_run(PipelineRun(input="test", output="test")) 11 | 12 | assert not memory.is_empty() 13 | 14 | def test_add_run(self): 15 | memory = PipelineMemory() 16 | run = PipelineRun(input="test", output="test") 17 | 18 | memory.add_run(run) 19 | 20 | assert memory.runs[0] == run 21 | 22 | def test_to_string(self): 23 | memory = PipelineMemory() 24 | run = PipelineRun(input="test", output="test") 25 | 26 | memory.add_run(run) 27 | 28 | assert "Input: test\nOutput: test" in memory.to_prompt_string() -------------------------------------------------------------------------------- /tests/unit/memory/test_pipeline_summary_memory.py: -------------------------------------------------------------------------------- 1 | from griptape.flow.summarizers import PromptDriverSummarizer 2 | from griptape.flow.memory import SummaryPipelineMemory 3 | from tests.mocks.mock_driver import MockDriver 4 | from griptape.flow.steps import PromptStep 5 | from griptape.flow.structures import Pipeline 6 | 7 | 8 | class TestSummaryMemory: 9 | def test_unsummarized_steps(self): 10 | memory = SummaryPipelineMemory(offset=1, summarizer=PromptDriverSummarizer(driver=MockDriver())) 11 | 12 | pipeline = Pipeline(memory=memory, prompt_driver=MockDriver()) 13 | 14 | pipeline.add_steps( 15 | PromptStep("test") 16 | ) 17 | 18 | pipeline.run() 19 | pipeline.run() 20 | pipeline.run() 21 | pipeline.run() 22 | 23 | assert len(memory.unsummarized_runs()) == 1 24 | 25 | def test_after_run(self): 26 | memory = SummaryPipelineMemory(offset=1, summarizer=PromptDriverSummarizer(driver=MockDriver())) 27 | 28 | pipeline = Pipeline(memory=memory, prompt_driver=MockDriver()) 29 | 30 | pipeline.add_steps( 31 | PromptStep("test") 32 | ) 33 | 34 | pipeline.run() 35 | pipeline.run() 36 | pipeline.run() 37 | pipeline.run() 38 | 39 | assert memory.summary is not None 40 | assert memory.summary_index == 3 41 | -------------------------------------------------------------------------------- /tests/unit/schemas/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/griptape-ai/griptape-flow/c5e88f53a1fa6088ae3a281b300c6b3478a8e68e/tests/unit/schemas/__init__.py -------------------------------------------------------------------------------- /tests/unit/schemas/test_pipeline_schema.py: -------------------------------------------------------------------------------- 1 | from griptape.flow.drivers import OpenAiPromptDriver 2 | from griptape.flow.tokenizers import TiktokenTokenizer 3 | from griptape.flow.steps import PromptStep, ToolkitStep, Step 4 | from griptape.flow.structures import Pipeline 5 | from griptape.flow.schemas import PipelineSchema 6 | 7 | 8 | class TestPipelineSchema: 9 | def test_serialization(self): 10 | pipeline = Pipeline( 11 | autoprune_memory=False, 12 | prompt_driver=OpenAiPromptDriver( 13 | tokenizer=TiktokenTokenizer(stop_sequence=""), 14 | temperature=0.12345 15 | ) 16 | ) 17 | 18 | tools = [ 19 | "calculator", 20 | "google_search" 21 | ] 22 | 23 | tool_step = ToolkitStep("test tool prompt", tool_names=["calculator"]) 24 | 25 | pipeline.add_steps( 26 | PromptStep("test prompt"), 27 | tool_step, 28 | ToolkitStep("test router step", tool_names=tools) 29 | ) 30 | 31 | pipeline_dict = PipelineSchema().dump(pipeline) 32 | 33 | assert pipeline_dict["autoprune_memory"] is False 34 | assert len(pipeline_dict["steps"]) == 3 35 | assert pipeline_dict["steps"][0]["state"] == "PENDING" 36 | assert pipeline_dict["steps"][0]["child_ids"][0] == pipeline.steps[1].id 37 | assert pipeline_dict["steps"][1]["parent_ids"][0] == pipeline.steps[0].id 38 | assert len(pipeline_dict["steps"][-1]["tool_names"]) == 2 39 | assert pipeline_dict["prompt_driver"]["temperature"] == 0.12345 40 | assert pipeline_dict["prompt_driver"]["tokenizer"]["stop_sequence"] == "" 41 | 42 | def test_deserialization(self): 43 | pipeline = Pipeline( 44 | autoprune_memory=False, 45 | prompt_driver=OpenAiPromptDriver( 46 | tokenizer=TiktokenTokenizer(stop_sequence=""), 47 | temperature=0.12345 48 | ) 49 | ) 50 | 51 | tools = [ 52 | "calculator", 53 | "google_search" 54 | ] 55 | 56 | tool_step = ToolkitStep("test tool prompt", tool_names=["calculator"]) 57 | 58 | pipeline.add_steps( 59 | PromptStep("test prompt"), 60 | tool_step, 61 | ToolkitStep("test router step", tool_names=tools) 62 | ) 63 | 64 | workflow_dict = PipelineSchema().dump(pipeline) 65 | deserialized_pipeline = PipelineSchema().load(workflow_dict) 66 | 67 | assert deserialized_pipeline.autoprune_memory is False 68 | assert len(deserialized_pipeline.steps) == 3 69 | assert deserialized_pipeline.steps[0].child_ids[0] == pipeline.steps[1].id 70 | assert deserialized_pipeline.steps[0].state == Step.State.PENDING 71 | assert deserialized_pipeline.steps[1].parent_ids[0] == pipeline.steps[0].id 72 | assert len(deserialized_pipeline.last_step().tool_names) == 2 73 | assert deserialized_pipeline.prompt_driver.temperature == 0.12345 74 | assert deserialized_pipeline.prompt_driver.tokenizer.stop_sequence == "" 75 | -------------------------------------------------------------------------------- /tests/unit/schemas/test_workflow_schema.py: -------------------------------------------------------------------------------- 1 | from griptape.flow.drivers import OpenAiPromptDriver 2 | from griptape.flow.rules import Rule 3 | from griptape.flow.tokenizers import TiktokenTokenizer 4 | from griptape.flow.steps import PromptStep, ToolkitStep 5 | from griptape.flow.structures import Workflow 6 | from griptape.flow.schemas import WorkflowSchema 7 | 8 | 9 | class TestWorkflowSchema: 10 | def test_serialization(self): 11 | workflow = Workflow( 12 | prompt_driver=OpenAiPromptDriver( 13 | tokenizer=TiktokenTokenizer(stop_sequence=""), 14 | temperature=0.12345 15 | ), 16 | rules=[ 17 | Rule("test rule 1"), 18 | Rule("test rule 2"), 19 | ] 20 | ) 21 | 22 | tools = [ 23 | "calculator", 24 | "google_search" 25 | ] 26 | 27 | workflow.add_steps( 28 | PromptStep("test prompt"), 29 | ToolkitStep("test tool prompt", tool_names=["calculator"]) 30 | ) 31 | 32 | step = ToolkitStep("test router step", tool_names=tools) 33 | 34 | workflow.steps[0].add_child(step) 35 | workflow.steps[1].add_child(step) 36 | 37 | workflow_dict = WorkflowSchema().dump(workflow) 38 | 39 | assert len(workflow_dict["steps"]) == 3 40 | assert len(workflow_dict["rules"]) == 2 41 | assert workflow_dict["steps"][0]["state"] == "PENDING" 42 | assert workflow_dict["steps"][0]["child_ids"][0] == step.id 43 | assert workflow.steps[0].id in step.parent_ids 44 | assert workflow.steps[1].id in step.parent_ids 45 | assert len(workflow_dict["steps"][-1]["tool_names"]) == 2 46 | assert workflow_dict["prompt_driver"]["temperature"] == 0.12345 47 | assert workflow_dict["prompt_driver"]["tokenizer"]["stop_sequence"] == "" 48 | assert workflow_dict["rules"][0]["value"] == "test rule 1" 49 | 50 | def test_deserialization(self): 51 | workflow = Workflow( 52 | prompt_driver=OpenAiPromptDriver( 53 | tokenizer=TiktokenTokenizer(stop_sequence=""), 54 | temperature=0.12345 55 | ), 56 | rules=[ 57 | Rule("test rule 1"), 58 | Rule("test rule 2"), 59 | ] 60 | ) 61 | 62 | tools = [ 63 | "calculator", 64 | "google_search" 65 | ] 66 | 67 | workflow.add_steps( 68 | PromptStep("test prompt"), 69 | ToolkitStep("test tool prompt", tool_names=["calculator"]) 70 | ) 71 | 72 | step = ToolkitStep("test router step", tool_names=tools) 73 | 74 | workflow.steps[0].add_child(step) 75 | workflow.steps[1].add_child(step) 76 | 77 | workflow_dict = WorkflowSchema().dump(workflow) 78 | deserialized_workflow = WorkflowSchema().load(workflow_dict) 79 | 80 | assert len(deserialized_workflow.steps) == 3 81 | assert len(deserialized_workflow.rules) == 2 82 | assert deserialized_workflow.steps[0].child_ids[0] == step.id 83 | assert deserialized_workflow.steps[0].id in step.parent_ids 84 | assert deserialized_workflow.steps[1].id in step.parent_ids 85 | assert len(deserialized_workflow.steps[-1].tool_names) == 2 86 | assert deserialized_workflow.prompt_driver.temperature == 0.12345 87 | assert deserialized_workflow.prompt_driver.tokenizer.stop_sequence == "" 88 | assert deserialized_workflow.rules[0].value == "test rule 1" 89 | -------------------------------------------------------------------------------- /tests/unit/steps/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/griptape-ai/griptape-flow/c5e88f53a1fa6088ae3a281b300c6b3478a8e68e/tests/unit/steps/__init__.py -------------------------------------------------------------------------------- /tests/unit/steps/test_prompt_step.py: -------------------------------------------------------------------------------- 1 | from griptape.flow.steps import PromptStep 2 | from tests.mocks.mock_driver import MockDriver 3 | from griptape.flow.structures import Pipeline 4 | 5 | 6 | class TestPromptStep: 7 | def test_run(self): 8 | step = PromptStep("test") 9 | pipeline = Pipeline(prompt_driver=MockDriver()) 10 | 11 | pipeline.add_step(step) 12 | 13 | assert step.run().value == "mock output" 14 | 15 | def test_render_prompt(self): 16 | step = PromptStep("{{ test }}", context={"test": "test value"}) 17 | 18 | Pipeline().add_step(step) 19 | 20 | assert step.render_prompt() == "test value" 21 | 22 | def test_full_context(self): 23 | parent = PromptStep("parent") 24 | step = PromptStep("test", context={"foo": "bar"}) 25 | child = PromptStep("child") 26 | pipeline = Pipeline(prompt_driver=MockDriver()) 27 | 28 | pipeline.add_steps(parent, step, child) 29 | 30 | pipeline.run() 31 | 32 | context = step.full_context 33 | 34 | assert context["foo"] == "bar" 35 | assert context["input"] == parent.output.value 36 | assert context["structure"] == pipeline 37 | assert context["parent"] == parent 38 | assert context["child"] == child 39 | -------------------------------------------------------------------------------- /tests/unit/steps/test_tool_substep.py: -------------------------------------------------------------------------------- 1 | import json 2 | from griptape.flow.steps import ToolkitStep, ToolSubstep 3 | from griptape.flow.structures import Pipeline 4 | 5 | 6 | class TestToolSubstep: 7 | def test_to_json(self): 8 | valid_input = """Thought: need to test\nAction: {"tool": "test", "action": "test action", "value": "test input"}\nObservation: test 9 | observation\nOutput: test output""" 10 | 11 | step = ToolkitStep(tool_names=[]) 12 | Pipeline().add_step(step) 13 | substep = step.add_substep(ToolSubstep(valid_input)) 14 | json_dict = json.loads(substep.to_json()) 15 | 16 | assert json_dict["tool"] == "test" 17 | assert json_dict["action"] == "test action" 18 | assert json_dict["value"] == "test input" 19 | -------------------------------------------------------------------------------- /tests/unit/steps/test_toolkit_step.py: -------------------------------------------------------------------------------- 1 | from griptape.tools import Calculator, WebSearch 2 | from griptape.flow.artifacts import ErrorOutput 3 | from griptape.flow.steps import ToolkitStep, ToolSubstep 4 | from griptape.flow.utils import ToolLoader 5 | from tests.mocks.mock_value_driver import MockValueDriver 6 | from griptape.flow.structures import Pipeline 7 | 8 | 9 | class TestToolkitStep: 10 | def test_init(self): 11 | assert len(ToolkitStep("test", tool_names=["Calculator", "WebSearch"]).tool_names) == 2 12 | 13 | try: 14 | assert ToolkitStep("test", tool_names=["Calculator", "Calculator"]) 15 | except ValueError: 16 | assert True 17 | 18 | def test_run(self): 19 | output = """Output: done""" 20 | 21 | tools = [ 22 | Calculator(), 23 | WebSearch() 24 | ] 25 | 26 | step = ToolkitStep("test", tool_names=["Calculator", "WebSearch"]) 27 | pipeline = Pipeline( 28 | prompt_driver=MockValueDriver(output), 29 | tool_loader=ToolLoader(tools=tools) 30 | ) 31 | 32 | pipeline.add_step(step) 33 | 34 | result = pipeline.run() 35 | 36 | assert len(step.tools) == 2 37 | assert len(step._substeps) == 1 38 | assert result.output.value == "done" 39 | 40 | def test_run_max_substeps(self): 41 | output = """Action: {"tool": "test"}""" 42 | 43 | step = ToolkitStep("test", tool_names=["Calculator"], max_substeps=3) 44 | pipeline = Pipeline(prompt_driver=MockValueDriver(output)) 45 | 46 | pipeline.add_step(step) 47 | 48 | pipeline.run() 49 | 50 | assert len(step._substeps) == 3 51 | assert isinstance(step.output, ErrorOutput) 52 | 53 | def test_init_from_prompt_1(self): 54 | valid_input = """Thought: need to test\nAction: {"tool": "test", "action": "test action", "value": "test input"}\nObservation: test 55 | observation\nOutput: test output""" 56 | step = ToolkitStep("test", tool_names=["Calculator"]) 57 | 58 | Pipeline().add_step(step) 59 | 60 | substep = step.add_substep(ToolSubstep(valid_input)) 61 | 62 | assert substep.thought == "need to test" 63 | assert substep.tool_name == "test" 64 | assert substep.tool_action == "test action" 65 | assert substep.tool_value == "test input" 66 | assert substep.output is None 67 | 68 | def test_init_from_prompt_2(self): 69 | valid_input = """Thought: need to test\nObservation: test 70 | observation\nOutput: test output""" 71 | step = ToolkitStep("test", tool_names=["Calculator"]) 72 | 73 | Pipeline().add_step(step) 74 | 75 | substep = step.add_substep(ToolSubstep(valid_input)) 76 | 77 | assert substep.thought == "need to test" 78 | assert substep.tool_name is None 79 | assert substep.tool_action is None 80 | assert substep.tool_value is None 81 | assert substep.output.value == "test output" 82 | 83 | def test_add_substep(self): 84 | step = ToolkitStep("test", tool_names=["Calculator"]) 85 | substep1 = ToolSubstep("test1", tool_name="test", tool_action="test", tool_value="test") 86 | substep2 = ToolSubstep("test2", tool_name="test", tool_action="test", tool_value="test") 87 | 88 | Pipeline().add_step(step) 89 | 90 | step.add_substep(substep1) 91 | step.add_substep(substep2) 92 | 93 | assert len(step._substeps) == 2 94 | 95 | assert len(substep1.children) == 1 96 | assert len(substep1.parents) == 0 97 | assert substep1.children[0] == substep2 98 | 99 | assert len(substep2.children) == 0 100 | assert len(substep2.parents) == 1 101 | assert substep2.parents[0] == substep1 102 | 103 | def test_find_substep(self): 104 | step = ToolkitStep("test", tool_names=["Calculator"]) 105 | substep1 = ToolSubstep("test1", tool_name="test", tool_action="test", tool_value="test") 106 | substep2 = ToolSubstep("test2", tool_name="test", tool_action="test", tool_value="test") 107 | 108 | Pipeline().add_step(step) 109 | 110 | step.add_substep(substep1) 111 | step.add_substep(substep2) 112 | 113 | assert step.find_substep(substep1.id) == substep1 114 | assert step.find_substep(substep2.id) == substep2 115 | 116 | def test_find_tool(self): 117 | tool = Calculator() 118 | step = ToolkitStep("test", tool_names=[tool.name]) 119 | 120 | Pipeline( 121 | tool_loader=ToolLoader(tools=[tool]) 122 | ).add_step(step) 123 | 124 | assert step.find_tool(tool.name) == tool 125 | -------------------------------------------------------------------------------- /tests/unit/structures/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/griptape-ai/griptape-flow/c5e88f53a1fa6088ae3a281b300c6b3478a8e68e/tests/unit/structures/__init__.py -------------------------------------------------------------------------------- /tests/unit/structures/test_pipeline.py: -------------------------------------------------------------------------------- 1 | import json 2 | from griptape.flow.artifacts import TextOutput 3 | from griptape.flow.rules import Rule 4 | from griptape.flow.tokenizers import TiktokenTokenizer 5 | from griptape.flow.steps import PromptStep, Step 6 | from griptape.flow.memory import PipelineMemory 7 | from tests.mocks.mock_driver import MockDriver 8 | from griptape.flow.structures import Pipeline 9 | 10 | 11 | class TestPipeline: 12 | def test_constructor(self): 13 | rule = Rule("test") 14 | driver = MockDriver() 15 | pipeline = Pipeline(prompt_driver=driver, rules=[rule]) 16 | 17 | assert pipeline.prompt_driver is driver 18 | assert pipeline.first_step() is None 19 | assert pipeline.last_step() is None 20 | assert pipeline.rules[0].value is "test" 21 | assert pipeline.memory is None 22 | 23 | def test_with_memory(self): 24 | first_step = PromptStep("test1") 25 | second_step = PromptStep("test2") 26 | third_step = PromptStep("test3") 27 | 28 | pipeline = Pipeline( 29 | prompt_driver=MockDriver(), 30 | memory=PipelineMemory() 31 | ) 32 | 33 | pipeline.add_steps(first_step, second_step, third_step) 34 | 35 | assert pipeline.memory is not None 36 | assert len(pipeline.memory.runs) == 0 37 | 38 | pipeline.run() 39 | pipeline.run() 40 | pipeline.run() 41 | 42 | assert len(pipeline.memory.runs) == 3 43 | 44 | def test_steps_order(self): 45 | first_step = PromptStep("test1") 46 | second_step = PromptStep("test2") 47 | third_step = PromptStep("test3") 48 | 49 | pipeline = Pipeline( 50 | prompt_driver=MockDriver() 51 | ) 52 | 53 | pipeline.add_step(first_step) 54 | pipeline.add_step(second_step) 55 | pipeline.add_step(third_step) 56 | 57 | assert pipeline.first_step().id is first_step.id 58 | assert pipeline.steps[1].id is second_step.id 59 | assert pipeline.steps[2].id is third_step.id 60 | assert pipeline.last_step().id is third_step.id 61 | 62 | def test_add_step(self): 63 | first_step = PromptStep("test1") 64 | second_step = PromptStep("test2") 65 | 66 | pipeline = Pipeline( 67 | prompt_driver=MockDriver() 68 | ) 69 | 70 | pipeline.add_step(first_step) 71 | pipeline.add_step(second_step) 72 | 73 | assert len(pipeline.steps) == 2 74 | assert first_step in pipeline.steps 75 | assert second_step in pipeline.steps 76 | assert first_step.structure == pipeline 77 | assert second_step.structure == pipeline 78 | assert len(first_step.parents) == 0 79 | assert len(first_step.children) == 1 80 | assert len(second_step.parents) == 1 81 | assert len(second_step.children) == 0 82 | 83 | def test_add_steps(self): 84 | first_step = PromptStep("test1") 85 | second_step = PromptStep("test2") 86 | 87 | pipeline = Pipeline( 88 | prompt_driver=MockDriver() 89 | ) 90 | 91 | pipeline.add_steps(first_step, second_step) 92 | 93 | assert len(pipeline.steps) == 2 94 | assert first_step in pipeline.steps 95 | assert second_step in pipeline.steps 96 | assert first_step.structure == pipeline 97 | assert second_step.structure == pipeline 98 | assert len(first_step.parents) == 0 99 | assert len(first_step.children) == 1 100 | assert len(second_step.parents) == 1 101 | assert len(second_step.children) == 0 102 | 103 | def test_prompt_stack_without_memory(self): 104 | pipeline = Pipeline( 105 | prompt_driver=MockDriver() 106 | ) 107 | 108 | step1 = PromptStep("test") 109 | step2 = PromptStep("test") 110 | 111 | pipeline.add_step(step1) 112 | 113 | # context and first input 114 | assert len(pipeline.prompt_stack(step1)) == 2 115 | 116 | pipeline.run() 117 | 118 | pipeline.add_step(step2) 119 | 120 | # context and second input 121 | assert len(pipeline.prompt_stack(step2)) == 2 122 | 123 | def test_prompt_stack_with_memory(self): 124 | pipeline = Pipeline( 125 | prompt_driver=MockDriver(), 126 | memory=PipelineMemory() 127 | ) 128 | 129 | step1 = PromptStep("test") 130 | step2 = PromptStep("test") 131 | 132 | pipeline.add_step(step1) 133 | 134 | # context and first input 135 | assert len(pipeline.prompt_stack(step1)) == 2 136 | 137 | pipeline.run() 138 | 139 | pipeline.add_step(step2) 140 | 141 | # context, memory, and second input 142 | assert len(pipeline.prompt_stack(step2)) == 3 143 | 144 | def test_to_prompt_string(self): 145 | pipeline = Pipeline( 146 | prompt_driver=MockDriver(), 147 | ) 148 | 149 | step = PromptStep("test") 150 | 151 | pipeline.add_step(step) 152 | 153 | pipeline.run() 154 | 155 | assert "mock output" in pipeline.to_prompt_string(step) 156 | 157 | def test_step_output_token_count(self): 158 | text = "foobar" 159 | 160 | assert TextOutput(text).token_count(TiktokenTokenizer()) == TiktokenTokenizer().token_count(text) 161 | 162 | def test_run(self): 163 | step = PromptStep("test") 164 | pipeline = Pipeline(prompt_driver=MockDriver()) 165 | pipeline.add_step(step) 166 | 167 | assert step.state == Step.State.PENDING 168 | 169 | result = pipeline.run() 170 | 171 | assert "mock output" in result.output.value 172 | assert step.state == Step.State.FINISHED 173 | 174 | def test_run_with_args(self): 175 | step = PromptStep("{{ args[0] }}-{{ args[1] }}") 176 | pipeline = Pipeline(prompt_driver=MockDriver()) 177 | pipeline.add_steps(step) 178 | 179 | pipeline._execution_args = ("test1", "test2") 180 | 181 | assert step.render_prompt() == "test1-test2" 182 | 183 | pipeline.run() 184 | 185 | assert step.render_prompt() == "-" 186 | 187 | def test_to_json(self): 188 | pipeline = Pipeline() 189 | 190 | pipeline.add_steps( 191 | PromptStep("test prompt"), 192 | PromptStep("test prompt") 193 | ) 194 | 195 | assert len(json.loads(pipeline.to_json())["steps"]) == 2 196 | 197 | def test_to_dict(self): 198 | pipeline = Pipeline() 199 | 200 | pipeline.add_steps( 201 | PromptStep("test prompt"), 202 | PromptStep("test prompt") 203 | ) 204 | 205 | assert len(pipeline.to_dict()["steps"]) == 2 206 | 207 | def test_from_json(self): 208 | pipeline = Pipeline() 209 | 210 | pipeline.add_steps( 211 | PromptStep("test prompt"), 212 | PromptStep("test prompt") 213 | ) 214 | 215 | workflow_json = pipeline.to_json() 216 | 217 | assert len(Pipeline.from_json(workflow_json).steps) == 2 218 | 219 | def test_from_dict(self): 220 | pipeline = Pipeline() 221 | 222 | pipeline.add_steps( 223 | PromptStep("test prompt"), 224 | PromptStep("test prompt") 225 | ) 226 | 227 | workflow_json = pipeline.to_dict() 228 | 229 | assert len(Pipeline.from_dict(workflow_json).steps) == 2 230 | 231 | def test_context(self): 232 | parent = PromptStep("parent") 233 | step = PromptStep("test") 234 | child = PromptStep("child") 235 | pipeline = Pipeline(prompt_driver=MockDriver()) 236 | 237 | pipeline.add_steps(parent, step, child) 238 | 239 | context = pipeline.context(step) 240 | 241 | assert context["input"] is None 242 | 243 | pipeline.run() 244 | 245 | context = pipeline.context(step) 246 | 247 | assert context["input"] == parent.output.value 248 | assert context["structure"] == pipeline 249 | assert context["parent"] == parent 250 | assert context["child"] == child 251 | -------------------------------------------------------------------------------- /tests/unit/structures/test_workflow.py: -------------------------------------------------------------------------------- 1 | import json 2 | from tests.mocks.mock_driver import MockDriver 3 | from griptape.flow.rules import Rule 4 | from griptape.flow.steps import PromptStep, Step 5 | from griptape.flow.structures import Workflow 6 | 7 | 8 | class TestWorkflow: 9 | def test_constructor(self): 10 | rule = Rule("test") 11 | driver = MockDriver() 12 | workflow = Workflow(prompt_driver=driver, rules=[rule]) 13 | 14 | assert workflow.prompt_driver is driver 15 | assert len(workflow.steps) == 0 16 | assert workflow.rules[0].value is "test" 17 | 18 | def test_add_step(self): 19 | first_step = PromptStep("test1") 20 | second_step = PromptStep("test2") 21 | 22 | workflow = Workflow( 23 | prompt_driver=MockDriver() 24 | ) 25 | 26 | workflow.add_step(first_step) 27 | workflow.add_step(second_step) 28 | 29 | assert len(workflow.steps) == 2 30 | assert first_step in workflow.steps 31 | assert second_step in workflow.steps 32 | assert first_step.structure == workflow 33 | assert second_step.structure == workflow 34 | assert len(first_step.parents) == 0 35 | assert len(first_step.children) == 0 36 | assert len(second_step.parents) == 0 37 | assert len(second_step.children) == 0 38 | 39 | def test_add_steps(self): 40 | first_step = PromptStep("test1") 41 | second_step = PromptStep("test2") 42 | 43 | workflow = Workflow( 44 | prompt_driver=MockDriver() 45 | ) 46 | 47 | workflow.add_steps(first_step, second_step) 48 | 49 | assert len(workflow.steps) == 2 50 | assert first_step in workflow.steps 51 | assert second_step in workflow.steps 52 | assert first_step.structure == workflow 53 | assert second_step.structure == workflow 54 | assert len(first_step.parents) == 0 55 | assert len(first_step.children) == 0 56 | assert len(second_step.parents) == 0 57 | assert len(second_step.children) == 0 58 | 59 | def test_run(self): 60 | step1 = PromptStep("test") 61 | step2 = PromptStep("test") 62 | workflow = Workflow(prompt_driver=MockDriver()) 63 | workflow.add_steps(step1, step2) 64 | 65 | assert step1.state == Step.State.PENDING 66 | assert step2.state == Step.State.PENDING 67 | 68 | workflow.run() 69 | 70 | assert step1.state == Step.State.FINISHED 71 | assert step2.state == Step.State.FINISHED 72 | 73 | def test_run_with_args(self): 74 | step = PromptStep("{{ args[0] }}-{{ args[1] }}") 75 | workflow = Workflow(prompt_driver=MockDriver()) 76 | workflow.add_steps(step) 77 | 78 | workflow._execution_args = ("test1", "test2") 79 | 80 | assert step.render_prompt() == "test1-test2" 81 | 82 | workflow.run() 83 | 84 | assert step.render_prompt() == "-" 85 | 86 | def test_run_topology_1(self): 87 | step1 = PromptStep("prompt1") 88 | step2 = PromptStep("prompt2") 89 | step3 = PromptStep("prompt3") 90 | workflow = Workflow(prompt_driver=MockDriver()) 91 | 92 | # step1 splits into step2 and step3 93 | workflow.add_step(step1) 94 | step1.add_child(step2) 95 | step3.add_parent(step1) 96 | 97 | workflow.run() 98 | 99 | assert step1.state == Step.State.FINISHED 100 | assert step2.state == Step.State.FINISHED 101 | assert step3.state == Step.State.FINISHED 102 | 103 | def test_run_topology_2(self): 104 | step1 = PromptStep("test1") 105 | step2 = PromptStep("test2") 106 | step3 = PromptStep("test3") 107 | workflow = Workflow(prompt_driver=MockDriver()) 108 | 109 | # step1 and step2 converge into step3 110 | workflow.add_steps(step1, step2) 111 | step1.add_child(step3) 112 | step3.add_parent(step2) 113 | 114 | workflow.run() 115 | 116 | assert step1.state == Step.State.FINISHED 117 | assert step2.state == Step.State.FINISHED 118 | assert step3.state == Step.State.FINISHED 119 | 120 | def test_output_steps(self): 121 | step1 = PromptStep("prompt1") 122 | step2 = PromptStep("prompt2") 123 | step3 = PromptStep("prompt3") 124 | workflow = Workflow(prompt_driver=MockDriver()) 125 | 126 | workflow.add_step(step1) 127 | step1.add_child(step2) 128 | step3.add_parent(step1) 129 | 130 | assert len(workflow.output_steps()) == 2 131 | assert step2 in workflow.output_steps() 132 | assert step3 in workflow.output_steps() 133 | 134 | def test_to_graph(self): 135 | step1 = PromptStep("prompt1", id="step1") 136 | step2 = PromptStep("prompt2", id="step2") 137 | step3 = PromptStep("prompt3", id="step3") 138 | workflow = Workflow(prompt_driver=MockDriver()) 139 | 140 | workflow.add_step(step1) 141 | step1.add_child(step2) 142 | step3.add_parent(step1) 143 | 144 | graph = workflow.to_graph() 145 | 146 | assert "step1" in graph["step2"] 147 | assert "step1" in graph["step3"] 148 | 149 | def test_order_steps(self): 150 | step1 = PromptStep("prompt1") 151 | step2 = PromptStep("prompt2") 152 | step3 = PromptStep("prompt3") 153 | workflow = Workflow(prompt_driver=MockDriver()) 154 | 155 | workflow.add_step(step1) 156 | step1.add_child(step2) 157 | step3.add_parent(step1) 158 | 159 | ordered_steps = workflow.order_steps() 160 | 161 | assert ordered_steps[0] == step1 162 | assert ordered_steps[1] == step2 or ordered_steps[1] == step3 163 | assert ordered_steps[2] == step2 or ordered_steps[2] == step3 164 | 165 | def test_to_json(self): 166 | workflow = Workflow() 167 | 168 | workflow.add_steps( 169 | PromptStep("test prompt"), 170 | PromptStep("test prompt") 171 | ) 172 | 173 | assert len(json.loads(workflow.to_json())["steps"]) == 2 174 | 175 | def test_to_dict(self): 176 | workflow = Workflow() 177 | 178 | workflow.add_steps( 179 | PromptStep("test prompt"), 180 | PromptStep("test prompt") 181 | ) 182 | 183 | assert len(workflow.to_dict()["steps"]) == 2 184 | 185 | def test_from_json(self): 186 | workflow = Workflow() 187 | 188 | workflow.add_steps( 189 | PromptStep("test prompt"), 190 | PromptStep("test prompt") 191 | ) 192 | 193 | workflow_json = workflow.to_json() 194 | 195 | assert len(Workflow.from_json(workflow_json).steps) == 2 196 | 197 | def test_from_dict(self): 198 | workflow = Workflow() 199 | 200 | workflow.add_steps( 201 | PromptStep("test prompt"), 202 | PromptStep("test prompt") 203 | ) 204 | 205 | workflow_json = workflow.to_dict() 206 | 207 | assert len(Workflow.from_dict(workflow_json).steps) == 2 208 | 209 | def test_context(self): 210 | parent = PromptStep("parent") 211 | step = PromptStep("test") 212 | child = PromptStep("child") 213 | workflow = Workflow(prompt_driver=MockDriver()) 214 | 215 | workflow.add_step(parent) 216 | 217 | parent.add_child(step) 218 | step.add_child(child) 219 | 220 | context = workflow.context(step) 221 | 222 | assert context["inputs"] == {parent.id: ""} 223 | 224 | workflow.run() 225 | 226 | context = workflow.context(step) 227 | 228 | assert context["inputs"] == {parent.id: parent.output.value} 229 | assert context["structure"] == workflow 230 | assert context["parents"] == {parent.id: parent} 231 | assert context["children"] == {child.id: child} 232 | -------------------------------------------------------------------------------- /tests/unit/tokenizers/test_hugging_face_tokenizer.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from transformers import GPT2Tokenizer 3 | from griptape.flow.tokenizers import HuggingFaceTokenizer 4 | 5 | 6 | class TestHuggingFaceTokenizer: 7 | @pytest.fixture 8 | def tokenizer(self): 9 | return HuggingFaceTokenizer( 10 | tokenizer=GPT2Tokenizer.from_pretrained("gpt2") 11 | ) 12 | 13 | def test_encode(self, tokenizer): 14 | assert tokenizer.encode("foo bar") == [21943, 2318] 15 | 16 | def test_decode(self, tokenizer): 17 | assert tokenizer.decode([21943, 2318]) == "foo bar" 18 | 19 | def test_token_count(self, tokenizer): 20 | assert tokenizer.token_count("foo bar huzzah") == 5 21 | 22 | def test_tokens_left(self, tokenizer): 23 | assert tokenizer.tokens_left("foo bar huzzah") == 1019 24 | -------------------------------------------------------------------------------- /tests/unit/tokenizers/test_tiktoken_tokenizer.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from griptape.flow.tokenizers import TiktokenTokenizer 3 | 4 | 5 | class TestTiktokenTokenizer: 6 | @pytest.fixture 7 | def tokenizer(self): 8 | return TiktokenTokenizer() 9 | 10 | def test_encode(self, tokenizer): 11 | assert tokenizer.encode("foo bar") == [8134, 3703] 12 | 13 | def test_decode(self, tokenizer): 14 | assert tokenizer.decode([8134, 3703]) == "foo bar" 15 | 16 | def test_token_count(self, tokenizer): 17 | assert tokenizer.token_count("foo bar huzzah") == 5 18 | 19 | def test_tokens_left(self, tokenizer): 20 | assert tokenizer.tokens_left("foo bar huzzah") == 4083 21 | 22 | def test_encoding(self, tokenizer): 23 | assert tokenizer.encoding.name == "cl100k_base" 24 | -------------------------------------------------------------------------------- /tests/unit/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/griptape-ai/griptape-flow/c5e88f53a1fa6088ae3a281b300c6b3478a8e68e/tests/unit/utils/__init__.py -------------------------------------------------------------------------------- /tests/unit/utils/test_conversation.py: -------------------------------------------------------------------------------- 1 | from tests.mocks.mock_driver import MockDriver 2 | from griptape.flow.memory import PipelineMemory 3 | from griptape.flow.steps import PromptStep 4 | from griptape.flow.structures import Pipeline 5 | from griptape.flow.utils import Conversation 6 | 7 | 8 | class TestConversation: 9 | def test_lines(self): 10 | pipeline = Pipeline(prompt_driver=MockDriver(), memory=PipelineMemory()) 11 | 12 | pipeline.add_steps( 13 | PromptStep("question 1") 14 | ) 15 | 16 | pipeline.run() 17 | pipeline.run() 18 | 19 | lines = Conversation(pipeline.memory).lines() 20 | 21 | assert lines[0] == "Q: question 1" 22 | assert lines[1] == "A: mock output" 23 | assert lines[2] == "Q: question 1" 24 | assert lines[3] == "A: mock output" 25 | 26 | def test_to_string(self): 27 | pipeline = Pipeline(prompt_driver=MockDriver(), memory=PipelineMemory()) 28 | 29 | pipeline.add_steps( 30 | PromptStep("question 1") 31 | ) 32 | 33 | pipeline.run() 34 | 35 | string = Conversation(pipeline.memory).to_string() 36 | 37 | assert string == "Q: question 1\nA: mock output" 38 | -------------------------------------------------------------------------------- /tests/unit/utils/test_tool_loader.py: -------------------------------------------------------------------------------- 1 | from griptape.tools import Calculator, WebScraper 2 | from griptape.flow.utils import ToolLoader 3 | 4 | 5 | class TestToolLoader: 6 | def test_init(self): 7 | loader = ToolLoader(tools=[Calculator(), WebScraper()]) 8 | 9 | assert len(loader.tools) == 2 10 | 11 | try: 12 | ToolLoader(tools=[Calculator(), Calculator()]) 13 | except ValueError: 14 | assert True 15 | 16 | def test_load_tool(self): 17 | loader = ToolLoader(tools=[Calculator(name="MyCalculator"), WebScraper()]) 18 | 19 | assert isinstance(loader.load_tool("MyCalculator"), Calculator) --------------------------------------------------------------------------------