├── .github ├── dependabot.yml ├── pr-labeler.yml ├── release-drafter.yml └── workflows │ ├── draft.yml │ ├── pr_labeler.yml │ └── pypi_deploy.yaml ├── .gitignore ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── agent_dingo ├── __init__.py ├── agent │ ├── __init__.py │ ├── agent.py │ ├── chat_context.py │ ├── docgen.py │ ├── function_descriptor.py │ ├── helpers.py │ ├── langchain.py │ ├── parser.py │ └── registry.py ├── core │ ├── blocks.py │ ├── message.py │ ├── output_parser.py │ └── state.py ├── llm │ ├── gemini.py │ ├── litellm.py │ ├── llama_cpp.py │ └── openai.py ├── rag │ ├── base.py │ ├── chunkers │ │ ├── __init__.py │ │ └── recursive.py │ ├── embedders │ │ ├── __init__.py │ │ ├── openai.py │ │ └── sentence_transformer.py │ ├── prompt_modifiers.py │ ├── readers │ │ ├── __init__.py │ │ ├── list.py │ │ ├── pdf.py │ │ ├── web.py │ │ └── word.py │ └── vector_stores │ │ ├── chromadb.py │ │ └── qdrant.py ├── serve.py └── utils.py ├── pyproject.toml └── tests ├── __init__.py ├── fake_llm.py ├── test_agent ├── __init__.py ├── test_agent.py ├── test_extract_substr.py ├── test_get_required_args.py ├── test_parser.py └── test_registry.py └── test_core ├── __init__.py ├── test_blocks.py ├── test_message.py ├── test_output_parser.py └── test_state.py /.github/dependabot.yml: -------------------------------------------------------------------------------- 1 | version: 2 2 | updates: 3 | - package-ecosystem: "pip" 4 | directory: "/" 5 | schedule: 6 | interval: "monthly" 7 | - package-ecosystem: "github-actions" 8 | directory: "/" 9 | schedule: 10 | interval: "monthly" 11 | -------------------------------------------------------------------------------- /.github/pr-labeler.yml: -------------------------------------------------------------------------------- 1 | feature: ['features/*', 'feature/*', 'feat/*', 'features-*', 'feature-*', 'feat-*'] 2 | fix: ['fixes/*', 'fix/*'] 3 | chore: ['chore/*'] 4 | dependencies: ['update/*'] 5 | -------------------------------------------------------------------------------- /.github/release-drafter.yml: -------------------------------------------------------------------------------- 1 | name-template: "v$RESOLVED_VERSION" 2 | tag-template: "v$RESOLVED_VERSION" 3 | categories: 4 | - title: "🚀 Features" 5 | labels: 6 | - "feature" 7 | - "enhancement" 8 | - title: "🐛 Bug Fixes" 9 | labels: 10 | - "fix" 11 | - "bugfix" 12 | - "bug" 13 | - title: "🧹 Maintenance" 14 | labels: 15 | - "maintenance" 16 | - "dependencies" 17 | - "refactoring" 18 | - "cosmetic" 19 | - "chore" 20 | - title: "📝️ Documentation" 21 | labels: 22 | - "documentation" 23 | - "docs" 24 | change-template: "- $TITLE (#$NUMBER)" 25 | change-title-escapes: '\<*_&' # You can add # and @ to disable mentions 26 | version-resolver: 27 | major: 28 | labels: 29 | - "major" 30 | minor: 31 | labels: 32 | - "minor" 33 | patch: 34 | labels: 35 | - "patch" 36 | default: patch 37 | template: | 38 | ## Changes 39 | 40 | $CHANGES 41 | -------------------------------------------------------------------------------- /.github/workflows/draft.yml: -------------------------------------------------------------------------------- 1 | # Drafts the next Release notes as Pull Requests are merged (or commits are pushed) into "main" or "master" 2 | name: Draft next release 3 | 4 | on: 5 | push: 6 | branches: [main, "master"] 7 | 8 | permissions: 9 | contents: read 10 | 11 | jobs: 12 | update-release-draft: 13 | permissions: 14 | contents: write 15 | pull-requests: write 16 | runs-on: ubuntu-latest 17 | steps: 18 | - uses: release-drafter/release-drafter@v5 19 | env: 20 | GITHUB_TOKEN: ${{ github.token }} 21 | -------------------------------------------------------------------------------- /.github/workflows/pr_labeler.yml: -------------------------------------------------------------------------------- 1 | # This workflow will apply the corresponding label on a pull request 2 | name: PR Labeler 3 | 4 | on: 5 | pull_request_target: 6 | 7 | permissions: 8 | contents: read 9 | pull-requests: write 10 | 11 | jobs: 12 | pr-labeler: 13 | runs-on: ubuntu-latest 14 | steps: 15 | - uses: TimonVS/pr-labeler-action@v4 16 | with: 17 | repo-token: ${{ github.token }} 18 | -------------------------------------------------------------------------------- /.github/workflows/pypi_deploy.yaml: -------------------------------------------------------------------------------- 1 | name: PyPi Deploy 2 | 3 | on: 4 | release: 5 | types: [published] 6 | workflow_dispatch: 7 | 8 | jobs: 9 | deploy: 10 | runs-on: ubuntu-latest 11 | steps: 12 | - name: Checkout code 13 | uses: actions/checkout@v4 14 | 15 | - name: Setup Python 16 | uses: actions/setup-python@v4 17 | with: 18 | python-version: '3.10' 19 | 20 | - name: Install dependencies 21 | run: | 22 | python -m pip install --upgrade pip 23 | pip install build twine 24 | 25 | - name: Build and publish 26 | env: 27 | TWINE_USERNAME: __token__ 28 | TWINE_PASSWORD: ${{ secrets.PYPI_TOKEN }} 29 | run: | 30 | python -m build 31 | twine upload dist/* -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | test.py 162 | tmp*.py -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Contributor Covenant Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | We as members, contributors, and leaders pledge to make participation in our 6 | community a harassment-free experience for everyone, regardless of age, body 7 | size, visible or invisible disability, ethnicity, sex characteristics, gender 8 | identity and expression, level of experience, education, socio-economic status, 9 | nationality, personal appearance, race, religion, or sexual identity 10 | and orientation. 11 | 12 | We pledge to act and interact in ways that contribute to an open, welcoming, 13 | diverse, inclusive, and healthy community. 14 | 15 | ## Our Standards 16 | 17 | Examples of behavior that contributes to a positive environment for our 18 | community include: 19 | 20 | * Demonstrating empathy and kindness toward other people 21 | * Being respectful of differing opinions, viewpoints, and experiences 22 | * Giving and gracefully accepting constructive feedback 23 | * Accepting responsibility and apologizing to those affected by our mistakes, 24 | and learning from the experience 25 | * Focusing on what is best not just for us as individuals, but for the 26 | overall community 27 | 28 | Examples of unacceptable behavior include: 29 | 30 | * The use of sexualized language or imagery, and sexual attention or 31 | advances of any kind 32 | * Trolling, insulting or derogatory comments, and personal or political attacks 33 | * Public or private harassment 34 | * Publishing others' private information, such as a physical or email 35 | address, without their explicit permission 36 | * Other conduct which could reasonably be considered inappropriate in a 37 | professional setting 38 | 39 | ## Enforcement Responsibilities 40 | 41 | Community leaders are responsible for clarifying and enforcing our standards of 42 | acceptable behavior and will take appropriate and fair corrective action in 43 | response to any behavior that they deem inappropriate, threatening, offensive, 44 | or harmful. 45 | 46 | Community leaders have the right and responsibility to remove, edit, or reject 47 | comments, commits, code, wiki edits, issues, and other contributions that are 48 | not aligned to this Code of Conduct, and will communicate reasons for moderation 49 | decisions when appropriate. 50 | 51 | ## Scope 52 | 53 | This Code of Conduct applies within all community spaces, and also applies when 54 | an individual is officially representing the community in public spaces. 55 | Examples of representing our community include using an official e-mail address, 56 | posting via an official social media account, or acting as an appointed 57 | representative at an online or offline event. 58 | 59 | ## Enforcement 60 | 61 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 62 | reported to the community leaders responsible for enforcement at 63 | https://www.linkedin.com/in/iryna-kondrashchenko-673800155/. 64 | All complaints will be reviewed and investigated promptly and fairly. 65 | 66 | All community leaders are obligated to respect the privacy and security of the 67 | reporter of any incident. 68 | 69 | ## Enforcement Guidelines 70 | 71 | Community leaders will follow these Community Impact Guidelines in determining 72 | the consequences for any action they deem in violation of this Code of Conduct: 73 | 74 | ### 1. Correction 75 | 76 | **Community Impact**: Use of inappropriate language or other behavior deemed 77 | unprofessional or unwelcome in the community. 78 | 79 | **Consequence**: A private, written warning from community leaders, providing 80 | clarity around the nature of the violation and an explanation of why the 81 | behavior was inappropriate. A public apology may be requested. 82 | 83 | ### 2. Warning 84 | 85 | **Community Impact**: A violation through a single incident or series 86 | of actions. 87 | 88 | **Consequence**: A warning with consequences for continued behavior. No 89 | interaction with the people involved, including unsolicited interaction with 90 | those enforcing the Code of Conduct, for a specified period of time. This 91 | includes avoiding interactions in community spaces as well as external channels 92 | like social media. Violating these terms may lead to a temporary or 93 | permanent ban. 94 | 95 | ### 3. Temporary Ban 96 | 97 | **Community Impact**: A serious violation of community standards, including 98 | sustained inappropriate behavior. 99 | 100 | **Consequence**: A temporary ban from any sort of interaction or public 101 | communication with the community for a specified period of time. No public or 102 | private interaction with the people involved, including unsolicited interaction 103 | with those enforcing the Code of Conduct, is allowed during this period. 104 | Violating these terms may lead to a permanent ban. 105 | 106 | ### 4. Permanent Ban 107 | 108 | **Community Impact**: Demonstrating a pattern of violation of community 109 | standards, including sustained inappropriate behavior, harassment of an 110 | individual, or aggression toward or disparagement of classes of individuals. 111 | 112 | **Consequence**: A permanent ban from any sort of public interaction within 113 | the community. 114 | 115 | ## Attribution 116 | 117 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], 118 | version 2.0, available at 119 | https://www.contributor-covenant.org/version/2/0/code_of_conduct.html. 120 | 121 | Community Impact Guidelines were inspired by [Mozilla's code of conduct 122 | enforcement ladder](https://github.com/mozilla/diversity). 123 | 124 | [homepage]: https://www.contributor-covenant.org 125 | 126 | For answers to common questions about this code of conduct, see the FAQ at 127 | https://www.contributor-covenant.org/faq. Translations are available at 128 | https://www.contributor-covenant.org/translations. -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to agent_dingo 2 | 3 | Welcome! We appreciate your interest in contributing to `agent_dingo`. Whether you're a developer, designer, writer, or simply passionate about open source, there are several ways you can help improve this project. This document will guide you through the process of contributing to our Python repository. 4 | 5 | ## How Can I Contribute? 6 | 7 | There are several ways you can contribute to this project: 8 | 9 | - Bug Fixes: Help us identify and fix issues in the codebase. 10 | - Feature Implementation: Implement new features and enhancements. 11 | - Documentation: Improve the project's documentation, including code comments and README files. 12 | - Testing: Write and improve test cases to ensure the project's quality and reliability. 13 | - Translations: Provide translations for the project's documentation or user interface. 14 | - Bug Reports and Feature Requests: Submit bug reports or suggest new features and improvements. 15 | 16 | **Important:** before contributing, we recommend that you open an issue to discuss your planned changes. This allows us to align our goals, provide guidance, and potentially find other contributors interested in collaborating on the same feature or bug fix. 17 | 18 | > ### Legal Notice 19 | > 20 | > When contributing to this project, you must agree that you have authored 100% of the content, that you have the necessary rights to the content and that the content you contribute may be provided under the project license. 21 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Oleh Kostromin 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

2 |
3 | AgentDingo 4 |
5 | Agent Dingo 6 |
7 |

8 | 9 |

A microframework for building LLM-powered pipelines and agents.

10 | 11 |

12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 |

25 | 26 | _Dingo_ is a compact LLM orchestration framework designed for straightforward development of production-ready LLM-powered applications. It combines simplicity with flexibility, allowing for the efficient construction of pipelines and agents, while maintaining a high level of control over the process. 27 | 28 | 29 | ## Support us 🤝 30 | 31 | You can support the project in the following ways: 32 | 33 | - ⭐ Star Dingo on GitHub (click the star button in the top right corner) 34 | - 💡 Provide your feedback or propose ideas in the [issues](https://github.com/BeastByteAI/agent_dingo/issues) section or [Discord](https://discord.gg/YDAbwuWK7V) 35 | - 📰 Post about Dingo on LinkedIn or other platforms 36 | - 🔗 Check out our other projects: Scikit-LLM, Falcon 37 | 38 |
39 | 40 | 41 | 42 | 43 | Logo 44 | 45 |

46 | 47 | 48 | 49 | 50 | Logo 51 | 52 | 53 | 54 | 55 | ## Quick Start & Documentation 🚀 56 | 57 | **Step 1:** Install `agent-dingo` 58 | 59 | ```bash 60 | pip install agent-dingo 61 | ``` 62 | 63 | **Step 2:** Configure your OpenAI API key 64 | 65 | ```bash 66 | export OPENAI_API_KEY= 67 | ``` 68 | 69 | **Step 3:** Build your pipeline 70 | 71 | Example 1 (Linear Pipeline): 72 | 73 | ````python 74 | from agent_dingo.llm.openai import OpenAI 75 | from agent_dingo.core.blocks import PromptBuilder 76 | from agent_dingo.core.message import UserMessage 77 | from agent_dingo.core.state import ChatPrompt 78 | 79 | 80 | # Model 81 | gpt = OpenAI("gpt-3.5-turbo") 82 | 83 | # Summary prompt block 84 | summary_pb = PromptBuilder( 85 | [UserMessage("Summarize the text in 10 words: ```{text}```.")] 86 | ) 87 | 88 | # Translation prompt block 89 | translation_pb = PromptBuilder( 90 | [UserMessage("Translate the text into {language}: ```{summarized_text}```.")], 91 | from_state=["summarized_text"], 92 | ) 93 | 94 | # Pipeline 95 | pipeline = summary_pb >> gpt >> translation_pb >> gpt 96 | 97 | input_text = """ 98 | Dingo is an ancient lineage of dog found in Australia, exhibiting a lean and sturdy physique adapted for speed and endurance, dingoes feature a wedge-shaped skull and come in colorations like light ginger, black and tan, or creamy white. They share a close genetic relationship with the New Guinea singing dog, diverging early from the domestic dog lineage. Dingoes typically form packs composed of a mated pair and their offspring, indicating social structures that have persisted through their history, dating back approximately 3,500 years in Australia. 99 | """ 100 | 101 | output = pipeline.run(text = input_text, language = "french") 102 | print(output) 103 | ```` 104 | 105 | Example 2 (Agent): 106 | 107 | ```python 108 | from agent_dingo.agent import Agent 109 | from agent_dingo.llm.openai import OpenAI 110 | import requests 111 | 112 | llm = OpenAI(model="gpt-3.5-turbo") 113 | agent = Agent(llm, max_function_calls=3) 114 | 115 | @agent.function 116 | def get_temperature(city: str) -> str: 117 | """Retrieves the current temperature in a city. 118 | 119 | Parameters 120 | ---------- 121 | city : str 122 | The city to get the temperature for. 123 | 124 | Returns 125 | ------- 126 | str 127 | String representation of the json response from the weather api. 128 | """ 129 | base_url = "https://api.openweathermap.org/data/2.5/weather" 130 | params = { 131 | "q": city, 132 | "appid": "", 133 | "units": "metric" 134 | } 135 | response = requests.get(base_url, params=params) 136 | data = response.json() 137 | return str(data) 138 | 139 | pipeline = agent.as_pipeline() 140 | ``` 141 | 142 | For a more detailed overview and additional examples, please refer to the **[documentation](https://dingo.beastbyte.ai/)**. 143 | -------------------------------------------------------------------------------- /agent_dingo/__init__.py: -------------------------------------------------------------------------------- 1 | # from agent_dingo.agent import AgentDingo 2 | 3 | __version__ = "0.1.0" 4 | __author__ = "Oleh Kostromin, Iryna Kondrashchenko" 5 | -------------------------------------------------------------------------------- /agent_dingo/agent/__init__.py: -------------------------------------------------------------------------------- 1 | from agent_dingo.agent.agent import Agent -------------------------------------------------------------------------------- /agent_dingo/agent/agent.py: -------------------------------------------------------------------------------- 1 | from typing import Callable, Union, Optional, Tuple, List, Literal 2 | from agent_dingo.agent.parser import parse 3 | from agent_dingo.agent.helpers import get_required_args, construct_json_repr 4 | from agent_dingo.agent.docgen import generate_docstring 5 | from agent_dingo.agent.function_descriptor import FunctionDescriptor 6 | from agent_dingo.core.blocks import BaseLLM, BaseAgent, Context, ChatPrompt, KVData 7 | from agent_dingo.core.message import UserMessage 8 | from agent_dingo.agent.chat_context import ChatContext 9 | from agent_dingo.agent.registry import Registry as _Registry 10 | import json 11 | import os 12 | import inspect 13 | from asyncio import run as asyncio_run, to_thread 14 | import warnings 15 | 16 | 17 | class Agent(BaseAgent): 18 | 19 | def __init__( 20 | self, 21 | llm: BaseLLM, 22 | max_function_calls: int = 10, 23 | before_function_call: Callable = None, 24 | allow_codegen: Union[bool, Literal["env"]] = "env", 25 | name="agent", 26 | description: str = "A helpful agent", 27 | ): 28 | """The agent that can be used to register functions and chat with the LLM. 29 | 30 | Parameters 31 | ---------- 32 | llm : BaseLLM 33 | llm instance to use; must support function calls 34 | max_function_calls : int, optional 35 | max number of consecutive function calls, by default 10 36 | before_function_call : Callable, optional 37 | a callable to execute before function call, by default None 38 | allow_codegen : Union[bool, Literal["env"]], optional 39 | determines whether the function docstrings can be auto-generated by the LLM, by default "env" 40 | name : str, optional 41 | name of the agent, by default "agent" 42 | description : str, optional 43 | description of the agent (needed when used as a sub-agent), by default "A helpful agent" 44 | """ 45 | if not isinstance(allow_codegen, bool) and allow_codegen != "env": 46 | raise ValueError( 47 | "allow_codegen must be a boolean or the string 'env' to use the DINGO_ALLOW_CODEGEN environment variable" 48 | ) 49 | if not llm.supports_function_calls: 50 | raise ValueError( 51 | "Provided LLM does not support function calls and cannot be used with the agent." 52 | ) 53 | self.model = llm 54 | self._allow_codegen = allow_codegen 55 | self._registry = _Registry() 56 | self.max_function_calls = max_function_calls 57 | self.before_function_call = before_function_call 58 | self.name = name 59 | self.description = description 60 | self._registered = False 61 | 62 | def _is_codegen_allowed(self) -> bool: 63 | """Determines whether docstring generation is allowed. 64 | 65 | Returns 66 | ------- 67 | bool 68 | True if docstring generation is allowed, False otherwise. 69 | """ 70 | if self._allow_codegen == "env": 71 | return bool(os.getenv("DINGO_ALLOW_CODEGEN", True)) 72 | return self._allow_codegen 73 | 74 | def register_descriptor(self, descriptor: FunctionDescriptor) -> None: 75 | """Registers a function descriptor with the agent. 76 | 77 | Parameters 78 | ---------- 79 | descriptor : FunctionDescriptor 80 | The function descriptor to register. 81 | """ 82 | if descriptor.required_context_keys is not None and self._registered: 83 | raise ValueError( 84 | "required_context_keys must be None if functions are registered after the agent" 85 | ) 86 | if not isinstance(descriptor, FunctionDescriptor): 87 | raise ValueError("descriptor must be a FunctionDescriptor") 88 | self._registry.add( 89 | name=descriptor.name, 90 | func=descriptor.func, 91 | json_repr=descriptor.json_repr, 92 | requires_context=descriptor.requires_context, 93 | required_context_keys=descriptor.required_context_keys, 94 | ) 95 | 96 | def register_function( 97 | self, func: Callable, required_context_keys: Optional[List[str]] = None 98 | ) -> None: 99 | """Registers a function with the agent. 100 | 101 | Parameters 102 | ---------- 103 | func : Callable 104 | The function to register. 105 | 106 | Raises 107 | ------ 108 | ValueError 109 | Function has no docstring and code generation is not allowed 110 | """ 111 | if required_context_keys is not None and self._registered: 112 | raise ValueError( 113 | "required_context_keys must be None if functions are registered after the agent" 114 | ) 115 | if required_context_keys is not None: 116 | for key in required_context_keys: 117 | if not isinstance(key, str): 118 | raise ValueError("required_context_keys must be a list of strings") 119 | docstring = func.__doc__ 120 | if docstring is None: 121 | if not self._is_codegen_allowed(): 122 | raise ValueError( 123 | "Function has no docstring and code generation is not allowed" 124 | ) 125 | docstring = generate_docstring(func, self.model) 126 | body, requires_context = parse(docstring) 127 | required_args = get_required_args(func) 128 | json_repr = construct_json_repr( 129 | func.__name__, body["description"], body["properties"], required_args 130 | ) 131 | self._registry.add( 132 | func.__name__, func, json_repr, requires_context, required_context_keys 133 | ) 134 | 135 | def _call_from_agent(self, query: str, chat_context: ChatContext) -> str: 136 | """Calls the agent from another from the agent. 137 | 138 | Parameters 139 | ---------- 140 | query : str 141 | Query 142 | context : Context 143 | Chat context 144 | """ 145 | prompt = ChatPrompt([UserMessage(query)]) 146 | response = self.forward(prompt, chat_context[0], chat_context[1]) 147 | return response["_out_0"] 148 | 149 | async def _async_call_from_agent( 150 | self, query: str, chat_context: ChatContext 151 | ) -> str: 152 | """Calls the agent from another from the agent. 153 | 154 | Parameters 155 | ---------- 156 | query : str 157 | Query 158 | context : Context 159 | Chat context 160 | """ 161 | prompt = ChatPrompt([UserMessage(query)]) 162 | response = await self.async_forward(prompt, chat_context[0], chat_context[1]) 163 | return response["_out_0"] 164 | 165 | def as_function_descriptor(self, as_async: bool = False) -> FunctionDescriptor: 166 | descriptor = FunctionDescriptor( 167 | name=self.name, 168 | func=self._call_from_agent if not as_async else self._async_call_from_agent, 169 | json_repr={ 170 | "name": self.name, 171 | "description": self.description, 172 | "parameters": { 173 | "type": "object", 174 | "properties": { 175 | "query": { 176 | "type": "string", 177 | "description": "The natural language query to send to the agent.", 178 | } 179 | }, 180 | "required": ["query"], 181 | }, 182 | }, 183 | requires_context=True, 184 | required_context_keys=self.get_required_context_keys(), 185 | ) 186 | self._registered = True 187 | return descriptor 188 | 189 | def function(self, *args, **kwargs) -> Callable: 190 | """Registers a function with the agent and returns the function. 191 | 192 | Parameters 193 | ---------- 194 | func : Callable 195 | The function to register. 196 | 197 | Returns 198 | ------- 199 | Callable 200 | The function. 201 | """ 202 | 203 | def outer(required_context_keys): 204 | def register_decorator(func): 205 | self.register_function( 206 | func, required_context_keys=required_context_keys 207 | ) 208 | return func 209 | 210 | return register_decorator 211 | 212 | if len(args) == 1 and callable(args[0]) and not kwargs: 213 | func = args[0] 214 | self.register_function(func) 215 | return func 216 | else: 217 | return outer(kwargs.get("required_context_keys", None)) 218 | 219 | def get_required_context_keys(self) -> List[str]: 220 | # this allows to handle the case where the user registers a function after registering the agent 221 | return self._registry.get_required_context_keys() 222 | 223 | def forward( 224 | self, state: ChatPrompt, context: Context, store: KVData 225 | ) -> Tuple[str, List[dict]]: 226 | """Sends a message to the LLM and returns the response. Calls functions if the LLM requests it. 227 | 228 | Parameters 229 | ---------- 230 | messages : Union[str, dict] 231 | The message(s) to send to the LLM 232 | context : ChatContext, optional 233 | The chat context, by default None 234 | Returns 235 | ------- 236 | Tuple[str, List[dict]] 237 | A tuple containing the last response and the conversation history. 238 | """ 239 | messages = state.dict 240 | n_calls = 0 241 | available_functions = self._registry.get_available_functions() 242 | chat_context = (context, store) 243 | while True: 244 | available_functions_i = ( 245 | available_functions if n_calls < self.max_function_calls else None 246 | ) 247 | response = self.model.send_message( 248 | messages, 249 | functions=available_functions_i, 250 | usage_meter=store.usage_meter, 251 | ) 252 | if response.get("tool_calls"): 253 | messages.append(response) 254 | for function in response["tool_calls"]: 255 | function_name = function["function"]["name"] 256 | function_args = json.loads(function["function"]["arguments"]) 257 | f, requires_context = self._registry.get_function(function_name) 258 | if requires_context: 259 | function_args["chat_context"] = chat_context 260 | if self.before_function_call: 261 | f, function_args = self.before_function_call( 262 | function_name, f, function_args 263 | ) 264 | try: 265 | if inspect.iscoroutinefunction(f): 266 | warnings.warn("Async function is called from a sync agent.") 267 | result = asyncio_run(f(**function_args)) 268 | else: 269 | result = f(**function_args) 270 | except Exception as e: 271 | print(e) 272 | result = "An error occurred while executing the function." 273 | messages.append( 274 | { 275 | "role": "tool", 276 | "tool_call_id": function["id"], 277 | "content": result, 278 | } 279 | ) 280 | n_calls += 1 281 | else: 282 | messages.append(response) 283 | return KVData(_out_0=response["content"]) 284 | 285 | async def async_forward( 286 | self, state: ChatPrompt, context: Context, store: KVData 287 | ) -> Tuple[str, List[dict]]: 288 | """Sends a message to the LLM and returns the response. Calls functions if the LLM requests it. 289 | 290 | Parameters 291 | ---------- 292 | messages : Union[str, dict] 293 | The message(s) to send to the LLM 294 | context : ChatContext, optional 295 | The chat context, by default None 296 | Returns 297 | ------- 298 | Tuple[str, List[dict]] 299 | A tuple containing the last response and the conversation history. 300 | """ 301 | messages = state.dict 302 | n_calls = 0 303 | available_functions = self._registry.get_available_functions() 304 | chat_context = (context, store) 305 | while True: 306 | available_functions_i = ( 307 | available_functions if n_calls < self.max_function_calls else None 308 | ) 309 | response = await self.model.async_send_message( 310 | messages, 311 | functions=available_functions_i, 312 | usage_meter=store.usage_meter, 313 | ) 314 | if response.get("tool_calls"): 315 | messages.append(response) 316 | for function in response["tool_calls"]: 317 | function_name = function["function"]["name"] 318 | function_args = json.loads(function["function"]["arguments"]) 319 | f, requires_context = self._registry.get_function(function_name) 320 | if requires_context: 321 | function_args["chat_context"] = chat_context 322 | if self.before_function_call: 323 | f, function_args = self.before_function_call( 324 | function_name, f, function_args 325 | ) 326 | try: 327 | if inspect.iscoroutinefunction(f): 328 | result = await f(**function_args) 329 | else: 330 | warnings.warn( 331 | "Sync function is called from an async agent." 332 | ) 333 | result = await to_thread(f, **function_args) 334 | except Exception as e: 335 | print(e) 336 | result = "An error occurred while executing the function." 337 | messages.append( 338 | { 339 | "role": "tool", 340 | "tool_call_id": function["id"], 341 | "content": result, 342 | } 343 | ) 344 | n_calls += 1 345 | else: 346 | messages.append(response) 347 | return KVData(_out_0=response["content"]) 348 | -------------------------------------------------------------------------------- /agent_dingo/agent/chat_context.py: -------------------------------------------------------------------------------- 1 | from agent_dingo.core.state import Context, KVData 2 | from typing import Tuple 3 | 4 | # mainly for legacy reasons 5 | ChatContext = Tuple[Context, KVData] 6 | -------------------------------------------------------------------------------- /agent_dingo/agent/docgen.py: -------------------------------------------------------------------------------- 1 | from typing import Callable 2 | import inspect 3 | from agent_dingo.core.blocks import BaseLLM 4 | import re 5 | 6 | _SYSTEM_MSG = "You are a code generation tool. Your responses are limited to providing the docstrings of functions." 7 | 8 | _PROMPT = """ 9 | You will be provided with a Python function. Your task is to generate a docstring in a Google style for that python function and return only the docsting delimited by triple backticks. Do not return the function itself. 10 | 11 | Python function: 12 | ```{code}``` 13 | 14 | Docstring: 15 | """ 16 | 17 | 18 | def generate_docstring(func: Callable, model: BaseLLM) -> str: 19 | """Generates a docstring for a given function. 20 | 21 | Parameters 22 | ---------- 23 | func : Callable 24 | The function to generate a docstring for. 25 | model : str 26 | The model to use for generating the docstring. 27 | 28 | Returns 29 | ------- 30 | str 31 | The generated docstring. 32 | """ 33 | code = inspect.getsource(func) 34 | messages = [ 35 | {"role": "system", "content": _SYSTEM_MSG}, 36 | {"role": "user", "content": _PROMPT.format(code=code)}, 37 | ] 38 | response = model.send_message(messages, temperature=0.0) 39 | 40 | response = ( 41 | response["content"] 42 | .replace("```python\n", "") 43 | .replace("```", "") 44 | .replace('"""', "") 45 | .replace("'''", "") 46 | ) 47 | 48 | return extract_substr(response) 49 | 50 | 51 | def extract_substr(input_string: str) -> str: 52 | """Extracts the desription and args from a docstring. 53 | 54 | Parameters 55 | ---------- 56 | input_string : str 57 | The docstring to extract the description and args from. 58 | 59 | Returns 60 | ------- 61 | str 62 | Reduced docstring containing only the description and the args. 63 | """ 64 | 65 | # Find the 'Returns:' string and capture everything before it 66 | match = re.search(r"(.*?)Returns:", input_string, re.DOTALL) 67 | 68 | if match: 69 | # Extract the portion before 'Returns:' and remove leading/trailing whitespace 70 | before_returns = match.group(1).strip() 71 | 72 | # Remove everything after 'Returns:' including the next line 73 | result = re.sub(r"Returns:.*?(\n|$)", "", before_returns, flags=re.DOTALL) 74 | else: 75 | result = input_string # 'Returns:' string not found 76 | 77 | return result 78 | -------------------------------------------------------------------------------- /agent_dingo/agent/function_descriptor.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Callable, Optional, List 3 | 4 | 5 | @dataclass 6 | class FunctionDescriptor: 7 | name: str 8 | func: Callable 9 | json_repr: dict 10 | requires_context: bool 11 | required_context_keys: Optional[List[str]] = None 12 | -------------------------------------------------------------------------------- /agent_dingo/agent/helpers.py: -------------------------------------------------------------------------------- 1 | from agent_dingo.core.state import Context 2 | from typing import Callable, List 3 | import inspect 4 | from agent_dingo.agent.chat_context import ChatContext 5 | 6 | 7 | def construct_json_repr( 8 | name: str, description: str, properties: dict, required: List[str] 9 | ) -> dict: 10 | """Constructs a JSON representation of a function. 11 | 12 | Parameters 13 | ---------- 14 | name : str 15 | The name of the function. 16 | description : str 17 | The description of the function. 18 | properties : dict 19 | The properties of the function (arguments, their descriptions and types). 20 | required : List[str] 21 | The required arguments of the function. 22 | 23 | Returns 24 | ------- 25 | dict 26 | The JSON representation of the function. 27 | """ 28 | return { 29 | "name": name, 30 | "description": description, 31 | "parameters": { 32 | "type": "object", 33 | "properties": properties, 34 | }, 35 | "required": required, 36 | } 37 | 38 | 39 | def get_required_args(func: Callable) -> List[str]: 40 | """Returns a list of the required arguments of a function. 41 | 42 | Parameters 43 | ---------- 44 | func : Callable 45 | The function. 46 | 47 | Returns 48 | ------- 49 | List[str] 50 | A list of the required arguments of the function. 51 | """ 52 | sig = inspect.signature(func) 53 | params = sig.parameters 54 | required_args = [ 55 | name 56 | for name, param in params.items() 57 | if param.default == inspect.Parameter.empty 58 | and not (name == "chat_context" and param.annotation == ChatContext) 59 | ] 60 | return required_args 61 | -------------------------------------------------------------------------------- /agent_dingo/agent/langchain.py: -------------------------------------------------------------------------------- 1 | from agent_dingo.agent.function_descriptor import FunctionDescriptor 2 | from langchain.tools import ( 3 | BaseTool as _BaseLangchainTool, 4 | format_tool_to_openai_function as _format_tool_to_openai_function, 5 | ) 6 | 7 | 8 | def convert_langchain_tool( 9 | tool: _BaseLangchainTool, make_async: bool = False 10 | ) -> FunctionDescriptor: 11 | """Converts a langchain tool to a function descriptor. 12 | 13 | Parameters 14 | ---------- 15 | tool : _BaseLangchainTool 16 | The langchain tool. 17 | 18 | Returns 19 | ------- 20 | FunctionDescriptor 21 | The function descriptor. 22 | """ 23 | if not isinstance(tool, _BaseLangchainTool): 24 | raise ValueError("tool must be a subclass of langchain.tools.BaseTool") 25 | 26 | json_repr = _format_tool_to_openai_function(tool=tool) 27 | name = json_repr["name"] 28 | if make_async: 29 | 30 | async def func(__arg1): 31 | return await tool.arun(tool_input=__arg1) 32 | 33 | else: 34 | 35 | def func(__arg1): 36 | return tool.run(tool_input=__arg1) 37 | 38 | requires_context = False 39 | descriptor = FunctionDescriptor( 40 | name=name, func=func, json_repr=json_repr, requires_context=requires_context 41 | ) 42 | return descriptor 43 | -------------------------------------------------------------------------------- /agent_dingo/agent/parser.py: -------------------------------------------------------------------------------- 1 | from docstring_parser import parse as _parse 2 | import ast 3 | 4 | 5 | _types = { 6 | "str": "string", 7 | "int": "integer", 8 | "float": "number", 9 | "bool": "boolean", 10 | "list": "array", 11 | "dict": "object", 12 | } 13 | 14 | 15 | def parse(docstring: str) -> dict: 16 | """Parses a docstring. 17 | 18 | Parameters 19 | ---------- 20 | docstring : str 21 | The docstring to parse. 22 | 23 | Returns 24 | ------- 25 | dict 26 | A dictionary containing the description and the arguments of the function. 27 | 28 | Raises 29 | ------ 30 | ValueError 31 | If the docstring has no description. 32 | """ 33 | parsed = _parse(docstring) 34 | description = "" 35 | if parsed.short_description: 36 | description = parsed.short_description 37 | if parsed.long_description: 38 | if description != "": 39 | description += "\n" + parsed.long_description 40 | else: 41 | description = parsed.long_description 42 | if description == "": 43 | raise ValueError("Docstring has no description") 44 | args = {} 45 | requires_context = False 46 | for arg in parsed.params: 47 | if arg.arg_name == "chat_context" and arg.type_name == "ChatContext": 48 | requires_context = True 49 | continue 50 | d = {} 51 | if "Enum:" in arg.description: 52 | arg_description = arg.description.split("Enum:")[0].strip() 53 | enum = arg.description.split("Enum:")[1].strip() 54 | try: 55 | enum = ast.literal_eval(enum) 56 | d = {"description": arg_description, "enum": enum} 57 | except Exception: 58 | d = {"description": arg_description} 59 | else: 60 | d = {"description": arg.description} 61 | if arg.type_name in _types: 62 | args[arg.arg_name] = {"type": _types[arg.type_name]} 63 | else: 64 | args[arg.arg_name] = {"type": "string"} 65 | args[arg.arg_name].update(d) 66 | return {"description": description, "properties": args}, requires_context 67 | -------------------------------------------------------------------------------- /agent_dingo/agent/registry.py: -------------------------------------------------------------------------------- 1 | from typing import Callable, List, Optional, Tuple 2 | 3 | 4 | class Registry: 5 | """A registry for functions that can be called by the agent.""" 6 | 7 | def __init__(self): 8 | self.__functions = {} 9 | self._required_context_keys = [] 10 | 11 | def add( 12 | self, 13 | name: str, 14 | func: Callable, 15 | json_repr: dict, 16 | requires_context: bool, 17 | required_context_keys: Optional[List[str]] = None, 18 | ) -> None: 19 | """Adds a function to the registry. 20 | 21 | Parameters 22 | ---------- 23 | name : str 24 | The name of the function. 25 | func : Callable 26 | The function. 27 | json_repr : dict 28 | The JSON representation of the function to be provided to the LLM. 29 | requires_context : bool 30 | Indicates whether the function requires a ChatContext object as one of its arguments. 31 | """ 32 | if requires_context and required_context_keys is None: 33 | raise ValueError( 34 | "If requires_context is True, required_context_keys must be provided" 35 | ) 36 | self.__functions[name] = { 37 | "func": func, 38 | "json_repr": json_repr, 39 | "requires_context": requires_context, 40 | "required_context_keys": required_context_keys or [], 41 | } 42 | 43 | def get_function(self, name: str) -> Tuple[Callable, bool]: 44 | """Retrieves a function from the registry. 45 | 46 | Parameters 47 | ---------- 48 | name : str 49 | The name of the function. 50 | 51 | Returns 52 | ------- 53 | Tuple[Callable, bool] 54 | A tuple containing the function and a boolean indicating whether the function requires a ChatContext object as one of its arguments. 55 | """ 56 | try: 57 | return ( 58 | self.__functions[name]["func"], 59 | self.__functions[name]["requires_context"], 60 | ) 61 | except KeyError: 62 | return ( 63 | ( 64 | lambda *args, **kwargs: f"Error: function `{name}` is not available. Most likely, the name is incorrect." 65 | ), 66 | False, 67 | ) 68 | 69 | def get_available_functions(self) -> List[dict]: 70 | """Returns a list of JSON representations of the functions in the registry. 71 | 72 | Returns 73 | ------- 74 | List[dict] 75 | A list of JSON representations of the functions in the registry. 76 | """ 77 | return [self.__functions[name]["json_repr"] for name in self.__functions] 78 | 79 | def get_required_context_keys(self) -> List[str]: 80 | """Returns a list of keys that are required in the ChatContext object. 81 | 82 | Returns 83 | ------- 84 | List[str] 85 | A list of keys that are required in the ChatContext object. 86 | """ 87 | keys = [] 88 | for f in self.__functions.values(): 89 | keys.extend(f["required_context_keys"]) 90 | return keys 91 | -------------------------------------------------------------------------------- /agent_dingo/core/blocks.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | from typing import Any, Coroutine, Optional, Union, List, Dict 3 | from abc import ABC, abstractmethod 4 | from agent_dingo.core.message import Message 5 | from agent_dingo.core.state import State, ChatPrompt, KVData, Context, Store, UsageMeter 6 | from agent_dingo.core.output_parser import BaseOutputParser, DefaultOutputParser 7 | import re 8 | import joblib 9 | import inspect 10 | import warnings 11 | 12 | 13 | import os 14 | 15 | if os.environ.get("DINGO_ALLOW_NESTED_ASYNCIO", False): 16 | import nest_asyncio 17 | 18 | nest_asyncio.apply() 19 | 20 | from asyncio import ( 21 | to_thread, 22 | gather, 23 | run as asyncio_run, 24 | ) 25 | 26 | 27 | class Block(ABC): 28 | """Base building block of a pipeline""" 29 | 30 | @abstractmethod 31 | def forward(self, state: Optional[State], context: Context, store: Store) -> State: 32 | pass 33 | 34 | async def async_forward( 35 | self, state: Optional[State], context: Context, store: Store 36 | ) -> State: 37 | # ideally, this is never called as all children should implement async forward 38 | warnings.warn( 39 | f"Called async_forward, but {self.__class__} block does not have an async implementation." 40 | ) 41 | return await to_thread(self.forward, state, context, store) 42 | 43 | @abstractmethod 44 | def get_required_context_keys(self) -> List[str]: 45 | """Each block must specify the keys it requires from the context.""" 46 | pass 47 | 48 | def __rshift__(self, other: Block) -> Pipeline: 49 | return Pipeline() >> self >> other 50 | 51 | def __lshift__(self, other: Block) -> Pipeline: 52 | if isinstance(other, Pipeline): 53 | other.add_block(self) 54 | return other 55 | return Pipeline() >> other >> self 56 | 57 | def as_pipeline(self) -> Pipeline: 58 | return Pipeline() >> self 59 | 60 | def __and__(self, other: Block) -> Parallel: 61 | return Parallel() & self & other 62 | 63 | 64 | ######### REASONERS ######### 65 | 66 | 67 | class BaseReasoner(Block): 68 | """A reasoner is a block that takes a prompt and returns a KVData object.""" 69 | 70 | @abstractmethod 71 | def forward(self, state: ChatPrompt, context: Context, store: Store) -> KVData: 72 | pass 73 | 74 | 75 | class BaseLLM(BaseReasoner): 76 | """LLM is a type of reasoner that directly interacts with a language model.""" 77 | 78 | supports_function_calls = False 79 | 80 | @abstractmethod 81 | def send_message( 82 | self, messages, functions=None, usage_meter: UsageMeter = None, **kwargs 83 | ): 84 | pass 85 | 86 | @abstractmethod 87 | async def async_send_message( 88 | self, messages, functions=None, usage_meter: UsageMeter = None, **kwargs 89 | ): 90 | pass 91 | 92 | def forward(self, state: ChatPrompt, context: Context, store: Store) -> KVData: 93 | if not isinstance(state, ChatPrompt): 94 | raise TypeError(f"State must be a ChatPrompt, got {type(state)}") 95 | new_state = KVData( 96 | _out_0=self.process_prompt(state, usage_meter=store.usage_meter) 97 | ) 98 | return new_state 99 | 100 | async def async_forward(self, state: State | None, context: Context, store: Store): 101 | if not isinstance(state, ChatPrompt): 102 | raise TypeError(f"State must be a ChatPrompt, got {type(state)}") 103 | new_state = KVData( 104 | _out_0=await self.async_process_prompt(state, usage_meter=store.usage_meter) 105 | ) 106 | return new_state 107 | 108 | def __call__(self, state: ChatPrompt) -> str: 109 | return self.process_prompt(state) 110 | 111 | def get_required_context_keys(self) -> List[str]: 112 | return [] 113 | 114 | def process_prompt( 115 | self, prompt: ChatPrompt, usage_meter: Optional[UsageMeter] = None, **kwargs 116 | ): 117 | return self.send_message(prompt.dict, None, usage_meter)["content"] 118 | 119 | async def async_process_prompt( 120 | self, prompt: ChatPrompt, usage_meter: Optional[UsageMeter] = None, **kwargs 121 | ): 122 | return (await self.async_send_message(prompt.dict, None, usage_meter))[ 123 | "content" 124 | ] 125 | 126 | 127 | class BaseAgent(BaseReasoner): 128 | """An agent is a type of reasoner that can autonomously perform multi-step reasoning.""" 129 | 130 | pass 131 | 132 | 133 | ######### KVData Processors ######### 134 | 135 | 136 | class BaseKVDataProcessor(Block): 137 | """KVDataProcessor is a block that takes a KVData object and returns a KVData object.""" 138 | 139 | @abstractmethod 140 | def forward(self, state: KVData, context: Context, store: Store) -> KVData: 141 | pass 142 | 143 | 144 | class Squash(BaseKVDataProcessor): 145 | def __init__(self, template: str): 146 | """Squash block takes a KVData with multiple keys and squashes them into a single key using a template string. 147 | 148 | Parameters 149 | ---------- 150 | template : str 151 | Template string with (in-order) placeholders for each key in the KVData object. 152 | """ 153 | self.template = template 154 | 155 | def forward(self, state: KVData, context: Context, store: Store) -> KVData: 156 | return KVData(_out_0=self.template.format(*state.values())) 157 | 158 | async def async_forward(self, state: KVData, context: Context, store: Store): 159 | return self.forward(state, context, store) 160 | 161 | def get_required_context_keys(self) -> List[str]: 162 | return [] 163 | 164 | 165 | ######### Prompt Builders ######### 166 | 167 | 168 | class BasePromptBuilder(Block): 169 | """PromptBuilder is a block that takes a KVData object (or None) and returns a ChatPrompt.""" 170 | 171 | @abstractmethod 172 | def forward( 173 | self, state: Optional[KVData], context: Context, store: Store 174 | ) -> ChatPrompt: 175 | pass 176 | 177 | 178 | class PromptBuilder(BasePromptBuilder): 179 | def __init__( 180 | self, 181 | messages: list[Message], 182 | from_state: Optional[Union[List[str], Dict[str, str]]] = None, 183 | from_store: Optional[Union[List[str], Dict[str, str]]] = None, 184 | ): 185 | """ 186 | PromptBuilder formats the list of messages (templates) with values from the state, store and context. 187 | 188 | Parameters 189 | ---------- 190 | messages : list[Message] 191 | List of message templates to format. 192 | from_state : Optional[Union[List[str], Dict[str, str]]], optional 193 | List of strings or mapping template->state that defines which placeholders should be populated by state values, by default None 194 | from_store : Optional[Union[List[str], Dict[str, str]]], optional 195 | List of strings or mapping template->store (where the store key is formated as .) that defines which placeholders should be populated by state values, by default None 196 | """ 197 | self.messages = messages 198 | self._from_state_keys = [] 199 | self._from_store_keys = [] 200 | if from_state is None: 201 | self._from_state = {} 202 | elif isinstance(from_state, list): 203 | self._from_state = {} 204 | self._from_state_keys.extend(from_state) 205 | for i, k in enumerate(from_state): 206 | self._from_state[k] = f"_out_{i}" 207 | elif isinstance(from_state, dict): 208 | self._from_state = from_state 209 | self._from_state_keys.extend(from_state.keys()) 210 | else: 211 | raise TypeError( 212 | f"from_state must be a list or dict, got {type(from_state)}" 213 | ) 214 | if from_store is None: 215 | self._from_store = {} 216 | elif isinstance(from_store, list): 217 | self._from_store = {} 218 | self._from_store_keys.extend(from_store) 219 | for i, k in enumerate(from_store): 220 | self._from_store[k] = f"_out_{i}" 221 | elif isinstance(from_store, dict): 222 | self._from_store = from_store 223 | self._from_store_keys = from_store.keys() 224 | else: 225 | raise TypeError( 226 | f"from_store must be a list or dict, got {type(from_store)}" 227 | ) 228 | 229 | self._placeholder_names = self._get_placeholder_names() 230 | 231 | def _get_placeholder_names(self) -> set: 232 | placeholder_pattern = r"\{(\w+)\}" 233 | placeholder_names = set() 234 | 235 | for message in self.messages: 236 | found_placeholders = re.findall(placeholder_pattern, message.content) 237 | placeholder_names.update(found_placeholders) 238 | 239 | return placeholder_names 240 | 241 | def forward( 242 | self, state: Optional[KVData], context: Context, store: Store 243 | ) -> ChatPrompt: 244 | values = {} 245 | for n in self._placeholder_names: 246 | if n in self._from_state.keys(): 247 | values[n] = state[self._from_state[n]] 248 | elif n in self._from_store.keys(): 249 | if "." in self._from_store[n]: 250 | outer, inner = self._from_store[n].split(".") 251 | values[n] = store.get_data(outer)[inner] 252 | else: 253 | raise ValueError( 254 | "Store key must be formatted as ." 255 | ) 256 | elif n in context.keys(): 257 | values[n] = context[n] 258 | else: 259 | raise KeyError(f"Could not find value for placeholder {n}") 260 | updated_messages = [type(m)(m.content.format(**values)) for m in self.messages] 261 | return ChatPrompt(updated_messages) 262 | 263 | async def async_forward(self, state: KVData, context: Context, store: Store): 264 | return self.forward(state, context, store) 265 | 266 | def get_required_context_keys(self) -> List[str]: 267 | keys = [] 268 | for n in self._placeholder_names: 269 | if n not in self._from_state_keys and n not in self._from_store_keys: 270 | keys.append(n) 271 | return keys 272 | 273 | 274 | ######### Prompt Modifiers ######### 275 | 276 | 277 | class BasePromptModifier(Block): 278 | """A prompt modifier is a block that takes a ChatPrompt and returns a ChatPrompt.""" 279 | 280 | @abstractmethod 281 | def forward(self, state: ChatPrompt, context: Context, store: Store) -> ChatPrompt: 282 | pass 283 | 284 | 285 | ######### Special Blocks ######### 286 | 287 | 288 | class Pipeline(Block): 289 | def __init__(self, output_parser: Optional[BaseOutputParser] = None): 290 | """ 291 | A pipeline is a sequence of blocks that are executed in order. 292 | The pipeline itself is a block that can be used in other pipelines. 293 | 294 | Parameters 295 | ---------- 296 | output_parser : Optional[BaseOutputParser], optional 297 | custom output parser of the last step, by default None 298 | """ 299 | self.output_parser: BaseOutputParser = output_parser or DefaultOutputParser() 300 | self._blocks = [] 301 | 302 | def add_block(self, block: Block): 303 | """ 304 | Add a block to the pipeline. 305 | 306 | Parameters 307 | ---------- 308 | block : Block 309 | Block to add to the pipeline. 310 | """ 311 | if not isinstance(block, Block): 312 | raise TypeError(f"Expected a Block, got {type(block)}") 313 | self._blocks.append(block) 314 | 315 | def forward(self, state: Optional[State], context: Context, store: Store) -> State: 316 | running_state = state 317 | for block in self._blocks: 318 | running_state = block.forward( 319 | state=running_state, context=context, store=store 320 | ) 321 | return running_state 322 | 323 | async def async_forward( 324 | self, state: Optional[State], context: Context, store: Store 325 | ) -> State: 326 | running_state = state 327 | for block in self._blocks: 328 | running_state = await block.async_forward( 329 | state=running_state, context=context, store=store 330 | ) 331 | return running_state 332 | 333 | def run(self, _state: Optional[State] = None, **kwargs: Dict[str, str]): 334 | """ 335 | Runs the pipeline with the given state and context (populated with kwargs). 336 | Each run initializes a new empty store. 337 | The output of the last block is parsed using the output_parser and returned. 338 | 339 | Parameters 340 | ---------- 341 | _state : Optional[State], optional 342 | initial state, by default None 343 | """ 344 | context = Context(**kwargs) 345 | store = Store() 346 | out = self.forward(state=_state, context=context, store=store) 347 | return self.output_parser.parse(out), store.usage_meter.get_usage() 348 | 349 | async def async_run( 350 | self, _state: Optional[State] = None, **kwargs: Dict[str, str] 351 | ) -> str: 352 | context = Context(**kwargs) 353 | store = Store() 354 | out = await self.async_forward(state=_state, context=context, store=store) 355 | return self.output_parser.parse(out), store.usage_meter.get_usage() 356 | 357 | def __rshift__(self, other: Block) -> Pipeline: 358 | self.add_block(other) 359 | return self 360 | 361 | def get_required_context_keys(self) -> List[str]: 362 | keys = [] 363 | for block in self._blocks: 364 | keys.extend(block.get_required_context_keys()) 365 | return keys 366 | 367 | 368 | class Parallel(Block): 369 | def __init__(self): 370 | """ 371 | A parallel block executes multiple sub-blocks in parallel. The output of each block is stored as a separate key in the KVData object. 372 | """ 373 | self.blocks = [] 374 | 375 | def add_block(self, block: Block): 376 | """ 377 | Add a block to the parallel block. 378 | 379 | Parameters 380 | ---------- 381 | block : Block 382 | Block to add. 383 | """ 384 | if not isinstance(block, Block): 385 | raise TypeError(f"Expected a Block, got {type(block)}") 386 | self.blocks.append(block) 387 | 388 | def forward(self, state: Optional[State], context: Context, store: Store) -> State: 389 | # run all blocks in parallel 390 | states = joblib.Parallel(n_jobs=len(self.blocks), backend="threading")( 391 | joblib.delayed(block.forward)(state=state, context=context, store=store) 392 | for block in self.blocks 393 | ) 394 | out = {} 395 | for i, state in enumerate(states): 396 | if i == 0 and isinstance(state, ChatPrompt): 397 | # allow a special case where the first block returns a ChatPrompt 398 | # the ouput of remaining branches will be ignored 399 | return state 400 | if not isinstance(state, KVData): 401 | raise TypeError( 402 | f"Expected KVData, got {type(state)} from block {i} of {len(states)}" 403 | ) 404 | if len(state.keys()) != 1: 405 | raise ValueError( 406 | f"Expected KVData with one key `_out_0`, got {len(state.keys())} keys from block {i} of {len(states)}" 407 | ) 408 | out[f"_out_{i}"] = state["_out_0"] 409 | return KVData(**out) 410 | 411 | async def async_forward( 412 | self, state: Optional[State], context: Context, store: Store 413 | ) -> State: 414 | tasks = [block.async_forward(state, context, store) for block in self.blocks] 415 | states = await gather(*tasks) 416 | 417 | out = {} 418 | for i, state in enumerate(states): 419 | if i == 0 and isinstance(state, ChatPrompt): 420 | return state 421 | if not isinstance(state, KVData): 422 | raise TypeError( 423 | f"Expected KVData, got {type(state)} from block {i} of {len(states)}" 424 | ) 425 | if len(state.keys()) != 1: 426 | raise ValueError( 427 | f"Expected KVData with one key `_out_0`, got {len(state.keys())} keys from block {i} of {len(states)}" 428 | ) 429 | out[f"_out_{i}"] = state["_out_0"] 430 | return KVData(**out) 431 | 432 | def __and__(self, other: Block) -> Parallel: 433 | self.add_block(other) 434 | return self 435 | 436 | def get_required_context_keys(self) -> List[str]: 437 | keys = [] 438 | for block in self.blocks: 439 | keys.extend(block.get_required_context_keys()) 440 | 441 | 442 | class Identity(Block): 443 | """NO-OP block that returns the input state as is.""" 444 | 445 | def forward(self, state: Optional[State], context: Context, store: Store) -> State: 446 | return state 447 | 448 | async def async_forward( 449 | self, state: Optional[State], context: Context, store: Store 450 | ) -> State: 451 | return state 452 | 453 | def get_required_context_keys(self) -> List[str]: 454 | return [] 455 | 456 | 457 | class SaveState(Block): 458 | def __init__(self, key: str): 459 | """Saves the current state to the store. 460 | 461 | Parameters 462 | ---------- 463 | key : str 464 | Key to save the state under. 465 | """ 466 | self.key = key 467 | 468 | def get_required_context_keys(self) -> List[str]: 469 | return [] 470 | 471 | def forward(self, state: State | None, context: Context, store: Store) -> State: 472 | store.update(self.key, state) 473 | return state 474 | 475 | async def async_forward( 476 | self, state: State | None, context: Context, store: Store 477 | ) -> State: 478 | return self.forward(state, context, store) 479 | 480 | 481 | class LoadState(Block): 482 | def __init__(self, from_: str, key: str): 483 | """ 484 | Loads the state from the store. 485 | 486 | Parameters 487 | ---------- 488 | from_ : str 489 | Defines whether to load from a Prompt of KVData section of the store. 490 | key : str 491 | Key to load the state from. 492 | """ 493 | if from_ not in ["prompts", "data"]: 494 | raise ValueError(f"from_ must be 'store' or 'context', got {from_}") 495 | self.from_ = from_ 496 | self.key = key 497 | 498 | def get_required_context_keys(self) -> List[str]: 499 | return [] 500 | 501 | def forward(self, state: State | None, context: Context, store: Store) -> State: 502 | if self.from_ == "prompts": 503 | return store.get_prompt(self.key) 504 | elif self.from_ == "data": 505 | return store.get_data(self.key) 506 | else: 507 | raise ValueError(f"from_ must be 'store' or 'context', got {self.from_}") 508 | 509 | async def async_forward( 510 | self, state: State | None, context: Context, store: Store 511 | ) -> State: 512 | return self.forward(state, context, store) 513 | 514 | 515 | class InlineBlock(Block): 516 | def __init__(self, required_context_keys: Optional[List[str]] = None): 517 | """A decorator to convert a function into an inline block. 518 | 519 | Parameters 520 | ---------- 521 | required_context_keys : Optional[List[str]], optional 522 | specifies the context keys required by the function, by default None 523 | """ 524 | self.required_context_keys = required_context_keys or [] 525 | self.func = None 526 | 527 | def get_required_context_keys(self) -> List[str]: 528 | return self.required_context_keys 529 | 530 | def __call__(self, func): 531 | self.func = func 532 | return self 533 | 534 | def _get_output(self, out) -> State: 535 | if isinstance(out, State.__args__): 536 | return out 537 | elif isinstance(out, dict): 538 | return KVData(**out) 539 | elif isinstance(out, str): 540 | return KVData(_out_0=out) 541 | elif isinstance(out, (list, tuple)): 542 | return KVData(**{f"_out_{i}": v for i, v in enumerate(out)}) 543 | raise TypeError(f"Expected a State, dict, str, or list, got {type(out)}") 544 | 545 | def forward(self, state: State | None, context: Context, store: Store) -> State: 546 | if inspect.iscoroutinefunction(self.func): 547 | warnings.warn(f"Called forward on an async inline block.") 548 | out = asyncio_run(self.func(state, context, store)) 549 | else: 550 | out = self.func(state, context, store) 551 | return self._get_output(out) 552 | 553 | async def async_forward( 554 | self, state: State | None, context: Context, store: Store 555 | ) -> State: 556 | if inspect.iscoroutinefunction(self.func): 557 | out = await self.func(state, context, store) 558 | else: 559 | warnings.warn(f"Called async_forward on a non-async inline block.") 560 | out = await to_thread(self.func, state, context, store) 561 | return self._get_output(out) 562 | -------------------------------------------------------------------------------- /agent_dingo/core/message.py: -------------------------------------------------------------------------------- 1 | class Message: 2 | """A base class to represent a message.""" 3 | 4 | role: str = "undefined" 5 | 6 | def __init__(self, content: str): 7 | self.content = content 8 | 9 | def __repr__(self): 10 | return f'Message(role="{self.role}" content="{self.content}")' 11 | 12 | @property 13 | def dict(self): 14 | return {"role": self.role, "content": self.content} 15 | 16 | 17 | class UserMessage(Message): 18 | role = "user" 19 | pass 20 | 21 | 22 | class SystemMessage(Message): 23 | role = "system" 24 | pass 25 | 26 | 27 | class AssistantMessage(Message): 28 | role = "assistant" 29 | pass 30 | -------------------------------------------------------------------------------- /agent_dingo/core/output_parser.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from agent_dingo.core.state import State, ChatPrompt, KVData 3 | 4 | 5 | class BaseOutputParser(ABC): 6 | """Base class for output parsers.""" 7 | 8 | def parse(self, output: State) -> str: 9 | """Delegates the parsing to the appropriate method based on the type of output. 10 | 11 | Parameters 12 | ---------- 13 | output : State 14 | The state object to parse. 15 | 16 | Returns 17 | ------- 18 | str 19 | parsed output 20 | """ 21 | if not isinstance(output, State.__args__): 22 | raise TypeError(f"Expected KVData, got {type(output)}") 23 | elif isinstance(output, KVData): 24 | return self._parse_kvdata(output) 25 | else: 26 | return self._parse_chat(output) 27 | 28 | @abstractmethod 29 | def _parse_chat(self, output: ChatPrompt) -> str: 30 | pass 31 | 32 | @abstractmethod 33 | def _parse_kvdata(self, output: KVData) -> str: 34 | pass 35 | 36 | 37 | class DefaultOutputParser(BaseOutputParser): 38 | """Default output parser. Expected output is a KVData with the key `_out_0`.""" 39 | 40 | def _parse_chat(self, output: ChatPrompt) -> str: 41 | raise RuntimeError( 42 | "Cannot parse chat output. The output must be an instance of KVData." 43 | ) 44 | 45 | def _parse_kvdata(self, output: KVData) -> str: 46 | if "_out_0" in output.keys(): 47 | return output["_out_0"] 48 | else: 49 | raise KeyError("Could not find output in KVData.") 50 | -------------------------------------------------------------------------------- /agent_dingo/core/state.py: -------------------------------------------------------------------------------- 1 | from typing import Union, List, Any 2 | from agent_dingo.core.message import Message 3 | from threading import Lock 4 | 5 | 6 | class ChatPrompt: 7 | def __init__(self, messages: List[Message]): 8 | """A collection of messages that are sent to the model. 9 | 10 | Parameters 11 | ---------- 12 | messages : List[Message] 13 | The messages to send to the model. 14 | """ 15 | self.messages = messages 16 | 17 | @property 18 | def dict(self): 19 | return [m.dict for m in self.messages] 20 | 21 | def __repr__(self): 22 | return f"ChatPrompt({self.messages})" 23 | 24 | 25 | class KVData: 26 | def __init__(self, **kwargs): 27 | """A dictionary-like object that stores key-value pairs.""" 28 | self._dict = {} 29 | for k, v in kwargs.items(): 30 | self._dict[k] = v 31 | 32 | def update(self, key: str, value: str): 33 | """Update the value of a key. 34 | 35 | Parameters 36 | ---------- 37 | key : str 38 | key to update 39 | value : str 40 | value to set 41 | """ 42 | if not isinstance(key, str) or not isinstance(value, str): 43 | raise TypeError("Both key and value must be strings.") 44 | # make existing keys immutable 45 | if key in self._dict: 46 | raise KeyError(f"Key {key} already exists.") 47 | self._dict[key] = value 48 | 49 | def __getitem__(self, key): 50 | return self._dict[key] 51 | 52 | def __repr__(self): 53 | return "KVData({0})".format(str(self._dict)) 54 | 55 | def __dict__(self): 56 | return self._dict 57 | 58 | def keys(self): 59 | return self._dict.keys() 60 | 61 | def values(self): 62 | return self._dict.values() 63 | 64 | @property 65 | def dict(self): 66 | return self._dict.copy() 67 | 68 | 69 | State = Union[ChatPrompt, KVData] 70 | 71 | 72 | class Context(KVData): 73 | def update(self, key, value): 74 | raise RuntimeError("Context is immutable.") 75 | 76 | 77 | class UsageMeter: 78 | def __init__(self): 79 | """An object that resides in the store and keeps track of the usage. It is updated by the LLMs.""" 80 | self.prompt_tokens = 0 81 | self.completion_tokens = 0 82 | self.last_finish_reason = None 83 | self._lock = Lock() 84 | 85 | def increment(self, prompt_tokens: int, completion_tokens: int) -> None: 86 | with self._lock: 87 | self.prompt_tokens += prompt_tokens 88 | self.completion_tokens += completion_tokens 89 | 90 | def get_usage(self) -> dict: 91 | return { 92 | "prompt_tokens": self.prompt_tokens, 93 | "completion_tokens": self.completion_tokens, 94 | "total_tokens": self.prompt_tokens + self.completion_tokens, 95 | } 96 | 97 | 98 | class Store: 99 | def __init__(self): 100 | """A simple key-value store that stores prompts, data, and other miscellaneous objects for the duration of a single pipeline run.""" 101 | self._data = {} 102 | self._prompts = {} 103 | self._misc = {} 104 | self.usage_meter = UsageMeter() 105 | self._lock = Lock() # probably not really needed 106 | 107 | def _update(self, key: str, item): 108 | if not isinstance(key, str): 109 | raise TypeError("Key must be a string.") 110 | if isinstance(item, ChatPrompt): 111 | self._prompts[key] = item 112 | elif isinstance(item, KVData): 113 | self._data[key] = item 114 | else: 115 | self._misc[key] = item 116 | 117 | def update(self, key: str, item: Any): 118 | """Update the store with a new item. 119 | 120 | Parameters 121 | ---------- 122 | key : str 123 | A key to store the item under. 124 | item : Any 125 | The item to store. 126 | """ 127 | with self._lock: 128 | self._update(key, item) 129 | 130 | def get_misc(self, key: str): 131 | with self._lock: 132 | return self._misc[key] 133 | 134 | def get_data(self, key: str) -> KVData: 135 | with self._lock: 136 | return self._data[key] 137 | 138 | def get_prompt(self, key: str) -> ChatPrompt: 139 | with self._lock: 140 | return self._prompts[key] 141 | -------------------------------------------------------------------------------- /agent_dingo/llm/gemini.py: -------------------------------------------------------------------------------- 1 | try: 2 | from vertexai import init as _vertex_init 3 | from vertexai.generative_models import ( 4 | Content, 5 | FunctionDeclaration, 6 | GenerativeModel, 7 | Part, 8 | Tool, 9 | ) 10 | except ImportError: 11 | raise ImportError( 12 | "VertexAI is not installed. Please install it using `pip install agent-dingo[vertexai]`" 13 | ) 14 | from typing import Optional, List 15 | from agent_dingo.core.blocks import BaseLLM 16 | from agent_dingo.core.state import UsageMeter 17 | import json 18 | 19 | _ROLES_MAP = { 20 | "user": "USER", 21 | "system": "USER", 22 | "assistant": "MODEL", 23 | } 24 | 25 | 26 | class Gemini(BaseLLM): 27 | def __init__( 28 | self, model: str, project: str, location: str, temperature: float = 0.7 29 | ): 30 | """ 31 | VertexAI Gemini LLM. 32 | 33 | Parameters 34 | ---------- 35 | model : str 36 | model to use 37 | project : str 38 | project id to use 39 | location : str 40 | location to use 41 | temperature : float, optional 42 | generation temperature, by default 0.7 43 | """ 44 | _vertex_init(project=project, location=location) 45 | self._model = GenerativeModel(model) 46 | self.supports_function_calls = True 47 | self.temperature = temperature 48 | 49 | def _get_tools(self, functions: Optional[List]) -> List[Tool]: 50 | if functions is None: 51 | return [] 52 | declarations: List[FunctionDeclaration] = [] 53 | for f in functions: 54 | declaration = FunctionDeclaration( 55 | name=f["name"], 56 | description=f["description"], 57 | parameters=f["parameters"], 58 | ) 59 | declarations.append(declaration) 60 | tool = Tool(function_declarations=declarations) 61 | return [tool] 62 | 63 | def send_message( 64 | self, 65 | messages, 66 | functions=None, 67 | usage_meter: UsageMeter = None, 68 | temperature: Optional[float] = None, 69 | **kwargs, 70 | ): 71 | converted = self._openai_to_gemini(messages) 72 | out = self._model.generate_content( 73 | contents=converted, 74 | tools=self._get_tools(functions), 75 | generation_config={"temperature": temperature or self.temperature}, 76 | ) 77 | return self._postprocess_response(out, usage_meter) 78 | 79 | async def async_send_message( 80 | self, 81 | messages, 82 | functions=None, 83 | usage_meter: UsageMeter = None, 84 | temperature=None, 85 | **kwargs, 86 | ): 87 | converted = self._openai_to_gemini(messages) 88 | response = await self._model.generate_content_async( 89 | contents=converted, 90 | tools=self._get_tools(functions), 91 | generation_config={"temperature": temperature or self.temperature}, 92 | ) 93 | return self._postprocess_response(response, usage_meter) 94 | 95 | def _postprocess_response(self, response, usage_meter: UsageMeter = None): 96 | n_prompt_tokens = response._raw_response.usage_metadata.prompt_token_count 97 | n_completion_tokens = ( 98 | response._raw_response.usage_metadata.candidates_token_count 99 | ) 100 | if usage_meter: 101 | usage_meter.increment( 102 | prompt_tokens=n_prompt_tokens, 103 | completion_tokens=n_completion_tokens, 104 | ) 105 | return self._gemini_to_openai(response) 106 | 107 | def _openai_to_gemini(self, messages): 108 | """Converts OpenAI messages to VertexAI messages.""" 109 | 110 | converted = [] 111 | 112 | for message in messages: 113 | if "_cache" in message.keys(): 114 | converted.append(message["_cache"]) 115 | elif message["role"] in _ROLES_MAP.keys(): 116 | content = Content( 117 | role=_ROLES_MAP[message["role"]], 118 | parts=[ 119 | Part.from_text(message["content"]), 120 | ], 121 | ) 122 | converted.append(content) 123 | elif message["role"] == "tool": 124 | content = Content( 125 | role="function", 126 | parts=[ 127 | Part.from_function_response( 128 | name=message["tool_call_id"], 129 | response={ 130 | "content": message["content"], 131 | }, 132 | ) 133 | ], 134 | ) 135 | converted.append(content) 136 | else: 137 | raise ValueError(f"Invalid message {message}") 138 | return converted 139 | 140 | def _gemini_to_openai(self, response): 141 | """Converts the Gemini response to OpenAI response.""" 142 | try: 143 | content = response.candidates[0].content.parts[0].text 144 | except AttributeError: 145 | content = None 146 | function_call = response.candidates[0].function_calls 147 | 148 | transformed_calls = [] 149 | if len(function_call) > 0: 150 | for call in function_call: 151 | id_ = call.name 152 | type_ = "function" 153 | name = call.name 154 | args = {arg: call.args[arg] for arg in call.args} 155 | transformed_call = { 156 | "id": id_, 157 | "type": type_, 158 | "function": { 159 | "name": name, 160 | "arguments": json.dumps(args), 161 | }, 162 | } 163 | transformed_calls.append(transformed_call) 164 | 165 | return { 166 | "content": content, 167 | "role": "assistant", 168 | "tool_calls": transformed_calls, 169 | "_cache": response.candidates[0].content, 170 | } 171 | -------------------------------------------------------------------------------- /agent_dingo/llm/litellm.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Dict 2 | from agent_dingo.core.blocks import BaseLLM 3 | from agent_dingo.core.state import UsageMeter 4 | 5 | try: 6 | from litellm import completion, acompletion 7 | except ImportError: 8 | raise ImportError( 9 | "litellm is not installed. Please install it using `pip install agent-dingo[litellm]`" 10 | ) 11 | 12 | 13 | class LiteLLM(BaseLLM): 14 | def __init__( 15 | self, 16 | model: str, 17 | temperature: float = 0.7, 18 | completion_extra_kwargs: Optional[Dict] = None, 19 | ): 20 | """ 21 | Lite LLM client to interact with various LLM providers. 22 | 23 | Parameters 24 | ---------- 25 | model : str 26 | model to use 27 | temperature : float, optional 28 | generation temparature, by default 0.7 29 | completion_extra_kwargs : Optional[Dict], optional 30 | additional arguments to be passed to a completion method, by default None 31 | """ 32 | 33 | self.temperature = temperature 34 | self.model = model 35 | self.completion_extra_kwargs = completion_extra_kwargs or {} 36 | 37 | def send_message( 38 | self, 39 | messages, 40 | functions=None, 41 | usage_meter: UsageMeter = None, 42 | temperature: Optional[float] = None, 43 | **kwargs, 44 | ): 45 | response = completion( 46 | messages=messages, 47 | model=self.model, 48 | temperature=temperature or self.temperature, 49 | **self.completion_extra_kwargs, 50 | ) 51 | self._log_usage(response, usage_meter) 52 | return response["choices"][0]["message"] 53 | 54 | async def async_send_message( 55 | self, 56 | messages, 57 | functions=None, 58 | usage_meter: UsageMeter = None, 59 | temperature=None, 60 | **kwargs, 61 | ): 62 | response = await acompletion( 63 | messages=messages, 64 | model=self.model, 65 | temperature=temperature or self.temperature, 66 | **self.completion_extra_kwargs, 67 | ) 68 | self._log_usage(response, usage_meter) 69 | return response["choices"][0]["message"] 70 | 71 | def _log_usage(self, response, usage_meter: UsageMeter = None): 72 | if usage_meter: 73 | usage_meter.increment( 74 | prompt_tokens=response["usage"]["prompt_tokens"], 75 | completion_tokens=response["usage"]["completion_tokens"], 76 | ) 77 | -------------------------------------------------------------------------------- /agent_dingo/llm/llama_cpp.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Any 2 | from agent_dingo.core.blocks import BaseLLM 3 | from agent_dingo.core.state import UsageMeter 4 | 5 | try: 6 | from llama_cpp import Llama as _Llama 7 | except ImportError: 8 | raise ImportError( 9 | "Llama.cpp is not installed. Please install it using `pip install agent-dingo[llama-cpp]`" 10 | ) 11 | import threading 12 | from concurrent.futures import ThreadPoolExecutor 13 | import asyncio 14 | 15 | 16 | class LlamaCPP(BaseLLM): 17 | def __init__( 18 | self, model: str, temperature: float = 0.7, verbose: bool = False, **kwargs: Any 19 | ): 20 | """ 21 | Llama.cpp client to run LLMs in a GGUF format locally. 22 | 23 | Parameters 24 | ---------- 25 | model : str 26 | path to the model file 27 | temperature : float, optional 28 | model temperature, by default 0.7 29 | verbose : bool, optional 30 | flag to enable verbosity, by default False 31 | **kwargs : Any 32 | additional arguments to be passed to the model constructor 33 | """ 34 | 35 | self.model = _Llama(model, verbose=verbose, **kwargs) 36 | self.temperature = temperature 37 | self._lock = threading.Lock() 38 | self._executor: Optional[ThreadPoolExecutor] = None 39 | 40 | def send_message( 41 | self, 42 | messages, 43 | functions=None, 44 | usage_meter: UsageMeter = None, 45 | temperature: Optional[float] = None, 46 | **kwargs, 47 | ): 48 | with self._lock: 49 | response = self.model.create_chat_completion( 50 | messages, temperature=temperature or self.temperature 51 | ) 52 | self._log_usage(response, usage_meter) 53 | return response["choices"][0]["message"] 54 | 55 | async def async_send_message( 56 | self, 57 | messages, 58 | functions=None, 59 | usage_meter: UsageMeter = None, 60 | temperature=None, 61 | **kwargs, 62 | ): 63 | loop = asyncio.get_event_loop() 64 | response = await loop.run_in_executor( 65 | self._get_executor(), 66 | self.model.create_chat_completion, 67 | messages, 68 | temperature=temperature or self.temperature, 69 | ) 70 | self._log_usage(response, usage_meter) 71 | return response["choices"][0]["message"] 72 | 73 | def _get_executor(self): 74 | if self._executor is None: 75 | self._executor = ThreadPoolExecutor(max_workers=1) 76 | return self._executor 77 | 78 | def _log_usage(self, response, usage_meter: UsageMeter = None): 79 | if usage_meter: 80 | usage_meter.increment( 81 | prompt_tokens=response["usage"]["prompt_tokens"], 82 | completion_tokens=response["usage"]["completion_tokens"], 83 | ) 84 | -------------------------------------------------------------------------------- /agent_dingo/llm/openai.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, List 2 | from agent_dingo.core.blocks import BaseLLM 3 | from agent_dingo.core.state import UsageMeter 4 | import openai 5 | from tenacity import retry, stop_after_attempt, wait_fixed 6 | 7 | 8 | @retry(stop=stop_after_attempt(3), wait=wait_fixed(3)) 9 | def _send_message( 10 | client: openai.OpenAI, 11 | messages: dict, 12 | model: str = "gpt-3.5-turbo-0613", 13 | functions: Optional[List] = None, 14 | temperature: float = 1.0, 15 | ) -> dict: 16 | """Sends messages to the LLM and returns the response. 17 | 18 | Parameters 19 | ---------- 20 | messages : dict 21 | Messages to send to the LLM. 22 | model : str, optional 23 | Model to use, by default "gpt-3.5-turbo-0613" 24 | functions : Optional[List], optional 25 | List of functions to use, by default None 26 | temperature : float, optional 27 | Temperature to use, by default 1. 28 | log_usage : Callable, optional 29 | Function to log usage, by default None 30 | 31 | Returns 32 | ------- 33 | dict 34 | The response from the LLM. 35 | """ 36 | f = {} 37 | if functions is not None: 38 | f["tools"] = [{"type": "function", "function": f} for f in functions] 39 | f["tool_choice"] = "auto" 40 | response = client.chat.completions.create( 41 | model=model, messages=messages, temperature=temperature, **f 42 | ) 43 | return response.choices[0].message, response 44 | 45 | 46 | @retry(stop=stop_after_attempt(3), wait=wait_fixed(3)) 47 | async def _async_send_message( 48 | client: openai.AsyncOpenAI, 49 | messages: dict, 50 | model: str = "gpt-3.5-turbo-0613", 51 | functions: Optional[List] = None, 52 | temperature: float = 1.0, 53 | ) -> dict: 54 | f = {} 55 | if functions is not None: 56 | f["tools"] = [{"type": "function", "function": f} for f in functions] 57 | f["tool_choice"] = "auto" 58 | response = await client.chat.completions.create( 59 | model=model, messages=messages, temperature=temperature, **f 60 | ) 61 | return response.choices[0].message, response 62 | 63 | 64 | def to_dict(obj): 65 | if isinstance(obj, dict): 66 | return {k: to_dict(v) for k, v in obj.items()} 67 | elif isinstance(obj, list): 68 | return [to_dict(item) for item in obj] 69 | elif hasattr(obj, "__dict__"): 70 | return {k: to_dict(v) for k, v in obj.__dict__.items() if not k.startswith("_")} 71 | else: 72 | return obj 73 | 74 | 75 | class OpenAI(BaseLLM): 76 | def __init__( 77 | self, 78 | model: str, 79 | temperature: float = 0.7, 80 | base_url: Optional[str] = None, 81 | # TODO: Add per instance API key 82 | # TODO: Add remaining generation parameters 83 | ): 84 | """ 85 | OpenAI client to interact with compatible APIs. 86 | 87 | Parameters 88 | ---------- 89 | model : str 90 | model to use 91 | temperature : float, optional 92 | generation temperature, by default 0.7 93 | base_url : Optional[str], optional 94 | _description_, by default None 95 | """ 96 | self.model = model 97 | self.temperature = temperature 98 | self.client = openai.OpenAI(base_url=base_url) 99 | self.async_client = openai.AsyncOpenAI(base_url=base_url) 100 | if base_url is None: 101 | self.supports_function_calls = True 102 | 103 | def send_message( 104 | self, 105 | messages, 106 | functions=None, 107 | usage_meter: UsageMeter = None, 108 | temperature=None, 109 | **kwargs, 110 | ): 111 | response = _send_message( 112 | client=self.client, 113 | messages=messages, 114 | model=self.model, 115 | functions=functions, 116 | temperature=temperature or self.temperature, 117 | ) 118 | return self._postprocess_response(response, usage_meter) 119 | 120 | async def async_send_message( 121 | self, 122 | messages, 123 | functions=None, 124 | usage_meter: UsageMeter = None, 125 | temperature=None, 126 | **kwargs, 127 | ): 128 | response = await _async_send_message( 129 | client=self.async_client, 130 | messages=messages, 131 | model=self.model, 132 | functions=functions, 133 | temperature=temperature or self.temperature, 134 | ) 135 | return self._postprocess_response(response, usage_meter) 136 | 137 | def _postprocess_response(self, response, usage_meter: UsageMeter = None): 138 | res, full_res = to_dict(response[0]), to_dict(response[1]) 139 | if usage_meter: 140 | usage_meter.increment( 141 | prompt_tokens=full_res["usage"]["prompt_tokens"], 142 | completion_tokens=full_res["usage"]["completion_tokens"], 143 | ) 144 | if "function_call" in res.keys(): 145 | del res["function_call"] 146 | return res 147 | -------------------------------------------------------------------------------- /agent_dingo/rag/base.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from typing import List, Optional 3 | from dataclasses import dataclass 4 | import hashlib 5 | import json 6 | from typing import Union 7 | 8 | 9 | @dataclass 10 | class Document: 11 | content: str 12 | metadata: dict 13 | 14 | @property 15 | def hash(self) -> str: 16 | metadata = json.dumps(self.metadata, sort_keys=True) 17 | return hashlib.sha256((self.content + metadata).encode()).hexdigest() 18 | 19 | 20 | @dataclass 21 | class Chunk: 22 | content: str 23 | parent: Document 24 | embedding: Optional[List[str]] = None 25 | 26 | @property 27 | def payload(self): 28 | return {"content": self.content, "document_metadata": self.parent.metadata} 29 | 30 | @property 31 | def hash(self) -> str: 32 | parent_hash = self.parent.hash 33 | embdding_hash = ( 34 | hashlib.sha256(json.dumps(self.embedding).encode()).hexdigest() 35 | if self.embedding 36 | else "" 37 | ) 38 | content_hash = hashlib.sha256(self.content.encode()).hexdigest() 39 | return hashlib.sha256( 40 | (parent_hash + embdding_hash + content_hash).encode() 41 | ).hexdigest() 42 | 43 | 44 | @dataclass 45 | class RetrievedChunk: 46 | content: str 47 | document_metadata: dict 48 | score: float 49 | 50 | 51 | class BaseReader(ABC): 52 | @abstractmethod 53 | def read(self, *args, **kwargs) -> List[Document]: 54 | pass 55 | 56 | 57 | class BaseChunker(ABC): 58 | @abstractmethod 59 | def chunk(self, document: Document) -> List[Chunk]: 60 | pass 61 | 62 | 63 | class BaseEmbedder(ABC): 64 | batch_size: int = 1 65 | 66 | def embed_chunks(self, chunks: List[Chunk]): 67 | for i in range(0, len(chunks), self.batch_size): 68 | batch = chunks[i : i + self.batch_size] 69 | contents = [chunk.content for chunk in batch] 70 | embeddings = self.embed(contents) 71 | for chunk, embedding in zip(batch, embeddings): 72 | chunk.embedding = embedding 73 | 74 | @abstractmethod 75 | async def async_embed(self, texts: Union[str, List[str]]) -> List[List[float]]: 76 | pass 77 | 78 | @abstractmethod 79 | def embed(self, texts: Union[str, List[str]]) -> List[List[float]]: 80 | pass 81 | 82 | 83 | class BaseVectorStore(ABC): 84 | @abstractmethod 85 | def upsert_chunks(self, chunks: List[Chunk]): 86 | pass 87 | 88 | @abstractmethod 89 | def retrieve(self, k: int, embedding: List[float]) -> List[RetrievedChunk]: 90 | pass 91 | 92 | @abstractmethod 93 | async def async_retrieve( 94 | self, k: int, embedding: List[float] 95 | ) -> List[RetrievedChunk]: 96 | pass 97 | -------------------------------------------------------------------------------- /agent_dingo/rag/chunkers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BeastByteAI/agent_dingo/087a26a5096df5cb663fc20bd58c89ae89aadcb0/agent_dingo/rag/chunkers/__init__.py -------------------------------------------------------------------------------- /agent_dingo/rag/chunkers/recursive.py: -------------------------------------------------------------------------------- 1 | from agent_dingo.rag.base import BaseChunker, Document, Chunk 2 | from typing import List 3 | import re 4 | 5 | 6 | class RecursiveChunker(BaseChunker): 7 | def __init__( 8 | self, separators=None, chunk_size=512, keep_separator=False, merge_separator=" " 9 | ): 10 | if separators is None: 11 | separators = ["\n\n", "\n", " ", ""] 12 | self.separators = separators 13 | self.chunk_size = chunk_size 14 | self.keep_separator = keep_separator 15 | self.merge_separator = merge_separator 16 | 17 | def chunk(self, documents: List[Document]) -> List[Chunk]: 18 | all_chunks = [] 19 | for doc in documents: 20 | chunks = self._split_text_recursive(doc.content, self.separators) 21 | chunks = self._merge_small_chunks(chunks) 22 | all_chunks.extend([Chunk(content=chunk, parent=doc) for chunk in chunks]) 23 | return all_chunks 24 | 25 | def _split_text_recursive(self, text, separators): 26 | if not separators: 27 | return [text] 28 | 29 | separator = separators[0] 30 | pattern = re.escape(separator) 31 | split_chunks = re.split(pattern, text) 32 | 33 | final_chunks = [] 34 | for i, chunk in enumerate(split_chunks): 35 | appended_chunk = ( 36 | separator + chunk if self.keep_separator and i > 0 else chunk 37 | ) 38 | if len(appended_chunk) <= self.chunk_size: 39 | final_chunks.append(appended_chunk) 40 | else: 41 | final_chunks.extend( 42 | self._split_text_recursive(appended_chunk, separators[1:]) 43 | ) 44 | 45 | return final_chunks 46 | 47 | def _merge_small_chunks(self, chunks): 48 | merged_chunks = [] 49 | current_chunk = "" 50 | 51 | for chunk in chunks: 52 | new_chunk = current_chunk + ( 53 | self.merge_separator + chunk if current_chunk else chunk 54 | ) 55 | if len(new_chunk) <= self.chunk_size: 56 | current_chunk = new_chunk 57 | else: 58 | if current_chunk: 59 | merged_chunks.append(current_chunk) 60 | current_chunk = chunk 61 | 62 | if current_chunk: 63 | merged_chunks.append(current_chunk) 64 | 65 | return merged_chunks 66 | -------------------------------------------------------------------------------- /agent_dingo/rag/embedders/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BeastByteAI/agent_dingo/087a26a5096df5cb663fc20bd58c89ae89aadcb0/agent_dingo/rag/embedders/__init__.py -------------------------------------------------------------------------------- /agent_dingo/rag/embedders/openai.py: -------------------------------------------------------------------------------- 1 | import openai 2 | from agent_dingo.rag.base import BaseEmbedder 3 | from typing import Optional, List 4 | 5 | 6 | class OpenAIEmbedder(BaseEmbedder): 7 | def __init__( 8 | self, 9 | model: str = "text-embedding-3-small", 10 | base_url: Optional[str] = None, 11 | dimensions: Optional[int] = None, 12 | ): 13 | self.model = model 14 | self.client = openai.OpenAI(base_url=base_url) 15 | self.async_client = openai.AsyncOpenAI(base_url=base_url) 16 | self.params = { 17 | "model": self.model, 18 | } 19 | if dimensions: 20 | self.params["dimensions"] = dimensions 21 | 22 | def embed(self, texts: str) -> List[List[float]]: 23 | if isinstance(texts, str): 24 | texts = [texts] 25 | embeddings = [ 26 | i.embedding 27 | for i in (self.client.embeddings.create(**self.params, input=texts).data) 28 | ] 29 | return embeddings 30 | 31 | async def async_embed(self, texts: str) -> List[List[float]]: 32 | if isinstance(texts, str): 33 | texts = [texts] 34 | res = await self.async_client.embeddings.create(**self.params, input=texts) 35 | embeddings = [i.embedding for i in res.data] 36 | return embeddings 37 | -------------------------------------------------------------------------------- /agent_dingo/rag/embedders/sentence_transformer.py: -------------------------------------------------------------------------------- 1 | try: 2 | from sentence_transformers import SentenceTransformer as _SentenceTransformer 3 | except ImportError: 4 | raise ImportError( 5 | "SentenceTransformers is not installed. Please install it using `pip install agent-dingo[sentence-transformers]`" 6 | ) 7 | from agent_dingo.rag.base import BaseEmbedder 8 | from typing import List, Union 9 | import hashlib 10 | import concurrent.futures 11 | import asyncio 12 | import os 13 | 14 | 15 | class SentenceTransformer(BaseEmbedder): 16 | def __init__( 17 | self, model_name: str = "paraphrase-MiniLM-L6-v2", batch_size: int = 128 18 | ): 19 | self.model = _SentenceTransformer(model_name) 20 | self.model_name = model_name 21 | self._executor = None 22 | self.batch_size = batch_size 23 | os.environ["TOKENIZERS_PARALLELISM"] = os.environ.get( 24 | "TOKENIZERS_PARALLELISM", "false" 25 | ) 26 | 27 | def _prepare_executor(self) -> None: 28 | if self._executor is None: 29 | self._executor = concurrent.futures.ThreadPoolExecutor(max_workers=1) 30 | 31 | async def async_embed(self, text: Union[str, List[str]]) -> List[List[float]]: 32 | if isinstance(text, str): 33 | text = [text] 34 | self._prepare_executor() 35 | loop = asyncio.get_event_loop() 36 | return await loop.run_in_executor(self._executor, self.embed, text) 37 | 38 | def embed(self, texts: Union[str, List[str]]) -> List[List[float]]: 39 | if isinstance(texts, str): 40 | texts = [texts] 41 | embeddings = self.model.encode(texts) 42 | return [embedding.tolist() for embedding in embeddings] 43 | 44 | def hash(self) -> str: 45 | return hashlib.sha256( 46 | ("SentenceTransformer::" + self.model_name).encode() 47 | ).hexdigest() 48 | -------------------------------------------------------------------------------- /agent_dingo/rag/prompt_modifiers.py: -------------------------------------------------------------------------------- 1 | from agent_dingo.core.state import ChatPrompt, Context, Store 2 | from agent_dingo.core.message import UserMessage, SystemMessage 3 | from agent_dingo.core.blocks import BasePromptModifier as _BasePromptModifier 4 | from agent_dingo.rag.base import BaseEmbedder, BaseVectorStore, RetrievedChunk 5 | from typing import List, Optional 6 | from warnings import warn 7 | 8 | _DEFAULT_RAG_TEMPLATE = """ 9 | {original_message} 10 | 11 | Relevant documents: 12 | {documents} 13 | """ 14 | 15 | 16 | class RAGPromptModifier(_BasePromptModifier): 17 | def __init__( 18 | self, 19 | embedder: BaseEmbedder, 20 | vector_store: BaseVectorStore, 21 | n_chunks_to_retrieve: int = 5, 22 | retrieved_data_location: str = "system", 23 | rag_template: Optional[str] = None, 24 | ): 25 | if retrieved_data_location not in ["system", "user"]: 26 | raise ValueError( 27 | "retrieved_data_location must be one of 'system' or 'user'" 28 | ) 29 | self.embedder = embedder 30 | self.vector_store = vector_store 31 | self.retrieved_data_location = retrieved_data_location 32 | self.rag_template = rag_template or _DEFAULT_RAG_TEMPLATE 33 | self.n_chunks_to_retrieve = n_chunks_to_retrieve 34 | 35 | def forward(self, state: ChatPrompt, context: Context, store: Store) -> ChatPrompt: 36 | if not isinstance(state, ChatPrompt): 37 | raise ValueError("state must be a ChatPrompt") 38 | query = state.messages[-1].content 39 | query_embedding = self.embedder.embed(query)[0] 40 | try: 41 | retrieved_data = self.vector_store.retrieve( 42 | self.n_chunks_to_retrieve, 43 | query_embedding, 44 | ) 45 | except Exception as e: 46 | retrieved_data = [] 47 | warn("No data was retrieved") 48 | return self._forward(state, retrieved_data) 49 | 50 | async def async_forward( 51 | self, state: ChatPrompt, context: Context, store: Store 52 | ) -> ChatPrompt: 53 | if not isinstance(state, ChatPrompt): 54 | raise ValueError("state must be a ChatPrompt") 55 | query = state.messages[-1].content 56 | query_embedding = (await self.embedder.async_embed(query))[0] 57 | try: 58 | retrieved_data = await self.vector_store.async_retrieve( 59 | self.n_chunks_to_retrieve, 60 | query_embedding, 61 | ) 62 | except Exception as e: 63 | retrieved_data = [] 64 | warn("No data was retrieved") 65 | return self._forward(state, retrieved_data) 66 | 67 | def _forward( 68 | self, state: ChatPrompt, retrieved_data: List[RetrievedChunk] 69 | ) -> ChatPrompt: 70 | if len(retrieved_data) < 1: 71 | return state 72 | modified = False 73 | messages = [] 74 | target_message_type = ( 75 | SystemMessage if self.retrieved_data_location == "system" else UserMessage 76 | ) 77 | for message in state.messages: 78 | if isinstance(message, target_message_type) and not modified: 79 | modified_message = target_message_type( 80 | self.rag_template.format( 81 | original_message=message.content, 82 | documents="\n".join([str(i.__dict__) for i in retrieved_data]), 83 | ) 84 | ) 85 | modified = True 86 | else: 87 | modified_message = message.__class__(message.content) 88 | messages.append(modified_message) 89 | if not modified: 90 | raise ValueError( 91 | f"Could not find a {target_message_type.__name__} message to modify" 92 | ) 93 | return ChatPrompt(messages) 94 | 95 | def get_required_context_keys(self) -> List[str]: 96 | return [] 97 | -------------------------------------------------------------------------------- /agent_dingo/rag/readers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BeastByteAI/agent_dingo/087a26a5096df5cb663fc20bd58c89ae89aadcb0/agent_dingo/rag/readers/__init__.py -------------------------------------------------------------------------------- /agent_dingo/rag/readers/list.py: -------------------------------------------------------------------------------- 1 | from agent_dingo.rag.base import BaseReader as _BaseReader, Document 2 | from typing import List, Optional 3 | 4 | 5 | class ListReader(_BaseReader): 6 | def read(self, inputs: List[str]) -> List[Document]: 7 | docs = [] 8 | for i in inputs: 9 | if not isinstance(i, str): 10 | raise ValueError("ListReader only accepts lists of strings") 11 | docs.append(Document(i, {"source": "memory"})) 12 | return docs 13 | -------------------------------------------------------------------------------- /agent_dingo/rag/readers/pdf.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | try: 4 | from PyPDF2 import PdfReader 5 | except ImportError: 6 | raise ImportError( 7 | "PyPDF2 is not installed. Please install it using `pip install agent-dingo[rag_default]`" 8 | ) 9 | from agent_dingo.rag.base import BaseReader as _BaseReader, Document 10 | 11 | 12 | class PDFReader(_BaseReader): 13 | def read(self, file_path: str) -> List[Document]: 14 | docs = [] 15 | reader = PdfReader(file_path) 16 | for i, page in enumerate(reader.pages): 17 | text = page.extract_text() 18 | docs.append(Document(text, {"source": file_path, "page": i})) 19 | return docs 20 | -------------------------------------------------------------------------------- /agent_dingo/rag/readers/web.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | from agent_dingo.rag.base import BaseReader as _BaseReader, Document 3 | try: 4 | import requests 5 | from bs4 import BeautifulSoup 6 | except ImportError: 7 | raise ImportError( 8 | "requests or BeautifulSoup4 are not installed. Please install it using `pip install agent-dingo[rag_default]`" 9 | ) 10 | 11 | 12 | class WebpageReader(_BaseReader): 13 | def read(self, url: str) -> List[Document]: 14 | docs = [] 15 | response = requests.get(url) 16 | if response.status_code == 200: 17 | soup = BeautifulSoup(response.content, "html.parser") 18 | text = soup.get_text(" ") 19 | docs.append(Document(text, {"source": url})) 20 | else: 21 | raise ValueError(f"Error fetching {url}: {response.status_code}") 22 | return docs 23 | -------------------------------------------------------------------------------- /agent_dingo/rag/readers/word.py: -------------------------------------------------------------------------------- 1 | from agent_dingo.rag.base import BaseReader as _BaseReader, Document 2 | from typing import List 3 | try: 4 | import docx # python-docx 5 | except ImportError: 6 | raise ImportError( 7 | "python-docx is not installed. Please install it using `pip install agent-dingo[rag_default]`" 8 | ) 9 | 10 | 11 | class WordDocumentReader(_BaseReader): 12 | def read(self, file_path: str) -> List[Document]: 13 | docs = [] 14 | doc = docx.Document(file_path) 15 | for i, para in enumerate(doc.paragraphs): 16 | text = para.text 17 | docs.append(Document(text, {"source": file_path, "paragraph": i})) 18 | return docs 19 | -------------------------------------------------------------------------------- /agent_dingo/rag/vector_stores/chromadb.py: -------------------------------------------------------------------------------- 1 | from agent_dingo.rag.base import ( 2 | BaseVectorStore as _BaseVectorStore, 3 | Chunk, 4 | RetrievedChunk, 5 | ) 6 | from agent_dingo.utils import sha256_to_uuid 7 | from typing import Optional, List 8 | 9 | try: 10 | import chromadb 11 | except ImportError: 12 | raise ImportError( 13 | "Chromadb is not installed. Please install it using `pip install agent-dingo[chromadb]`" 14 | ) 15 | 16 | import asyncio 17 | from concurrent.futures import ThreadPoolExecutor 18 | 19 | 20 | class ChromaDB(_BaseVectorStore): 21 | def __init__( 22 | self, 23 | collection_name: str, 24 | path: Optional[str] = None, 25 | host: Optional[str] = None, 26 | port: Optional[int] = None, 27 | recreate_collection: bool = False, 28 | upsert_batch_size: int = 32, 29 | ): 30 | """ 31 | ChromaDB vector store. 32 | 33 | Parameters 34 | ---------- 35 | collection_name : str 36 | name of the collection to store the vectors 37 | path : Optional[str], optional 38 | path to the local database, by default None 39 | host : Optional[str], optional 40 | host, by default None 41 | port : Optional[int], optional 42 | port, by default None 43 | recreate_collection : bool, optional 44 | flag to control whether the collection should be recreated on init, by default False 45 | upsert_batch_size : int, optional 46 | batch size for upserting the documents, by default 32 47 | """ 48 | if path is not None and (host is not None or port is not None): 49 | raise ValueError("Either path or host/port must be specified, not both") 50 | if path is not None: 51 | self.client = chromadb.PersistentClient(path=path) 52 | else: 53 | self.client = chromadb.HttpClient(host=host, port=port) 54 | 55 | if recreate_collection: 56 | try: 57 | self.client.delete_collection(collection_name) 58 | except ValueError: 59 | pass 60 | 61 | self.collection = self.client.get_or_create_collection(collection_name) 62 | 63 | self.upsert_batch_size = upsert_batch_size 64 | 65 | self._executor = None 66 | 67 | def upsert_chunks(self, chunks: List[Chunk]): 68 | for i in range(0, len(chunks), self.upsert_batch_size): 69 | batch = chunks[i : i + self.upsert_batch_size] 70 | batch_contents = [chunk.content for chunk in batch] 71 | batch_ids = [sha256_to_uuid(chunk.hash) for chunk in batch] 72 | batch_embeddings = [chunk.embedding for chunk in batch] 73 | batch_metadata = [chunk.payload["document_metadata"] for chunk in batch] 74 | self.collection.upsert( 75 | ids=batch_ids, 76 | documents=batch_contents, 77 | embeddings=batch_embeddings, 78 | metadatas=batch_metadata, 79 | ) 80 | 81 | def retrieve(self, k: int, query: List[float]) -> List[RetrievedChunk]: 82 | search_result = self.collection.query( 83 | query_embeddings=query, 84 | n_results=k, 85 | include=["metadatas", "documents", "distances"], 86 | ) 87 | retrieved_chunks = [] 88 | for content, metadata, score in zip( 89 | search_result["documents"][0], 90 | search_result["metadatas"][0], 91 | search_result["distances"][0], 92 | ): 93 | retrieved_chunks.append(RetrievedChunk(content, metadata, score)) 94 | return retrieved_chunks 95 | 96 | async def async_retrieve(self, k: int, query: List[float]) -> List[RetrievedChunk]: 97 | loop = asyncio.get_event_loop() 98 | return await loop.run_in_executor(self._get_executor(), self.retrieve, k, query) 99 | 100 | def _get_executor(self) -> ThreadPoolExecutor: 101 | if self._executor is None: 102 | self._executor = ThreadPoolExecutor(max_workers=1) 103 | return self._executor 104 | -------------------------------------------------------------------------------- /agent_dingo/rag/vector_stores/qdrant.py: -------------------------------------------------------------------------------- 1 | from agent_dingo.rag.base import ( 2 | BaseVectorStore as _BaseVectorStore, 3 | Chunk, 4 | RetrievedChunk, 5 | ) 6 | from agent_dingo.utils import sha256_to_uuid 7 | 8 | try: 9 | from qdrant_client import QdrantClient, AsyncQdrantClient 10 | from qdrant_client.http import models 11 | except ImportError: 12 | raise ImportError( 13 | "Qdrant is not installed. Please install it using `pip install agenet-dingo[qdrant]`" 14 | ) 15 | from typing import Optional, List 16 | from warnings import warn 17 | 18 | 19 | class Qdrant(_BaseVectorStore): 20 | def __init__( 21 | self, 22 | collection_name: str, 23 | embedding_size: int, 24 | path: Optional[str] = None, 25 | host: Optional[str] = None, 26 | port: Optional[int] = None, 27 | url: Optional[str] = None, 28 | api_key: Optional[str] = None, 29 | recreate_collection: bool = False, 30 | upsert_batch_size: int = 32, 31 | try_init: bool = True, 32 | ): 33 | """ 34 | Qdrant vector store. 35 | 36 | Parameters 37 | ---------- 38 | collection_name : str 39 | collection name 40 | embedding_size : int 41 | size of the vector embeddings 42 | path : Optional[str], optional 43 | path to the local database, does not support concurrent clients, by default None 44 | host : Optional[str], optional 45 | host, by default None 46 | port : Optional[int], optional 47 | port, by default None 48 | url : Optional[str], optional 49 | base url of qdrant provider, by default None 50 | api_key : Optional[str], optional 51 | api key of qdrant provider, by default None 52 | recreate_collection : bool, optional 53 | flag to control whether the collection should be recreated on init, by default False 54 | upsert_batch_size : int, optional 55 | batch size for upserting the documents, by default 32 56 | try_init : bool, optional 57 | flag to control whether the collocetion should be created on object initialization, by default True 58 | 59 | Raises 60 | ------ 61 | an 62 | _description_ 63 | """ 64 | clint_params = { 65 | "host": host, 66 | "path": path, 67 | "port": port, 68 | "url": url, 69 | "api_key": api_key, 70 | } 71 | client_params = {k: v for k, v in clint_params.items() if v is not None} 72 | self.client_params = client_params 73 | self._client = None 74 | self._async_client = None 75 | self.collection_name = collection_name 76 | self.embedding_size = embedding_size 77 | self.recreate_collection = recreate_collection 78 | self.upsert_batch_size = upsert_batch_size 79 | if path and (try_init or recreate_collection): 80 | warn( 81 | "Using local Qdrant storage will only work in a synchronous environment. Trying to call async methods will raise an error." 82 | ) 83 | if try_init or recreate_collection: 84 | self._init_collection() 85 | 86 | def make_sync_client(self): 87 | self._client = ( 88 | QdrantClient(**self.client_params) 89 | if self.client_params 90 | else QdrantClient(":memory:") 91 | ) 92 | 93 | def make_async_client(self): 94 | self._async_client = ( 95 | AsyncQdrantClient(**self.client_params) 96 | if self.client_params 97 | else AsyncQdrantClient(":memory:") 98 | ) 99 | 100 | @property 101 | def client(self): 102 | if self._client is None: 103 | self.make_sync_client() 104 | return self._client 105 | 106 | @property 107 | def async_client(self): 108 | if self._async_client is None: 109 | self.make_async_client() 110 | return self._async_client 111 | 112 | def _init_collection(self): 113 | create_fn = ( 114 | self.client.create_collection 115 | if not self.recreate_collection 116 | else self.client.recreate_collection 117 | ) 118 | try: 119 | create_fn( 120 | collection_name=self.collection_name, 121 | vectors_config=models.VectorParams( 122 | size=self.embedding_size, distance=models.Distance.COSINE 123 | ), 124 | ) 125 | except ValueError as e: 126 | if self.recreate_collection: 127 | raise e 128 | pass # collection already exists 129 | 130 | def upsert_chunks(self, chunks: List[Chunk]): 131 | for i in range(0, len(chunks), self.upsert_batch_size): 132 | batch = chunks[i : i + self.upsert_batch_size] 133 | points = [] 134 | for chunk in batch: 135 | if chunk.embedding is None: 136 | raise ValueError("Chunk must be embedded before upserting") 137 | point = models.PointStruct( 138 | vector=chunk.embedding, 139 | payload=chunk.payload, 140 | id=sha256_to_uuid(chunk.hash), 141 | ) 142 | points.append(point) 143 | self.client.upsert(points=points, collection_name=self.collection_name) 144 | 145 | def retrieve(self, k: int, query: List[float]): 146 | search_result = self.client.search( 147 | collection_name=self.collection_name, 148 | query_vector=query, 149 | limit=k, 150 | ) 151 | return self._process_search_reults(search_result) 152 | 153 | async def async_retrieve(self, k: int, query: List[float]): 154 | search_result = await self.async_client.search( 155 | collection_name=self.collection_name, 156 | query_vector=query, 157 | limit=k, 158 | ) 159 | return self._process_search_reults(search_result) 160 | 161 | def _process_search_reults(self, search_result: List) -> List[RetrievedChunk]: 162 | retrieved_chunks = [] 163 | for r in search_result: 164 | content = r.payload["content"] 165 | metadata = r.payload["document_metadata"] 166 | score = r.score 167 | retrieved_chunks.append(RetrievedChunk(content, metadata, score)) 168 | return retrieved_chunks 169 | -------------------------------------------------------------------------------- /agent_dingo/serve.py: -------------------------------------------------------------------------------- 1 | from agent_dingo.core.state import State, Store, Context, ChatPrompt 2 | from agent_dingo.core.blocks import Pipeline 3 | from agent_dingo.core.message import UserMessage, SystemMessage, AssistantMessage 4 | from fastapi import FastAPI, HTTPException 5 | from pydantic import BaseModel 6 | import uvicorn 7 | from typing import List, Dict, Optional, Tuple, Union 8 | from uuid import uuid4 9 | import time 10 | 11 | 12 | class Message(BaseModel): 13 | role: str 14 | content: str 15 | 16 | 17 | class PipelineRunRequest(BaseModel): 18 | model: str 19 | messages: List[Message] 20 | 21 | 22 | class Usage(BaseModel): 23 | prompt_tokens: int 24 | completion_tokens: int 25 | total_tokens: int 26 | 27 | 28 | class Model(BaseModel): 29 | id: str 30 | object: str = "model" 31 | created: int 32 | owned_by: str = "dingo" 33 | 34 | 35 | class Models(BaseModel): 36 | models: List[Model] 37 | object: str = "list" 38 | 39 | 40 | class Choice(BaseModel): 41 | index: int 42 | message: Message 43 | logprobs: Optional[Dict] = None 44 | finish_reason: str = "stop" 45 | 46 | 47 | class PipelineOutputResponse(BaseModel): 48 | id: str 49 | object: str 50 | created: int 51 | model: str 52 | usage: Usage 53 | choices: List[Choice] 54 | 55 | 56 | _role_to_message_type = { 57 | "user": UserMessage, 58 | "system": SystemMessage, 59 | "assistant": AssistantMessage, 60 | } 61 | 62 | 63 | def _construct_response( 64 | output: str, usage: Usage, model: str 65 | ) -> PipelineOutputResponse: 66 | generated_uuid = str(uuid4()) 67 | current_timestamp = int(time.time()) 68 | return PipelineOutputResponse( 69 | id=generated_uuid, 70 | object="chat.completion", 71 | created=current_timestamp, 72 | usage=usage, 73 | model=model, 74 | choices=[ 75 | Choice( 76 | index=0, 77 | message=Message(role="assistant", content=output), 78 | finish_reason="stop", 79 | ) 80 | ], 81 | ) 82 | 83 | 84 | def _construct_pipeline_input( 85 | input_: List[Message], 86 | ) -> Tuple[ChatPrompt, Dict[str, str]]: 87 | messages = [] 88 | context = {} 89 | for m in input_: 90 | if m.role.startswith("context_"): 91 | key = m.role[8:] 92 | if key in context: 93 | raise ValueError(f"Context key {key} already exists.") 94 | context[key] = m.content 95 | else: 96 | msg = _role_to_message_type[m.role](m.content) 97 | messages.append(msg) 98 | state = ChatPrompt(messages) 99 | return state, context 100 | 101 | 102 | def make_app(pipeline: Union[Pipeline, Dict[str, Pipeline]], is_async: bool = False): 103 | app = FastAPI() 104 | created_at = int(time.time()) 105 | if isinstance(pipeline, Pipeline): 106 | available_pipelines = {"dingo": pipeline} 107 | else: 108 | for k, v in pipeline.items(): 109 | if not isinstance(v, Pipeline): 110 | raise ValueError(f"Pipeline {k} is not an instance of Pipeline.") 111 | available_pipelines = pipeline 112 | 113 | if is_async: 114 | 115 | @app.post("/chat/completions") 116 | async def run_pipeline(input: PipelineRunRequest) -> PipelineOutputResponse: 117 | state, context = _construct_pipeline_input(input.messages) 118 | selected_pipeline = available_pipelines[input.model] 119 | output, usage = await selected_pipeline.async_run(_state=state, **context) 120 | return _construct_response(output, Usage(**usage), model=input.model) 121 | 122 | else: 123 | 124 | @app.post("/chat/completions") 125 | def run_pipeline(input: PipelineRunRequest) -> PipelineOutputResponse: 126 | state, context = _construct_pipeline_input(input.messages) 127 | selected_pipeline = available_pipelines[input.model] 128 | output, usage = selected_pipeline.run(_state=state, **context) 129 | return _construct_response(output, Usage(**usage), model=input.model) 130 | 131 | @app.get("/models") 132 | async def get_models() -> Models: 133 | models = Models( 134 | models=[Model(id=k, created=created_at) for k in available_pipelines.keys()] 135 | ) 136 | return models 137 | 138 | return app 139 | 140 | 141 | def serve_pipeline( 142 | pipeline: Union[Pipeline, Dict[str, Pipeline]], 143 | is_async: bool = False, 144 | host: str = "0.0.0.0", 145 | port: int = 8000, 146 | ): 147 | app = make_app(pipeline, is_async) 148 | uvicorn.run(app, host=host, port=port) 149 | -------------------------------------------------------------------------------- /agent_dingo/utils.py: -------------------------------------------------------------------------------- 1 | def sha256_to_uuid(sha256_hash: str) -> str: 2 | short_hash = sha256_hash[:32] 3 | formatted_uuid = f"{short_hash[:8]}-{short_hash[8:12]}-{short_hash[12:16]}-{short_hash[16:20]}-{short_hash[20:32]}" 4 | return formatted_uuid 5 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools>=61.0"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | dependencies = [ 7 | "openai>=1.25.0,<2.0.0", 8 | "docstring_parser>=0.15.0,<1.0.0", 9 | "tenacity>=8.2.0,<9.0.0", 10 | ] 11 | name = "agent_dingo" 12 | version = "1.0.0" 13 | authors = [ 14 | { name="Oleh Kostromin", email="kostromin97@gmail.com" }, 15 | { name="Iryna Kondrashchenko", email="iryna230520@gmail.com" }, 16 | ] 17 | description = "A microframework for creating simple AI agents." 18 | readme = "README.md" 19 | license = {text = "MIT"} 20 | requires-python = ">=3.9" 21 | classifiers = [ 22 | "Programming Language :: Python :: 3", 23 | "License :: OSI Approved :: MIT License", 24 | "Operating System :: OS Independent", 25 | ] 26 | 27 | [project.optional-dependencies] 28 | server = ["fastapi>=0.105.0,<1.0.0", "uvicorn>=0.20.0,<1.0.0"] 29 | langchain = ["langchain>=0.1.0,<0.2.0"] 30 | qdrant = ["qdrant-client>=1.9.0,<2.0.0"] 31 | chromadb = ["chromadb>=0.5.0,<1.0.0"] 32 | sentence-transformers = ["sentence-transformers>=2.3.0,<3.0.0"] 33 | rag-default = ["PyPDF2>=3.0.0,<4.0.0", "beautifulsoup4>=4.12.0,<5.0.0", "requests>=2.26.0,<3.0.0", "python-docx>=1.0.0,<2.0.0"] 34 | vertexai = ["google-cloud-aiplatform>=1.40.0,<2.0.0"] 35 | litellm = ["litellm>=1.30.0,<2.0.0"] 36 | llama-cpp = ["llama-cpp-python>=0.2.20,<0.3.0"] 37 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BeastByteAI/agent_dingo/087a26a5096df5cb663fc20bd58c89ae89aadcb0/tests/__init__.py -------------------------------------------------------------------------------- /tests/fake_llm.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, List, Callable 2 | from agent_dingo.core.blocks import BaseLLM 3 | from agent_dingo.core.state import ChatPrompt, UsageMeter 4 | 5 | import openai 6 | from tenacity import retry, stop_after_attempt, wait_fixed 7 | 8 | 9 | class FakeLLM(BaseLLM): 10 | def __init__( 11 | self, 12 | model: str = "123", 13 | temperature: float = 0.7, 14 | base_url: Optional[str] = None, 15 | ): 16 | self.model = model 17 | self.temperature = temperature 18 | self.client = openai.OpenAI(base_url=base_url) 19 | self.async_client = openai.AsyncOpenAI(base_url=base_url) 20 | if base_url is None: 21 | self.supports_function_calls = True 22 | 23 | def send_message( 24 | self, 25 | messages, 26 | functions=None, 27 | usage_meter: UsageMeter = None, 28 | temperature=None, 29 | **kwargs, 30 | ): 31 | res, full_res = ( 32 | "Fake response", 33 | { 34 | "id": "chatcmpl-123", 35 | "object": "chat.completion", 36 | "created": 1677652288, 37 | "model": "gpt-3.5-turbo-0613", 38 | "system_fingerprint": "fp_44709d6fcb", 39 | "choices": [ 40 | { 41 | "index": 0, 42 | "message": { 43 | "role": "assistant", 44 | "content": "Fake response", 45 | }, 46 | "logprobs": None, 47 | "finish_reason": "stop", 48 | } 49 | ], 50 | "usage": { 51 | "prompt_tokens": 9, 52 | "completion_tokens": 12, 53 | "total_tokens": 21, 54 | }, 55 | }, 56 | ) 57 | if usage_meter: 58 | usage_meter.increment( 59 | prompt_tokens=full_res["usage"]["prompt_tokens"], 60 | completion_tokens=full_res["usage"]["completion_tokens"], 61 | ) 62 | if "function_call" in res.keys(): 63 | del res["function_call"] 64 | return res 65 | 66 | async def async_send_message( 67 | self, 68 | messages, 69 | functions=None, 70 | usage_meter: UsageMeter = None, 71 | temperature=None, 72 | **kwargs, 73 | ): 74 | return self.send_message( 75 | messages, functions, usage_meter, temperature, **kwargs 76 | ) 77 | 78 | def process_prompt( 79 | self, prompt: ChatPrompt, usage_meter: Optional[UsageMeter] = None, **kwargs 80 | ): 81 | return self.send_message(prompt.dict, None, usage_meter)["content"] 82 | 83 | async def async_process_prompt( 84 | self, prompt: ChatPrompt, usage_meter: Optional[UsageMeter] = None, **kwargs 85 | ): 86 | return (await self.async_send_message(prompt.dict, None, usage_meter))[ 87 | "content" 88 | ] 89 | -------------------------------------------------------------------------------- /tests/test_agent/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BeastByteAI/agent_dingo/087a26a5096df5cb663fc20bd58c89ae89aadcb0/tests/test_agent/__init__.py -------------------------------------------------------------------------------- /tests/test_agent/test_agent.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | from unittest.mock import patch 3 | from agent_dingo.agent import Agent 4 | from agent_dingo.agent.function_descriptor import FunctionDescriptor 5 | from tests.fake_llm import FakeLLM 6 | 7 | 8 | class TestAgentDingo(unittest.TestCase): 9 | def setUp(self): 10 | llm = FakeLLM() 11 | self.agent = Agent(llm) 12 | 13 | def test_register_function(self): 14 | def func(arg: str): 15 | """_summary_ 16 | 17 | Parameters 18 | ---------- 19 | arg : str 20 | _description_ 21 | """ 22 | pass 23 | 24 | self.agent.register_function(func) 25 | self.assertEqual(len(self.agent._registry._Registry__functions), 1) 26 | 27 | def test_register_descriptor(self): 28 | d = FunctionDescriptor( 29 | name="function_from_descriptor", 30 | func=lambda arg: None, 31 | json_repr={}, 32 | requires_context=False, 33 | ) 34 | self.agent.register_descriptor(d) 35 | self.assertEqual(len(self.agent._registry._Registry__functions), 1) 36 | self.assertIn( 37 | "function_from_descriptor", self.agent._registry._Registry__functions.keys() 38 | ) 39 | 40 | def test_function_decorator(self): 41 | @self.agent.function 42 | def func(arg: str): 43 | """_summary_ 44 | 45 | Parameters 46 | ---------- 47 | arg : str 48 | _description_ 49 | """ 50 | pass 51 | 52 | self.assertEqual(len(self.agent._registry._Registry__functions), 1) 53 | 54 | 55 | if __name__ == "__main__": 56 | unittest.main() 57 | -------------------------------------------------------------------------------- /tests/test_agent/test_extract_substr.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | from agent_dingo.agent.docgen import extract_substr 3 | 4 | 5 | class TestExtractSubstr(unittest.TestCase): 6 | def test_extract_substr(self): 7 | input_string = """Extracts the desription and args from a docstring. 8 | 9 | Args: 10 | input_string (str): The docstring to extract the description and args from. 11 | 12 | Returns: 13 | str: Reduced docstring containing only the description and the args. 14 | """ 15 | expected_output = """Extracts the desription and args from a docstring. 16 | 17 | Args: 18 | input_string (str): The docstring to extract the description and args from.""" 19 | self.assertEqual(extract_substr(input_string), expected_output) 20 | -------------------------------------------------------------------------------- /tests/test_agent/test_get_required_args.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | from agent_dingo.agent.helpers import get_required_args 3 | from agent_dingo.agent.chat_context import ChatContext 4 | 5 | 6 | class TestGetRequiredArgs(unittest.TestCase): 7 | def test_no_args(self): 8 | def func(): 9 | pass 10 | 11 | self.assertEqual(get_required_args(func), []) 12 | 13 | def test_required_args(self): 14 | def func(a, b, c): 15 | pass 16 | 17 | self.assertEqual(get_required_args(func), ["a", "b", "c"]) 18 | 19 | def test_optional_args(self): 20 | def func(a, b, c=None): 21 | pass 22 | 23 | self.assertEqual(get_required_args(func), ["a", "b"]) 24 | 25 | def test_mixed_args(self): 26 | def func(a, b, c=None, d=None): 27 | pass 28 | 29 | self.assertEqual(get_required_args(func), ["a", "b"]) 30 | 31 | def test_with_chat_context(self): 32 | def func(a, b, chat_context: ChatContext): 33 | pass 34 | 35 | self.assertEqual(get_required_args(func), ["a", "b"]) 36 | 37 | def test_wrong_chat_context_type(self): 38 | def func(a, b, chat_context: str): 39 | pass 40 | 41 | self.assertEqual(get_required_args(func), ["a", "b", "chat_context"]) 42 | 43 | def test_wrong_chat_context_name(self): 44 | def func(a, b, chat_context_: ChatContext): 45 | pass 46 | 47 | self.assertEqual(get_required_args(func), ["a", "b", "chat_context_"]) 48 | 49 | 50 | if __name__ == "__main__": 51 | unittest.main() 52 | -------------------------------------------------------------------------------- /tests/test_agent/test_parser.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | from agent_dingo.agent.parser import parse 3 | 4 | 5 | class TestParser(unittest.TestCase): 6 | def test_parse_google(self): 7 | docstring = """Parses a docstring. 8 | 9 | Args: 10 | docstring (str): The docstring to parse. 11 | 12 | Returns: 13 | dict: A dictionary containing the description and the arguments of the function. 14 | """ 15 | expected_output = ( 16 | { 17 | "description": "Parses a docstring.", 18 | "properties": { 19 | "docstring": { 20 | "type": "string", 21 | "description": "The docstring to parse.", 22 | } 23 | }, 24 | }, 25 | False, 26 | ) 27 | self.assertEqual(parse(docstring), expected_output) 28 | 29 | def test_parse_numpy(self): 30 | docstring = """Parses a docstring. 31 | 32 | Parameters 33 | ---------- 34 | docstring : str 35 | The docstring to parse. 36 | 37 | Returns 38 | ------- 39 | dict 40 | A dictionary containing the description and the arguments of the function. 41 | """ 42 | expected_output = ( 43 | { 44 | "description": "Parses a docstring.", 45 | "properties": { 46 | "docstring": { 47 | "type": "string", 48 | "description": "The docstring to parse.", 49 | } 50 | }, 51 | }, 52 | False, 53 | ) 54 | self.assertEqual(parse(docstring), expected_output) 55 | 56 | def test_parse_with_enum(self): 57 | docstring = """Parses a docstring. 58 | 59 | Parameters 60 | ---------- 61 | arg1 : str 62 | The first argument. 63 | arg2 : int 64 | The second argument. 65 | arg3 : float 66 | The third argument. 67 | arg4 : bool 68 | The fourth argument. 69 | arg5 : list 70 | The fifth argument. 71 | arg6 : dict 72 | The sixth argument. 73 | arg7 : str 74 | The seventh argument. Enum: ['value1', 'value2', 'value3'] 75 | 76 | Returns 77 | ------- 78 | dict 79 | A dictionary containing the description and the arguments of the function. 80 | """ 81 | expected_output = ( 82 | { 83 | "description": "Parses a docstring.", 84 | "properties": { 85 | "arg1": {"type": "string", "description": "The first argument."}, 86 | "arg2": {"type": "integer", "description": "The second argument."}, 87 | "arg3": {"type": "number", "description": "The third argument."}, 88 | "arg4": {"type": "boolean", "description": "The fourth argument."}, 89 | "arg5": {"type": "array", "description": "The fifth argument."}, 90 | "arg6": {"type": "object", "description": "The sixth argument."}, 91 | "arg7": { 92 | "type": "string", 93 | "description": "The seventh argument.", 94 | "enum": ["value1", "value2", "value3"], 95 | }, 96 | }, 97 | }, 98 | False, 99 | ) 100 | self.assertEqual(parse(docstring), expected_output) 101 | 102 | def test_parse_with_context(self): 103 | docstring = """Parses a docstring. 104 | 105 | Parameters 106 | ---------- 107 | arg1 : str 108 | The first argument. 109 | chat_context : ChatContext 110 | The chat context. 111 | 112 | Returns 113 | ------- 114 | dict 115 | A dictionary containing the description and the arguments of the function. 116 | """ 117 | expected_output = ( 118 | { 119 | "description": "Parses a docstring.", 120 | "properties": { 121 | "arg1": {"type": "string", "description": "The first argument."} 122 | }, 123 | }, 124 | True, 125 | ) 126 | self.assertEqual(parse(docstring), expected_output) 127 | 128 | def test_parse_without_description(self): 129 | docstring = """Parameters 130 | ---------- 131 | arg1 : str 132 | The first argument. 133 | 134 | Returns 135 | ------- 136 | dict 137 | A dictionary containing the description and the arguments of the function. 138 | """ 139 | with self.assertRaises(ValueError): 140 | parse(docstring) 141 | 142 | 143 | if __name__ == "__main__": 144 | unittest.main() 145 | -------------------------------------------------------------------------------- /tests/test_agent/test_registry.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | from agent_dingo.agent.registry import Registry as _Registry 3 | 4 | 5 | class TestRegistry(unittest.TestCase): 6 | def setUp(self): 7 | self.registry = _Registry() 8 | 9 | def test_add_function(self): 10 | def func(): 11 | pass 12 | 13 | json_repr = {"name": "func"} 14 | self.registry.add("func", func, json_repr, False) 15 | self.assertEqual(len(self.registry._Registry__functions), 1) 16 | 17 | def test_get_function(self): 18 | def func(): 19 | pass 20 | 21 | json_repr = {"name": "func"} 22 | self.registry.add("func", func, json_repr, False) 23 | func, requires_context = self.registry.get_function("func") 24 | self.assertEqual(callable(func), True) 25 | self.assertEqual(requires_context, False) 26 | 27 | def test_get_available_functions(self): 28 | def func1(): 29 | pass 30 | 31 | def func2(): 32 | pass 33 | 34 | json_repr1 = {"name": "func1"} 35 | json_repr2 = {"name": "func2"} 36 | self.registry.add("func1", func1, json_repr1, False) 37 | self.registry.add( 38 | "func2", func2, json_repr2, True, required_context_keys=["any"] 39 | ) 40 | available_functions = self.registry.get_available_functions() 41 | self.assertEqual(len(available_functions), 2) 42 | self.assertEqual(available_functions[0]["name"], "func1") 43 | self.assertEqual(available_functions[1]["name"], "func2") 44 | 45 | 46 | if __name__ == "__main__": 47 | unittest.main() 48 | -------------------------------------------------------------------------------- /tests/test_core/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BeastByteAI/agent_dingo/087a26a5096df5cb663fc20bd58c89ae89aadcb0/tests/test_core/__init__.py -------------------------------------------------------------------------------- /tests/test_core/test_blocks.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | from agent_dingo.core.blocks import ( 3 | Squash, 4 | PromptBuilder, 5 | Pipeline, 6 | Parallel, 7 | Identity, 8 | SaveState, 9 | LoadState, 10 | InlineBlock, 11 | ) 12 | from agent_dingo.core.state import State, ChatPrompt, KVData, Context, Store, UsageMeter 13 | from agent_dingo.core.message import Message 14 | 15 | 16 | class TestBlocks(unittest.TestCase): 17 | def test_squash(self): 18 | s = Squash("{0} {1}") 19 | state = KVData(_out_0="Hello", _out_1="World") 20 | context = Context() 21 | store = Store() 22 | self.assertEqual(s.forward(state, context, store)["_out_0"], "Hello World") 23 | 24 | def test_prompt_builder(self): 25 | pb = PromptBuilder([Message("Hello {name}")], from_state=["_out_0"]) 26 | state = KVData(_out_0="World") 27 | context = Context(name="World") 28 | store = Store() 29 | self.assertEqual( 30 | pb.forward(state, context, store).messages[0].content, "Hello World" 31 | ) 32 | 33 | def test_pipeline(self): 34 | p = Pipeline() 35 | p.add_block(Squash("{0} {1}")) 36 | state = KVData(_out_0="Hello", _out_1="World") 37 | context = Context() 38 | store = Store() 39 | self.assertEqual(p.forward(state, context, store)["_out_0"], "Hello World") 40 | self.assertEqual(p.run(state)[0], "Hello World") 41 | 42 | def test_pipeline_build(self): 43 | block1 = Identity() 44 | block2 = Identity() 45 | pipeline = block1 >> block2 46 | self.assertIsInstance(pipeline, Pipeline) 47 | self.assertEqual(len(pipeline._blocks), 2) 48 | self.assertIs(pipeline._blocks[0], block1) 49 | self.assertIs(pipeline._blocks[1], block2) 50 | 51 | def test_pipeline_build_with_parallel(self): 52 | block1 = Identity() 53 | block2 = Identity() 54 | squash = Squash("{0} {1}") 55 | pipeline = (block1 & block2) >> squash 56 | self.assertIsInstance(pipeline, Pipeline) 57 | self.assertEqual(len(pipeline._blocks), 2) 58 | self.assertIsInstance(pipeline._blocks[0], Parallel) 59 | self.assertIs(pipeline._blocks[1], squash) 60 | 61 | def test_identity(self): 62 | i = Identity() 63 | state = KVData(_out_0="Hello") 64 | context = Context() 65 | store = Store() 66 | self.assertIs(i.forward(state, context, store), state) 67 | 68 | def test_save_state(self): 69 | ss = SaveState("key") 70 | state = KVData(_out_0="Hello") 71 | context = Context() 72 | store = Store() 73 | self.assertEqual(ss.forward(state, context, store)["_out_0"], "Hello") 74 | self.assertIs(store.get_data("key"), state) 75 | 76 | def test_load_state(self): 77 | ls = LoadState("data", "key") 78 | state = KVData(_out_0="Hello") 79 | context = Context() 80 | store = Store() 81 | store.update("key", state) 82 | self.assertIs(ls.forward(state, context, store), state) 83 | 84 | def test_inline_block(self): 85 | ib = InlineBlock() 86 | state_ = KVData(_out_0="Hello") 87 | 88 | @ib 89 | def func(state, context, store): 90 | return state_ 91 | 92 | state = KVData(_out_0="World") 93 | context = Context() 94 | store = Store() 95 | self.assertIs(func.forward(state, context, store), state_) 96 | 97 | 98 | if __name__ == "__main__": 99 | unittest.main() 100 | -------------------------------------------------------------------------------- /tests/test_core/test_message.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | from agent_dingo.core.message import ( 3 | Message, 4 | UserMessage, 5 | SystemMessage, 6 | AssistantMessage, 7 | ) 8 | 9 | 10 | class TestMessage(unittest.TestCase): 11 | def test_message(self): 12 | m = Message("Hello") 13 | self.assertEqual(m.content, "Hello") 14 | self.assertEqual(m.role, "undefined") 15 | self.assertEqual(m.dict, {"role": "undefined", "content": "Hello"}) 16 | 17 | def test_user_message(self): 18 | m = UserMessage("Hello") 19 | self.assertEqual(m.content, "Hello") 20 | self.assertEqual(m.role, "user") 21 | self.assertEqual(m.dict, {"role": "user", "content": "Hello"}) 22 | 23 | def test_system_message(self): 24 | m = SystemMessage("Hello") 25 | self.assertEqual(m.content, "Hello") 26 | self.assertEqual(m.role, "system") 27 | self.assertEqual(m.dict, {"role": "system", "content": "Hello"}) 28 | 29 | def test_assistant_message(self): 30 | m = AssistantMessage("Hello") 31 | self.assertEqual(m.content, "Hello") 32 | self.assertEqual(m.role, "assistant") 33 | self.assertEqual(m.dict, {"role": "assistant", "content": "Hello"}) 34 | 35 | 36 | if __name__ == "__main__": 37 | unittest.main() 38 | -------------------------------------------------------------------------------- /tests/test_core/test_output_parser.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | from agent_dingo.core.output_parser import BaseOutputParser, DefaultOutputParser 3 | from agent_dingo.core.state import State, ChatPrompt, KVData 4 | from agent_dingo.core.message import Message 5 | 6 | 7 | class TestOutputParser(unittest.TestCase): 8 | def test_base_output_parser(self): 9 | class TestOutputParser(BaseOutputParser): 10 | def _parse_chat(self, output: ChatPrompt) -> str: 11 | return "chat" 12 | 13 | def _parse_kvdata(self, output: KVData) -> str: 14 | return "kvdata" 15 | 16 | parser = TestOutputParser() 17 | self.assertEqual(parser.parse(KVData(a="1")), "kvdata") 18 | self.assertEqual(parser.parse(ChatPrompt([Message("Hello")])), "chat") 19 | with self.assertRaises(TypeError): 20 | parser.parse("invalid") 21 | 22 | def test_default_output_parser(self): 23 | parser = DefaultOutputParser() 24 | with self.assertRaises(RuntimeError): 25 | parser.parse(ChatPrompt([Message("Hello")])) 26 | self.assertEqual(parser.parse(KVData(_out_0="output")), "output") 27 | with self.assertRaises(KeyError): 28 | parser.parse(KVData(a="1")) 29 | 30 | 31 | if __name__ == "__main__": 32 | unittest.main() 33 | -------------------------------------------------------------------------------- /tests/test_core/test_state.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | from agent_dingo.core.state import ChatPrompt, KVData, Context, UsageMeter, Store 3 | from agent_dingo.core.message import Message 4 | 5 | 6 | class TestState(unittest.TestCase): 7 | def test_chat_prompt(self): 8 | cp = ChatPrompt([Message("Hello")]) 9 | self.assertEqual(cp.dict, [{"role": "undefined", "content": "Hello"}]) 10 | 11 | def test_kvdata(self): 12 | kv = KVData(a="1", b="2") 13 | self.assertEqual(kv["a"], "1") 14 | self.assertEqual(kv["b"], "2") 15 | self.assertEqual(kv.dict, {"a": "1", "b": "2"}) 16 | with self.assertRaises(KeyError): 17 | kv.update("a", "3") 18 | 19 | def test_context(self): 20 | ctx = Context(a="1", b="2") 21 | with self.assertRaises(RuntimeError): 22 | ctx.update("a", "3") 23 | 24 | def test_usage_meter(self): 25 | um = UsageMeter() 26 | um.increment(10, 20) 27 | self.assertEqual( 28 | um.get_usage(), 29 | {"prompt_tokens": 10, "completion_tokens": 20, "total_tokens": 30}, 30 | ) 31 | 32 | def test_store(self): 33 | st = Store() 34 | st.update("data", KVData(a="1")) 35 | st.update("prompt", ChatPrompt([Message("Hello")])) 36 | st.update("misc", "misc") 37 | self.assertEqual(st.get_data("data").dict, {"a": "1"}) 38 | self.assertEqual( 39 | st.get_prompt("prompt").dict, [{"role": "undefined", "content": "Hello"}] 40 | ) 41 | self.assertEqual(st.get_misc("misc"), "misc") 42 | 43 | 44 | if __name__ == "__main__": 45 | unittest.main() 46 | --------------------------------------------------------------------------------