├── .github
├── dependabot.yml
├── pr-labeler.yml
├── release-drafter.yml
└── workflows
│ ├── draft.yml
│ ├── pr_labeler.yml
│ └── pypi_deploy.yaml
├── .gitignore
├── CODE_OF_CONDUCT.md
├── CONTRIBUTING.md
├── LICENSE
├── README.md
├── agent_dingo
├── __init__.py
├── agent
│ ├── __init__.py
│ ├── agent.py
│ ├── chat_context.py
│ ├── docgen.py
│ ├── function_descriptor.py
│ ├── helpers.py
│ ├── langchain.py
│ ├── parser.py
│ └── registry.py
├── core
│ ├── blocks.py
│ ├── message.py
│ ├── output_parser.py
│ └── state.py
├── llm
│ ├── gemini.py
│ ├── litellm.py
│ ├── llama_cpp.py
│ └── openai.py
├── rag
│ ├── base.py
│ ├── chunkers
│ │ ├── __init__.py
│ │ └── recursive.py
│ ├── embedders
│ │ ├── __init__.py
│ │ ├── openai.py
│ │ └── sentence_transformer.py
│ ├── prompt_modifiers.py
│ ├── readers
│ │ ├── __init__.py
│ │ ├── list.py
│ │ ├── pdf.py
│ │ ├── web.py
│ │ └── word.py
│ └── vector_stores
│ │ ├── chromadb.py
│ │ └── qdrant.py
├── serve.py
└── utils.py
├── pyproject.toml
└── tests
├── __init__.py
├── fake_llm.py
├── test_agent
├── __init__.py
├── test_agent.py
├── test_extract_substr.py
├── test_get_required_args.py
├── test_parser.py
└── test_registry.py
└── test_core
├── __init__.py
├── test_blocks.py
├── test_message.py
├── test_output_parser.py
└── test_state.py
/.github/dependabot.yml:
--------------------------------------------------------------------------------
1 | version: 2
2 | updates:
3 | - package-ecosystem: "pip"
4 | directory: "/"
5 | schedule:
6 | interval: "monthly"
7 | - package-ecosystem: "github-actions"
8 | directory: "/"
9 | schedule:
10 | interval: "monthly"
11 |
--------------------------------------------------------------------------------
/.github/pr-labeler.yml:
--------------------------------------------------------------------------------
1 | feature: ['features/*', 'feature/*', 'feat/*', 'features-*', 'feature-*', 'feat-*']
2 | fix: ['fixes/*', 'fix/*']
3 | chore: ['chore/*']
4 | dependencies: ['update/*']
5 |
--------------------------------------------------------------------------------
/.github/release-drafter.yml:
--------------------------------------------------------------------------------
1 | name-template: "v$RESOLVED_VERSION"
2 | tag-template: "v$RESOLVED_VERSION"
3 | categories:
4 | - title: "🚀 Features"
5 | labels:
6 | - "feature"
7 | - "enhancement"
8 | - title: "🐛 Bug Fixes"
9 | labels:
10 | - "fix"
11 | - "bugfix"
12 | - "bug"
13 | - title: "🧹 Maintenance"
14 | labels:
15 | - "maintenance"
16 | - "dependencies"
17 | - "refactoring"
18 | - "cosmetic"
19 | - "chore"
20 | - title: "📝️ Documentation"
21 | labels:
22 | - "documentation"
23 | - "docs"
24 | change-template: "- $TITLE (#$NUMBER)"
25 | change-title-escapes: '\<*_&' # You can add # and @ to disable mentions
26 | version-resolver:
27 | major:
28 | labels:
29 | - "major"
30 | minor:
31 | labels:
32 | - "minor"
33 | patch:
34 | labels:
35 | - "patch"
36 | default: patch
37 | template: |
38 | ## Changes
39 |
40 | $CHANGES
41 |
--------------------------------------------------------------------------------
/.github/workflows/draft.yml:
--------------------------------------------------------------------------------
1 | # Drafts the next Release notes as Pull Requests are merged (or commits are pushed) into "main" or "master"
2 | name: Draft next release
3 |
4 | on:
5 | push:
6 | branches: [main, "master"]
7 |
8 | permissions:
9 | contents: read
10 |
11 | jobs:
12 | update-release-draft:
13 | permissions:
14 | contents: write
15 | pull-requests: write
16 | runs-on: ubuntu-latest
17 | steps:
18 | - uses: release-drafter/release-drafter@v5
19 | env:
20 | GITHUB_TOKEN: ${{ github.token }}
21 |
--------------------------------------------------------------------------------
/.github/workflows/pr_labeler.yml:
--------------------------------------------------------------------------------
1 | # This workflow will apply the corresponding label on a pull request
2 | name: PR Labeler
3 |
4 | on:
5 | pull_request_target:
6 |
7 | permissions:
8 | contents: read
9 | pull-requests: write
10 |
11 | jobs:
12 | pr-labeler:
13 | runs-on: ubuntu-latest
14 | steps:
15 | - uses: TimonVS/pr-labeler-action@v4
16 | with:
17 | repo-token: ${{ github.token }}
18 |
--------------------------------------------------------------------------------
/.github/workflows/pypi_deploy.yaml:
--------------------------------------------------------------------------------
1 | name: PyPi Deploy
2 |
3 | on:
4 | release:
5 | types: [published]
6 | workflow_dispatch:
7 |
8 | jobs:
9 | deploy:
10 | runs-on: ubuntu-latest
11 | steps:
12 | - name: Checkout code
13 | uses: actions/checkout@v4
14 |
15 | - name: Setup Python
16 | uses: actions/setup-python@v4
17 | with:
18 | python-version: '3.10'
19 |
20 | - name: Install dependencies
21 | run: |
22 | python -m pip install --upgrade pip
23 | pip install build twine
24 |
25 | - name: Build and publish
26 | env:
27 | TWINE_USERNAME: __token__
28 | TWINE_PASSWORD: ${{ secrets.PYPI_TOKEN }}
29 | run: |
30 | python -m build
31 | twine upload dist/*
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 | cover/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | .pybuilder/
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | # For a library or package, you might want to ignore these files since the code is
87 | # intended to run in multiple environments; otherwise, check them in:
88 | # .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # poetry
98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99 | # This is especially recommended for binary packages to ensure reproducibility, and is more
100 | # commonly ignored for libraries.
101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102 | #poetry.lock
103 |
104 | # pdm
105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106 | #pdm.lock
107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108 | # in version control.
109 | # https://pdm.fming.dev/#use-with-ide
110 | .pdm.toml
111 |
112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113 | __pypackages__/
114 |
115 | # Celery stuff
116 | celerybeat-schedule
117 | celerybeat.pid
118 |
119 | # SageMath parsed files
120 | *.sage.py
121 |
122 | # Environments
123 | .env
124 | .venv
125 | env/
126 | venv/
127 | ENV/
128 | env.bak/
129 | venv.bak/
130 |
131 | # Spyder project settings
132 | .spyderproject
133 | .spyproject
134 |
135 | # Rope project settings
136 | .ropeproject
137 |
138 | # mkdocs documentation
139 | /site
140 |
141 | # mypy
142 | .mypy_cache/
143 | .dmypy.json
144 | dmypy.json
145 |
146 | # Pyre type checker
147 | .pyre/
148 |
149 | # pytype static type analyzer
150 | .pytype/
151 |
152 | # Cython debug symbols
153 | cython_debug/
154 |
155 | # PyCharm
156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
158 | # and can be added to the global gitignore or merged into this file. For a more nuclear
159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
160 | #.idea/
161 | test.py
162 | tmp*.py
--------------------------------------------------------------------------------
/CODE_OF_CONDUCT.md:
--------------------------------------------------------------------------------
1 | # Contributor Covenant Code of Conduct
2 |
3 | ## Our Pledge
4 |
5 | We as members, contributors, and leaders pledge to make participation in our
6 | community a harassment-free experience for everyone, regardless of age, body
7 | size, visible or invisible disability, ethnicity, sex characteristics, gender
8 | identity and expression, level of experience, education, socio-economic status,
9 | nationality, personal appearance, race, religion, or sexual identity
10 | and orientation.
11 |
12 | We pledge to act and interact in ways that contribute to an open, welcoming,
13 | diverse, inclusive, and healthy community.
14 |
15 | ## Our Standards
16 |
17 | Examples of behavior that contributes to a positive environment for our
18 | community include:
19 |
20 | * Demonstrating empathy and kindness toward other people
21 | * Being respectful of differing opinions, viewpoints, and experiences
22 | * Giving and gracefully accepting constructive feedback
23 | * Accepting responsibility and apologizing to those affected by our mistakes,
24 | and learning from the experience
25 | * Focusing on what is best not just for us as individuals, but for the
26 | overall community
27 |
28 | Examples of unacceptable behavior include:
29 |
30 | * The use of sexualized language or imagery, and sexual attention or
31 | advances of any kind
32 | * Trolling, insulting or derogatory comments, and personal or political attacks
33 | * Public or private harassment
34 | * Publishing others' private information, such as a physical or email
35 | address, without their explicit permission
36 | * Other conduct which could reasonably be considered inappropriate in a
37 | professional setting
38 |
39 | ## Enforcement Responsibilities
40 |
41 | Community leaders are responsible for clarifying and enforcing our standards of
42 | acceptable behavior and will take appropriate and fair corrective action in
43 | response to any behavior that they deem inappropriate, threatening, offensive,
44 | or harmful.
45 |
46 | Community leaders have the right and responsibility to remove, edit, or reject
47 | comments, commits, code, wiki edits, issues, and other contributions that are
48 | not aligned to this Code of Conduct, and will communicate reasons for moderation
49 | decisions when appropriate.
50 |
51 | ## Scope
52 |
53 | This Code of Conduct applies within all community spaces, and also applies when
54 | an individual is officially representing the community in public spaces.
55 | Examples of representing our community include using an official e-mail address,
56 | posting via an official social media account, or acting as an appointed
57 | representative at an online or offline event.
58 |
59 | ## Enforcement
60 |
61 | Instances of abusive, harassing, or otherwise unacceptable behavior may be
62 | reported to the community leaders responsible for enforcement at
63 | https://www.linkedin.com/in/iryna-kondrashchenko-673800155/.
64 | All complaints will be reviewed and investigated promptly and fairly.
65 |
66 | All community leaders are obligated to respect the privacy and security of the
67 | reporter of any incident.
68 |
69 | ## Enforcement Guidelines
70 |
71 | Community leaders will follow these Community Impact Guidelines in determining
72 | the consequences for any action they deem in violation of this Code of Conduct:
73 |
74 | ### 1. Correction
75 |
76 | **Community Impact**: Use of inappropriate language or other behavior deemed
77 | unprofessional or unwelcome in the community.
78 |
79 | **Consequence**: A private, written warning from community leaders, providing
80 | clarity around the nature of the violation and an explanation of why the
81 | behavior was inappropriate. A public apology may be requested.
82 |
83 | ### 2. Warning
84 |
85 | **Community Impact**: A violation through a single incident or series
86 | of actions.
87 |
88 | **Consequence**: A warning with consequences for continued behavior. No
89 | interaction with the people involved, including unsolicited interaction with
90 | those enforcing the Code of Conduct, for a specified period of time. This
91 | includes avoiding interactions in community spaces as well as external channels
92 | like social media. Violating these terms may lead to a temporary or
93 | permanent ban.
94 |
95 | ### 3. Temporary Ban
96 |
97 | **Community Impact**: A serious violation of community standards, including
98 | sustained inappropriate behavior.
99 |
100 | **Consequence**: A temporary ban from any sort of interaction or public
101 | communication with the community for a specified period of time. No public or
102 | private interaction with the people involved, including unsolicited interaction
103 | with those enforcing the Code of Conduct, is allowed during this period.
104 | Violating these terms may lead to a permanent ban.
105 |
106 | ### 4. Permanent Ban
107 |
108 | **Community Impact**: Demonstrating a pattern of violation of community
109 | standards, including sustained inappropriate behavior, harassment of an
110 | individual, or aggression toward or disparagement of classes of individuals.
111 |
112 | **Consequence**: A permanent ban from any sort of public interaction within
113 | the community.
114 |
115 | ## Attribution
116 |
117 | This Code of Conduct is adapted from the [Contributor Covenant][homepage],
118 | version 2.0, available at
119 | https://www.contributor-covenant.org/version/2/0/code_of_conduct.html.
120 |
121 | Community Impact Guidelines were inspired by [Mozilla's code of conduct
122 | enforcement ladder](https://github.com/mozilla/diversity).
123 |
124 | [homepage]: https://www.contributor-covenant.org
125 |
126 | For answers to common questions about this code of conduct, see the FAQ at
127 | https://www.contributor-covenant.org/faq. Translations are available at
128 | https://www.contributor-covenant.org/translations.
--------------------------------------------------------------------------------
/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | # Contributing to agent_dingo
2 |
3 | Welcome! We appreciate your interest in contributing to `agent_dingo`. Whether you're a developer, designer, writer, or simply passionate about open source, there are several ways you can help improve this project. This document will guide you through the process of contributing to our Python repository.
4 |
5 | ## How Can I Contribute?
6 |
7 | There are several ways you can contribute to this project:
8 |
9 | - Bug Fixes: Help us identify and fix issues in the codebase.
10 | - Feature Implementation: Implement new features and enhancements.
11 | - Documentation: Improve the project's documentation, including code comments and README files.
12 | - Testing: Write and improve test cases to ensure the project's quality and reliability.
13 | - Translations: Provide translations for the project's documentation or user interface.
14 | - Bug Reports and Feature Requests: Submit bug reports or suggest new features and improvements.
15 |
16 | **Important:** before contributing, we recommend that you open an issue to discuss your planned changes. This allows us to align our goals, provide guidance, and potentially find other contributors interested in collaborating on the same feature or bug fix.
17 |
18 | > ### Legal Notice
19 | >
20 | > When contributing to this project, you must agree that you have authored 100% of the content, that you have the necessary rights to the content and that the content you contribute may be provided under the project license.
21 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023 Oleh Kostromin
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 | Agent Dingo
6 |
7 |
8 |
9 |
A microframework for building LLM-powered pipelines and agents.
25 |
26 | _Dingo_ is a compact LLM orchestration framework designed for straightforward development of production-ready LLM-powered applications. It combines simplicity with flexibility, allowing for the efficient construction of pipelines and agents, while maintaining a high level of control over the process.
27 |
28 |
29 | ## Support us 🤝
30 |
31 | You can support the project in the following ways:
32 |
33 | - ⭐ Star Dingo on GitHub (click the star button in the top right corner)
34 | - 💡 Provide your feedback or propose ideas in the [issues](https://github.com/BeastByteAI/agent_dingo/issues) section or [Discord](https://discord.gg/YDAbwuWK7V)
35 | - 📰 Post about Dingo on LinkedIn or other platforms
36 | - 🔗 Check out our other projects: Scikit-LLM, Falcon
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 | ## Quick Start & Documentation 🚀
56 |
57 | **Step 1:** Install `agent-dingo`
58 |
59 | ```bash
60 | pip install agent-dingo
61 | ```
62 |
63 | **Step 2:** Configure your OpenAI API key
64 |
65 | ```bash
66 | export OPENAI_API_KEY=
67 | ```
68 |
69 | **Step 3:** Build your pipeline
70 |
71 | Example 1 (Linear Pipeline):
72 |
73 | ````python
74 | from agent_dingo.llm.openai import OpenAI
75 | from agent_dingo.core.blocks import PromptBuilder
76 | from agent_dingo.core.message import UserMessage
77 | from agent_dingo.core.state import ChatPrompt
78 |
79 |
80 | # Model
81 | gpt = OpenAI("gpt-3.5-turbo")
82 |
83 | # Summary prompt block
84 | summary_pb = PromptBuilder(
85 | [UserMessage("Summarize the text in 10 words: ```{text}```.")]
86 | )
87 |
88 | # Translation prompt block
89 | translation_pb = PromptBuilder(
90 | [UserMessage("Translate the text into {language}: ```{summarized_text}```.")],
91 | from_state=["summarized_text"],
92 | )
93 |
94 | # Pipeline
95 | pipeline = summary_pb >> gpt >> translation_pb >> gpt
96 |
97 | input_text = """
98 | Dingo is an ancient lineage of dog found in Australia, exhibiting a lean and sturdy physique adapted for speed and endurance, dingoes feature a wedge-shaped skull and come in colorations like light ginger, black and tan, or creamy white. They share a close genetic relationship with the New Guinea singing dog, diverging early from the domestic dog lineage. Dingoes typically form packs composed of a mated pair and their offspring, indicating social structures that have persisted through their history, dating back approximately 3,500 years in Australia.
99 | """
100 |
101 | output = pipeline.run(text = input_text, language = "french")
102 | print(output)
103 | ````
104 |
105 | Example 2 (Agent):
106 |
107 | ```python
108 | from agent_dingo.agent import Agent
109 | from agent_dingo.llm.openai import OpenAI
110 | import requests
111 |
112 | llm = OpenAI(model="gpt-3.5-turbo")
113 | agent = Agent(llm, max_function_calls=3)
114 |
115 | @agent.function
116 | def get_temperature(city: str) -> str:
117 | """Retrieves the current temperature in a city.
118 |
119 | Parameters
120 | ----------
121 | city : str
122 | The city to get the temperature for.
123 |
124 | Returns
125 | -------
126 | str
127 | String representation of the json response from the weather api.
128 | """
129 | base_url = "https://api.openweathermap.org/data/2.5/weather"
130 | params = {
131 | "q": city,
132 | "appid": "",
133 | "units": "metric"
134 | }
135 | response = requests.get(base_url, params=params)
136 | data = response.json()
137 | return str(data)
138 |
139 | pipeline = agent.as_pipeline()
140 | ```
141 |
142 | For a more detailed overview and additional examples, please refer to the **[documentation](https://dingo.beastbyte.ai/)**.
143 |
--------------------------------------------------------------------------------
/agent_dingo/__init__.py:
--------------------------------------------------------------------------------
1 | # from agent_dingo.agent import AgentDingo
2 |
3 | __version__ = "0.1.0"
4 | __author__ = "Oleh Kostromin, Iryna Kondrashchenko"
5 |
--------------------------------------------------------------------------------
/agent_dingo/agent/__init__.py:
--------------------------------------------------------------------------------
1 | from agent_dingo.agent.agent import Agent
--------------------------------------------------------------------------------
/agent_dingo/agent/agent.py:
--------------------------------------------------------------------------------
1 | from typing import Callable, Union, Optional, Tuple, List, Literal
2 | from agent_dingo.agent.parser import parse
3 | from agent_dingo.agent.helpers import get_required_args, construct_json_repr
4 | from agent_dingo.agent.docgen import generate_docstring
5 | from agent_dingo.agent.function_descriptor import FunctionDescriptor
6 | from agent_dingo.core.blocks import BaseLLM, BaseAgent, Context, ChatPrompt, KVData
7 | from agent_dingo.core.message import UserMessage
8 | from agent_dingo.agent.chat_context import ChatContext
9 | from agent_dingo.agent.registry import Registry as _Registry
10 | import json
11 | import os
12 | import inspect
13 | from asyncio import run as asyncio_run, to_thread
14 | import warnings
15 |
16 |
17 | class Agent(BaseAgent):
18 |
19 | def __init__(
20 | self,
21 | llm: BaseLLM,
22 | max_function_calls: int = 10,
23 | before_function_call: Callable = None,
24 | allow_codegen: Union[bool, Literal["env"]] = "env",
25 | name="agent",
26 | description: str = "A helpful agent",
27 | ):
28 | """The agent that can be used to register functions and chat with the LLM.
29 |
30 | Parameters
31 | ----------
32 | llm : BaseLLM
33 | llm instance to use; must support function calls
34 | max_function_calls : int, optional
35 | max number of consecutive function calls, by default 10
36 | before_function_call : Callable, optional
37 | a callable to execute before function call, by default None
38 | allow_codegen : Union[bool, Literal["env"]], optional
39 | determines whether the function docstrings can be auto-generated by the LLM, by default "env"
40 | name : str, optional
41 | name of the agent, by default "agent"
42 | description : str, optional
43 | description of the agent (needed when used as a sub-agent), by default "A helpful agent"
44 | """
45 | if not isinstance(allow_codegen, bool) and allow_codegen != "env":
46 | raise ValueError(
47 | "allow_codegen must be a boolean or the string 'env' to use the DINGO_ALLOW_CODEGEN environment variable"
48 | )
49 | if not llm.supports_function_calls:
50 | raise ValueError(
51 | "Provided LLM does not support function calls and cannot be used with the agent."
52 | )
53 | self.model = llm
54 | self._allow_codegen = allow_codegen
55 | self._registry = _Registry()
56 | self.max_function_calls = max_function_calls
57 | self.before_function_call = before_function_call
58 | self.name = name
59 | self.description = description
60 | self._registered = False
61 |
62 | def _is_codegen_allowed(self) -> bool:
63 | """Determines whether docstring generation is allowed.
64 |
65 | Returns
66 | -------
67 | bool
68 | True if docstring generation is allowed, False otherwise.
69 | """
70 | if self._allow_codegen == "env":
71 | return bool(os.getenv("DINGO_ALLOW_CODEGEN", True))
72 | return self._allow_codegen
73 |
74 | def register_descriptor(self, descriptor: FunctionDescriptor) -> None:
75 | """Registers a function descriptor with the agent.
76 |
77 | Parameters
78 | ----------
79 | descriptor : FunctionDescriptor
80 | The function descriptor to register.
81 | """
82 | if descriptor.required_context_keys is not None and self._registered:
83 | raise ValueError(
84 | "required_context_keys must be None if functions are registered after the agent"
85 | )
86 | if not isinstance(descriptor, FunctionDescriptor):
87 | raise ValueError("descriptor must be a FunctionDescriptor")
88 | self._registry.add(
89 | name=descriptor.name,
90 | func=descriptor.func,
91 | json_repr=descriptor.json_repr,
92 | requires_context=descriptor.requires_context,
93 | required_context_keys=descriptor.required_context_keys,
94 | )
95 |
96 | def register_function(
97 | self, func: Callable, required_context_keys: Optional[List[str]] = None
98 | ) -> None:
99 | """Registers a function with the agent.
100 |
101 | Parameters
102 | ----------
103 | func : Callable
104 | The function to register.
105 |
106 | Raises
107 | ------
108 | ValueError
109 | Function has no docstring and code generation is not allowed
110 | """
111 | if required_context_keys is not None and self._registered:
112 | raise ValueError(
113 | "required_context_keys must be None if functions are registered after the agent"
114 | )
115 | if required_context_keys is not None:
116 | for key in required_context_keys:
117 | if not isinstance(key, str):
118 | raise ValueError("required_context_keys must be a list of strings")
119 | docstring = func.__doc__
120 | if docstring is None:
121 | if not self._is_codegen_allowed():
122 | raise ValueError(
123 | "Function has no docstring and code generation is not allowed"
124 | )
125 | docstring = generate_docstring(func, self.model)
126 | body, requires_context = parse(docstring)
127 | required_args = get_required_args(func)
128 | json_repr = construct_json_repr(
129 | func.__name__, body["description"], body["properties"], required_args
130 | )
131 | self._registry.add(
132 | func.__name__, func, json_repr, requires_context, required_context_keys
133 | )
134 |
135 | def _call_from_agent(self, query: str, chat_context: ChatContext) -> str:
136 | """Calls the agent from another from the agent.
137 |
138 | Parameters
139 | ----------
140 | query : str
141 | Query
142 | context : Context
143 | Chat context
144 | """
145 | prompt = ChatPrompt([UserMessage(query)])
146 | response = self.forward(prompt, chat_context[0], chat_context[1])
147 | return response["_out_0"]
148 |
149 | async def _async_call_from_agent(
150 | self, query: str, chat_context: ChatContext
151 | ) -> str:
152 | """Calls the agent from another from the agent.
153 |
154 | Parameters
155 | ----------
156 | query : str
157 | Query
158 | context : Context
159 | Chat context
160 | """
161 | prompt = ChatPrompt([UserMessage(query)])
162 | response = await self.async_forward(prompt, chat_context[0], chat_context[1])
163 | return response["_out_0"]
164 |
165 | def as_function_descriptor(self, as_async: bool = False) -> FunctionDescriptor:
166 | descriptor = FunctionDescriptor(
167 | name=self.name,
168 | func=self._call_from_agent if not as_async else self._async_call_from_agent,
169 | json_repr={
170 | "name": self.name,
171 | "description": self.description,
172 | "parameters": {
173 | "type": "object",
174 | "properties": {
175 | "query": {
176 | "type": "string",
177 | "description": "The natural language query to send to the agent.",
178 | }
179 | },
180 | "required": ["query"],
181 | },
182 | },
183 | requires_context=True,
184 | required_context_keys=self.get_required_context_keys(),
185 | )
186 | self._registered = True
187 | return descriptor
188 |
189 | def function(self, *args, **kwargs) -> Callable:
190 | """Registers a function with the agent and returns the function.
191 |
192 | Parameters
193 | ----------
194 | func : Callable
195 | The function to register.
196 |
197 | Returns
198 | -------
199 | Callable
200 | The function.
201 | """
202 |
203 | def outer(required_context_keys):
204 | def register_decorator(func):
205 | self.register_function(
206 | func, required_context_keys=required_context_keys
207 | )
208 | return func
209 |
210 | return register_decorator
211 |
212 | if len(args) == 1 and callable(args[0]) and not kwargs:
213 | func = args[0]
214 | self.register_function(func)
215 | return func
216 | else:
217 | return outer(kwargs.get("required_context_keys", None))
218 |
219 | def get_required_context_keys(self) -> List[str]:
220 | # this allows to handle the case where the user registers a function after registering the agent
221 | return self._registry.get_required_context_keys()
222 |
223 | def forward(
224 | self, state: ChatPrompt, context: Context, store: KVData
225 | ) -> Tuple[str, List[dict]]:
226 | """Sends a message to the LLM and returns the response. Calls functions if the LLM requests it.
227 |
228 | Parameters
229 | ----------
230 | messages : Union[str, dict]
231 | The message(s) to send to the LLM
232 | context : ChatContext, optional
233 | The chat context, by default None
234 | Returns
235 | -------
236 | Tuple[str, List[dict]]
237 | A tuple containing the last response and the conversation history.
238 | """
239 | messages = state.dict
240 | n_calls = 0
241 | available_functions = self._registry.get_available_functions()
242 | chat_context = (context, store)
243 | while True:
244 | available_functions_i = (
245 | available_functions if n_calls < self.max_function_calls else None
246 | )
247 | response = self.model.send_message(
248 | messages,
249 | functions=available_functions_i,
250 | usage_meter=store.usage_meter,
251 | )
252 | if response.get("tool_calls"):
253 | messages.append(response)
254 | for function in response["tool_calls"]:
255 | function_name = function["function"]["name"]
256 | function_args = json.loads(function["function"]["arguments"])
257 | f, requires_context = self._registry.get_function(function_name)
258 | if requires_context:
259 | function_args["chat_context"] = chat_context
260 | if self.before_function_call:
261 | f, function_args = self.before_function_call(
262 | function_name, f, function_args
263 | )
264 | try:
265 | if inspect.iscoroutinefunction(f):
266 | warnings.warn("Async function is called from a sync agent.")
267 | result = asyncio_run(f(**function_args))
268 | else:
269 | result = f(**function_args)
270 | except Exception as e:
271 | print(e)
272 | result = "An error occurred while executing the function."
273 | messages.append(
274 | {
275 | "role": "tool",
276 | "tool_call_id": function["id"],
277 | "content": result,
278 | }
279 | )
280 | n_calls += 1
281 | else:
282 | messages.append(response)
283 | return KVData(_out_0=response["content"])
284 |
285 | async def async_forward(
286 | self, state: ChatPrompt, context: Context, store: KVData
287 | ) -> Tuple[str, List[dict]]:
288 | """Sends a message to the LLM and returns the response. Calls functions if the LLM requests it.
289 |
290 | Parameters
291 | ----------
292 | messages : Union[str, dict]
293 | The message(s) to send to the LLM
294 | context : ChatContext, optional
295 | The chat context, by default None
296 | Returns
297 | -------
298 | Tuple[str, List[dict]]
299 | A tuple containing the last response and the conversation history.
300 | """
301 | messages = state.dict
302 | n_calls = 0
303 | available_functions = self._registry.get_available_functions()
304 | chat_context = (context, store)
305 | while True:
306 | available_functions_i = (
307 | available_functions if n_calls < self.max_function_calls else None
308 | )
309 | response = await self.model.async_send_message(
310 | messages,
311 | functions=available_functions_i,
312 | usage_meter=store.usage_meter,
313 | )
314 | if response.get("tool_calls"):
315 | messages.append(response)
316 | for function in response["tool_calls"]:
317 | function_name = function["function"]["name"]
318 | function_args = json.loads(function["function"]["arguments"])
319 | f, requires_context = self._registry.get_function(function_name)
320 | if requires_context:
321 | function_args["chat_context"] = chat_context
322 | if self.before_function_call:
323 | f, function_args = self.before_function_call(
324 | function_name, f, function_args
325 | )
326 | try:
327 | if inspect.iscoroutinefunction(f):
328 | result = await f(**function_args)
329 | else:
330 | warnings.warn(
331 | "Sync function is called from an async agent."
332 | )
333 | result = await to_thread(f, **function_args)
334 | except Exception as e:
335 | print(e)
336 | result = "An error occurred while executing the function."
337 | messages.append(
338 | {
339 | "role": "tool",
340 | "tool_call_id": function["id"],
341 | "content": result,
342 | }
343 | )
344 | n_calls += 1
345 | else:
346 | messages.append(response)
347 | return KVData(_out_0=response["content"])
348 |
--------------------------------------------------------------------------------
/agent_dingo/agent/chat_context.py:
--------------------------------------------------------------------------------
1 | from agent_dingo.core.state import Context, KVData
2 | from typing import Tuple
3 |
4 | # mainly for legacy reasons
5 | ChatContext = Tuple[Context, KVData]
6 |
--------------------------------------------------------------------------------
/agent_dingo/agent/docgen.py:
--------------------------------------------------------------------------------
1 | from typing import Callable
2 | import inspect
3 | from agent_dingo.core.blocks import BaseLLM
4 | import re
5 |
6 | _SYSTEM_MSG = "You are a code generation tool. Your responses are limited to providing the docstrings of functions."
7 |
8 | _PROMPT = """
9 | You will be provided with a Python function. Your task is to generate a docstring in a Google style for that python function and return only the docsting delimited by triple backticks. Do not return the function itself.
10 |
11 | Python function:
12 | ```{code}```
13 |
14 | Docstring:
15 | """
16 |
17 |
18 | def generate_docstring(func: Callable, model: BaseLLM) -> str:
19 | """Generates a docstring for a given function.
20 |
21 | Parameters
22 | ----------
23 | func : Callable
24 | The function to generate a docstring for.
25 | model : str
26 | The model to use for generating the docstring.
27 |
28 | Returns
29 | -------
30 | str
31 | The generated docstring.
32 | """
33 | code = inspect.getsource(func)
34 | messages = [
35 | {"role": "system", "content": _SYSTEM_MSG},
36 | {"role": "user", "content": _PROMPT.format(code=code)},
37 | ]
38 | response = model.send_message(messages, temperature=0.0)
39 |
40 | response = (
41 | response["content"]
42 | .replace("```python\n", "")
43 | .replace("```", "")
44 | .replace('"""', "")
45 | .replace("'''", "")
46 | )
47 |
48 | return extract_substr(response)
49 |
50 |
51 | def extract_substr(input_string: str) -> str:
52 | """Extracts the desription and args from a docstring.
53 |
54 | Parameters
55 | ----------
56 | input_string : str
57 | The docstring to extract the description and args from.
58 |
59 | Returns
60 | -------
61 | str
62 | Reduced docstring containing only the description and the args.
63 | """
64 |
65 | # Find the 'Returns:' string and capture everything before it
66 | match = re.search(r"(.*?)Returns:", input_string, re.DOTALL)
67 |
68 | if match:
69 | # Extract the portion before 'Returns:' and remove leading/trailing whitespace
70 | before_returns = match.group(1).strip()
71 |
72 | # Remove everything after 'Returns:' including the next line
73 | result = re.sub(r"Returns:.*?(\n|$)", "", before_returns, flags=re.DOTALL)
74 | else:
75 | result = input_string # 'Returns:' string not found
76 |
77 | return result
78 |
--------------------------------------------------------------------------------
/agent_dingo/agent/function_descriptor.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass
2 | from typing import Callable, Optional, List
3 |
4 |
5 | @dataclass
6 | class FunctionDescriptor:
7 | name: str
8 | func: Callable
9 | json_repr: dict
10 | requires_context: bool
11 | required_context_keys: Optional[List[str]] = None
12 |
--------------------------------------------------------------------------------
/agent_dingo/agent/helpers.py:
--------------------------------------------------------------------------------
1 | from agent_dingo.core.state import Context
2 | from typing import Callable, List
3 | import inspect
4 | from agent_dingo.agent.chat_context import ChatContext
5 |
6 |
7 | def construct_json_repr(
8 | name: str, description: str, properties: dict, required: List[str]
9 | ) -> dict:
10 | """Constructs a JSON representation of a function.
11 |
12 | Parameters
13 | ----------
14 | name : str
15 | The name of the function.
16 | description : str
17 | The description of the function.
18 | properties : dict
19 | The properties of the function (arguments, their descriptions and types).
20 | required : List[str]
21 | The required arguments of the function.
22 |
23 | Returns
24 | -------
25 | dict
26 | The JSON representation of the function.
27 | """
28 | return {
29 | "name": name,
30 | "description": description,
31 | "parameters": {
32 | "type": "object",
33 | "properties": properties,
34 | },
35 | "required": required,
36 | }
37 |
38 |
39 | def get_required_args(func: Callable) -> List[str]:
40 | """Returns a list of the required arguments of a function.
41 |
42 | Parameters
43 | ----------
44 | func : Callable
45 | The function.
46 |
47 | Returns
48 | -------
49 | List[str]
50 | A list of the required arguments of the function.
51 | """
52 | sig = inspect.signature(func)
53 | params = sig.parameters
54 | required_args = [
55 | name
56 | for name, param in params.items()
57 | if param.default == inspect.Parameter.empty
58 | and not (name == "chat_context" and param.annotation == ChatContext)
59 | ]
60 | return required_args
61 |
--------------------------------------------------------------------------------
/agent_dingo/agent/langchain.py:
--------------------------------------------------------------------------------
1 | from agent_dingo.agent.function_descriptor import FunctionDescriptor
2 | from langchain.tools import (
3 | BaseTool as _BaseLangchainTool,
4 | format_tool_to_openai_function as _format_tool_to_openai_function,
5 | )
6 |
7 |
8 | def convert_langchain_tool(
9 | tool: _BaseLangchainTool, make_async: bool = False
10 | ) -> FunctionDescriptor:
11 | """Converts a langchain tool to a function descriptor.
12 |
13 | Parameters
14 | ----------
15 | tool : _BaseLangchainTool
16 | The langchain tool.
17 |
18 | Returns
19 | -------
20 | FunctionDescriptor
21 | The function descriptor.
22 | """
23 | if not isinstance(tool, _BaseLangchainTool):
24 | raise ValueError("tool must be a subclass of langchain.tools.BaseTool")
25 |
26 | json_repr = _format_tool_to_openai_function(tool=tool)
27 | name = json_repr["name"]
28 | if make_async:
29 |
30 | async def func(__arg1):
31 | return await tool.arun(tool_input=__arg1)
32 |
33 | else:
34 |
35 | def func(__arg1):
36 | return tool.run(tool_input=__arg1)
37 |
38 | requires_context = False
39 | descriptor = FunctionDescriptor(
40 | name=name, func=func, json_repr=json_repr, requires_context=requires_context
41 | )
42 | return descriptor
43 |
--------------------------------------------------------------------------------
/agent_dingo/agent/parser.py:
--------------------------------------------------------------------------------
1 | from docstring_parser import parse as _parse
2 | import ast
3 |
4 |
5 | _types = {
6 | "str": "string",
7 | "int": "integer",
8 | "float": "number",
9 | "bool": "boolean",
10 | "list": "array",
11 | "dict": "object",
12 | }
13 |
14 |
15 | def parse(docstring: str) -> dict:
16 | """Parses a docstring.
17 |
18 | Parameters
19 | ----------
20 | docstring : str
21 | The docstring to parse.
22 |
23 | Returns
24 | -------
25 | dict
26 | A dictionary containing the description and the arguments of the function.
27 |
28 | Raises
29 | ------
30 | ValueError
31 | If the docstring has no description.
32 | """
33 | parsed = _parse(docstring)
34 | description = ""
35 | if parsed.short_description:
36 | description = parsed.short_description
37 | if parsed.long_description:
38 | if description != "":
39 | description += "\n" + parsed.long_description
40 | else:
41 | description = parsed.long_description
42 | if description == "":
43 | raise ValueError("Docstring has no description")
44 | args = {}
45 | requires_context = False
46 | for arg in parsed.params:
47 | if arg.arg_name == "chat_context" and arg.type_name == "ChatContext":
48 | requires_context = True
49 | continue
50 | d = {}
51 | if "Enum:" in arg.description:
52 | arg_description = arg.description.split("Enum:")[0].strip()
53 | enum = arg.description.split("Enum:")[1].strip()
54 | try:
55 | enum = ast.literal_eval(enum)
56 | d = {"description": arg_description, "enum": enum}
57 | except Exception:
58 | d = {"description": arg_description}
59 | else:
60 | d = {"description": arg.description}
61 | if arg.type_name in _types:
62 | args[arg.arg_name] = {"type": _types[arg.type_name]}
63 | else:
64 | args[arg.arg_name] = {"type": "string"}
65 | args[arg.arg_name].update(d)
66 | return {"description": description, "properties": args}, requires_context
67 |
--------------------------------------------------------------------------------
/agent_dingo/agent/registry.py:
--------------------------------------------------------------------------------
1 | from typing import Callable, List, Optional, Tuple
2 |
3 |
4 | class Registry:
5 | """A registry for functions that can be called by the agent."""
6 |
7 | def __init__(self):
8 | self.__functions = {}
9 | self._required_context_keys = []
10 |
11 | def add(
12 | self,
13 | name: str,
14 | func: Callable,
15 | json_repr: dict,
16 | requires_context: bool,
17 | required_context_keys: Optional[List[str]] = None,
18 | ) -> None:
19 | """Adds a function to the registry.
20 |
21 | Parameters
22 | ----------
23 | name : str
24 | The name of the function.
25 | func : Callable
26 | The function.
27 | json_repr : dict
28 | The JSON representation of the function to be provided to the LLM.
29 | requires_context : bool
30 | Indicates whether the function requires a ChatContext object as one of its arguments.
31 | """
32 | if requires_context and required_context_keys is None:
33 | raise ValueError(
34 | "If requires_context is True, required_context_keys must be provided"
35 | )
36 | self.__functions[name] = {
37 | "func": func,
38 | "json_repr": json_repr,
39 | "requires_context": requires_context,
40 | "required_context_keys": required_context_keys or [],
41 | }
42 |
43 | def get_function(self, name: str) -> Tuple[Callable, bool]:
44 | """Retrieves a function from the registry.
45 |
46 | Parameters
47 | ----------
48 | name : str
49 | The name of the function.
50 |
51 | Returns
52 | -------
53 | Tuple[Callable, bool]
54 | A tuple containing the function and a boolean indicating whether the function requires a ChatContext object as one of its arguments.
55 | """
56 | try:
57 | return (
58 | self.__functions[name]["func"],
59 | self.__functions[name]["requires_context"],
60 | )
61 | except KeyError:
62 | return (
63 | (
64 | lambda *args, **kwargs: f"Error: function `{name}` is not available. Most likely, the name is incorrect."
65 | ),
66 | False,
67 | )
68 |
69 | def get_available_functions(self) -> List[dict]:
70 | """Returns a list of JSON representations of the functions in the registry.
71 |
72 | Returns
73 | -------
74 | List[dict]
75 | A list of JSON representations of the functions in the registry.
76 | """
77 | return [self.__functions[name]["json_repr"] for name in self.__functions]
78 |
79 | def get_required_context_keys(self) -> List[str]:
80 | """Returns a list of keys that are required in the ChatContext object.
81 |
82 | Returns
83 | -------
84 | List[str]
85 | A list of keys that are required in the ChatContext object.
86 | """
87 | keys = []
88 | for f in self.__functions.values():
89 | keys.extend(f["required_context_keys"])
90 | return keys
91 |
--------------------------------------------------------------------------------
/agent_dingo/core/blocks.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 | from typing import Any, Coroutine, Optional, Union, List, Dict
3 | from abc import ABC, abstractmethod
4 | from agent_dingo.core.message import Message
5 | from agent_dingo.core.state import State, ChatPrompt, KVData, Context, Store, UsageMeter
6 | from agent_dingo.core.output_parser import BaseOutputParser, DefaultOutputParser
7 | import re
8 | import joblib
9 | import inspect
10 | import warnings
11 |
12 |
13 | import os
14 |
15 | if os.environ.get("DINGO_ALLOW_NESTED_ASYNCIO", False):
16 | import nest_asyncio
17 |
18 | nest_asyncio.apply()
19 |
20 | from asyncio import (
21 | to_thread,
22 | gather,
23 | run as asyncio_run,
24 | )
25 |
26 |
27 | class Block(ABC):
28 | """Base building block of a pipeline"""
29 |
30 | @abstractmethod
31 | def forward(self, state: Optional[State], context: Context, store: Store) -> State:
32 | pass
33 |
34 | async def async_forward(
35 | self, state: Optional[State], context: Context, store: Store
36 | ) -> State:
37 | # ideally, this is never called as all children should implement async forward
38 | warnings.warn(
39 | f"Called async_forward, but {self.__class__} block does not have an async implementation."
40 | )
41 | return await to_thread(self.forward, state, context, store)
42 |
43 | @abstractmethod
44 | def get_required_context_keys(self) -> List[str]:
45 | """Each block must specify the keys it requires from the context."""
46 | pass
47 |
48 | def __rshift__(self, other: Block) -> Pipeline:
49 | return Pipeline() >> self >> other
50 |
51 | def __lshift__(self, other: Block) -> Pipeline:
52 | if isinstance(other, Pipeline):
53 | other.add_block(self)
54 | return other
55 | return Pipeline() >> other >> self
56 |
57 | def as_pipeline(self) -> Pipeline:
58 | return Pipeline() >> self
59 |
60 | def __and__(self, other: Block) -> Parallel:
61 | return Parallel() & self & other
62 |
63 |
64 | ######### REASONERS #########
65 |
66 |
67 | class BaseReasoner(Block):
68 | """A reasoner is a block that takes a prompt and returns a KVData object."""
69 |
70 | @abstractmethod
71 | def forward(self, state: ChatPrompt, context: Context, store: Store) -> KVData:
72 | pass
73 |
74 |
75 | class BaseLLM(BaseReasoner):
76 | """LLM is a type of reasoner that directly interacts with a language model."""
77 |
78 | supports_function_calls = False
79 |
80 | @abstractmethod
81 | def send_message(
82 | self, messages, functions=None, usage_meter: UsageMeter = None, **kwargs
83 | ):
84 | pass
85 |
86 | @abstractmethod
87 | async def async_send_message(
88 | self, messages, functions=None, usage_meter: UsageMeter = None, **kwargs
89 | ):
90 | pass
91 |
92 | def forward(self, state: ChatPrompt, context: Context, store: Store) -> KVData:
93 | if not isinstance(state, ChatPrompt):
94 | raise TypeError(f"State must be a ChatPrompt, got {type(state)}")
95 | new_state = KVData(
96 | _out_0=self.process_prompt(state, usage_meter=store.usage_meter)
97 | )
98 | return new_state
99 |
100 | async def async_forward(self, state: State | None, context: Context, store: Store):
101 | if not isinstance(state, ChatPrompt):
102 | raise TypeError(f"State must be a ChatPrompt, got {type(state)}")
103 | new_state = KVData(
104 | _out_0=await self.async_process_prompt(state, usage_meter=store.usage_meter)
105 | )
106 | return new_state
107 |
108 | def __call__(self, state: ChatPrompt) -> str:
109 | return self.process_prompt(state)
110 |
111 | def get_required_context_keys(self) -> List[str]:
112 | return []
113 |
114 | def process_prompt(
115 | self, prompt: ChatPrompt, usage_meter: Optional[UsageMeter] = None, **kwargs
116 | ):
117 | return self.send_message(prompt.dict, None, usage_meter)["content"]
118 |
119 | async def async_process_prompt(
120 | self, prompt: ChatPrompt, usage_meter: Optional[UsageMeter] = None, **kwargs
121 | ):
122 | return (await self.async_send_message(prompt.dict, None, usage_meter))[
123 | "content"
124 | ]
125 |
126 |
127 | class BaseAgent(BaseReasoner):
128 | """An agent is a type of reasoner that can autonomously perform multi-step reasoning."""
129 |
130 | pass
131 |
132 |
133 | ######### KVData Processors #########
134 |
135 |
136 | class BaseKVDataProcessor(Block):
137 | """KVDataProcessor is a block that takes a KVData object and returns a KVData object."""
138 |
139 | @abstractmethod
140 | def forward(self, state: KVData, context: Context, store: Store) -> KVData:
141 | pass
142 |
143 |
144 | class Squash(BaseKVDataProcessor):
145 | def __init__(self, template: str):
146 | """Squash block takes a KVData with multiple keys and squashes them into a single key using a template string.
147 |
148 | Parameters
149 | ----------
150 | template : str
151 | Template string with (in-order) placeholders for each key in the KVData object.
152 | """
153 | self.template = template
154 |
155 | def forward(self, state: KVData, context: Context, store: Store) -> KVData:
156 | return KVData(_out_0=self.template.format(*state.values()))
157 |
158 | async def async_forward(self, state: KVData, context: Context, store: Store):
159 | return self.forward(state, context, store)
160 |
161 | def get_required_context_keys(self) -> List[str]:
162 | return []
163 |
164 |
165 | ######### Prompt Builders #########
166 |
167 |
168 | class BasePromptBuilder(Block):
169 | """PromptBuilder is a block that takes a KVData object (or None) and returns a ChatPrompt."""
170 |
171 | @abstractmethod
172 | def forward(
173 | self, state: Optional[KVData], context: Context, store: Store
174 | ) -> ChatPrompt:
175 | pass
176 |
177 |
178 | class PromptBuilder(BasePromptBuilder):
179 | def __init__(
180 | self,
181 | messages: list[Message],
182 | from_state: Optional[Union[List[str], Dict[str, str]]] = None,
183 | from_store: Optional[Union[List[str], Dict[str, str]]] = None,
184 | ):
185 | """
186 | PromptBuilder formats the list of messages (templates) with values from the state, store and context.
187 |
188 | Parameters
189 | ----------
190 | messages : list[Message]
191 | List of message templates to format.
192 | from_state : Optional[Union[List[str], Dict[str, str]]], optional
193 | List of strings or mapping template->state that defines which placeholders should be populated by state values, by default None
194 | from_store : Optional[Union[List[str], Dict[str, str]]], optional
195 | List of strings or mapping template->store (where the store key is formated as .) that defines which placeholders should be populated by state values, by default None
196 | """
197 | self.messages = messages
198 | self._from_state_keys = []
199 | self._from_store_keys = []
200 | if from_state is None:
201 | self._from_state = {}
202 | elif isinstance(from_state, list):
203 | self._from_state = {}
204 | self._from_state_keys.extend(from_state)
205 | for i, k in enumerate(from_state):
206 | self._from_state[k] = f"_out_{i}"
207 | elif isinstance(from_state, dict):
208 | self._from_state = from_state
209 | self._from_state_keys.extend(from_state.keys())
210 | else:
211 | raise TypeError(
212 | f"from_state must be a list or dict, got {type(from_state)}"
213 | )
214 | if from_store is None:
215 | self._from_store = {}
216 | elif isinstance(from_store, list):
217 | self._from_store = {}
218 | self._from_store_keys.extend(from_store)
219 | for i, k in enumerate(from_store):
220 | self._from_store[k] = f"_out_{i}"
221 | elif isinstance(from_store, dict):
222 | self._from_store = from_store
223 | self._from_store_keys = from_store.keys()
224 | else:
225 | raise TypeError(
226 | f"from_store must be a list or dict, got {type(from_store)}"
227 | )
228 |
229 | self._placeholder_names = self._get_placeholder_names()
230 |
231 | def _get_placeholder_names(self) -> set:
232 | placeholder_pattern = r"\{(\w+)\}"
233 | placeholder_names = set()
234 |
235 | for message in self.messages:
236 | found_placeholders = re.findall(placeholder_pattern, message.content)
237 | placeholder_names.update(found_placeholders)
238 |
239 | return placeholder_names
240 |
241 | def forward(
242 | self, state: Optional[KVData], context: Context, store: Store
243 | ) -> ChatPrompt:
244 | values = {}
245 | for n in self._placeholder_names:
246 | if n in self._from_state.keys():
247 | values[n] = state[self._from_state[n]]
248 | elif n in self._from_store.keys():
249 | if "." in self._from_store[n]:
250 | outer, inner = self._from_store[n].split(".")
251 | values[n] = store.get_data(outer)[inner]
252 | else:
253 | raise ValueError(
254 | "Store key must be formatted as ."
255 | )
256 | elif n in context.keys():
257 | values[n] = context[n]
258 | else:
259 | raise KeyError(f"Could not find value for placeholder {n}")
260 | updated_messages = [type(m)(m.content.format(**values)) for m in self.messages]
261 | return ChatPrompt(updated_messages)
262 |
263 | async def async_forward(self, state: KVData, context: Context, store: Store):
264 | return self.forward(state, context, store)
265 |
266 | def get_required_context_keys(self) -> List[str]:
267 | keys = []
268 | for n in self._placeholder_names:
269 | if n not in self._from_state_keys and n not in self._from_store_keys:
270 | keys.append(n)
271 | return keys
272 |
273 |
274 | ######### Prompt Modifiers #########
275 |
276 |
277 | class BasePromptModifier(Block):
278 | """A prompt modifier is a block that takes a ChatPrompt and returns a ChatPrompt."""
279 |
280 | @abstractmethod
281 | def forward(self, state: ChatPrompt, context: Context, store: Store) -> ChatPrompt:
282 | pass
283 |
284 |
285 | ######### Special Blocks #########
286 |
287 |
288 | class Pipeline(Block):
289 | def __init__(self, output_parser: Optional[BaseOutputParser] = None):
290 | """
291 | A pipeline is a sequence of blocks that are executed in order.
292 | The pipeline itself is a block that can be used in other pipelines.
293 |
294 | Parameters
295 | ----------
296 | output_parser : Optional[BaseOutputParser], optional
297 | custom output parser of the last step, by default None
298 | """
299 | self.output_parser: BaseOutputParser = output_parser or DefaultOutputParser()
300 | self._blocks = []
301 |
302 | def add_block(self, block: Block):
303 | """
304 | Add a block to the pipeline.
305 |
306 | Parameters
307 | ----------
308 | block : Block
309 | Block to add to the pipeline.
310 | """
311 | if not isinstance(block, Block):
312 | raise TypeError(f"Expected a Block, got {type(block)}")
313 | self._blocks.append(block)
314 |
315 | def forward(self, state: Optional[State], context: Context, store: Store) -> State:
316 | running_state = state
317 | for block in self._blocks:
318 | running_state = block.forward(
319 | state=running_state, context=context, store=store
320 | )
321 | return running_state
322 |
323 | async def async_forward(
324 | self, state: Optional[State], context: Context, store: Store
325 | ) -> State:
326 | running_state = state
327 | for block in self._blocks:
328 | running_state = await block.async_forward(
329 | state=running_state, context=context, store=store
330 | )
331 | return running_state
332 |
333 | def run(self, _state: Optional[State] = None, **kwargs: Dict[str, str]):
334 | """
335 | Runs the pipeline with the given state and context (populated with kwargs).
336 | Each run initializes a new empty store.
337 | The output of the last block is parsed using the output_parser and returned.
338 |
339 | Parameters
340 | ----------
341 | _state : Optional[State], optional
342 | initial state, by default None
343 | """
344 | context = Context(**kwargs)
345 | store = Store()
346 | out = self.forward(state=_state, context=context, store=store)
347 | return self.output_parser.parse(out), store.usage_meter.get_usage()
348 |
349 | async def async_run(
350 | self, _state: Optional[State] = None, **kwargs: Dict[str, str]
351 | ) -> str:
352 | context = Context(**kwargs)
353 | store = Store()
354 | out = await self.async_forward(state=_state, context=context, store=store)
355 | return self.output_parser.parse(out), store.usage_meter.get_usage()
356 |
357 | def __rshift__(self, other: Block) -> Pipeline:
358 | self.add_block(other)
359 | return self
360 |
361 | def get_required_context_keys(self) -> List[str]:
362 | keys = []
363 | for block in self._blocks:
364 | keys.extend(block.get_required_context_keys())
365 | return keys
366 |
367 |
368 | class Parallel(Block):
369 | def __init__(self):
370 | """
371 | A parallel block executes multiple sub-blocks in parallel. The output of each block is stored as a separate key in the KVData object.
372 | """
373 | self.blocks = []
374 |
375 | def add_block(self, block: Block):
376 | """
377 | Add a block to the parallel block.
378 |
379 | Parameters
380 | ----------
381 | block : Block
382 | Block to add.
383 | """
384 | if not isinstance(block, Block):
385 | raise TypeError(f"Expected a Block, got {type(block)}")
386 | self.blocks.append(block)
387 |
388 | def forward(self, state: Optional[State], context: Context, store: Store) -> State:
389 | # run all blocks in parallel
390 | states = joblib.Parallel(n_jobs=len(self.blocks), backend="threading")(
391 | joblib.delayed(block.forward)(state=state, context=context, store=store)
392 | for block in self.blocks
393 | )
394 | out = {}
395 | for i, state in enumerate(states):
396 | if i == 0 and isinstance(state, ChatPrompt):
397 | # allow a special case where the first block returns a ChatPrompt
398 | # the ouput of remaining branches will be ignored
399 | return state
400 | if not isinstance(state, KVData):
401 | raise TypeError(
402 | f"Expected KVData, got {type(state)} from block {i} of {len(states)}"
403 | )
404 | if len(state.keys()) != 1:
405 | raise ValueError(
406 | f"Expected KVData with one key `_out_0`, got {len(state.keys())} keys from block {i} of {len(states)}"
407 | )
408 | out[f"_out_{i}"] = state["_out_0"]
409 | return KVData(**out)
410 |
411 | async def async_forward(
412 | self, state: Optional[State], context: Context, store: Store
413 | ) -> State:
414 | tasks = [block.async_forward(state, context, store) for block in self.blocks]
415 | states = await gather(*tasks)
416 |
417 | out = {}
418 | for i, state in enumerate(states):
419 | if i == 0 and isinstance(state, ChatPrompt):
420 | return state
421 | if not isinstance(state, KVData):
422 | raise TypeError(
423 | f"Expected KVData, got {type(state)} from block {i} of {len(states)}"
424 | )
425 | if len(state.keys()) != 1:
426 | raise ValueError(
427 | f"Expected KVData with one key `_out_0`, got {len(state.keys())} keys from block {i} of {len(states)}"
428 | )
429 | out[f"_out_{i}"] = state["_out_0"]
430 | return KVData(**out)
431 |
432 | def __and__(self, other: Block) -> Parallel:
433 | self.add_block(other)
434 | return self
435 |
436 | def get_required_context_keys(self) -> List[str]:
437 | keys = []
438 | for block in self.blocks:
439 | keys.extend(block.get_required_context_keys())
440 |
441 |
442 | class Identity(Block):
443 | """NO-OP block that returns the input state as is."""
444 |
445 | def forward(self, state: Optional[State], context: Context, store: Store) -> State:
446 | return state
447 |
448 | async def async_forward(
449 | self, state: Optional[State], context: Context, store: Store
450 | ) -> State:
451 | return state
452 |
453 | def get_required_context_keys(self) -> List[str]:
454 | return []
455 |
456 |
457 | class SaveState(Block):
458 | def __init__(self, key: str):
459 | """Saves the current state to the store.
460 |
461 | Parameters
462 | ----------
463 | key : str
464 | Key to save the state under.
465 | """
466 | self.key = key
467 |
468 | def get_required_context_keys(self) -> List[str]:
469 | return []
470 |
471 | def forward(self, state: State | None, context: Context, store: Store) -> State:
472 | store.update(self.key, state)
473 | return state
474 |
475 | async def async_forward(
476 | self, state: State | None, context: Context, store: Store
477 | ) -> State:
478 | return self.forward(state, context, store)
479 |
480 |
481 | class LoadState(Block):
482 | def __init__(self, from_: str, key: str):
483 | """
484 | Loads the state from the store.
485 |
486 | Parameters
487 | ----------
488 | from_ : str
489 | Defines whether to load from a Prompt of KVData section of the store.
490 | key : str
491 | Key to load the state from.
492 | """
493 | if from_ not in ["prompts", "data"]:
494 | raise ValueError(f"from_ must be 'store' or 'context', got {from_}")
495 | self.from_ = from_
496 | self.key = key
497 |
498 | def get_required_context_keys(self) -> List[str]:
499 | return []
500 |
501 | def forward(self, state: State | None, context: Context, store: Store) -> State:
502 | if self.from_ == "prompts":
503 | return store.get_prompt(self.key)
504 | elif self.from_ == "data":
505 | return store.get_data(self.key)
506 | else:
507 | raise ValueError(f"from_ must be 'store' or 'context', got {self.from_}")
508 |
509 | async def async_forward(
510 | self, state: State | None, context: Context, store: Store
511 | ) -> State:
512 | return self.forward(state, context, store)
513 |
514 |
515 | class InlineBlock(Block):
516 | def __init__(self, required_context_keys: Optional[List[str]] = None):
517 | """A decorator to convert a function into an inline block.
518 |
519 | Parameters
520 | ----------
521 | required_context_keys : Optional[List[str]], optional
522 | specifies the context keys required by the function, by default None
523 | """
524 | self.required_context_keys = required_context_keys or []
525 | self.func = None
526 |
527 | def get_required_context_keys(self) -> List[str]:
528 | return self.required_context_keys
529 |
530 | def __call__(self, func):
531 | self.func = func
532 | return self
533 |
534 | def _get_output(self, out) -> State:
535 | if isinstance(out, State.__args__):
536 | return out
537 | elif isinstance(out, dict):
538 | return KVData(**out)
539 | elif isinstance(out, str):
540 | return KVData(_out_0=out)
541 | elif isinstance(out, (list, tuple)):
542 | return KVData(**{f"_out_{i}": v for i, v in enumerate(out)})
543 | raise TypeError(f"Expected a State, dict, str, or list, got {type(out)}")
544 |
545 | def forward(self, state: State | None, context: Context, store: Store) -> State:
546 | if inspect.iscoroutinefunction(self.func):
547 | warnings.warn(f"Called forward on an async inline block.")
548 | out = asyncio_run(self.func(state, context, store))
549 | else:
550 | out = self.func(state, context, store)
551 | return self._get_output(out)
552 |
553 | async def async_forward(
554 | self, state: State | None, context: Context, store: Store
555 | ) -> State:
556 | if inspect.iscoroutinefunction(self.func):
557 | out = await self.func(state, context, store)
558 | else:
559 | warnings.warn(f"Called async_forward on a non-async inline block.")
560 | out = await to_thread(self.func, state, context, store)
561 | return self._get_output(out)
562 |
--------------------------------------------------------------------------------
/agent_dingo/core/message.py:
--------------------------------------------------------------------------------
1 | class Message:
2 | """A base class to represent a message."""
3 |
4 | role: str = "undefined"
5 |
6 | def __init__(self, content: str):
7 | self.content = content
8 |
9 | def __repr__(self):
10 | return f'Message(role="{self.role}" content="{self.content}")'
11 |
12 | @property
13 | def dict(self):
14 | return {"role": self.role, "content": self.content}
15 |
16 |
17 | class UserMessage(Message):
18 | role = "user"
19 | pass
20 |
21 |
22 | class SystemMessage(Message):
23 | role = "system"
24 | pass
25 |
26 |
27 | class AssistantMessage(Message):
28 | role = "assistant"
29 | pass
30 |
--------------------------------------------------------------------------------
/agent_dingo/core/output_parser.py:
--------------------------------------------------------------------------------
1 | from abc import ABC, abstractmethod
2 | from agent_dingo.core.state import State, ChatPrompt, KVData
3 |
4 |
5 | class BaseOutputParser(ABC):
6 | """Base class for output parsers."""
7 |
8 | def parse(self, output: State) -> str:
9 | """Delegates the parsing to the appropriate method based on the type of output.
10 |
11 | Parameters
12 | ----------
13 | output : State
14 | The state object to parse.
15 |
16 | Returns
17 | -------
18 | str
19 | parsed output
20 | """
21 | if not isinstance(output, State.__args__):
22 | raise TypeError(f"Expected KVData, got {type(output)}")
23 | elif isinstance(output, KVData):
24 | return self._parse_kvdata(output)
25 | else:
26 | return self._parse_chat(output)
27 |
28 | @abstractmethod
29 | def _parse_chat(self, output: ChatPrompt) -> str:
30 | pass
31 |
32 | @abstractmethod
33 | def _parse_kvdata(self, output: KVData) -> str:
34 | pass
35 |
36 |
37 | class DefaultOutputParser(BaseOutputParser):
38 | """Default output parser. Expected output is a KVData with the key `_out_0`."""
39 |
40 | def _parse_chat(self, output: ChatPrompt) -> str:
41 | raise RuntimeError(
42 | "Cannot parse chat output. The output must be an instance of KVData."
43 | )
44 |
45 | def _parse_kvdata(self, output: KVData) -> str:
46 | if "_out_0" in output.keys():
47 | return output["_out_0"]
48 | else:
49 | raise KeyError("Could not find output in KVData.")
50 |
--------------------------------------------------------------------------------
/agent_dingo/core/state.py:
--------------------------------------------------------------------------------
1 | from typing import Union, List, Any
2 | from agent_dingo.core.message import Message
3 | from threading import Lock
4 |
5 |
6 | class ChatPrompt:
7 | def __init__(self, messages: List[Message]):
8 | """A collection of messages that are sent to the model.
9 |
10 | Parameters
11 | ----------
12 | messages : List[Message]
13 | The messages to send to the model.
14 | """
15 | self.messages = messages
16 |
17 | @property
18 | def dict(self):
19 | return [m.dict for m in self.messages]
20 |
21 | def __repr__(self):
22 | return f"ChatPrompt({self.messages})"
23 |
24 |
25 | class KVData:
26 | def __init__(self, **kwargs):
27 | """A dictionary-like object that stores key-value pairs."""
28 | self._dict = {}
29 | for k, v in kwargs.items():
30 | self._dict[k] = v
31 |
32 | def update(self, key: str, value: str):
33 | """Update the value of a key.
34 |
35 | Parameters
36 | ----------
37 | key : str
38 | key to update
39 | value : str
40 | value to set
41 | """
42 | if not isinstance(key, str) or not isinstance(value, str):
43 | raise TypeError("Both key and value must be strings.")
44 | # make existing keys immutable
45 | if key in self._dict:
46 | raise KeyError(f"Key {key} already exists.")
47 | self._dict[key] = value
48 |
49 | def __getitem__(self, key):
50 | return self._dict[key]
51 |
52 | def __repr__(self):
53 | return "KVData({0})".format(str(self._dict))
54 |
55 | def __dict__(self):
56 | return self._dict
57 |
58 | def keys(self):
59 | return self._dict.keys()
60 |
61 | def values(self):
62 | return self._dict.values()
63 |
64 | @property
65 | def dict(self):
66 | return self._dict.copy()
67 |
68 |
69 | State = Union[ChatPrompt, KVData]
70 |
71 |
72 | class Context(KVData):
73 | def update(self, key, value):
74 | raise RuntimeError("Context is immutable.")
75 |
76 |
77 | class UsageMeter:
78 | def __init__(self):
79 | """An object that resides in the store and keeps track of the usage. It is updated by the LLMs."""
80 | self.prompt_tokens = 0
81 | self.completion_tokens = 0
82 | self.last_finish_reason = None
83 | self._lock = Lock()
84 |
85 | def increment(self, prompt_tokens: int, completion_tokens: int) -> None:
86 | with self._lock:
87 | self.prompt_tokens += prompt_tokens
88 | self.completion_tokens += completion_tokens
89 |
90 | def get_usage(self) -> dict:
91 | return {
92 | "prompt_tokens": self.prompt_tokens,
93 | "completion_tokens": self.completion_tokens,
94 | "total_tokens": self.prompt_tokens + self.completion_tokens,
95 | }
96 |
97 |
98 | class Store:
99 | def __init__(self):
100 | """A simple key-value store that stores prompts, data, and other miscellaneous objects for the duration of a single pipeline run."""
101 | self._data = {}
102 | self._prompts = {}
103 | self._misc = {}
104 | self.usage_meter = UsageMeter()
105 | self._lock = Lock() # probably not really needed
106 |
107 | def _update(self, key: str, item):
108 | if not isinstance(key, str):
109 | raise TypeError("Key must be a string.")
110 | if isinstance(item, ChatPrompt):
111 | self._prompts[key] = item
112 | elif isinstance(item, KVData):
113 | self._data[key] = item
114 | else:
115 | self._misc[key] = item
116 |
117 | def update(self, key: str, item: Any):
118 | """Update the store with a new item.
119 |
120 | Parameters
121 | ----------
122 | key : str
123 | A key to store the item under.
124 | item : Any
125 | The item to store.
126 | """
127 | with self._lock:
128 | self._update(key, item)
129 |
130 | def get_misc(self, key: str):
131 | with self._lock:
132 | return self._misc[key]
133 |
134 | def get_data(self, key: str) -> KVData:
135 | with self._lock:
136 | return self._data[key]
137 |
138 | def get_prompt(self, key: str) -> ChatPrompt:
139 | with self._lock:
140 | return self._prompts[key]
141 |
--------------------------------------------------------------------------------
/agent_dingo/llm/gemini.py:
--------------------------------------------------------------------------------
1 | try:
2 | from vertexai import init as _vertex_init
3 | from vertexai.generative_models import (
4 | Content,
5 | FunctionDeclaration,
6 | GenerativeModel,
7 | Part,
8 | Tool,
9 | )
10 | except ImportError:
11 | raise ImportError(
12 | "VertexAI is not installed. Please install it using `pip install agent-dingo[vertexai]`"
13 | )
14 | from typing import Optional, List
15 | from agent_dingo.core.blocks import BaseLLM
16 | from agent_dingo.core.state import UsageMeter
17 | import json
18 |
19 | _ROLES_MAP = {
20 | "user": "USER",
21 | "system": "USER",
22 | "assistant": "MODEL",
23 | }
24 |
25 |
26 | class Gemini(BaseLLM):
27 | def __init__(
28 | self, model: str, project: str, location: str, temperature: float = 0.7
29 | ):
30 | """
31 | VertexAI Gemini LLM.
32 |
33 | Parameters
34 | ----------
35 | model : str
36 | model to use
37 | project : str
38 | project id to use
39 | location : str
40 | location to use
41 | temperature : float, optional
42 | generation temperature, by default 0.7
43 | """
44 | _vertex_init(project=project, location=location)
45 | self._model = GenerativeModel(model)
46 | self.supports_function_calls = True
47 | self.temperature = temperature
48 |
49 | def _get_tools(self, functions: Optional[List]) -> List[Tool]:
50 | if functions is None:
51 | return []
52 | declarations: List[FunctionDeclaration] = []
53 | for f in functions:
54 | declaration = FunctionDeclaration(
55 | name=f["name"],
56 | description=f["description"],
57 | parameters=f["parameters"],
58 | )
59 | declarations.append(declaration)
60 | tool = Tool(function_declarations=declarations)
61 | return [tool]
62 |
63 | def send_message(
64 | self,
65 | messages,
66 | functions=None,
67 | usage_meter: UsageMeter = None,
68 | temperature: Optional[float] = None,
69 | **kwargs,
70 | ):
71 | converted = self._openai_to_gemini(messages)
72 | out = self._model.generate_content(
73 | contents=converted,
74 | tools=self._get_tools(functions),
75 | generation_config={"temperature": temperature or self.temperature},
76 | )
77 | return self._postprocess_response(out, usage_meter)
78 |
79 | async def async_send_message(
80 | self,
81 | messages,
82 | functions=None,
83 | usage_meter: UsageMeter = None,
84 | temperature=None,
85 | **kwargs,
86 | ):
87 | converted = self._openai_to_gemini(messages)
88 | response = await self._model.generate_content_async(
89 | contents=converted,
90 | tools=self._get_tools(functions),
91 | generation_config={"temperature": temperature or self.temperature},
92 | )
93 | return self._postprocess_response(response, usage_meter)
94 |
95 | def _postprocess_response(self, response, usage_meter: UsageMeter = None):
96 | n_prompt_tokens = response._raw_response.usage_metadata.prompt_token_count
97 | n_completion_tokens = (
98 | response._raw_response.usage_metadata.candidates_token_count
99 | )
100 | if usage_meter:
101 | usage_meter.increment(
102 | prompt_tokens=n_prompt_tokens,
103 | completion_tokens=n_completion_tokens,
104 | )
105 | return self._gemini_to_openai(response)
106 |
107 | def _openai_to_gemini(self, messages):
108 | """Converts OpenAI messages to VertexAI messages."""
109 |
110 | converted = []
111 |
112 | for message in messages:
113 | if "_cache" in message.keys():
114 | converted.append(message["_cache"])
115 | elif message["role"] in _ROLES_MAP.keys():
116 | content = Content(
117 | role=_ROLES_MAP[message["role"]],
118 | parts=[
119 | Part.from_text(message["content"]),
120 | ],
121 | )
122 | converted.append(content)
123 | elif message["role"] == "tool":
124 | content = Content(
125 | role="function",
126 | parts=[
127 | Part.from_function_response(
128 | name=message["tool_call_id"],
129 | response={
130 | "content": message["content"],
131 | },
132 | )
133 | ],
134 | )
135 | converted.append(content)
136 | else:
137 | raise ValueError(f"Invalid message {message}")
138 | return converted
139 |
140 | def _gemini_to_openai(self, response):
141 | """Converts the Gemini response to OpenAI response."""
142 | try:
143 | content = response.candidates[0].content.parts[0].text
144 | except AttributeError:
145 | content = None
146 | function_call = response.candidates[0].function_calls
147 |
148 | transformed_calls = []
149 | if len(function_call) > 0:
150 | for call in function_call:
151 | id_ = call.name
152 | type_ = "function"
153 | name = call.name
154 | args = {arg: call.args[arg] for arg in call.args}
155 | transformed_call = {
156 | "id": id_,
157 | "type": type_,
158 | "function": {
159 | "name": name,
160 | "arguments": json.dumps(args),
161 | },
162 | }
163 | transformed_calls.append(transformed_call)
164 |
165 | return {
166 | "content": content,
167 | "role": "assistant",
168 | "tool_calls": transformed_calls,
169 | "_cache": response.candidates[0].content,
170 | }
171 |
--------------------------------------------------------------------------------
/agent_dingo/llm/litellm.py:
--------------------------------------------------------------------------------
1 | from typing import Optional, Dict
2 | from agent_dingo.core.blocks import BaseLLM
3 | from agent_dingo.core.state import UsageMeter
4 |
5 | try:
6 | from litellm import completion, acompletion
7 | except ImportError:
8 | raise ImportError(
9 | "litellm is not installed. Please install it using `pip install agent-dingo[litellm]`"
10 | )
11 |
12 |
13 | class LiteLLM(BaseLLM):
14 | def __init__(
15 | self,
16 | model: str,
17 | temperature: float = 0.7,
18 | completion_extra_kwargs: Optional[Dict] = None,
19 | ):
20 | """
21 | Lite LLM client to interact with various LLM providers.
22 |
23 | Parameters
24 | ----------
25 | model : str
26 | model to use
27 | temperature : float, optional
28 | generation temparature, by default 0.7
29 | completion_extra_kwargs : Optional[Dict], optional
30 | additional arguments to be passed to a completion method, by default None
31 | """
32 |
33 | self.temperature = temperature
34 | self.model = model
35 | self.completion_extra_kwargs = completion_extra_kwargs or {}
36 |
37 | def send_message(
38 | self,
39 | messages,
40 | functions=None,
41 | usage_meter: UsageMeter = None,
42 | temperature: Optional[float] = None,
43 | **kwargs,
44 | ):
45 | response = completion(
46 | messages=messages,
47 | model=self.model,
48 | temperature=temperature or self.temperature,
49 | **self.completion_extra_kwargs,
50 | )
51 | self._log_usage(response, usage_meter)
52 | return response["choices"][0]["message"]
53 |
54 | async def async_send_message(
55 | self,
56 | messages,
57 | functions=None,
58 | usage_meter: UsageMeter = None,
59 | temperature=None,
60 | **kwargs,
61 | ):
62 | response = await acompletion(
63 | messages=messages,
64 | model=self.model,
65 | temperature=temperature or self.temperature,
66 | **self.completion_extra_kwargs,
67 | )
68 | self._log_usage(response, usage_meter)
69 | return response["choices"][0]["message"]
70 |
71 | def _log_usage(self, response, usage_meter: UsageMeter = None):
72 | if usage_meter:
73 | usage_meter.increment(
74 | prompt_tokens=response["usage"]["prompt_tokens"],
75 | completion_tokens=response["usage"]["completion_tokens"],
76 | )
77 |
--------------------------------------------------------------------------------
/agent_dingo/llm/llama_cpp.py:
--------------------------------------------------------------------------------
1 | from typing import Optional, Any
2 | from agent_dingo.core.blocks import BaseLLM
3 | from agent_dingo.core.state import UsageMeter
4 |
5 | try:
6 | from llama_cpp import Llama as _Llama
7 | except ImportError:
8 | raise ImportError(
9 | "Llama.cpp is not installed. Please install it using `pip install agent-dingo[llama-cpp]`"
10 | )
11 | import threading
12 | from concurrent.futures import ThreadPoolExecutor
13 | import asyncio
14 |
15 |
16 | class LlamaCPP(BaseLLM):
17 | def __init__(
18 | self, model: str, temperature: float = 0.7, verbose: bool = False, **kwargs: Any
19 | ):
20 | """
21 | Llama.cpp client to run LLMs in a GGUF format locally.
22 |
23 | Parameters
24 | ----------
25 | model : str
26 | path to the model file
27 | temperature : float, optional
28 | model temperature, by default 0.7
29 | verbose : bool, optional
30 | flag to enable verbosity, by default False
31 | **kwargs : Any
32 | additional arguments to be passed to the model constructor
33 | """
34 |
35 | self.model = _Llama(model, verbose=verbose, **kwargs)
36 | self.temperature = temperature
37 | self._lock = threading.Lock()
38 | self._executor: Optional[ThreadPoolExecutor] = None
39 |
40 | def send_message(
41 | self,
42 | messages,
43 | functions=None,
44 | usage_meter: UsageMeter = None,
45 | temperature: Optional[float] = None,
46 | **kwargs,
47 | ):
48 | with self._lock:
49 | response = self.model.create_chat_completion(
50 | messages, temperature=temperature or self.temperature
51 | )
52 | self._log_usage(response, usage_meter)
53 | return response["choices"][0]["message"]
54 |
55 | async def async_send_message(
56 | self,
57 | messages,
58 | functions=None,
59 | usage_meter: UsageMeter = None,
60 | temperature=None,
61 | **kwargs,
62 | ):
63 | loop = asyncio.get_event_loop()
64 | response = await loop.run_in_executor(
65 | self._get_executor(),
66 | self.model.create_chat_completion,
67 | messages,
68 | temperature=temperature or self.temperature,
69 | )
70 | self._log_usage(response, usage_meter)
71 | return response["choices"][0]["message"]
72 |
73 | def _get_executor(self):
74 | if self._executor is None:
75 | self._executor = ThreadPoolExecutor(max_workers=1)
76 | return self._executor
77 |
78 | def _log_usage(self, response, usage_meter: UsageMeter = None):
79 | if usage_meter:
80 | usage_meter.increment(
81 | prompt_tokens=response["usage"]["prompt_tokens"],
82 | completion_tokens=response["usage"]["completion_tokens"],
83 | )
84 |
--------------------------------------------------------------------------------
/agent_dingo/llm/openai.py:
--------------------------------------------------------------------------------
1 | from typing import Optional, List
2 | from agent_dingo.core.blocks import BaseLLM
3 | from agent_dingo.core.state import UsageMeter
4 | import openai
5 | from tenacity import retry, stop_after_attempt, wait_fixed
6 |
7 |
8 | @retry(stop=stop_after_attempt(3), wait=wait_fixed(3))
9 | def _send_message(
10 | client: openai.OpenAI,
11 | messages: dict,
12 | model: str = "gpt-3.5-turbo-0613",
13 | functions: Optional[List] = None,
14 | temperature: float = 1.0,
15 | ) -> dict:
16 | """Sends messages to the LLM and returns the response.
17 |
18 | Parameters
19 | ----------
20 | messages : dict
21 | Messages to send to the LLM.
22 | model : str, optional
23 | Model to use, by default "gpt-3.5-turbo-0613"
24 | functions : Optional[List], optional
25 | List of functions to use, by default None
26 | temperature : float, optional
27 | Temperature to use, by default 1.
28 | log_usage : Callable, optional
29 | Function to log usage, by default None
30 |
31 | Returns
32 | -------
33 | dict
34 | The response from the LLM.
35 | """
36 | f = {}
37 | if functions is not None:
38 | f["tools"] = [{"type": "function", "function": f} for f in functions]
39 | f["tool_choice"] = "auto"
40 | response = client.chat.completions.create(
41 | model=model, messages=messages, temperature=temperature, **f
42 | )
43 | return response.choices[0].message, response
44 |
45 |
46 | @retry(stop=stop_after_attempt(3), wait=wait_fixed(3))
47 | async def _async_send_message(
48 | client: openai.AsyncOpenAI,
49 | messages: dict,
50 | model: str = "gpt-3.5-turbo-0613",
51 | functions: Optional[List] = None,
52 | temperature: float = 1.0,
53 | ) -> dict:
54 | f = {}
55 | if functions is not None:
56 | f["tools"] = [{"type": "function", "function": f} for f in functions]
57 | f["tool_choice"] = "auto"
58 | response = await client.chat.completions.create(
59 | model=model, messages=messages, temperature=temperature, **f
60 | )
61 | return response.choices[0].message, response
62 |
63 |
64 | def to_dict(obj):
65 | if isinstance(obj, dict):
66 | return {k: to_dict(v) for k, v in obj.items()}
67 | elif isinstance(obj, list):
68 | return [to_dict(item) for item in obj]
69 | elif hasattr(obj, "__dict__"):
70 | return {k: to_dict(v) for k, v in obj.__dict__.items() if not k.startswith("_")}
71 | else:
72 | return obj
73 |
74 |
75 | class OpenAI(BaseLLM):
76 | def __init__(
77 | self,
78 | model: str,
79 | temperature: float = 0.7,
80 | base_url: Optional[str] = None,
81 | # TODO: Add per instance API key
82 | # TODO: Add remaining generation parameters
83 | ):
84 | """
85 | OpenAI client to interact with compatible APIs.
86 |
87 | Parameters
88 | ----------
89 | model : str
90 | model to use
91 | temperature : float, optional
92 | generation temperature, by default 0.7
93 | base_url : Optional[str], optional
94 | _description_, by default None
95 | """
96 | self.model = model
97 | self.temperature = temperature
98 | self.client = openai.OpenAI(base_url=base_url)
99 | self.async_client = openai.AsyncOpenAI(base_url=base_url)
100 | if base_url is None:
101 | self.supports_function_calls = True
102 |
103 | def send_message(
104 | self,
105 | messages,
106 | functions=None,
107 | usage_meter: UsageMeter = None,
108 | temperature=None,
109 | **kwargs,
110 | ):
111 | response = _send_message(
112 | client=self.client,
113 | messages=messages,
114 | model=self.model,
115 | functions=functions,
116 | temperature=temperature or self.temperature,
117 | )
118 | return self._postprocess_response(response, usage_meter)
119 |
120 | async def async_send_message(
121 | self,
122 | messages,
123 | functions=None,
124 | usage_meter: UsageMeter = None,
125 | temperature=None,
126 | **kwargs,
127 | ):
128 | response = await _async_send_message(
129 | client=self.async_client,
130 | messages=messages,
131 | model=self.model,
132 | functions=functions,
133 | temperature=temperature or self.temperature,
134 | )
135 | return self._postprocess_response(response, usage_meter)
136 |
137 | def _postprocess_response(self, response, usage_meter: UsageMeter = None):
138 | res, full_res = to_dict(response[0]), to_dict(response[1])
139 | if usage_meter:
140 | usage_meter.increment(
141 | prompt_tokens=full_res["usage"]["prompt_tokens"],
142 | completion_tokens=full_res["usage"]["completion_tokens"],
143 | )
144 | if "function_call" in res.keys():
145 | del res["function_call"]
146 | return res
147 |
--------------------------------------------------------------------------------
/agent_dingo/rag/base.py:
--------------------------------------------------------------------------------
1 | from abc import ABC, abstractmethod
2 | from typing import List, Optional
3 | from dataclasses import dataclass
4 | import hashlib
5 | import json
6 | from typing import Union
7 |
8 |
9 | @dataclass
10 | class Document:
11 | content: str
12 | metadata: dict
13 |
14 | @property
15 | def hash(self) -> str:
16 | metadata = json.dumps(self.metadata, sort_keys=True)
17 | return hashlib.sha256((self.content + metadata).encode()).hexdigest()
18 |
19 |
20 | @dataclass
21 | class Chunk:
22 | content: str
23 | parent: Document
24 | embedding: Optional[List[str]] = None
25 |
26 | @property
27 | def payload(self):
28 | return {"content": self.content, "document_metadata": self.parent.metadata}
29 |
30 | @property
31 | def hash(self) -> str:
32 | parent_hash = self.parent.hash
33 | embdding_hash = (
34 | hashlib.sha256(json.dumps(self.embedding).encode()).hexdigest()
35 | if self.embedding
36 | else ""
37 | )
38 | content_hash = hashlib.sha256(self.content.encode()).hexdigest()
39 | return hashlib.sha256(
40 | (parent_hash + embdding_hash + content_hash).encode()
41 | ).hexdigest()
42 |
43 |
44 | @dataclass
45 | class RetrievedChunk:
46 | content: str
47 | document_metadata: dict
48 | score: float
49 |
50 |
51 | class BaseReader(ABC):
52 | @abstractmethod
53 | def read(self, *args, **kwargs) -> List[Document]:
54 | pass
55 |
56 |
57 | class BaseChunker(ABC):
58 | @abstractmethod
59 | def chunk(self, document: Document) -> List[Chunk]:
60 | pass
61 |
62 |
63 | class BaseEmbedder(ABC):
64 | batch_size: int = 1
65 |
66 | def embed_chunks(self, chunks: List[Chunk]):
67 | for i in range(0, len(chunks), self.batch_size):
68 | batch = chunks[i : i + self.batch_size]
69 | contents = [chunk.content for chunk in batch]
70 | embeddings = self.embed(contents)
71 | for chunk, embedding in zip(batch, embeddings):
72 | chunk.embedding = embedding
73 |
74 | @abstractmethod
75 | async def async_embed(self, texts: Union[str, List[str]]) -> List[List[float]]:
76 | pass
77 |
78 | @abstractmethod
79 | def embed(self, texts: Union[str, List[str]]) -> List[List[float]]:
80 | pass
81 |
82 |
83 | class BaseVectorStore(ABC):
84 | @abstractmethod
85 | def upsert_chunks(self, chunks: List[Chunk]):
86 | pass
87 |
88 | @abstractmethod
89 | def retrieve(self, k: int, embedding: List[float]) -> List[RetrievedChunk]:
90 | pass
91 |
92 | @abstractmethod
93 | async def async_retrieve(
94 | self, k: int, embedding: List[float]
95 | ) -> List[RetrievedChunk]:
96 | pass
97 |
--------------------------------------------------------------------------------
/agent_dingo/rag/chunkers/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BeastByteAI/agent_dingo/087a26a5096df5cb663fc20bd58c89ae89aadcb0/agent_dingo/rag/chunkers/__init__.py
--------------------------------------------------------------------------------
/agent_dingo/rag/chunkers/recursive.py:
--------------------------------------------------------------------------------
1 | from agent_dingo.rag.base import BaseChunker, Document, Chunk
2 | from typing import List
3 | import re
4 |
5 |
6 | class RecursiveChunker(BaseChunker):
7 | def __init__(
8 | self, separators=None, chunk_size=512, keep_separator=False, merge_separator=" "
9 | ):
10 | if separators is None:
11 | separators = ["\n\n", "\n", " ", ""]
12 | self.separators = separators
13 | self.chunk_size = chunk_size
14 | self.keep_separator = keep_separator
15 | self.merge_separator = merge_separator
16 |
17 | def chunk(self, documents: List[Document]) -> List[Chunk]:
18 | all_chunks = []
19 | for doc in documents:
20 | chunks = self._split_text_recursive(doc.content, self.separators)
21 | chunks = self._merge_small_chunks(chunks)
22 | all_chunks.extend([Chunk(content=chunk, parent=doc) for chunk in chunks])
23 | return all_chunks
24 |
25 | def _split_text_recursive(self, text, separators):
26 | if not separators:
27 | return [text]
28 |
29 | separator = separators[0]
30 | pattern = re.escape(separator)
31 | split_chunks = re.split(pattern, text)
32 |
33 | final_chunks = []
34 | for i, chunk in enumerate(split_chunks):
35 | appended_chunk = (
36 | separator + chunk if self.keep_separator and i > 0 else chunk
37 | )
38 | if len(appended_chunk) <= self.chunk_size:
39 | final_chunks.append(appended_chunk)
40 | else:
41 | final_chunks.extend(
42 | self._split_text_recursive(appended_chunk, separators[1:])
43 | )
44 |
45 | return final_chunks
46 |
47 | def _merge_small_chunks(self, chunks):
48 | merged_chunks = []
49 | current_chunk = ""
50 |
51 | for chunk in chunks:
52 | new_chunk = current_chunk + (
53 | self.merge_separator + chunk if current_chunk else chunk
54 | )
55 | if len(new_chunk) <= self.chunk_size:
56 | current_chunk = new_chunk
57 | else:
58 | if current_chunk:
59 | merged_chunks.append(current_chunk)
60 | current_chunk = chunk
61 |
62 | if current_chunk:
63 | merged_chunks.append(current_chunk)
64 |
65 | return merged_chunks
66 |
--------------------------------------------------------------------------------
/agent_dingo/rag/embedders/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BeastByteAI/agent_dingo/087a26a5096df5cb663fc20bd58c89ae89aadcb0/agent_dingo/rag/embedders/__init__.py
--------------------------------------------------------------------------------
/agent_dingo/rag/embedders/openai.py:
--------------------------------------------------------------------------------
1 | import openai
2 | from agent_dingo.rag.base import BaseEmbedder
3 | from typing import Optional, List
4 |
5 |
6 | class OpenAIEmbedder(BaseEmbedder):
7 | def __init__(
8 | self,
9 | model: str = "text-embedding-3-small",
10 | base_url: Optional[str] = None,
11 | dimensions: Optional[int] = None,
12 | ):
13 | self.model = model
14 | self.client = openai.OpenAI(base_url=base_url)
15 | self.async_client = openai.AsyncOpenAI(base_url=base_url)
16 | self.params = {
17 | "model": self.model,
18 | }
19 | if dimensions:
20 | self.params["dimensions"] = dimensions
21 |
22 | def embed(self, texts: str) -> List[List[float]]:
23 | if isinstance(texts, str):
24 | texts = [texts]
25 | embeddings = [
26 | i.embedding
27 | for i in (self.client.embeddings.create(**self.params, input=texts).data)
28 | ]
29 | return embeddings
30 |
31 | async def async_embed(self, texts: str) -> List[List[float]]:
32 | if isinstance(texts, str):
33 | texts = [texts]
34 | res = await self.async_client.embeddings.create(**self.params, input=texts)
35 | embeddings = [i.embedding for i in res.data]
36 | return embeddings
37 |
--------------------------------------------------------------------------------
/agent_dingo/rag/embedders/sentence_transformer.py:
--------------------------------------------------------------------------------
1 | try:
2 | from sentence_transformers import SentenceTransformer as _SentenceTransformer
3 | except ImportError:
4 | raise ImportError(
5 | "SentenceTransformers is not installed. Please install it using `pip install agent-dingo[sentence-transformers]`"
6 | )
7 | from agent_dingo.rag.base import BaseEmbedder
8 | from typing import List, Union
9 | import hashlib
10 | import concurrent.futures
11 | import asyncio
12 | import os
13 |
14 |
15 | class SentenceTransformer(BaseEmbedder):
16 | def __init__(
17 | self, model_name: str = "paraphrase-MiniLM-L6-v2", batch_size: int = 128
18 | ):
19 | self.model = _SentenceTransformer(model_name)
20 | self.model_name = model_name
21 | self._executor = None
22 | self.batch_size = batch_size
23 | os.environ["TOKENIZERS_PARALLELISM"] = os.environ.get(
24 | "TOKENIZERS_PARALLELISM", "false"
25 | )
26 |
27 | def _prepare_executor(self) -> None:
28 | if self._executor is None:
29 | self._executor = concurrent.futures.ThreadPoolExecutor(max_workers=1)
30 |
31 | async def async_embed(self, text: Union[str, List[str]]) -> List[List[float]]:
32 | if isinstance(text, str):
33 | text = [text]
34 | self._prepare_executor()
35 | loop = asyncio.get_event_loop()
36 | return await loop.run_in_executor(self._executor, self.embed, text)
37 |
38 | def embed(self, texts: Union[str, List[str]]) -> List[List[float]]:
39 | if isinstance(texts, str):
40 | texts = [texts]
41 | embeddings = self.model.encode(texts)
42 | return [embedding.tolist() for embedding in embeddings]
43 |
44 | def hash(self) -> str:
45 | return hashlib.sha256(
46 | ("SentenceTransformer::" + self.model_name).encode()
47 | ).hexdigest()
48 |
--------------------------------------------------------------------------------
/agent_dingo/rag/prompt_modifiers.py:
--------------------------------------------------------------------------------
1 | from agent_dingo.core.state import ChatPrompt, Context, Store
2 | from agent_dingo.core.message import UserMessage, SystemMessage
3 | from agent_dingo.core.blocks import BasePromptModifier as _BasePromptModifier
4 | from agent_dingo.rag.base import BaseEmbedder, BaseVectorStore, RetrievedChunk
5 | from typing import List, Optional
6 | from warnings import warn
7 |
8 | _DEFAULT_RAG_TEMPLATE = """
9 | {original_message}
10 |
11 | Relevant documents:
12 | {documents}
13 | """
14 |
15 |
16 | class RAGPromptModifier(_BasePromptModifier):
17 | def __init__(
18 | self,
19 | embedder: BaseEmbedder,
20 | vector_store: BaseVectorStore,
21 | n_chunks_to_retrieve: int = 5,
22 | retrieved_data_location: str = "system",
23 | rag_template: Optional[str] = None,
24 | ):
25 | if retrieved_data_location not in ["system", "user"]:
26 | raise ValueError(
27 | "retrieved_data_location must be one of 'system' or 'user'"
28 | )
29 | self.embedder = embedder
30 | self.vector_store = vector_store
31 | self.retrieved_data_location = retrieved_data_location
32 | self.rag_template = rag_template or _DEFAULT_RAG_TEMPLATE
33 | self.n_chunks_to_retrieve = n_chunks_to_retrieve
34 |
35 | def forward(self, state: ChatPrompt, context: Context, store: Store) -> ChatPrompt:
36 | if not isinstance(state, ChatPrompt):
37 | raise ValueError("state must be a ChatPrompt")
38 | query = state.messages[-1].content
39 | query_embedding = self.embedder.embed(query)[0]
40 | try:
41 | retrieved_data = self.vector_store.retrieve(
42 | self.n_chunks_to_retrieve,
43 | query_embedding,
44 | )
45 | except Exception as e:
46 | retrieved_data = []
47 | warn("No data was retrieved")
48 | return self._forward(state, retrieved_data)
49 |
50 | async def async_forward(
51 | self, state: ChatPrompt, context: Context, store: Store
52 | ) -> ChatPrompt:
53 | if not isinstance(state, ChatPrompt):
54 | raise ValueError("state must be a ChatPrompt")
55 | query = state.messages[-1].content
56 | query_embedding = (await self.embedder.async_embed(query))[0]
57 | try:
58 | retrieved_data = await self.vector_store.async_retrieve(
59 | self.n_chunks_to_retrieve,
60 | query_embedding,
61 | )
62 | except Exception as e:
63 | retrieved_data = []
64 | warn("No data was retrieved")
65 | return self._forward(state, retrieved_data)
66 |
67 | def _forward(
68 | self, state: ChatPrompt, retrieved_data: List[RetrievedChunk]
69 | ) -> ChatPrompt:
70 | if len(retrieved_data) < 1:
71 | return state
72 | modified = False
73 | messages = []
74 | target_message_type = (
75 | SystemMessage if self.retrieved_data_location == "system" else UserMessage
76 | )
77 | for message in state.messages:
78 | if isinstance(message, target_message_type) and not modified:
79 | modified_message = target_message_type(
80 | self.rag_template.format(
81 | original_message=message.content,
82 | documents="\n".join([str(i.__dict__) for i in retrieved_data]),
83 | )
84 | )
85 | modified = True
86 | else:
87 | modified_message = message.__class__(message.content)
88 | messages.append(modified_message)
89 | if not modified:
90 | raise ValueError(
91 | f"Could not find a {target_message_type.__name__} message to modify"
92 | )
93 | return ChatPrompt(messages)
94 |
95 | def get_required_context_keys(self) -> List[str]:
96 | return []
97 |
--------------------------------------------------------------------------------
/agent_dingo/rag/readers/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BeastByteAI/agent_dingo/087a26a5096df5cb663fc20bd58c89ae89aadcb0/agent_dingo/rag/readers/__init__.py
--------------------------------------------------------------------------------
/agent_dingo/rag/readers/list.py:
--------------------------------------------------------------------------------
1 | from agent_dingo.rag.base import BaseReader as _BaseReader, Document
2 | from typing import List, Optional
3 |
4 |
5 | class ListReader(_BaseReader):
6 | def read(self, inputs: List[str]) -> List[Document]:
7 | docs = []
8 | for i in inputs:
9 | if not isinstance(i, str):
10 | raise ValueError("ListReader only accepts lists of strings")
11 | docs.append(Document(i, {"source": "memory"}))
12 | return docs
13 |
--------------------------------------------------------------------------------
/agent_dingo/rag/readers/pdf.py:
--------------------------------------------------------------------------------
1 | from typing import List
2 |
3 | try:
4 | from PyPDF2 import PdfReader
5 | except ImportError:
6 | raise ImportError(
7 | "PyPDF2 is not installed. Please install it using `pip install agent-dingo[rag_default]`"
8 | )
9 | from agent_dingo.rag.base import BaseReader as _BaseReader, Document
10 |
11 |
12 | class PDFReader(_BaseReader):
13 | def read(self, file_path: str) -> List[Document]:
14 | docs = []
15 | reader = PdfReader(file_path)
16 | for i, page in enumerate(reader.pages):
17 | text = page.extract_text()
18 | docs.append(Document(text, {"source": file_path, "page": i}))
19 | return docs
20 |
--------------------------------------------------------------------------------
/agent_dingo/rag/readers/web.py:
--------------------------------------------------------------------------------
1 | from typing import List
2 | from agent_dingo.rag.base import BaseReader as _BaseReader, Document
3 | try:
4 | import requests
5 | from bs4 import BeautifulSoup
6 | except ImportError:
7 | raise ImportError(
8 | "requests or BeautifulSoup4 are not installed. Please install it using `pip install agent-dingo[rag_default]`"
9 | )
10 |
11 |
12 | class WebpageReader(_BaseReader):
13 | def read(self, url: str) -> List[Document]:
14 | docs = []
15 | response = requests.get(url)
16 | if response.status_code == 200:
17 | soup = BeautifulSoup(response.content, "html.parser")
18 | text = soup.get_text(" ")
19 | docs.append(Document(text, {"source": url}))
20 | else:
21 | raise ValueError(f"Error fetching {url}: {response.status_code}")
22 | return docs
23 |
--------------------------------------------------------------------------------
/agent_dingo/rag/readers/word.py:
--------------------------------------------------------------------------------
1 | from agent_dingo.rag.base import BaseReader as _BaseReader, Document
2 | from typing import List
3 | try:
4 | import docx # python-docx
5 | except ImportError:
6 | raise ImportError(
7 | "python-docx is not installed. Please install it using `pip install agent-dingo[rag_default]`"
8 | )
9 |
10 |
11 | class WordDocumentReader(_BaseReader):
12 | def read(self, file_path: str) -> List[Document]:
13 | docs = []
14 | doc = docx.Document(file_path)
15 | for i, para in enumerate(doc.paragraphs):
16 | text = para.text
17 | docs.append(Document(text, {"source": file_path, "paragraph": i}))
18 | return docs
19 |
--------------------------------------------------------------------------------
/agent_dingo/rag/vector_stores/chromadb.py:
--------------------------------------------------------------------------------
1 | from agent_dingo.rag.base import (
2 | BaseVectorStore as _BaseVectorStore,
3 | Chunk,
4 | RetrievedChunk,
5 | )
6 | from agent_dingo.utils import sha256_to_uuid
7 | from typing import Optional, List
8 |
9 | try:
10 | import chromadb
11 | except ImportError:
12 | raise ImportError(
13 | "Chromadb is not installed. Please install it using `pip install agent-dingo[chromadb]`"
14 | )
15 |
16 | import asyncio
17 | from concurrent.futures import ThreadPoolExecutor
18 |
19 |
20 | class ChromaDB(_BaseVectorStore):
21 | def __init__(
22 | self,
23 | collection_name: str,
24 | path: Optional[str] = None,
25 | host: Optional[str] = None,
26 | port: Optional[int] = None,
27 | recreate_collection: bool = False,
28 | upsert_batch_size: int = 32,
29 | ):
30 | """
31 | ChromaDB vector store.
32 |
33 | Parameters
34 | ----------
35 | collection_name : str
36 | name of the collection to store the vectors
37 | path : Optional[str], optional
38 | path to the local database, by default None
39 | host : Optional[str], optional
40 | host, by default None
41 | port : Optional[int], optional
42 | port, by default None
43 | recreate_collection : bool, optional
44 | flag to control whether the collection should be recreated on init, by default False
45 | upsert_batch_size : int, optional
46 | batch size for upserting the documents, by default 32
47 | """
48 | if path is not None and (host is not None or port is not None):
49 | raise ValueError("Either path or host/port must be specified, not both")
50 | if path is not None:
51 | self.client = chromadb.PersistentClient(path=path)
52 | else:
53 | self.client = chromadb.HttpClient(host=host, port=port)
54 |
55 | if recreate_collection:
56 | try:
57 | self.client.delete_collection(collection_name)
58 | except ValueError:
59 | pass
60 |
61 | self.collection = self.client.get_or_create_collection(collection_name)
62 |
63 | self.upsert_batch_size = upsert_batch_size
64 |
65 | self._executor = None
66 |
67 | def upsert_chunks(self, chunks: List[Chunk]):
68 | for i in range(0, len(chunks), self.upsert_batch_size):
69 | batch = chunks[i : i + self.upsert_batch_size]
70 | batch_contents = [chunk.content for chunk in batch]
71 | batch_ids = [sha256_to_uuid(chunk.hash) for chunk in batch]
72 | batch_embeddings = [chunk.embedding for chunk in batch]
73 | batch_metadata = [chunk.payload["document_metadata"] for chunk in batch]
74 | self.collection.upsert(
75 | ids=batch_ids,
76 | documents=batch_contents,
77 | embeddings=batch_embeddings,
78 | metadatas=batch_metadata,
79 | )
80 |
81 | def retrieve(self, k: int, query: List[float]) -> List[RetrievedChunk]:
82 | search_result = self.collection.query(
83 | query_embeddings=query,
84 | n_results=k,
85 | include=["metadatas", "documents", "distances"],
86 | )
87 | retrieved_chunks = []
88 | for content, metadata, score in zip(
89 | search_result["documents"][0],
90 | search_result["metadatas"][0],
91 | search_result["distances"][0],
92 | ):
93 | retrieved_chunks.append(RetrievedChunk(content, metadata, score))
94 | return retrieved_chunks
95 |
96 | async def async_retrieve(self, k: int, query: List[float]) -> List[RetrievedChunk]:
97 | loop = asyncio.get_event_loop()
98 | return await loop.run_in_executor(self._get_executor(), self.retrieve, k, query)
99 |
100 | def _get_executor(self) -> ThreadPoolExecutor:
101 | if self._executor is None:
102 | self._executor = ThreadPoolExecutor(max_workers=1)
103 | return self._executor
104 |
--------------------------------------------------------------------------------
/agent_dingo/rag/vector_stores/qdrant.py:
--------------------------------------------------------------------------------
1 | from agent_dingo.rag.base import (
2 | BaseVectorStore as _BaseVectorStore,
3 | Chunk,
4 | RetrievedChunk,
5 | )
6 | from agent_dingo.utils import sha256_to_uuid
7 |
8 | try:
9 | from qdrant_client import QdrantClient, AsyncQdrantClient
10 | from qdrant_client.http import models
11 | except ImportError:
12 | raise ImportError(
13 | "Qdrant is not installed. Please install it using `pip install agenet-dingo[qdrant]`"
14 | )
15 | from typing import Optional, List
16 | from warnings import warn
17 |
18 |
19 | class Qdrant(_BaseVectorStore):
20 | def __init__(
21 | self,
22 | collection_name: str,
23 | embedding_size: int,
24 | path: Optional[str] = None,
25 | host: Optional[str] = None,
26 | port: Optional[int] = None,
27 | url: Optional[str] = None,
28 | api_key: Optional[str] = None,
29 | recreate_collection: bool = False,
30 | upsert_batch_size: int = 32,
31 | try_init: bool = True,
32 | ):
33 | """
34 | Qdrant vector store.
35 |
36 | Parameters
37 | ----------
38 | collection_name : str
39 | collection name
40 | embedding_size : int
41 | size of the vector embeddings
42 | path : Optional[str], optional
43 | path to the local database, does not support concurrent clients, by default None
44 | host : Optional[str], optional
45 | host, by default None
46 | port : Optional[int], optional
47 | port, by default None
48 | url : Optional[str], optional
49 | base url of qdrant provider, by default None
50 | api_key : Optional[str], optional
51 | api key of qdrant provider, by default None
52 | recreate_collection : bool, optional
53 | flag to control whether the collection should be recreated on init, by default False
54 | upsert_batch_size : int, optional
55 | batch size for upserting the documents, by default 32
56 | try_init : bool, optional
57 | flag to control whether the collocetion should be created on object initialization, by default True
58 |
59 | Raises
60 | ------
61 | an
62 | _description_
63 | """
64 | clint_params = {
65 | "host": host,
66 | "path": path,
67 | "port": port,
68 | "url": url,
69 | "api_key": api_key,
70 | }
71 | client_params = {k: v for k, v in clint_params.items() if v is not None}
72 | self.client_params = client_params
73 | self._client = None
74 | self._async_client = None
75 | self.collection_name = collection_name
76 | self.embedding_size = embedding_size
77 | self.recreate_collection = recreate_collection
78 | self.upsert_batch_size = upsert_batch_size
79 | if path and (try_init or recreate_collection):
80 | warn(
81 | "Using local Qdrant storage will only work in a synchronous environment. Trying to call async methods will raise an error."
82 | )
83 | if try_init or recreate_collection:
84 | self._init_collection()
85 |
86 | def make_sync_client(self):
87 | self._client = (
88 | QdrantClient(**self.client_params)
89 | if self.client_params
90 | else QdrantClient(":memory:")
91 | )
92 |
93 | def make_async_client(self):
94 | self._async_client = (
95 | AsyncQdrantClient(**self.client_params)
96 | if self.client_params
97 | else AsyncQdrantClient(":memory:")
98 | )
99 |
100 | @property
101 | def client(self):
102 | if self._client is None:
103 | self.make_sync_client()
104 | return self._client
105 |
106 | @property
107 | def async_client(self):
108 | if self._async_client is None:
109 | self.make_async_client()
110 | return self._async_client
111 |
112 | def _init_collection(self):
113 | create_fn = (
114 | self.client.create_collection
115 | if not self.recreate_collection
116 | else self.client.recreate_collection
117 | )
118 | try:
119 | create_fn(
120 | collection_name=self.collection_name,
121 | vectors_config=models.VectorParams(
122 | size=self.embedding_size, distance=models.Distance.COSINE
123 | ),
124 | )
125 | except ValueError as e:
126 | if self.recreate_collection:
127 | raise e
128 | pass # collection already exists
129 |
130 | def upsert_chunks(self, chunks: List[Chunk]):
131 | for i in range(0, len(chunks), self.upsert_batch_size):
132 | batch = chunks[i : i + self.upsert_batch_size]
133 | points = []
134 | for chunk in batch:
135 | if chunk.embedding is None:
136 | raise ValueError("Chunk must be embedded before upserting")
137 | point = models.PointStruct(
138 | vector=chunk.embedding,
139 | payload=chunk.payload,
140 | id=sha256_to_uuid(chunk.hash),
141 | )
142 | points.append(point)
143 | self.client.upsert(points=points, collection_name=self.collection_name)
144 |
145 | def retrieve(self, k: int, query: List[float]):
146 | search_result = self.client.search(
147 | collection_name=self.collection_name,
148 | query_vector=query,
149 | limit=k,
150 | )
151 | return self._process_search_reults(search_result)
152 |
153 | async def async_retrieve(self, k: int, query: List[float]):
154 | search_result = await self.async_client.search(
155 | collection_name=self.collection_name,
156 | query_vector=query,
157 | limit=k,
158 | )
159 | return self._process_search_reults(search_result)
160 |
161 | def _process_search_reults(self, search_result: List) -> List[RetrievedChunk]:
162 | retrieved_chunks = []
163 | for r in search_result:
164 | content = r.payload["content"]
165 | metadata = r.payload["document_metadata"]
166 | score = r.score
167 | retrieved_chunks.append(RetrievedChunk(content, metadata, score))
168 | return retrieved_chunks
169 |
--------------------------------------------------------------------------------
/agent_dingo/serve.py:
--------------------------------------------------------------------------------
1 | from agent_dingo.core.state import State, Store, Context, ChatPrompt
2 | from agent_dingo.core.blocks import Pipeline
3 | from agent_dingo.core.message import UserMessage, SystemMessage, AssistantMessage
4 | from fastapi import FastAPI, HTTPException
5 | from pydantic import BaseModel
6 | import uvicorn
7 | from typing import List, Dict, Optional, Tuple, Union
8 | from uuid import uuid4
9 | import time
10 |
11 |
12 | class Message(BaseModel):
13 | role: str
14 | content: str
15 |
16 |
17 | class PipelineRunRequest(BaseModel):
18 | model: str
19 | messages: List[Message]
20 |
21 |
22 | class Usage(BaseModel):
23 | prompt_tokens: int
24 | completion_tokens: int
25 | total_tokens: int
26 |
27 |
28 | class Model(BaseModel):
29 | id: str
30 | object: str = "model"
31 | created: int
32 | owned_by: str = "dingo"
33 |
34 |
35 | class Models(BaseModel):
36 | models: List[Model]
37 | object: str = "list"
38 |
39 |
40 | class Choice(BaseModel):
41 | index: int
42 | message: Message
43 | logprobs: Optional[Dict] = None
44 | finish_reason: str = "stop"
45 |
46 |
47 | class PipelineOutputResponse(BaseModel):
48 | id: str
49 | object: str
50 | created: int
51 | model: str
52 | usage: Usage
53 | choices: List[Choice]
54 |
55 |
56 | _role_to_message_type = {
57 | "user": UserMessage,
58 | "system": SystemMessage,
59 | "assistant": AssistantMessage,
60 | }
61 |
62 |
63 | def _construct_response(
64 | output: str, usage: Usage, model: str
65 | ) -> PipelineOutputResponse:
66 | generated_uuid = str(uuid4())
67 | current_timestamp = int(time.time())
68 | return PipelineOutputResponse(
69 | id=generated_uuid,
70 | object="chat.completion",
71 | created=current_timestamp,
72 | usage=usage,
73 | model=model,
74 | choices=[
75 | Choice(
76 | index=0,
77 | message=Message(role="assistant", content=output),
78 | finish_reason="stop",
79 | )
80 | ],
81 | )
82 |
83 |
84 | def _construct_pipeline_input(
85 | input_: List[Message],
86 | ) -> Tuple[ChatPrompt, Dict[str, str]]:
87 | messages = []
88 | context = {}
89 | for m in input_:
90 | if m.role.startswith("context_"):
91 | key = m.role[8:]
92 | if key in context:
93 | raise ValueError(f"Context key {key} already exists.")
94 | context[key] = m.content
95 | else:
96 | msg = _role_to_message_type[m.role](m.content)
97 | messages.append(msg)
98 | state = ChatPrompt(messages)
99 | return state, context
100 |
101 |
102 | def make_app(pipeline: Union[Pipeline, Dict[str, Pipeline]], is_async: bool = False):
103 | app = FastAPI()
104 | created_at = int(time.time())
105 | if isinstance(pipeline, Pipeline):
106 | available_pipelines = {"dingo": pipeline}
107 | else:
108 | for k, v in pipeline.items():
109 | if not isinstance(v, Pipeline):
110 | raise ValueError(f"Pipeline {k} is not an instance of Pipeline.")
111 | available_pipelines = pipeline
112 |
113 | if is_async:
114 |
115 | @app.post("/chat/completions")
116 | async def run_pipeline(input: PipelineRunRequest) -> PipelineOutputResponse:
117 | state, context = _construct_pipeline_input(input.messages)
118 | selected_pipeline = available_pipelines[input.model]
119 | output, usage = await selected_pipeline.async_run(_state=state, **context)
120 | return _construct_response(output, Usage(**usage), model=input.model)
121 |
122 | else:
123 |
124 | @app.post("/chat/completions")
125 | def run_pipeline(input: PipelineRunRequest) -> PipelineOutputResponse:
126 | state, context = _construct_pipeline_input(input.messages)
127 | selected_pipeline = available_pipelines[input.model]
128 | output, usage = selected_pipeline.run(_state=state, **context)
129 | return _construct_response(output, Usage(**usage), model=input.model)
130 |
131 | @app.get("/models")
132 | async def get_models() -> Models:
133 | models = Models(
134 | models=[Model(id=k, created=created_at) for k in available_pipelines.keys()]
135 | )
136 | return models
137 |
138 | return app
139 |
140 |
141 | def serve_pipeline(
142 | pipeline: Union[Pipeline, Dict[str, Pipeline]],
143 | is_async: bool = False,
144 | host: str = "0.0.0.0",
145 | port: int = 8000,
146 | ):
147 | app = make_app(pipeline, is_async)
148 | uvicorn.run(app, host=host, port=port)
149 |
--------------------------------------------------------------------------------
/agent_dingo/utils.py:
--------------------------------------------------------------------------------
1 | def sha256_to_uuid(sha256_hash: str) -> str:
2 | short_hash = sha256_hash[:32]
3 | formatted_uuid = f"{short_hash[:8]}-{short_hash[8:12]}-{short_hash[12:16]}-{short_hash[16:20]}-{short_hash[20:32]}"
4 | return formatted_uuid
5 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [build-system]
2 | requires = ["setuptools>=61.0"]
3 | build-backend = "setuptools.build_meta"
4 |
5 | [project]
6 | dependencies = [
7 | "openai>=1.25.0,<2.0.0",
8 | "docstring_parser>=0.15.0,<1.0.0",
9 | "tenacity>=8.2.0,<9.0.0",
10 | ]
11 | name = "agent_dingo"
12 | version = "1.0.0"
13 | authors = [
14 | { name="Oleh Kostromin", email="kostromin97@gmail.com" },
15 | { name="Iryna Kondrashchenko", email="iryna230520@gmail.com" },
16 | ]
17 | description = "A microframework for creating simple AI agents."
18 | readme = "README.md"
19 | license = {text = "MIT"}
20 | requires-python = ">=3.9"
21 | classifiers = [
22 | "Programming Language :: Python :: 3",
23 | "License :: OSI Approved :: MIT License",
24 | "Operating System :: OS Independent",
25 | ]
26 |
27 | [project.optional-dependencies]
28 | server = ["fastapi>=0.105.0,<1.0.0", "uvicorn>=0.20.0,<1.0.0"]
29 | langchain = ["langchain>=0.1.0,<0.2.0"]
30 | qdrant = ["qdrant-client>=1.9.0,<2.0.0"]
31 | chromadb = ["chromadb>=0.5.0,<1.0.0"]
32 | sentence-transformers = ["sentence-transformers>=2.3.0,<3.0.0"]
33 | rag-default = ["PyPDF2>=3.0.0,<4.0.0", "beautifulsoup4>=4.12.0,<5.0.0", "requests>=2.26.0,<3.0.0", "python-docx>=1.0.0,<2.0.0"]
34 | vertexai = ["google-cloud-aiplatform>=1.40.0,<2.0.0"]
35 | litellm = ["litellm>=1.30.0,<2.0.0"]
36 | llama-cpp = ["llama-cpp-python>=0.2.20,<0.3.0"]
37 |
--------------------------------------------------------------------------------
/tests/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BeastByteAI/agent_dingo/087a26a5096df5cb663fc20bd58c89ae89aadcb0/tests/__init__.py
--------------------------------------------------------------------------------
/tests/fake_llm.py:
--------------------------------------------------------------------------------
1 | from typing import Optional, List, Callable
2 | from agent_dingo.core.blocks import BaseLLM
3 | from agent_dingo.core.state import ChatPrompt, UsageMeter
4 |
5 | import openai
6 | from tenacity import retry, stop_after_attempt, wait_fixed
7 |
8 |
9 | class FakeLLM(BaseLLM):
10 | def __init__(
11 | self,
12 | model: str = "123",
13 | temperature: float = 0.7,
14 | base_url: Optional[str] = None,
15 | ):
16 | self.model = model
17 | self.temperature = temperature
18 | self.client = openai.OpenAI(base_url=base_url)
19 | self.async_client = openai.AsyncOpenAI(base_url=base_url)
20 | if base_url is None:
21 | self.supports_function_calls = True
22 |
23 | def send_message(
24 | self,
25 | messages,
26 | functions=None,
27 | usage_meter: UsageMeter = None,
28 | temperature=None,
29 | **kwargs,
30 | ):
31 | res, full_res = (
32 | "Fake response",
33 | {
34 | "id": "chatcmpl-123",
35 | "object": "chat.completion",
36 | "created": 1677652288,
37 | "model": "gpt-3.5-turbo-0613",
38 | "system_fingerprint": "fp_44709d6fcb",
39 | "choices": [
40 | {
41 | "index": 0,
42 | "message": {
43 | "role": "assistant",
44 | "content": "Fake response",
45 | },
46 | "logprobs": None,
47 | "finish_reason": "stop",
48 | }
49 | ],
50 | "usage": {
51 | "prompt_tokens": 9,
52 | "completion_tokens": 12,
53 | "total_tokens": 21,
54 | },
55 | },
56 | )
57 | if usage_meter:
58 | usage_meter.increment(
59 | prompt_tokens=full_res["usage"]["prompt_tokens"],
60 | completion_tokens=full_res["usage"]["completion_tokens"],
61 | )
62 | if "function_call" in res.keys():
63 | del res["function_call"]
64 | return res
65 |
66 | async def async_send_message(
67 | self,
68 | messages,
69 | functions=None,
70 | usage_meter: UsageMeter = None,
71 | temperature=None,
72 | **kwargs,
73 | ):
74 | return self.send_message(
75 | messages, functions, usage_meter, temperature, **kwargs
76 | )
77 |
78 | def process_prompt(
79 | self, prompt: ChatPrompt, usage_meter: Optional[UsageMeter] = None, **kwargs
80 | ):
81 | return self.send_message(prompt.dict, None, usage_meter)["content"]
82 |
83 | async def async_process_prompt(
84 | self, prompt: ChatPrompt, usage_meter: Optional[UsageMeter] = None, **kwargs
85 | ):
86 | return (await self.async_send_message(prompt.dict, None, usage_meter))[
87 | "content"
88 | ]
89 |
--------------------------------------------------------------------------------
/tests/test_agent/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BeastByteAI/agent_dingo/087a26a5096df5cb663fc20bd58c89ae89aadcb0/tests/test_agent/__init__.py
--------------------------------------------------------------------------------
/tests/test_agent/test_agent.py:
--------------------------------------------------------------------------------
1 | import unittest
2 | from unittest.mock import patch
3 | from agent_dingo.agent import Agent
4 | from agent_dingo.agent.function_descriptor import FunctionDescriptor
5 | from tests.fake_llm import FakeLLM
6 |
7 |
8 | class TestAgentDingo(unittest.TestCase):
9 | def setUp(self):
10 | llm = FakeLLM()
11 | self.agent = Agent(llm)
12 |
13 | def test_register_function(self):
14 | def func(arg: str):
15 | """_summary_
16 |
17 | Parameters
18 | ----------
19 | arg : str
20 | _description_
21 | """
22 | pass
23 |
24 | self.agent.register_function(func)
25 | self.assertEqual(len(self.agent._registry._Registry__functions), 1)
26 |
27 | def test_register_descriptor(self):
28 | d = FunctionDescriptor(
29 | name="function_from_descriptor",
30 | func=lambda arg: None,
31 | json_repr={},
32 | requires_context=False,
33 | )
34 | self.agent.register_descriptor(d)
35 | self.assertEqual(len(self.agent._registry._Registry__functions), 1)
36 | self.assertIn(
37 | "function_from_descriptor", self.agent._registry._Registry__functions.keys()
38 | )
39 |
40 | def test_function_decorator(self):
41 | @self.agent.function
42 | def func(arg: str):
43 | """_summary_
44 |
45 | Parameters
46 | ----------
47 | arg : str
48 | _description_
49 | """
50 | pass
51 |
52 | self.assertEqual(len(self.agent._registry._Registry__functions), 1)
53 |
54 |
55 | if __name__ == "__main__":
56 | unittest.main()
57 |
--------------------------------------------------------------------------------
/tests/test_agent/test_extract_substr.py:
--------------------------------------------------------------------------------
1 | import unittest
2 | from agent_dingo.agent.docgen import extract_substr
3 |
4 |
5 | class TestExtractSubstr(unittest.TestCase):
6 | def test_extract_substr(self):
7 | input_string = """Extracts the desription and args from a docstring.
8 |
9 | Args:
10 | input_string (str): The docstring to extract the description and args from.
11 |
12 | Returns:
13 | str: Reduced docstring containing only the description and the args.
14 | """
15 | expected_output = """Extracts the desription and args from a docstring.
16 |
17 | Args:
18 | input_string (str): The docstring to extract the description and args from."""
19 | self.assertEqual(extract_substr(input_string), expected_output)
20 |
--------------------------------------------------------------------------------
/tests/test_agent/test_get_required_args.py:
--------------------------------------------------------------------------------
1 | import unittest
2 | from agent_dingo.agent.helpers import get_required_args
3 | from agent_dingo.agent.chat_context import ChatContext
4 |
5 |
6 | class TestGetRequiredArgs(unittest.TestCase):
7 | def test_no_args(self):
8 | def func():
9 | pass
10 |
11 | self.assertEqual(get_required_args(func), [])
12 |
13 | def test_required_args(self):
14 | def func(a, b, c):
15 | pass
16 |
17 | self.assertEqual(get_required_args(func), ["a", "b", "c"])
18 |
19 | def test_optional_args(self):
20 | def func(a, b, c=None):
21 | pass
22 |
23 | self.assertEqual(get_required_args(func), ["a", "b"])
24 |
25 | def test_mixed_args(self):
26 | def func(a, b, c=None, d=None):
27 | pass
28 |
29 | self.assertEqual(get_required_args(func), ["a", "b"])
30 |
31 | def test_with_chat_context(self):
32 | def func(a, b, chat_context: ChatContext):
33 | pass
34 |
35 | self.assertEqual(get_required_args(func), ["a", "b"])
36 |
37 | def test_wrong_chat_context_type(self):
38 | def func(a, b, chat_context: str):
39 | pass
40 |
41 | self.assertEqual(get_required_args(func), ["a", "b", "chat_context"])
42 |
43 | def test_wrong_chat_context_name(self):
44 | def func(a, b, chat_context_: ChatContext):
45 | pass
46 |
47 | self.assertEqual(get_required_args(func), ["a", "b", "chat_context_"])
48 |
49 |
50 | if __name__ == "__main__":
51 | unittest.main()
52 |
--------------------------------------------------------------------------------
/tests/test_agent/test_parser.py:
--------------------------------------------------------------------------------
1 | import unittest
2 | from agent_dingo.agent.parser import parse
3 |
4 |
5 | class TestParser(unittest.TestCase):
6 | def test_parse_google(self):
7 | docstring = """Parses a docstring.
8 |
9 | Args:
10 | docstring (str): The docstring to parse.
11 |
12 | Returns:
13 | dict: A dictionary containing the description and the arguments of the function.
14 | """
15 | expected_output = (
16 | {
17 | "description": "Parses a docstring.",
18 | "properties": {
19 | "docstring": {
20 | "type": "string",
21 | "description": "The docstring to parse.",
22 | }
23 | },
24 | },
25 | False,
26 | )
27 | self.assertEqual(parse(docstring), expected_output)
28 |
29 | def test_parse_numpy(self):
30 | docstring = """Parses a docstring.
31 |
32 | Parameters
33 | ----------
34 | docstring : str
35 | The docstring to parse.
36 |
37 | Returns
38 | -------
39 | dict
40 | A dictionary containing the description and the arguments of the function.
41 | """
42 | expected_output = (
43 | {
44 | "description": "Parses a docstring.",
45 | "properties": {
46 | "docstring": {
47 | "type": "string",
48 | "description": "The docstring to parse.",
49 | }
50 | },
51 | },
52 | False,
53 | )
54 | self.assertEqual(parse(docstring), expected_output)
55 |
56 | def test_parse_with_enum(self):
57 | docstring = """Parses a docstring.
58 |
59 | Parameters
60 | ----------
61 | arg1 : str
62 | The first argument.
63 | arg2 : int
64 | The second argument.
65 | arg3 : float
66 | The third argument.
67 | arg4 : bool
68 | The fourth argument.
69 | arg5 : list
70 | The fifth argument.
71 | arg6 : dict
72 | The sixth argument.
73 | arg7 : str
74 | The seventh argument. Enum: ['value1', 'value2', 'value3']
75 |
76 | Returns
77 | -------
78 | dict
79 | A dictionary containing the description and the arguments of the function.
80 | """
81 | expected_output = (
82 | {
83 | "description": "Parses a docstring.",
84 | "properties": {
85 | "arg1": {"type": "string", "description": "The first argument."},
86 | "arg2": {"type": "integer", "description": "The second argument."},
87 | "arg3": {"type": "number", "description": "The third argument."},
88 | "arg4": {"type": "boolean", "description": "The fourth argument."},
89 | "arg5": {"type": "array", "description": "The fifth argument."},
90 | "arg6": {"type": "object", "description": "The sixth argument."},
91 | "arg7": {
92 | "type": "string",
93 | "description": "The seventh argument.",
94 | "enum": ["value1", "value2", "value3"],
95 | },
96 | },
97 | },
98 | False,
99 | )
100 | self.assertEqual(parse(docstring), expected_output)
101 |
102 | def test_parse_with_context(self):
103 | docstring = """Parses a docstring.
104 |
105 | Parameters
106 | ----------
107 | arg1 : str
108 | The first argument.
109 | chat_context : ChatContext
110 | The chat context.
111 |
112 | Returns
113 | -------
114 | dict
115 | A dictionary containing the description and the arguments of the function.
116 | """
117 | expected_output = (
118 | {
119 | "description": "Parses a docstring.",
120 | "properties": {
121 | "arg1": {"type": "string", "description": "The first argument."}
122 | },
123 | },
124 | True,
125 | )
126 | self.assertEqual(parse(docstring), expected_output)
127 |
128 | def test_parse_without_description(self):
129 | docstring = """Parameters
130 | ----------
131 | arg1 : str
132 | The first argument.
133 |
134 | Returns
135 | -------
136 | dict
137 | A dictionary containing the description and the arguments of the function.
138 | """
139 | with self.assertRaises(ValueError):
140 | parse(docstring)
141 |
142 |
143 | if __name__ == "__main__":
144 | unittest.main()
145 |
--------------------------------------------------------------------------------
/tests/test_agent/test_registry.py:
--------------------------------------------------------------------------------
1 | import unittest
2 | from agent_dingo.agent.registry import Registry as _Registry
3 |
4 |
5 | class TestRegistry(unittest.TestCase):
6 | def setUp(self):
7 | self.registry = _Registry()
8 |
9 | def test_add_function(self):
10 | def func():
11 | pass
12 |
13 | json_repr = {"name": "func"}
14 | self.registry.add("func", func, json_repr, False)
15 | self.assertEqual(len(self.registry._Registry__functions), 1)
16 |
17 | def test_get_function(self):
18 | def func():
19 | pass
20 |
21 | json_repr = {"name": "func"}
22 | self.registry.add("func", func, json_repr, False)
23 | func, requires_context = self.registry.get_function("func")
24 | self.assertEqual(callable(func), True)
25 | self.assertEqual(requires_context, False)
26 |
27 | def test_get_available_functions(self):
28 | def func1():
29 | pass
30 |
31 | def func2():
32 | pass
33 |
34 | json_repr1 = {"name": "func1"}
35 | json_repr2 = {"name": "func2"}
36 | self.registry.add("func1", func1, json_repr1, False)
37 | self.registry.add(
38 | "func2", func2, json_repr2, True, required_context_keys=["any"]
39 | )
40 | available_functions = self.registry.get_available_functions()
41 | self.assertEqual(len(available_functions), 2)
42 | self.assertEqual(available_functions[0]["name"], "func1")
43 | self.assertEqual(available_functions[1]["name"], "func2")
44 |
45 |
46 | if __name__ == "__main__":
47 | unittest.main()
48 |
--------------------------------------------------------------------------------
/tests/test_core/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BeastByteAI/agent_dingo/087a26a5096df5cb663fc20bd58c89ae89aadcb0/tests/test_core/__init__.py
--------------------------------------------------------------------------------
/tests/test_core/test_blocks.py:
--------------------------------------------------------------------------------
1 | import unittest
2 | from agent_dingo.core.blocks import (
3 | Squash,
4 | PromptBuilder,
5 | Pipeline,
6 | Parallel,
7 | Identity,
8 | SaveState,
9 | LoadState,
10 | InlineBlock,
11 | )
12 | from agent_dingo.core.state import State, ChatPrompt, KVData, Context, Store, UsageMeter
13 | from agent_dingo.core.message import Message
14 |
15 |
16 | class TestBlocks(unittest.TestCase):
17 | def test_squash(self):
18 | s = Squash("{0} {1}")
19 | state = KVData(_out_0="Hello", _out_1="World")
20 | context = Context()
21 | store = Store()
22 | self.assertEqual(s.forward(state, context, store)["_out_0"], "Hello World")
23 |
24 | def test_prompt_builder(self):
25 | pb = PromptBuilder([Message("Hello {name}")], from_state=["_out_0"])
26 | state = KVData(_out_0="World")
27 | context = Context(name="World")
28 | store = Store()
29 | self.assertEqual(
30 | pb.forward(state, context, store).messages[0].content, "Hello World"
31 | )
32 |
33 | def test_pipeline(self):
34 | p = Pipeline()
35 | p.add_block(Squash("{0} {1}"))
36 | state = KVData(_out_0="Hello", _out_1="World")
37 | context = Context()
38 | store = Store()
39 | self.assertEqual(p.forward(state, context, store)["_out_0"], "Hello World")
40 | self.assertEqual(p.run(state)[0], "Hello World")
41 |
42 | def test_pipeline_build(self):
43 | block1 = Identity()
44 | block2 = Identity()
45 | pipeline = block1 >> block2
46 | self.assertIsInstance(pipeline, Pipeline)
47 | self.assertEqual(len(pipeline._blocks), 2)
48 | self.assertIs(pipeline._blocks[0], block1)
49 | self.assertIs(pipeline._blocks[1], block2)
50 |
51 | def test_pipeline_build_with_parallel(self):
52 | block1 = Identity()
53 | block2 = Identity()
54 | squash = Squash("{0} {1}")
55 | pipeline = (block1 & block2) >> squash
56 | self.assertIsInstance(pipeline, Pipeline)
57 | self.assertEqual(len(pipeline._blocks), 2)
58 | self.assertIsInstance(pipeline._blocks[0], Parallel)
59 | self.assertIs(pipeline._blocks[1], squash)
60 |
61 | def test_identity(self):
62 | i = Identity()
63 | state = KVData(_out_0="Hello")
64 | context = Context()
65 | store = Store()
66 | self.assertIs(i.forward(state, context, store), state)
67 |
68 | def test_save_state(self):
69 | ss = SaveState("key")
70 | state = KVData(_out_0="Hello")
71 | context = Context()
72 | store = Store()
73 | self.assertEqual(ss.forward(state, context, store)["_out_0"], "Hello")
74 | self.assertIs(store.get_data("key"), state)
75 |
76 | def test_load_state(self):
77 | ls = LoadState("data", "key")
78 | state = KVData(_out_0="Hello")
79 | context = Context()
80 | store = Store()
81 | store.update("key", state)
82 | self.assertIs(ls.forward(state, context, store), state)
83 |
84 | def test_inline_block(self):
85 | ib = InlineBlock()
86 | state_ = KVData(_out_0="Hello")
87 |
88 | @ib
89 | def func(state, context, store):
90 | return state_
91 |
92 | state = KVData(_out_0="World")
93 | context = Context()
94 | store = Store()
95 | self.assertIs(func.forward(state, context, store), state_)
96 |
97 |
98 | if __name__ == "__main__":
99 | unittest.main()
100 |
--------------------------------------------------------------------------------
/tests/test_core/test_message.py:
--------------------------------------------------------------------------------
1 | import unittest
2 | from agent_dingo.core.message import (
3 | Message,
4 | UserMessage,
5 | SystemMessage,
6 | AssistantMessage,
7 | )
8 |
9 |
10 | class TestMessage(unittest.TestCase):
11 | def test_message(self):
12 | m = Message("Hello")
13 | self.assertEqual(m.content, "Hello")
14 | self.assertEqual(m.role, "undefined")
15 | self.assertEqual(m.dict, {"role": "undefined", "content": "Hello"})
16 |
17 | def test_user_message(self):
18 | m = UserMessage("Hello")
19 | self.assertEqual(m.content, "Hello")
20 | self.assertEqual(m.role, "user")
21 | self.assertEqual(m.dict, {"role": "user", "content": "Hello"})
22 |
23 | def test_system_message(self):
24 | m = SystemMessage("Hello")
25 | self.assertEqual(m.content, "Hello")
26 | self.assertEqual(m.role, "system")
27 | self.assertEqual(m.dict, {"role": "system", "content": "Hello"})
28 |
29 | def test_assistant_message(self):
30 | m = AssistantMessage("Hello")
31 | self.assertEqual(m.content, "Hello")
32 | self.assertEqual(m.role, "assistant")
33 | self.assertEqual(m.dict, {"role": "assistant", "content": "Hello"})
34 |
35 |
36 | if __name__ == "__main__":
37 | unittest.main()
38 |
--------------------------------------------------------------------------------
/tests/test_core/test_output_parser.py:
--------------------------------------------------------------------------------
1 | import unittest
2 | from agent_dingo.core.output_parser import BaseOutputParser, DefaultOutputParser
3 | from agent_dingo.core.state import State, ChatPrompt, KVData
4 | from agent_dingo.core.message import Message
5 |
6 |
7 | class TestOutputParser(unittest.TestCase):
8 | def test_base_output_parser(self):
9 | class TestOutputParser(BaseOutputParser):
10 | def _parse_chat(self, output: ChatPrompt) -> str:
11 | return "chat"
12 |
13 | def _parse_kvdata(self, output: KVData) -> str:
14 | return "kvdata"
15 |
16 | parser = TestOutputParser()
17 | self.assertEqual(parser.parse(KVData(a="1")), "kvdata")
18 | self.assertEqual(parser.parse(ChatPrompt([Message("Hello")])), "chat")
19 | with self.assertRaises(TypeError):
20 | parser.parse("invalid")
21 |
22 | def test_default_output_parser(self):
23 | parser = DefaultOutputParser()
24 | with self.assertRaises(RuntimeError):
25 | parser.parse(ChatPrompt([Message("Hello")]))
26 | self.assertEqual(parser.parse(KVData(_out_0="output")), "output")
27 | with self.assertRaises(KeyError):
28 | parser.parse(KVData(a="1"))
29 |
30 |
31 | if __name__ == "__main__":
32 | unittest.main()
33 |
--------------------------------------------------------------------------------
/tests/test_core/test_state.py:
--------------------------------------------------------------------------------
1 | import unittest
2 | from agent_dingo.core.state import ChatPrompt, KVData, Context, UsageMeter, Store
3 | from agent_dingo.core.message import Message
4 |
5 |
6 | class TestState(unittest.TestCase):
7 | def test_chat_prompt(self):
8 | cp = ChatPrompt([Message("Hello")])
9 | self.assertEqual(cp.dict, [{"role": "undefined", "content": "Hello"}])
10 |
11 | def test_kvdata(self):
12 | kv = KVData(a="1", b="2")
13 | self.assertEqual(kv["a"], "1")
14 | self.assertEqual(kv["b"], "2")
15 | self.assertEqual(kv.dict, {"a": "1", "b": "2"})
16 | with self.assertRaises(KeyError):
17 | kv.update("a", "3")
18 |
19 | def test_context(self):
20 | ctx = Context(a="1", b="2")
21 | with self.assertRaises(RuntimeError):
22 | ctx.update("a", "3")
23 |
24 | def test_usage_meter(self):
25 | um = UsageMeter()
26 | um.increment(10, 20)
27 | self.assertEqual(
28 | um.get_usage(),
29 | {"prompt_tokens": 10, "completion_tokens": 20, "total_tokens": 30},
30 | )
31 |
32 | def test_store(self):
33 | st = Store()
34 | st.update("data", KVData(a="1"))
35 | st.update("prompt", ChatPrompt([Message("Hello")]))
36 | st.update("misc", "misc")
37 | self.assertEqual(st.get_data("data").dict, {"a": "1"})
38 | self.assertEqual(
39 | st.get_prompt("prompt").dict, [{"role": "undefined", "content": "Hello"}]
40 | )
41 | self.assertEqual(st.get_misc("misc"), "misc")
42 |
43 |
44 | if __name__ == "__main__":
45 | unittest.main()
46 |
--------------------------------------------------------------------------------