├── .github
└── workflows
│ └── python-publish.yml
├── .gitignore
├── DagentLogo.png
├── LICENSE
├── README.md
├── dagent
├── .gitignore
├── .python-version
├── README.md
├── examples
│ ├── add_two_nums.json
│ ├── multiply_two_nums.json
│ ├── quickstart_local_simple_agent.py
│ ├── quickstart_simple_agent.py
│ └── sql_agent_local.py
├── pyproject.toml
├── requirements-dev.lock
├── requirements.lock
├── src
│ └── dagent
│ │ ├── DagNode.py
│ │ ├── DecisionNode.py
│ │ ├── FunctionNode.py
│ │ ├── __init__.py
│ │ └── base_functions.py
└── tests
│ ├── __init__.py
│ ├── test_DecisionNode.py
│ ├── test_FunctionNode.py
│ ├── test_base_functions.py
│ └── test_linking.py
└── notes.md
/.github/workflows/python-publish.yml:
--------------------------------------------------------------------------------
1 | # This workflow will upload a Python Package using Twine when a release is created
2 | # For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python#publishing-to-package-registries
3 |
4 | # This workflow uses actions that are not certified by GitHub.
5 | # They are provided by a third-party and are governed by
6 | # separate terms of service, privacy policy, and support
7 | # documentation.
8 |
9 | name: Upload Python Package
10 |
11 | on:
12 | release:
13 | types: [published]
14 |
15 | permissions:
16 | contents: read
17 |
18 | jobs:
19 | deploy:
20 |
21 | runs-on: ubuntu-latest
22 |
23 | steps:
24 | - uses: actions/checkout@v4
25 | - name: Set up Python
26 | uses: actions/setup-python@v3
27 | with:
28 | python-version: '3.x'
29 | - name: Install dependencies
30 | run: |
31 | cd dagent
32 | python -m pip install --upgrade pip
33 | pip install build
34 | - name: Build package
35 | run: python -m build
36 | - name: Publish package
37 | uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29
38 | with:
39 | user: __token__
40 | password: ${{ secrets.PYPI_API_TOKEN }}
41 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | Tool_JSON/
3 | __pycache__/
4 | *.py[cod]
5 | *$py.class
6 |
7 | # C extensions
8 | *.so
9 |
10 | # Distribution / packaging
11 | .Python
12 | build/
13 | develop-eggs/
14 | dist/
15 | downloads/
16 | eggs/
17 | .eggs/
18 | lib/
19 | lib64/
20 | parts/
21 | sdist/
22 | var/
23 | wheels/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 | cover/
54 |
55 | # Translations
56 | *.mo
57 | *.pot
58 |
59 | # Django stuff:
60 | *.log
61 | local_settings.py
62 | db.sqlite3
63 | db.sqlite3-journal
64 |
65 | # Flask stuff:
66 | instance/
67 | .webassets-cache
68 |
69 | # Scrapy stuff:
70 | .scrapy
71 |
72 | # Sphinx documentation
73 | docs/_build/
74 |
75 | # PyBuilder
76 | .pybuilder/
77 | target/
78 |
79 | # Jupyter Notebook
80 | .ipynb_checkpoints
81 |
82 | # IPython
83 | profile_default/
84 | ipython_config.py
85 |
86 | # pyenv
87 | # For a library or package, you might want to ignore these files since the code is
88 | # intended to run in multiple environments; otherwise, check them in:
89 | # .python-version
90 |
91 | # pipenv
92 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
93 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
94 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
95 | # install all needed dependencies.
96 | #Pipfile.lock
97 |
98 | # poetry
99 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
100 | # This is especially recommended for binary packages to ensure reproducibility, and is more
101 | # commonly ignored for libraries.
102 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
103 | #poetry.lock
104 |
105 | # pdm
106 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
107 | #pdm.lock
108 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
109 | # in version control.
110 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
111 | .pdm.toml
112 | .pdm-python
113 | .pdm-build/
114 |
115 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
116 | __pypackages__/
117 |
118 | # Celery stuff
119 | celerybeat-schedule
120 | celerybeat.pid
121 |
122 | # SageMath parsed files
123 | *.sage.py
124 |
125 | # Environments
126 | .env
127 | .venv
128 | env/
129 | venv/
130 | ENV/
131 | env.bak/
132 | venv.bak/
133 |
134 | # Spyder project settings
135 | .spyderproject
136 | .spyproject
137 |
138 | # Rope project settings
139 | .ropeproject
140 |
141 | # mkdocs documentation
142 | /site
143 |
144 | # mypy
145 | .mypy_cache/
146 | .dmypy.json
147 | dmypy.json
148 |
149 | # Pyre type checker
150 | .pyre/
151 |
152 | # pytype static type analyzer
153 | .pytype/
154 |
155 | # Cython debug symbols
156 | cython_debug/
157 |
158 | # PyCharm
159 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
160 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
161 | # and can be added to the global gitignore or merged into this file. For a more nuclear
162 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
163 | #.idea/
164 |
--------------------------------------------------------------------------------
/DagentLogo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Extensible-AI/DAGent/8dd76ffe7c1bdd57ebfcf67a859f25d196632233/DagentLogo.png
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2024 Parth Sareen
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # DAGent - Directed Acyclic Graphs (DAGs) as AI Agents
2 |
3 |
4 |
5 |
6 |
7 | *DAGent is an opinionated Python library to create AI Agents quickly without overhead*
8 |
9 |
10 | [](https://pepy.tech/project/dagent)
11 |
12 | ## Quickstart
13 |
14 | ### Installing
15 | - `pip install dagent` or `rye add dagent`
16 | - Make sure you have the API key of your choice available in system. The default is `OPENAI_API_KEY`
17 |
18 |
19 | ### Get right into it
20 | See [dagent/examples/quickstart_simple_agent.py](dagent/examples/quickstart_simple_agent.py) for a quickstart example
21 |
22 |
23 | ### DAGent basics
24 |
25 | The idea behind dagent is to structure AI agents in to a workflow. This is done through setting each function up as a node in a graph.
26 |
27 | The agentic behavior is through the inferring of what function to run through the use of LLMs which is abstracted by a "Decision Node".
28 |
29 | `Tool`
30 | - A tool is just a function which the LLM can use.
31 | - It is helpful to have docstrings and annotations to assist the llm infer what is happening. This is recommended for larger functions/tools.
32 |
33 | `FunctionNode`
34 | - Runs a python function
35 | - Can be attached to a `DecisionNode` to be treated as a tool and allow an LLM to choose which function to run
36 |
37 | `DecisionNode`
38 | - This is where the llm picks a function to run from given options
39 | - The `.compile()` method autogenerates and saves tool descriptions under Tool. Run with param `force_load=True` if there are errors or if an option of tool changes
40 | - These tool/function descriptions get generated under a `Tool_JSON` folder. Feel free to edit tool descriptions if the agent is unreliable.
41 |
42 | `prev_output` param for functions:
43 | - If passing data from one function to another, make sure this param is in the function signature.
44 | - If extra params get passed in/weird stuff happens add a `**kwargs` to see if there are any hidden params which were passed and need to be handled
45 |
46 |
47 | ### DAGent Diagram
48 | ```mermaid
49 | graph TD
50 | A[Function Node] --> B[Decision Node]
51 | B --> C[Function Node]
52 | B --> E[Function Node]
53 | D --> F[Function Node]
54 | E --> G[Decision Node]
55 | F --> H[Function Node]
56 | G --> K[Function Node]
57 | G -- "Pick Function to Run" --> I[Function Node]
58 | G --> J[Function Node]
59 | I --> L[Function Node]
60 | J --> M[Function Node]
61 | K --> N[Function Node]
62 | K -- "Run Both " --> S[Function Node]
63 |
64 | %% Additional annotations
65 | B -- "Use a Function as a tool" --> D[Function Node]
66 |
67 | ```
68 |
69 | ## Using Different Models
70 |
71 | DAGent supports using different LLM models for inference and tool description generation. You can specify the model when calling `call_llm` or `call_llm_tool`, or when compiling the DecisionNode.
72 |
73 | For example, to use the `groq/llama3-70b-8192` model:
74 |
75 | ```python
76 |
77 | # Using groq with decision node
78 | decision_node1 = DecisionNode('groq/llama3-70b-8192')
79 |
80 | # Using ollama with decision node
81 | decision_node2 = DecisionNode('ollama_chat/llama3.1', api_base="http://localhost:11434")
82 |
83 | # Call llm function
84 | output = decision_node2.run(messages=[{'role': 'user', 'content': 'add the numbers 2 and 3'}])
85 |
86 | ```
87 |
88 | ### Other things to know
89 |
90 | - `prev_output` is needed in the function signature if you want to use the value from the prior function's value. Obviously the prior function should have returned something for this to work
91 | - If there are errors with too many params being passed into a function node, add `**kwargs` to your function
92 | - Args can be overriden at any time using the following (this merges the kwargs in the background with priority to the user):
93 |
94 | ```python
95 | add_two_nums_node.user_params = {
96 | # param_name : value
97 | a : 10
98 | }
99 | ```
100 |
101 | ## Feedback
102 |
103 | If you are interested in providing feedback, please use [this form](https://docs.google.com/forms/d/14EHPUEYGVV-eNj6HyUaQYokHUtouYOeWfriwHaHOARE).
104 |
105 | ## Acknowledgements
106 | Shoutout to:
107 | - [@omkizzy](https://x.com/omkizzy)
108 | - [@kaelan](https://github.com/Oasixer)
109 |
--------------------------------------------------------------------------------
/dagent/.gitignore:
--------------------------------------------------------------------------------
1 | # python generated files
2 | __pycache__/
3 | *.py[oc]
4 | build/
5 | dist/
6 | wheels/
7 | *.egg-info
8 | TOOL_JSON
9 | # venv
10 | .venv
11 |
--------------------------------------------------------------------------------
/dagent/.python-version:
--------------------------------------------------------------------------------
1 | 3.12.2
2 |
--------------------------------------------------------------------------------
/dagent/README.md:
--------------------------------------------------------------------------------
1 | # dagent
--------------------------------------------------------------------------------
/dagent/examples/add_two_nums.json:
--------------------------------------------------------------------------------
1 | {"type": "function", "function": {"name": "add_two_nums", "description": "Add two integer numbers", "parameters": {"type": "object", "properties": {"a": {"type": "integer", "description": "The first number to add"}, "b": {"type": "integer", "description": "The second number to add"}}, "required": ["a", "b"]}}}
--------------------------------------------------------------------------------
/dagent/examples/multiply_two_nums.json:
--------------------------------------------------------------------------------
1 | {"type": "function", "function": {"name": "multiply_two_nums", "description": "Multiply two integers", "parameters": {"type": "object", "properties": {"a": {"type": "integer", "description": "The first integer to multiply"}, "b": {"type": "integer", "description": "The second integer to multiply"}}, "required": ["a", "b"]}}}
--------------------------------------------------------------------------------
/dagent/examples/quickstart_local_simple_agent.py:
--------------------------------------------------------------------------------
1 |
2 | """
3 | This example demonstrates the main concepts of the dagent library using a local model:
4 | 1. Function Nodes: Represent individual operations in the workflow.
5 | 2. Decision Nodes: Use AI models to make decisions and route the workflow.
6 | 3. Node Linking: Connect nodes to create a directed acyclic graph (DAG).
7 | 4. Compilation: Prepare the DAG for execution.
8 | 5. Execution: Run the workflow starting from an entry point.
9 | """
10 |
11 | from dagent import DecisionNode, FunctionNode
12 | import argparse
13 |
14 | def add_two_nums(a: int, b: int) -> int:
15 | """A simple function to add two numbers."""
16 | return a + b
17 |
18 | def multiply_two_nums(a: int, b: int) -> int:
19 | """A simple function to multiply two numbers."""
20 | return a * b
21 |
22 | def print_result(prev_output: int) -> None:
23 | """
24 | Print the result from a previous node.
25 |
26 | Note: `prev_output` is automatically passed from the previous node.
27 | """
28 | print(prev_output)
29 | return prev_output
30 |
31 | def entry_func(input: str) -> str:
32 | """Entry point function for the workflow."""
33 | return input
34 |
35 | def main():
36 |
37 | # Initialize Function Nodes for basic arithmetic operations and result printing
38 | add_two_nums_node = FunctionNode(func=add_two_nums)
39 | multiply_two_nums_node = FunctionNode(func=multiply_two_nums)
40 | print_result_node = FunctionNode(func=print_result)
41 |
42 | # Initialize the entry point of the workflow
43 | entry_node = FunctionNode(func=entry_func)
44 |
45 | # Initialize a Decision Node configured to use a local AI model for decision-making
46 | decision_node = DecisionNode(model='ollama_chat/llama3.1', api_base="http://localhost:11434")
47 |
48 | # Link Nodes to define the workflow structure
49 | entry_node.next_nodes = [decision_node]
50 | decision_node.next_nodes = [add_two_nums_node, multiply_two_nums_node]
51 | add_two_nums_node.next_nodes = [print_result_node]
52 | multiply_two_nums_node.next_nodes = [print_result_node]
53 |
54 | # Compile the DAG to prepare it for execution
55 | entry_node.compile(force_load=False)
56 |
57 | # Execute the DAG in a loop to process user input dynamically
58 | while True:
59 | user_input = input("Enter your command: ")
60 | entry_node.run(input=user_input)
61 |
62 | if __name__ == "__main__":
63 | main()
64 |
--------------------------------------------------------------------------------
/dagent/examples/quickstart_simple_agent.py:
--------------------------------------------------------------------------------
1 |
2 | from dagent import DecisionNode, FunctionNode
3 | import logging
4 |
5 | """
6 | This example demonstrates the main concepts of the dagent library:
7 | 1. Function Nodes: Represent individual operations in the workflow.
8 | 2. Decision Nodes: Use AI models to make decisions and route the workflow.
9 | 3. Node Linking: Connect nodes to create a directed acyclic graph (DAG).
10 | 4. Compilation: Prepare the DAG for execution.
11 | 5. Execution: Run the workflow starting from an entry point.
12 | """
13 |
14 |
15 | # Can enable logging below to save logs to file
16 | # logging.basicConfig(level=logging.INFO,
17 | # format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
18 | # handlers=[logging.FileHandler('dagent_logs.log'), logging.StreamHandler()])
19 |
20 | def add_two_nums(a: int, b: int) -> int:
21 | """A simple function to add two numbers."""
22 | return a + b
23 |
24 |
25 | def multiply_two_nums(a: int, b: int) -> int:
26 | """A simple function to multiply two numbers."""
27 | return a * b
28 |
29 |
30 | def print_result(prev_output: int) -> None:
31 | """
32 | Print the result from a previous node.
33 |
34 | Note: `prev_output` is automatically passed from the previous node.
35 | """
36 | print(prev_output)
37 | return prev_output
38 |
39 |
40 | def entry_func(input: str) -> str:
41 | """Entry point function for the workflow."""
42 | return input
43 |
44 |
45 | def main():
46 | # Setup Function Nodes
47 | """
48 | FunctionNodes wrap regular Python functions, allowing them to be used in the DAG.
49 | """
50 | add_two_nums_node = FunctionNode(func=add_two_nums)
51 | multiply_two_nums_node = FunctionNode(func=multiply_two_nums)
52 | print_result_node = FunctionNode(func=print_result)
53 | entry_node = FunctionNode(func=entry_func)
54 |
55 | # Setup Decision Node
56 | """
57 | DecisionNodes use AI models to make routing decisions in the workflow.
58 | """
59 | decision_node = DecisionNode(model='gpt-4-0125-preview', api_base=None)
60 |
61 | # Link Nodes
62 | """
63 | Nodes are connected by setting their `next_nodes` attribute.
64 | This creates the structure of the directed acyclic graph (DAG).
65 | """
66 | entry_node.next_nodes = [decision_node]
67 |
68 | decision_node.next_nodes = [
69 | add_two_nums_node,
70 | multiply_two_nums_node,
71 | ]
72 |
73 | add_two_nums_node.next_nodes = [print_result_node]
74 | multiply_two_nums_node.next_nodes = [print_result_node]
75 |
76 | # Compile the DAG
77 | """
78 | Compilation prepares the DAG for execution, ensuring all nodes are properly linked.
79 | """
80 | entry_node.compile(force_load=True)
81 |
82 | # Execute the DAG
83 | while True:
84 | user_input = input("Enter your command: ")
85 | entry_node.run(input=user_input)
86 |
87 | if __name__ == "__main__":
88 | main()
89 |
--------------------------------------------------------------------------------
/dagent/examples/sql_agent_local.py:
--------------------------------------------------------------------------------
1 | from dagent import DecisionNode, FunctionNode, call_llm
2 | import logging
3 |
4 | example_schema = """
5 | CREATE TABLE Users (
6 | user_id INTEGER PRIMARY KEY,
7 | username TEXT NOT NULL,
8 | email TEXT NOT NULL,
9 | registration_date DATE
10 | );
11 |
12 | CREATE TABLE Products (
13 | product_id INTEGER PRIMARY KEY,
14 | product_name TEXT NOT NULL,
15 | price DECIMAL(10, 2) NOT NULL,
16 | stock_quantity INTEGER
17 | );
18 |
19 | CREATE TABLE Orders (
20 | order_id INTEGER PRIMARY KEY,
21 | user_id INTEGER,
22 | order_date DATE,
23 | total_amount DECIMAL(10, 2),
24 | FOREIGN KEY (user_id) REFERENCES Users(user_id)
25 | );
26 |
27 | CREATE TABLE OrderItems (
28 | order_item_id INTEGER PRIMARY KEY,
29 | order_id INTEGER,
30 | product_id INTEGER,
31 | quantity INTEGER,
32 | price_per_unit DECIMAL(10, 2),
33 | FOREIGN KEY (order_id) REFERENCES Orders(order_id),
34 | FOREIGN KEY (product_id) REFERENCES Products(product_id)
35 | );
36 |
37 | INSERT INTO Users (username, email, registration_date) VALUES
38 | ('john_doe', 'john@example.com', '2023-01-15'),
39 | ('jane_smith', 'jane@example.com', '2023-02-20'),
40 | ('bob_johnson', 'bob@example.com', '2023-03-10');
41 |
42 | INSERT INTO Products (product_name, price, stock_quantity) VALUES
43 | ('Laptop', 999.99, 50),
44 | ('Smartphone', 599.99, 100),
45 | ('Headphones', 79.99, 200);
46 |
47 | INSERT INTO Orders (user_id, order_date, total_amount) VALUES
48 | (1, '2023-04-01', 1079.98),
49 | (2, '2023-04-15', 599.99),
50 | (3, '2023-04-30', 159.98);
51 |
52 | INSERT INTO OrderItems (order_id, product_id, quantity, price_per_unit) VALUES
53 | (1, 1, 1, 999.99),
54 | (1, 3, 1, 79.99),
55 | (2, 2, 1, 599.99),
56 | (3, 3, 2, 79.99);
57 | """
58 |
59 |
60 | logging.basicConfig(level=logging.INFO,
61 | format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
62 | handlers=[logging.FileHandler('dagent_logs.log'), logging.StreamHandler()])
63 |
64 | def get_user_input() -> str:
65 | """Get user input from the command line."""
66 | user_input = input("Enter your command: ")
67 | return user_input
68 |
69 | def get_database_schema(prev_output: str) -> str:
70 | return example_schema
71 |
72 | def generate_sql(prev_output: str, database_schema: str) -> str:
73 | """Generate SQL from the user input."""
74 | sql = call_llm(model='ollama_chat/llama3.1', api_base="http://localhost:11434", messages=[{"role": "user", "content": f"Generate SQL from the user input: {prev_output} and the following database schema: {database_schema}"}])
75 | print('generated sql: ', sql)
76 | return sql
77 |
78 | def show_results(prev_output: str) -> str:
79 | print('results: ', prev_output)
80 |
81 |
82 | get_user_input_node = FunctionNode(func=get_user_input)
83 | get_database_schema_node = FunctionNode(func=get_database_schema)
84 | generate_sql_node = FunctionNode(func=generate_sql, user_params={"database_schema": example_schema})
85 | show_results_node = FunctionNode(func=show_results)
86 |
87 | decision_node = DecisionNode(model='ollama_chat/llama3.1', api_base="http://localhost:11434")
88 |
89 | get_user_input_node.next_nodes = [decision_node]
90 | decision_node.next_nodes = [get_database_schema_node, generate_sql_node]
91 | get_database_schema_node.next_nodes = [generate_sql_node]
92 | generate_sql_node.next_nodes = [show_results_node]
93 |
94 |
95 | if __name__ == "__main__":
96 | get_user_input_node.compile(force_load=False)
97 | get_user_input_node.run()
--------------------------------------------------------------------------------
/dagent/pyproject.toml:
--------------------------------------------------------------------------------
1 | [project]
2 | name = "dagent"
3 | version = "0.0.9"
4 | description = "AI Agents as DAGs - Directed Acyclic Graphs"
5 | authors = [
6 | { name = "Parth Sareen", email = "parth@extensible.dev" }
7 | ]
8 | dependencies = [
9 | "litellm>=1.44.19",
10 | ]
11 | readme = "../README.md"
12 | requires-python = ">= 3.10"
13 |
14 | [build-system]
15 | requires = ["hatchling"]
16 | build-backend = "hatchling.build"
17 |
18 | [tool.rye]
19 | managed = true
20 | dev-dependencies = []
21 |
22 | [tool.hatch.metadata]
23 | allow-direct-references = true
24 |
25 | [tool.hatch.build.targets.wheel]
26 | packages = ["src/dagent"]
27 |
--------------------------------------------------------------------------------
/dagent/requirements-dev.lock:
--------------------------------------------------------------------------------
1 | # generated by rye
2 | # use `rye lock` or `rye sync` to update this lockfile
3 | #
4 | # last locked with the following flags:
5 | # pre: false
6 | # features: []
7 | # all-features: false
8 | # with-sources: false
9 | # generate-hashes: false
10 |
11 | -e file:.
12 | aiohttp==3.9.5
13 | # via litellm
14 | aiosignal==1.3.1
15 | # via aiohttp
16 | annotated-types==0.7.0
17 | # via pydantic
18 | anyio==4.4.0
19 | # via httpx
20 | # via openai
21 | attrs==23.2.0
22 | # via aiohttp
23 | # via jsonschema
24 | # via referencing
25 | certifi==2024.6.2
26 | # via httpcore
27 | # via httpx
28 | # via requests
29 | charset-normalizer==3.3.2
30 | # via requests
31 | click==8.1.7
32 | # via litellm
33 | distro==1.9.0
34 | # via openai
35 | filelock==3.14.0
36 | # via huggingface-hub
37 | frozenlist==1.4.1
38 | # via aiohttp
39 | # via aiosignal
40 | fsspec==2024.6.0
41 | # via huggingface-hub
42 | h11==0.14.0
43 | # via httpcore
44 | httpcore==1.0.5
45 | # via httpx
46 | httpx==0.27.0
47 | # via openai
48 | huggingface-hub==0.23.3
49 | # via tokenizers
50 | idna==3.7
51 | # via anyio
52 | # via httpx
53 | # via requests
54 | # via yarl
55 | importlib-metadata==7.1.0
56 | # via litellm
57 | jinja2==3.1.4
58 | # via litellm
59 | jiter==0.5.0
60 | # via openai
61 | jsonschema==4.23.0
62 | # via litellm
63 | jsonschema-specifications==2023.12.1
64 | # via jsonschema
65 | litellm==1.44.19
66 | # via dagent
67 | markupsafe==2.1.5
68 | # via jinja2
69 | multidict==6.0.5
70 | # via aiohttp
71 | # via yarl
72 | openai==1.42.0
73 | # via litellm
74 | packaging==24.0
75 | # via huggingface-hub
76 | pydantic==2.7.3
77 | # via litellm
78 | # via openai
79 | pydantic-core==2.18.4
80 | # via pydantic
81 | python-dotenv==1.0.1
82 | # via litellm
83 | pyyaml==6.0.1
84 | # via huggingface-hub
85 | referencing==0.35.1
86 | # via jsonschema
87 | # via jsonschema-specifications
88 | regex==2024.5.15
89 | # via tiktoken
90 | requests==2.32.3
91 | # via huggingface-hub
92 | # via litellm
93 | # via tiktoken
94 | rpds-py==0.20.0
95 | # via jsonschema
96 | # via referencing
97 | sniffio==1.3.1
98 | # via anyio
99 | # via httpx
100 | # via openai
101 | tiktoken==0.7.0
102 | # via litellm
103 | tokenizers==0.19.1
104 | # via litellm
105 | tqdm==4.66.4
106 | # via huggingface-hub
107 | # via openai
108 | typing-extensions==4.12.2
109 | # via huggingface-hub
110 | # via openai
111 | # via pydantic
112 | # via pydantic-core
113 | urllib3==2.2.1
114 | # via requests
115 | yarl==1.9.4
116 | # via aiohttp
117 | zipp==3.19.2
118 | # via importlib-metadata
119 |
--------------------------------------------------------------------------------
/dagent/requirements.lock:
--------------------------------------------------------------------------------
1 | # generated by rye
2 | # use `rye lock` or `rye sync` to update this lockfile
3 | #
4 | # last locked with the following flags:
5 | # pre: false
6 | # features: []
7 | # all-features: false
8 | # with-sources: false
9 | # generate-hashes: false
10 |
11 | -e file:.
12 | aiohttp==3.9.5
13 | # via litellm
14 | aiosignal==1.3.1
15 | # via aiohttp
16 | annotated-types==0.7.0
17 | # via pydantic
18 | anyio==4.4.0
19 | # via httpx
20 | # via openai
21 | attrs==23.2.0
22 | # via aiohttp
23 | # via jsonschema
24 | # via referencing
25 | certifi==2024.6.2
26 | # via httpcore
27 | # via httpx
28 | # via requests
29 | charset-normalizer==3.3.2
30 | # via requests
31 | click==8.1.7
32 | # via litellm
33 | distro==1.9.0
34 | # via openai
35 | filelock==3.14.0
36 | # via huggingface-hub
37 | frozenlist==1.4.1
38 | # via aiohttp
39 | # via aiosignal
40 | fsspec==2024.6.0
41 | # via huggingface-hub
42 | h11==0.14.0
43 | # via httpcore
44 | httpcore==1.0.5
45 | # via httpx
46 | httpx==0.27.0
47 | # via openai
48 | huggingface-hub==0.23.3
49 | # via tokenizers
50 | idna==3.7
51 | # via anyio
52 | # via httpx
53 | # via requests
54 | # via yarl
55 | importlib-metadata==7.1.0
56 | # via litellm
57 | jinja2==3.1.4
58 | # via litellm
59 | jiter==0.5.0
60 | # via openai
61 | jsonschema==4.23.0
62 | # via litellm
63 | jsonschema-specifications==2023.12.1
64 | # via jsonschema
65 | litellm==1.44.19
66 | # via dagent
67 | markupsafe==2.1.5
68 | # via jinja2
69 | multidict==6.0.5
70 | # via aiohttp
71 | # via yarl
72 | openai==1.42.0
73 | # via litellm
74 | packaging==24.0
75 | # via huggingface-hub
76 | pydantic==2.7.3
77 | # via litellm
78 | # via openai
79 | pydantic-core==2.18.4
80 | # via pydantic
81 | python-dotenv==1.0.1
82 | # via litellm
83 | pyyaml==6.0.1
84 | # via huggingface-hub
85 | referencing==0.35.1
86 | # via jsonschema
87 | # via jsonschema-specifications
88 | regex==2024.5.15
89 | # via tiktoken
90 | requests==2.32.3
91 | # via huggingface-hub
92 | # via litellm
93 | # via tiktoken
94 | rpds-py==0.20.0
95 | # via jsonschema
96 | # via referencing
97 | sniffio==1.3.1
98 | # via anyio
99 | # via httpx
100 | # via openai
101 | tiktoken==0.7.0
102 | # via litellm
103 | tokenizers==0.19.1
104 | # via litellm
105 | tqdm==4.66.4
106 | # via huggingface-hub
107 | # via openai
108 | typing-extensions==4.12.2
109 | # via huggingface-hub
110 | # via openai
111 | # via pydantic
112 | # via pydantic-core
113 | urllib3==2.2.1
114 | # via requests
115 | yarl==1.9.4
116 | # via aiohttp
117 | zipp==3.19.2
118 | # via importlib-metadata
119 |
--------------------------------------------------------------------------------
/dagent/src/dagent/DagNode.py:
--------------------------------------------------------------------------------
1 |
2 | class DagNode:
3 | def __init__(
4 | self,
5 | func: callable,
6 | next_nodes: dict[str, 'DagNode'] = None
7 | ):
8 | self.func = func
9 | self.next_nodes = next_nodes or {}
10 | self.node_result = None
11 |
12 | def compile(self):
13 | """
14 | Use an LLM to generate tool descriptions from functions to run the DAG
15 | """
16 | return NotImplemented
17 |
18 | def run(self, **kwargs) -> any:
19 | raise NotImplementedError("Subclasses should implement this method")
--------------------------------------------------------------------------------
/dagent/src/dagent/DecisionNode.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | import inspect
4 | from .DagNode import DagNode
5 |
6 | from .base_functions import call_llm_tool, create_tool_desc
7 | import logging
8 |
9 |
10 | class DecisionNode(DagNode):
11 | def __init__(
12 | self,
13 | func: callable = call_llm_tool,
14 | next_nodes: list | None = None,
15 | user_params: dict | None = None,
16 | model: str = 'gpt-4-0125-preview',
17 | api_base: str | None = None,
18 | tool_json_dir: str = 'Tool_JSON',
19 | retry_json_count: int = 3,
20 | max_tool_calls: int | None = None
21 |
22 | ):
23 | super().__init__(func, next_nodes)
24 | self.user_params = user_params or {}
25 | self.logger = logging.getLogger(__name__)
26 | self.compiled = False
27 | self.api_base = api_base
28 | self.model = model
29 | self.tool_json_dir = tool_json_dir
30 | self.retry_json_count = retry_json_count
31 | self.max_tool_calls = max_tool_calls
32 | self.logger.info(f"DecisionNode initialized with model: {model}, api_base: {api_base}, max_tool_calls: {max_tool_calls}")
33 |
34 |
35 | def compile(self, force_load=False) -> None:
36 | self.logger.info("Starting compilation process")
37 | self.compiled = True
38 |
39 | if isinstance(self.next_nodes, list):
40 | self.next_nodes = {node.func.__name__: node for node in self.next_nodes}
41 | self.logger.debug(f"Converted next_nodes list to dictionary: {self.next_nodes.keys()}")
42 |
43 | for _, next_node in self.next_nodes.items():
44 | func_name = os.path.join(self.tool_json_dir, next_node.func.__name__ + '.json')
45 | self.logger.info(f"Compiling tool description for function: {next_node.func.__name__}")
46 |
47 | if force_load or not os.path.exists(func_name):
48 | self.logger.debug(f"Creating new tool description for {next_node.func.__name__}")
49 | os.makedirs(self.tool_json_dir, exist_ok=True)
50 | try:
51 | current_retry_count = 0
52 | tool_desc = create_tool_desc(model=self.model, function_desc=inspect.getsource(next_node.func), api_base=self.api_base)
53 |
54 | while not tool_desc and current_retry_count < self.retry_json_count:
55 | self.logger.warning(f"Retry {current_retry_count + 1} for creating tool description of {next_node.func.__name__}")
56 | tool_desc = create_tool_desc(model=self.model, function_desc=inspect.getsource(next_node.func), api_base=self.api_base)
57 | current_retry_count += 1
58 |
59 | if not tool_desc:
60 | error_msg = f"Tool description for {next_node.func.__name__} could not be generated, recommend generating manually and storing under {func_name}.json in {self.tool_json_dir} directory"
61 | self.logger.error(error_msg)
62 | raise ValueError(error_msg)
63 |
64 | tool_desc_json = json.loads(tool_desc)
65 | self.logger.debug(f"Successfully created tool description for {next_node.func.__name__}")
66 | except Exception as e:
67 | self.logger.error(f"Error creating tool description for {next_node.func.__name__}: {e}")
68 | raise e
69 | with open(func_name, 'w') as f:
70 | json.dump(tool_desc_json, f)
71 | self.logger.info(f"Saved tool description for {next_node.func.__name__} to {func_name}")
72 | else:
73 | self.logger.info(f"Loading existing tool description for {next_node.func.__name__} from {func_name}")
74 | with open(func_name, 'r') as f:
75 | tool_desc_json = json.load(f)
76 |
77 | next_node.tool_description = tool_desc_json
78 | next_node.compile()
79 | self.logger.info("Compilation process completed successfully for DecisionNode")
80 |
81 |
82 | def run(self, **kwargs) -> any:
83 | self.logger.info("Starting DecisionNode run")
84 | if not self.next_nodes:
85 | error_msg = "Next nodes not specified for LLM call"
86 | self.logger.error(error_msg)
87 | raise ValueError(error_msg)
88 |
89 | if not self.compiled:
90 | error_msg = "Node not compiled. Please run compile() method from the entry node first"
91 | self.logger.error(error_msg)
92 | raise ValueError(error_msg)
93 |
94 | if not kwargs.get('prev_output') and not kwargs.get('messages'):
95 | error_msg = "No input data provided for LLM call"
96 | self.logger.error(error_msg)
97 | raise ValueError(error_msg)
98 |
99 | # Get existing messages or create an empty list
100 | messages = self.user_params.get('messages', kwargs.get('messages', []))
101 | # Add previous output as a user message if available
102 | if 'prev_output' in kwargs:
103 | messages.append({'role': 'user', 'content': str(kwargs['prev_output'])})
104 |
105 | # Update kwargs with the final messages list
106 | kwargs['messages'] = messages
107 | self.logger.info(f"Prepared messages for LLM call: {messages}")
108 |
109 | try:
110 | self.logger.info(f"Calling LLM tool with model: {self.model}")
111 | # The 'messages' param is passed in through the kwargs
112 | response = call_llm_tool(model=self.model, tools=[node.tool_description for node in self.next_nodes.values()], api_base=self.api_base, **kwargs)
113 | tool_calls = getattr(response, 'tool_calls', None)
114 | if not tool_calls:
115 | error_msg = "No tool calls received from LLM tool response"
116 | self.logger.error(error_msg)
117 | raise ValueError(error_msg)
118 |
119 | self.logger.info(f"Received {len(tool_calls)} tool call(s) from LLM")
120 | self.logger.info(f"Tool calls: {tool_calls}")
121 | if self.logger.getEffectiveLevel() == logging.DEBUG:
122 | proceed = input("Debug mode is active. Do you want to proceed with the tool calls? (yes(y)/no(n)): ")
123 | print('proceed: ',proceed)
124 | if proceed.lower() not in ['yes', 'y']:
125 | self.logger.debug("\nUser chose not to proceed with the tool calls.")
126 | return
127 |
128 | # Apply max_tool_calls constraint if set
129 | if self.max_tool_calls is not None:
130 | tool_calls = tool_calls[:self.max_tool_calls]
131 | self.logger.info(f"Constrained to {len(tool_calls)} tool call(s) due to max_tool_calls setting")
132 |
133 | for tool_call in tool_calls:
134 | function_name = tool_call.function.name
135 | function_args = json.loads(tool_call.function.arguments)
136 | self.logger.debug(f"Processing tool call for function: {function_name} with arguments: {function_args}")
137 |
138 | next_node = self.next_nodes.get(function_name)
139 | if not next_node:
140 | error_msg = f"Function name '{function_name}' not found in next_nodes. Something went wrong"
141 | self.logger.error(error_msg)
142 | raise KeyError(error_msg)
143 |
144 | # Merge user_params with function_args, giving precedence to user_params
145 | merged_args = {**function_args, **self.user_params}
146 | # Print kwargs for debugging
147 | self.logger.info(f"Current kwargs: {kwargs}")
148 | print(f"Current kwargs: {kwargs}")
149 |
150 | if 'prev_output' in kwargs:
151 | merged_args['prev_output'] = kwargs['prev_output']
152 |
153 | func_signature = inspect.signature(next_node.func)
154 |
155 | # TODO: Manage through derived data models
156 | filtered_args = {k: v for k, v in merged_args.items() if k in func_signature.parameters}
157 | self.logger.info(f"Filtered arguments for {function_name}: {filtered_args}")
158 |
159 | self.logger.info(f"Executing next node: {function_name}")
160 | next_node.run(**filtered_args)
161 |
162 | except (AttributeError, json.JSONDecodeError) as e:
163 | error_msg = f"Error parsing tool call: {e}"
164 | self.logger.error(error_msg)
165 | raise ValueError(error_msg)
166 | except Exception as e:
167 | error_msg = f"LLM tool call failed: {e}"
168 | self.logger.error(error_msg)
169 | raise RuntimeError(error_msg)
170 | self.logger.info("DecisionNode run completed successfully")
171 |
--------------------------------------------------------------------------------
/dagent/src/dagent/FunctionNode.py:
--------------------------------------------------------------------------------
1 | import logging
2 | from .DagNode import DagNode
3 |
4 | class FunctionNode(DagNode):
5 | def __init__(self, func: callable, tool_description = dict | None, next_nodes: list | None = None, user_params: dict | None = None):
6 | super().__init__(func, next_nodes)
7 | self.tool_description = tool_description
8 | self.user_params = user_params or {}
9 | self.compiled = False
10 | self.node_result = None
11 | self.logger = logging.getLogger(__name__)
12 |
13 | def compile(self, force_load=False) -> None:
14 | self.logger.info(f"Compiling FunctionNode for function: {self.func.__name__}")
15 | self.compiled = True
16 | if isinstance(self.next_nodes, list):
17 | self.logger.debug("Converting next_nodes from list to dictionary")
18 | self.next_nodes = {node.func.__name__: node for node in self.next_nodes}
19 | for node_name, next_node in self.next_nodes.items():
20 | self.logger.debug(f"Compiling next node: {node_name}")
21 | next_node.compile(force_load=force_load)
22 |
23 | def run(self, **kwargs) -> any:
24 | if not self.compiled:
25 | self.logger.error("Attempted to run uncompiled node")
26 | raise ValueError("Node not compiled. Please run compile() method from the entry node first")
27 |
28 | self.logger.info(f"Running FunctionNode for function: {self.func.__name__}")
29 | merged_params = {**self.user_params, **kwargs}
30 | self.logger.debug(f"Merged parameters: {merged_params}")
31 |
32 | try:
33 | self.node_result = self.func(**merged_params)
34 | self.logger.debug(f"Function result: {self.node_result}")
35 | except Exception as e:
36 | self.logger.error(f"Error executing function {self.func.__name__}: {str(e)}")
37 | raise
38 |
39 | # Pass the result to the next nodes if any
40 | # TODO: figure out param logic pattern
41 | if not self.next_nodes:
42 | self.logger.info(f"No next nodes after {self.func.__name__}, returning result")
43 | return self.node_result
44 | for node_name, next_node in self.next_nodes.items():
45 | self.logger.info(f"Passing result to next node: {node_name}")
46 | # TODO: creating data models for passing info between nodes
47 | params = {'prev_output': self.node_result, **next_node.user_params}
48 | self.logger.debug(f"Parameters for next node: {params}")
49 | next_node.run(**params)
50 |
--------------------------------------------------------------------------------
/dagent/src/dagent/__init__.py:
--------------------------------------------------------------------------------
1 | from .base_functions import call_llm_tool, call_llm
2 | from .DecisionNode import DecisionNode
3 | from .FunctionNode import FunctionNode
--------------------------------------------------------------------------------
/dagent/src/dagent/base_functions.py:
--------------------------------------------------------------------------------
1 | import inspect
2 | from litellm import completion
3 | # For more info: https://litellm.vercel.app/docs/completion/input
4 |
5 |
6 | def call_llm_tool(model, messages, tools, api_base=None, **kwargs):
7 | response = completion(
8 | model=model,
9 | messages=messages,
10 | tools=tools,
11 | api_base=api_base,
12 | )
13 | return response.choices[0].message
14 |
15 |
16 | def create_tool_desc(model, function_desc, api_base=None):
17 | example = {
18 | "type": "function",
19 | "function": {
20 | "name": "get_calendar_events",
21 | "description": "Get calendar events within a specified time range",
22 | "parameters": {
23 | "type": "object",
24 | "properties": {
25 | "start_time": {
26 | "type": "string",
27 | "description": "The start time for the event search, in ISO format",
28 | },
29 | "end_time": {
30 | "type": "string",
31 | "description": "The end time for the event search, in ISO format",
32 | },
33 | },
34 | "required": ["start_time", "end_time"],
35 | },
36 | }
37 | }
38 | messages = [{"role": "user", "content": "Create a json for the attached function: {} using the following pattern for the json: {}. Don't add anything extra. Make sure everything follows a valid json format".format(function_desc, example)}]
39 | response = completion(
40 | model=model,
41 | response_format={"type":"json_object"},
42 | messages=messages,
43 | api_base=api_base
44 | )
45 | return response.choices[0].message.content
46 |
47 |
48 | def call_llm(model, messages, api_base=None, **kwargs):
49 | response = completion(
50 | model=model,
51 | messages=messages,
52 | api_base=api_base,
53 | **kwargs
54 | )
55 | return response.choices[0].message.content
56 |
57 |
--------------------------------------------------------------------------------
/dagent/tests/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Extensible-AI/DAGent/8dd76ffe7c1bdd57ebfcf67a859f25d196632233/dagent/tests/__init__.py
--------------------------------------------------------------------------------
/dagent/tests/test_DecisionNode.py:
--------------------------------------------------------------------------------
1 | import unittest
2 | import os
3 | import shutil
4 | from unittest.mock import patch, MagicMock
5 | from dagent.DecisionNode import DecisionNode
6 | from dagent.FunctionNode import FunctionNode
7 |
8 | class TestDecisionNode(unittest.TestCase):
9 | def setUp(self):
10 | self.test_dir = 'test_tool_json'
11 | os.makedirs(self.test_dir, exist_ok=True)
12 |
13 | def tearDown(self):
14 | if os.path.exists(self.test_dir):
15 | shutil.rmtree(self.test_dir)
16 |
17 | @patch('dagent.DecisionNode.call_llm_tool')
18 | @patch('dagent.DecisionNode.create_tool_desc')
19 | def test_decision_node(self, mock_create_tool_desc, mock_call_llm_tool):
20 | # Mock function for testing
21 | def test_function(arg1, arg2):
22 | return f"Result: {arg1}, {arg2}"
23 |
24 | # Mock create_tool_desc to return a valid tool description
25 | mock_create_tool_desc.return_value = '{"type": "function", "function": {"name": "test_function", "parameters": {"type": "object", "properties": {"arg1": {"type": "string"}, "arg2": {"type": "string"}}, "required": ["arg1", "arg2"]}}}'
26 |
27 | # Mock call_llm_tool to return a valid response
28 | mock_response = MagicMock()
29 | mock_tool_call = MagicMock()
30 | mock_tool_call.function = MagicMock()
31 | mock_tool_call.function.name = 'test_function'
32 | mock_tool_call.function.arguments = '{"arg1": "hello", "arg2": "world"}'
33 | mock_response.tool_calls = [mock_tool_call]
34 | mock_call_llm_tool.return_value = mock_response
35 |
36 | # Create DecisionNode with FunctionNode as next node
37 | function_node = FunctionNode(func=test_function)
38 | decision_node = DecisionNode(next_nodes={'test_function': function_node})
39 |
40 | # Compile the decision node
41 | decision_node.compile(tool_json_dir=self.test_dir)
42 |
43 | # Check if the tool description file was created
44 | self.assertTrue(os.path.exists(os.path.join(self.test_dir, 'test_function.json')))
45 |
46 | # Run the decision node
47 | with patch.object(function_node, 'run') as mock_function_run:
48 | decision_node.run(messages=[{'role': 'user', 'content': 'Test message'}])
49 |
50 | # Assert that the function node's run method was called with correct arguments
51 | mock_function_run.assert_called_once_with(arg1='hello', arg2='world')
52 |
53 | # Assert that create_tool_desc and call_llm_tool were called
54 | mock_create_tool_desc.assert_called_once()
55 | mock_call_llm_tool.assert_called_once()
56 |
57 | if __name__ == '__main__':
58 | unittest.main()
--------------------------------------------------------------------------------
/dagent/tests/test_FunctionNode.py:
--------------------------------------------------------------------------------
1 | import unittest
2 | from unittest.mock import MagicMock
3 | from dagent.FunctionNode import FunctionNode
4 |
5 | class TestFunctionNode(unittest.TestCase):
6 | def setUp(self):
7 | def test_func(a, b):
8 | return a + b
9 | self.function_node = FunctionNode(func=test_func)
10 | self.function_node.tool_description = None
11 |
12 |
13 | def test_init(self):
14 | self.assertIsInstance(self.function_node, FunctionNode)
15 | self.assertEqual(self.function_node.func.__name__, 'test_func')
16 | # Check if tool_description is either None or an empty dict
17 | self.assertTrue(self.function_node.tool_description is None or self.function_node.tool_description == {})
18 | self.assertEqual(self.function_node.user_params, {})
19 | self.assertFalse(self.function_node.compiled)
20 | self.assertIsNone(self.function_node.node_result)
21 |
22 | def test_compile(self):
23 | def identity(x):
24 | return x
25 | next_node = FunctionNode(func=identity)
26 | self.function_node.next_nodes = [next_node]
27 | self.function_node.compile()
28 | self.assertTrue(self.function_node.compiled)
29 | self.assertIsInstance(self.function_node.next_nodes, dict)
30 | # Changed assertion for next_nodes key
31 | self.assertIn('identity', self.function_node.next_nodes)
32 |
33 | def test_compile(self):
34 | def identity(x):
35 | return x
36 | next_node = FunctionNode(func=identity)
37 | self.function_node.next_nodes = [next_node]
38 | self.function_node.compile()
39 | self.assertTrue(self.function_node.compiled)
40 | self.assertIsInstance(self.function_node.next_nodes, dict)
41 | self.assertIn('identity', self.function_node.next_nodes)
42 |
43 | def test_run_without_compile(self):
44 | with self.assertRaises(ValueError):
45 | self.function_node.run(a=1, b=2)
46 |
47 | def test_run_with_compile(self):
48 | self.function_node.compile()
49 | result = self.function_node.run(a=1, b=2)
50 | self.assertEqual(result, 3)
51 | self.assertEqual(self.function_node.node_result, 3)
52 |
53 | def test_run_with_user_params(self):
54 | self.function_node.user_params = {'b': 5}
55 | self.function_node.compile()
56 | result = self.function_node.run(a=1)
57 | self.assertEqual(result, 6)
58 |
59 | def test_run_with_next_nodes(self):
60 | def double(x):
61 | return x * 2
62 | next_node = FunctionNode(func=double)
63 | self.function_node.next_nodes = [next_node]
64 | self.function_node.compile()
65 |
66 | next_node.run = MagicMock()
67 | self.function_node.run(a=1, b=2)
68 |
69 | next_node.run.assert_called_once_with(prev_output=3)
70 |
71 | if __name__ == '__main__':
72 | unittest.main()
73 |
--------------------------------------------------------------------------------
/dagent/tests/test_base_functions.py:
--------------------------------------------------------------------------------
1 | import json
2 | import inspect
3 | import argparse
4 | from dagent.base_functions import *
5 |
6 | def add_two_nums(a: int, b: int) -> int:
7 | return a + b
8 |
9 | def run_llm(model, api_base=None):
10 | # Run `call_llm`
11 | output = call_llm(model, [{'role': 'user', 'content': 'add the numbers 2 and 3'}], api_base=api_base)
12 | print(f'{model} output:', output)
13 |
14 | # Create tool description for `add_two_nums` function
15 | desc = create_tool_desc(model=model, function_desc=inspect.getsource(add_two_nums), api_base=api_base)
16 | print(f'{model} tool desc:', desc, end='\n\n')
17 |
18 | tool_desc_json = json.loads(desc)
19 |
20 | # Run `call_llm_tool`
21 | output = call_llm_tool(model, [{'role': 'user', 'content': 'add the numbers 2 and 3 using the provided tool'}], tools=[tool_desc_json], api_base=api_base)
22 |
23 | tool_calls = getattr(output, 'tool_calls', None)
24 | if not tool_calls:
25 | raise ValueError("No tool calls received from LLM tool response")
26 |
27 | function_name = tool_calls[0].function.name
28 | print(f'{model} output func name:', function_name, end='\n\n')
29 |
30 | def main():
31 | parser = argparse.ArgumentParser(description="Run LLM models")
32 | parser.add_argument('model', choices=['groq', 'ollama', 'gpt4'], help="Select the model to run")
33 | args = parser.parse_args()
34 |
35 | if args.model == 'groq':
36 | run_llm('groq/llama3-70b-8192')
37 | elif args.model == 'ollama':
38 | run_llm('ollama_chat/llama3.1', api_base="http://localhost:11434")
39 | elif args.model == 'gpt4':
40 | run_llm('gpt-4-0125-preview')
41 |
42 | if __name__ == "__main__":
43 | main()
44 |
--------------------------------------------------------------------------------
/dagent/tests/test_linking.py:
--------------------------------------------------------------------------------
1 | import unittest
2 | from unittest.mock import patch, MagicMock
3 | from dagent import DecisionNode, FunctionNode
4 |
5 | def add_two_nums(a: int, b: int) -> int:
6 | return a + b
7 |
8 | def multiply_two_nums(a: int, b: int) -> int:
9 | return a * b
10 |
11 | class TestLinking(unittest.TestCase):
12 | def setUp(self):
13 | self.add_node = FunctionNode(func=add_two_nums)
14 | self.multiply_node = FunctionNode(func=multiply_two_nums)
15 | self.decision_node = DecisionNode(model='gpt-4-0125-preview', api_base=None)
16 |
17 | self.decision_node.next_nodes = {
18 | 'add_two_nums': self.add_node,
19 | 'multiply_two_nums': self.multiply_node
20 | }
21 |
22 | @patch('dagent.DecisionNode.call_llm_tool')
23 | def test_decision_node_linking(self, mock_call_llm_tool):
24 | # Mock the LLM response to simulate choosing the add_two_nums function
25 | mock_response = MagicMock()
26 | mock_function = MagicMock()
27 | mock_function.name = 'add_two_nums'
28 | mock_function.arguments = '{"a": 2, "b": 3}'
29 | mock_response.tool_calls = [MagicMock(function=mock_function)]
30 | mock_call_llm_tool.return_value = mock_response
31 |
32 | # Compile the decision node
33 | self.decision_node.compile()
34 |
35 | # Run the decision node
36 | with patch.object(self.add_node, 'run') as mock_add_run:
37 | self.decision_node.run(messages=[{'role': 'user', 'content': 'Add 2 and 3'}])
38 |
39 | # Assert that the add_two_nums function was called
40 | mock_add_run.assert_called_once_with(a=2, b=3)
41 |
42 | @patch('dagent.DecisionNode.call_llm_tool')
43 | def test_decision_node_linking_multiply(self, mock_call_llm_tool):
44 | # Mock the LLM response to simulate choosing the multiply_two_nums function
45 | mock_response = MagicMock()
46 | mock_function = MagicMock()
47 | mock_function.name = 'multiply_two_nums'
48 | mock_function.arguments = '{"a": 4, "b": 5}'
49 | mock_response.tool_calls = [MagicMock(function=mock_function)]
50 | mock_call_llm_tool.return_value = mock_response
51 |
52 | # Compile the decision node
53 | self.decision_node.compile()
54 |
55 | # Run the decision node
56 | with patch.object(self.multiply_node, 'run') as mock_multiply_run:
57 | self.decision_node.run(messages=[{'role': 'user', 'content': 'Multiply 4 and 5'}])
58 |
59 | # Assert that the multiply_two_nums function was called
60 | mock_multiply_run.assert_called_once_with(a=4, b=5)
61 |
62 | if __name__ == '__main__':
63 | unittest.main()
--------------------------------------------------------------------------------
/notes.md:
--------------------------------------------------------------------------------
1 | ## Notes
2 | - [ ] Look if things run in memory and how to isolate for large workflows -> e.g. funcA(funcB(...)) -> funcA(...) -> funcB(...)
3 | - [ ] Side effects/mutations
4 | - [ ] Creating a data model for communication between functions + schema validation -> autogenerate?
5 | - [ ] Logging
6 | - [ ] Alerting on error
7 | - [x] Add a compile method to derive data models and tool descriptions
8 | - [ ] Docker
9 | - [ ] simple chat memory
10 |
--------------------------------------------------------------------------------