├── .github └── workflows │ └── python-publish.yml ├── .gitignore ├── DagentLogo.png ├── LICENSE ├── README.md ├── dagent ├── .gitignore ├── .python-version ├── README.md ├── examples │ ├── add_two_nums.json │ ├── multiply_two_nums.json │ ├── quickstart_local_simple_agent.py │ ├── quickstart_simple_agent.py │ └── sql_agent_local.py ├── pyproject.toml ├── requirements-dev.lock ├── requirements.lock ├── src │ └── dagent │ │ ├── DagNode.py │ │ ├── DecisionNode.py │ │ ├── FunctionNode.py │ │ ├── __init__.py │ │ └── base_functions.py └── tests │ ├── __init__.py │ ├── test_DecisionNode.py │ ├── test_FunctionNode.py │ ├── test_base_functions.py │ └── test_linking.py └── notes.md /.github/workflows/python-publish.yml: -------------------------------------------------------------------------------- 1 | # This workflow will upload a Python Package using Twine when a release is created 2 | # For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python#publishing-to-package-registries 3 | 4 | # This workflow uses actions that are not certified by GitHub. 5 | # They are provided by a third-party and are governed by 6 | # separate terms of service, privacy policy, and support 7 | # documentation. 8 | 9 | name: Upload Python Package 10 | 11 | on: 12 | release: 13 | types: [published] 14 | 15 | permissions: 16 | contents: read 17 | 18 | jobs: 19 | deploy: 20 | 21 | runs-on: ubuntu-latest 22 | 23 | steps: 24 | - uses: actions/checkout@v4 25 | - name: Set up Python 26 | uses: actions/setup-python@v3 27 | with: 28 | python-version: '3.x' 29 | - name: Install dependencies 30 | run: | 31 | cd dagent 32 | python -m pip install --upgrade pip 33 | pip install build 34 | - name: Build package 35 | run: python -m build 36 | - name: Publish package 37 | uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29 38 | with: 39 | user: __token__ 40 | password: ${{ secrets.PYPI_API_TOKEN }} 41 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | Tool_JSON/ 3 | __pycache__/ 4 | *.py[cod] 5 | *$py.class 6 | 7 | # C extensions 8 | *.so 9 | 10 | # Distribution / packaging 11 | .Python 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | cover/ 54 | 55 | # Translations 56 | *.mo 57 | *.pot 58 | 59 | # Django stuff: 60 | *.log 61 | local_settings.py 62 | db.sqlite3 63 | db.sqlite3-journal 64 | 65 | # Flask stuff: 66 | instance/ 67 | .webassets-cache 68 | 69 | # Scrapy stuff: 70 | .scrapy 71 | 72 | # Sphinx documentation 73 | docs/_build/ 74 | 75 | # PyBuilder 76 | .pybuilder/ 77 | target/ 78 | 79 | # Jupyter Notebook 80 | .ipynb_checkpoints 81 | 82 | # IPython 83 | profile_default/ 84 | ipython_config.py 85 | 86 | # pyenv 87 | # For a library or package, you might want to ignore these files since the code is 88 | # intended to run in multiple environments; otherwise, check them in: 89 | # .python-version 90 | 91 | # pipenv 92 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 93 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 94 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 95 | # install all needed dependencies. 96 | #Pipfile.lock 97 | 98 | # poetry 99 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 100 | # This is especially recommended for binary packages to ensure reproducibility, and is more 101 | # commonly ignored for libraries. 102 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 103 | #poetry.lock 104 | 105 | # pdm 106 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 107 | #pdm.lock 108 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 109 | # in version control. 110 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 111 | .pdm.toml 112 | .pdm-python 113 | .pdm-build/ 114 | 115 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 116 | __pypackages__/ 117 | 118 | # Celery stuff 119 | celerybeat-schedule 120 | celerybeat.pid 121 | 122 | # SageMath parsed files 123 | *.sage.py 124 | 125 | # Environments 126 | .env 127 | .venv 128 | env/ 129 | venv/ 130 | ENV/ 131 | env.bak/ 132 | venv.bak/ 133 | 134 | # Spyder project settings 135 | .spyderproject 136 | .spyproject 137 | 138 | # Rope project settings 139 | .ropeproject 140 | 141 | # mkdocs documentation 142 | /site 143 | 144 | # mypy 145 | .mypy_cache/ 146 | .dmypy.json 147 | dmypy.json 148 | 149 | # Pyre type checker 150 | .pyre/ 151 | 152 | # pytype static type analyzer 153 | .pytype/ 154 | 155 | # Cython debug symbols 156 | cython_debug/ 157 | 158 | # PyCharm 159 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 160 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 161 | # and can be added to the global gitignore or merged into this file. For a more nuclear 162 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 163 | #.idea/ 164 | -------------------------------------------------------------------------------- /DagentLogo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Extensible-AI/DAGent/8dd76ffe7c1bdd57ebfcf67a859f25d196632233/DagentLogo.png -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Parth Sareen 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DAGent - Directed Acyclic Graphs (DAGs) as AI Agents 2 | 3 |

4 | Dagent Logo 5 |

6 | 7 | *DAGent is an opinionated Python library to create AI Agents quickly without overhead* 8 | 9 | 10 | [![Downloads](https://static.pepy.tech/badge/dagent)](https://pepy.tech/project/dagent) 11 | 12 | ## Quickstart 13 | 14 | ### Installing 15 | - `pip install dagent` or `rye add dagent` 16 | - Make sure you have the API key of your choice available in system. The default is `OPENAI_API_KEY` 17 | 18 | 19 | ### Get right into it 20 | See [dagent/examples/quickstart_simple_agent.py](dagent/examples/quickstart_simple_agent.py) for a quickstart example 21 | 22 | 23 | ### DAGent basics 24 | 25 | The idea behind dagent is to structure AI agents in to a workflow. This is done through setting each function up as a node in a graph. 26 | 27 | The agentic behavior is through the inferring of what function to run through the use of LLMs which is abstracted by a "Decision Node". 28 | 29 | `Tool` 30 | - A tool is just a function which the LLM can use. 31 | - It is helpful to have docstrings and annotations to assist the llm infer what is happening. This is recommended for larger functions/tools. 32 | 33 | `FunctionNode` 34 | - Runs a python function 35 | - Can be attached to a `DecisionNode` to be treated as a tool and allow an LLM to choose which function to run 36 | 37 | `DecisionNode` 38 | - This is where the llm picks a function to run from given options 39 | - The `.compile()` method autogenerates and saves tool descriptions under Tool. Run with param `force_load=True` if there are errors or if an option of tool changes 40 | - These tool/function descriptions get generated under a `Tool_JSON` folder. Feel free to edit tool descriptions if the agent is unreliable. 41 | 42 | `prev_output` param for functions: 43 | - If passing data from one function to another, make sure this param is in the function signature. 44 | - If extra params get passed in/weird stuff happens add a `**kwargs` to see if there are any hidden params which were passed and need to be handled 45 | 46 | 47 | ### DAGent Diagram 48 | ```mermaid 49 | graph TD 50 | A[Function Node] --> B[Decision Node] 51 | B --> C[Function Node] 52 | B --> E[Function Node] 53 | D --> F[Function Node] 54 | E --> G[Decision Node] 55 | F --> H[Function Node] 56 | G --> K[Function Node] 57 | G -- "Pick Function to Run" --> I[Function Node] 58 | G --> J[Function Node] 59 | I --> L[Function Node] 60 | J --> M[Function Node] 61 | K --> N[Function Node] 62 | K -- "Run Both " --> S[Function Node] 63 | 64 | %% Additional annotations 65 | B -- "Use a Function as a tool" --> D[Function Node] 66 | 67 | ``` 68 | 69 | ## Using Different Models 70 | 71 | DAGent supports using different LLM models for inference and tool description generation. You can specify the model when calling `call_llm` or `call_llm_tool`, or when compiling the DecisionNode. 72 | 73 | For example, to use the `groq/llama3-70b-8192` model: 74 | 75 | ```python 76 | 77 | # Using groq with decision node 78 | decision_node1 = DecisionNode('groq/llama3-70b-8192') 79 | 80 | # Using ollama with decision node 81 | decision_node2 = DecisionNode('ollama_chat/llama3.1', api_base="http://localhost:11434") 82 | 83 | # Call llm function 84 | output = decision_node2.run(messages=[{'role': 'user', 'content': 'add the numbers 2 and 3'}]) 85 | 86 | ``` 87 | 88 | ### Other things to know 89 | 90 | - `prev_output` is needed in the function signature if you want to use the value from the prior function's value. Obviously the prior function should have returned something for this to work 91 | - If there are errors with too many params being passed into a function node, add `**kwargs` to your function 92 | - Args can be overriden at any time using the following (this merges the kwargs in the background with priority to the user): 93 | 94 | ```python 95 | add_two_nums_node.user_params = { 96 | # param_name : value 97 | a : 10 98 | } 99 | ``` 100 | 101 | ## Feedback 102 | 103 | If you are interested in providing feedback, please use [this form](https://docs.google.com/forms/d/14EHPUEYGVV-eNj6HyUaQYokHUtouYOeWfriwHaHOARE). 104 | 105 | ## Acknowledgements 106 | Shoutout to: 107 | - [@omkizzy](https://x.com/omkizzy) 108 | - [@kaelan](https://github.com/Oasixer) 109 | -------------------------------------------------------------------------------- /dagent/.gitignore: -------------------------------------------------------------------------------- 1 | # python generated files 2 | __pycache__/ 3 | *.py[oc] 4 | build/ 5 | dist/ 6 | wheels/ 7 | *.egg-info 8 | TOOL_JSON 9 | # venv 10 | .venv 11 | -------------------------------------------------------------------------------- /dagent/.python-version: -------------------------------------------------------------------------------- 1 | 3.12.2 2 | -------------------------------------------------------------------------------- /dagent/README.md: -------------------------------------------------------------------------------- 1 | # dagent -------------------------------------------------------------------------------- /dagent/examples/add_two_nums.json: -------------------------------------------------------------------------------- 1 | {"type": "function", "function": {"name": "add_two_nums", "description": "Add two integer numbers", "parameters": {"type": "object", "properties": {"a": {"type": "integer", "description": "The first number to add"}, "b": {"type": "integer", "description": "The second number to add"}}, "required": ["a", "b"]}}} -------------------------------------------------------------------------------- /dagent/examples/multiply_two_nums.json: -------------------------------------------------------------------------------- 1 | {"type": "function", "function": {"name": "multiply_two_nums", "description": "Multiply two integers", "parameters": {"type": "object", "properties": {"a": {"type": "integer", "description": "The first integer to multiply"}, "b": {"type": "integer", "description": "The second integer to multiply"}}, "required": ["a", "b"]}}} -------------------------------------------------------------------------------- /dagent/examples/quickstart_local_simple_agent.py: -------------------------------------------------------------------------------- 1 | 2 | """ 3 | This example demonstrates the main concepts of the dagent library using a local model: 4 | 1. Function Nodes: Represent individual operations in the workflow. 5 | 2. Decision Nodes: Use AI models to make decisions and route the workflow. 6 | 3. Node Linking: Connect nodes to create a directed acyclic graph (DAG). 7 | 4. Compilation: Prepare the DAG for execution. 8 | 5. Execution: Run the workflow starting from an entry point. 9 | """ 10 | 11 | from dagent import DecisionNode, FunctionNode 12 | import argparse 13 | 14 | def add_two_nums(a: int, b: int) -> int: 15 | """A simple function to add two numbers.""" 16 | return a + b 17 | 18 | def multiply_two_nums(a: int, b: int) -> int: 19 | """A simple function to multiply two numbers.""" 20 | return a * b 21 | 22 | def print_result(prev_output: int) -> None: 23 | """ 24 | Print the result from a previous node. 25 | 26 | Note: `prev_output` is automatically passed from the previous node. 27 | """ 28 | print(prev_output) 29 | return prev_output 30 | 31 | def entry_func(input: str) -> str: 32 | """Entry point function for the workflow.""" 33 | return input 34 | 35 | def main(): 36 | 37 | # Initialize Function Nodes for basic arithmetic operations and result printing 38 | add_two_nums_node = FunctionNode(func=add_two_nums) 39 | multiply_two_nums_node = FunctionNode(func=multiply_two_nums) 40 | print_result_node = FunctionNode(func=print_result) 41 | 42 | # Initialize the entry point of the workflow 43 | entry_node = FunctionNode(func=entry_func) 44 | 45 | # Initialize a Decision Node configured to use a local AI model for decision-making 46 | decision_node = DecisionNode(model='ollama_chat/llama3.1', api_base="http://localhost:11434") 47 | 48 | # Link Nodes to define the workflow structure 49 | entry_node.next_nodes = [decision_node] 50 | decision_node.next_nodes = [add_two_nums_node, multiply_two_nums_node] 51 | add_two_nums_node.next_nodes = [print_result_node] 52 | multiply_two_nums_node.next_nodes = [print_result_node] 53 | 54 | # Compile the DAG to prepare it for execution 55 | entry_node.compile(force_load=False) 56 | 57 | # Execute the DAG in a loop to process user input dynamically 58 | while True: 59 | user_input = input("Enter your command: ") 60 | entry_node.run(input=user_input) 61 | 62 | if __name__ == "__main__": 63 | main() 64 | -------------------------------------------------------------------------------- /dagent/examples/quickstart_simple_agent.py: -------------------------------------------------------------------------------- 1 | 2 | from dagent import DecisionNode, FunctionNode 3 | import logging 4 | 5 | """ 6 | This example demonstrates the main concepts of the dagent library: 7 | 1. Function Nodes: Represent individual operations in the workflow. 8 | 2. Decision Nodes: Use AI models to make decisions and route the workflow. 9 | 3. Node Linking: Connect nodes to create a directed acyclic graph (DAG). 10 | 4. Compilation: Prepare the DAG for execution. 11 | 5. Execution: Run the workflow starting from an entry point. 12 | """ 13 | 14 | 15 | # Can enable logging below to save logs to file 16 | # logging.basicConfig(level=logging.INFO, 17 | # format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', 18 | # handlers=[logging.FileHandler('dagent_logs.log'), logging.StreamHandler()]) 19 | 20 | def add_two_nums(a: int, b: int) -> int: 21 | """A simple function to add two numbers.""" 22 | return a + b 23 | 24 | 25 | def multiply_two_nums(a: int, b: int) -> int: 26 | """A simple function to multiply two numbers.""" 27 | return a * b 28 | 29 | 30 | def print_result(prev_output: int) -> None: 31 | """ 32 | Print the result from a previous node. 33 | 34 | Note: `prev_output` is automatically passed from the previous node. 35 | """ 36 | print(prev_output) 37 | return prev_output 38 | 39 | 40 | def entry_func(input: str) -> str: 41 | """Entry point function for the workflow.""" 42 | return input 43 | 44 | 45 | def main(): 46 | # Setup Function Nodes 47 | """ 48 | FunctionNodes wrap regular Python functions, allowing them to be used in the DAG. 49 | """ 50 | add_two_nums_node = FunctionNode(func=add_two_nums) 51 | multiply_two_nums_node = FunctionNode(func=multiply_two_nums) 52 | print_result_node = FunctionNode(func=print_result) 53 | entry_node = FunctionNode(func=entry_func) 54 | 55 | # Setup Decision Node 56 | """ 57 | DecisionNodes use AI models to make routing decisions in the workflow. 58 | """ 59 | decision_node = DecisionNode(model='gpt-4-0125-preview', api_base=None) 60 | 61 | # Link Nodes 62 | """ 63 | Nodes are connected by setting their `next_nodes` attribute. 64 | This creates the structure of the directed acyclic graph (DAG). 65 | """ 66 | entry_node.next_nodes = [decision_node] 67 | 68 | decision_node.next_nodes = [ 69 | add_two_nums_node, 70 | multiply_two_nums_node, 71 | ] 72 | 73 | add_two_nums_node.next_nodes = [print_result_node] 74 | multiply_two_nums_node.next_nodes = [print_result_node] 75 | 76 | # Compile the DAG 77 | """ 78 | Compilation prepares the DAG for execution, ensuring all nodes are properly linked. 79 | """ 80 | entry_node.compile(force_load=True) 81 | 82 | # Execute the DAG 83 | while True: 84 | user_input = input("Enter your command: ") 85 | entry_node.run(input=user_input) 86 | 87 | if __name__ == "__main__": 88 | main() 89 | -------------------------------------------------------------------------------- /dagent/examples/sql_agent_local.py: -------------------------------------------------------------------------------- 1 | from dagent import DecisionNode, FunctionNode, call_llm 2 | import logging 3 | 4 | example_schema = """ 5 | CREATE TABLE Users ( 6 | user_id INTEGER PRIMARY KEY, 7 | username TEXT NOT NULL, 8 | email TEXT NOT NULL, 9 | registration_date DATE 10 | ); 11 | 12 | CREATE TABLE Products ( 13 | product_id INTEGER PRIMARY KEY, 14 | product_name TEXT NOT NULL, 15 | price DECIMAL(10, 2) NOT NULL, 16 | stock_quantity INTEGER 17 | ); 18 | 19 | CREATE TABLE Orders ( 20 | order_id INTEGER PRIMARY KEY, 21 | user_id INTEGER, 22 | order_date DATE, 23 | total_amount DECIMAL(10, 2), 24 | FOREIGN KEY (user_id) REFERENCES Users(user_id) 25 | ); 26 | 27 | CREATE TABLE OrderItems ( 28 | order_item_id INTEGER PRIMARY KEY, 29 | order_id INTEGER, 30 | product_id INTEGER, 31 | quantity INTEGER, 32 | price_per_unit DECIMAL(10, 2), 33 | FOREIGN KEY (order_id) REFERENCES Orders(order_id), 34 | FOREIGN KEY (product_id) REFERENCES Products(product_id) 35 | ); 36 | 37 | INSERT INTO Users (username, email, registration_date) VALUES 38 | ('john_doe', 'john@example.com', '2023-01-15'), 39 | ('jane_smith', 'jane@example.com', '2023-02-20'), 40 | ('bob_johnson', 'bob@example.com', '2023-03-10'); 41 | 42 | INSERT INTO Products (product_name, price, stock_quantity) VALUES 43 | ('Laptop', 999.99, 50), 44 | ('Smartphone', 599.99, 100), 45 | ('Headphones', 79.99, 200); 46 | 47 | INSERT INTO Orders (user_id, order_date, total_amount) VALUES 48 | (1, '2023-04-01', 1079.98), 49 | (2, '2023-04-15', 599.99), 50 | (3, '2023-04-30', 159.98); 51 | 52 | INSERT INTO OrderItems (order_id, product_id, quantity, price_per_unit) VALUES 53 | (1, 1, 1, 999.99), 54 | (1, 3, 1, 79.99), 55 | (2, 2, 1, 599.99), 56 | (3, 3, 2, 79.99); 57 | """ 58 | 59 | 60 | logging.basicConfig(level=logging.INFO, 61 | format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', 62 | handlers=[logging.FileHandler('dagent_logs.log'), logging.StreamHandler()]) 63 | 64 | def get_user_input() -> str: 65 | """Get user input from the command line.""" 66 | user_input = input("Enter your command: ") 67 | return user_input 68 | 69 | def get_database_schema(prev_output: str) -> str: 70 | return example_schema 71 | 72 | def generate_sql(prev_output: str, database_schema: str) -> str: 73 | """Generate SQL from the user input.""" 74 | sql = call_llm(model='ollama_chat/llama3.1', api_base="http://localhost:11434", messages=[{"role": "user", "content": f"Generate SQL from the user input: {prev_output} and the following database schema: {database_schema}"}]) 75 | print('generated sql: ', sql) 76 | return sql 77 | 78 | def show_results(prev_output: str) -> str: 79 | print('results: ', prev_output) 80 | 81 | 82 | get_user_input_node = FunctionNode(func=get_user_input) 83 | get_database_schema_node = FunctionNode(func=get_database_schema) 84 | generate_sql_node = FunctionNode(func=generate_sql, user_params={"database_schema": example_schema}) 85 | show_results_node = FunctionNode(func=show_results) 86 | 87 | decision_node = DecisionNode(model='ollama_chat/llama3.1', api_base="http://localhost:11434") 88 | 89 | get_user_input_node.next_nodes = [decision_node] 90 | decision_node.next_nodes = [get_database_schema_node, generate_sql_node] 91 | get_database_schema_node.next_nodes = [generate_sql_node] 92 | generate_sql_node.next_nodes = [show_results_node] 93 | 94 | 95 | if __name__ == "__main__": 96 | get_user_input_node.compile(force_load=False) 97 | get_user_input_node.run() -------------------------------------------------------------------------------- /dagent/pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "dagent" 3 | version = "0.0.9" 4 | description = "AI Agents as DAGs - Directed Acyclic Graphs" 5 | authors = [ 6 | { name = "Parth Sareen", email = "parth@extensible.dev" } 7 | ] 8 | dependencies = [ 9 | "litellm>=1.44.19", 10 | ] 11 | readme = "../README.md" 12 | requires-python = ">= 3.10" 13 | 14 | [build-system] 15 | requires = ["hatchling"] 16 | build-backend = "hatchling.build" 17 | 18 | [tool.rye] 19 | managed = true 20 | dev-dependencies = [] 21 | 22 | [tool.hatch.metadata] 23 | allow-direct-references = true 24 | 25 | [tool.hatch.build.targets.wheel] 26 | packages = ["src/dagent"] 27 | -------------------------------------------------------------------------------- /dagent/requirements-dev.lock: -------------------------------------------------------------------------------- 1 | # generated by rye 2 | # use `rye lock` or `rye sync` to update this lockfile 3 | # 4 | # last locked with the following flags: 5 | # pre: false 6 | # features: [] 7 | # all-features: false 8 | # with-sources: false 9 | # generate-hashes: false 10 | 11 | -e file:. 12 | aiohttp==3.9.5 13 | # via litellm 14 | aiosignal==1.3.1 15 | # via aiohttp 16 | annotated-types==0.7.0 17 | # via pydantic 18 | anyio==4.4.0 19 | # via httpx 20 | # via openai 21 | attrs==23.2.0 22 | # via aiohttp 23 | # via jsonschema 24 | # via referencing 25 | certifi==2024.6.2 26 | # via httpcore 27 | # via httpx 28 | # via requests 29 | charset-normalizer==3.3.2 30 | # via requests 31 | click==8.1.7 32 | # via litellm 33 | distro==1.9.0 34 | # via openai 35 | filelock==3.14.0 36 | # via huggingface-hub 37 | frozenlist==1.4.1 38 | # via aiohttp 39 | # via aiosignal 40 | fsspec==2024.6.0 41 | # via huggingface-hub 42 | h11==0.14.0 43 | # via httpcore 44 | httpcore==1.0.5 45 | # via httpx 46 | httpx==0.27.0 47 | # via openai 48 | huggingface-hub==0.23.3 49 | # via tokenizers 50 | idna==3.7 51 | # via anyio 52 | # via httpx 53 | # via requests 54 | # via yarl 55 | importlib-metadata==7.1.0 56 | # via litellm 57 | jinja2==3.1.4 58 | # via litellm 59 | jiter==0.5.0 60 | # via openai 61 | jsonschema==4.23.0 62 | # via litellm 63 | jsonschema-specifications==2023.12.1 64 | # via jsonschema 65 | litellm==1.44.19 66 | # via dagent 67 | markupsafe==2.1.5 68 | # via jinja2 69 | multidict==6.0.5 70 | # via aiohttp 71 | # via yarl 72 | openai==1.42.0 73 | # via litellm 74 | packaging==24.0 75 | # via huggingface-hub 76 | pydantic==2.7.3 77 | # via litellm 78 | # via openai 79 | pydantic-core==2.18.4 80 | # via pydantic 81 | python-dotenv==1.0.1 82 | # via litellm 83 | pyyaml==6.0.1 84 | # via huggingface-hub 85 | referencing==0.35.1 86 | # via jsonschema 87 | # via jsonschema-specifications 88 | regex==2024.5.15 89 | # via tiktoken 90 | requests==2.32.3 91 | # via huggingface-hub 92 | # via litellm 93 | # via tiktoken 94 | rpds-py==0.20.0 95 | # via jsonschema 96 | # via referencing 97 | sniffio==1.3.1 98 | # via anyio 99 | # via httpx 100 | # via openai 101 | tiktoken==0.7.0 102 | # via litellm 103 | tokenizers==0.19.1 104 | # via litellm 105 | tqdm==4.66.4 106 | # via huggingface-hub 107 | # via openai 108 | typing-extensions==4.12.2 109 | # via huggingface-hub 110 | # via openai 111 | # via pydantic 112 | # via pydantic-core 113 | urllib3==2.2.1 114 | # via requests 115 | yarl==1.9.4 116 | # via aiohttp 117 | zipp==3.19.2 118 | # via importlib-metadata 119 | -------------------------------------------------------------------------------- /dagent/requirements.lock: -------------------------------------------------------------------------------- 1 | # generated by rye 2 | # use `rye lock` or `rye sync` to update this lockfile 3 | # 4 | # last locked with the following flags: 5 | # pre: false 6 | # features: [] 7 | # all-features: false 8 | # with-sources: false 9 | # generate-hashes: false 10 | 11 | -e file:. 12 | aiohttp==3.9.5 13 | # via litellm 14 | aiosignal==1.3.1 15 | # via aiohttp 16 | annotated-types==0.7.0 17 | # via pydantic 18 | anyio==4.4.0 19 | # via httpx 20 | # via openai 21 | attrs==23.2.0 22 | # via aiohttp 23 | # via jsonschema 24 | # via referencing 25 | certifi==2024.6.2 26 | # via httpcore 27 | # via httpx 28 | # via requests 29 | charset-normalizer==3.3.2 30 | # via requests 31 | click==8.1.7 32 | # via litellm 33 | distro==1.9.0 34 | # via openai 35 | filelock==3.14.0 36 | # via huggingface-hub 37 | frozenlist==1.4.1 38 | # via aiohttp 39 | # via aiosignal 40 | fsspec==2024.6.0 41 | # via huggingface-hub 42 | h11==0.14.0 43 | # via httpcore 44 | httpcore==1.0.5 45 | # via httpx 46 | httpx==0.27.0 47 | # via openai 48 | huggingface-hub==0.23.3 49 | # via tokenizers 50 | idna==3.7 51 | # via anyio 52 | # via httpx 53 | # via requests 54 | # via yarl 55 | importlib-metadata==7.1.0 56 | # via litellm 57 | jinja2==3.1.4 58 | # via litellm 59 | jiter==0.5.0 60 | # via openai 61 | jsonschema==4.23.0 62 | # via litellm 63 | jsonschema-specifications==2023.12.1 64 | # via jsonschema 65 | litellm==1.44.19 66 | # via dagent 67 | markupsafe==2.1.5 68 | # via jinja2 69 | multidict==6.0.5 70 | # via aiohttp 71 | # via yarl 72 | openai==1.42.0 73 | # via litellm 74 | packaging==24.0 75 | # via huggingface-hub 76 | pydantic==2.7.3 77 | # via litellm 78 | # via openai 79 | pydantic-core==2.18.4 80 | # via pydantic 81 | python-dotenv==1.0.1 82 | # via litellm 83 | pyyaml==6.0.1 84 | # via huggingface-hub 85 | referencing==0.35.1 86 | # via jsonschema 87 | # via jsonschema-specifications 88 | regex==2024.5.15 89 | # via tiktoken 90 | requests==2.32.3 91 | # via huggingface-hub 92 | # via litellm 93 | # via tiktoken 94 | rpds-py==0.20.0 95 | # via jsonschema 96 | # via referencing 97 | sniffio==1.3.1 98 | # via anyio 99 | # via httpx 100 | # via openai 101 | tiktoken==0.7.0 102 | # via litellm 103 | tokenizers==0.19.1 104 | # via litellm 105 | tqdm==4.66.4 106 | # via huggingface-hub 107 | # via openai 108 | typing-extensions==4.12.2 109 | # via huggingface-hub 110 | # via openai 111 | # via pydantic 112 | # via pydantic-core 113 | urllib3==2.2.1 114 | # via requests 115 | yarl==1.9.4 116 | # via aiohttp 117 | zipp==3.19.2 118 | # via importlib-metadata 119 | -------------------------------------------------------------------------------- /dagent/src/dagent/DagNode.py: -------------------------------------------------------------------------------- 1 | 2 | class DagNode: 3 | def __init__( 4 | self, 5 | func: callable, 6 | next_nodes: dict[str, 'DagNode'] = None 7 | ): 8 | self.func = func 9 | self.next_nodes = next_nodes or {} 10 | self.node_result = None 11 | 12 | def compile(self): 13 | """ 14 | Use an LLM to generate tool descriptions from functions to run the DAG 15 | """ 16 | return NotImplemented 17 | 18 | def run(self, **kwargs) -> any: 19 | raise NotImplementedError("Subclasses should implement this method") -------------------------------------------------------------------------------- /dagent/src/dagent/DecisionNode.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import inspect 4 | from .DagNode import DagNode 5 | 6 | from .base_functions import call_llm_tool, create_tool_desc 7 | import logging 8 | 9 | 10 | class DecisionNode(DagNode): 11 | def __init__( 12 | self, 13 | func: callable = call_llm_tool, 14 | next_nodes: list | None = None, 15 | user_params: dict | None = None, 16 | model: str = 'gpt-4-0125-preview', 17 | api_base: str | None = None, 18 | tool_json_dir: str = 'Tool_JSON', 19 | retry_json_count: int = 3, 20 | max_tool_calls: int | None = None 21 | 22 | ): 23 | super().__init__(func, next_nodes) 24 | self.user_params = user_params or {} 25 | self.logger = logging.getLogger(__name__) 26 | self.compiled = False 27 | self.api_base = api_base 28 | self.model = model 29 | self.tool_json_dir = tool_json_dir 30 | self.retry_json_count = retry_json_count 31 | self.max_tool_calls = max_tool_calls 32 | self.logger.info(f"DecisionNode initialized with model: {model}, api_base: {api_base}, max_tool_calls: {max_tool_calls}") 33 | 34 | 35 | def compile(self, force_load=False) -> None: 36 | self.logger.info("Starting compilation process") 37 | self.compiled = True 38 | 39 | if isinstance(self.next_nodes, list): 40 | self.next_nodes = {node.func.__name__: node for node in self.next_nodes} 41 | self.logger.debug(f"Converted next_nodes list to dictionary: {self.next_nodes.keys()}") 42 | 43 | for _, next_node in self.next_nodes.items(): 44 | func_name = os.path.join(self.tool_json_dir, next_node.func.__name__ + '.json') 45 | self.logger.info(f"Compiling tool description for function: {next_node.func.__name__}") 46 | 47 | if force_load or not os.path.exists(func_name): 48 | self.logger.debug(f"Creating new tool description for {next_node.func.__name__}") 49 | os.makedirs(self.tool_json_dir, exist_ok=True) 50 | try: 51 | current_retry_count = 0 52 | tool_desc = create_tool_desc(model=self.model, function_desc=inspect.getsource(next_node.func), api_base=self.api_base) 53 | 54 | while not tool_desc and current_retry_count < self.retry_json_count: 55 | self.logger.warning(f"Retry {current_retry_count + 1} for creating tool description of {next_node.func.__name__}") 56 | tool_desc = create_tool_desc(model=self.model, function_desc=inspect.getsource(next_node.func), api_base=self.api_base) 57 | current_retry_count += 1 58 | 59 | if not tool_desc: 60 | error_msg = f"Tool description for {next_node.func.__name__} could not be generated, recommend generating manually and storing under {func_name}.json in {self.tool_json_dir} directory" 61 | self.logger.error(error_msg) 62 | raise ValueError(error_msg) 63 | 64 | tool_desc_json = json.loads(tool_desc) 65 | self.logger.debug(f"Successfully created tool description for {next_node.func.__name__}") 66 | except Exception as e: 67 | self.logger.error(f"Error creating tool description for {next_node.func.__name__}: {e}") 68 | raise e 69 | with open(func_name, 'w') as f: 70 | json.dump(tool_desc_json, f) 71 | self.logger.info(f"Saved tool description for {next_node.func.__name__} to {func_name}") 72 | else: 73 | self.logger.info(f"Loading existing tool description for {next_node.func.__name__} from {func_name}") 74 | with open(func_name, 'r') as f: 75 | tool_desc_json = json.load(f) 76 | 77 | next_node.tool_description = tool_desc_json 78 | next_node.compile() 79 | self.logger.info("Compilation process completed successfully for DecisionNode") 80 | 81 | 82 | def run(self, **kwargs) -> any: 83 | self.logger.info("Starting DecisionNode run") 84 | if not self.next_nodes: 85 | error_msg = "Next nodes not specified for LLM call" 86 | self.logger.error(error_msg) 87 | raise ValueError(error_msg) 88 | 89 | if not self.compiled: 90 | error_msg = "Node not compiled. Please run compile() method from the entry node first" 91 | self.logger.error(error_msg) 92 | raise ValueError(error_msg) 93 | 94 | if not kwargs.get('prev_output') and not kwargs.get('messages'): 95 | error_msg = "No input data provided for LLM call" 96 | self.logger.error(error_msg) 97 | raise ValueError(error_msg) 98 | 99 | # Get existing messages or create an empty list 100 | messages = self.user_params.get('messages', kwargs.get('messages', [])) 101 | # Add previous output as a user message if available 102 | if 'prev_output' in kwargs: 103 | messages.append({'role': 'user', 'content': str(kwargs['prev_output'])}) 104 | 105 | # Update kwargs with the final messages list 106 | kwargs['messages'] = messages 107 | self.logger.info(f"Prepared messages for LLM call: {messages}") 108 | 109 | try: 110 | self.logger.info(f"Calling LLM tool with model: {self.model}") 111 | # The 'messages' param is passed in through the kwargs 112 | response = call_llm_tool(model=self.model, tools=[node.tool_description for node in self.next_nodes.values()], api_base=self.api_base, **kwargs) 113 | tool_calls = getattr(response, 'tool_calls', None) 114 | if not tool_calls: 115 | error_msg = "No tool calls received from LLM tool response" 116 | self.logger.error(error_msg) 117 | raise ValueError(error_msg) 118 | 119 | self.logger.info(f"Received {len(tool_calls)} tool call(s) from LLM") 120 | self.logger.info(f"Tool calls: {tool_calls}") 121 | if self.logger.getEffectiveLevel() == logging.DEBUG: 122 | proceed = input("Debug mode is active. Do you want to proceed with the tool calls? (yes(y)/no(n)): ") 123 | print('proceed: ',proceed) 124 | if proceed.lower() not in ['yes', 'y']: 125 | self.logger.debug("\nUser chose not to proceed with the tool calls.") 126 | return 127 | 128 | # Apply max_tool_calls constraint if set 129 | if self.max_tool_calls is not None: 130 | tool_calls = tool_calls[:self.max_tool_calls] 131 | self.logger.info(f"Constrained to {len(tool_calls)} tool call(s) due to max_tool_calls setting") 132 | 133 | for tool_call in tool_calls: 134 | function_name = tool_call.function.name 135 | function_args = json.loads(tool_call.function.arguments) 136 | self.logger.debug(f"Processing tool call for function: {function_name} with arguments: {function_args}") 137 | 138 | next_node = self.next_nodes.get(function_name) 139 | if not next_node: 140 | error_msg = f"Function name '{function_name}' not found in next_nodes. Something went wrong" 141 | self.logger.error(error_msg) 142 | raise KeyError(error_msg) 143 | 144 | # Merge user_params with function_args, giving precedence to user_params 145 | merged_args = {**function_args, **self.user_params} 146 | # Print kwargs for debugging 147 | self.logger.info(f"Current kwargs: {kwargs}") 148 | print(f"Current kwargs: {kwargs}") 149 | 150 | if 'prev_output' in kwargs: 151 | merged_args['prev_output'] = kwargs['prev_output'] 152 | 153 | func_signature = inspect.signature(next_node.func) 154 | 155 | # TODO: Manage through derived data models 156 | filtered_args = {k: v for k, v in merged_args.items() if k in func_signature.parameters} 157 | self.logger.info(f"Filtered arguments for {function_name}: {filtered_args}") 158 | 159 | self.logger.info(f"Executing next node: {function_name}") 160 | next_node.run(**filtered_args) 161 | 162 | except (AttributeError, json.JSONDecodeError) as e: 163 | error_msg = f"Error parsing tool call: {e}" 164 | self.logger.error(error_msg) 165 | raise ValueError(error_msg) 166 | except Exception as e: 167 | error_msg = f"LLM tool call failed: {e}" 168 | self.logger.error(error_msg) 169 | raise RuntimeError(error_msg) 170 | self.logger.info("DecisionNode run completed successfully") 171 | -------------------------------------------------------------------------------- /dagent/src/dagent/FunctionNode.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from .DagNode import DagNode 3 | 4 | class FunctionNode(DagNode): 5 | def __init__(self, func: callable, tool_description = dict | None, next_nodes: list | None = None, user_params: dict | None = None): 6 | super().__init__(func, next_nodes) 7 | self.tool_description = tool_description 8 | self.user_params = user_params or {} 9 | self.compiled = False 10 | self.node_result = None 11 | self.logger = logging.getLogger(__name__) 12 | 13 | def compile(self, force_load=False) -> None: 14 | self.logger.info(f"Compiling FunctionNode for function: {self.func.__name__}") 15 | self.compiled = True 16 | if isinstance(self.next_nodes, list): 17 | self.logger.debug("Converting next_nodes from list to dictionary") 18 | self.next_nodes = {node.func.__name__: node for node in self.next_nodes} 19 | for node_name, next_node in self.next_nodes.items(): 20 | self.logger.debug(f"Compiling next node: {node_name}") 21 | next_node.compile(force_load=force_load) 22 | 23 | def run(self, **kwargs) -> any: 24 | if not self.compiled: 25 | self.logger.error("Attempted to run uncompiled node") 26 | raise ValueError("Node not compiled. Please run compile() method from the entry node first") 27 | 28 | self.logger.info(f"Running FunctionNode for function: {self.func.__name__}") 29 | merged_params = {**self.user_params, **kwargs} 30 | self.logger.debug(f"Merged parameters: {merged_params}") 31 | 32 | try: 33 | self.node_result = self.func(**merged_params) 34 | self.logger.debug(f"Function result: {self.node_result}") 35 | except Exception as e: 36 | self.logger.error(f"Error executing function {self.func.__name__}: {str(e)}") 37 | raise 38 | 39 | # Pass the result to the next nodes if any 40 | # TODO: figure out param logic pattern 41 | if not self.next_nodes: 42 | self.logger.info(f"No next nodes after {self.func.__name__}, returning result") 43 | return self.node_result 44 | for node_name, next_node in self.next_nodes.items(): 45 | self.logger.info(f"Passing result to next node: {node_name}") 46 | # TODO: creating data models for passing info between nodes 47 | params = {'prev_output': self.node_result, **next_node.user_params} 48 | self.logger.debug(f"Parameters for next node: {params}") 49 | next_node.run(**params) 50 | -------------------------------------------------------------------------------- /dagent/src/dagent/__init__.py: -------------------------------------------------------------------------------- 1 | from .base_functions import call_llm_tool, call_llm 2 | from .DecisionNode import DecisionNode 3 | from .FunctionNode import FunctionNode -------------------------------------------------------------------------------- /dagent/src/dagent/base_functions.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | from litellm import completion 3 | # For more info: https://litellm.vercel.app/docs/completion/input 4 | 5 | 6 | def call_llm_tool(model, messages, tools, api_base=None, **kwargs): 7 | response = completion( 8 | model=model, 9 | messages=messages, 10 | tools=tools, 11 | api_base=api_base, 12 | ) 13 | return response.choices[0].message 14 | 15 | 16 | def create_tool_desc(model, function_desc, api_base=None): 17 | example = { 18 | "type": "function", 19 | "function": { 20 | "name": "get_calendar_events", 21 | "description": "Get calendar events within a specified time range", 22 | "parameters": { 23 | "type": "object", 24 | "properties": { 25 | "start_time": { 26 | "type": "string", 27 | "description": "The start time for the event search, in ISO format", 28 | }, 29 | "end_time": { 30 | "type": "string", 31 | "description": "The end time for the event search, in ISO format", 32 | }, 33 | }, 34 | "required": ["start_time", "end_time"], 35 | }, 36 | } 37 | } 38 | messages = [{"role": "user", "content": "Create a json for the attached function: {} using the following pattern for the json: {}. Don't add anything extra. Make sure everything follows a valid json format".format(function_desc, example)}] 39 | response = completion( 40 | model=model, 41 | response_format={"type":"json_object"}, 42 | messages=messages, 43 | api_base=api_base 44 | ) 45 | return response.choices[0].message.content 46 | 47 | 48 | def call_llm(model, messages, api_base=None, **kwargs): 49 | response = completion( 50 | model=model, 51 | messages=messages, 52 | api_base=api_base, 53 | **kwargs 54 | ) 55 | return response.choices[0].message.content 56 | 57 | -------------------------------------------------------------------------------- /dagent/tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Extensible-AI/DAGent/8dd76ffe7c1bdd57ebfcf67a859f25d196632233/dagent/tests/__init__.py -------------------------------------------------------------------------------- /dagent/tests/test_DecisionNode.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import os 3 | import shutil 4 | from unittest.mock import patch, MagicMock 5 | from dagent.DecisionNode import DecisionNode 6 | from dagent.FunctionNode import FunctionNode 7 | 8 | class TestDecisionNode(unittest.TestCase): 9 | def setUp(self): 10 | self.test_dir = 'test_tool_json' 11 | os.makedirs(self.test_dir, exist_ok=True) 12 | 13 | def tearDown(self): 14 | if os.path.exists(self.test_dir): 15 | shutil.rmtree(self.test_dir) 16 | 17 | @patch('dagent.DecisionNode.call_llm_tool') 18 | @patch('dagent.DecisionNode.create_tool_desc') 19 | def test_decision_node(self, mock_create_tool_desc, mock_call_llm_tool): 20 | # Mock function for testing 21 | def test_function(arg1, arg2): 22 | return f"Result: {arg1}, {arg2}" 23 | 24 | # Mock create_tool_desc to return a valid tool description 25 | mock_create_tool_desc.return_value = '{"type": "function", "function": {"name": "test_function", "parameters": {"type": "object", "properties": {"arg1": {"type": "string"}, "arg2": {"type": "string"}}, "required": ["arg1", "arg2"]}}}' 26 | 27 | # Mock call_llm_tool to return a valid response 28 | mock_response = MagicMock() 29 | mock_tool_call = MagicMock() 30 | mock_tool_call.function = MagicMock() 31 | mock_tool_call.function.name = 'test_function' 32 | mock_tool_call.function.arguments = '{"arg1": "hello", "arg2": "world"}' 33 | mock_response.tool_calls = [mock_tool_call] 34 | mock_call_llm_tool.return_value = mock_response 35 | 36 | # Create DecisionNode with FunctionNode as next node 37 | function_node = FunctionNode(func=test_function) 38 | decision_node = DecisionNode(next_nodes={'test_function': function_node}) 39 | 40 | # Compile the decision node 41 | decision_node.compile(tool_json_dir=self.test_dir) 42 | 43 | # Check if the tool description file was created 44 | self.assertTrue(os.path.exists(os.path.join(self.test_dir, 'test_function.json'))) 45 | 46 | # Run the decision node 47 | with patch.object(function_node, 'run') as mock_function_run: 48 | decision_node.run(messages=[{'role': 'user', 'content': 'Test message'}]) 49 | 50 | # Assert that the function node's run method was called with correct arguments 51 | mock_function_run.assert_called_once_with(arg1='hello', arg2='world') 52 | 53 | # Assert that create_tool_desc and call_llm_tool were called 54 | mock_create_tool_desc.assert_called_once() 55 | mock_call_llm_tool.assert_called_once() 56 | 57 | if __name__ == '__main__': 58 | unittest.main() -------------------------------------------------------------------------------- /dagent/tests/test_FunctionNode.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | from unittest.mock import MagicMock 3 | from dagent.FunctionNode import FunctionNode 4 | 5 | class TestFunctionNode(unittest.TestCase): 6 | def setUp(self): 7 | def test_func(a, b): 8 | return a + b 9 | self.function_node = FunctionNode(func=test_func) 10 | self.function_node.tool_description = None 11 | 12 | 13 | def test_init(self): 14 | self.assertIsInstance(self.function_node, FunctionNode) 15 | self.assertEqual(self.function_node.func.__name__, 'test_func') 16 | # Check if tool_description is either None or an empty dict 17 | self.assertTrue(self.function_node.tool_description is None or self.function_node.tool_description == {}) 18 | self.assertEqual(self.function_node.user_params, {}) 19 | self.assertFalse(self.function_node.compiled) 20 | self.assertIsNone(self.function_node.node_result) 21 | 22 | def test_compile(self): 23 | def identity(x): 24 | return x 25 | next_node = FunctionNode(func=identity) 26 | self.function_node.next_nodes = [next_node] 27 | self.function_node.compile() 28 | self.assertTrue(self.function_node.compiled) 29 | self.assertIsInstance(self.function_node.next_nodes, dict) 30 | # Changed assertion for next_nodes key 31 | self.assertIn('identity', self.function_node.next_nodes) 32 | 33 | def test_compile(self): 34 | def identity(x): 35 | return x 36 | next_node = FunctionNode(func=identity) 37 | self.function_node.next_nodes = [next_node] 38 | self.function_node.compile() 39 | self.assertTrue(self.function_node.compiled) 40 | self.assertIsInstance(self.function_node.next_nodes, dict) 41 | self.assertIn('identity', self.function_node.next_nodes) 42 | 43 | def test_run_without_compile(self): 44 | with self.assertRaises(ValueError): 45 | self.function_node.run(a=1, b=2) 46 | 47 | def test_run_with_compile(self): 48 | self.function_node.compile() 49 | result = self.function_node.run(a=1, b=2) 50 | self.assertEqual(result, 3) 51 | self.assertEqual(self.function_node.node_result, 3) 52 | 53 | def test_run_with_user_params(self): 54 | self.function_node.user_params = {'b': 5} 55 | self.function_node.compile() 56 | result = self.function_node.run(a=1) 57 | self.assertEqual(result, 6) 58 | 59 | def test_run_with_next_nodes(self): 60 | def double(x): 61 | return x * 2 62 | next_node = FunctionNode(func=double) 63 | self.function_node.next_nodes = [next_node] 64 | self.function_node.compile() 65 | 66 | next_node.run = MagicMock() 67 | self.function_node.run(a=1, b=2) 68 | 69 | next_node.run.assert_called_once_with(prev_output=3) 70 | 71 | if __name__ == '__main__': 72 | unittest.main() 73 | -------------------------------------------------------------------------------- /dagent/tests/test_base_functions.py: -------------------------------------------------------------------------------- 1 | import json 2 | import inspect 3 | import argparse 4 | from dagent.base_functions import * 5 | 6 | def add_two_nums(a: int, b: int) -> int: 7 | return a + b 8 | 9 | def run_llm(model, api_base=None): 10 | # Run `call_llm` 11 | output = call_llm(model, [{'role': 'user', 'content': 'add the numbers 2 and 3'}], api_base=api_base) 12 | print(f'{model} output:', output) 13 | 14 | # Create tool description for `add_two_nums` function 15 | desc = create_tool_desc(model=model, function_desc=inspect.getsource(add_two_nums), api_base=api_base) 16 | print(f'{model} tool desc:', desc, end='\n\n') 17 | 18 | tool_desc_json = json.loads(desc) 19 | 20 | # Run `call_llm_tool` 21 | output = call_llm_tool(model, [{'role': 'user', 'content': 'add the numbers 2 and 3 using the provided tool'}], tools=[tool_desc_json], api_base=api_base) 22 | 23 | tool_calls = getattr(output, 'tool_calls', None) 24 | if not tool_calls: 25 | raise ValueError("No tool calls received from LLM tool response") 26 | 27 | function_name = tool_calls[0].function.name 28 | print(f'{model} output func name:', function_name, end='\n\n') 29 | 30 | def main(): 31 | parser = argparse.ArgumentParser(description="Run LLM models") 32 | parser.add_argument('model', choices=['groq', 'ollama', 'gpt4'], help="Select the model to run") 33 | args = parser.parse_args() 34 | 35 | if args.model == 'groq': 36 | run_llm('groq/llama3-70b-8192') 37 | elif args.model == 'ollama': 38 | run_llm('ollama_chat/llama3.1', api_base="http://localhost:11434") 39 | elif args.model == 'gpt4': 40 | run_llm('gpt-4-0125-preview') 41 | 42 | if __name__ == "__main__": 43 | main() 44 | -------------------------------------------------------------------------------- /dagent/tests/test_linking.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | from unittest.mock import patch, MagicMock 3 | from dagent import DecisionNode, FunctionNode 4 | 5 | def add_two_nums(a: int, b: int) -> int: 6 | return a + b 7 | 8 | def multiply_two_nums(a: int, b: int) -> int: 9 | return a * b 10 | 11 | class TestLinking(unittest.TestCase): 12 | def setUp(self): 13 | self.add_node = FunctionNode(func=add_two_nums) 14 | self.multiply_node = FunctionNode(func=multiply_two_nums) 15 | self.decision_node = DecisionNode(model='gpt-4-0125-preview', api_base=None) 16 | 17 | self.decision_node.next_nodes = { 18 | 'add_two_nums': self.add_node, 19 | 'multiply_two_nums': self.multiply_node 20 | } 21 | 22 | @patch('dagent.DecisionNode.call_llm_tool') 23 | def test_decision_node_linking(self, mock_call_llm_tool): 24 | # Mock the LLM response to simulate choosing the add_two_nums function 25 | mock_response = MagicMock() 26 | mock_function = MagicMock() 27 | mock_function.name = 'add_two_nums' 28 | mock_function.arguments = '{"a": 2, "b": 3}' 29 | mock_response.tool_calls = [MagicMock(function=mock_function)] 30 | mock_call_llm_tool.return_value = mock_response 31 | 32 | # Compile the decision node 33 | self.decision_node.compile() 34 | 35 | # Run the decision node 36 | with patch.object(self.add_node, 'run') as mock_add_run: 37 | self.decision_node.run(messages=[{'role': 'user', 'content': 'Add 2 and 3'}]) 38 | 39 | # Assert that the add_two_nums function was called 40 | mock_add_run.assert_called_once_with(a=2, b=3) 41 | 42 | @patch('dagent.DecisionNode.call_llm_tool') 43 | def test_decision_node_linking_multiply(self, mock_call_llm_tool): 44 | # Mock the LLM response to simulate choosing the multiply_two_nums function 45 | mock_response = MagicMock() 46 | mock_function = MagicMock() 47 | mock_function.name = 'multiply_two_nums' 48 | mock_function.arguments = '{"a": 4, "b": 5}' 49 | mock_response.tool_calls = [MagicMock(function=mock_function)] 50 | mock_call_llm_tool.return_value = mock_response 51 | 52 | # Compile the decision node 53 | self.decision_node.compile() 54 | 55 | # Run the decision node 56 | with patch.object(self.multiply_node, 'run') as mock_multiply_run: 57 | self.decision_node.run(messages=[{'role': 'user', 'content': 'Multiply 4 and 5'}]) 58 | 59 | # Assert that the multiply_two_nums function was called 60 | mock_multiply_run.assert_called_once_with(a=4, b=5) 61 | 62 | if __name__ == '__main__': 63 | unittest.main() -------------------------------------------------------------------------------- /notes.md: -------------------------------------------------------------------------------- 1 | ## Notes 2 | - [ ] Look if things run in memory and how to isolate for large workflows -> e.g. funcA(funcB(...)) -> funcA(...) -> funcB(...) 3 | - [ ] Side effects/mutations 4 | - [ ] Creating a data model for communication between functions + schema validation -> autogenerate? 5 | - [ ] Logging 6 | - [ ] Alerting on error 7 | - [x] Add a compile method to derive data models and tool descriptions 8 | - [ ] Docker 9 | - [ ] simple chat memory 10 | --------------------------------------------------------------------------------