├── .github ├── dependabot.yml └── workflows │ ├── docs.yml │ ├── smokeshow.yml │ └── test.yml ├── .gitignore ├── .pre-commit-config.yaml ├── LICENSE.txt ├── Makefile ├── README.md ├── autochain ├── __init__.py ├── agent │ ├── __init__.py │ ├── base_agent.py │ ├── conversational_agent │ │ ├── __init__.py │ │ ├── conversational_agent.py │ │ ├── output_parser.py │ │ ├── prompt.py │ │ └── readme.md │ ├── message.py │ ├── openai_functions_agent │ │ ├── __init__.py │ │ ├── openai_functions_agent.py │ │ ├── output_parser.py │ │ ├── prompt.py │ │ └── readme.md │ ├── prompt_formatter.py │ └── structs.py ├── chain │ ├── __init__.py │ ├── base_chain.py │ ├── chain.py │ ├── constants.py │ └── langchain_wrapper_chain.py ├── errors.py ├── examples │ ├── __init__.py │ ├── get_weather_with_conversational_agent.py │ ├── get_weather_with_openai_function_agent.py │ ├── readme.md │ ├── upsale_goal_conversational_agent.py │ └── write_poem_with_conversational_agent.py ├── memory │ ├── __init__.py │ ├── base.py │ ├── buffer_memory.py │ ├── constants.py │ ├── long_term_memory.py │ └── redis_memory.py ├── models │ ├── __init__.py │ ├── ada_embedding.py │ ├── base.py │ ├── chat_openai.py │ ├── huggingface_text_generation_model.py │ └── readme.md ├── py.typed ├── tools │ ├── __init__.py │ ├── base.py │ ├── google_search │ │ ├── __init__.py │ │ ├── tool.py │ │ └── util.py │ ├── internal_search │ │ ├── __init__.py │ │ ├── base_search_tool.py │ │ ├── chromadb_tool.py │ │ ├── lancedb_tool.py │ │ └── pinecone_tool.py │ └── simple_handoff │ │ ├── __init__.py │ │ └── tool.py ├── utils.py └── workflows_evaluation │ ├── __init__.py │ ├── base_test.py │ ├── conversational_agent_eval │ ├── __init__.py │ ├── find_food_near_me_test.py │ └── generate_ads_test.py │ ├── langchain_eval │ ├── __init__.py │ ├── custom_langchain_output_parser.py │ ├── find_food_near_me_test.py │ ├── generate_ads_test.py │ ├── langchain_test_utils.py │ └── readme.md │ ├── openai_function_agent_eval │ ├── __init__.py │ ├── find_food_near_me_test.py │ ├── generate_ads_test.py │ └── get_weather_test.py │ └── test_utils.py ├── docs ├── agent.md ├── chain.md ├── components_overview.md ├── css │ ├── custom.css │ └── termynal.css ├── examples.md ├── img │ ├── autochain.drawio.png │ ├── autochain.drawio.svg │ ├── icon.png │ └── logo-margin │ │ └── logo.png ├── index.md ├── js │ ├── custom.js │ └── termynal.js ├── memory.md ├── robots.txt ├── tool.md └── workflow-evaluation.md ├── mkdocs.insiders.yml ├── mkdocs.yml ├── poetry.lock ├── pyproject.toml ├── test_utils ├── __init__.py └── pinecone_mocks.py └── tests ├── agent ├── test_conversational_agent.py └── test_openai_functions_agent.py ├── memory ├── test_buffer_memory.py ├── test_long_term_memory.py └── test_redis_memory.py ├── models ├── test_chat_openai.py └── test_openai_ada_encoder.py └── tools ├── test_base_tool.py ├── test_chromadb_tool.py ├── test_google_search.py ├── test_lancedb_tool.py ├── test_pinecone_tool.py └── test_simple_handoff.py /.github/dependabot.yml: -------------------------------------------------------------------------------- 1 | version: 2 2 | updates: 3 | # GitHub Actions 4 | - package-ecosystem: "github-actions" 5 | directory: "/" 6 | schedule: 7 | interval: "daily" 8 | commit-message: 9 | prefix: ⬆ 10 | # Python 11 | - package-ecosystem: "pip" 12 | directory: "/" 13 | schedule: 14 | interval: "daily" 15 | commit-message: 16 | prefix: ⬆ 17 | -------------------------------------------------------------------------------- /.github/workflows/docs.yml: -------------------------------------------------------------------------------- 1 | name: Build Docs 2 | on: 3 | push: 4 | branches: 5 | - main 6 | pull_request: 7 | types: 8 | - opened 9 | - synchronize 10 | jobs: 11 | build: 12 | runs-on: ubuntu-latest 13 | steps: 14 | - uses: actions/checkout@v2 15 | # TODO: enable this once the README and index.md have the same content, with absolute links to a published website 16 | # - name: Ensure README has the same contents as docs/index.md 17 | # run: diff docs/index.md README.md 18 | - name: Set up Python 19 | uses: actions/setup-python@v2 20 | with: 21 | python-version: "3.10" 22 | - uses: actions/cache@v2 23 | id: cache 24 | with: 25 | path: ${{ env.pythonLocation }} 26 | key: ${{ runner.os }}-python-${{ env.pythonLocation }}-${{ hashFiles('pyproject.toml') }} 27 | - name: Install poetry 28 | if: steps.cache.outputs.cache-hit != 'true' 29 | shell: bash 30 | run: | 31 | python -m pip install --upgrade pip 32 | python -m pip install "poetry>=1.5.0" 33 | - name: Configure poetry 34 | run: python -m poetry config virtualenvs.create false 35 | - name: Install Dependencies 36 | if: steps.cache.outputs.cache-hit != 'true' 37 | run: python -m poetry install 38 | - name: Install Material for MkDocs Insiders 39 | run: python -m poetry run pip install git+https://${{ secrets.MK_DOCS_ACTIONS_TOKEN }}@github.com/squidfunk/mkdocs-material-insiders.git 40 | - uses: actions/cache@v2 41 | with: 42 | key: mkdocs-cards-${{ github.ref }} 43 | path: .cache 44 | - name: Build Docs with Insiders 45 | run: python -m poetry run mkdocs build --config-file mkdocs.insiders.yml 46 | - name: Publish to Cloudflare Pages 47 | uses: cloudflare/pages-action@v1 48 | with: 49 | apiToken: ${{ secrets.CF_API_TOKEN_PAGES }} 50 | accountId: ${{ secrets.CF_ACCOUNT_ID }} 51 | projectName: autochain 52 | directory: site/ 53 | gitHubToken: ${{ secrets.GITHUB_TOKEN }} 54 | wranglerVersion: '3' 55 | -------------------------------------------------------------------------------- /.github/workflows/smokeshow.yml: -------------------------------------------------------------------------------- 1 | name: Smokeshow 2 | 3 | on: 4 | workflow_run: 5 | workflows: [Test] 6 | types: [completed] 7 | 8 | permissions: 9 | statuses: write 10 | 11 | jobs: 12 | smokeshow: 13 | if: ${{ github.event.workflow_run.conclusion == 'success' }} 14 | runs-on: ubuntu-latest 15 | 16 | steps: 17 | - uses: actions/setup-python@v4 18 | with: 19 | python-version: '3.10' 20 | 21 | - run: pip install smokeshow 22 | 23 | - uses: dawidd6/action-download-artifact@v2.27.0 24 | with: 25 | github_token: ${{ secrets.GITHUB_TOKEN }} 26 | workflow: test.yml 27 | commit: ${{ github.event.workflow_run.head_sha }} 28 | 29 | - run: smokeshow upload coverage-html 30 | env: 31 | SMOKESHOW_GITHUB_STATUS_DESCRIPTION: Coverage {coverage-percentage} 32 | SMOKESHOW_GITHUB_COVERAGE_THRESHOLD: 100 33 | SMOKESHOW_GITHUB_CONTEXT: coverage 34 | SMOKESHOW_GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} 35 | SMOKESHOW_GITHUB_PR_HEAD_SHA: ${{ github.event.workflow_run.head_sha }} 36 | SMOKESHOW_AUTH_KEY: ${{ secrets.SMOKESHOW_AUTH_KEY }} 37 | -------------------------------------------------------------------------------- /.github/workflows/test.yml: -------------------------------------------------------------------------------- 1 | name: Test 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | pull_request: 8 | types: 9 | - opened 10 | - synchronize 11 | 12 | jobs: 13 | test: 14 | runs-on: ubuntu-latest 15 | strategy: 16 | matrix: 17 | python-version: ["3.8", "3.9", "3.10", "3.11"] 18 | fail-fast: false 19 | 20 | steps: 21 | - uses: actions/checkout@v3 22 | - name: Set up Python 23 | uses: actions/setup-python@v4 24 | with: 25 | python-version: ${{ matrix.python-version }} 26 | - uses: actions/cache@v3 27 | id: python-cache 28 | with: 29 | path: ${{ env.pythonLocation }} 30 | key: ${{ runner.os }}-python-${{ env.pythonLocation }}-${{ hashFiles('pyproject.toml') }}-${{ hashFiles('poetry.lock') }} 31 | 32 | - name: Install Poetry 33 | if: steps.python-cache.outputs.cache-hit != 'true' 34 | shell: bash 35 | run: | 36 | python -m pip install --upgrade pip 37 | python -m pip install "poetry>=1.5.0" 38 | 39 | - name: Configure Poetry 40 | shell: bash 41 | run: python -m poetry config virtualenvs.create false 42 | 43 | - name: Install Dependencies 44 | if: steps.python-cache.outputs.cache-hit != 'true' 45 | shell: bash 46 | run: python -m poetry install --all-extras 47 | 48 | # TODO: run lints, mypy, ruff 49 | - run: mkdir coverage 50 | - name: Test 51 | run: coverage run -m pytest tests 52 | env: 53 | COVERAGE_FILE: coverage/.coverage.${{ runner.os }}-py${{ matrix.python-version }} 54 | CONTEXT: ${{ runner.os }}-py${{ matrix.python-version }} 55 | - name: Store coverage files 56 | uses: actions/upload-artifact@v3 57 | with: 58 | name: coverage 59 | path: coverage 60 | coverage-combine: 61 | needs: [test] 62 | runs-on: ubuntu-latest 63 | 64 | steps: 65 | - uses: actions/checkout@v3 66 | 67 | - uses: actions/setup-python@v4 68 | with: 69 | python-version: '3.10' 70 | - name: Get coverage files 71 | uses: actions/download-artifact@v3 72 | with: 73 | name: coverage 74 | path: coverage 75 | 76 | - run: pip install coverage[toml] 77 | 78 | - run: ls -la coverage 79 | - run: coverage combine coverage 80 | - run: coverage report 81 | - run: coverage html --show-contexts --title "Coverage for ${{ github.sha }}" 82 | 83 | - name: Store coverage HTML 84 | uses: actions/upload-artifact@v3 85 | with: 86 | name: coverage-html 87 | path: htmlcov 88 | 89 | # https://github.com/marketplace/actions/alls-green#why 90 | check: # This job does nothing and is only used for the branch protection 91 | 92 | if: always() 93 | 94 | needs: 95 | - coverage-combine 96 | 97 | runs-on: ubuntu-latest 98 | 99 | steps: 100 | - name: Decide whether the needed jobs succeeded or failed 101 | uses: re-actors/alls-green@release/v1 102 | with: 103 | jobs: ${{ toJSON(needs) }} 104 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .vs/ 2 | .vscode/ 3 | .idea/ 4 | # Byte-compiled / optimized / DLL files 5 | __pycache__/ 6 | *.py[cod] 7 | *$py.class 8 | 9 | # C extensions 10 | *.so 11 | 12 | # Distribution / packaging 13 | .Python 14 | build/ 15 | develop-eggs/ 16 | dist/ 17 | downloads/ 18 | eggs/ 19 | .eggs/ 20 | lib/ 21 | lib64/ 22 | parts/ 23 | sdist/ 24 | var/ 25 | wheels/ 26 | pip-wheel-metadata/ 27 | share/python-wheels/ 28 | *.egg-info/ 29 | .installed.cfg 30 | *.egg 31 | MANIFEST 32 | 33 | # PyInstaller 34 | # Usually these files are written by a python script from a template 35 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 36 | *.manifest 37 | *.spec 38 | 39 | # Installer logs 40 | pip-log.txt 41 | pip-delete-this-directory.txt 42 | 43 | # Unit test / coverage reports 44 | htmlcov/ 45 | .tox/ 46 | .nox/ 47 | .coverage 48 | .coverage.* 49 | .cache 50 | nosetests.xml 51 | coverage.xml 52 | *.cover 53 | *.py,cover 54 | .hypothesis/ 55 | .pytest_cache/ 56 | 57 | # Translations 58 | *.mo 59 | *.pot 60 | 61 | # Django stuff: 62 | *.log 63 | local_settings.py 64 | db.sqlite3 65 | db.sqlite3-journal 66 | 67 | # Flask stuff: 68 | instance/ 69 | .webassets-cache 70 | 71 | # Scrapy stuff: 72 | .scrapy 73 | 74 | # Sphinx documentation 75 | docs/_build/ 76 | 77 | # PyBuilder 78 | target/ 79 | 80 | # Jupyter Notebook 81 | .ipynb_checkpoints 82 | notebooks/ 83 | 84 | # IPython 85 | profile_default/ 86 | ipython_config.py 87 | 88 | # pyenv 89 | .python-version 90 | 91 | # pipenv 92 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 93 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 94 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 95 | # install all needed dependencies. 96 | #Pipfile.lock 97 | 98 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 99 | __pypackages__/ 100 | 101 | # Celery stuff 102 | celerybeat-schedule 103 | celerybeat.pid 104 | 105 | # SageMath parsed files 106 | *.sage.py 107 | 108 | # Environments 109 | .env 110 | .envrc 111 | .venv 112 | .venvs 113 | env/ 114 | venv/ 115 | ENV/ 116 | env.bak/ 117 | venv.bak/ 118 | 119 | # Spyder project settings 120 | .spyderproject 121 | .spyproject 122 | 123 | # Rope project settings 124 | .ropeproject 125 | 126 | # mkdocs documentation 127 | /site 128 | 129 | # mypy 130 | .mypy_cache/ 131 | .dmypy.json 132 | dmypy.json 133 | 134 | # Pyre type checker 135 | .pyre/ 136 | 137 | # macOS display setting files 138 | .DS_Store 139 | 140 | # Wandb directory 141 | wandb/ 142 | 143 | # asdf tool versions 144 | .tool-versions 145 | /.ruff_cache/ 146 | 147 | *.pkl 148 | *.bin 149 | 150 | # integration test artifacts 151 | data_map* 152 | \[('_type', 'fake'), ('stop', None)] 153 | 154 | .chroma/ 155 | test_results/ 156 | lancedb/ -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | # See https://pre-commit.com for more information 2 | # See https://pre-commit.com/hooks.html for more hooks 3 | default_language_version: 4 | python: python3.10 5 | repos: 6 | - repo: https://github.com/pre-commit/pre-commit-hooks 7 | rev: v4.4.0 8 | hooks: 9 | - id: check-added-large-files 10 | - id: check-toml 11 | - id: check-yaml 12 | args: 13 | - --unsafe 14 | - id: end-of-file-fixer 15 | - id: trailing-whitespace 16 | - repo: https://github.com/asottile/pyupgrade 17 | rev: v3.3.1 18 | hooks: 19 | - id: pyupgrade 20 | args: 21 | - --py3-plus 22 | - --keep-runtime-typing 23 | - repo: https://github.com/charliermarsh/ruff-pre-commit 24 | rev: v0.0.272 25 | hooks: 26 | - id: ruff 27 | args: 28 | - --fix 29 | - repo: https://github.com/psf/black 30 | rev: 23.3.0 31 | hooks: 32 | - id: black 33 | -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | The MIT License 2 | 3 | Copyright (c) Harrison Chase 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in 13 | all copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 21 | THE SOFTWARE. -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | PYTHON_FILES = ./autochain ./tests 2 | 3 | TEST_ENV = . $(TEST_ENV_DIR)/bin/activate 4 | TEST_ENV_DIR = $(CURDIR)/venv 5 | 6 | .PHONY: black 7 | black: 8 | $(TEST_ENV) && \ 9 | black $(PYTHON_FILES) 10 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # AutoChain 2 | 3 | Large language models (LLMs) have shown huge success in different text generation tasks and 4 | enable developers to build generative agents based on objectives expressed in natural language. 5 | 6 | However, most generative agents require heavy customization for specific purposes, and 7 | supporting different use cases can sometimes be overwhelming using existing tools 8 | and frameworks. As a result, it is still very challenging to build a custom generative agent. 9 | 10 | In addition, evaluating such generative agents, which is usually done by manually trying different 11 | scenarios, is a very manual, repetitive, and expensive task. 12 | 13 | AutoChain takes inspiration from LangChain and AutoGPT and aims to solve 14 | both problems by providing a lightweight and extensible framework 15 | for developers to build their own agents using LLMs with custom tools and 16 | [automatically evaluating](#workflow-evaluation) different user scenarios with simulated 17 | conversations. Experienced user of LangChain would find AutoChain is easy to navigate since 18 | they share similar but simpler concepts. 19 | 20 | The goal is to enable rapid iteration on generative agents, both by simplifying agent customization 21 | and evaluation. 22 | 23 | If you have any questions, please feel free to reach out to Yi Lu 24 | 25 | ## Features 26 | 27 | - 🚀 lightweight and extensible generative agent pipeline. 28 | - 🔗 agent that can use different custom tools and 29 | support OpenAI [function calling](https://platform.openai.com/docs/guides/gpt/function-calling) 30 | - 💾 simple memory tracking for conversation history and tools' outputs 31 | - 🤖 automated agent multi-turn conversation evaluation with simulated conversations 32 | 33 | ## Setup 34 | 35 | Quick install 36 | 37 | ```shell 38 | pip install autochain 39 | ``` 40 | 41 | Or install from source after cloning this repository 42 | 43 | ```shell 44 | cd autochain 45 | pyenv virtualenv 3.10.11 venv 46 | pyenv local venv 47 | 48 | pip install . 49 | ``` 50 | 51 | Set `PYTHONPATH` and `OPENAI_API_KEY` 52 | 53 | ```shell 54 | export OPENAI_API_KEY= 55 | export PYTHONPATH=`pwd` 56 | ``` 57 | 58 | Run your first conversation with agent interactively 59 | 60 | ```shell 61 | python autochain/workflows_evaluation/conversational_agent_eval/generate_ads_test.py -i 62 | ``` 63 | 64 | ## How does AutoChain simplify building agents? 65 | 66 | AutoChain aims to provide a lightweight framework and simplifies the agent building process in a 67 | few 68 | ways, as compared to existing frameworks 69 | 70 | 1. Easy prompt update 71 | Engineering and iterating over prompts is a crucial part of building generative 72 | agent. AutoChain makes it very easy to update prompts and visualize prompt 73 | outputs. Run with `-v` flag to output verbose prompt and outputs in console. 74 | 2. Up to 2 layers of abstraction 75 | As part of enabling rapid iteration, AutoChain chooses to remove most of the 76 | abstraction layers from alternative frameworks 77 | 3. Automated multi-turn evaluation 78 | Evaluation is the most painful and undefined part of building generative agents. Updating the 79 | agent to better perform in one scenario often causes regression in other use cases. AutoChain 80 | provides a testing framework to automatically evaluate agent's ability under different 81 | user scenarios. 82 | 83 | ## Example usage 84 | 85 | If you have experience with LangChain, you already know 80% of the AutoChain interfaces. 86 | 87 | AutoChain aims to make building custom generative agents as straightforward as possible, with as 88 | little abstractions as possible. 89 | 90 | The most basic example uses the default chain and `ConversationalAgent`: 91 | 92 | ```python 93 | from autochain.chain.chain import Chain 94 | from autochain.memory.buffer_memory import BufferMemory 95 | from autochain.models.chat_openai import ChatOpenAI 96 | from autochain.agent.conversational_agent.conversational_agent import ConversationalAgent 97 | 98 | llm = ChatOpenAI(temperature=0) 99 | memory = BufferMemory() 100 | agent = ConversationalAgent.from_llm_and_tools(llm=llm) 101 | chain = Chain(agent=agent, memory=memory) 102 | 103 | print(chain.run("Write me a poem about AI")['message']) 104 | ``` 105 | 106 | Just like in LangChain, you can add a list of tools to the agent 107 | 108 | ```python 109 | tools = [ 110 | Tool( 111 | name="Get weather", 112 | func=lambda *args, **kwargs: "Today is a sunny day", 113 | description="""This function returns the weather information""" 114 | ) 115 | ] 116 | 117 | memory = BufferMemory() 118 | agent = ConversationalAgent.from_llm_and_tools(llm=llm, tools=tools) 119 | chain = Chain(agent=agent, memory=memory) 120 | print(chain.run("What is the weather today")['message']) 121 | ``` 122 | 123 | AutoChain also added support 124 | for [function calling](https://platform.openai.com/docs/guides/gpt/function-calling) 125 | in OpenAI models. Behind the scenes, it turns the function spec into OpenAI format without explicit 126 | instruction, so you can keep following the same `Tool` interface you are familiar with. 127 | 128 | ```python 129 | llm = ChatOpenAI(temperature=0) 130 | agent = OpenAIFunctionsAgent.from_llm_and_tools(llm=llm, tools=tools) 131 | ``` 132 | 133 | See [more examples](./docs/examples.md) under `autochain/examples` and [workflow 134 | evaluation](./docs/workflow-evaluation.md) test cases which can also be run interactively. 135 | 136 | Read more about detailed [components overview](./docs/components_overview.md) 137 | 138 | ## Workflow Evaluation 139 | 140 | It is notoriously hard to evaluate generative agents in LangChain or AutoGPT. An agent's behavior 141 | is nondeterministic and susceptible to small changes to the prompt or model. As such, it is 142 | hard to know what effects an update to the agent will have on all relevant use cases. 143 | 144 | The current path for 145 | evaluation is running the agent through a large number of preset queries and evaluate the 146 | generated responses. However, that is limited to single turn conversation, general and not 147 | specific to tasks and expensive to verify. 148 | 149 | To facilitate agent evaluation, AutoChain introduces the workflow evaluation framework. This 150 | framework runs conversations between a generative agent and LLM-simulated test users. The test 151 | users incorporate various user contexts and desired conversation outcomes, which enables easy 152 | addition of test cases for new user scenarios and fast evaluation. The framework leverages LLMs to 153 | evaluate whether a given multi-turn conversation has achieved the intended outcome. 154 | 155 | Read more about our [evaluation strategy](./docs/workflow-evaluation.md). 156 | 157 | ### How to run workflow evaluations 158 | 159 | You can either run your tests in interactive mode, or run the full suite of test cases at once. 160 | `autochain/workflows_evaluation/conversational_agent_eval/generate_ads_test.py` contains a few 161 | example test cases. 162 | 163 | To run all the cases defined in a test file: 164 | 165 | ```shell 166 | python autochain/workflows_evaluation/conversational_agent_eval/generate_ads_test.py 167 | ``` 168 | 169 | To run your tests interactively `-i`: 170 | 171 | ```shell 172 | python autochain/workflows_evaluation/conversational_agent_eval/generate_ads_test.py -i 173 | ``` 174 | 175 | Looking for more details on how AutoChain works? See 176 | our [components overview](./docs/components_overview.md) 177 | -------------------------------------------------------------------------------- /autochain/__init__.py: -------------------------------------------------------------------------------- 1 | __version__ = "0.0.1" 2 | -------------------------------------------------------------------------------- /autochain/agent/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Forethought-Technologies/AutoChain/5a1203bb01208b3e186c927222ae31702d309270/autochain/agent/__init__.py -------------------------------------------------------------------------------- /autochain/agent/base_agent.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from abc import ABC, abstractmethod 4 | from string import Template 5 | from typing import Any, List, Optional, Sequence, Union 6 | 7 | from autochain.agent.message import ChatMessageHistory 8 | from autochain.agent.prompt_formatter import JSONPromptTemplate 9 | from autochain.agent.structs import AgentAction, AgentFinish, AgentOutputParser 10 | from autochain.models.base import BaseLanguageModel 11 | from autochain.tools.base import Tool 12 | from pydantic import BaseModel 13 | 14 | 15 | class BaseAgent(BaseModel, ABC): 16 | output_parser: AgentOutputParser = None 17 | llm: BaseLanguageModel = None 18 | tools: Sequence[Tool] = [] 19 | 20 | @classmethod 21 | def from_llm_and_tools( 22 | cls, 23 | llm: BaseLanguageModel, 24 | tools: Sequence[Tool], 25 | prompt: str, 26 | output_parser: Optional[AgentOutputParser] = None, 27 | input_variables: Optional[List[str]] = None, 28 | **kwargs: Any, 29 | ) -> BaseAgent: 30 | """Construct an agent from an LLM and tools.""" 31 | 32 | def should_answer( 33 | self, should_answer_prompt_template: str = "", **kwargs 34 | ) -> Optional[AgentFinish]: 35 | """Determine if agent should continue to answer user questions based on the latest user 36 | query""" 37 | return None 38 | 39 | @abstractmethod 40 | def plan( 41 | self, 42 | history: ChatMessageHistory, 43 | intermediate_steps: List[AgentAction], 44 | **kwargs: Any, 45 | ) -> Union[AgentAction, AgentFinish]: 46 | """ 47 | Plan the next step. either taking an action with AgentAction or respond to user with AgentFinish 48 | Args: 49 | history: entire conversation history between user and agent including the latest query 50 | intermediate_steps: List of AgentAction that has been performed with outputs 51 | **kwargs: key value pairs from chain, which contains query and other stored memories 52 | 53 | Returns: 54 | AgentAction or AgentFinish 55 | """ 56 | 57 | def clarify_args_for_agent_action( 58 | self, 59 | agent_action: AgentAction, 60 | history: ChatMessageHistory, 61 | intermediate_steps: List[AgentAction], 62 | **kwargs: Any, 63 | ) -> Union[AgentAction, AgentFinish]: 64 | """ 65 | Ask clarifying question if needed. When agent is about to perform an action, we could 66 | use this function with different prompt to ask clarifying question for input if needed. 67 | Sometimes the planning response would already have the clarifying question, but we found 68 | it is more precise if there is a different prompt just for clarifying args 69 | 70 | Args: 71 | agent_action: agent action about to take 72 | history: conversation history including the latest query 73 | intermediate_steps: list of agent action taken so far 74 | **kwargs: 75 | 76 | Returns: 77 | Either a clarifying question (AgentFinish) or take the planned action (AgentAction) 78 | """ 79 | return agent_action 80 | 81 | def fix_action_input( 82 | self, tool: Tool, action: AgentAction, error: str 83 | ) -> Optional[AgentAction]: 84 | """If the tool failed due to error, what should be the fix for inputs""" 85 | pass 86 | 87 | @staticmethod 88 | def get_prompt_template( 89 | prompt: str = "", 90 | input_variables: Optional[List[str]] = None, 91 | ) -> JSONPromptTemplate: 92 | """Create prompt in the style of the zero shot agent. 93 | 94 | Args: 95 | prompt: message to be injected between prefix and suffix. 96 | input_variables: List of input variables the final prompt will expect. 97 | 98 | Returns: 99 | A PromptTemplate with the template assembled from the pieces here. 100 | """ 101 | template = Template(prompt) 102 | 103 | if input_variables is None: 104 | input_variables = ["input", "agent_scratchpad"] 105 | return JSONPromptTemplate(template=template, input_variables=input_variables) 106 | 107 | def is_generation_confident( 108 | self, 109 | history: ChatMessageHistory, 110 | agent_output: Union[AgentAction, AgentFinish], 111 | min_confidence: int = 3, 112 | ) -> bool: 113 | """Check if the generation is confident enough to take action""" 114 | return True 115 | -------------------------------------------------------------------------------- /autochain/agent/conversational_agent/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Forethought-Technologies/AutoChain/5a1203bb01208b3e186c927222ae31702d309270/autochain/agent/conversational_agent/__init__.py -------------------------------------------------------------------------------- /autochain/agent/conversational_agent/output_parser.py: -------------------------------------------------------------------------------- 1 | import json 2 | from typing import Union 3 | 4 | from colorama import Fore 5 | 6 | from autochain.agent.message import BaseMessage 7 | from autochain.agent.structs import AgentAction, AgentFinish, AgentOutputParser 8 | from autochain.errors import OutputParserException 9 | from autochain.utils import print_with_color 10 | 11 | 12 | class ConvoJSONOutputParser(AgentOutputParser): 13 | def parse(self, message: BaseMessage) -> Union[AgentAction, AgentFinish]: 14 | response = self.load_json_output(message) 15 | 16 | action_name = response.get("tool", {}).get("name") 17 | action_args = response.get("tool", {}).get("args") 18 | 19 | if ( 20 | "no" in response.get("thoughts", {}).get("need_use_tool").lower().strip() 21 | or not action_name 22 | ): 23 | output_message = response.get("response") 24 | if output_message: 25 | return AgentFinish(message=response.get("response"), log=output_message) 26 | else: 27 | return AgentFinish( 28 | message="Sorry, i don't understand", log=output_message 29 | ) 30 | 31 | return AgentAction( 32 | tool=action_name, 33 | tool_input=action_args, 34 | model_response=response.get("response", ""), 35 | ) 36 | 37 | def parse_clarification( 38 | self, message: BaseMessage, agent_action: AgentAction 39 | ) -> Union[AgentAction, AgentFinish]: 40 | response = self.load_json_output(message) 41 | 42 | has_arg_value = response.get("has_arg_value", "") 43 | clarifying_question = response.get("clarifying_question", "") 44 | 45 | if "no" in has_arg_value.lower() and clarifying_question: 46 | return AgentFinish(message=clarifying_question, log=clarifying_question) 47 | else: 48 | return agent_action 49 | -------------------------------------------------------------------------------- /autochain/agent/conversational_agent/prompt.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | PLANNING_PROMPT_TEMPLATE = """You are an assistant who tries to have helpful conversation 4 | with user based on previous conversation and previous tools outputs from tools. 5 | ${prompt} 6 | Use tool when provided. If there is no tool available, respond with have a helpful and polite 7 | conversation. Find next step without using the same tool with same inputs. 8 | 9 | Assistant has access to the following tools: 10 | ${tools} 11 | 12 | Previous conversation so far: 13 | ${history} 14 | 15 | Previous tools outputs: 16 | ${agent_scratchpad} 17 | 18 | Please respond user question in JSON format as described below 19 | RESPONSE FORMAT: 20 | { 21 | "thoughts": { 22 | "plan": "Given previous tools outputs, what is the next step after the previous conversation", 23 | "need_use_tool": "answer with 'Yes' if requires more information not in previous tools outputs else 'No'" 24 | }, 25 | "tool": { 26 | "name": "tool name, should be one of [${tool_names}] or empty if tool is not needed", 27 | "args": { 28 | "arg_name": "arg value from conversation history or tools outputs to run tool" 29 | } 30 | }, 31 | "response": "response to user given tools outputs and conversations", 32 | } 33 | 34 | Ensure the response can be parsed by Python json.loads 35 | """ 36 | 37 | SHOULD_ANSWER_PROMPT_TEMPLATE = """You are a support agent. 38 | Given the following conversation so far, has assistant finish helping user with all the 39 | questions? 40 | Answer with yes or no. 41 | 42 | Conversation: 43 | ${history} 44 | """ 45 | 46 | FIX_TOOL_INPUT_PROMPT_TEMPLATE = """Tool have the following spec and input provided 47 | Spec: "{tool_description}" 48 | Inputs: "{inputs}" 49 | Running this tool failed with the following error: "{error}" 50 | What is the correct input in JSON format for this tool? 51 | """ 52 | 53 | 54 | CLARIFYING_QUESTION_PROMPT_TEMPLATE = """You are a support agent who is going to use '${tool_name}' tool. 55 | Check if you have enough information from the previous conversation and tools outputs to use tool based on the spec below. 56 | "${tool_desp}" 57 | 58 | Previous conversation so far: 59 | ${history} 60 | 61 | Previous tools outputs: 62 | ${agent_scratchpad} 63 | 64 | Please respond user question in JSON format as described below 65 | RESPONSE FORMAT: 66 | { 67 | "has_arg_value": "Do values for all input args for '${tool_name}' tool exist? answer with Yes or No", 68 | "clarifying_question": "clarifying question to user to ask for missing information" 69 | } 70 | Ensure the response can be parsed by Python json.loads""" 71 | -------------------------------------------------------------------------------- /autochain/agent/conversational_agent/readme.md: -------------------------------------------------------------------------------- 1 | `ConversationalAgent` is a simple implementation of an agent to have a polite and helpful 2 | conversation with users. If tools were available, it would also use tools during the 3 | conversation. -------------------------------------------------------------------------------- /autochain/agent/message.py: -------------------------------------------------------------------------------- 1 | import enum 2 | from abc import abstractmethod 3 | from typing import Any, Dict, List 4 | 5 | from pydantic import BaseModel, Field 6 | 7 | 8 | class MessageType(enum.Enum): 9 | UserMessage = enum.auto() 10 | AIMessage = enum.auto() 11 | SystemMessage = enum.auto() 12 | FunctionMessage = enum.auto() 13 | 14 | 15 | class BaseMessage(BaseModel): 16 | """Message object.""" 17 | 18 | content: str 19 | additional_kwargs: dict = Field(default_factory=dict) 20 | 21 | @property 22 | @abstractmethod 23 | def type(self) -> str: 24 | """Type of the message, used for serialization.""" 25 | 26 | 27 | class UserMessage(BaseMessage): 28 | """Type of message that is spoken by the human.""" 29 | 30 | example: bool = False 31 | 32 | @property 33 | def type(self) -> str: 34 | """Type of the message, used for serialization.""" 35 | return "user" 36 | 37 | 38 | class AIMessage(BaseMessage): 39 | """Type of message that is spoken by the AI.""" 40 | 41 | example: bool = False 42 | function_call: Dict[str, Any] = {} 43 | 44 | @property 45 | def type(self) -> str: 46 | """Type of the message, used for serialization.""" 47 | return "ai" 48 | 49 | 50 | class SystemMessage(BaseMessage): 51 | """Type of message that is a system message.""" 52 | 53 | @property 54 | def type(self) -> str: 55 | """Type of the message, used for serialization.""" 56 | return "system" 57 | 58 | 59 | class FunctionMessage(BaseMessage): 60 | """Type of message that is a function message.""" 61 | 62 | name: str 63 | conversational_message: str = "" 64 | 65 | @property 66 | def type(self) -> str: 67 | """Type of the message, used for serialization.""" 68 | return "function" 69 | 70 | 71 | class ChatMessageHistory(BaseModel): 72 | messages: List[BaseMessage] = [] 73 | 74 | def save_message(self, message: str, message_type: MessageType, **kwargs): 75 | if message_type == MessageType.AIMessage: 76 | self.messages.append(AIMessage(content=message)) 77 | elif message_type == MessageType.UserMessage: 78 | self.messages.append(UserMessage(content=message)) 79 | elif message_type == MessageType.FunctionMessage: 80 | self.messages.append( 81 | FunctionMessage( 82 | content=message, 83 | name=kwargs["name"], 84 | conversational_message=kwargs["conversational_message"], 85 | ) 86 | ) 87 | elif message_type == MessageType.SystemMessage: 88 | self.messages.append(SystemMessage(content=message)) 89 | 90 | def format_message(self): 91 | string_messages = [] 92 | if len(self.messages) > 0: 93 | for m in self.messages: 94 | if isinstance(m, FunctionMessage): 95 | string_messages.append(f"Action: {m.conversational_message}") 96 | continue 97 | 98 | if isinstance(m, UserMessage): 99 | role = "User" 100 | elif isinstance(m, AIMessage): 101 | role = "Assistant" 102 | elif isinstance(m, SystemMessage): 103 | role = "System" 104 | else: 105 | continue 106 | string_messages.append(f"{role}: {m.content}") 107 | return "\n".join(string_messages) + "\n" 108 | return "" 109 | 110 | def get_latest_user_message(self) -> UserMessage: 111 | for message in reversed(self.messages): 112 | if isinstance(message, UserMessage): 113 | return message 114 | return UserMessage(content="n/a") 115 | 116 | def clear(self) -> None: 117 | self.messages = [] 118 | -------------------------------------------------------------------------------- /autochain/agent/openai_functions_agent/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Forethought-Technologies/AutoChain/5a1203bb01208b3e186c927222ae31702d309270/autochain/agent/openai_functions_agent/__init__.py -------------------------------------------------------------------------------- /autochain/agent/openai_functions_agent/openai_functions_agent.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import logging 4 | from string import Template 5 | from typing import Any, Dict, List, Optional, Union 6 | 7 | from autochain.agent.base_agent import BaseAgent 8 | from autochain.agent.message import ChatMessageHistory, SystemMessage, UserMessage 9 | from autochain.agent.openai_functions_agent.output_parser import ( 10 | OpenAIFunctionOutputParser, 11 | ) 12 | from autochain.agent.openai_functions_agent.prompt import ESTIMATE_CONFIDENCE_PROMPT 13 | from autochain.agent.structs import AgentAction, AgentFinish 14 | from autochain.models.base import BaseLanguageModel, Generation 15 | from autochain.tools.base import Tool 16 | from autochain.utils import print_with_color 17 | from colorama import Fore 18 | 19 | logger = logging.getLogger(__name__) 20 | 21 | 22 | class OpenAIFunctionsAgent(BaseAgent): 23 | """ 24 | Agent supports function calling natively in OpenAI, which leverage function message to 25 | determine which tool should be used 26 | When tool is not selected, responds just like conversational agent 27 | Tool descriptions are generated from typing from the tool 28 | """ 29 | 30 | llm: BaseLanguageModel = None 31 | allowed_tools: Dict[str, Tool] = {} 32 | tools: List[Tool] = [] 33 | prompt: Optional[str] = None 34 | min_confidence: int = 3 35 | 36 | @classmethod 37 | def from_llm_and_tools( 38 | cls, 39 | llm: BaseLanguageModel, 40 | tools: Optional[List[Tool]] = None, 41 | output_parser: Optional[OpenAIFunctionOutputParser] = None, 42 | prompt: str = None, 43 | min_confidence: int = 3, 44 | **kwargs: Any, 45 | ) -> OpenAIFunctionsAgent: 46 | tools = tools or [] 47 | 48 | allowed_tools = {tool.name: tool for tool in tools} 49 | _output_parser = output_parser or OpenAIFunctionOutputParser() 50 | return cls( 51 | llm=llm, 52 | allowed_tools=allowed_tools, 53 | output_parser=_output_parser, 54 | tools=tools, 55 | prompt=prompt, 56 | min_confidence=min_confidence, 57 | **kwargs, 58 | ) 59 | 60 | def plan( 61 | self, 62 | history: ChatMessageHistory, 63 | intermediate_steps: List[AgentAction], 64 | retries: int = 2, 65 | **kwargs: Any, 66 | ) -> Union[AgentAction, AgentFinish]: 67 | while retries > 0: 68 | print_with_color("Planning", Fore.LIGHTYELLOW_EX) 69 | 70 | final_messages = [] 71 | if self.prompt: 72 | final_messages.append(SystemMessage(content=self.prompt)) 73 | final_messages += history.messages 74 | 75 | logger.info(f"\nPlanning Input: {[m.content for m in final_messages]} \n") 76 | full_output: Generation = self.llm.generate( 77 | final_messages, self.tools 78 | ).generations[0] 79 | 80 | agent_output: Union[AgentAction, AgentFinish] = self.output_parser.parse( 81 | full_output.message 82 | ) 83 | print( 84 | f"Planning output: \nmessage content: {repr(full_output.message.content)}; " 85 | f"function_call: " 86 | f"{repr(full_output.message.function_call)}", 87 | Fore.YELLOW, 88 | ) 89 | if isinstance(agent_output, AgentAction): 90 | print_with_color( 91 | f"Plan to take action '{agent_output.tool}'", Fore.LIGHTYELLOW_EX 92 | ) 93 | 94 | generation_is_confident = self.is_generation_confident( 95 | history=history, 96 | agent_output=agent_output, 97 | min_confidence=self.min_confidence, 98 | ) 99 | if not generation_is_confident: 100 | retries -= 1 101 | print_with_color( 102 | f"Generation is not confident, {retries} retries left", 103 | Fore.LIGHTYELLOW_EX, 104 | ) 105 | continue 106 | else: 107 | return agent_output 108 | 109 | def is_generation_confident( 110 | self, 111 | history: ChatMessageHistory, 112 | agent_output: Union[AgentAction, AgentFinish], 113 | min_confidence: int = 3, 114 | ) -> bool: 115 | """ 116 | Estimate the confidence of the generation 117 | Args: 118 | history: history of the conversation 119 | agent_output: the output from the agent 120 | min_confidence: minimum confidence score to be considered as confident 121 | """ 122 | 123 | def _format_assistant_message(action_output: Union[AgentAction, AgentFinish]): 124 | if isinstance(action_output, AgentFinish): 125 | assistant_message = f"Assistant: {action_output.message}" 126 | elif isinstance(action_output, AgentAction): 127 | assistant_message = f"Action: {action_output.tool} with input: {action_output.tool_input}" 128 | else: 129 | raise ValueError("Unsupported action for estimating confidence score") 130 | 131 | return assistant_message 132 | 133 | prompt = Template(ESTIMATE_CONFIDENCE_PROMPT).substitute( 134 | policy=self.prompt, 135 | conversation_history=history.format_message(), 136 | assistant_message=_format_assistant_message(agent_output), 137 | ) 138 | logger.info(f"\nEstimate confidence prompt: {prompt} \n") 139 | 140 | message = UserMessage(content=prompt) 141 | 142 | full_output: Generation = self.llm.generate([message], self.tools).generations[ 143 | 0 144 | ] 145 | 146 | estimated_confidence = self.output_parser.parse_estimated_confidence( 147 | full_output.message 148 | ) 149 | 150 | return estimated_confidence >= min_confidence 151 | -------------------------------------------------------------------------------- /autochain/agent/openai_functions_agent/output_parser.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | import re 4 | from typing import Union 5 | 6 | from autochain.agent.message import AIMessage 7 | from autochain.agent.structs import AgentAction, AgentFinish, AgentOutputParser 8 | 9 | logger = logging.getLogger(__name__) 10 | 11 | 12 | class OpenAIFunctionOutputParser(AgentOutputParser): 13 | def parse(self, message: AIMessage) -> Union[AgentAction, AgentFinish]: 14 | if message.function_call: 15 | action_name = message.function_call["name"] 16 | action_args = json.loads(message.function_call["arguments"]) 17 | 18 | return AgentAction( 19 | tool=action_name, 20 | tool_input=action_args, 21 | model_response=message.content, 22 | ) 23 | else: 24 | return AgentFinish(message=message.content, log=message.content) 25 | 26 | def parse_estimated_confidence(self, message: AIMessage) -> int: 27 | """Parse estimated confidence from the message""" 28 | 29 | def find_first_integer(input_string): 30 | # Define a regular expression pattern to match integers 31 | pattern = re.compile(r"\d+") 32 | 33 | # Search for the first match in the input string 34 | match = pattern.search(input_string) 35 | 36 | # Check if a match is found 37 | if match: 38 | # Extract and return the matched integer 39 | return int(match.group()) 40 | else: 41 | # Return 0 if no integer is found 42 | logger.info(f"\nCannot find confidence in message: {input_string}\n") 43 | return 0 44 | 45 | content = message.content.strip() 46 | 47 | return find_first_integer(content) 48 | -------------------------------------------------------------------------------- /autochain/agent/openai_functions_agent/prompt.py: -------------------------------------------------------------------------------- 1 | ESTIMATE_CONFIDENCE_PROMPT = """Given the system policy assistant needs to strictly follow and 2 | the conversation history between user and assistant so far, 3 | "System policy: ${policy} 4 | ${conversation_history}" 5 | 6 | How confident are you the next step from assistant should be the following: 7 | "${assistant_message}" 8 | 9 | Estimate the confidence from 1-5, 1 being the least confident and 5 being the most confident. 10 | Confidence: 11 | """ 12 | -------------------------------------------------------------------------------- /autochain/agent/openai_functions_agent/readme.md: -------------------------------------------------------------------------------- 1 | `OpenAIFunctionsAgent` is an agent uses OpenAI function calling introduced after 0613. It uses a 2 | new type of message called `FunctionMessage` to specify the function outputs and specify 3 | available tools when calling generation. -------------------------------------------------------------------------------- /autochain/agent/prompt_formatter.py: -------------------------------------------------------------------------------- 1 | from string import Template 2 | from typing import Any, List 3 | 4 | from pydantic import BaseModel, Extra 5 | 6 | from autochain.agent.message import BaseMessage, UserMessage 7 | 8 | 9 | class JSONPromptTemplate(BaseModel): 10 | """ 11 | Format prompt with string Template and dictionary of variables 12 | """ 13 | 14 | template: Template 15 | """The prompt template.""" 16 | 17 | input_variables: List[str] 18 | """A list of the names of the variables the prompt template expects.""" 19 | 20 | class Config: 21 | """Configuration for this pydantic object.""" 22 | 23 | extra = Extra.forbid 24 | arbitrary_types_allowed = True 25 | 26 | def format_prompt(self, **kwargs: Any) -> List[BaseMessage]: 27 | variables = {v: "" for v in self.input_variables} 28 | variables.update(kwargs) 29 | prompt = self.template.substitute(**variables) 30 | return [UserMessage(content=prompt)] 31 | -------------------------------------------------------------------------------- /autochain/agent/structs.py: -------------------------------------------------------------------------------- 1 | import json 2 | from abc import abstractmethod 3 | from typing import Any, Dict, List, Union 4 | 5 | from autochain.agent.message import BaseMessage, UserMessage 6 | from autochain.chain import constants 7 | from autochain.models.base import Generation 8 | from autochain.models.chat_openai import ChatOpenAI 9 | from pydantic import BaseModel 10 | 11 | 12 | class AgentAction(BaseModel): 13 | """Agent's action to take.""" 14 | 15 | tool: str 16 | tool_input: Union[str, dict] 17 | """tool outputs""" 18 | tool_output: str = "" 19 | 20 | """log message for debugging""" 21 | log: str = "" 22 | 23 | """model response or """ 24 | model_response: str = "" 25 | 26 | @property 27 | def response(self): 28 | """message to be stored in memory and shared with next prompt""" 29 | if self.model_response and not self.tool_output: 30 | # share the model response or log message as output if tool fails to call 31 | return self.model_response 32 | return ( 33 | f"Outputs from using tool '{self.tool}' for inputs {self.tool_input} " 34 | f"is '{self.tool_output}'\n" 35 | ) 36 | 37 | 38 | class AgentFinish(BaseModel): 39 | """Agent's return value.""" 40 | 41 | message: str 42 | log: str 43 | intermediate_steps: List[AgentAction] = [] 44 | 45 | def format_output(self) -> Dict[str, Any]: 46 | final_output = { 47 | "message": self.message, 48 | constants.INTERMEDIATE_STEPS: self.intermediate_steps, 49 | } 50 | return final_output 51 | 52 | 53 | class AgentOutputParser(BaseModel): 54 | @staticmethod 55 | def load_json_output(message: BaseMessage) -> Dict[str, Any]: 56 | """If the message contains a json response, try to parse it into dictionary""" 57 | text = message.content 58 | clean_text = "" 59 | 60 | try: 61 | clean_text = text[text.index("{") : text.rindex("}") + 1].strip() 62 | response = json.loads(clean_text) 63 | except Exception: 64 | llm = ChatOpenAI(temperature=0) 65 | message = [ 66 | UserMessage( 67 | content=f"""Fix the following json into correct format 68 | ```json 69 | {clean_text} 70 | ``` 71 | """ 72 | ) 73 | ] 74 | full_output: Generation = llm.generate(message).generations[0] 75 | response = json.loads(full_output.message.content) 76 | 77 | return response 78 | 79 | @abstractmethod 80 | def parse(self, message: BaseMessage) -> Union[AgentAction, AgentFinish]: 81 | """Parse text into agent action/finish.""" 82 | 83 | def parse_clarification( 84 | self, message: BaseMessage, agent_action: AgentAction 85 | ) -> Union[AgentAction, AgentFinish]: 86 | """Parse clarification outputs""" 87 | return agent_action 88 | 89 | def parse_estimated_confidence(self, message: BaseMessage) -> int: 90 | """Parse estimated confidence from the message""" 91 | return 1 92 | -------------------------------------------------------------------------------- /autochain/chain/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Forethought-Technologies/AutoChain/5a1203bb01208b3e186c927222ae31702d309270/autochain/chain/__init__.py -------------------------------------------------------------------------------- /autochain/chain/base_chain.py: -------------------------------------------------------------------------------- 1 | """ 2 | Base interface that all chains should implement 3 | """ 4 | import logging 5 | import time 6 | from abc import ABC, abstractmethod 7 | from copy import deepcopy 8 | from typing import Any, Dict, List, Optional 9 | 10 | from autochain.agent.base_agent import BaseAgent 11 | from autochain.agent.message import ChatMessageHistory, MessageType 12 | from autochain.agent.structs import AgentAction, AgentFinish 13 | from autochain.chain import constants 14 | from autochain.memory.base import BaseMemory 15 | from autochain.tools.base import Tool 16 | from pydantic import BaseModel 17 | 18 | logger = logging.getLogger(__name__) 19 | 20 | 21 | class BaseChain(BaseModel, ABC): 22 | """ 23 | Base interface that all chains should implement. 24 | Chain will standardize inputs and outputs, the main entry pointy is the run function. 25 | """ 26 | 27 | agent: Optional[BaseAgent] = None 28 | memory: Optional[BaseMemory] = None 29 | last_query: str = "" 30 | max_iterations: Optional[int] = 15 31 | max_execution_time: Optional[float] = None 32 | 33 | def prep_inputs(self, user_query: str) -> Dict[str, str]: 34 | """Load conversation history from memory and prep inputs.""" 35 | inputs = { 36 | constants.CONVERSATION_HISTORY: ChatMessageHistory(), 37 | constants.INTERMEDIATE_STEPS: [], 38 | } 39 | if self.memory is not None: 40 | intermediate_steps = self.memory.load_memory( 41 | constants.INTERMEDIATE_STEPS, [] 42 | ) 43 | self.memory.save_conversation( 44 | message=user_query, message_type=MessageType.UserMessage 45 | ) 46 | 47 | inputs[constants.CONVERSATION_HISTORY] = deepcopy( 48 | self.memory.load_conversation() 49 | ) 50 | inputs[constants.INTERMEDIATE_STEPS] = deepcopy(intermediate_steps) 51 | 52 | return inputs 53 | 54 | def prep_output( 55 | self, 56 | inputs: Dict[str, str], 57 | output: AgentFinish, 58 | return_only_outputs: bool = False, 59 | ) -> Dict[str, Any]: 60 | """Save conversation into memory and prep outputs.""" 61 | output_dict = output.format_output() 62 | if self.memory is not None: 63 | self.memory.save_conversation( 64 | message=output.message, message_type=MessageType.AIMessage 65 | ) 66 | self.memory.save_memory( 67 | key=constants.INTERMEDIATE_STEPS, value=output.intermediate_steps 68 | ) 69 | 70 | if return_only_outputs: 71 | return output_dict 72 | else: 73 | return {**inputs, **output_dict} 74 | 75 | def run( 76 | self, 77 | user_query: str, 78 | return_only_outputs: bool = False, 79 | ) -> Dict[str, Any]: 80 | """Wrapper for _run function by formatting the input and outputs 81 | 82 | Args: 83 | user_query: user query 84 | return_only_outputs: boolean for whether to return only outputs in the 85 | response. If True, only new keys generated by this chain will be 86 | returned. If False, both input keys and new keys generated by this 87 | chain will be returned. Defaults to False. 88 | 89 | """ 90 | inputs = self.prep_inputs(user_query) 91 | logger.info(f"\n Input to agent: {inputs}") 92 | try: 93 | output = self._run(inputs) 94 | except (KeyboardInterrupt, Exception) as e: 95 | raise e 96 | 97 | return self.prep_output(inputs, output, return_only_outputs) 98 | 99 | def _run( 100 | self, 101 | inputs: Dict[str, Any], 102 | ) -> AgentFinish: 103 | """ 104 | Run inputs including user query and past conversation with agent and get response back 105 | calls take_next_step function to determine what should be the next step after 106 | collecting all the inputs and memorized contents 107 | """ 108 | # Construct a mapping of tool name to tool for easy lookup 109 | name_to_tool_map = {tool.name: tool for tool in self.agent.tools} 110 | 111 | intermediate_steps: List[AgentAction] = inputs[constants.INTERMEDIATE_STEPS] 112 | 113 | # Let's start tracking the number of iterations and time elapsed 114 | iterations = 0 115 | time_elapsed = 0.0 116 | start_time = time.time() 117 | # We now enter the agent loop (until it returns something). 118 | while self._should_continue(iterations, time_elapsed): 119 | logger.info(f"\n Intermediate steps: {intermediate_steps}\n") 120 | next_step_output = self.should_answer(inputs=inputs) 121 | 122 | # if next_step_output is None which means should ask agent to answer and take next 123 | # step 124 | if not next_step_output: 125 | next_step_output = self.take_next_step( 126 | name_to_tool_map, 127 | inputs, 128 | ) 129 | 130 | if isinstance(next_step_output, AgentFinish): 131 | next_step_output.intermediate_steps = intermediate_steps 132 | return next_step_output 133 | 134 | # stores action output into the conversation as FunctionMessage, which can be used by 135 | # OpenAIFunctionsAgent 136 | if isinstance(next_step_output, AgentAction): 137 | self.memory.save_conversation( 138 | message=str(next_step_output.tool_output), 139 | name=next_step_output.tool, 140 | conversational_message=f"{next_step_output.tool} with input: " 141 | f"{next_step_output.tool_input}", 142 | message_type=MessageType.FunctionMessage, 143 | ) 144 | 145 | intermediate_steps.append(next_step_output) 146 | # update inputs 147 | inputs[constants.INTERMEDIATE_STEPS] = intermediate_steps 148 | inputs[constants.CONVERSATION_HISTORY] = self.memory.load_conversation() 149 | 150 | iterations += 1 151 | time_elapsed = time.time() - start_time 152 | 153 | # force the termination when shouldn't continue 154 | output = AgentFinish( 155 | message="Agent stopped due to iteration limit or time limit.", 156 | log="", 157 | intermediate_steps=intermediate_steps, 158 | ) 159 | return output 160 | 161 | @abstractmethod 162 | def take_next_step( 163 | self, 164 | name_to_tool_map: Dict[str, Tool], 165 | inputs: Dict[str, str], 166 | ) -> (AgentFinish, AgentAction): 167 | """How agent determines the next step after observing the inputs and intermediate 168 | steps""" 169 | 170 | def _should_continue(self, iterations: int, time_elapsed: float) -> bool: 171 | if self.max_iterations is not None and iterations >= self.max_iterations: 172 | return False 173 | if ( 174 | self.max_execution_time is not None 175 | and time_elapsed >= self.max_execution_time 176 | ): 177 | return False 178 | 179 | return True 180 | 181 | def should_answer(self, inputs) -> Optional[AgentFinish]: 182 | """ 183 | Let agent determines if it should continue to answer questions 184 | or that is the end of the conversation 185 | Args: 186 | inputs: Dict contains user query and other memorized contents 187 | 188 | Returns: 189 | None if should answer 190 | AgentFinish if should NOT answer and respond to user with message 191 | """ 192 | output = None 193 | # check if agent should answer this query 194 | last_query = ( 195 | inputs[constants.CONVERSATION_HISTORY].get_latest_user_message().content 196 | ) 197 | if self.last_query != last_query: 198 | output = self.agent.should_answer(**inputs) 199 | self.last_query = last_query 200 | 201 | return output 202 | -------------------------------------------------------------------------------- /autochain/chain/chain.py: -------------------------------------------------------------------------------- 1 | """Default implementation of Chain""" 2 | import logging 3 | from typing import Dict 4 | 5 | from autochain.agent.structs import AgentAction, AgentFinish 6 | from autochain.chain.base_chain import BaseChain 7 | from autochain.errors import ToolRunningError 8 | from autochain.tools.base import Tool 9 | from autochain.tools.simple_handoff.tool import HandOffToAgent 10 | 11 | logger = logging.getLogger(__name__) 12 | 13 | 14 | class Chain(BaseChain): 15 | """ 16 | Default chain with take_next_step implemented 17 | It handles a few common error cases with agent, such as taking repeated action with same 18 | inputs and whether agent should continue the conversation 19 | """ 20 | 21 | return_intermediate_steps: bool = False 22 | handle_parsing_errors = True 23 | graceful_exit_tool: Tool = HandOffToAgent() 24 | 25 | def handle_repeated_action(self, agent_action: AgentAction) -> AgentFinish: 26 | print( 27 | f"Action taken before: {agent_action.tool}, " 28 | f"input: {agent_action.tool_input}" 29 | ) 30 | if agent_action.model_response: 31 | return AgentFinish( 32 | message=agent_action.response, 33 | log=f"Action taken before: {agent_action.tool}, " 34 | f"input: {agent_action.tool_input}", 35 | ) 36 | else: 37 | print("No response from agent. Gracefully exit due to repeated action") 38 | return AgentFinish( 39 | message=self.graceful_exit_tool.run(), 40 | log="Gracefully exit due to repeated action", 41 | ) 42 | 43 | def take_next_step( 44 | self, 45 | name_to_tool_map: Dict[str, Tool], 46 | inputs: Dict[str, str], 47 | ) -> (AgentFinish, AgentAction): 48 | """ 49 | How agent determines the next step after observing the inputs and intermediate steps 50 | Args: 51 | name_to_tool_map: map of tool name to the actual tool object 52 | inputs: a dictionary of all inputs, such as user query, past conversation and 53 | tools outputs 54 | 55 | Returns: 56 | Either AgentFinish to respond to user or AgentAction to take the next action 57 | """ 58 | 59 | try: 60 | # Call the LLM to see what to do. 61 | output = self.agent.plan( 62 | **inputs, 63 | ) 64 | except Exception as e: 65 | if not self.handle_parsing_errors: 66 | raise e 67 | tool_output = f"Invalid or incomplete response due to {e}" 68 | print(tool_output) 69 | output = AgentFinish(message=self.graceful_exit_tool.run(), log=tool_output) 70 | return output 71 | 72 | if isinstance(output, AgentAction): 73 | output = self.agent.clarify_args_for_agent_action(output, **inputs) 74 | 75 | # If agent plans to respond to AgentFinish or there is a clarifying question, respond to 76 | # user by returning AgentFinish 77 | if isinstance(output, AgentFinish): 78 | return output 79 | 80 | if isinstance(output, AgentAction): 81 | tool_output = "" 82 | # Check if tool is supported 83 | if output.tool in name_to_tool_map: 84 | tool = name_to_tool_map[output.tool] 85 | 86 | # how to handle the case where same action with same input is taken before 87 | if output.tool_input == self.memory.load_memory(tool.name): 88 | return self.handle_repeated_action(output) 89 | 90 | self.memory.save_memory(tool.name, output.tool_input) 91 | # We then call the tool on the tool input to get an tool_output 92 | try: 93 | tool_output = tool.run(output.tool_input) 94 | except ToolRunningError as e: 95 | new_agent_action = self.agent.fix_action_input( 96 | tool, output, error=str(e) 97 | ) 98 | if ( 99 | new_agent_action 100 | and new_agent_action.tool_input != output.tool_input 101 | ): 102 | tool_output = tool.run(output.tool_input) 103 | 104 | print( 105 | f"Took action '{tool.name}' with inputs '{output.tool_input}', " 106 | f"and the tool_output is {tool_output}" 107 | ) 108 | else: 109 | tool_output = f"Tool {output.tool} if not supported" 110 | 111 | output.tool_output = tool_output 112 | return output 113 | else: 114 | raise ValueError(f"Unsupported action: {type(output)}") 115 | -------------------------------------------------------------------------------- /autochain/chain/constants.py: -------------------------------------------------------------------------------- 1 | QUERY = "query" 2 | CONVERSATION_HISTORY = "history" 3 | 4 | # stores agent actions, which also contain their outputs 5 | INTERMEDIATE_STEPS = "intermediate_steps" 6 | -------------------------------------------------------------------------------- /autochain/chain/langchain_wrapper_chain.py: -------------------------------------------------------------------------------- 1 | """Wrapper of LangChain to follow the same chain interface""" 2 | from typing import Any, Dict, List, Optional 3 | 4 | from langchain.chains.base import Chain as LangChain 5 | from langchain.schema import BaseMemory 6 | 7 | from autochain.agent.structs import AgentFinish, AgentAction 8 | from autochain.chain.base_chain import BaseChain 9 | from autochain.tools.base import Tool 10 | 11 | 12 | class LangChainWrapperChain(BaseChain): 13 | """ 14 | Wrapper chain instantiate from LangChain's Chain object to match AutoChain interface 15 | """ 16 | 17 | langchain: LangChain = None 18 | memory: Optional[BaseMemory] = None 19 | 20 | def __init__(self, langchain: LangChain, *args, **kwargs): 21 | super().__init__(*args, **kwargs) 22 | self.langchain = langchain 23 | self.memory = self.langchain.memory 24 | 25 | def run( 26 | self, 27 | user_query: str, 28 | **kwargs, 29 | ) -> Dict[str, Any]: 30 | response_msg: str = self.langchain.run(user_query) 31 | agent_finish = AgentFinish(message=response_msg, log="") 32 | return agent_finish.format_output() 33 | 34 | def take_next_step( 35 | self, 36 | name_to_tool_map: Dict[str, Tool], 37 | inputs: Dict[str, str], 38 | ) -> (AgentFinish, AgentAction): 39 | pass 40 | -------------------------------------------------------------------------------- /autochain/errors.py: -------------------------------------------------------------------------------- 1 | class OutputParserException(Exception): 2 | """Exception that output parsers should raise to signify a parsing error. 3 | 4 | This exists to differentiate parsing errors from other code or execution errors 5 | that also may arise inside the output parser. OutputParserExceptions will be 6 | available to catch and handle in ways to fix the parsing error, while other 7 | errors will be raised. 8 | """ 9 | 10 | pass 11 | 12 | 13 | class ToolRunningError(Exception): 14 | """Exception when tool fails to run""" 15 | 16 | def __init__(self, message): 17 | self.message = message 18 | -------------------------------------------------------------------------------- /autochain/examples/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Forethought-Technologies/AutoChain/5a1203bb01208b3e186c927222ae31702d309270/autochain/examples/__init__.py -------------------------------------------------------------------------------- /autochain/examples/get_weather_with_conversational_agent.py: -------------------------------------------------------------------------------- 1 | from autochain.chain.chain import Chain 2 | from autochain.memory.buffer_memory import BufferMemory 3 | from autochain.models.chat_openai import ChatOpenAI 4 | from autochain.tools.base import Tool 5 | from autochain.agent.conversational_agent.conversational_agent import ( 6 | ConversationalAgent, 7 | ) 8 | from autochain.utils import get_args 9 | 10 | # Set logging level 11 | _ = get_args() 12 | 13 | llm = ChatOpenAI(temperature=0) 14 | tools = [ 15 | Tool( 16 | name="Get weather", 17 | func=lambda *args, **kwargs: "Today is a sunny day", 18 | description="""This function returns the weather information""", 19 | ) 20 | ] 21 | 22 | memory = BufferMemory() 23 | agent = ConversationalAgent.from_llm_and_tools(llm=llm, tools=tools) 24 | chain = Chain(agent=agent, memory=memory) 25 | 26 | user_query = "what is the weather today" 27 | print(f">> User: {user_query}") 28 | print(f">> Assistant: {chain.run(user_query)['message']}") 29 | next_user_query = "Boston" 30 | print(f">> User: {next_user_query}") 31 | print(f">> Assistant: {chain.run(next_user_query)['message']}") 32 | -------------------------------------------------------------------------------- /autochain/examples/get_weather_with_openai_function_agent.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | from autochain.agent.openai_functions_agent.openai_functions_agent import ( 4 | OpenAIFunctionsAgent, 5 | ) 6 | from autochain.chain.chain import Chain 7 | from autochain.memory.buffer_memory import BufferMemory 8 | from autochain.models.chat_openai import ChatOpenAI 9 | from autochain.tools.base import Tool 10 | from autochain.utils import get_args 11 | 12 | # Set logging level 13 | _ = get_args() 14 | 15 | 16 | def get_current_weather(location: str, unit: str = "fahrenheit"): 17 | """Get the current weather in a given location""" 18 | weather_info = { 19 | "location": location, 20 | "temperature": "72", 21 | "unit": unit, 22 | "forecast": ["sunny", "windy"], 23 | } 24 | return json.dumps(weather_info) 25 | 26 | 27 | tools = [ 28 | Tool( 29 | name="get_current_weather", 30 | func=get_current_weather, 31 | description="""Get the current weather in a given location""", 32 | ) 33 | ] 34 | 35 | memory = BufferMemory() 36 | llm = ChatOpenAI(temperature=0) 37 | agent = OpenAIFunctionsAgent.from_llm_and_tools(llm=llm, tools=tools) 38 | chain = Chain(agent=agent, memory=memory) 39 | 40 | # example 41 | user_query = "What's the weather today?" 42 | print(f">> User: {user_query}") 43 | print(f">> Assistant: {chain.run(user_query)['message']}") 44 | next_user_query = "Boston" 45 | print(f">> User: {next_user_query}") 46 | print(f">> Assistant: {chain.run(next_user_query)['message']}") 47 | -------------------------------------------------------------------------------- /autochain/examples/readme.md: -------------------------------------------------------------------------------- 1 | # Examples 2 | We provide a few simple examples for illustration purpose for how to use `ConversationalAgent` 3 | and `OpenAIFunctionsAgent` with and without tools. 4 | However, examples are not limited here. Workflow evaluation contains more example usages of 5 | those agents for different purposes, such as generating advertisement by interactive with user. 6 | Please checkout examples under [workflow evaluation](../workflows_evaluation) for more details. -------------------------------------------------------------------------------- /autochain/examples/upsale_goal_conversational_agent.py: -------------------------------------------------------------------------------- 1 | from autochain.agent.conversational_agent.conversational_agent import ( 2 | ConversationalAgent, 3 | ) 4 | from autochain.chain.chain import Chain 5 | from autochain.memory.buffer_memory import BufferMemory 6 | from autochain.models.chat_openai import ChatOpenAI 7 | from autochain.utils import get_args 8 | 9 | # Set logging level 10 | _ = get_args() 11 | 12 | prompt = ( 13 | "You are a sales agent who wants to up sale all customer inquire. Your goal is " 14 | "introducing more expensive options to user" 15 | ) 16 | 17 | llm = ChatOpenAI(temperature=0) 18 | memory = BufferMemory() 19 | agent = ConversationalAgent.from_llm_and_tools(llm=llm, prompt=prompt) 20 | chain = Chain(agent=agent, memory=memory) 21 | 22 | user_query = "How much is this basic rice cooker" 23 | print(f">>> User: {user_query}") 24 | print( 25 | f""">>> Assistant: 26 | {chain.run("How much is this basic rice cooker")["message"]} 27 | """ 28 | ) 29 | -------------------------------------------------------------------------------- /autochain/examples/write_poem_with_conversational_agent.py: -------------------------------------------------------------------------------- 1 | from autochain.agent.conversational_agent.conversational_agent import ( 2 | ConversationalAgent, 3 | ) 4 | from autochain.chain.chain import Chain 5 | from autochain.memory.buffer_memory import BufferMemory 6 | from autochain.models.chat_openai import ChatOpenAI 7 | from autochain.utils import get_args 8 | 9 | # Set logging level 10 | _ = get_args() 11 | 12 | llm = ChatOpenAI(temperature=0) 13 | memory = BufferMemory() 14 | agent = ConversationalAgent.from_llm_and_tools(llm=llm) 15 | chain = Chain(agent=agent, memory=memory) 16 | 17 | user_query = "Write me a poem about AI" 18 | print(f">> User: {user_query}") 19 | print( 20 | f""">>> Assistant: 21 | {chain.run(user_query)["message"]} 22 | """ 23 | ) 24 | -------------------------------------------------------------------------------- /autochain/memory/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Forethought-Technologies/AutoChain/5a1203bb01208b3e186c927222ae31702d309270/autochain/memory/__init__.py -------------------------------------------------------------------------------- /autochain/memory/base.py: -------------------------------------------------------------------------------- 1 | """Common memory schema object.""" 2 | from __future__ import annotations 3 | 4 | from abc import ABC, abstractmethod 5 | from typing import ( 6 | Any, 7 | Dict, 8 | Optional, 9 | Union, 10 | ) 11 | 12 | from pydantic import BaseModel 13 | 14 | from autochain.agent.message import ChatMessageHistory, MessageType 15 | 16 | 17 | class BaseMemory(BaseModel, ABC): 18 | """Base interface for memory in chains.""" 19 | 20 | @abstractmethod 21 | def load_memory( 22 | self, key: Union[str, None] = None, default: Optional[Any] = None, **kwargs: Any 23 | ) -> Any: 24 | """Return key-value pairs given the text input to the chain.""" 25 | 26 | @abstractmethod 27 | def load_conversation(self, **kwargs) -> ChatMessageHistory: 28 | """Return key-value pairs given the text input to the chain.""" 29 | 30 | @abstractmethod 31 | def save_memory(self, key: str, value: Any) -> None: 32 | """Save the context of this model run to memory.""" 33 | 34 | @abstractmethod 35 | def save_conversation( 36 | self, message: str, message_type: MessageType, **kwargs 37 | ) -> None: 38 | """Save the context of this model run to memory.""" 39 | 40 | @abstractmethod 41 | def clear(self) -> None: 42 | """Clear memory contents.""" 43 | -------------------------------------------------------------------------------- /autochain/memory/buffer_memory.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Optional 2 | 3 | from autochain.agent.message import ChatMessageHistory, MessageType 4 | from autochain.memory.base import BaseMemory 5 | 6 | 7 | class BufferMemory(BaseMemory): 8 | """Buffer for storing conversation memory and an in-memory kv store.""" 9 | 10 | conversation_history = ChatMessageHistory() 11 | kv_memory = {} 12 | 13 | def load_memory( 14 | self, key: Optional[str] = None, default: Optional[Any] = None, **kwargs 15 | ) -> Any: 16 | """Return history buffer by key or all memories.""" 17 | if not key: 18 | return self.kv_memory 19 | 20 | return self.kv_memory.get(key, default) 21 | 22 | def load_conversation(self, **kwargs) -> ChatMessageHistory: 23 | """Return history buffer and format it into a conversational string format.""" 24 | return self.conversation_history 25 | 26 | def save_memory(self, key: str, value: Any) -> None: 27 | self.kv_memory[key] = value 28 | 29 | def save_conversation( 30 | self, message: str, message_type: MessageType, **kwargs 31 | ) -> None: 32 | """Save context from this conversation to buffer.""" 33 | self.conversation_history.save_message( 34 | message=message, message_type=message_type, **kwargs 35 | ) 36 | 37 | def clear(self) -> None: 38 | """Clear memory contents.""" 39 | self.conversation_history.clear() 40 | self.kv_memory = {} 41 | -------------------------------------------------------------------------------- /autochain/memory/constants.py: -------------------------------------------------------------------------------- 1 | ONE_HOUR = 3600 2 | -------------------------------------------------------------------------------- /autochain/memory/long_term_memory.py: -------------------------------------------------------------------------------- 1 | """ 2 | This is an example implementation of long term memory and retrieve using query 3 | It contains three memory stores. 4 | conversation_history stores all the messages including FunctionMessage between assistant and agent, 5 | long_term_memory stores a collection of ChromaDoc (or would be modified use other vectory db) 6 | kv_memory: stores anything else as kv pairs 7 | """ 8 | from typing import Any, Optional 9 | 10 | from autochain.agent.message import ChatMessageHistory, MessageType 11 | from autochain.memory.base import BaseMemory 12 | from autochain.tools.internal_search.base_search_tool import BaseSearchTool 13 | from autochain.tools.internal_search.chromadb_tool import ChromaDBSearch, ChromaDoc 14 | from autochain.tools.internal_search.pinecone_tool import PineconeSearch, PineconeDoc 15 | from autochain.tools.internal_search.lancedb_tool import LanceDBSeach, LanceDBDoc 16 | 17 | SEARCH_PROVIDERS = (ChromaDBSearch, PineconeSearch, LanceDBSeach) 18 | SEARCH_DOC_TYPES = (ChromaDoc, PineconeDoc, LanceDBDoc) 19 | 20 | class LongTermMemory(BaseMemory): 21 | """Buffer for storing conversation memory and an in-memory kv store.""" 22 | 23 | conversation_history = ChatMessageHistory() 24 | kv_memory = {} 25 | long_term_memory: BaseSearchTool = None 26 | 27 | class Config: 28 | keep_untouched = SEARCH_PROVIDERS 29 | 30 | def load_memory( 31 | self, 32 | key: Optional[str] = None, 33 | default: Optional[Any] = None, 34 | top_k: int = 1, 35 | **kwargs 36 | ) -> Any: 37 | """Return history buffer by key or all memories.""" 38 | if key in self.kv_memory: 39 | return self.kv_memory[key] 40 | 41 | # else try to retrieve from long term memory 42 | result = self.long_term_memory.run({"query": key, "top_k": top_k}) 43 | return result or default 44 | 45 | def load_conversation(self, **kwargs) -> ChatMessageHistory: 46 | """Return history buffer and format it into a conversational string format.""" 47 | return self.conversation_history 48 | 49 | def save_memory(self, key: str, value: Any) -> None: 50 | if ( 51 | isinstance(value, list) 52 | and len(value) > 0 53 | and (isinstance(value[0], SEARCH_DOC_TYPES)) 54 | ): 55 | self.long_term_memory.add_docs(docs=value) 56 | elif key: 57 | self.kv_memory[key] = value 58 | 59 | def save_conversation( 60 | self, message: str, message_type: MessageType, **kwargs 61 | ) -> None: 62 | """Save context from this conversation to buffer.""" 63 | self.conversation_history.save_message( 64 | message=message, message_type=message_type, **kwargs 65 | ) 66 | 67 | def clear(self) -> None: 68 | """Clear memory contents.""" 69 | self.conversation_history.clear() 70 | self.long_term_memory.clear_index() 71 | self.kv_memory = {} 72 | -------------------------------------------------------------------------------- /autochain/memory/redis_memory.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | from typing import Any, Optional, Dict 3 | 4 | from autochain.agent.message import ( 5 | ChatMessageHistory, 6 | MessageType, 7 | BaseMessage, 8 | AIMessage, 9 | UserMessage, 10 | FunctionMessage, 11 | SystemMessage, 12 | ) 13 | from autochain.memory.base import BaseMemory 14 | from redis import Redis 15 | 16 | from autochain.memory.constants import ONE_HOUR 17 | 18 | 19 | class RedisMemory(BaseMemory): 20 | """Store conversation info in redis memory.""" 21 | 22 | expire_time: int = ONE_HOUR 23 | redis_key_prefix: str 24 | redis_client: Redis 25 | 26 | class Config: 27 | """Configuration for this pydantic object.""" 28 | 29 | arbitrary_types_allowed = True 30 | 31 | def load_memory( 32 | self, key: Optional[str] = None, default: Optional[Any] = None, **kwargs 33 | ) -> Any: 34 | """Get the key's corresponding value from redis.""" 35 | if not key.startswith(self.redis_key_prefix): 36 | key = self.redis_key_prefix + f":{key}" 37 | pickled = self.redis_client.get(key) 38 | if not pickled: 39 | return default 40 | return pickle.loads(pickled) 41 | 42 | def load_conversation(self, **kwargs: Dict[str, Any]) -> ChatMessageHistory: 43 | """Return chat message history.""" 44 | redis_key = self.redis_key_prefix + f":{ChatMessageHistory.__name__}" 45 | return ChatMessageHistory(messages=self.load_memory(redis_key, [])) 46 | 47 | def save_memory(self, key: str, value: Any) -> None: 48 | """Save the key value pair to redis.""" 49 | if not key.startswith(self.redis_key_prefix): 50 | key = self.redis_key_prefix + f":{key}" 51 | pickled = pickle.dumps(value) 52 | self.redis_client.set(key, pickled, ex=self.expire_time) 53 | 54 | def save_conversation( 55 | self, message: str, message_type: MessageType, **kwargs 56 | ) -> None: 57 | """Save context from this conversation to redis.""" 58 | redis_key = self.redis_key_prefix + f":{ChatMessageHistory.__name__}" 59 | pickled = self.redis_client.get(redis_key) 60 | if pickled: 61 | messages: list[BaseMessage] = pickle.loads(pickled) 62 | else: 63 | messages = [] 64 | if message_type == MessageType.AIMessage: 65 | messages.append(AIMessage(content=message)) 66 | elif message_type == MessageType.UserMessage: 67 | messages.append(UserMessage(content=message)) 68 | elif message_type == MessageType.FunctionMessage: 69 | messages.append(FunctionMessage(content=message, name=kwargs["name"])) 70 | elif message_type == MessageType.SystemMessage: 71 | messages.append(SystemMessage(content=message)) 72 | else: 73 | raise ValueError(f"Unsupported message type: {message_type}") 74 | self.save_memory(redis_key, messages) 75 | 76 | def clear(self) -> None: 77 | """Clear redis memory.""" 78 | for key in self.redis_client.keys(f"{self.redis_key_prefix}:*"): 79 | self.redis_client.delete(key) 80 | -------------------------------------------------------------------------------- /autochain/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Forethought-Technologies/AutoChain/5a1203bb01208b3e186c927222ae31702d309270/autochain/models/__init__.py -------------------------------------------------------------------------------- /autochain/models/ada_embedding.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import List, Optional, Any, Dict 3 | 4 | from pydantic import root_validator 5 | 6 | from autochain.tools.base import Tool 7 | 8 | from autochain.agent.message import BaseMessage 9 | 10 | from autochain.models.base import BaseLanguageModel, LLMResult, EmbeddingResult 11 | 12 | 13 | class OpenAIAdaEncoder(BaseLanguageModel): 14 | """ 15 | Text encoder using OpenAI Model 16 | """ 17 | 18 | client: Any #: :meta private: 19 | model_name: str = "text-embedding-ada-002" 20 | 21 | @root_validator() 22 | def validate_environment(cls, values: Dict) -> Dict: 23 | """Validate that api key and python package exists in environment.""" 24 | openai_api_key = os.environ["OPENAI_API_KEY"] 25 | try: 26 | import openai 27 | 28 | except ImportError: 29 | raise ValueError( 30 | "Could not import openai python package. " 31 | "Please install it with `pip install openai`." 32 | ) 33 | openai.api_key = openai_api_key 34 | try: 35 | values["client"] = openai.Embedding 36 | except AttributeError: 37 | raise ValueError( 38 | "`openai` has no `ChatCompletion` attribute, this is likely " 39 | "due to an old version of the openai package. Try upgrading it " 40 | "with `pip install --upgrade openai`." 41 | ) 42 | return values 43 | 44 | def generate( 45 | self, 46 | messages: List[BaseMessage], 47 | functions: Optional[List[Tool]] = None, 48 | stop: Optional[List[str]] = None, 49 | ) -> LLMResult: 50 | pass 51 | 52 | def encode(self, texts: List[str]) -> EmbeddingResult: 53 | def _format_response(texts, resp) -> EmbeddingResult: 54 | embeddings = [d.get("embedding") for d in resp.get("data", [])] 55 | return EmbeddingResult(texts=texts, embeddings=embeddings) 56 | 57 | params: Dict[str, Any] = { 58 | "model": self.model_name, 59 | "input": texts, 60 | **self._default_params, 61 | } 62 | 63 | response = self.generate_with_retry(**params) 64 | return _format_response(texts=texts, resp=response) 65 | -------------------------------------------------------------------------------- /autochain/models/base.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import logging 4 | from abc import abstractmethod 5 | from typing import Any, Callable, Dict, List, Optional, Tuple, Union 6 | 7 | from pydantic import Extra, Field, BaseModel 8 | from tenacity import ( 9 | before_sleep_log, 10 | retry, 11 | retry_if_exception_type, 12 | stop_after_attempt, 13 | wait_exponential, 14 | ) 15 | 16 | from autochain.agent.message import BaseMessage 17 | from autochain.tools.base import Tool 18 | 19 | logger = logging.getLogger(__name__) 20 | 21 | 22 | class Generation(BaseModel): 23 | """Output of a single generation.""" 24 | 25 | message: BaseMessage 26 | """Generated text output.""" 27 | 28 | generation_info: Optional[Dict[str, Any]] = None 29 | """Raw generation info response from the provider""" 30 | """May include things like reason for finishing (e.g. in OpenAI)""" 31 | # TODO: add log probs 32 | 33 | 34 | class LLMResult(BaseModel): 35 | """Class that contains all relevant information for an LLM Result.""" 36 | 37 | generations: List[Generation] 38 | """List of the things generated. This is List[List[]] because 39 | each input could have multiple generations.""" 40 | llm_output: Optional[dict] = None 41 | """For arbitrary LLM provider specific output.""" 42 | 43 | 44 | class EmbeddingResult(BaseModel): 45 | texts: List[str] 46 | embeddings: List[List[float]] 47 | 48 | 49 | class BaseLanguageModel(BaseModel): 50 | """Wrapper around OpenAI Chat large language models. 51 | 52 | To use, you should have the ``openai`` python package installed, and the 53 | environment variable ``OPENAI_API_KEY`` set with your API key. 54 | 55 | Any parameters that are valid to be passed to the openai.create call can be passed 56 | in, even if not explicitly saved on this class. 57 | 58 | Example: 59 | .. code-block:: python 60 | 61 | from autochain.models import ChatOpenAI 62 | openai = ChatOpenAI(model_name="gpt-3.5-turbo") 63 | """ 64 | 65 | client: Any #: :meta private: 66 | model_name: str = "gpt-3.5-turbo" 67 | """Model name to use.""" 68 | temperature: float = 0.7 69 | """What sampling temperature to use.""" 70 | model_kwargs: Dict[str, Any] = Field(default_factory=dict) 71 | """Holds any model parameters valid for `create` call not explicitly specified.""" 72 | openai_api_key: Optional[str] = None 73 | openai_organization: Optional[str] = None 74 | request_timeout: Optional[Union[float, Tuple[float, float]]] = None 75 | """Timeout for requests to OpenAI completion API. Default is 600 seconds.""" 76 | max_retries: int = 6 77 | """Maximum number of retries to make when generating.""" 78 | n: int = 1 79 | """Number of chat completions to generate for each prompt.""" 80 | max_tokens: Optional[int] = None 81 | """Maximum number of tokens to generate.""" 82 | 83 | class Config: 84 | """Configuration for this pydantic object.""" 85 | 86 | extra = Extra.ignore 87 | 88 | @property 89 | def _default_params(self) -> Dict[str, Any]: 90 | """Get the default parameters for calling OpenAI API.""" 91 | return { 92 | "model": self.model_name, 93 | "request_timeout": self.request_timeout, 94 | "max_tokens": self.max_tokens, 95 | "n": self.n, 96 | "temperature": self.temperature, 97 | **self.model_kwargs, 98 | } 99 | 100 | def _create_retry_decorator(self) -> Callable[[Any], Any]: 101 | import openai 102 | 103 | min_seconds = 1 104 | max_seconds = 60 105 | # Wait 2^x * 1 second between each retry starting with 106 | # 4 seconds, then up to 10 seconds, then 10 seconds afterwards 107 | return retry( 108 | reraise=True, 109 | stop=stop_after_attempt(self.max_retries), 110 | wait=wait_exponential(multiplier=1, min=min_seconds, max=max_seconds), 111 | retry=( 112 | retry_if_exception_type(openai.error.Timeout) 113 | | retry_if_exception_type(openai.error.APIError) 114 | | retry_if_exception_type(openai.error.APIConnectionError) 115 | | retry_if_exception_type(openai.error.RateLimitError) 116 | | retry_if_exception_type(openai.error.ServiceUnavailableError) 117 | ), 118 | before_sleep=before_sleep_log(logger, logging.WARNING), 119 | ) 120 | 121 | def generate_with_retry(self, **kwargs: Any) -> Any: 122 | """Use tenacity to retry the completion call.""" 123 | retry_decorator = self._create_retry_decorator() 124 | 125 | @retry_decorator 126 | def _generate_with_retry(**kwargs: Any) -> Any: 127 | return self.client.create(**kwargs) 128 | 129 | return _generate_with_retry(**kwargs) 130 | 131 | @abstractmethod 132 | def generate( 133 | self, 134 | messages: List[BaseMessage], 135 | functions: Optional[List[Tool]] = None, 136 | stop: Optional[List[str]] = None, 137 | ) -> LLMResult: 138 | pass 139 | 140 | def encode(self, texts: List[str]) -> EmbeddingResult: 141 | pass 142 | -------------------------------------------------------------------------------- /autochain/models/huggingface_text_generation_model.py: -------------------------------------------------------------------------------- 1 | """OpenAI chat wrapper.""" 2 | from __future__ import annotations 3 | 4 | import logging 5 | from typing import Any, Dict, List, Optional, Tuple, Union 6 | 7 | import torch 8 | from pydantic import Field 9 | from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer 10 | 11 | from autochain.agent.message import ( 12 | BaseMessage, 13 | AIMessage, 14 | ) 15 | from autochain.models.base import ( 16 | LLMResult, 17 | Generation, 18 | BaseLanguageModel, 19 | ) 20 | from autochain.tools.base import Tool 21 | 22 | logger = logging.getLogger(__name__) 23 | 24 | 25 | class HuggingFaceTextGenerationModel(BaseLanguageModel): 26 | """Huggingface model that supports text-generation task 27 | 28 | Example: 29 | .. code-block:: python 30 | 31 | from autochain.models.huggingface_text_generation_model import HuggingFaceTextGenerationModel 32 | llm = HuggingFaceTextGenerationModel(model_name="mosaicml/mpt-7b", model_kwargs={"trust_remote_code":True}) 33 | """ 34 | 35 | model_name: str = "gpt2" 36 | """Model name to use. GPT2 is only for demostration purpose. It does not work well for task 37 | planning""" 38 | temperature: float = 0 39 | """What sampling temperature to use.""" 40 | model_kwargs: Dict[str, Any] = Field(default_factory=dict) 41 | """Holds any model parameters valid for `create` call not explicitly specified.""" 42 | 43 | tokenizer_kwargs: Dict[str, Any] = Field(default_factory=dict) 44 | """Holds any model parameters valid for creating tokenizer.""" 45 | 46 | request_timeout: Optional[Union[float, Tuple[float, float]]] = None 47 | """Timeout for requests to OpenAI completion API. Default is 600 seconds.""" 48 | max_retries: int = 6 49 | # TODO: support streaming 50 | # """Maximum number of retries to make when generating.""" 51 | # streaming: bool = False 52 | # """Whether to stream the results or not.""" 53 | # n: int = 1 54 | """Number of chat completions to generate for each prompt.""" 55 | max_tokens: Optional[int] = 512 56 | """Maximum number of tokens to generate.""" 57 | 58 | default_stop_tokens: List[str] = ["."] 59 | """Model will generate tokens up to the number of max token, so it would be good to have 60 | default stop token""" 61 | 62 | model: Optional[AutoModelForCausalLM] 63 | tokenizer: Optional[AutoTokenizer] 64 | 65 | class Config: 66 | """Configuration for this pydantic object.""" 67 | 68 | arbitrary_types_allowed = True 69 | 70 | def __init__(self, *args, **kwargs): 71 | super().__init__(*args, **kwargs) 72 | if torch.cuda.is_available(): 73 | self.model_kwargs["device_map"] = "auto" 74 | 75 | self.tokenizer = AutoTokenizer.from_pretrained( 76 | self.model_name, **self.tokenizer_kwargs 77 | ) 78 | 79 | def generate( 80 | self, 81 | messages: List[BaseMessage], 82 | functions: Optional[List[Tool]] = None, 83 | stop: Optional[List[str]] = None, 84 | ) -> LLMResult: 85 | generator = pipeline( 86 | task="text-generation", 87 | model=self.model_name, 88 | tokenizer=self.tokenizer, 89 | max_new_tokens=self.max_tokens, 90 | temperature=self.temperature, 91 | **self.model_kwargs, 92 | ) 93 | 94 | prompt = self._construct_prompt_from_message(messages) 95 | generation = generator(prompt, do_sample=False) 96 | 97 | return self._create_llm_result(generation=generation, prompt=prompt, stop=stop) 98 | 99 | @staticmethod 100 | def _construct_prompt_from_message(messages: List[BaseMessage]): 101 | prompt = "" 102 | for msg in messages: 103 | prompt += msg.content 104 | return prompt 105 | 106 | @staticmethod 107 | def _enforce_stop_tokens(text: str, stop: List[str]) -> str: 108 | """Cut off the text as soon as any stop words occur.""" 109 | first_index = len(text) 110 | for s in stop: 111 | if s in text: 112 | first_index = min(text.index(s), first_index) 113 | 114 | return text[:first_index].strip() 115 | 116 | def _create_llm_result( 117 | self, generation: List[Dict[str, Any]], prompt: str, stop: List[str] 118 | ) -> LLMResult: 119 | text = generation[0]["generated_text"][len(prompt) :] 120 | if self.max_tokens: 121 | token_ids = self.tokenizer.encode(text)[: self.max_tokens] 122 | text = self.tokenizer.decode(token_ids) 123 | 124 | # it is better to have a default stop token so model does not always generate to max 125 | # sequence length 126 | stop = stop or self.default_stop_tokens 127 | text = self._enforce_stop_tokens(text=text, stop=stop) 128 | 129 | return LLMResult( 130 | generations=[Generation(message=AIMessage(content=text))], 131 | llm_output={ 132 | "token_usage": len(text.split()), 133 | "model_name": self.model_name, 134 | }, 135 | ) 136 | -------------------------------------------------------------------------------- /autochain/models/readme.md: -------------------------------------------------------------------------------- 1 | # Models 2 | 3 | ## Huggingface text generation model 4 | To use open sourced model from huggingface, AutoChain introduces 5 | `HuggingFaceTextGenerationModel` to support this use case. 6 | Requirements to be installed 7 | ```shell 8 | transformers 9 | torch 10 | accelerate 11 | ``` 12 | 13 | Example usage 14 | ```python 15 | from autochain.models.huggingface_text_generation_model import ( 16 | HuggingFaceTextGenerationModel, 17 | ) 18 | from autochain.agent.conversational_agent.conversational_agent import ( 19 | ConversationalAgent, 20 | ) 21 | 22 | llm = HuggingFaceTextGenerationModel(model_name="mosaicml/mpt-7b", 23 | model_kwargs={"trust_remote_code":True}) 24 | agent = ConversationalAgent.from_llm_and_tools(llm=llm) 25 | ``` 26 | > Task planning could be a too challenging task for "small" model -------------------------------------------------------------------------------- /autochain/py.typed: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Forethought-Technologies/AutoChain/5a1203bb01208b3e186c927222ae31702d309270/autochain/py.typed -------------------------------------------------------------------------------- /autochain/tools/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Forethought-Technologies/AutoChain/5a1203bb01208b3e186c927222ae31702d309270/autochain/tools/__init__.py -------------------------------------------------------------------------------- /autochain/tools/base.py: -------------------------------------------------------------------------------- 1 | """Base implementation for tools or skills.""" 2 | from __future__ import annotations 3 | 4 | import inspect 5 | from abc import ABC 6 | from typing import Any, Callable, Dict, Optional, Tuple, Type, Union 7 | 8 | from autochain.errors import ToolRunningError 9 | from pydantic import ( 10 | BaseModel, 11 | root_validator, 12 | ) 13 | 14 | 15 | class Tool(ABC, BaseModel): 16 | """Interface AutoChain tools must implement.""" 17 | 18 | name: Optional[str] = None 19 | """The unique name of the tool that clearly communicates its purpose. 20 | If not provided, it will be named after the func name. 21 | The more descriptive it is, the easier it would be for model to call the right tool 22 | """ 23 | 24 | description: str 25 | """Used to tell the model how/when/why to use the tool. 26 | You can provide few-shot examples as a part of the description. 27 | """ 28 | 29 | arg_description: Optional[Dict[str, Any]] = None 30 | """Dictionary of arg name and description when using OpenAIFunctionsAgent to provide 31 | additional argument information""" 32 | 33 | args_schema: Optional[Type[BaseModel]] = None 34 | """Pydantic model class to validate and parse the tool's input arguments.""" 35 | 36 | func: Union[Callable[..., str], None] = None 37 | 38 | @root_validator() 39 | def validate_environment(cls, values: Dict) -> Dict: 40 | """Validate that api key and python package exists in environment.""" 41 | func = values.get("func") 42 | if func and not values.get("name"): 43 | values["name"] = values["func"].__name__ 44 | 45 | # check if all args from arg_description exist in func args 46 | if values.get("arg_description") and func: 47 | inspection = inspect.getfullargspec(func) 48 | override_args = set(values["arg_description"].keys()) 49 | args = set(inspection.args) 50 | override_without_args = override_args - args 51 | if len(override_without_args) > 0: 52 | raise ValueError( 53 | f"Provide arg description for not existed args: {override_without_args}" 54 | ) 55 | 56 | return values 57 | 58 | def _parse_input( 59 | self, 60 | tool_input: Union[str, Dict], 61 | ) -> Union[str, Dict[str, Any]]: 62 | """Convert tool input to pydantic model.""" 63 | input_args = self.args_schema 64 | if isinstance(tool_input, str): 65 | if input_args is not None: 66 | key_ = next(iter(input_args.__fields__.keys())) 67 | input_args.validate({key_: tool_input}) 68 | return tool_input 69 | else: 70 | if input_args is not None: 71 | result = input_args.parse_obj(tool_input) 72 | return {k: v for k, v in result.dict().items() if k in tool_input} 73 | return tool_input 74 | 75 | def _to_args_and_kwargs(self, tool_input: Union[str, Dict]) -> Tuple[Tuple, Dict]: 76 | # For backwards compatibility, if run_input is a string, 77 | # pass as a positional argument. 78 | if isinstance(tool_input, str): 79 | return (tool_input,), {} 80 | else: 81 | return (), tool_input 82 | 83 | def _run( 84 | self, 85 | *args: Any, 86 | **kwargs: Any, 87 | ) -> str: 88 | return self.func(*args, **kwargs) 89 | 90 | def run( 91 | self, 92 | tool_input: Union[str, Dict] = "", 93 | **kwargs: Any, 94 | ) -> str: 95 | """Run the tool.""" 96 | try: 97 | parsed_input = self._parse_input(tool_input) 98 | except ValueError as e: 99 | # return exception as tool output 100 | raise ToolRunningError(message=f"Tool input args value Error: {e}") from e 101 | 102 | try: 103 | tool_args, tool_kwargs = self._to_args_and_kwargs(parsed_input) 104 | tool_output = self._run(*tool_args, **tool_kwargs) 105 | except (Exception, KeyboardInterrupt) as e: 106 | raise ToolRunningError( 107 | message=f"Failed to run tool {self.name} due to {e}" 108 | ) from e 109 | 110 | return tool_output 111 | -------------------------------------------------------------------------------- /autochain/tools/google_search/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Forethought-Technologies/AutoChain/5a1203bb01208b3e186c927222ae31702d309270/autochain/tools/google_search/__init__.py -------------------------------------------------------------------------------- /autochain/tools/google_search/tool.py: -------------------------------------------------------------------------------- 1 | from autochain.tools.base import Tool 2 | from autochain.tools.google_search.util import GoogleSearchAPIWrapper 3 | 4 | 5 | class GoogleSearchTool(Tool): 6 | """Tool that has capability to query the Google Search API and get back json.""" 7 | 8 | name = "Google Search Results JSON" 9 | description = ( 10 | "A wrapper around Google Search. " 11 | "Useful for when you need to answer questions about current events. " 12 | "Input should be a search query. Output is a JSON array of the query results" 13 | ) 14 | num_results: int = 4 15 | api_wrapper: GoogleSearchAPIWrapper 16 | 17 | def _run( 18 | self, 19 | query: str, 20 | ) -> str: 21 | """Use the tool.""" 22 | return str(self.api_wrapper.results(query, self.num_results)) 23 | -------------------------------------------------------------------------------- /autochain/tools/google_search/util.py: -------------------------------------------------------------------------------- 1 | """Util that calls Google Search.""" 2 | from typing import Any, Dict, List, Optional 3 | 4 | from pydantic import BaseModel, Extra, root_validator 5 | 6 | from autochain.utils import get_from_dict_or_env 7 | 8 | 9 | class GoogleSearchAPIWrapper(BaseModel): 10 | """Wrapper for Google Search API. 11 | 12 | Adapted from: Instructions adapted from https://stackoverflow.com/questions/ 13 | 37083058/ 14 | programmatically-searching-google-in-python-using-custom-search 15 | 16 | TODO: DOCS for using it 17 | 1. Install google-api-python-client 18 | - If you don't already have a Google account, sign up. 19 | - If you have never created a Google APIs Console project, 20 | read the Managing Projects page and create a project in the Google API Console. 21 | - Install the library using pip install google-api-python-client 22 | The current version of the library is 2.70.0 at this time 23 | 24 | 2. To create an API key: 25 | - Navigate to the APIs & Services→Credentials panel in Cloud Console. 26 | - Select Create credentials, then select API key from the drop-down menu. 27 | - The API key created dialog box displays your newly created key. 28 | - You now have an API_KEY 29 | 30 | 3. Setup Custom Search Engine so you can search the entire web 31 | - Create a custom search engine in this link. 32 | - In Sites to search, add any valid URL (i.e. www.stackoverflow.com). 33 | - That’s all you have to fill up, the rest doesn’t matter. 34 | In the left-side menu, click Edit search engine → {your search engine name} 35 | → Setup Set Search the entire web to ON. Remove the URL you added from 36 | the list of Sites to search. 37 | - Under Search engine ID you’ll find the search-engine-ID. 38 | 39 | 4. Enable the Custom Search API 40 | - Navigate to the APIs & Services→Dashboard panel in Cloud Console. 41 | - Click Enable APIs and Services. 42 | - Search for Custom Search API and click on it. 43 | - Click Enable. 44 | URL for it: https://console.cloud.google.com/apis/library/customsearch.googleapis 45 | .com 46 | """ 47 | 48 | search_engine: Any #: :meta private: 49 | google_api_key: Optional[str] = None 50 | google_cse_id: Optional[str] = None 51 | k: int = 10 52 | siterestrict: bool = False 53 | 54 | class Config: 55 | """Configuration for this pydantic object.""" 56 | 57 | extra = Extra.forbid 58 | 59 | def _google_search_results(self, search_term: str, **kwargs: Any) -> List[dict]: 60 | cse = self.search_engine.cse() 61 | if self.siterestrict: 62 | cse = cse.siterestrict() 63 | res = cse.list(q=search_term, cx=self.google_cse_id, **kwargs).execute() 64 | return res.get("items", []) 65 | 66 | @root_validator() 67 | def validate_environment(cls, values: Dict) -> Dict: 68 | """Validate that api key and python package exists in environment.""" 69 | google_api_key = get_from_dict_or_env( 70 | values, "google_api_key", "GOOGLE_API_KEY" 71 | ) 72 | values["google_api_key"] = google_api_key 73 | 74 | google_cse_id = get_from_dict_or_env(values, "google_cse_id", "GOOGLE_CSE_ID") 75 | values["google_cse_id"] = google_cse_id 76 | 77 | try: 78 | from googleapiclient.discovery import build 79 | 80 | except ImportError: 81 | raise ImportError( 82 | "google-api-python-client is not installed. " 83 | "Please install it with `pip install google-api-python-client`" 84 | ) 85 | 86 | service = build("customsearch", "v1", developerKey=google_api_key) 87 | values["search_engine"] = service 88 | 89 | return values 90 | 91 | def run(self, query: str) -> str: 92 | """Run query through GoogleSearch and parse result.""" 93 | snippets = [] 94 | results = self._google_search_results(query, num=self.k) 95 | if len(results) == 0: 96 | return "No good Google Search Result was found" 97 | for result in results: 98 | if "snippet" in result: 99 | snippets.append(result["snippet"]) 100 | 101 | return " ".join(snippets) 102 | 103 | def results(self, query: str, num_results: int) -> List[Dict]: 104 | """Run query through GoogleSearch and return metadata. 105 | 106 | Args: 107 | query: The query to search for. 108 | num_results: The number of results to return. 109 | 110 | Returns: 111 | A list of dictionaries with the following keys: 112 | snippet - The description of the result. 113 | title - The title of the result. 114 | link - The link to the result. 115 | """ 116 | metadata_results = [] 117 | results = self._google_search_results(query, num=num_results) 118 | if len(results) == 0: 119 | return [{"Result": "No good Google Search Result was found"}] 120 | for result in results: 121 | metadata_result = { 122 | "title": result["title"], 123 | "link": result["link"], 124 | } 125 | if "snippet" in result: 126 | metadata_result["snippet"] = result["snippet"] 127 | metadata_results.append(metadata_result) 128 | 129 | return metadata_results 130 | -------------------------------------------------------------------------------- /autochain/tools/internal_search/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Forethought-Technologies/AutoChain/5a1203bb01208b3e186c927222ae31702d309270/autochain/tools/internal_search/__init__.py -------------------------------------------------------------------------------- /autochain/tools/internal_search/base_search_tool.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod 2 | from typing import Any, List 3 | 4 | from pydantic import Extra, BaseModel 5 | 6 | 7 | class BaseSearchTool(BaseModel): 8 | class Config: 9 | """Configuration for this pydantic object.""" 10 | 11 | extra = Extra.forbid 12 | arbitrary_types_allowed = True 13 | 14 | @abstractmethod 15 | def _run( 16 | self, 17 | query: str, 18 | top_k: int = 2, 19 | *args: Any, 20 | **kwargs: Any, 21 | ) -> str: 22 | raise NotImplementedError 23 | 24 | @abstractmethod 25 | def add_docs(self, docs: List[Any], **kwargs): 26 | raise NotImplementedError 27 | 28 | @abstractmethod 29 | def clear_index(self): 30 | raise NotImplementedError 31 | -------------------------------------------------------------------------------- /autochain/tools/internal_search/chromadb_tool.py: -------------------------------------------------------------------------------- 1 | import uuid 2 | from dataclasses import dataclass, field 3 | from typing import List, Any, Dict, Optional 4 | 5 | import chromadb 6 | from chromadb.api import QueryResult 7 | from pydantic import Extra 8 | 9 | from autochain.tools.base import Tool 10 | from autochain.tools.internal_search.base_search_tool import BaseSearchTool 11 | 12 | 13 | @dataclass 14 | class ChromaDoc: 15 | doc: str 16 | metadata: Dict[str, Any] 17 | id: str = field(default_factory=lambda: str(uuid.uuid1())) 18 | 19 | 20 | class ChromaDBSearch(Tool, BaseSearchTool): 21 | """ 22 | Use ChromaDB as internal search tool 23 | """ 24 | 25 | collection_name: str = "index" 26 | collection: Optional[Any] = None 27 | 28 | class Config: 29 | """Configuration for this pydantic object.""" 30 | 31 | extra = Extra.forbid 32 | arbitrary_types_allowed = True 33 | 34 | def __init__(self, docs: List[ChromaDoc], **kwargs): 35 | super().__init__(**kwargs) 36 | client = chromadb.Client() 37 | 38 | collection = client.create_collection(self.collection_name) 39 | self.collection = collection 40 | 41 | # Add docs to the collection. Can also update and delete. Row-based API coming soon! 42 | self.add_docs(docs=docs) 43 | 44 | def _run( 45 | self, 46 | query: str, 47 | top_k: int = 2, 48 | *args: Any, 49 | **kwargs: Any, 50 | ) -> str: 51 | def _format_output(query_result: QueryResult) -> str: 52 | """Only return the document since they are likely to be passed to prompt""" 53 | documents = query_result.get("documents", []) 54 | if len(documents) == 0: 55 | return "" 56 | 57 | docs = documents[0] 58 | return "\n".join([f"Doc {i}: {doc}" for i, doc in enumerate(docs)]) 59 | 60 | result = self.collection.query( 61 | query_texts=[query], 62 | n_results=top_k, 63 | ) 64 | return _format_output(result) 65 | 66 | def add_docs(self, docs: List[ChromaDoc], **kwargs): 67 | """Add a list of documents to collection""" 68 | if docs: 69 | self.collection.add( 70 | documents=[d.doc for d in docs], 71 | # we embed for you, or bring your own 72 | metadatas=[d.metadata for d in docs], 73 | # filter on arbitrary metadata! 74 | ids=[d.id for d in docs], # must be unique for each doc 75 | ) 76 | 77 | def clear_index(self): 78 | self.collection.delete() 79 | -------------------------------------------------------------------------------- /autochain/tools/internal_search/lancedb_tool.py: -------------------------------------------------------------------------------- 1 | from typing import List, Any 2 | from dataclasses import dataclass 3 | 4 | import lancedb 5 | import pandas as pd 6 | 7 | from autochain.tools.base import Tool 8 | from autochain.models.base import BaseLanguageModel 9 | from autochain.tools.internal_search.base_search_tool import BaseSearchTool 10 | 11 | @dataclass 12 | class LanceDBDoc: 13 | doc: str 14 | vector: List[float] = None 15 | 16 | class LanceDBSeach(Tool, BaseSearchTool): 17 | """ 18 | Use LanceDB as the internal search tool 19 | 20 | LanceDB is a vector database that supports vector search. 21 | 22 | Args: 23 | uri: the uri of the database. Default to "lancedb" 24 | table_name: the name of the table. Default to "table" 25 | metric: the metric used for vector search. Default to "cosine" 26 | encoder: the encoder used to encode the documents. Default to None 27 | docs: the documents to be indexed. Default to None 28 | """ 29 | class Config: 30 | """Configuration for this pydantic object.""" 31 | 32 | arbitrary_types_allowed = True 33 | 34 | docs: List[LanceDBDoc] 35 | uri: str = "lancedb" 36 | table_name: str = "table" 37 | metric: str = "cosine" 38 | encoder: BaseLanguageModel = None 39 | db: lancedb.db.DBConnection = None 40 | table: lancedb.table.Table = None 41 | def __init__(self, **kwargs) -> None: 42 | super().__init__(**kwargs) 43 | self.db = lancedb.connect(self.uri) 44 | if self.docs: 45 | self._encode_docs(self.docs) 46 | self._create_table(self.docs) 47 | 48 | def _create_table(self, docs: List[LanceDBDoc]) -> None: 49 | self.table = self.db.create_table(self.table_name, self._docs_to_dataframe(docs), mode="overwrite") 50 | 51 | def _encode_docs(self, docs: List[LanceDBDoc]) -> None: 52 | for doc in docs: 53 | if not doc.vector: 54 | if not self.encoder: 55 | raise ValueError("Encoder is not provided for encoding docs") 56 | doc.vector = self.encoder.encode([doc.doc]).embeddings[0] 57 | 58 | def _docs_to_dataframe(self, docs: List[LanceDBDoc]) -> pd.DataFrame: 59 | return pd.DataFrame( 60 | [ 61 | {"doc": doc.doc, "vector": doc.vector} 62 | for doc in docs 63 | ] 64 | ) 65 | 66 | def _run( 67 | self, 68 | query: str, 69 | top_k: int = 2, 70 | *args: Any, 71 | **kwargs: Any, 72 | ) -> str: 73 | if self.table is None: 74 | return "" 75 | 76 | embeddings = self.encoder.encode([query]).embeddings[0] 77 | result = self.table.search(embeddings).limit(top_k).to_df()["doc"].to_list() 78 | 79 | return "\n".join([f"Doc {i}: {doc}" for i, doc in enumerate(result)]) 80 | 81 | def add_docs(self, docs: List[LanceDBDoc], **kwargs): 82 | if not len(docs): 83 | return 84 | 85 | self._encode_docs(docs) 86 | self.table.add(self._docs_to_dataframe(docs)) if self.table else self._create_table(docs) 87 | 88 | def clear_index(self): 89 | if self.table_name in self.db.table_names(): 90 | self.db.drop_table(self.table_name) 91 | self.table = None 92 | -------------------------------------------------------------------------------- /autochain/tools/internal_search/pinecone_tool.py: -------------------------------------------------------------------------------- 1 | import uuid 2 | from dataclasses import dataclass, field 3 | from typing import List, Any, Optional, Dict 4 | 5 | import pinecone 6 | from pinecone import QueryResponse 7 | 8 | from autochain.models.base import BaseLanguageModel 9 | from autochain.tools.base import Tool 10 | from autochain.tools.internal_search.base_search_tool import BaseSearchTool 11 | 12 | 13 | @dataclass 14 | class PineconeDoc: 15 | doc: str 16 | vector: List[float] = None 17 | id: str = field(default_factory=lambda: str(uuid.uuid1())) 18 | 19 | 20 | class PineconeSearch(Tool, BaseSearchTool): 21 | """ 22 | Use Pinecone as the internal search tool 23 | """ 24 | 25 | docs: List[PineconeDoc] 26 | index_name: str = "index" 27 | index: Optional[Any] = None 28 | dimension: int = 8 29 | metric: str = "euclidean" 30 | encoder: BaseLanguageModel = None # such as OpenAIAdaEncoder 31 | id2doc: Dict[str, str] = {} 32 | 33 | def __init__(self, **kwargs): 34 | super().__init__(**kwargs) 35 | pinecone.create_index( 36 | self.index_name, dimension=self.dimension, metric=self.metric 37 | ) 38 | self.index = pinecone.Index(self.index_name) 39 | 40 | self.add_docs(self.docs) 41 | 42 | def _encode(self, doc: PineconeDoc) -> None: 43 | if not doc.vector and self.encoder: 44 | # TODO: encoder over batches 45 | doc.vector = self.encoder.encode([doc.doc]).embeddings[0] 46 | 47 | def _run( 48 | self, 49 | query: str, 50 | top_k: int = 2, 51 | include_values: bool = False, 52 | *args: Any, 53 | **kwargs: Any, 54 | ) -> str: 55 | def _format_output(query_response: QueryResponse) -> str: 56 | """Only return the document since they are likely to be passed to prompt""" 57 | documents = query_response.get("matches", []) 58 | if len(documents) == 0: 59 | return "" 60 | 61 | return "\n".join( 62 | [ 63 | f"Doc {i}: {self.id2doc[doc['id']]}" 64 | for i, doc in enumerate(documents) 65 | ] 66 | ) 67 | 68 | encoding = self.encoder.encode([query]).embeddings[0] 69 | 70 | response: QueryResponse = self.index.query( 71 | vector=encoding, top_k=top_k, include_values=include_values 72 | ) 73 | return _format_output(response) 74 | 75 | def add_docs(self, docs: List[PineconeDoc], **kwargs): 76 | if not len(docs): 77 | return 78 | 79 | for doc in docs: 80 | self._encode(doc) 81 | self.id2doc[doc.id] = doc.doc 82 | 83 | self.index.upsert([(d.id, d.vector) for d in docs]) 84 | 85 | def clear_index(self): 86 | pinecone.delete_index(self.index_name) 87 | pinecone.create_index( 88 | self.index_name, dimension=self.dimension, metric=self.metric 89 | ) 90 | self.index = pinecone.Index(self.index_name) 91 | -------------------------------------------------------------------------------- /autochain/tools/simple_handoff/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Forethought-Technologies/AutoChain/5a1203bb01208b3e186c927222ae31702d309270/autochain/tools/simple_handoff/__init__.py -------------------------------------------------------------------------------- /autochain/tools/simple_handoff/tool.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | from autochain.tools.base import Tool 4 | 5 | 6 | class HandOffToAgent(Tool): 7 | name = "Hand off" 8 | description = "Hand off to a human agent" 9 | handoff_msg = "Let me hand you off to an agent now" 10 | 11 | def _run(self, *args: Any, **kwargs: Any) -> str: 12 | return self.handoff_msg 13 | -------------------------------------------------------------------------------- /autochain/utils.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | import os 4 | from typing import Optional, Dict, Any 5 | 6 | from colorama import Style 7 | 8 | 9 | def print_with_color(text: str, color: str): 10 | if os.getenv("NO_COLOR"): 11 | print(text) 12 | else: 13 | print(color + text) 14 | print(Style.RESET_ALL) 15 | 16 | 17 | def get_from_dict_or_env( 18 | data: Dict[str, Any], key: str, env_key: str, default: Optional[str] = None 19 | ) -> str: 20 | """Get a value from a dictionary or an environment variable.""" 21 | if key in data and data[key]: 22 | return data[key] 23 | else: 24 | return get_from_env(key, env_key, default=default) 25 | 26 | 27 | def get_from_env(key: str, env_key: str, default: Optional[str] = None) -> str: 28 | """Get a value from a dictionary or an environment variable.""" 29 | if env_key in os.environ and os.environ[env_key]: 30 | return os.environ[env_key] 31 | elif default is not None: 32 | return default 33 | else: 34 | raise ValueError( 35 | f"Did not find {key}, please add an environment variable" 36 | f" `{env_key}` which contains it, or pass" 37 | f" `{key}` as a named parameter." 38 | ) 39 | 40 | 41 | def get_args(): 42 | """Adding arguments for running test interactively or setting verbosity""" 43 | parser = argparse.ArgumentParser() 44 | parser.add_argument( 45 | "--interact", 46 | "-i", 47 | action="store_true", 48 | help="if run interactively", 49 | ) 50 | parser.add_argument( 51 | "--verbose", 52 | "-v", 53 | action="store_true", 54 | help="if show detailed contents, such as intermediate results and prompts", 55 | ) 56 | args = parser.parse_args() 57 | if args.verbose: 58 | logging.basicConfig(level=logging.INFO) 59 | 60 | return args 61 | -------------------------------------------------------------------------------- /autochain/workflows_evaluation/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Forethought-Technologies/AutoChain/5a1203bb01208b3e186c927222ae31702d309270/autochain/workflows_evaluation/__init__.py -------------------------------------------------------------------------------- /autochain/workflows_evaluation/base_test.py: -------------------------------------------------------------------------------- 1 | import os.path 2 | from abc import ABC, abstractmethod 3 | from dataclasses import dataclass 4 | from typing import List, Tuple, Any, Dict 5 | 6 | import pandas as pd 7 | from colorama import Fore 8 | 9 | from autochain.agent.message import UserMessage 10 | from autochain.chain import constants 11 | from autochain.chain.base_chain import BaseChain 12 | from autochain.models.base import Generation 13 | from autochain.models.chat_openai import ChatOpenAI 14 | from autochain.tools.base import Tool 15 | from autochain.utils import print_with_color 16 | from autochain.workflows_evaluation.test_utils import parse_evaluation_response 17 | 18 | 19 | @dataclass 20 | class TestCase: 21 | """Standardized data class for each test case for BastTest""" 22 | 23 | test_name: str = "" 24 | user_context: str = "" 25 | expected_outcome: str = "" 26 | 27 | 28 | class BaseTest(ABC): 29 | @property 30 | @abstractmethod 31 | def chain(self) -> BaseChain: 32 | """Chain to test with, which support run() function""" 33 | 34 | @property 35 | @abstractmethod 36 | def tools(self) -> List[Tool]: 37 | """Workflow policy""" 38 | 39 | @property 40 | @abstractmethod 41 | def test_cases(self) -> List[TestCase]: 42 | """""" 43 | 44 | 45 | class WorkflowTester: 46 | def __init__(self, tests: List[BaseTest], output_dir: str): 47 | self.chain = None 48 | self.tests = tests 49 | self.output_dir = output_dir 50 | self.llm = ChatOpenAI(temperature=0) 51 | 52 | def test_each_case(self, test_case: TestCase): 53 | self.chain.memory.clear() 54 | 55 | conversation_history = [] 56 | user_query = "" 57 | conversation_end = False 58 | max_turn = 8 59 | response = {} 60 | while not conversation_end and len(conversation_history) < max_turn: 61 | if not conversation_end: 62 | user_query = self.get_next_user_query( 63 | conversation_history, test_case.user_context 64 | ) 65 | conversation_history.append(("user", user_query)) 66 | print_with_color(f">> User: {user_query}", Fore.GREEN) 67 | 68 | response: Dict[str, Any] = self.chain.run(user_query) 69 | 70 | agent_message = response["message"] 71 | conversation_history.append(("assistant", agent_message)) 72 | print_with_color(f">> Assistant: {agent_message}", Fore.GREEN) 73 | 74 | conversation_end = self.determine_if_conversation_ends(agent_message) 75 | 76 | is_agent_helpful = self.determine_if_agent_solved_problem( 77 | conversation_history, test_case.expected_outcome 78 | ) 79 | return conversation_history, is_agent_helpful, response 80 | 81 | def run_test(self, test): 82 | test_results = [] 83 | self.chain = test.chain 84 | for i, test_case in enumerate(test.test_cases): 85 | print( 86 | f"========== Start running test case: {test_case.test_name} ==========\n" 87 | ) 88 | conversation_history, is_agent_helpful, last_response = self.test_each_case( 89 | test_case 90 | ) 91 | test_results.append( 92 | { 93 | "test_name": test_case.test_name, 94 | "conversation_history": [ 95 | f"{user_type}: {message}" 96 | for user_type, message, in conversation_history 97 | ], 98 | "num_turns": len(conversation_history), 99 | "expected_outcome": test_case.expected_outcome, 100 | "is_agent_helpful": is_agent_helpful, 101 | "actions_took": [ 102 | { 103 | "tool": action.tool, 104 | "tool_input": action.tool_input, 105 | "tool_output": action.tool_output, 106 | } 107 | for action in last_response[constants.INTERMEDIATE_STEPS] 108 | ], 109 | } 110 | ) 111 | 112 | df = pd.DataFrame(test_results) 113 | os.makedirs(self.output_dir, exist_ok=True) 114 | df.to_json( 115 | os.path.join(self.output_dir, f"{test.__class__.__name__}.jsonl"), 116 | lines=True, 117 | orient="records", 118 | ) 119 | 120 | def run_all_tests(self): 121 | for test in self.tests: 122 | self.run_test(test) 123 | 124 | def run_interactive(self): 125 | test = self.tests[0] 126 | self.chain = test.chain 127 | self.chain.memory.clear() 128 | 129 | while True: 130 | user_query = input(">> User: ") 131 | response = self.chain.run(user_query)["message"] 132 | print_with_color(f">> Assistant: {response}", Fore.GREEN) 133 | 134 | def determine_if_conversation_ends(self, last_utterance: str) -> bool: 135 | messages = [ 136 | UserMessage( 137 | content=f"""The most recent reply from assistant 138 | assistant: "{last_utterance}" 139 | Has assistant finish assisting the user or tries to hand off to an agent? Answer with yes or no""" 140 | ), 141 | ] 142 | output: Generation = self.llm.generate(messages=messages).generations[0] 143 | 144 | if "yes" in output.message.content.lower(): 145 | # finish assisting; conversation should end 146 | return True 147 | else: 148 | # not yet finished; conversation should continue 149 | return False 150 | 151 | def get_next_user_query( 152 | self, conversation_history: List[Tuple[str, str]], user_context: str 153 | ) -> str: 154 | messages = [] 155 | conversation = "" 156 | 157 | for user_type, utterance in conversation_history: 158 | conversation += f"{user_type}: {utterance}\n" 159 | 160 | conversation += "user: " 161 | 162 | messages.append( 163 | UserMessage( 164 | content=f"""You are a user with access to the following context information about yourself. 165 | Based on previous conversation, write the message to assistant to help you with goal described 166 | in context without asking repetitive questions. 167 | Replies 'Thank you' if the goal is achieved. 168 | If you are not sure about how to answer, respond with "hand off to agent". 169 | Context: 170 | "{user_context}" 171 | 172 | Previous conversation: 173 | {conversation}""" 174 | ) 175 | ) 176 | 177 | output: Generation = self.llm.generate( 178 | messages=messages, stop=[".", "?"] 179 | ).generations[0] 180 | return output.message.content 181 | 182 | def determine_if_agent_solved_problem( 183 | self, conversation_history: List[Tuple[str, str]], expected_outcome: str 184 | ) -> Dict[str, str]: 185 | messages = [] 186 | conversation = "" 187 | for user_type, utterance in conversation_history: 188 | conversation += f"{user_type}: {utterance}\n" 189 | 190 | messages.append( 191 | UserMessage( 192 | content=f"""You are an admin for assistant and check if assistant meets the expected outcome based on previous conversation. 193 | 194 | Previous conversation: 195 | {conversation} 196 | 197 | Expected outcome is "{expected_outcome}" 198 | Does conversation reach the expected outcome for user? answer in JSON format 199 | {{ 200 | "reason": "explain step by step if conversation reaches the expected outcome", 201 | "rating": "rating from 1 to 5; 1 for not meeting the expected outcome at all, 5 for completely meeting the expected outcome", 202 | }}""" 203 | ) 204 | ) 205 | 206 | output: Generation = self.llm.generate(messages=messages).generations[0] 207 | return parse_evaluation_response(output.message) 208 | -------------------------------------------------------------------------------- /autochain/workflows_evaluation/conversational_agent_eval/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Forethought-Technologies/AutoChain/5a1203bb01208b3e186c927222ae31702d309270/autochain/workflows_evaluation/conversational_agent_eval/__init__.py -------------------------------------------------------------------------------- /autochain/workflows_evaluation/conversational_agent_eval/find_food_near_me_test.py: -------------------------------------------------------------------------------- 1 | from autochain.tools.base import Tool 2 | from autochain.workflows_evaluation.base_test import BaseTest, TestCase, WorkflowTester 3 | from autochain.workflows_evaluation.test_utils import ( 4 | create_chain_from_test, 5 | ) 6 | from autochain.utils import get_args 7 | 8 | 9 | def search_restaurant(location: str, **kwargs): 10 | """Returns order information as a dictionary, where order_status can be "shipped" or "not_shipped" """ 11 | return [ 12 | { 13 | "restaurant_name": f"ABC dumplings", 14 | "food_type": "Chinese", 15 | }, 16 | { 17 | "restaurant_name": f"KK sushi", 18 | "food_type": "Japanese", 19 | }, 20 | ] 21 | 22 | 23 | def get_menu(restaurant_name: str, **kwargs): 24 | """Changes the shipping address for unshipped orders. Requires the order_id and the new_address inputs""" 25 | if "dumpling" in restaurant_name.lower(): 26 | return ["tan tan noodles", "mushroom fried rice", "pork buns"] 27 | elif "sushi" in restaurant_name.lower(): 28 | return ["unagi roll", "tuna sushi", "fried tofu"] 29 | else: 30 | return "not found" 31 | 32 | 33 | class TestFindFoodNearMe(BaseTest): 34 | prompt = """You are able to search restaurant and find corresponding food type for user. 35 | First, searching restaurants for users and responds to user with restaurants met user food preference. 36 | Secondly, only if user requested, use tool to get menu. From menu list, responds to 37 | users with dishes they might like. 38 | If no restaurant met user requirements, replies with i don't know. 39 | """ 40 | 41 | tools = [ 42 | Tool( 43 | func=search_restaurant, 44 | description="""This function searches all available restaurants and their food types 45 | Input args: location""", 46 | ), 47 | Tool( 48 | func=get_menu, 49 | description="""This function gets the name of all dishes for the restaurant 50 | Input args: restaurant_name""", 51 | ), 52 | ] 53 | 54 | test_cases = [ 55 | TestCase( 56 | test_name="find a chinese restaurant", 57 | user_context="find the name of the any chinese restaurant; you are located in new " 58 | "york city", 59 | expected_outcome="found ABC dumplings", 60 | ), 61 | TestCase( 62 | test_name="failed to find any french restaurant", 63 | user_context="find the name of the any french restaurant; you are located in new " 64 | "york city", 65 | expected_outcome="cannot find any french restaurants", 66 | ), 67 | TestCase( 68 | test_name="find vegetarian option for a Japanese restaurant", 69 | user_context="find a Japanese restaurant and all the vegetarian options; you are located in new " 70 | "york city", 71 | expected_outcome="found KK sushi and fired tofu", 72 | ), 73 | ] 74 | 75 | chain = create_chain_from_test(tools=tools, prompt=prompt) 76 | 77 | 78 | if __name__ == "__main__": 79 | tester = WorkflowTester(tests=[TestFindFoodNearMe()], output_dir="./test_results") 80 | 81 | args = get_args() 82 | if args.interact: 83 | tester.run_interactive() 84 | else: 85 | tester.run_all_tests() 86 | -------------------------------------------------------------------------------- /autochain/workflows_evaluation/conversational_agent_eval/generate_ads_test.py: -------------------------------------------------------------------------------- 1 | from autochain.tools.base import Tool 2 | from autochain.workflows_evaluation.base_test import BaseTest, TestCase, WorkflowTester 3 | from autochain.workflows_evaluation.test_utils import ( 4 | create_chain_from_test, 5 | ) 6 | from autochain.utils import get_args 7 | 8 | 9 | def get_item_spec(item_name: str, **kwargs): 10 | if "toy" in item_name.lower(): 11 | return {"name": "toy bear", "color": "red", "age_group": "1-5 years old"} 12 | elif "printer" in item_name.lower(): 13 | return { 14 | "name": "Wireless Printer", 15 | "printer_type": "Printer, Scanner, Copier", 16 | "color_print_speed": "5.5 page per minute", 17 | "mono_print_speed": "7.5 page per minute", 18 | } 19 | else: 20 | return {} 21 | 22 | 23 | def search_image_path_for_item(item_name: str): 24 | if "toy" in item_name.lower(): 25 | return "[images/toy.png]" 26 | elif "printer" in item_name.lower(): 27 | return "[images/awesome_printer.png]" 28 | else: 29 | return "" 30 | 31 | 32 | class TestGenerateAds(BaseTest): 33 | prompt = """"Your goals is helping user to generate an advertisement for user requested 34 | product and find relevant image path for the item. 35 | You would first clarify what product you would write advertisement for and what are the key 36 | points should be included in the ads. 37 | Based on item name, you could get its specifications that can be used in advertisement. 38 | Then, you need to search and include an image path for the item at the bottom of advertisement. 39 | You could find relevant images path with tool provided and search of relevant image using query. 40 | Generate advertisement with image path. 41 | """ 42 | 43 | tools = [ 44 | Tool( 45 | func=get_item_spec, 46 | description="""This function get item spec by searching for item name 47 | Input args: item_name: non-empty str""", 48 | ), 49 | Tool( 50 | func=search_image_path_for_item, 51 | description="""This function retrieves relevant image path for a given search query 52 | Input args: item_name: str""", 53 | ), 54 | ] 55 | 56 | test_cases = [ 57 | TestCase( 58 | test_name="ads for toy bear", 59 | user_context="Write me an advertisement for toy bear; item name is 'toy bear'. it is " 60 | "cute and made in USA, they should be " 61 | "included in the ads. Ads should include image", 62 | expected_outcome="generate an advertisement for toy bear and mentions it is cute. " 63 | "Also ads should include an image path", 64 | ), 65 | TestCase( 66 | test_name="printer ads", 67 | user_context="write me an advertisement for printer; item name is 'good printer'. " 68 | "printer is used and in good condition. " 69 | "Ads should include image", 70 | expected_outcome="generate an advertisement for wireless printer and mentions it is " 71 | "wireless, can be used as scanner and is used. Also ads should " 72 | "include an image path", 73 | ), 74 | ] 75 | 76 | chain = create_chain_from_test(tools=tools, prompt=prompt) 77 | 78 | 79 | if __name__ == "__main__": 80 | tester = WorkflowTester(tests=[TestGenerateAds()], output_dir="./test_results") 81 | 82 | args = get_args() 83 | if args.interact: 84 | tester.run_interactive() 85 | else: 86 | tester.run_all_tests() 87 | -------------------------------------------------------------------------------- /autochain/workflows_evaluation/langchain_eval/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Forethought-Technologies/AutoChain/5a1203bb01208b3e186c927222ae31702d309270/autochain/workflows_evaluation/langchain_eval/__init__.py -------------------------------------------------------------------------------- /autochain/workflows_evaluation/langchain_eval/custom_langchain_output_parser.py: -------------------------------------------------------------------------------- 1 | import re 2 | from typing import Union 3 | 4 | from langchain.agents.agent import AgentOutputParser 5 | from langchain.agents.conversational.prompt import FORMAT_INSTRUCTIONS 6 | from langchain.schema import AgentAction, AgentFinish 7 | 8 | 9 | class CustomConvoOutputParser(AgentOutputParser): 10 | ai_prefix: str = "AI" 11 | 12 | def get_format_instructions(self) -> str: 13 | return FORMAT_INSTRUCTIONS 14 | 15 | def parse(self, text: str) -> Union[AgentAction, AgentFinish]: 16 | if f"{self.ai_prefix}:" in text: 17 | return AgentFinish( 18 | {"output": text.split(f"{self.ai_prefix}:")[-1].strip()}, text 19 | ) 20 | regex = r"Action: (.*?)[\n]*Action Input: (.*)" 21 | match = re.search(regex, text) 22 | if not match: 23 | print( 24 | f"\nLangChain OutputParserException: Could not parse LLM output: `{text}`" 25 | ) 26 | return AgentFinish({"output": text.strip()}, text.strip()) 27 | action = match.group(1) 28 | action_input = match.group(2) 29 | return AgentAction(action.strip(), action_input.strip(" ").strip('"'), text) 30 | 31 | @property 32 | def _type(self) -> str: 33 | return "conversational" 34 | -------------------------------------------------------------------------------- /autochain/workflows_evaluation/langchain_eval/find_food_near_me_test.py: -------------------------------------------------------------------------------- 1 | from langchain.agents import AgentType 2 | from langchain.tools import Tool as LCTool 3 | 4 | from autochain.workflows_evaluation.base_test import BaseTest, TestCase, WorkflowTester 5 | from autochain.utils import get_args 6 | from autochain.workflows_evaluation.langchain_eval.langchain_test_utils import ( 7 | create_langchain_from_test, 8 | ) 9 | 10 | 11 | def search_restaurant(location: str, **kwargs): 12 | """Returns order information as a dictionary, where order_status can be "shipped" or "not_shipped" """ 13 | return [ 14 | { 15 | "restaurant_name": f"ABC dumplings", 16 | "food_type": "Chinese", 17 | }, 18 | { 19 | "restaurant_name": f"KK sushi", 20 | "food_type": "Japanese", 21 | }, 22 | ] 23 | 24 | 25 | def get_menu(restaurant_name: str, **kwargs): 26 | """Changes the shipping address for unshipped orders. Requires the order_id and the new_address inputs""" 27 | if "dumpling" in restaurant_name.lower(): 28 | return ["tan tan noodles", "mushroom fried rice", "pork buns"] 29 | elif "sushi" in restaurant_name.lower(): 30 | return ["unagi roll", "tuna sushi", "fried tofu"] 31 | else: 32 | return "not found" 33 | 34 | 35 | class TestFindFoodNearMeWithLC(BaseTest): 36 | prompt = """You are able to search restaurant and find corresponding food type for user. 37 | First, searching restaurants for users and responds to user with restaurants met user food preference. 38 | Secondly, only if user requested, use tool to get menu. From menu list, responds to 39 | users with dishes they might like. 40 | If no restaurant met user requirements, replies with i don't know. 41 | """ 42 | 43 | tools = [ 44 | LCTool( 45 | name="search restaurant", 46 | func=search_restaurant, 47 | description="""This function searches all available restaurants and their food types 48 | Input args: location""", 49 | ), 50 | LCTool( 51 | name="get menu", 52 | func=get_menu, 53 | description="""This function gets the name of all dishes for the restaurant 54 | Input args: restaurant_name""", 55 | ), 56 | ] 57 | 58 | test_cases = [ 59 | TestCase( 60 | test_name="find a chinese restaurant", 61 | user_context="find the name of the any chinese restaurant; you are located in new " 62 | "york city", 63 | expected_outcome="found ABC dumplings", 64 | ), 65 | TestCase( 66 | test_name="failed to find any french restaurant", 67 | user_context="find the name of the any french restaurant; you are located in new " 68 | "york city", 69 | expected_outcome="cannot find any french restaurants", 70 | ), 71 | TestCase( 72 | test_name="find vegetarian option for a Japanese restaurant", 73 | user_context="find a Japanese restaurant and all the vegetarian options; you are located in new " 74 | "york city", 75 | expected_outcome="found KK sushi and fired tofu", 76 | ), 77 | ] 78 | 79 | chain = create_langchain_from_test( 80 | tools=tools, 81 | agent_type=AgentType.CONVERSATIONAL_REACT_DESCRIPTION, 82 | prefix=prompt, 83 | ) 84 | 85 | 86 | if __name__ == "__main__": 87 | tests = WorkflowTester( 88 | tests=[TestFindFoodNearMeWithLC()], output_dir="./test_results" 89 | ) 90 | 91 | args = get_args() 92 | if args.interact: 93 | tests.run_interactive() 94 | else: 95 | tests.run_all_tests() 96 | -------------------------------------------------------------------------------- /autochain/workflows_evaluation/langchain_eval/generate_ads_test.py: -------------------------------------------------------------------------------- 1 | from langchain.agents import AgentType 2 | from langchain.tools import Tool as LCTool 3 | 4 | from autochain.workflows_evaluation.base_test import BaseTest, TestCase, WorkflowTester 5 | from autochain.utils import get_args 6 | from autochain.workflows_evaluation.langchain_eval.langchain_test_utils import ( 7 | create_langchain_from_test, 8 | ) 9 | 10 | 11 | def get_item_spec(item_name: str, **kwargs): 12 | if "toy" in item_name.lower(): 13 | return {"name": "toy bear", "color": "red", "age_group": "1-5 years old"} 14 | elif "printer" in item_name.lower(): 15 | return { 16 | "name": "Wireless Printer", 17 | "printer_type": "Printer, Scanner, Copier", 18 | "color_print_speed": "5.5 page per minute", 19 | "mono_print_speed": "7.5 page per minute", 20 | } 21 | else: 22 | return {} 23 | 24 | 25 | def search_image_path_for_item(item_name: str): 26 | if "toy" in item_name.lower(): 27 | return "images/toy.png" 28 | elif "printer" in item_name.lower(): 29 | return "images/awesome_printer.png" 30 | else: 31 | return "" 32 | 33 | 34 | class TestGenerateAdsWithLC(BaseTest): 35 | prompt = """Your goals is helping user to generate an advertisement for user requested 36 | product and find relevant image path for the item. 37 | You would first clarify what product you would write advertisement for and what are the key 38 | points should be included in the ads. 39 | Based on item name, you could get its specifications that can be used in advertisement. 40 | Then, you need to search and include an image path for the item at the bottom of advertisement. 41 | You could find relevant images path with tool provided and search of relevant image using query. 42 | Generate advertisement with image path. 43 | 44 | TOOLS: 45 | ------ 46 | 47 | Assistant has access to the following tools: 48 | """ 49 | 50 | tools = [ 51 | LCTool( 52 | name="get item spec", 53 | func=get_item_spec, 54 | description="""This function get item spec by searching for item name 55 | Input args: item_name: non-empty str""", 56 | ), 57 | LCTool( 58 | name="search image path for item", 59 | func=search_image_path_for_item, 60 | description="""This function retrieves relevant image path for a given search query 61 | Input args: item_name: str""", 62 | ), 63 | ] 64 | 65 | test_cases = [ 66 | TestCase( 67 | test_name="ads for toy bear", 68 | user_context="Write me an advertisement for toy bear; item name is 'toy bear'. it is " 69 | "cute and made in USA, they should be " 70 | "included in the ads. Ads should include image", 71 | expected_outcome="generate an advertisement for toy bear and mentions it is cute. " 72 | "Also ads should include an image path", 73 | ), 74 | TestCase( 75 | test_name="printer ads", 76 | user_context="write me an advertisement for printer; item name is 'good printer'. " 77 | "printer is used and in good condition. " 78 | "Ads should include image", 79 | expected_outcome="generate an advertisement for wireless printer and mentions it is " 80 | "wireless, can be used as scanner and is used. Also ads should " 81 | "include an image path", 82 | ), 83 | ] 84 | 85 | chain = create_langchain_from_test( 86 | tools=tools, 87 | agent_type=AgentType.CONVERSATIONAL_REACT_DESCRIPTION, 88 | prefix=prompt, 89 | ) 90 | 91 | 92 | if __name__ == "__main__": 93 | tests = WorkflowTester(tests=[TestGenerateAdsWithLC()], output_dir="./test_results") 94 | 95 | args = get_args() 96 | if args.interact: 97 | tests.run_interactive() 98 | else: 99 | tests.run_all_tests() 100 | -------------------------------------------------------------------------------- /autochain/workflows_evaluation/langchain_eval/langchain_test_utils.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional 2 | 3 | from langchain.agents import AgentType, initialize_agent 4 | from langchain.base_language import BaseLanguageModel as LCBaseLanguageModel 5 | from langchain.chat_models import ChatOpenAI as LCChatOpenAIModel 6 | from langchain.memory import ConversationBufferMemory as LCConversationBufferMemory 7 | from langchain.schema import BaseMemory as LCBaseMemory 8 | from langchain.tools import Tool as LCTool 9 | from autochain.chain.langchain_wrapper_chain import LangChainWrapperChain 10 | from autochain.workflows_evaluation.langchain_eval.custom_langchain_output_parser import ( 11 | CustomConvoOutputParser, 12 | ) 13 | 14 | 15 | def create_langchain_from_test( 16 | tools: List[LCTool], 17 | agent_type: AgentType, 18 | memory: Optional[LCBaseMemory] = None, 19 | llm: Optional[LCBaseLanguageModel] = None, 20 | **kwargs, 21 | ): 22 | """ 23 | Create LangChainWrapperChain by instantiating LangChain agent 24 | Args: 25 | tools: list of langchain tool 26 | agent_type: LangChain AgentType 27 | memory: LangChain memory 28 | llm: LangChain language model 29 | 30 | Returns: 31 | LangChainWrapperChain 32 | """ 33 | llm = llm or LCChatOpenAIModel(temperature=0) 34 | memory = memory or LCConversationBufferMemory(memory_key="chat_history") 35 | 36 | # Created a more lenient output parser to walk around fragility of LangChain Agent 37 | kwargs["output_parser"] = CustomConvoOutputParser() 38 | 39 | langchain = initialize_agent( 40 | tools, llm, agent=agent_type, verbose=True, memory=memory, agent_kwargs=kwargs 41 | ) 42 | 43 | return LangChainWrapperChain(langchain=langchain) 44 | -------------------------------------------------------------------------------- /autochain/workflows_evaluation/langchain_eval/readme.md: -------------------------------------------------------------------------------- 1 | # Evaluate LangChain Agent 2 | 3 | We created a few examples for evaluating LangChain agents with AutoChain workflow evaluation 4 | framework. User could configure types of LangChain agent used in `LangChainWrapperChain`, which 5 | is just a simple wrapper of LangChain to adapt to AutoChain interface. 6 | 7 | To run any LangChain agent evaluation, user would need to install LangChain first. 8 | 9 | ```shell 10 | pip install langchain 11 | ``` 12 | 13 | Sometimes LangChain agent would not response according to the format described in the prompt. 14 | To walk around this problem, agent uses a custom and more lenient output parser 15 | `CustomConvoOutputParser` and directly respond to user when output format does not match 16 | instead of raising an exception. -------------------------------------------------------------------------------- /autochain/workflows_evaluation/openai_function_agent_eval/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Forethought-Technologies/AutoChain/5a1203bb01208b3e186c927222ae31702d309270/autochain/workflows_evaluation/openai_function_agent_eval/__init__.py -------------------------------------------------------------------------------- /autochain/workflows_evaluation/openai_function_agent_eval/find_food_near_me_test.py: -------------------------------------------------------------------------------- 1 | from autochain.agent.openai_functions_agent.openai_functions_agent import ( 2 | OpenAIFunctionsAgent, 3 | ) 4 | from autochain.models.chat_openai import ChatOpenAI 5 | from autochain.tools.base import Tool 6 | from autochain.workflows_evaluation.base_test import BaseTest, TestCase, WorkflowTester 7 | from autochain.workflows_evaluation.test_utils import ( 8 | create_chain_from_test, 9 | ) 10 | from autochain.utils import get_args 11 | 12 | 13 | def search_restaurant(location: str, **kwargs): 14 | """Returns order information as a dictionary, where order_status can be "shipped" or "not_shipped" """ 15 | return [ 16 | { 17 | "restaurant_name": f"ABC dumplings", 18 | "food_type": "Chinese", 19 | }, 20 | { 21 | "restaurant_name": f"KK sushi", 22 | "food_type": "Japanese", 23 | }, 24 | ] 25 | 26 | 27 | def get_menu(restaurant_name: str, **kwargs): 28 | """Changes the shipping address for unshipped orders. Requires the order_id and the new_address inputs""" 29 | if "dumpling" in restaurant_name.lower(): 30 | return ["tan tan noodles", "mushroom fried rice", "pork buns"] 31 | elif "sushi" in restaurant_name.lower(): 32 | return ["unagi roll", "tuna sushi", "fried tofu"] 33 | else: 34 | return "not found" 35 | 36 | 37 | class TestFindFoodNearMeWithFunctionCalling(BaseTest): 38 | prompt = """You are able to search restaurant and find corresponding food type for user. 39 | First, searching restaurants for users and responds to user with restaurants met user food preference. 40 | Secondly, only if user requested, use tool to get menu. From menu list, responds to 41 | users with dishes they might like. 42 | If no restaurant met user requirements, replies with i don't know. 43 | """ 44 | 45 | tools = [ 46 | Tool( 47 | func=search_restaurant, 48 | description="""This function searches all available restaurants and their food types 49 | Input args: location""", 50 | ), 51 | Tool( 52 | func=get_menu, 53 | description="""This function gets the name of all dishes for the restaurant 54 | Input args: restaurant_name""", 55 | ), 56 | ] 57 | 58 | test_cases = [ 59 | TestCase( 60 | test_name="find a chinese restaurant", 61 | user_context="find the name of the any chinese restaurant and get menu", 62 | expected_outcome="found ABC dumplings", 63 | ), 64 | TestCase( 65 | test_name="failed to find any french restaurant", 66 | user_context="find the name of the any french restaurant and get menu", 67 | expected_outcome="cannot find any french restaurants", 68 | ), 69 | TestCase( 70 | test_name="find vegetarian option for a Japanese restaurant", 71 | user_context="find a Japanese restaurant and all the vegetarian options", 72 | expected_outcome="found KK sushi and fired tofu", 73 | ), 74 | ] 75 | 76 | llm = ChatOpenAI(temperature=0) 77 | chain = create_chain_from_test( 78 | tools=tools, agent_cls=OpenAIFunctionsAgent, llm=llm, prompt=prompt 79 | ) 80 | 81 | 82 | if __name__ == "__main__": 83 | tester = WorkflowTester( 84 | tests=[TestFindFoodNearMeWithFunctionCalling()], 85 | output_dir="./test_results", 86 | ) 87 | 88 | args = get_args() 89 | if args.interact: 90 | tester.run_interactive() 91 | else: 92 | tester.run_all_tests() 93 | -------------------------------------------------------------------------------- /autochain/workflows_evaluation/openai_function_agent_eval/generate_ads_test.py: -------------------------------------------------------------------------------- 1 | from autochain.agent.openai_functions_agent.openai_functions_agent import ( 2 | OpenAIFunctionsAgent, 3 | ) 4 | from autochain.models.chat_openai import ChatOpenAI 5 | from autochain.tools.base import Tool 6 | from autochain.workflows_evaluation.base_test import BaseTest, TestCase, WorkflowTester 7 | from autochain.workflows_evaluation.test_utils import ( 8 | create_chain_from_test, 9 | ) 10 | from autochain.utils import get_args 11 | 12 | 13 | def get_item_spec(item_name: str, **kwargs): 14 | if "toy" in item_name.lower(): 15 | return {"name": "toy bear", "color": "red", "age_group": "1-5 years old"} 16 | elif "printer" in item_name.lower(): 17 | return { 18 | "name": "Wireless Printer", 19 | "printer_type": "Printer, Scanner, Copier", 20 | "color_print_speed": "5.5 page per minute", 21 | "mono_print_speed": "7.5 page per minute", 22 | } 23 | else: 24 | return {} 25 | 26 | 27 | def search_image_path_for_item(item_name: str): 28 | if "toy" in item_name.lower(): 29 | return "[images/toy.png]" 30 | elif "printer" in item_name.lower(): 31 | return "[images/awesome_printer.png]" 32 | else: 33 | return "" 34 | 35 | 36 | class TestGenerateAdsWithFunctionCalling(BaseTest): 37 | prompt = """"Your goals is helping user to generate an advertisement for user requested 38 | product and find relevant image path for the item. 39 | You would first clarify what product you would write advertisement for and what are the key 40 | points should be included in the ads. 41 | Based on item name, you could get its specifications that can be used in advertisement. 42 | Then, you need to search and include an image path for the item at the bottom of advertisement. 43 | You could find relevant images path with tool provided and search of relevant image using query. 44 | Generate advertisement with image path. 45 | """ 46 | 47 | tools = [ 48 | Tool( 49 | func=get_item_spec, 50 | description="""This function get item spec by searching for item name 51 | Input args: item_name: non-empty str""", 52 | ), 53 | Tool( 54 | func=search_image_path_for_item, 55 | description="""This function retrieves relevant image path for a given search query 56 | Input args: item_name: str""", 57 | ), 58 | ] 59 | 60 | test_cases = [ 61 | TestCase( 62 | test_name="ads for toy bear", 63 | user_context="Write me an advertisement for toy bear; item name is 'toy bear'. it is " 64 | "cute and made in USA, they should be " 65 | "included in the ads. Ads should include image", 66 | expected_outcome="generate an advertisement for toy bear and mentions it is cute. " 67 | "Also ads should include an image path", 68 | ), 69 | TestCase( 70 | test_name="printer ads", 71 | user_context="write me an advertisement for printer; item name is 'good printer'. " 72 | "printer is used and in good condition. " 73 | "Ads should include image", 74 | expected_outcome="generate an advertisement for wireless printer and mentions it is " 75 | "wireless, can be used as scanner and is used. Also ads should " 76 | "include an image path", 77 | ), 78 | ] 79 | 80 | llm = ChatOpenAI(temperature=0) 81 | chain = create_chain_from_test( 82 | tools=tools, agent_cls=OpenAIFunctionsAgent, llm=llm, prompt=prompt 83 | ) 84 | 85 | 86 | if __name__ == "__main__": 87 | tester = WorkflowTester( 88 | tests=[TestGenerateAdsWithFunctionCalling()], 89 | output_dir="./test_results", 90 | ) 91 | 92 | args = get_args() 93 | if args.interact: 94 | tester.run_interactive() 95 | else: 96 | tester.run_all_tests() 97 | -------------------------------------------------------------------------------- /autochain/workflows_evaluation/openai_function_agent_eval/get_weather_test.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | from autochain.agent.openai_functions_agent.openai_functions_agent import ( 4 | OpenAIFunctionsAgent, 5 | ) 6 | from autochain.models.chat_openai import ChatOpenAI 7 | from autochain.tools.base import Tool 8 | from autochain.workflows_evaluation.base_test import BaseTest, TestCase, WorkflowTester 9 | from autochain.workflows_evaluation.test_utils import ( 10 | create_chain_from_test, 11 | ) 12 | from autochain.utils import get_args 13 | 14 | 15 | def get_current_weather(location: str, unit: str = "fahrenheit"): 16 | """Get the current weather in a given location""" 17 | weather_info = { 18 | "location": location, 19 | "temperature": "72", 20 | "unit": unit, 21 | "forecast": ["sunny", "windy"], 22 | } 23 | return json.dumps(weather_info) 24 | 25 | 26 | class TestGetWeatherWithFunctionCalling(BaseTest): 27 | prompt = """You are a weather support agent tries to get weather information for requested 28 | user location""" 29 | 30 | tools = [ 31 | Tool( 32 | name="get_current_weather", 33 | func=get_current_weather, 34 | description="""Get the current weather in a given location""", 35 | ) 36 | ] 37 | 38 | test_cases = [ 39 | TestCase( 40 | test_name="get weather for boston", 41 | user_context="want to get current weather information; location in Boston", 42 | expected_outcome="found weather information in Boston", 43 | ), 44 | ] 45 | 46 | llm = ChatOpenAI(temperature=0) 47 | chain = create_chain_from_test( 48 | tools=tools, agent_cls=OpenAIFunctionsAgent, llm=llm, prompt=prompt 49 | ) 50 | 51 | 52 | if __name__ == "__main__": 53 | tester = WorkflowTester( 54 | tests=[TestGetWeatherWithFunctionCalling()], 55 | output_dir="./test_results", 56 | ) 57 | 58 | args = get_args() 59 | if args.interact: 60 | tester.run_interactive() 61 | else: 62 | tester.run_all_tests() 63 | -------------------------------------------------------------------------------- /autochain/workflows_evaluation/test_utils.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional, Dict 2 | 3 | from autochain.agent.conversational_agent.conversational_agent import ( 4 | ConversationalAgent, 5 | ) 6 | 7 | from autochain.agent.structs import AgentOutputParser 8 | from autochain.agent.message import BaseMessage 9 | from autochain.chain.chain import Chain 10 | from autochain.memory.base import BaseMemory 11 | from autochain.memory.buffer_memory import BufferMemory 12 | from autochain.models.base import BaseLanguageModel 13 | from autochain.models.chat_openai import ChatOpenAI 14 | from autochain.tools.base import Tool 15 | 16 | 17 | def create_chain_from_test( 18 | tools: List[Tool], 19 | memory: Optional[BaseMemory] = None, 20 | llm: Optional[BaseLanguageModel] = None, 21 | agent_cls=ConversationalAgent, 22 | **kwargs 23 | ): 24 | """ 25 | Create Chain for running tests 26 | Args: 27 | tools: list of autochain tools 28 | memory: memory store for chain 29 | llm: model for agent 30 | agent_cls: metadata class for instantiating agent 31 | Returns: 32 | Chain 33 | """ 34 | llm = llm or ChatOpenAI(temperature=0) 35 | memory = memory or BufferMemory() 36 | agent = agent_cls.from_llm_and_tools(llm=llm, tools=tools, **kwargs) 37 | return Chain(agent=agent, memory=memory) 38 | 39 | 40 | def parse_evaluation_response(message: BaseMessage) -> Dict[str, str]: 41 | """ 42 | Parse the reason and rating from the call to determine if the conversation reaches the 43 | expected outcome 44 | """ 45 | response = AgentOutputParser.load_json_output(message) 46 | return { 47 | "rating": response.get("rating"), 48 | "reason": response.get("reason"), 49 | } 50 | -------------------------------------------------------------------------------- /docs/agent.md: -------------------------------------------------------------------------------- 1 | # Agent 2 | 3 | Agent is the component that implements the interface to interact with `Chain` by deciding how to 4 | respond to users or whether agent requires to use any tool. 5 | 6 | It could use different prompts for different functionalities agent could have. The main goal 7 | for agent is to plan for the next step and then either respond to user with `AgentFinish` or take a 8 | action with `AgentAction`. 9 | 10 | There are a few typical interactions an agent should support: 11 | 12 | **prompt** 13 | Depending on agents you are building, you might want to write different agent's 14 | planning prompts. Policy controls the steps agent should take for different situations. 15 | Those prompts could be string templates so that later agent could substitute 16 | different values into the prompt for different use cases 17 | 18 | **should_answer** 19 | Not all the questions should be answered by agent. If agent decided that this 20 | is not a query that should be handled by this agent, it could gracefully exit as early as 21 | possible. 22 | 23 | **plan** 24 | This is the core of the agent which takes in all the stored memory, including past 25 | conversation history and tool outputs, which are saved to previous `AgentAction`, and prompt the 26 | model to output either `AgentFinish` or`AgentAction` for the next step. 27 | `AgentFinish` means agent decides to respond back to user with a 28 | message. While not just `plan` could output `AgentFinish`, `AgentFinish` is the **only** way to 29 | exits the chain and wait for next user inputs. 30 | `AgentAction` means agent decides to use a tool and wants to perform an action before responding 31 | to user. Once chain observes agent would like to perform an action, it will calls the 32 | corresponding tool and store tool outputs, into the chain's memory for the next iteration of 33 | planning. 34 | 35 | **clarify_args_for_agent_action** 36 | When agent wants to take an action with tools, it is usually required to have some input arguments, 37 | which may or may not exists in the past conversation history or action outputs. While the 38 | smartest agent would output `AgentFinish` with response that asks user for missing information. 39 | It might not always be the case. To decouple the problem and make is simpler for agent, we 40 | could add another step that explicitly ask user for clarifying questions if any argument is 41 | missing for a given tool to be used. This function will either outputs an `AgentFinish`, which 42 | asks the clarifying question or `AgentAction` that is same as action just checked, which means 43 | no more clarifying question is needed. 44 | We skipped implementing this for OpenAI agent with function calling and rely on its native 45 | response for clarifying question. 46 | 47 | ### ConversationalAgent 48 | 49 | This is the a basic agent with a simple default prompt template to have nice conversation with 50 | user. It could also use tools if provided. 51 | While it does not use native OpenAI function calling, this agent showcases the interaction between 52 | memory and prompts. 53 | User could provide a custom prompt injected to the [prompt template](../autochain/agent/conversational_agent/prompt.py), 54 | which contains the prompt placeholder variable. 55 | 56 | ### OpenAIFunctionAgent 57 | 58 | At Jun 13, OpenAI released [function calling](https://platform.openai.com/docs/guides/gpt/chat-completions-api) 59 | , which is a new way for model to use tools natively with function calling. 60 | We introduced `OpenAIFunctionsAgent` to support native function calling when tools are provided. 61 | To give a system message or instruction to agent via prompt, user could provide the prompt when 62 | creating the Agent, such as `agent = ConversationalAgent.from_llm_and_tools(llm=llm, prompt=prompt)` 63 | -------------------------------------------------------------------------------- /docs/chain.md: -------------------------------------------------------------------------------- 1 | # Chain 2 | 3 | If you have worked on LangChain before, you already knows 80% of what chain does. 4 | Chain is the stateful orchestrator for agent, which controls when to involve agent in which way. It 5 | offers a framework to interact with agents by controlling the information flow. Chain leverages a 6 | memory component that memorizes past conversations and any other intermediate 7 | steps, such as tool outputs. 8 | 9 | For a typical chain, it contains an `agent` to interact with, list of `tools` that agent might use 10 | and `memory` that stored stateful information. 11 | 12 | Flow diagram describes the high level picture of the default chain interaction with an agent. 13 | 14 | ![alt text](./img/autochain.drawio.png) 15 | 16 | ## Differences with LangChain 17 | 18 | To remove abstraction and internal concepts, we expose a more flatten and simplified interface 19 | for chain and explained in details below. 20 | The main difference is we simply the flow and removed as many internal concepts as possible. 21 | 22 | ## BaseChain and Chain 23 | 24 | This is the most generic interface for implementing any chain in AutoChain. It contains a few 25 | features user could override. `BaseChain` is the generic interface where `Chain` is the default 26 | chain implementation by implementing the only abstract method `take_next_step` 27 | 28 | ### run 29 | 30 | this is the entry point to interact with chain with `user_query`. Because there are often 31 | different ways to interact with the chain, any chain could override the `_run` function that 32 | handles the business logics for agent interaction, while still be benefited from input and 33 | output standardization provided by the `BaseChain`. 34 | 35 | ### _run 36 | 37 | This provide the standard way to manage memories and determines when the agent should stop 38 | answering user query. Most of the time, user could reuse how 39 | we manage memory and just need to change what is the next step agent should do given inputs 40 | including `user_query` and memories. In that case, user would need to implement the 41 | `take_next_step` function. 42 | 43 | ### take_next_step 44 | 45 | This is an abstract method implements the way you would like to interact with agent and asking 46 | agent to come up with the next step. The default implementation is in `Chain`, where it asks 47 | the `agent` to plan for next step and execute `tools` selected by `agent` 48 | 49 | ### should_answer 50 | 51 | It is often unclear when agent should stop responding to user query. Sometimes user would just 52 | say "Thank you" in the end but agent might not understand this as end of the conversation, so 53 | it could still try to respond with more contents, even clarifying questions in some cases. By 54 | default, agent will always respond to user until user stops. In the case that is not desired, we 55 | introduce the `should_answer` step in `BaseChain` to stop agent from further interaction. 56 | -------------------------------------------------------------------------------- /docs/components_overview.md: -------------------------------------------------------------------------------- 1 | # Components overview 2 | 3 | There are a few key concepts in AutoChain, which could be easily extended to build new agents. 4 | 5 | ### Chain 6 | 7 | `Chain` is the overall *stateful* orchestrator for agent interaction. It determines when to use 8 | tools or respond to users. `Chain` is the only stateful component, so all the interactions with 9 | memory happen at the `Chain` level. By Default, it saves all the chat conversation history and 10 | intermediate `AgentAction` with corresponding outputs at `prep_input` and `prep_output` steps. 11 | 12 | `Agent` provides ways of interaction, while `Chain` determines how to 13 | interact with agent. 14 | 15 | Read more about the [chain concept](./chain.md). 16 | 17 | ### Agent 18 | 19 | Agent is the *stateless* component that decides how to respond to the user or whether an agent 20 | requires to use tools. 21 | It could contain different prompts for different functionalities an agent could have. The main goal 22 | for an agent is to plan for the next step, either respond to the user with `AgentFinish` or take an 23 | action with `AgentAction`. 24 | 25 | Read more about [agent](./agent.md). 26 | 27 | ### Tool 28 | 29 | The ability to use tools make the agent incredible more powerful as shown in LangChain and 30 | AutoGPT. We follow a similar concept of "tool" as in LangChain here as well. 31 | All the tools in LangChain can be easily ported over to AutoChain if you like, since they follow 32 | a very similar interface. 33 | 34 | Read more about [tool](./tool.md). 35 | 36 | ### Memory 37 | 38 | It is important for a chain to keep the memory for a particular conversation with a user. The 39 | memory 40 | interface exposes two ways to save memories. One is `save_conversation` which saves the chat 41 | history between the agent and the user, and `save_memory` to save any additional information 42 | for any specific business logics. 43 | 44 | By default, memory are saved/updated in the beginning and updated in the end at `Chain` level. 45 | Memory saves conversation history, including the latest user query, and intermediate 46 | steps, which is a list of `AgentAction` taken with corresponding outputs. 47 | All memorized contents are usually provided to Agent for planning the next step. 48 | 49 | Read more about [memory](./memory.md) -------------------------------------------------------------------------------- /docs/css/custom.css: -------------------------------------------------------------------------------- 1 | .termynal-comment { 2 | color: #4a968f; 3 | font-style: italic; 4 | display: block; 5 | } 6 | 7 | .termy [data-termynal] { 8 | white-space: pre-wrap; 9 | } 10 | 11 | a.external-link::after { 12 | /* \00A0 is a non-breaking space 13 | to make the mark be on the same line as the link 14 | */ 15 | content: "\00A0[↪]"; 16 | } 17 | 18 | a.internal-link::after { 19 | /* \00A0 is a non-breaking space 20 | to make the mark be on the same line as the link 21 | */ 22 | content: "\00A0↪"; 23 | } 24 | 25 | .shadow { 26 | box-shadow: 5px 5px 10px #999; 27 | } 28 | -------------------------------------------------------------------------------- /docs/css/termynal.css: -------------------------------------------------------------------------------- 1 | /** 2 | * termynal.js 3 | * 4 | * @author Ines Montani 5 | * @version 0.0.1 6 | * @license MIT 7 | */ 8 | 9 | :root { 10 | --color-bg: #252a33; 11 | --color-text: #eee; 12 | --color-text-subtle: #a2a2a2; 13 | } 14 | 15 | [data-termynal] { 16 | width: 750px; 17 | max-width: 100%; 18 | background: var(--color-bg); 19 | color: var(--color-text); 20 | /* font-size: 18px; */ 21 | font-size: 15px; 22 | /* font-family: 'Fira Mono', Consolas, Menlo, Monaco, 'Courier New', Courier, monospace; */ 23 | font-family: 'Roboto Mono', 'Fira Mono', Consolas, Menlo, Monaco, 'Courier New', Courier, monospace; 24 | border-radius: 4px; 25 | padding: 75px 45px 35px; 26 | position: relative; 27 | -webkit-box-sizing: border-box; 28 | box-sizing: border-box; 29 | } 30 | 31 | [data-termynal]:before { 32 | content: ''; 33 | position: absolute; 34 | top: 15px; 35 | left: 15px; 36 | display: inline-block; 37 | width: 15px; 38 | height: 15px; 39 | border-radius: 50%; 40 | /* A little hack to display the window buttons in one pseudo element. */ 41 | background: #d9515d; 42 | -webkit-box-shadow: 25px 0 0 #f4c025, 50px 0 0 #3ec930; 43 | box-shadow: 25px 0 0 #f4c025, 50px 0 0 #3ec930; 44 | } 45 | 46 | [data-termynal]:after { 47 | content: 'bash'; 48 | position: absolute; 49 | color: var(--color-text-subtle); 50 | top: 5px; 51 | left: 0; 52 | width: 100%; 53 | text-align: center; 54 | } 55 | 56 | a[data-terminal-control] { 57 | text-align: right; 58 | display: block; 59 | color: #aebbff; 60 | } 61 | 62 | [data-ty] { 63 | display: block; 64 | line-height: 2; 65 | } 66 | 67 | [data-ty]:before { 68 | /* Set up defaults and ensure empty lines are displayed. */ 69 | content: ''; 70 | display: inline-block; 71 | vertical-align: middle; 72 | } 73 | 74 | [data-ty="input"]:before, 75 | [data-ty-prompt]:before { 76 | margin-right: 0.75em; 77 | color: var(--color-text-subtle); 78 | } 79 | 80 | [data-ty="input"]:before { 81 | content: '$'; 82 | } 83 | 84 | [data-ty][data-ty-prompt]:before { 85 | content: attr(data-ty-prompt); 86 | } 87 | 88 | [data-ty-cursor]:after { 89 | content: attr(data-ty-cursor); 90 | font-family: monospace; 91 | margin-left: 0.5em; 92 | -webkit-animation: blink 1s infinite; 93 | animation: blink 1s infinite; 94 | } 95 | 96 | 97 | /* Cursor animation */ 98 | 99 | @-webkit-keyframes blink { 100 | 50% { 101 | opacity: 0; 102 | } 103 | } 104 | 105 | @keyframes blink { 106 | 50% { 107 | opacity: 0; 108 | } 109 | } 110 | -------------------------------------------------------------------------------- /docs/img/autochain.drawio.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Forethought-Technologies/AutoChain/5a1203bb01208b3e186c927222ae31702d309270/docs/img/autochain.drawio.png -------------------------------------------------------------------------------- /docs/img/icon.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Forethought-Technologies/AutoChain/5a1203bb01208b3e186c927222ae31702d309270/docs/img/icon.png -------------------------------------------------------------------------------- /docs/img/logo-margin/logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Forethought-Technologies/AutoChain/5a1203bb01208b3e186c927222ae31702d309270/docs/img/logo-margin/logo.png -------------------------------------------------------------------------------- /docs/index.md: -------------------------------------------------------------------------------- 1 | # AutoChain 2 | 3 | Large language models (LLMs) have shown huge success in different text generation tasks and 4 | enable developers to build generative agents based on natural language objectives. 5 | 6 | However, most of the generative agents require heavy customization for specific purposes, and 7 | adapting to different use cases is sometimes overwhelming using existing tools 8 | and framework. As a result, it is still very challenging to build a customized generative agent. 9 | 10 | In addition, evaluating such agents powered by LLMs by trying different use 11 | cases under different potential user scenarios is a very manual and expensive task. 12 | 13 | AutoChain takes inspiration from LangChain and AutoGPT and aims to solve 14 | both problems by providing a lightweight and extensible framework 15 | for developers to build their own conversational agents using LLMs with custom tools and 16 | [automatically evaluating](#workflow-evaluation) different user scenarios with simulated 17 | conversations. Experienced user of LangChain would find AutoChain is easy to navigate since 18 | they share similar but simpler concepts. 19 | 20 | The goal is to enable quick user experiments of generative agents, knowing users would 21 | make more customizations as they are building their own agent. 22 | 23 | If you have any question, please feel free to reach out to Yi Lu 24 | 25 | ## Features 26 | 27 | - 🚀 lightweight and extensible generative agent pipeline. 28 | - 🔗 agent that can use different custom tools and 29 | support OpenAI [function calling](https://platform.openai.com/docs/guides/gpt/function-calling) 30 | - 💾 simple memory tracking for conversation history and tools' outputs 31 | - 🤖 automated agent multi-turn conversation evaluation with simulated conversations 32 | 33 | ## Setup 34 | 35 | Quick install 36 | 37 | ```shell 38 | pip install autochain 39 | ``` 40 | 41 | Or install from source after cloning the repo 42 | 43 | ```shell 44 | cd autochain 45 | pyenv virtualenv 3.10.11 venv 46 | pyenv local venv 47 | 48 | pip install . 49 | ``` 50 | 51 | Set `PYTHONPATH` and `OPENAI_API_KEY` 52 | 53 | ```shell 54 | export OPENAI_API_KEY= 55 | export PYTHONPATH=`pwd` 56 | ``` 57 | 58 | Run your first conversation with agent interactively 59 | 60 | ```shell 61 | python autochain/workflows_evaluation/conversational_agent_eval/generate_ads_test.py -i 62 | ``` 63 | 64 | ## Example usage 65 | 66 | If you have experiences with LangChain, you already know 80% of the AutoChain interfaces. 67 | 68 | AutoChain aims to make creating a new customized agent very straight forward with as few 69 | concepts as possible. 70 | Read about more [example usages](./examples.md). 71 | 72 | The most basic example uses default chain and `ConversationalAgent`: 73 | 74 | ```python 75 | from autochain.chain.chain import Chain 76 | from autochain.memory.buffer_memory import BufferMemory 77 | from autochain.models.chat_openai import ChatOpenAI 78 | from autochain.agent.conversational_agent.conversational_agent import ConversationalAgent 79 | 80 | llm = ChatOpenAI(temperature=0) 81 | memory = BufferMemory() 82 | agent = ConversationalAgent.from_llm_and_tools(llm=llm) 83 | chain = Chain(agent=agent, memory=memory) 84 | 85 | print(chain.run("Write me a poem about AI")['message']) 86 | ``` 87 | 88 | User could add a list of tools to the agent similar to LangChain 89 | 90 | ```python 91 | tools = [Tool( 92 | name="Get weather", 93 | func=lambda *args, **kwargs: "Today is a sunny day", 94 | description="""This function returns the weather information""" 95 | )] 96 | 97 | memory = BufferMemory() 98 | agent = ConversationalAgent.from_llm_and_tools(llm=llm, tools=tools) 99 | chain = Chain(agent=agent, memory=memory) 100 | print(chain.run("What is the weather today")['message']) 101 | ``` 102 | 103 | AutoChain also added supports for [function calling](https://platform.openai. 104 | com/docs/guides/gpt/function-calling) 105 | for OpenAI model. It extrapolates the function spec in OpenAI format without user explicit 106 | instruction, so user could follow the same `Tool` interface. 107 | 108 | ```python 109 | llm = ChatOpenAI(temperature=0) 110 | agent = OpenAIFunctionsAgent.from_llm_and_tools(llm=llm, tools=tools) 111 | ``` 112 | 113 | Check out [more examples](./examples.md) under `autochain/examples` and [workflow 114 | evaluation](./workflow-evaluation.md) test cases which can also be run interactively. 115 | 116 | ## How does AutoChain simplify building agents? 117 | 118 | AutoChain aims to provide a lightweight framework and simplifies the building process a few 119 | ways comparing with other existing frameworks 120 | 121 | 1. Easy prompt update 122 | Prompt engineering and iterations is one of the most important part of building generative 123 | agent. AutoChain makes is very obvious and easy to update prompts and visualize prompt 124 | outputs. Run with `-v` flag to output verbose prompt and outputs in console. 125 | 2. Up to 2 layers of abstraction 126 | Since this goal of AutoChain is enabling quick iterations, it chooses to remove most of the 127 | abstraction layers from alternative frameworks and make it easy to follow 128 | 3. Automated multi-turn evaluation 129 | The most painful and uncertain part of building generative agent is how to evaluate its 130 | performance. Any change for one scenario could cause regression in other use cases. AutoChain 131 | provides an easy test framework to automatically evaluate agent's ability under different 132 | user scenarios. 133 | 134 | Read more about detailed [components overview](./components_overview.md) 135 | 136 | ## Workflow Evaluation 137 | 138 | It is notoriously hard to evaluate generative agents in LangChain or AutoGPT. An agent's behavior 139 | is nondeterministic and susceptible to small changes to the prompt or model. It is very 140 | hard to know if agent is behaving correctly under different scenarios. The current path for 141 | evaluation is running the agent through a large number of preset queries and evaluate the 142 | generated responses. However, that is limited to single turn conversation, general and not 143 | specific to tasks and expensive to verify. 144 | 145 | To facilitate agent evaluation, AutoChain introduces the workflow evaluation framework. This 146 | framework runs conversations between a generative agent and LLM-simulated test users. The test 147 | users incorporate various user contexts and desired conversation outcomes, which enables easy 148 | addition of test cases for new user scenarios and fast evaluation. The framework leverages LLMs to 149 | evaluate whether a given multi-turn conversation has achieved the intended outcome. 150 | 151 | Read more about our [evaluation strategy](./workflow-evaluation.md). 152 | 153 | ### How to run workflow evaluations 154 | 155 | There are two modes for running workflow tests. Interactively or running all test cases. 156 | For example in `autochain/workflows_evaluation/conversational_agent_eval/generate_ads_test.py`, 157 | there are already a few example test cases. 158 | 159 | Running all the test cases defined in the test: 160 | 161 | ```shell 162 | python autochain/workflows_evaluation/conversational_agent_eval/generate_ads_test.py 163 | ``` 164 | 165 | You can also have an interactive conversation with agent by passing the interactive flag `-i`: 166 | 167 | ```shell 168 | python autochain/workflows_evaluation/conversational_agent_eval/generate_ads_test.py -i 169 | ``` 170 | 171 | More explanations for how AutoChain works? checkout [components overview](./components_overview.md) 172 | -------------------------------------------------------------------------------- /docs/js/custom.js: -------------------------------------------------------------------------------- 1 | function setupTermynal() { 2 | document.querySelectorAll(".use-termynal").forEach(node => { 3 | node.style.display = "block"; 4 | new Termynal(node, { 5 | lineDelay: 500 6 | }); 7 | }); 8 | const progressLiteralStart = "---> 100%"; 9 | const promptLiteralStart = "$ "; 10 | const customPromptLiteralStart = "# "; 11 | const termynalActivateClass = "termy"; 12 | let termynals = []; 13 | 14 | function createTermynals() { 15 | document 16 | .querySelectorAll(`.${termynalActivateClass} .highlight`) 17 | .forEach(node => { 18 | const text = node.textContent; 19 | const lines = text.split("\n"); 20 | const useLines = []; 21 | let buffer = []; 22 | function saveBuffer() { 23 | if (buffer.length) { 24 | let isBlankSpace = true; 25 | buffer.forEach(line => { 26 | if (line) { 27 | isBlankSpace = false; 28 | } 29 | }); 30 | dataValue = {}; 31 | if (isBlankSpace) { 32 | dataValue["delay"] = 0; 33 | } 34 | if (buffer[buffer.length - 1] === "") { 35 | // A last single
won't have effect 36 | // so put an additional one 37 | buffer.push(""); 38 | } 39 | const bufferValue = buffer.join("
"); 40 | dataValue["value"] = bufferValue; 41 | useLines.push(dataValue); 42 | buffer = []; 43 | } 44 | } 45 | for (let line of lines) { 46 | if (line === progressLiteralStart) { 47 | saveBuffer(); 48 | useLines.push({ 49 | type: "progress" 50 | }); 51 | } else if (line.startsWith(promptLiteralStart)) { 52 | saveBuffer(); 53 | const value = line.replace(promptLiteralStart, "").trimEnd(); 54 | useLines.push({ 55 | type: "input", 56 | value: value 57 | }); 58 | } else if (line.startsWith("// ")) { 59 | saveBuffer(); 60 | const value = "💬 " + line.replace("// ", "").trimEnd(); 61 | useLines.push({ 62 | value: value, 63 | class: "termynal-comment", 64 | delay: 0 65 | }); 66 | } else if (line.startsWith(customPromptLiteralStart)) { 67 | saveBuffer(); 68 | const promptStart = line.indexOf(promptLiteralStart); 69 | if (promptStart === -1) { 70 | console.error("Custom prompt found but no end delimiter", line) 71 | } 72 | const prompt = line.slice(0, promptStart).replace(customPromptLiteralStart, "") 73 | let value = line.slice(promptStart + promptLiteralStart.length); 74 | useLines.push({ 75 | type: "input", 76 | value: value, 77 | prompt: prompt 78 | }); 79 | } else { 80 | buffer.push(line); 81 | } 82 | } 83 | saveBuffer(); 84 | const div = document.createElement("div"); 85 | node.replaceWith(div); 86 | const termynal = new Termynal(div, { 87 | lineData: useLines, 88 | noInit: true, 89 | lineDelay: 500 90 | }); 91 | termynals.push(termynal); 92 | }); 93 | } 94 | 95 | function loadVisibleTermynals() { 96 | termynals = termynals.filter(termynal => { 97 | if (termynal.container.getBoundingClientRect().top - innerHeight <= 0) { 98 | termynal.init(); 99 | return false; 100 | } 101 | return true; 102 | }); 103 | } 104 | window.addEventListener("scroll", loadVisibleTermynals); 105 | createTermynals(); 106 | loadVisibleTermynals(); 107 | } 108 | 109 | async function main() { 110 | setupTermynal() 111 | } 112 | 113 | main() 114 | -------------------------------------------------------------------------------- /docs/memory.md: -------------------------------------------------------------------------------- 1 | # Memory 2 | 3 | We have a simple memory interface to experiment with. Memory is accessible at the `Chain` level, 4 | and only at th `Chain` level, since it is the only stateful component. By default, memory saves 5 | conversation history, including the latest user query, and intermediate 6 | steps, which are `AgentAction` taken with corresponding outputs. 7 | 8 | `Chain` could collect all the memory and puts into `inputs` at `prep_inputs` step and updates 9 | memory at `prep_outputs` step. Constructed `inputs` will be passed to agent as kwargs. 10 | 11 | There are two parts of the memory, conversation history and key-value memory 12 | 13 | ## Conversation history 14 | 15 | Memory uses `ChatMessageHistory` to store all the conversation history between agent and user 16 | as instances of `BaseMessage`, including `FunctionMessage`, which is tool used and 17 | corresponding output. This make tracking all interactions easy and fit the same 18 | interface OpenAI API requires. 19 | 20 | ## Key-value memory 21 | 22 | Not only we could save conversation history, it allows saving any memory in key value pair 23 | format. By default, it saves all the `AgentActions` and corresponding outputs using key value 24 | pairs. This part is designed to be flexible and users could save anything to it with preferred 25 | storage. One way is using this as long term memory powered by internal search tool. Example 26 | implementation of it is under `autochain/memory/long_term_memory.py`. In that example, if the 27 | value is an instance of document, it would be saved to `long_term_memory` and can be retrieved 28 | using key as query. 29 | 30 | ## Types of memory supported 31 | 32 | AutoChain supports different types of memory for different use cases. 33 | 34 | ### BufferMemory 35 | 36 | This is the simplest implementation of memory. Everything stored in RAM with python dictionary 37 | as key-value store. This is best suited for experimentation and iterating prompts, which is the 38 | default type of memory AutoChain uses in examples and evaluation. 39 | 40 | ### LongTermMemory 41 | 42 | In the case there are a lot of information need to be stored and only a small part of it is 43 | needed during the planning step, `LongTermMemory` enables agents to retrieve partial memory 44 | with internal search tool, such as `ChromaDBSearch`, `PineconeSearch`, `LanceDBSearch`. Search query is the 45 | key of the store, and it still follow the same interface as other memory implementations. Both 46 | would encode the text into vector DB and retrieve using the search query. 47 | 48 | ### RedisMemory 49 | 50 | Redis is also supported to save information. This is useful when hosting AutoChain as a backend 51 | service on more than one server instance, in which case it's not possible to use RAM as memory. 52 | -------------------------------------------------------------------------------- /docs/robots.txt: -------------------------------------------------------------------------------- 1 | User-agent: * 2 | Allow: / 3 | -------------------------------------------------------------------------------- /docs/tool.md: -------------------------------------------------------------------------------- 1 | # Tool 2 | 3 | The ability to use tools makes the agent incredible more powerful as shown in LangChain and 4 | AutoGPT. We follow the similar concept of tool in LangChain here as well. 5 | All the tools in LangChain can be easily ported over to AutoChain since they follow very 6 | similar interface. 7 | Tool is essentially an object that implements a `run` function that takes in a dictionary of 8 | kwargs. Since input parsing can be reused, in most cases, user would just need to pass the 9 | callable function to create a new tool, and LLM will generate the inputs on the fly when it 10 | needs to use the tool. As the result, the interface for `Tool` is below: 11 | 12 | - **func** 13 | Function callable will be called at the `run` function. It will automatically generate the 14 | typing information when using `OpenAIFunctionsAgent`. 15 | 16 | - **description** 17 | To make it easy and descriptive for LLM model to understand when it should use this tool, it 18 | would be great to have a description for proper tool usage. 19 | 20 | ## Other optional parameters 21 | 22 | - **name** 23 | Tool name as identifier for model specify which tool to use. If this is not provided, it will 24 | be same as the `func` name. User might want to provide a more descriptive name for the tool if 25 | the function name is not very obvious. 26 | 27 | - **arg_description** 28 | Function calling feature of OpenAI supports adding description for each argument. User could 29 | pass a dictionary of arg name and description using `arg_description` parameter. They will be 30 | formatted into the prompt when using `OpenAIFunctionsAgent`. 31 | 32 | 33 | ## Tools included 34 | ### GoogleSearchTool 35 | Migrated from LangChain, which is also an example for user to easily migrate any tool from 36 | LangChain if needed. 37 | User would need to provide `google_api_key` and `google_cse_id` to 38 | search google through API. This allows the agent to have access to search engine and other 39 | non-parametric information. 40 | 41 | ### PineconeTool 42 | Internal search tool that can be used for long term memory of the agent or looking up relevant 43 | information that does not exists from the Internet. Currently, AutoChain supports `Pinecone` as 44 | long term memory for the agent 45 | 46 | 47 | ### ChromaDBTool 48 | Internal search tool that can be used for long term memory of the agent or looking up relevant 49 | information that does not exists from the Internet. Currently, AutoChain supports `ChromaDB` as 50 | long term memory for the agent. 51 | 52 | ### LanceDBTool 53 | Internal search tool that can be used for long term memory of the agent or looking up relevant 54 | information that does not exists from the Internet. Currently, AutoChain supports `ChromaDB` as 55 | long term memory for the agent. LanceDBTool is serverless, and does not require any setup. -------------------------------------------------------------------------------- /mkdocs.insiders.yml: -------------------------------------------------------------------------------- 1 | INHERIT: mkdocs.yml 2 | plugins: 3 | - search 4 | - social 5 | -------------------------------------------------------------------------------- /mkdocs.yml: -------------------------------------------------------------------------------- 1 | site_name: AutoChain 2 | site_description: "AutoChain: Build lightweight, extensible, and testable LLM Agents" 3 | # TODO: set the proper docs website URL once the domain is defined 4 | site_url: https://engineering.forethought.ai 5 | repo_url: https://github.com/Forethought-Technologies/AutoChain 6 | repo_name: AutoChain 7 | theme: 8 | name: material 9 | icon: 10 | repo: fontawesome/brands/github 11 | palette: 12 | - scheme: slate 13 | primary: white 14 | accent: purple 15 | # Disable Dark mode toggle 16 | # toggle: 17 | # icon: material/lightbulb-outline 18 | # name: Switch to light mode 19 | # - scheme: default 20 | # primary: white 21 | # accent: purple 22 | # toggle: 23 | # icon: material/lightbulb 24 | # name: Switch to dark mode 25 | features: 26 | - search.suggest 27 | - search.highlight 28 | - content.tabs.link 29 | # TODO if AutoChain gets its own logo and icon, it could be put in these directories 30 | # if the intention is to preserve Forethought's logo, remove this comment 31 | logo: img/icon.png 32 | favicon: img/icon.png 33 | language: en 34 | 35 | nav: 36 | - index.md 37 | - examples.md 38 | - workflow-evaluation.md 39 | - components_overview.md 40 | - chain.md 41 | - agent.md 42 | - tool.md 43 | - memory.md 44 | 45 | plugins: 46 | - git-authors 47 | 48 | markdown_extensions: 49 | - toc: 50 | permalink: true 51 | - markdown.extensions.codehilite: 52 | guess_lang: false 53 | - admonition 54 | - codehilite 55 | - extra 56 | - pymdownx.superfences: 57 | custom_fences: 58 | - name: mermaid 59 | class: mermaid 60 | format: !!python/name:pymdownx.superfences.fence_code_format '' 61 | - pymdownx.tabbed: 62 | alternate_style: true 63 | - mdx_include 64 | 65 | extra: 66 | # TODO: do we want Google Analytics? 67 | # analytics: 68 | # provider: google 69 | # property: YY-xxxxx 70 | social: 71 | - icon: fontawesome/brands/twitter 72 | link: https://twitter.com/forethought_ai 73 | - icon: fontawesome/brands/linkedin 74 | link: https://www.linkedin.com/company/forethought-ai/ 75 | - icon: fontawesome/solid/globe 76 | link: https://forethought.ai/ 77 | 78 | extra_css: 79 | - css/termynal.css 80 | - css/custom.css 81 | 82 | extra_javascript: 83 | - js/termynal.js 84 | - js/custom.js 85 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.poetry] 2 | name = "autochain" 3 | version = "0.0.5" 4 | description = "AutoChain: Build lightweight, extensible, and testable LLM Agents" 5 | # TODO: add URLs, homepage, documentation 6 | authors = [ 7 | "Yi Lu ", 8 | "Forethought Engineering ", 9 | ] 10 | readme = "README.md" 11 | classifiers = [ 12 | "Intended Audience :: Customer Service", 13 | "Intended Audience :: Developers", 14 | "Intended Audience :: Information Technology", 15 | "Intended Audience :: System Administrators", 16 | "Intended Audience :: Science/Research", 17 | "Operating System :: OS Independent", 18 | "Programming Language :: Python :: 3", 19 | "Programming Language :: Python", 20 | "Topic :: Scientific/Engineering", 21 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 22 | "Topic :: Software Development :: Libraries :: Application Frameworks", 23 | "Topic :: Software Development :: Libraries :: Python Modules", 24 | "Topic :: Software Development :: Libraries", 25 | "Topic :: Software Development", 26 | "Typing :: Typed", 27 | "Development Status :: 4 - Beta", 28 | "Framework :: Pydantic", 29 | "Framework :: Pydantic :: 1", 30 | # TODO: define license 31 | # "License :: OSI Approved :: MIT License", 32 | "Programming Language :: Python :: 3 :: Only", 33 | "Programming Language :: Python :: 3.8", 34 | "Programming Language :: Python :: 3.9", 35 | "Programming Language :: Python :: 3.10", 36 | "Programming Language :: Python :: 3.11", 37 | "Programming Language :: Python :: 3.12", 38 | ] 39 | 40 | [tool.poetry.dependencies] 41 | python = "^3.8.1" 42 | colorama = ">=0.4.6" 43 | pydantic = "^1.10.9" 44 | google-api-python-client = {version = ">=2.89.0", optional = true} 45 | chromadb = ">=0.3.26" 46 | lancedb = {version = ">=0.1.16", optional = true} 47 | pandas = ">=2.0.2" 48 | openai = ">=0.27.8" 49 | types-colorama = ">=0.4.15.11" 50 | pytest = "^7.3.2" 51 | redis = "^4.6.0" 52 | pinecone-client = {version = "^2.2.2", optional = true} 53 | mkdocs-git-authors-plugin = "^0.7.2" 54 | tenacity = "^8.2.2" 55 | 56 | 57 | [tool.poetry.group.dev.dependencies] 58 | mypy = "^1.3.0" 59 | black = "^23.3.0" 60 | ruff = "^0.0.272" 61 | mkdocs = "^1.4.2" 62 | mkdocs-material = "^8.5.7" 63 | mdx-include = "^1.4.2" 64 | Pillow = "^9.3.0" 65 | CairoSVG = "^2.5.2" 66 | pytest = "^7.3.2" 67 | coverage = {extras = ["toml"], version = "^7.3.0"} 68 | 69 | [build-system] 70 | requires = ["poetry-core"] 71 | build-backend = "poetry.core.masonry.api" 72 | 73 | [tool.poetry.extras] 74 | google = ["google-api-python-client"] 75 | pinecone= ["pinecone-client"] 76 | lancedb = ["lancedb"] 77 | 78 | [tool.mypy] 79 | strict = true 80 | 81 | [[tool.mypy.overrides]] 82 | module = [ 83 | "googleapiclient.discovery", 84 | "chromadb", 85 | "chromadb.api", 86 | ] 87 | ignore_missing_imports = true 88 | 89 | [tool.ruff] 90 | select = [ 91 | "E", # pycodestyle errors 92 | "W", # pycodestyle warnings 93 | "F", # pyflakes 94 | "I", # isort 95 | "C", # flake8-comprehensions 96 | "B", # flake8-bugbear 97 | ] 98 | ignore = [ 99 | "E501", # line too long, handled by black 100 | "B008", # do not perform function calls in argument defaults 101 | "C901", # too complex 102 | ] 103 | 104 | [tool.ruff.per-file-ignores] 105 | # "__init__.py" = ["F401"] 106 | 107 | [tool.ruff.isort] 108 | known-third-party = ["autochain"] 109 | -------------------------------------------------------------------------------- /test_utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .pinecone_mocks import DummyEncoder -------------------------------------------------------------------------------- /test_utils/pinecone_mocks.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional 2 | from unittest import mock 3 | 4 | import pytest 5 | 6 | from autochain.agent.message import BaseMessage 7 | from autochain.models.base import BaseLanguageModel, LLMResult, EmbeddingResult 8 | from autochain.tools.base import Tool 9 | 10 | 11 | class MockIndex: 12 | def __init__(self): 13 | self.kv = {} 14 | 15 | def upsert(self, id_vectors, *args, **kwargs): 16 | for id, vector in id_vectors: 17 | self.kv[id] = vector 18 | 19 | def query(self, vector, *args, **kwargs): 20 | for id, v in self.kv.items(): 21 | if vector == v: 22 | return { 23 | "matches": [ 24 | { 25 | "id": id, 26 | "score": 0.9, 27 | } 28 | ], 29 | "namespace": "", 30 | } 31 | else: 32 | return {} 33 | 34 | 35 | class DummyEncoder(BaseLanguageModel): 36 | def generate( 37 | self, 38 | messages: List[BaseMessage], 39 | functions: Optional[List[Tool]] = None, 40 | stop: Optional[List[str]] = None, 41 | ) -> LLMResult: 42 | pass 43 | 44 | def encode(self, texts: List[str]) -> EmbeddingResult: 45 | return EmbeddingResult( 46 | texts=texts, 47 | embeddings=[ 48 | [-0.025949304923415184, -0.012664584442973137, 0.017791053280234337] 49 | ], 50 | ) 51 | 52 | 53 | @pytest.fixture 54 | def pinecone_index_fixture(): 55 | with mock.patch( 56 | "pinecone.create_index", 57 | return_value=None, 58 | ), mock.patch( 59 | "pinecone.Index", 60 | return_value=MockIndex(), 61 | ), mock.patch( 62 | "pinecone.delete_index", 63 | return_value=None, 64 | ): 65 | yield 66 | -------------------------------------------------------------------------------- /tests/agent/test_conversational_agent.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | from unittest import mock 4 | 5 | import pytest 6 | 7 | from autochain.agent.conversational_agent.conversational_agent import ( 8 | ConversationalAgent, 9 | ) 10 | from autochain.agent.message import ( 11 | ChatMessageHistory, 12 | MessageType, 13 | ) 14 | from autochain.agent.structs import AgentFinish 15 | 16 | from autochain.models.chat_openai import ChatOpenAI 17 | from autochain.tools.simple_handoff.tool import HandOffToAgent 18 | 19 | 20 | @pytest.fixture 21 | def openai_should_answer_fixture(): 22 | with mock.patch( 23 | "autochain.models.chat_openai.ChatOpenAI.generate_with_retry", 24 | side_effect=side_effect, 25 | ): 26 | yield 27 | 28 | 29 | def side_effect(*args, **kwargs): 30 | message = kwargs["messages"][0]["content"] 31 | 32 | if "good" in message: 33 | return { 34 | "choices": [ 35 | { 36 | "message": { 37 | "role": "assistant", 38 | "content": "yes, question is resolved", 39 | } 40 | } 41 | ], 42 | "usage": 10, 43 | } 44 | else: 45 | return { 46 | "choices": [ 47 | { 48 | "message": { 49 | "role": "assistant", 50 | "content": "no, question is not resolved", 51 | } 52 | } 53 | ], 54 | "usage": 10, 55 | } 56 | 57 | 58 | @pytest.fixture 59 | def openai_response_fixture(): 60 | with mock.patch( 61 | "autochain.models.chat_openai.ChatOpenAI.generate_with_retry", 62 | return_value={ 63 | "choices": [ 64 | { 65 | "message": { 66 | "role": "assistant", 67 | "content": json.dumps( 68 | { 69 | "thoughts": { 70 | "plan": "Given workflow policy and previous tools outputs", 71 | "need_use_tool": "Yes if needs to use another tool not used in previous tools outputs else No", 72 | }, 73 | "tool": {"name": "", "args": {"arg_name": ""}}, 74 | "response": "response to suer", 75 | "workflow_finished": "No", 76 | } 77 | ), 78 | } 79 | } 80 | ], 81 | "usage": 10, 82 | }, 83 | ): 84 | yield 85 | 86 | 87 | def test_should_answer_prompt(openai_should_answer_fixture): 88 | os.environ["OPENAI_API_KEY"] = "mock_api_key" 89 | agent = ConversationalAgent.from_llm_and_tools(llm=ChatOpenAI(), tools=[]) 90 | 91 | history = ChatMessageHistory() 92 | history.save_message("good user query", MessageType.UserMessage) 93 | inputs = {"history": history} 94 | response = agent.should_answer(**inputs) 95 | assert isinstance(response, AgentFinish) 96 | 97 | history = ChatMessageHistory() 98 | history.save_message("bad user query", MessageType.UserMessage) 99 | inputs = {"history": history} 100 | agent = ConversationalAgent(llm=ChatOpenAI(), tools=[]) 101 | response = agent.should_answer(**inputs) 102 | assert response is None 103 | 104 | 105 | def test_plan(openai_response_fixture): 106 | os.environ["OPENAI_API_KEY"] = "mock_api_key" 107 | agent = ConversationalAgent.from_llm_and_tools( 108 | llm=ChatOpenAI(), tools=[HandOffToAgent()] 109 | ) 110 | 111 | history = ChatMessageHistory() 112 | history.save_message("first user query", MessageType.UserMessage) 113 | history.save_message("assistant response", MessageType.AIMessage) 114 | history.save_message("second user query", MessageType.UserMessage) 115 | 116 | action = agent.plan(history=history, intermediate_steps=[]) 117 | assert isinstance(action, AgentFinish) 118 | -------------------------------------------------------------------------------- /tests/agent/test_openai_functions_agent.py: -------------------------------------------------------------------------------- 1 | from unittest import mock 2 | 3 | import pytest 4 | from autochain.agent.message import ( 5 | ChatMessageHistory, 6 | MessageType, 7 | ) 8 | from autochain.agent.openai_functions_agent.openai_functions_agent import ( 9 | OpenAIFunctionsAgent, 10 | ) 11 | from autochain.agent.structs import AgentAction, AgentFinish 12 | from autochain.models.chat_openai import ChatOpenAI 13 | 14 | 15 | @pytest.fixture 16 | def openai_function_calling_fixture(): 17 | with mock.patch( 18 | "autochain.models.chat_openai.ChatOpenAI.generate_with_retry", 19 | return_value={ 20 | "choices": [ 21 | { 22 | "message": { 23 | "role": "assistant", 24 | "content": None, 25 | "function_call": { 26 | "name": "get_current_weather", 27 | "arguments": '{\n "location": "Toronto, Canada",\n "format": "celsius"\n}', 28 | }, 29 | } 30 | } 31 | ], 32 | "usage": 10, 33 | }, 34 | ): 35 | yield 36 | 37 | 38 | @pytest.fixture 39 | def openai_response_fixture(): 40 | with mock.patch( 41 | "autochain.models.chat_openai.ChatOpenAI.generate_with_retry", 42 | return_value={ 43 | "choices": [ 44 | { 45 | "message": { 46 | "role": "assistant", 47 | "content": "Sure, let me get that information for you.", 48 | } 49 | } 50 | ], 51 | "usage": 10, 52 | }, 53 | ): 54 | yield 55 | 56 | 57 | @pytest.fixture 58 | def openai_estimate_confidence_fixture(): 59 | with mock.patch( 60 | "autochain.models.chat_openai.ChatOpenAI.generate_with_retry", 61 | return_value={ 62 | "choices": [ 63 | { 64 | "message": { 65 | "role": "assistant", 66 | "content": "the confidence is 4.", 67 | } 68 | } 69 | ], 70 | "usage": 10, 71 | }, 72 | ): 73 | yield 74 | 75 | 76 | @pytest.fixture 77 | def is_generation_confident_fixture(): 78 | with mock.patch( 79 | "autochain.agent.openai_functions_agent.openai_functions_agent.OpenAIFunctionsAgent.is_generation_confident", 80 | return_value=True, 81 | ): 82 | yield 83 | 84 | 85 | def test_function_calling_plan( 86 | openai_function_calling_fixture, is_generation_confident_fixture 87 | ): 88 | agent = OpenAIFunctionsAgent.from_llm_and_tools(llm=ChatOpenAI(), tools=[]) 89 | 90 | history = ChatMessageHistory() 91 | history.save_message("first user query", MessageType.UserMessage) 92 | history.save_message("assistant response", MessageType.AIMessage) 93 | history.save_message("second user query", MessageType.UserMessage) 94 | 95 | action = agent.plan(history=history, intermediate_steps=[]) 96 | assert isinstance(action, AgentAction) 97 | assert action.tool == "get_current_weather" 98 | 99 | 100 | def test_response_plan(openai_response_fixture, is_generation_confident_fixture): 101 | agent = OpenAIFunctionsAgent.from_llm_and_tools(llm=ChatOpenAI(), tools=[]) 102 | 103 | history = ChatMessageHistory() 104 | history.save_message("first user query", MessageType.UserMessage) 105 | history.save_message("assistant response", MessageType.AIMessage) 106 | history.save_message("second user query", MessageType.UserMessage) 107 | 108 | action = agent.plan(history=history, intermediate_steps=[]) 109 | assert isinstance(action, AgentFinish) 110 | 111 | 112 | def test_estimate_confidence(openai_estimate_confidence_fixture): 113 | agent = OpenAIFunctionsAgent.from_llm_and_tools(llm=ChatOpenAI(), tools=[]) 114 | 115 | history = ChatMessageHistory() 116 | history.save_message("first user query", MessageType.UserMessage) 117 | history.save_message("assistant response", MessageType.AIMessage) 118 | history.save_message("second user query", MessageType.UserMessage) 119 | 120 | agent_finish_output = AgentFinish(message="agent response", log="sample log") 121 | is_confident = agent.is_generation_confident( 122 | history=history, agent_output=agent_finish_output, min_confidence=3 123 | ) 124 | assert is_confident 125 | 126 | is_confident = agent.is_generation_confident( 127 | history=history, agent_output=agent_finish_output, min_confidence=5 128 | ) 129 | assert is_confident is False 130 | 131 | agent_action_output = AgentAction( 132 | tool="get_current_weather", 133 | tool_input={"location": "Toronto, Canada", "format": "celsius"}, 134 | ) 135 | is_confident = agent.is_generation_confident( 136 | history=history, agent_output=agent_action_output, min_confidence=3 137 | ) 138 | 139 | assert is_confident 140 | -------------------------------------------------------------------------------- /tests/memory/test_buffer_memory.py: -------------------------------------------------------------------------------- 1 | from autochain.agent.message import MessageType 2 | from autochain.memory.buffer_memory import BufferMemory 3 | 4 | 5 | def test_buffer_kv_memory(): 6 | memory = BufferMemory() 7 | memory.save_memory(key="k", value="v") 8 | value = memory.load_memory(key="k") 9 | assert value == "v" 10 | 11 | default_value = memory.load_memory(key="k2", default="v2") 12 | assert default_value == "v2" 13 | 14 | memory.clear() 15 | assert memory.load_memory(key="k") is None 16 | 17 | 18 | def test_buffer_conversation_memory(): 19 | memory = BufferMemory() 20 | memory.save_conversation("user query", MessageType.UserMessage) 21 | memory.save_conversation("response to user", MessageType.AIMessage) 22 | 23 | conversation = memory.load_conversation().format_message() 24 | assert conversation == "User: user query\nAssistant: response to user\n" 25 | 26 | memory.clear() 27 | message_after_clear = memory.load_conversation().format_message() 28 | assert message_after_clear == "" 29 | -------------------------------------------------------------------------------- /tests/memory/test_long_term_memory.py: -------------------------------------------------------------------------------- 1 | from autochain.agent.message import MessageType 2 | from autochain.memory.long_term_memory import LongTermMemory 3 | from autochain.tools.internal_search.chromadb_tool import ChromaDoc, ChromaDBSearch 4 | from autochain.tools.internal_search.pinecone_tool import PineconeSearch, PineconeDoc 5 | from autochain.tools.internal_search.lancedb_tool import LanceDBSeach, LanceDBDoc 6 | from test_utils.pinecone_mocks import DummyEncoder, pinecone_index_fixture 7 | 8 | 9 | def test_long_term_kv_memory_chromadb(): 10 | memory = LongTermMemory( 11 | long_term_memory=ChromaDBSearch(docs=[], description="long term memory") 12 | ) 13 | memory.save_memory(key="k", value="v") 14 | value = memory.load_memory(key="k") 15 | assert value == "v" 16 | 17 | default_value = memory.load_memory(key="k2", default="v2") 18 | assert default_value == "v2" 19 | 20 | memory.clear() 21 | assert memory.load_memory(key="k") is None 22 | 23 | 24 | def test_buffer_conversation_memory(): 25 | memory = LongTermMemory( 26 | long_term_memory=ChromaDBSearch(docs=[], description="long term memory") 27 | ) 28 | memory.save_conversation("user query", MessageType.UserMessage) 29 | memory.save_conversation("response to user", MessageType.AIMessage) 30 | 31 | conversation = memory.load_conversation().format_message() 32 | assert conversation == "User: user query\nAssistant: response to user\n" 33 | 34 | memory.clear() 35 | message_after_clear = memory.load_conversation().format_message() 36 | assert message_after_clear == "" 37 | 38 | 39 | def test_long_term_memory(): 40 | d = ChromaDoc("This is document1", metadata={"source": "notion"}) 41 | memory = LongTermMemory( 42 | long_term_memory=ChromaDBSearch(docs=[], description="long term memory") 43 | ) 44 | memory.save_memory(key="", value=[d]) 45 | 46 | value = memory.load_memory(key="document query") 47 | assert value == "Doc 0: This is document1" 48 | 49 | 50 | def test_long_term_kv_memory_pincode(pinecone_index_fixture): 51 | memory = LongTermMemory( 52 | long_term_memory=PineconeSearch( 53 | docs=[], description="long term memory", encoder=DummyEncoder() 54 | ) 55 | ) 56 | memory.save_memory(key="k", value="v") 57 | value = memory.load_memory(key="k") 58 | assert value == "v" 59 | 60 | default_value = memory.load_memory(key="k2", default="v2") 61 | assert default_value == "v2" 62 | 63 | memory.clear() 64 | assert memory.load_memory(key="k") is None 65 | 66 | 67 | def test_buffer_conversation_memory_pinecone(pinecone_index_fixture): 68 | memory = LongTermMemory( 69 | long_term_memory=PineconeSearch( 70 | docs=[], description="long term memory", encoder=DummyEncoder() 71 | ) 72 | ) 73 | memory.save_conversation("user query", MessageType.UserMessage) 74 | memory.save_conversation("response to user", MessageType.AIMessage) 75 | 76 | conversation = memory.load_conversation().format_message() 77 | assert conversation == "User: user query\nAssistant: response to user\n" 78 | 79 | memory.clear() 80 | message_after_clear = memory.load_conversation().format_message() 81 | assert message_after_clear == "" 82 | 83 | 84 | def test_long_term_memory_pinecone(pinecone_index_fixture): 85 | d = PineconeDoc( 86 | "This is document1", 87 | ) 88 | memory = LongTermMemory( 89 | long_term_memory=PineconeSearch( 90 | docs=[], description="long term memory", encoder=DummyEncoder() 91 | ) 92 | ) 93 | memory.save_memory(key="", value=[d]) 94 | 95 | value = memory.load_memory(key="document query") 96 | assert value == "Doc 0: This is document1" 97 | 98 | def test_long_term_kv_memory_lancedb(): 99 | memory = LongTermMemory( 100 | long_term_memory=LanceDBSeach( 101 | docs=[], description="long term memory", encoder=DummyEncoder() 102 | ) 103 | ) 104 | memory.save_memory(key="k", value="v") 105 | value = memory.load_memory(key="k") 106 | assert value == "v" 107 | 108 | default_value = memory.load_memory(key="k2", default="v2") 109 | assert default_value == "v2" 110 | 111 | memory.clear() 112 | assert memory.load_memory(key="k") is None 113 | 114 | 115 | def test_buffer_conversation_memory_lancedb(): 116 | memory = LongTermMemory( 117 | long_term_memory=LanceDBSeach( 118 | docs=[], description="long term memory", encoder=DummyEncoder() 119 | ) 120 | ) 121 | memory.save_conversation("user query", MessageType.UserMessage) 122 | memory.save_conversation("response to user", MessageType.AIMessage) 123 | 124 | conversation = memory.load_conversation().format_message() 125 | assert conversation == "User: user query\nAssistant: response to user\n" 126 | 127 | memory.clear() 128 | message_after_clear = memory.load_conversation().format_message() 129 | assert message_after_clear == "" 130 | 131 | 132 | def test_long_term_memory_lancedb(): 133 | d = LanceDBDoc( 134 | "This is document1", 135 | ) 136 | memory = LongTermMemory( 137 | long_term_memory=LanceDBSeach( 138 | docs=[], description="long term memory", encoder=DummyEncoder() 139 | ) 140 | ) 141 | memory.save_memory(key="", value=[d]) 142 | 143 | value = memory.load_memory(key="document query") 144 | assert value == "Doc 0: This is document1" 145 | -------------------------------------------------------------------------------- /tests/memory/test_redis_memory.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | from unittest.mock import MagicMock 3 | 4 | from autochain.agent.message import AIMessage, MessageType, UserMessage 5 | from autochain.memory.redis_memory import RedisMemory 6 | from redis.client import Redis 7 | 8 | 9 | def test_redis_kv_memory(): 10 | mock_redis = MagicMock(spec=Redis) 11 | pickled = pickle.dumps("v") 12 | mock_redis.get.side_effect = [pickled, None, None] 13 | 14 | memory = RedisMemory(redis_key_prefix="test", redis_client=mock_redis) 15 | 16 | memory.save_memory(key="k", value="v") 17 | value = memory.load_memory(key="k") 18 | assert value == "v" 19 | 20 | default_value = memory.load_memory(key="k2", default="v2") 21 | assert default_value == "v2" 22 | 23 | memory.clear() 24 | assert memory.load_memory(key="k") is None 25 | 26 | 27 | def test_redis_conversation_memory(): 28 | mock_redis = MagicMock(spec=Redis) 29 | user_query = "user query" 30 | ai_response = "response to user" 31 | user_message = UserMessage(content=user_query) 32 | ai_message = AIMessage(content=ai_response) 33 | mock_redis.get.side_effect = [ 34 | None, 35 | None, 36 | pickle.dumps([user_message, ai_message]), 37 | None, 38 | ] 39 | 40 | memory = RedisMemory(redis_key_prefix="test", redis_client=mock_redis) 41 | memory.save_conversation(user_query, MessageType.UserMessage) 42 | memory.save_conversation(ai_response, MessageType.AIMessage) 43 | 44 | conversation = memory.load_conversation().format_message() 45 | assert conversation == "User: user query\nAssistant: response to user\n" 46 | 47 | memory.clear() 48 | message_after_clear = memory.load_conversation().format_message() 49 | assert message_after_clear == "" 50 | -------------------------------------------------------------------------------- /tests/models/test_chat_openai.py: -------------------------------------------------------------------------------- 1 | import os 2 | from unittest import mock 3 | 4 | import pytest 5 | from autochain.tools.base import Tool 6 | 7 | from autochain.agent.message import UserMessage 8 | from autochain.models.base import LLMResult 9 | from autochain.models.chat_openai import ChatOpenAI, convert_tool_to_dict 10 | 11 | 12 | def sample_tool_func_no_type(k, *arg, **kwargs): 13 | return f"run with {k}" 14 | 15 | 16 | def sample_tool_func_with_type(k: int, *arg, **kwargs): 17 | return str(k + 1) 18 | 19 | 20 | def sample_tool_func_with_type_default(k: int, d: int = 1, *arg, **kwargs): 21 | return str(k + d + 1) 22 | 23 | 24 | @pytest.fixture 25 | def openai_completion_fixture(): 26 | with mock.patch( 27 | "openai.ChatCompletion.create", 28 | return_value={ 29 | "choices": [ 30 | {"message": {"role": "assistant", "content": "generated message"}} 31 | ], 32 | "usage": 10, 33 | }, 34 | ): 35 | yield 36 | 37 | 38 | def test_chat_completion(openai_completion_fixture): 39 | os.environ["OPENAI_API_KEY"] = "mock_api_key" 40 | model = ChatOpenAI(temperature=0) 41 | response = model.generate([UserMessage(content="test message")]) 42 | assert isinstance(response, LLMResult) 43 | assert len(response.generations) == 1 44 | assert response.generations[0].message.content == "generated message" 45 | 46 | 47 | def test_convert_tool_to_dict(): 48 | no_type_tool = Tool( 49 | func=sample_tool_func_no_type, 50 | description="""This is just a dummy tool without typing info""", 51 | ) 52 | 53 | tool_dict = convert_tool_to_dict(no_type_tool) 54 | 55 | assert tool_dict == { 56 | "name": "sample_tool_func_no_type", 57 | "description": "This is just a " "dummy tool without typing info", 58 | "parameters": { 59 | "type": "object", 60 | "properties": {"k": {"type": "string"}}, 61 | "required": ["k"], 62 | }, 63 | } 64 | 65 | with_type_tool = Tool( 66 | func=sample_tool_func_with_type, 67 | description="""This is just a dummy tool with typing info""", 68 | ) 69 | 70 | with_type_tool_dict = convert_tool_to_dict(with_type_tool) 71 | assert with_type_tool_dict == { 72 | "name": "sample_tool_func_with_type", 73 | "description": "This is just a dummy tool with typing info", 74 | "parameters": { 75 | "type": "object", 76 | "properties": {"k": {"type": "int"}}, 77 | "required": ["k"], 78 | }, 79 | } 80 | 81 | with_type_default_tool = Tool( 82 | func=sample_tool_func_with_type_default, 83 | description="""This is just a dummy tool with typing info""", 84 | ) 85 | 86 | with_type_default_tool_dict = convert_tool_to_dict(with_type_default_tool) 87 | assert with_type_default_tool_dict == { 88 | "name": "sample_tool_func_with_type_default", 89 | "description": "This is just a dummy tool with typing info", 90 | "parameters": { 91 | "type": "object", 92 | "properties": {"k": {"type": "int"}, "d": {"type": "int"}}, 93 | "required": ["k"], 94 | }, 95 | } 96 | 97 | with_type_and_desp_tool = Tool( 98 | func=sample_tool_func_with_type, 99 | description="""This is just a dummy tool with typing info""", 100 | arg_description={"k": "key of the arg"}, 101 | ) 102 | 103 | with_type_and_desp_tool_dict = convert_tool_to_dict(with_type_and_desp_tool) 104 | assert with_type_and_desp_tool_dict == { 105 | "name": "sample_tool_func_with_type", 106 | "description": "This is just a dummy tool with typing info", 107 | "parameters": { 108 | "type": "object", 109 | "properties": {"k": {"type": "int", "description": "key of the arg"}}, 110 | "required": ["k"], 111 | }, 112 | } 113 | -------------------------------------------------------------------------------- /tests/models/test_openai_ada_encoder.py: -------------------------------------------------------------------------------- 1 | import os 2 | from unittest import mock 3 | 4 | import pytest 5 | 6 | from autochain.models.ada_embedding import OpenAIAdaEncoder 7 | from autochain.models.base import EmbeddingResult 8 | 9 | 10 | @pytest.fixture 11 | def ada_encoding_fixture(): 12 | with mock.patch( 13 | "openai.Embedding.create", 14 | return_value={ 15 | "object": "list", 16 | "data": [ 17 | { 18 | "object": "embedding", 19 | "index": 0, 20 | "embedding": [ 21 | -0.025949304923415184, 22 | -0.012664584442973137, 23 | 0.017791053280234337, 24 | ], 25 | } 26 | ], 27 | "model": "text-embedding-ada-002-v2", 28 | "usage": {"prompt_tokens": 2, "total_tokens": 2}, 29 | }, 30 | ): 31 | yield 32 | 33 | 34 | def test_ada_encoder(ada_encoding_fixture): 35 | text = "example text" 36 | os.environ["OPENAI_API_KEY"] = "mock_api_key" 37 | 38 | encoder = OpenAIAdaEncoder(temperature=0) 39 | response = encoder.encode([text]) 40 | 41 | assert response 42 | assert isinstance(response, EmbeddingResult) 43 | assert response.texts[0] == text 44 | assert len(response.embeddings[0]) > 0 45 | assert response.embeddings[0] == [ 46 | -0.025949304923415184, 47 | -0.012664584442973137, 48 | 0.017791053280234337, 49 | ] 50 | -------------------------------------------------------------------------------- /tests/tools/test_base_tool.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from autochain.tools.base import Tool 4 | 5 | 6 | def sample_tool_func(k, *arg, **kwargs): 7 | return f"run with {k}" 8 | 9 | 10 | def test_run_tool(): 11 | tool = Tool( 12 | func=sample_tool_func, 13 | description="""This is just a dummy tool""", 14 | ) 15 | 16 | output = tool.run("test") 17 | assert output == "run with test" 18 | 19 | 20 | def test_tool_name_override(): 21 | new_test_name = "new_name" 22 | tool = Tool( 23 | name=new_test_name, 24 | func=sample_tool_func, 25 | description="""This is just a dummy tool""", 26 | ) 27 | 28 | assert tool.name == new_test_name 29 | 30 | 31 | def test_arg_description(): 32 | valid_arg_description = {"k": "key of the arg"} 33 | 34 | invalid_arg_description = {"not_k": "key of the arg"} 35 | 36 | _ = Tool( 37 | func=sample_tool_func, 38 | description="""This is just a dummy tool""", 39 | arg_description=valid_arg_description, 40 | ) 41 | 42 | with pytest.raises(ValueError): 43 | _ = Tool( 44 | func=sample_tool_func, 45 | description="""This is just a dummy tool""", 46 | arg_description=invalid_arg_description, 47 | ) 48 | -------------------------------------------------------------------------------- /tests/tools/test_chromadb_tool.py: -------------------------------------------------------------------------------- 1 | from autochain.tools.internal_search.chromadb_tool import ChromaDBSearch, ChromaDoc 2 | 3 | 4 | def test_chromadb_tool_run(): 5 | d1 = ChromaDoc("This is document1", metadata={"source": "notion"}) 6 | 7 | d2 = ChromaDoc("This is document2", metadata={"source": "google-docs"}) 8 | 9 | t = ChromaDBSearch( 10 | docs=[d1, d2], name="internal_search", description="internal search" 11 | ) 12 | output = t.run({"query": "This is a query document", "n_results": 2}) 13 | assert output == "Doc 0: This is document1\nDoc 1: This is document2" 14 | -------------------------------------------------------------------------------- /tests/tools/test_google_search.py: -------------------------------------------------------------------------------- 1 | import os 2 | from unittest import mock 3 | 4 | import pytest 5 | 6 | from autochain.tools.google_search.util import GoogleSearchAPIWrapper 7 | 8 | 9 | @pytest.fixture 10 | def google_search_fixture(): 11 | with mock.patch( 12 | "autochain.tools.google_search.util.GoogleSearchAPIWrapper._google_search_results", 13 | return_value=[{"snippet": "Barack Hussein Obama II"}], 14 | ): 15 | yield 16 | 17 | 18 | def test_google_search(google_search_fixture) -> None: 19 | """Test that call gives the correct answer.""" 20 | os.environ["GOOGLE_API_KEY"] = "mock_api_key" 21 | os.environ["GOOGLE_CSE_ID"] = "mock_cse_id" 22 | search = GoogleSearchAPIWrapper() 23 | output = search.run("What was Obama's first name?") 24 | assert "Barack Hussein Obama II" in output 25 | -------------------------------------------------------------------------------- /tests/tools/test_lancedb_tool.py: -------------------------------------------------------------------------------- 1 | from autochain.tools.internal_search.lancedb_tool import LanceDBDoc, LanceDBSeach 2 | from test_utils import DummyEncoder 3 | 4 | 5 | def test_lancedb_search(): 6 | docs = [LanceDBDoc(doc="test_document", id="A")] 7 | 8 | lancedb_search = LanceDBSeach( 9 | uri="lancedb", 10 | description="internal search with lancedb", 11 | docs=docs, 12 | encoder=DummyEncoder(), 13 | ) 14 | assert lancedb_search.docs[0].vector == [ 15 | -0.025949304923415184, 16 | -0.012664584442973137, 17 | 0.017791053280234337, 18 | ] 19 | assert lancedb_search.run({"query": "test question"}) == "Doc 0: test_document" 20 | -------------------------------------------------------------------------------- /tests/tools/test_pinecone_tool.py: -------------------------------------------------------------------------------- 1 | from autochain.tools.internal_search.pinecone_tool import PineconeSearch, PineconeDoc 2 | from test_utils.pinecone_mocks import ( 3 | DummyEncoder, 4 | pinecone_index_fixture, 5 | ) 6 | 7 | 8 | def test_pinecone_search(pinecone_index_fixture): 9 | docs = [PineconeDoc(doc="test_document", id="A")] 10 | 11 | pinecone_search = PineconeSearch( 12 | name="pinecone_search", 13 | description="internal search with pinecone", 14 | docs=docs, 15 | encoder=DummyEncoder(), 16 | ) 17 | assert pinecone_search.docs[0].vector == [ 18 | -0.025949304923415184, 19 | -0.012664584442973137, 20 | 0.017791053280234337, 21 | ] 22 | assert pinecone_search.run({"query": "test question"}) == "Doc 0: test_document" 23 | -------------------------------------------------------------------------------- /tests/tools/test_simple_handoff.py: -------------------------------------------------------------------------------- 1 | from autochain.tools.simple_handoff.tool import HandOffToAgent 2 | 3 | 4 | def test_simple_handoff() -> None: 5 | handoff = HandOffToAgent() 6 | msg = handoff.run() 7 | assert handoff.handoff_msg == msg 8 | --------------------------------------------------------------------------------