20 |
21 |
22 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Logging and temporary files
2 | logs/
3 | tmp/
4 | wandb/
5 |
6 | # Data and outputs
7 | data/
8 | outputs/
9 |
10 | # macOS-specific files
11 | .DS_Store
12 |
13 | # IDE and editor files
14 | .vscode/
15 |
16 | # Python-specific ignores
17 | __pycache__/
18 | *.py[cod]
19 | *$py.class
20 |
21 | # Compiled files
22 | *.so
23 |
24 | # Build and packaging artifacts
25 | .Python
26 | build/
27 | dist/
28 | *.egg-info/
29 | .wheels/
30 | MANIFEST
31 |
32 | # Installer logs
33 | pip-log.txt
34 | pip-delete-this-directory.txt
35 |
36 | # Test and coverage
37 | .tox/
38 | .nox/
39 | .coverage*
40 | .pytest_cache/
41 |
42 | # Jupyter Notebook
43 | .ipynb_checkpoints/
44 |
45 | # Environment files
46 | .env
47 | .venv/
48 | ENV/
49 |
50 | # Documentation builds
51 | docs/_build/
52 |
53 | # Type checkers
54 | .mypy_cache/
55 | .pyre/
56 | .pytype/
57 |
58 | # PyCharm
59 | # .idea/
60 |
61 | # Archive and temporary directories
62 | archive/
63 | savedir/
64 | output/
65 | tool_output/
66 |
--------------------------------------------------------------------------------
/tests/test_models.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2024 HuggingFace Inc.
3 | #
4 |
5 | import unittest
6 | from prime import models, tool
7 | from typing import Optional
8 |
9 |
10 | class ModelTests(unittest.TestCase):
11 | def test_get_json_schema_has_nullable_args(self):
12 | @tool
13 | def get_weather(location: str, celsius: Optional[bool] = False) -> str:
14 | """
15 | Get weather in the next days at given location.
16 | Secretly this tool does not care about the location, it hates the weather everywhere.
17 |
18 | Args:
19 | location: the location
20 | celsius: the temperature type
21 | """
22 | return "The weather is UNGODLY with torrential rains and temperatures below -10°C"
23 |
24 | assert (
25 | "nullable"
26 | in models.get_json_schema(get_weather)["function"]["parameters"][
27 | "properties"
28 | ]["celsius"]
29 | )
30 |
--------------------------------------------------------------------------------
/examples/tool_calling_agent_from_any_llm.py:
--------------------------------------------------------------------------------
1 | from prime.agents import ToolCallingAgent
2 | from prime import tool, HfApiModel, TransformersModel, LiteLLMModel
3 | from typing import Optional
4 |
5 | # Choose which LLM engine to use!
6 | # model = HfApiModel(model_id="meta-llama/Llama-3.3-70B-Instruct")
7 | # model = TransformersModel(model_id="meta-llama/Llama-3.2-2B-Instruct")
8 |
9 | # For anthropic: change model_id below to 'anthropic/claude-3-5-sonnet-20240620'
10 | model = LiteLLMModel(model_id="gpt-4o")
11 |
12 | @tool
13 | def get_weather(location: str, celsius: Optional[bool] = False) -> str:
14 | """
15 | Get weather in the next days at given location.
16 | Secretly this tool does not care about the location, it hates the weather everywhere.
17 |
18 | Args:
19 | location: the location
20 | celsius: the temperature
21 | """
22 | return "The weather is UNGODLY with torrential rains and temperatures below -10°C"
23 |
24 | agent = ToolCallingAgent(tools=[get_weather], model=model)
25 |
26 | print(agent.run("What's the weather like in Paris?"))
--------------------------------------------------------------------------------
/tests/test_search.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2024 HuggingFace Inc.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | import unittest
17 |
18 | from prime import load_tool
19 |
20 | from .test_tools import ToolTesterMixin
21 |
22 |
23 | class DuckDuckGoSearchToolTester(unittest.TestCase, ToolTesterMixin):
24 | def setUp(self):
25 | self.tool = load_tool("web_search")
26 | self.tool.setup()
27 |
28 | def test_exact_match_arg(self):
29 | result = self.tool("Agents")
30 | assert isinstance(result, str)
31 |
--------------------------------------------------------------------------------
/docs/source/en/_toctree.yml:
--------------------------------------------------------------------------------
1 | - title: Get started
2 | sections:
3 | - local: index
4 | title: 🤗 Agents
5 | - local: guided_tour
6 | title: Guided tour
7 | - title: Tutorials
8 | sections:
9 | - local: tutorials/building_good_agents
10 | title: ✨ Building good agents
11 | - local: tutorials/tools
12 | title: 🛠️ Tools - in-depth guide
13 | - local: tutorials/secure_code_execution
14 | title: 🛡️ Secure your code execution with E2B
15 | - title: Conceptual guides
16 | sections:
17 | - local: conceptual_guides/intro_agents
18 | title: 🤖 An introduction to agentic systems
19 | - local: conceptual_guides/react
20 | title: 🤔 How do Multi-step agents work?
21 | - title: Examples
22 | sections:
23 | - local: examples/text_to_sql
24 | title: Self-correcting Text-to-SQL
25 | - local: examples/rag
26 | title: Master you knowledge base with agentic RAG
27 | - local: examples/multiagents
28 | title: Orchestrate a multi-agent system
29 | - title: Reference
30 | sections:
31 | - local: reference/agents
32 | title: Agent-related objects
33 | - local: reference/tools
34 | title: Tool-related objects
35 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [build-system]
2 | requires = ["setuptools"]
3 | build-backend = "setuptools.build_meta"
4 |
5 | [project]
6 | name = "prime-agents-py"
7 | version = "0.2.0"
8 | description = "Prime is a powerful, modular, and highly extensible framework designed for developers building applications that require a seamless blend of performance, flexibility, and scalability. With Prime, you can harness the power of modern tools to craft cutting-edge solutions across industries."
9 | authors = [
10 | { name="primengine inc", email="support@primengine.ai" }, { name="Prime Inc"},
11 | ]
12 | readme = "README.md"
13 | requires-python = ">=3.10"
14 | dependencies = [
15 | "torch",
16 | "torchaudio",
17 | "torchvision",
18 | "transformers>=4.0.0",
19 | "requests>=2.32.3",
20 | "rich>=13.9.4",
21 | "pandas>=2.2.3",
22 | "jinja2>=3.1.4",
23 | "pillow>=11.0.0",
24 | "markdownify>=0.14.1",
25 | "gradio>=5.8.0",
26 | "duckduckgo-search>=6.3.7",
27 | "python-dotenv>=1.0.1",
28 | "e2b-code-interpreter>=1.0.3",
29 | "litellm>=1.55.10",
30 | ]
31 |
32 | [project.optional-dependencies]
33 | test = [
34 | "pytest>=8.1.0",
35 | "sqlalchemy"
36 | ]
37 |
--------------------------------------------------------------------------------
/.github/workflows/python-publish.yml:
--------------------------------------------------------------------------------
1 | name: Upload Python Package
2 |
3 | on:
4 | release:
5 | types: [published]
6 |
7 | permissions:
8 | contents: read
9 |
10 | jobs:
11 | release-build:
12 | runs-on: ubuntu-latest
13 |
14 | steps:
15 | - uses: actions/checkout@v4
16 |
17 | - uses: actions/setup-python@v5
18 | with:
19 | python-version: "3.x"
20 |
21 | - name: Build release distributions
22 | run: |
23 | python -m pip install build
24 | python -m build
25 | - name: Upload distributions
26 | uses: actions/upload-artifact@v4
27 | with:
28 | name: release-dists
29 | path: dist/
30 |
31 | pypi-publish:
32 | runs-on: ubuntu-latest
33 | needs:
34 | - release-build
35 |
36 | steps:
37 | - name: Retrieve release distributions
38 | uses: actions/download-artifact@v4
39 | with:
40 | name: release-dists
41 | path: dist/
42 |
43 | - name: Publish release distributions to PyPI
44 | uses: pypa/gh-action-pypi-publish@release/v1
45 | with:
46 | packages-dir: dist/
47 | password: ${{ secrets.PYPI_API_TOKEN }}
48 |
--------------------------------------------------------------------------------
/server.py:
--------------------------------------------------------------------------------
1 | import socket
2 | import sys
3 | import traceback
4 | import io
5 |
6 | exec_globals = {}
7 | exec_locals = {}
8 |
9 | def execute_code(code):
10 | stdout = io.StringIO()
11 | stderr = io.StringIO()
12 | sys.stdout = stdout
13 | sys.stderr = stderr
14 |
15 | try:
16 | exec(code, exec_globals, exec_locals)
17 | except Exception as e:
18 | traceback.print_exc(file=stderr)
19 |
20 | output = stdout.getvalue()
21 | error = stderr.getvalue()
22 |
23 | # Restore original stdout and stderr
24 | sys.stdout = sys.__stdout__
25 | sys.stderr = sys.__stderr__
26 |
27 | return output + error
28 |
29 | def start_server(host='0.0.0.0', port=65432):
30 | with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
31 | s.bind((host, port))
32 | s.listen()
33 | print(f"Server listening on {host}:{port}")
34 | while True:
35 | conn, addr = s.accept()
36 | with conn:
37 | print(f"Connected by {addr}")
38 | data = conn.recv(1024)
39 | if not data:
40 | break
41 | code = data.decode('utf-8')
42 | output = execute_code(code)
43 | conn.sendall(output.encode('utf-8'))
44 |
45 | if __name__ == "__main__":
46 | start_server()
--------------------------------------------------------------------------------
/examples/e2b_example.py:
--------------------------------------------------------------------------------
1 | from prime import Tool, CodeAgent, HfApiModel
2 | from prime.default_tools import VisitWebpageTool
3 | from dotenv import load_dotenv
4 |
5 |
6 | load_dotenv()
7 |
8 |
9 | class GetCatImageTool(Tool):
10 | name="get_cat_image"
11 | description = "Get a cat image"
12 | inputs = {}
13 | output_type = "image"
14 |
15 | def __init__(self):
16 | super().__init__()
17 | self.url = "https://em-content.zobj.net/source/twitter/53/robot-face_1f916.png"
18 |
19 | def forward(self):
20 | from PIL import Image
21 | import requests
22 | from io import BytesIO
23 |
24 | response = requests.get(self.url)
25 |
26 | return Image.open(BytesIO(response.content))
27 |
28 |
29 | get_cat_image = GetCatImageTool()
30 |
31 | agent = CodeAgent(
32 | tools = [get_cat_image, VisitWebpageTool()],
33 | model=HfApiModel(),
34 | additional_authorized_imports=["Pillow", "requests", "markdownify"], # "duckduckgo-search",
35 | use_e2b_executor=True
36 | )
37 |
38 | agent.run(
39 | "Return me an image of a cat. Directly use the image provided in your state.", additional_args={"cat_image":get_cat_image()}
40 | ) # Asking to directly return the image from state tests that additional_args are properly sent to server.
41 |
42 | # Try the agent in a Gradio UI
43 | from prime import GradioUI
44 |
45 | GradioUI(agent).launch()
--------------------------------------------------------------------------------
/src/prime/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # coding=utf-8
3 |
4 | # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
5 | #
6 | # Licensed under the Apache License, Version 2.0 (the "License");
7 | # you may not use this file except in compliance with the License.
8 | # You may obtain a copy of the License at
9 | #
10 | # http://www.apache.org/licenses/LICENSE-2.0
11 | #
12 | # Unless required by applicable law or agreed to in writing, software
13 | # distributed under the License is distributed on an "AS IS" BASIS,
14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 | # See the License for the specific language governing permissions and
16 | # limitations under the License.
17 | __version__ = "1.1.0.dev0"
18 |
19 | from typing import TYPE_CHECKING
20 |
21 | from transformers.utils import _LazyModule
22 | from transformers.utils.import_utils import define_import_structure
23 |
24 |
25 | if TYPE_CHECKING:
26 | from .agents import *
27 | from .default_tools import *
28 | from .gradio_ui import *
29 | from .models import *
30 | from .local_python_executor import *
31 | from .e2b_executor import *
32 | from .monitoring import *
33 | from .prompts import *
34 | from .tools import *
35 | from .types import *
36 | from .utils import *
37 |
38 |
39 | else:
40 | import sys
41 |
42 | _file = globals()["__file__"]
43 | import_structure = define_import_structure(_file)
44 | import_structure[""] = {"__version__": __version__}
45 | sys.modules[__name__] = _LazyModule(
46 | __name__,
47 | _file,
48 | import_structure,
49 | module_spec=__spec__,
50 | extra_objects={"__version__": __version__},
51 | )
52 |
--------------------------------------------------------------------------------
/Makefile:
--------------------------------------------------------------------------------
1 | .PHONY: quality style test docs utils
2 |
3 | check_dirs := .
4 |
5 | # Quality checks
6 | # -----------------------------
7 | # Ensure the source code meets quality standards
8 |
9 | extra_quality_checks:
10 | python utils/check_copies.py
11 | python utils/check_dummies.py
12 | python utils/check_repo.py
13 | doc-builder style prime docs/source --max_len 119
14 |
15 | quality:
16 | ruff check $(check_dirs)
17 | ruff format --check $(check_dirs)
18 | doc-builder style prime docs/source --max_len 119 --check_only
19 |
20 | style:
21 | ruff check $(check_dirs) --fix
22 | ruff format $(check_dirs)
23 | doc-builder style prime docs/source --max_len 119
24 |
25 | # Testing
26 | # -----------------------------
27 | # Run specific tests or all tests for the library
28 |
29 | test_big_modeling:
30 | python -m pytest -s -v ./tests/test_big_modeling.py ./tests/test_modeling_utils.py \
31 | $(if $(IS_GITHUB_CI),--report-log "$(PYTORCH_VERSION)_big_modeling.log",)
32 |
33 | test_core:
34 | python -m pytest -s -v ./tests/ --ignore=./tests/test_examples.py \
35 | $(if $(IS_GITHUB_CI),--report-log "$(PYTORCH_VERSION)_core.log",)
36 |
37 | test_cli:
38 | python -m pytest -s -v ./tests/test_cli.py \
39 | $(if $(IS_GITHUB_CI),--report-log "$(PYTORCH_VERSION)_cli.log",)
40 |
41 | test_examples:
42 | python -m pytest -s -v ./tests/test_examples.py \
43 | $(if $(IS_GITHUB_CI),--report-log "$(PYTORCH_VERSION)_examples.log",)
44 |
45 | test:
46 | $(MAKE) test_core
47 | $(MAKE) test_cli
48 | $(MAKE) test_big_modeling
49 | $(MAKE) test_deepspeed
50 | $(MAKE) test_fsdp
51 |
52 | test_prod:
53 | $(MAKE) test_core
54 |
55 | test_rest:
56 | python -m pytest -s -v ./tests/test_examples.py::FeatureExamplesTests \
57 | $(if $(IS_GITHUB_CI),--report-log "$(PYTORCH_VERSION)_rest.log",)
58 |
--------------------------------------------------------------------------------
/src/prime/monitoring.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # coding=utf-8
3 |
4 | # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
5 | #
6 | # Licensed under the Apache License, Version 2.0 (the "License");
7 | # you may not use this file except in compliance with the License.
8 | # You may obtain a copy of the License at
9 | #
10 | # http://www.apache.org/licenses/LICENSE-2.0
11 | #
12 | # Unless required by applicable law or agreed to in writing, software
13 | # distributed under the License is distributed on an "AS IS" BASIS,
14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 | # See the License for the specific language governing permissions and
16 | # limitations under the License.
17 | from .utils import console
18 | from rich.text import Text
19 |
20 |
21 | class Monitor:
22 | def __init__(self, tracked_model):
23 | self.step_durations = []
24 | self.tracked_model = tracked_model
25 | if (
26 | getattr(self.tracked_model, "last_input_token_count", "Not found")
27 | != "Not found"
28 | ):
29 | self.total_input_token_count = 0
30 | self.total_output_token_count = 0
31 |
32 | def get_total_token_counts(self):
33 | return {
34 | "input": self.total_input_token_count,
35 | "output": self.total_output_token_count,
36 | }
37 |
38 | def reset(self):
39 | self.step_durations = []
40 | self.total_input_token_count = 0
41 | self.total_output_token_count = 0
42 |
43 | def update_metrics(self, step_log):
44 | step_duration = step_log.duration
45 | self.step_durations.append(step_duration)
46 | console_outputs = (
47 | f"[Step {len(self.step_durations)-1}: Duration {step_duration:.2f} seconds"
48 | )
49 |
50 | if getattr(self.tracked_model, "last_input_token_count", None) is not None:
51 | self.total_input_token_count += self.tracked_model.last_input_token_count
52 | self.total_output_token_count += self.tracked_model.last_output_token_count
53 | console_outputs += f"| Input tokens: {self.total_input_token_count:,} | Output tokens: {self.total_output_token_count:,}"
54 | console_outputs += "]"
55 | console.print(Text(console_outputs, style="dim"))
56 |
57 |
58 | __all__ = ["Monitor"]
59 |
--------------------------------------------------------------------------------
/tests/test_final_answer.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2024 HuggingFace Inc.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | import unittest
17 | from pathlib import Path
18 |
19 | import numpy as np
20 | from PIL import Image
21 |
22 | from transformers import is_torch_available
23 | from transformers.testing_utils import get_tests_dir, require_torch
24 | from prime.types import AGENT_TYPE_MAPPING
25 |
26 | from prime.default_tools import FinalAnswerTool
27 |
28 | from .test_tools import ToolTesterMixin
29 |
30 |
31 | if is_torch_available():
32 | import torch
33 |
34 |
35 | class FinalAnswerToolTester(unittest.TestCase, ToolTesterMixin):
36 | def setUp(self):
37 | self.inputs = {"answer": "Final answer"}
38 | self.tool = FinalAnswerTool()
39 |
40 | def test_exact_match_arg(self):
41 | result = self.tool("Final answer")
42 | self.assertEqual(result, "Final answer")
43 |
44 | def test_exact_match_kwarg(self):
45 | result = self.tool(answer=self.inputs["answer"])
46 | self.assertEqual(result, "Final answer")
47 |
48 | def create_inputs(self):
49 | inputs_text = {"answer": "Text input"}
50 | inputs_image = {
51 | "answer": Image.open(
52 | Path(get_tests_dir("fixtures")) / "000000039769.png"
53 | ).resize((512, 512))
54 | }
55 | inputs_audio = {"answer": torch.Tensor(np.ones(3000))}
56 | return {"string": inputs_text, "image": inputs_image, "audio": inputs_audio}
57 |
58 | @require_torch
59 | def test_agent_type_output(self):
60 | inputs = self.create_inputs()
61 | for input_type, input in inputs.items():
62 | output = self.tool(**input, sanitize_inputs_outputs=True)
63 | agent_type = AGENT_TYPE_MAPPING[input_type]
64 | self.assertTrue(isinstance(output, agent_type))
65 |
--------------------------------------------------------------------------------
/examples/text_to_sql.py:
--------------------------------------------------------------------------------
1 |
2 | from sqlalchemy import (
3 | create_engine,
4 | MetaData,
5 | Table,
6 | Column,
7 | String,
8 | Integer,
9 | Float,
10 | insert,
11 | inspect,
12 | text,
13 | )
14 |
15 | engine = create_engine("sqlite:///:memory:")
16 | metadata_obj = MetaData()
17 |
18 | # create city SQL table
19 | table_name = "receipts"
20 | receipts = Table(
21 | table_name,
22 | metadata_obj,
23 | Column("receipt_id", Integer, primary_key=True),
24 | Column("customer_name", String(16), primary_key=True),
25 | Column("price", Float),
26 | Column("tip", Float),
27 | )
28 | metadata_obj.create_all(engine)
29 |
30 | rows = [
31 | {"receipt_id": 1, "customer_name": "Alan Payne", "price": 12.06, "tip": 1.20},
32 | {"receipt_id": 2, "customer_name": "Alex Mason", "price": 23.86, "tip": 0.24},
33 | {"receipt_id": 3, "customer_name": "Woodrow Wilson", "price": 53.43, "tip": 5.43},
34 | {"receipt_id": 4, "customer_name": "Margaret James", "price": 21.11, "tip": 1.00},
35 | ]
36 | for row in rows:
37 | stmt = insert(receipts).values(**row)
38 | with engine.begin() as connection:
39 | cursor = connection.execute(stmt)
40 |
41 | inspector = inspect(engine)
42 | columns_info = [(col["name"], col["type"]) for col in inspector.get_columns("receipts")]
43 |
44 | table_description = "Columns:\n" + "\n".join([f" - {name}: {col_type}" for name, col_type in columns_info])
45 | print(table_description)
46 |
47 | from prime import tool
48 |
49 | @tool
50 | def sql_engine(query: str) -> str:
51 | """
52 | Allows you to perform SQL queries on the table. Returns a string representation of the result.
53 | The table is named 'receipts'. Its description is as follows:
54 | Columns:
55 | - receipt_id: INTEGER
56 | - customer_name: VARCHAR(16)
57 | - price: FLOAT
58 | - tip: FLOAT
59 |
60 | Args:
61 | query: The query to perform. This should be correct SQL.
62 | """
63 | output = ""
64 | with engine.connect() as con:
65 | rows = con.execute(text(query))
66 | for row in rows:
67 | output += "\n" + str(row)
68 | return output
69 |
70 | from prime import CodeAgent, HfApiModel
71 |
72 | agent = CodeAgent(
73 | tools=[sql_engine],
74 | model=HfApiModel("meta-llama/Meta-Llama-3.1-8B-Instruct"),
75 | )
76 | agent.run("Can you give me the name of the client who got the most expensive receipt?")
--------------------------------------------------------------------------------
/examples/rag.py:
--------------------------------------------------------------------------------
1 | # from huggingface_hub import login
2 |
3 | # login()
4 | import datasets
5 | from langchain.docstore.document import Document
6 | from langchain.text_splitter import RecursiveCharacterTextSplitter
7 | from langchain_community.retrievers import BM25Retriever
8 |
9 | knowledge_base = datasets.load_dataset("m-ric/huggingface_doc", split="train")
10 | knowledge_base = knowledge_base.filter(lambda row: row["source"].startswith("huggingface/transformers"))
11 |
12 | source_docs = [
13 | Document(page_content=doc["text"], metadata={"source": doc["source"].split("/")[1]})
14 | for doc in knowledge_base
15 | ]
16 |
17 | text_splitter = RecursiveCharacterTextSplitter(
18 | chunk_size=500,
19 | chunk_overlap=50,
20 | add_start_index=True,
21 | strip_whitespace=True,
22 | separators=["\n\n", "\n", ".", " ", ""],
23 | )
24 | docs_processed = text_splitter.split_documents(source_docs)
25 |
26 | from src.prime.tools import Tool
27 |
28 |
29 | class RetrieverTool(Tool):
30 | name = "retriever"
31 | description = "Uses semantic search to retrieve the parts of transformers documentation that could be most relevant to answer your query."
32 | inputs = {
33 | "query": {
34 | "type": "string",
35 | "description": "The query to perform. This should be semantically close to your target documents. Use the affirmative form rather than a question.",
36 | }
37 | }
38 | output_type = "string"
39 |
40 | def __init__(self, docs, **kwargs):
41 | super().__init__(**kwargs)
42 | self.retriever = BM25Retriever.from_documents(
43 | docs, k=10
44 | )
45 |
46 | def forward(self, query: str) -> str:
47 | assert isinstance(query, str), "Your search query must be a string"
48 |
49 | docs = self.retriever.invoke(
50 | query,
51 | )
52 | return "\nRetrieved documents:\n" + "".join(
53 | [
54 | f"\n\n===== Document {str(i)} =====\n" + doc.page_content
55 | for i, doc in enumerate(docs)
56 | ]
57 | )
58 |
59 |
60 | from prime import HfApiModel, CodeAgent
61 |
62 | retriever_tool = RetrieverTool(docs_processed)
63 | agent = CodeAgent(
64 | tools=[retriever_tool], model=HfApiModel("meta-llama/Llama-3.3-70B-Instruct"), max_iterations=4, verbose=True
65 | )
66 |
67 | agent_output = agent.run("For a transformers model training, which is slower, the forward or the backward pass?")
68 |
69 | print("Final output:")
70 | print(agent_output)
71 |
--------------------------------------------------------------------------------
/docs/source/en/conceptual_guides/react.md:
--------------------------------------------------------------------------------
1 |
16 | # How do multi-step agents work?
17 |
18 | The ReAct framework ([Yao et al., 2022](https://huggingface.co/papers/2210.03629)) is currently the main approach to building agents.
19 |
20 | The name is based on the concatenation of two words, "Reason" and "Act." Indeed, agents following this architecture will solve their task in as many steps as needed, each step consisting of a Reasoning step, then an Action step where it formulates tool calls that will bring it closer to solving the task at hand.
21 |
22 | React process involves keeping a memory of past steps.
23 |
24 | > [!TIP]
25 | > Read [Open-source LLMs as LangChain Agents](https://huggingface.co/blog/open-source-llms-as-agents) blog post to learn more about multi-step agents.
26 |
27 | Here is a video overview of how that works:
28 |
29 |
30 |
34 |
38 |
39 |
40 | 
41 |
42 | We implement two versions of ToolCallingAgent:
43 | - [`ToolCallingAgent`] generates tool calls as a JSON in its output.
44 | - [`CodeAgent`] is a new type of ToolCallingAgent that generates its tool calls as blobs of code, which works really well for LLMs that have strong coding performance.
45 |
46 | > [!TIP]
47 | > We also provide an option to run agents in one-shot: just pass `single_step=True` when launching the agent, like `agent.run(your_task, single_step=True)`
--------------------------------------------------------------------------------
/docs/source/en/reference/tools.md:
--------------------------------------------------------------------------------
1 |
16 | # Tools
17 |
18 |
19 |
20 | Smolagents is an experimental API which is subject to change at any time. Results returned by the agents
21 | can vary as the APIs or underlying models are prone to change.
22 |
23 |
24 |
25 | To learn more about agents and tools make sure to read the [introductory guide](../index). This page
26 | contains the API docs for the underlying classes.
27 |
28 | ## Tools
29 |
30 | ### load_tool
31 |
32 | [[autodoc]] load_tool
33 |
34 | ### tool
35 |
36 | [[autodoc]] tool
37 |
38 | ### Tool
39 |
40 | [[autodoc]] Tool
41 |
42 | ### Toolbox
43 |
44 | [[autodoc]] Toolbox
45 |
46 | ### launch_gradio_demo
47 |
48 | [[autodoc]] launch_gradio_demo
49 |
50 |
51 | ### ToolCollection
52 |
53 | [[autodoc]] ToolCollection
54 |
55 | ## Agent Types
56 |
57 | Agents can handle any type of object in-between tools; tools, being completely multimodal, can accept and return
58 | text, image, audio, video, among other types. In order to increase compatibility between tools, as well as to
59 | correctly render these returns in ipython (jupyter, colab, ipython notebooks, ...), we implement wrapper classes
60 | around these types.
61 |
62 | The wrapped objects should continue behaving as initially; a text object should still behave as a string, an image
63 | object should still behave as a `PIL.Image`.
64 |
65 | These types have three specific purposes:
66 |
67 | - Calling `to_raw` on the type should return the underlying object
68 | - Calling `to_string` on the type should return the object as a string: that can be the string in case of an `AgentText`
69 | but will be the path of the serialized version of the object in other instances
70 | - Displaying it in an ipython kernel should display the object correctly
71 |
72 | ### AgentText
73 |
74 | [[autodoc]] prime.types.AgentText
75 |
76 | ### AgentImage
77 |
78 | [[autodoc]] prime.types.AgentImage
79 |
80 | ### AgentAudio
81 |
82 | [[autodoc]] prime.types.AgentAudio
83 |
--------------------------------------------------------------------------------
/tests/test_utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import unittest
3 | import shutil
4 | import tempfile
5 |
6 | from pathlib import Path
7 |
8 |
9 | def str_to_bool(value) -> int:
10 | """
11 | Converts a string representation of truth to `True` (1) or `False` (0).
12 |
13 | True values are `y`, `yes`, `t`, `true`, `on`, and `1`; False value are `n`, `no`, `f`, `false`, `off`, and `0`;
14 | """
15 | value = value.lower()
16 | if value in ("y", "yes", "t", "true", "on", "1"):
17 | return 1
18 | elif value in ("n", "no", "f", "false", "off", "0"):
19 | return 0
20 | else:
21 | raise ValueError(f"invalid truth value {value}")
22 |
23 |
24 | def get_int_from_env(env_keys, default):
25 | """Returns the first positive env value found in the `env_keys` list or the default."""
26 | for e in env_keys:
27 | val = int(os.environ.get(e, -1))
28 | if val >= 0:
29 | return val
30 | return default
31 |
32 |
33 | def parse_flag_from_env(key, default=False):
34 | """Returns truthy value for `key` from the env if available else the default."""
35 | value = os.environ.get(key, str(default))
36 | return (
37 | str_to_bool(value) == 1
38 | ) # As its name indicates `str_to_bool` actually returns an int...
39 |
40 |
41 | _run_slow_tests = parse_flag_from_env("RUN_SLOW", default=False)
42 |
43 |
44 | def skip(test_case):
45 | "Decorator that skips a test unconditionally"
46 | return unittest.skip("Test was skipped")(test_case)
47 |
48 |
49 | def slow(test_case):
50 | """
51 | Decorator marking a test as slow. Slow tests are skipped by default. Set the RUN_SLOW environment variable to a
52 | truthy value to run them.
53 | """
54 | return unittest.skipUnless(_run_slow_tests, "test is slow")(test_case)
55 |
56 |
57 | class TempDirTestCase(unittest.TestCase):
58 | """
59 | A TestCase class that keeps a single `tempfile.TemporaryDirectory` open for the duration of the class, wipes its
60 | data at the start of a test, and then destroyes it at the end of the TestCase.
61 |
62 | Useful for when a class or API requires a single constant folder throughout it's use, such as Weights and Biases
63 |
64 | The temporary directory location will be stored in `self.tmpdir`
65 | """
66 |
67 | clear_on_setup = True
68 |
69 | @classmethod
70 | def setUpClass(cls):
71 | "Creates a `tempfile.TemporaryDirectory` and stores it in `cls.tmpdir`"
72 | cls.tmpdir = Path(tempfile.mkdtemp())
73 |
74 | @classmethod
75 | def tearDownClass(cls):
76 | "Remove `cls.tmpdir` after test suite has finished"
77 | if os.path.exists(cls.tmpdir):
78 | shutil.rmtree(cls.tmpdir)
79 |
80 | def setUp(self):
81 | "Destroy all contents in `self.tmpdir`, but not `self.tmpdir`"
82 | if self.clear_on_setup:
83 | for path in self.tmpdir.glob("**/*"):
84 | if path.is_file():
85 | path.unlink()
86 | elif path.is_dir():
87 | shutil.rmtree(path)
88 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Welcome Prime Engine 1.0.0 Open Beta Release!
2 | ### Improved AI Engine speed
3 | ### Improved AI models structure
4 | ### Refractored old LLM models
5 | ### Improved AI agents
6 | ### Major and minor bug fixes.
7 |
8 | # Prime Engine - BNB Chain AI Agents Creating Framework
9 |
10 | Prime is a powerful, modular, and highly extensible framework designed for developers building applications that require a seamless blend of performance, flexibility, and scalability. With Prime, you can harness the power of modern tools to craft cutting-edge solutions across industries.
11 |
12 | ---
13 |
14 | ## Contract Address and Development Funding
15 | This project is powered by a single, official smart contract address. Please ensure that you interact only with the following address to avoid fraud or misrepresentation:
16 |
17 | ## Contract Address $PRIME:
18 | `soon`
19 |
20 | ## DEX: https://dexscreener.com/soon
21 | ## WEB: https://primengine.ai/
22 | ## X: https://x.com/primengineai
23 | ## TG: https://t.me/primengineai
24 |
25 | All development and maintenance of this project are funded exclusively through the creator's wallet associated with the token.
26 |
27 | ## Why PRIME?
28 | - Developer-Centric: Designed with developers in mind, PRIME simplifies complex processes and accelerates development.
29 | - Open-Source: Fully open-source, ensuring transparency and community-driven growth.
30 | - Adaptable: Built to accommodate a wide range of industries and applications, from startups to enterprises.
31 |
32 | ## Key Features
33 |
34 | ### 🚀 High Performance
35 | Prime is optimized for speed and efficiency, ensuring your applications run smoothly even under heavy workloads.
36 |
37 | ### ⚙️ Modularity
38 | With a modular design, Prime allows you to integrate only the components you need, keeping your applications lightweight and focused.
39 |
40 | ### 📈 Scalability
41 | Whether you're starting small or building enterprise-grade solutions, Prime adapts to your needs, growing with your application.
42 |
43 | ### 🔒 Security First
44 | Built with security in mind, Prime provides robust tools to safeguard your application and its users.
45 |
46 | ### 🧩 Extensibility
47 | Easily customize and expand Prime's functionality with plugins, libraries, and APIs tailored to your project's needs.
48 |
49 |
50 | ## Getting Started
51 |
52 | First install the package via PyPi.
53 | ```bash
54 | pip install prime-agents-py
55 | ```
56 | Then define your agent, give it the tools it needs and run it!
57 | ```py
58 | from prime import CodeAgent, DuckDuckGoSearchTool, HfApiModel
59 |
60 | agent = CodeAgent(tools=[DuckDuckGoSearchTool()], model=HfApiModel())
61 |
62 | agent.run("Tell me BNB Chain and Bitcoin price in March 2025")
63 | ```
64 |
65 |
66 | ## Contributing
67 | We welcome contributions! If you want to report a bug, suggest a feature, or contribute code, please:
68 |
69 | 1. Fork the repository.
70 | 2. Create a new branch for your feature or bugfix.
71 | 3. Submit a pull request.
72 |
73 | ## How strong are open models for agentic workflows?
74 |
75 | We've created `CodeAgent` instances with some leading models, and compared them on [this benchmark](https://huggingface.co/datasets/m-ric/agents_medium_benchmark_2) that gathers questions from a few different benchmarks to propose a varied blend of challenges.
76 |
77 | ## Citing prime
78 |
79 | If you use `prime` in your publication, please cite it by using the following BibTeX entry.
80 |
81 | ```bibtex
82 | @Misc{prime,
83 | title = {`prime`: The easiest way to build efficient BNB Chain agentic systems.},
84 | author = {primengine inc.},
85 | howpublished = {\url{https://github.com/primengine/prime}},
86 | year = {2025}
87 | }
88 | ```
89 |
--------------------------------------------------------------------------------
/docs/source/en/index.md:
--------------------------------------------------------------------------------
1 |
15 |
16 | # `prime`
17 |
18 | This library is the simplest framework out there to build powerful agents! By the way, wtf are "agents"? We provide our definition [in this page](conceptual_guides/intro_agents), whe're you'll also find tips for when to use them or not (spoilers: you'll often be better off without agents).
19 |
20 | This library offers:
21 |
22 | ✨ **Simplicity**: the logic for agents fits in ~thousand lines of code. We kept abstractions to their minimal shape above raw code!
23 |
24 | 🌐 **Support for any LLM**: it supports models hosted on the Hub loaded in their `transformers` version or through our inference API, but also models from OpenAI, Anthropic... it's really easy to power an agent with any LLM.
25 |
26 | 🧑💻 **First-class support for Code Agents**, i.e. agents that write their actions in code (as opposed to "agents being used to write code"), [read more here](tutorials/secure_code_execution).
27 |
28 | 🤗 **Hub integrations**: you can share and load tools to/from the Hub, and more is to come!
29 |
30 |
50 |
--------------------------------------------------------------------------------
/tests/test_types.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2024 HuggingFace Inc.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | import os
16 | import tempfile
17 | import unittest
18 | import uuid
19 | from pathlib import Path
20 |
21 | from prime.types import AgentAudio, AgentImage, AgentText
22 | from transformers.testing_utils import (
23 | require_soundfile,
24 | require_torch,
25 | require_vision,
26 | )
27 | from transformers.utils import (
28 | is_soundfile_availble,
29 | )
30 |
31 | import torch
32 | from PIL import Image
33 |
34 |
35 | if is_soundfile_availble():
36 | import soundfile as sf
37 |
38 |
39 | def get_new_path(suffix="") -> str:
40 | directory = tempfile.mkdtemp()
41 | return os.path.join(directory, str(uuid.uuid4()) + suffix)
42 |
43 |
44 | @require_soundfile
45 | @require_torch
46 | class AgentAudioTests(unittest.TestCase):
47 | def test_from_tensor(self):
48 | tensor = torch.rand(12, dtype=torch.float64) - 0.5
49 | agent_type = AgentAudio(tensor)
50 | path = str(agent_type.to_string())
51 |
52 | # Ensure that the tensor and the agent_type's tensor are the same
53 | self.assertTrue(torch.allclose(tensor, agent_type.to_raw(), atol=1e-4))
54 |
55 | del agent_type
56 |
57 | # Ensure the path remains even after the object deletion
58 | self.assertTrue(os.path.exists(path))
59 |
60 | # Ensure that the file contains the same value as the original tensor
61 | new_tensor, _ = sf.read(path)
62 | self.assertTrue(torch.allclose(tensor, torch.tensor(new_tensor), atol=1e-4))
63 |
64 | def test_from_string(self):
65 | tensor = torch.rand(12, dtype=torch.float64) - 0.5
66 | path = get_new_path(suffix=".wav")
67 | sf.write(path, tensor, 16000)
68 |
69 | agent_type = AgentAudio(path)
70 |
71 | self.assertTrue(torch.allclose(tensor, agent_type.to_raw(), atol=1e-4))
72 | self.assertEqual(agent_type.to_string(), path)
73 |
74 |
75 | @require_vision
76 | @require_torch
77 | class AgentImageTests(unittest.TestCase):
78 | def test_from_tensor(self):
79 | tensor = torch.randint(0, 256, (64, 64, 3))
80 | agent_type = AgentImage(tensor)
81 | path = str(agent_type.to_string())
82 |
83 | # Ensure that the tensor and the agent_type's tensor are the same
84 | self.assertTrue(torch.allclose(tensor, agent_type._tensor, atol=1e-4))
85 |
86 | self.assertIsInstance(agent_type.to_raw(), Image.Image)
87 |
88 | # Ensure the path remains even after the object deletion
89 | del agent_type
90 | self.assertTrue(os.path.exists(path))
91 |
92 | def test_from_string(self):
93 | path = Path("tests/fixtures/000000039769.png")
94 | image = Image.open(path)
95 | agent_type = AgentImage(path)
96 |
97 | self.assertTrue(path.samefile(agent_type.to_string()))
98 | self.assertTrue(image == agent_type.to_raw())
99 |
100 | # Ensure the path remains even after the object deletion
101 | del agent_type
102 | self.assertTrue(os.path.exists(path))
103 |
104 | def test_from_image(self):
105 | path = Path("tests/fixtures/000000039769.png")
106 | image = Image.open(path)
107 | agent_type = AgentImage(image)
108 |
109 | self.assertFalse(path.samefile(agent_type.to_string()))
110 | self.assertTrue(image == agent_type.to_raw())
111 |
112 | # Ensure the path remains even after the object deletion
113 | del agent_type
114 | self.assertTrue(os.path.exists(path))
115 |
116 |
117 | class AgentTextTests(unittest.TestCase):
118 | def test_from_string(self):
119 | string = "Hey!"
120 | agent_type = AgentText(string)
121 |
122 | self.assertEqual(string, agent_type.to_string())
123 | self.assertEqual(string, agent_type.to_raw())
124 | self.assertEqual(string, agent_type)
125 |
--------------------------------------------------------------------------------
/src/prime/gradio_ui.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # coding=utf-8
3 |
4 | # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
5 | #
6 | # Licensed under the Apache License, Version 2.0 (the "License");
7 | # you may not use this file except in compliance with the License.
8 | # You may obtain a copy of the License at
9 | #
10 | # http://www.apache.org/licenses/LICENSE-2.0
11 | #
12 | # Unless required by applicable law or agreed to in writing, software
13 | # distributed under the License is distributed on an "AS IS" BASIS,
14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 | # See the License for the specific language governing permissions and
16 | # limitations under the License.
17 | from .types import AgentAudio, AgentImage, AgentText, handle_agent_output_types
18 | from .agents import MultiStepAgent, AgentStep, ActionStep
19 | import gradio as gr
20 |
21 |
22 | def pull_messages_from_step(step_log: AgentStep, test_mode: bool = True):
23 | """Extract ChatMessage objects from agent steps"""
24 | if isinstance(step_log, ActionStep):
25 | yield gr.ChatMessage(role="assistant", content=step_log.llm_output)
26 | if step_log.tool_call is not None:
27 | used_code = step_log.tool_call.name == "code interpreter"
28 | content = step_log.tool_call.arguments
29 | if used_code:
30 | content = f"```py\n{content}\n```"
31 | yield gr.ChatMessage(
32 | role="assistant",
33 | metadata={"title": f"🛠️ Used tool {step_log.tool_call.name}"},
34 | content=str(content),
35 | )
36 | if step_log.observations is not None:
37 | yield gr.ChatMessage(
38 | role="assistant", content=f"```\n{step_log.observations}\n```"
39 | )
40 | if step_log.error is not None:
41 | yield gr.ChatMessage(
42 | role="assistant",
43 | content=str(step_log.error),
44 | metadata={"title": "💥 Error"},
45 | )
46 |
47 |
48 | def stream_to_gradio(
49 | agent,
50 | task: str,
51 | test_mode: bool = False,
52 | reset_agent_memory: bool = False,
53 | **kwargs,
54 | ):
55 | """Runs an agent with the given task and streams the messages from the agent as gradio ChatMessages."""
56 |
57 | for step_log in agent.run(task, stream=True, reset=reset_agent_memory, **kwargs):
58 | for message in pull_messages_from_step(step_log, test_mode=test_mode):
59 | yield message
60 |
61 | final_answer = step_log # Last log is the run's final_answer
62 | final_answer = handle_agent_output_types(final_answer)
63 |
64 | if isinstance(final_answer, AgentText):
65 | yield gr.ChatMessage(
66 | role="assistant",
67 | content=f"**Final answer:**\n```\n{final_answer.to_string()}\n```",
68 | )
69 | elif isinstance(final_answer, AgentImage):
70 | yield gr.ChatMessage(
71 | role="assistant",
72 | content={"path": final_answer.to_string(), "mime_type": "image/png"},
73 | )
74 | elif isinstance(final_answer, AgentAudio):
75 | yield gr.ChatMessage(
76 | role="assistant",
77 | content={"path": final_answer.to_string(), "mime_type": "audio/wav"},
78 | )
79 | else:
80 | yield gr.ChatMessage(role="assistant", content=str(final_answer))
81 |
82 |
83 | class GradioUI:
84 | """A one-line interface to launch your agent in Gradio"""
85 |
86 | def __init__(self, agent: MultiStepAgent):
87 | self.agent = agent
88 |
89 | def interact_with_agent(self, prompt, messages):
90 | messages.append(gr.ChatMessage(role="user", content=prompt))
91 | yield messages
92 | for msg in stream_to_gradio(self.agent, task=prompt, reset_agent_memory=False):
93 | messages.append(msg)
94 | yield messages
95 | yield messages
96 |
97 | def launch(self):
98 | with gr.Blocks() as demo:
99 | stored_message = gr.State([])
100 | chatbot = gr.Chatbot(
101 | label="Agent",
102 | type="messages",
103 | avatar_images=(
104 | None,
105 | "https://em-content.zobj.net/source/twitter/53/robot-face_1f916.png",
106 | ),
107 | )
108 | text_input = gr.Textbox(lines=1, label="Chat Message")
109 | text_input.submit(
110 | lambda s: (s, ""), [text_input], [stored_message, text_input]
111 | ).then(self.interact_with_agent, [stored_message, chatbot], [chatbot])
112 |
113 | demo.launch()
114 |
115 |
116 | __all__ = ["stream_to_gradio", "GradioUI"]
117 |
--------------------------------------------------------------------------------
/docs/source/en/tutorials/secure_code_execution.md:
--------------------------------------------------------------------------------
1 |
16 | # Secure code execution
17 |
18 | [[open-in-colab]]
19 |
20 | > [!TIP]
21 | > If you're new to building agents, make sure to first read the [intro to agents](../conceptual_guides/intro_agents) and the [guided tour of prime](../guided_tour).
22 |
23 | ### Code agents
24 |
25 | [Multiple](https://huggingface.co/papers/2402.01030) [research](https://huggingface.co/papers/2411.01747) [papers](https://huggingface.co/papers/2401.00812) have shown that having the LLM write its actions (the tool calls) in code is much better than the current standard format for tool calling, which is across the industry different shades of "writing actions as a JSON of tools names and arguments to use".
26 |
27 | Why is code better? Well, because we crafted our code languages specifically to be great at expressing actions performed by a computer. If JSON snippets was a better way, this package would have been written in JSON snippets and the devil would be laughing at us.
28 |
29 | Code is just a better way to express actions on a computer. It has better:
30 | - **Composability:** could you nest JSON actions within each other, or define a set of JSON actions to re-use later, the same way you could just define a python function?
31 | - **Object management:** how do you store the output of an action like `generate_image` in JSON?
32 | - **Generality:** code is built to express simply anything you can do have a computer do.
33 | - **Representation in LLM training corpuses:** why not leverage this benediction of the sky that plenty of quality actions have already been included in LLM training corpuses?
34 |
35 | This is illustrated on the figure below, taken from [Executable Code Actions Elicit Better LLM Agents](https://huggingface.co/papers/2402.01030).
36 |
37 |
38 |
39 | This is why we put emphasis on proposing code agents, in this case python agents, which meant putting higher effort on building secure python interpreters.
40 |
41 | ### Local python interpreter
42 |
43 | By default, the `CodeAgent` runs LLM-generated code in your environment.
44 | This execution is not done by the vanilla Python interpreter: we've re-built a more secure `LocalPythonInterpreter` from the ground up.
45 | This interpreter is designed for security by:
46 | - Restricting the imports to a list explicitly passed by the user
47 | - Capping the number of operations to prevent infinite loops and resource bloating.
48 | - Will not perform any operation that's not pre-defined.
49 |
50 | Wev'e used this on many use cases, without ever observing any damage to the environment.
51 |
52 | However this solution is not watertight: one could imagine occasions where LLMs fine-tuned for malignant actions could still hurt your environment. For instance if you've allowed an innocuous package like `Pillow` to process images, the LLM could generate thousands of saves of images to bloat your hard drive.
53 | It's certainly not likely if you've chosen the LLM engine yourself, but it could happen.
54 |
55 | So if you want to be extra cautious, you can use the remote code execution option described below.
56 |
57 | ### E2B code executor
58 |
59 | For maximum security, you can use our integration with E2B to run code in a sandboxed environment. This is a remote execution service that runs your code in an isolated container, making it impossible for the code to affect your local environment.
60 |
61 | For this, you will need to setup your E2B account and set your `E2B_API_KEY` in your environment variables. Head to [E2B's quickstart documentation](https://e2b.dev/docs/quickstart) for more information.
62 |
63 | Then you can install it with `pip install e2b-code-interpreter python-dotenv`.
64 |
65 | Now you're set!
66 |
67 | To set the code executor to E2B, simply pass the flag `use_e2b_executor=True` when initializing your `CodeAgent`.
68 | Note that you should add all the tool's dependencies in `additional_authorized_imports`, so that the executor installs them.
69 |
70 | ```py
71 | from prime import CodeAgent, VisitWebpageTool, HfApiModel
72 | agent = CodeAgent(
73 | tools = [VisitWebpageTool()],
74 | model=HfApiModel(),
75 | additional_authorized_imports=["requests", "markdownify"],
76 | use_e2b_executor=True
77 | )
78 |
79 | agent.run("What was Abraham Lincoln's preferred pet?")
80 | ```
81 |
82 | E2B code execution is not compatible with multi-agents at the moment - because having an agent call in a code blob that should be executed remotely is a mess. But we're working on adding it!
--------------------------------------------------------------------------------
/docs/source/en/reference/agents.md:
--------------------------------------------------------------------------------
1 |
16 | # Agents
17 |
18 |
19 |
20 | Smolagents is an experimental API which is subject to change at any time. Results returned by the agents
21 | can vary as the APIs or underlying models are prone to change.
22 |
23 |
24 |
25 | To learn more about agents and tools make sure to read the [introductory guide](../index). This page
26 | contains the API docs for the underlying classes.
27 |
28 | ## Agents
29 |
30 | Our agents inherit from [`MultiStepAgent`], which means they can act in multiple steps, each step consisting of one thought, then one tool call and execution. Read more in [this conceptual guide](../conceptual_guides/react).
31 |
32 | We provide two types of agents, based on the main [`Agent`] class.
33 | - [`CodeAgent`] is the default agent, it writes its tool calls in Python code.
34 | - [`ToolCallingAgent`] writes its tool calls in JSON.
35 |
36 | Both require arguments `model` and list of tools `tools` at initialization.
37 |
38 |
39 | ### Classes of agents
40 |
41 | [[autodoc]] MultiStepAgent
42 |
43 | [[autodoc]] CodeAgent
44 |
45 | [[autodoc]] ToolCallingAgent
46 |
47 |
48 | ### ManagedAgent
49 |
50 | [[autodoc]] ManagedAgent
51 |
52 | ### stream_to_gradio
53 |
54 | [[autodoc]] stream_to_gradio
55 |
56 | ### GradioUI
57 |
58 | [[autodoc]] GradioUI
59 |
60 | ## Models
61 |
62 | You're free to create and use your own models to power your agent.
63 |
64 | You could use any `model` callable for your agent, as long as:
65 | 1. It follows the [messages format](./chat_templating) (`List[Dict[str, str]]`) for its input `messages`, and it returns a `str`.
66 | 2. It stops generating outputs *before* the sequences passed in the argument `stop_sequences`
67 |
68 | For defining your LLM, you can make a `custom_model` method which accepts a list of [messages](./chat_templating) and returns text. This callable also needs to accept a `stop_sequences` argument that indicates when to stop generating.
69 |
70 | ```python
71 | from huggingface_hub import login, InferenceClient
72 |
73 | login("")
74 |
75 | model_id = "meta-llama/Llama-3.3-70B-Instruct"
76 |
77 | client = InferenceClient(model=model_id)
78 |
79 | def custom_model(messages, stop_sequences=["Task"]) -> str:
80 | response = client.chat_completion(messages, stop=stop_sequences, max_tokens=1000)
81 | answer = response.choices[0].message.content
82 | return answer
83 | ```
84 |
85 | Additionally, `custom_model` can also take a `grammar` argument. In the case where you specify a `grammar` upon agent initialization, this argument will be passed to the calls to model, with the `grammar` that you defined upon initialization, to allow [constrained generation](https://huggingface.co/docs/text-generation-inference/conceptual/guidance) in order to force properly-formatted agent outputs.
86 |
87 | ### TransformersModel
88 |
89 | For convenience, we have added a `TransformersModel` that implements the points above by building a local `transformers` pipeline for the model_id given at initialization.
90 |
91 | ```python
92 | from prime import TransformersModel
93 |
94 | model = TransformersModel(model_id="HuggingFaceTB/SmolLM-135M-Instruct")
95 |
96 | print(model([{"role": "user", "content": "Ok!"}], stop_sequences=["great"]))
97 | ```
98 | ```text
99 | >>> What a
100 | ```
101 |
102 | [[autodoc]] TransformersModel
103 |
104 | ### HfApiModel
105 |
106 | The `HfApiModel` wraps an [HF Inference API](https://huggingface.co/docs/api-inference/index) client for the execution of the LLM.
107 |
108 | ```python
109 | from prime import HfApiModel
110 |
111 | messages = [
112 | {"role": "user", "content": "Hello, how are you?"},
113 | {"role": "assistant", "content": "I'm doing great. How can I help you today?"},
114 | {"role": "user", "content": "No need to help, take it easy."},
115 | ]
116 |
117 | model = HfApiModel()
118 | print(model(messages))
119 | ```
120 | ```text
121 | >>> Of course! If you change your mind, feel free to reach out. Take care!
122 | ```
123 | [[autodoc]] HfApiModel
124 |
125 | ### LiteLLMModel
126 |
127 | The `LiteLLMModel` leverages [LiteLLM](https://www.litellm.ai/) to support 100+ LLMs from various providers.
128 |
129 | ```python
130 | from prime import LiteLLMModel
131 |
132 | messages = [
133 | {"role": "user", "content": "Hello, how are you?"},
134 | {"role": "assistant", "content": "I'm doing great. How can I help you today?"},
135 | {"role": "user", "content": "No need to help, take it easy."},
136 | ]
137 |
138 | model = LiteLLMModel("anthropic/claude-3-5-sonnet-latest")
139 | print(model(messages))
140 | ```
141 |
142 | [[autodoc]] LiteLLMModel
--------------------------------------------------------------------------------
/src/prime/e2b_executor.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # coding=utf-8
3 |
4 | # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
5 | #
6 | # Licensed under the Apache License, Version 2.0 (the "License");
7 | # you may not use this file except in compliance with the License.
8 | # You may obtain a copy of the License at
9 | #
10 | # http://www.apache.org/licenses/LICENSE-2.0
11 | #
12 | # Unless required by applicable law or agreed to in writing, software
13 | # distributed under the License is distributed on an "AS IS" BASIS,
14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 | # See the License for the specific language governing permissions and
16 | # limitations under the License.
17 | from dotenv import load_dotenv
18 | import textwrap
19 | import base64
20 | import pickle
21 | from io import BytesIO
22 | from PIL import Image
23 |
24 | from e2b_code_interpreter import Sandbox
25 | from typing import List, Tuple, Any
26 | from .tool_validation import validate_tool_attributes
27 | from .utils import instance_to_source, BASE_BUILTIN_MODULES, console
28 | from .tools import Tool
29 |
30 | load_dotenv()
31 |
32 |
33 | class E2BExecutor:
34 | def __init__(self, additional_imports: List[str], tools: List[Tool]):
35 | self.custom_tools = {}
36 | self.sbx = Sandbox() # "qywp2ctmu2q7jzprcf4j")
37 | # TODO: validate installing agents package or not
38 | # print("Installing agents package on remote executor...")
39 | # self.sbx.commands.run(
40 | # "pip install git+https://github.com/huggingface/prime.git",
41 | # timeout=300
42 | # )
43 | # print("Installation of agents package finished.")
44 | additional_imports = additional_imports + ["pickle5"]
45 | if len(additional_imports) > 0:
46 | execution = self.sbx.commands.run(
47 | "pip install " + " ".join(additional_imports)
48 | )
49 | if execution.error:
50 | raise Exception(f"Error installing dependencies: {execution.error}")
51 | else:
52 | console.print(f"Installation of {additional_imports} succeeded!")
53 |
54 | tool_codes = []
55 | for tool in tools:
56 | validate_tool_attributes(tool.__class__, check_imports=False)
57 | tool_code = instance_to_source(tool, base_cls=Tool)
58 | tool_code = tool_code.replace("from prime.tools import Tool", "")
59 | tool_code += f"\n{tool.name} = {tool.__class__.__name__}()\n"
60 | tool_codes.append(tool_code)
61 |
62 | tool_definition_code = "\n".join(
63 | [f"import {module}" for module in BASE_BUILTIN_MODULES]
64 | )
65 | tool_definition_code += textwrap.dedent("""
66 | class Tool:
67 | def __call__(self, *args, **kwargs):
68 | return self.forward(*args, **kwargs)
69 |
70 | def forward(self, *args, **kwargs):
71 | pass # to be implemented in child class
72 | """)
73 | tool_definition_code += "\n\n".join(tool_codes)
74 |
75 | tool_definition_execution = self.run_code_raise_errors(tool_definition_code)
76 | console.print(tool_definition_execution.logs)
77 |
78 | def run_code_raise_errors(self, code: str):
79 | execution = self.sbx.run_code(
80 | code,
81 | )
82 | if execution.error:
83 | execution_logs = "\n".join([str(log) for log in execution.logs.stdout])
84 | logs = execution_logs
85 | logs += "Executing code yielded an error:"
86 | logs += execution.error.name
87 | logs += execution.error.value
88 | logs += execution.error.traceback
89 | raise ValueError(logs)
90 | return execution
91 |
92 | def __call__(self, code_action: str, additional_args: dict) -> Tuple[Any, Any]:
93 | if len(additional_args) > 0:
94 | # Pickle additional_args to server
95 | import tempfile
96 |
97 | with tempfile.NamedTemporaryFile() as f:
98 | pickle.dump(additional_args, f)
99 | f.flush()
100 | with open(f.name, "rb") as file:
101 | self.sbx.files.write("/home/state.pkl", file)
102 | remote_unloading_code = """import pickle
103 | import os
104 | print("File path", os.path.getsize('/home/state.pkl'))
105 | with open('/home/state.pkl', 'rb') as f:
106 | pickle_dict = pickle.load(f)
107 | locals().update({key: value for key, value in pickle_dict.items()})
108 | """
109 | execution = self.run_code_raise_errors(remote_unloading_code)
110 | execution_logs = "\n".join([str(log) for log in execution.logs.stdout])
111 | console.print(execution_logs)
112 |
113 | execution = self.run_code_raise_errors(code_action)
114 | execution_logs = "\n".join([str(log) for log in execution.logs.stdout])
115 | if not execution.results:
116 | return None, execution_logs
117 | else:
118 | for result in execution.results:
119 | if result.is_main_result:
120 | for attribute_name in ["jpeg", "png"]:
121 | if getattr(result, attribute_name) is not None:
122 | image_output = getattr(result, attribute_name)
123 | decoded_bytes = base64.b64decode(
124 | image_output.encode("utf-8")
125 | )
126 | return Image.open(BytesIO(decoded_bytes)), execution_logs
127 | for attribute_name in [
128 | "chart",
129 | "data",
130 | "html",
131 | "javascript",
132 | "json",
133 | "latex",
134 | "markdown",
135 | "pdf",
136 | "svg",
137 | "text",
138 | ]:
139 | if getattr(result, attribute_name) is not None:
140 | return getattr(result, attribute_name), execution_logs
141 | raise ValueError("No main result returned by executor!")
142 |
143 |
144 | __all__ = ["E2BExecutor"]
145 |
--------------------------------------------------------------------------------
/tests/test_all_docs.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2024 HuggingFace Inc.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | import ast
17 | import os
18 | import re
19 | import shutil
20 | import tempfile
21 | import subprocess
22 | import traceback
23 | import pytest
24 | from pathlib import Path
25 | from typing import List
26 | from dotenv import load_dotenv
27 |
28 |
29 | class SubprocessCallException(Exception):
30 | pass
31 |
32 |
33 | def run_command(command: List[str], return_stdout=False, env=None):
34 | """
35 | Runs command with subprocess.check_output and returns stdout if requested.
36 | Properly captures and handles errors during command execution.
37 | """
38 | for i, c in enumerate(command):
39 | if isinstance(c, Path):
40 | command[i] = str(c)
41 |
42 | if env is None:
43 | env = os.environ.copy()
44 |
45 | try:
46 | output = subprocess.check_output(command, stderr=subprocess.STDOUT, env=env)
47 | if return_stdout:
48 | if hasattr(output, "decode"):
49 | output = output.decode("utf-8")
50 | return output
51 | except subprocess.CalledProcessError as e:
52 | raise SubprocessCallException(
53 | f"Command `{' '.join(command)}` failed with the following error:\n\n{e.output.decode()}"
54 | ) from e
55 |
56 |
57 | class DocCodeExtractor:
58 | """Handles extraction and validation of Python code from markdown files."""
59 |
60 | @staticmethod
61 | def extract_python_code(content: str) -> List[str]:
62 | """Extract Python code blocks from markdown content."""
63 | pattern = r"```(?:python|py)\n(.*?)\n```"
64 | matches = re.finditer(pattern, content, re.DOTALL)
65 | return [match.group(1).strip() for match in matches]
66 |
67 | @staticmethod
68 | def create_test_script(code_blocks: List[str], tmp_dir: str) -> Path:
69 | """Create a temporary Python script from code blocks."""
70 | combined_code = "\n\n".join(code_blocks)
71 | assert len(combined_code) > 0, "Code is empty!"
72 | tmp_file = Path(tmp_dir) / "test_script.py"
73 |
74 | with open(tmp_file, "w", encoding="utf-8") as f:
75 | f.write(combined_code)
76 |
77 | return tmp_file
78 |
79 |
80 | class TestDocs:
81 | """Test case for documentation code testing."""
82 |
83 | @classmethod
84 | def setup_class(cls):
85 | cls._tmpdir = tempfile.mkdtemp()
86 | cls.launch_args = ["python3"]
87 | cls.docs_dir = Path(__file__).parent.parent / "docs" / "source"
88 | cls.extractor = DocCodeExtractor()
89 |
90 | if not cls.docs_dir.exists():
91 | raise ValueError(f"Docs directory not found at {cls.docs_dir}")
92 |
93 | load_dotenv()
94 | cls.hf_token = os.getenv("HF_TOKEN")
95 |
96 | cls.md_files = list(cls.docs_dir.rglob("*.md"))
97 | if not cls.md_files:
98 | raise ValueError(f"No markdown files found in {cls.docs_dir}")
99 |
100 | @classmethod
101 | def teardown_class(cls):
102 | shutil.rmtree(cls._tmpdir)
103 |
104 | @pytest.mark.timeout(100)
105 | def test_single_doc(self, doc_path: Path):
106 | """Test a single documentation file."""
107 | with open(doc_path, "r", encoding="utf-8") as f:
108 | content = f.read()
109 |
110 | code_blocks = self.extractor.extract_python_code(content)
111 | excluded_snippets = [
112 | "ToolCollection",
113 | "image_generation_tool",
114 | "from_langchain",
115 | "while llm_should_continue(memory):",
116 | ]
117 | code_blocks = [
118 | block
119 | for block in code_blocks
120 | if not any(
121 | [snippet in block for snippet in excluded_snippets]
122 | ) # Exclude these tools that take longer to run and add dependencies
123 | ]
124 | if len(code_blocks) == 0:
125 | pytest.skip(f"No Python code blocks found in {doc_path.name}")
126 |
127 | # Validate syntax of each block individually by parsing it
128 | for i, block in enumerate(code_blocks, 1):
129 | ast.parse(block)
130 |
131 | # Create and execute test script
132 | try:
133 | code_blocks = [
134 | block.replace("", self.hf_token).replace(
135 | "{your_username}", "m-ric"
136 | )
137 | for block in code_blocks
138 | ]
139 | test_script = self.extractor.create_test_script(code_blocks, self._tmpdir)
140 | run_command(self.launch_args + [str(test_script)])
141 |
142 | except SubprocessCallException as e:
143 | pytest.fail(f"\nError while testing {doc_path.name}:\n{str(e)}")
144 | except Exception:
145 | pytest.fail(
146 | f"\nUnexpected error while testing {doc_path.name}:\n{traceback.format_exc()}"
147 | )
148 |
149 | @pytest.fixture(autouse=True)
150 | def _setup(self):
151 | """Fixture to ensure temporary directory exists for each test."""
152 | os.makedirs(self._tmpdir, exist_ok=True)
153 | yield
154 | # Clean up test files after each test
155 | for file in Path(self._tmpdir).glob("*"):
156 | file.unlink()
157 |
158 |
159 | def pytest_generate_tests(metafunc):
160 | """Generate test cases for each markdown file."""
161 | if "doc_path" in metafunc.fixturenames:
162 | test_class = metafunc.cls
163 |
164 | # Initialize the class if needed
165 | if not hasattr(test_class, "md_files"):
166 | test_class.setup_class()
167 |
168 | # Parameterize with the markdown files
169 | metafunc.parametrize(
170 | "doc_path", test_class.md_files, ids=[f.stem for f in test_class.md_files]
171 | )
172 |
--------------------------------------------------------------------------------
/tests/test_monitoring.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2024 HuggingFace Inc.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | import unittest
17 |
18 | from prime import (
19 | AgentImage,
20 | AgentError,
21 | CodeAgent,
22 | ToolCallingAgent,
23 | stream_to_gradio,
24 | )
25 |
26 |
27 | class MonitoringTester(unittest.TestCase):
28 | def test_code_agent_metrics(self):
29 | class FakeLLMModel:
30 | def __init__(self):
31 | self.last_input_token_count = 10
32 | self.last_output_token_count = 20
33 |
34 | def __call__(self, prompt, **kwargs):
35 | return """
36 | Code:
37 | ```py
38 | final_answer('This is the final answer.')
39 | ```"""
40 |
41 | agent = CodeAgent(
42 | tools=[],
43 | model=FakeLLMModel(),
44 | max_iterations=1,
45 | )
46 |
47 | agent.run("Fake task")
48 |
49 | self.assertEqual(agent.monitor.total_input_token_count, 10)
50 | self.assertEqual(agent.monitor.total_output_token_count, 20)
51 |
52 | def test_json_agent_metrics(self):
53 | class FakeLLMModel:
54 | def __init__(self):
55 | self.last_input_token_count = 10
56 | self.last_output_token_count = 20
57 |
58 | def get_tool_call(self, prompt, **kwargs):
59 | return "final_answer", {"answer": "image"}, "fake_id"
60 |
61 | agent = ToolCallingAgent(
62 | tools=[],
63 | model=FakeLLMModel(),
64 | max_iterations=1,
65 | )
66 |
67 | agent.run("Fake task")
68 |
69 | self.assertEqual(agent.monitor.total_input_token_count, 10)
70 | self.assertEqual(agent.monitor.total_output_token_count, 20)
71 |
72 | def test_code_agent_metrics_max_iterations(self):
73 | class FakeLLMModel:
74 | def __init__(self):
75 | self.last_input_token_count = 10
76 | self.last_output_token_count = 20
77 |
78 | def __call__(self, prompt, **kwargs):
79 | return "Malformed answer"
80 |
81 | agent = CodeAgent(
82 | tools=[],
83 | model=FakeLLMModel(),
84 | max_iterations=1,
85 | )
86 |
87 | agent.run("Fake task")
88 |
89 | self.assertEqual(agent.monitor.total_input_token_count, 20)
90 | self.assertEqual(agent.monitor.total_output_token_count, 40)
91 |
92 | def test_code_agent_metrics_generation_error(self):
93 | class FakeLLMModel:
94 | def __init__(self):
95 | self.last_input_token_count = 10
96 | self.last_output_token_count = 20
97 |
98 | def __call__(self, prompt, **kwargs):
99 | self.last_input_token_count = 10
100 | self.last_output_token_count = 0
101 | raise Exception("Cannot generate")
102 |
103 | agent = CodeAgent(
104 | tools=[],
105 | model=FakeLLMModel(),
106 | max_iterations=1,
107 | )
108 | agent.run("Fake task")
109 |
110 | self.assertEqual(
111 | agent.monitor.total_input_token_count, 20
112 | ) # Should have done two monitoring callbacks
113 | self.assertEqual(agent.monitor.total_output_token_count, 0)
114 |
115 | def test_streaming_agent_text_output(self):
116 | def dummy_model(prompt, **kwargs):
117 | return """
118 | Code:
119 | ```py
120 | final_answer('This is the final answer.')
121 | ```"""
122 |
123 | agent = CodeAgent(
124 | tools=[],
125 | model=dummy_model,
126 | max_iterations=1,
127 | )
128 |
129 | # Use stream_to_gradio to capture the output
130 | outputs = list(stream_to_gradio(agent, task="Test task", test_mode=True))
131 |
132 | self.assertEqual(len(outputs), 4)
133 | final_message = outputs[-1]
134 | self.assertEqual(final_message.role, "assistant")
135 | self.assertIn("This is the final answer.", final_message.content)
136 |
137 | def test_streaming_agent_image_output(self):
138 | class FakeLLM:
139 | def __init__(self):
140 | pass
141 |
142 | def get_tool_call(self, messages, **kwargs):
143 | return "final_answer", {"answer": "image"}, "fake_id"
144 |
145 | agent = ToolCallingAgent(
146 | tools=[],
147 | model=FakeLLM(),
148 | max_iterations=1,
149 | )
150 |
151 | # Use stream_to_gradio to capture the output
152 | outputs = list(
153 | stream_to_gradio(
154 | agent,
155 | task="Test task",
156 | additional_args=dict(image=AgentImage(value="path.png")),
157 | test_mode=True,
158 | )
159 | )
160 |
161 | self.assertEqual(len(outputs), 3)
162 | final_message = outputs[-1]
163 | self.assertEqual(final_message.role, "assistant")
164 | self.assertIsInstance(final_message.content, dict)
165 | self.assertEqual(final_message.content["path"], "path.png")
166 | self.assertEqual(final_message.content["mime_type"], "image/png")
167 |
168 | def test_streaming_with_agent_error(self):
169 | def dummy_model(prompt, **kwargs):
170 | raise AgentError("Simulated agent error")
171 |
172 | agent = CodeAgent(
173 | tools=[],
174 | model=dummy_model,
175 | max_iterations=1,
176 | )
177 |
178 | # Use stream_to_gradio to capture the output
179 | outputs = list(stream_to_gradio(agent, task="Test task", test_mode=True))
180 |
181 | self.assertEqual(len(outputs), 5)
182 | final_message = outputs[-1]
183 | self.assertEqual(final_message.role, "assistant")
184 | self.assertIn("Simulated agent error", final_message.content)
185 |
--------------------------------------------------------------------------------
/docs/source/en/examples/text_to_sql.md:
--------------------------------------------------------------------------------
1 |
16 | # Text-to-SQL
17 |
18 | [[open-in-colab]]
19 |
20 | In this tutorial, we’ll see how to implement an agent that leverages SQL using `prime`.
21 |
22 | > Let's start with the golden question: why not keep it simple and use a standard text-to-SQL pipeline?
23 |
24 | A standard text-to-sql pipeline is brittle, since the generated SQL query can be incorrect. Even worse, the query could be incorrect, but not raise an error, instead giving some incorrect/useless outputs without raising an alarm.
25 |
26 | 👉 Instead, an agent system is able to critically inspect outputs and decide if the query needs to be changed or not, thus giving it a huge performance boost.
27 |
28 | Let’s build this agent! 💪
29 |
30 | First, we setup the SQL environment:
31 | ```py
32 | from sqlalchemy import (
33 | create_engine,
34 | MetaData,
35 | Table,
36 | Column,
37 | String,
38 | Integer,
39 | Float,
40 | insert,
41 | inspect,
42 | text,
43 | )
44 |
45 | engine = create_engine("sqlite:///:memory:")
46 | metadata_obj = MetaData()
47 |
48 | # create city SQL table
49 | table_name = "receipts"
50 | receipts = Table(
51 | table_name,
52 | metadata_obj,
53 | Column("receipt_id", Integer, primary_key=True),
54 | Column("customer_name", String(16), primary_key=True),
55 | Column("price", Float),
56 | Column("tip", Float),
57 | )
58 | metadata_obj.create_all(engine)
59 |
60 | rows = [
61 | {"receipt_id": 1, "customer_name": "Alan Payne", "price": 12.06, "tip": 1.20},
62 | {"receipt_id": 2, "customer_name": "Alex Mason", "price": 23.86, "tip": 0.24},
63 | {"receipt_id": 3, "customer_name": "Woodrow Wilson", "price": 53.43, "tip": 5.43},
64 | {"receipt_id": 4, "customer_name": "Margaret James", "price": 21.11, "tip": 1.00},
65 | ]
66 | for row in rows:
67 | stmt = insert(receipts).values(**row)
68 | with engine.begin() as connection:
69 | cursor = connection.execute(stmt)
70 | ```
71 |
72 | ### Build our agent
73 |
74 | Now let’s make our SQL table retrievable by a tool.
75 |
76 | The tool’s description attribute will be embedded in the LLM’s prompt by the agent system: it gives the LLM information about how to use the tool. This is where we want to describe the SQL table.
77 |
78 | ```py
79 | inspector = inspect(engine)
80 | columns_info = [(col["name"], col["type"]) for col in inspector.get_columns("receipts")]
81 |
82 | table_description = "Columns:\n" + "\n".join([f" - {name}: {col_type}" for name, col_type in columns_info])
83 | print(table_description)
84 | ```
85 |
86 | ```text
87 | Columns:
88 | - receipt_id: INTEGER
89 | - customer_name: VARCHAR(16)
90 | - price: FLOAT
91 | - tip: FLOAT
92 | ```
93 |
94 | Now let’s build our tool. It needs the following: (read [the tool doc](../tutorials/tools) for more detail)
95 | - A docstring with an `Args:` part listing arguments.
96 | - Type hints on both inputs and output.
97 |
98 | ```py
99 | from prime import tool
100 |
101 | @tool
102 | def sql_engine(query: str) -> str:
103 | """
104 | Allows you to perform SQL queries on the table. Returns a string representation of the result.
105 | The table is named 'receipts'. Its description is as follows:
106 | Columns:
107 | - receipt_id: INTEGER
108 | - customer_name: VARCHAR(16)
109 | - price: FLOAT
110 | - tip: FLOAT
111 |
112 | Args:
113 | query: The query to perform. This should be correct SQL.
114 | """
115 | output = ""
116 | with engine.connect() as con:
117 | rows = con.execute(text(query))
118 | for row in rows:
119 | output += "\n" + str(row)
120 | return output
121 | ```
122 |
123 | Now let us create an agent that leverages this tool.
124 |
125 | We use the `CodeAgent`, which is transformers.agents’ main agent class: an agent that writes actions in code and can iterate on previous output according to the ReAct framework.
126 |
127 | The model is the LLM that powers the agent system. HfApiModel allows you to call LLMs using HF’s Inference API, either via Serverless or Dedicated endpoint, but you could also use any proprietary API.
128 |
129 | ```py
130 | from prime import CodeAgent, HfApiModel
131 |
132 | agent = CodeAgent(
133 | tools=[sql_engine],
134 | model=HfApiModel("meta-llama/Meta-Llama-3.1-8B-Instruct"),
135 | )
136 | agent.run("Can you give me the name of the client who got the most expensive receipt?")
137 | ```
138 |
139 | ### Level 2: Table joins
140 |
141 | Now let’s make it more challenging! We want our agent to handle joins across multiple tables.
142 |
143 | So let’s make a second table recording the names of waiters for each receipt_id!
144 |
145 | ```py
146 | table_name = "waiters"
147 | receipts = Table(
148 | table_name,
149 | metadata_obj,
150 | Column("receipt_id", Integer, primary_key=True),
151 | Column("waiter_name", String(16), primary_key=True),
152 | )
153 | metadata_obj.create_all(engine)
154 |
155 | rows = [
156 | {"receipt_id": 1, "waiter_name": "Corey Johnson"},
157 | {"receipt_id": 2, "waiter_name": "Michael Watts"},
158 | {"receipt_id": 3, "waiter_name": "Michael Watts"},
159 | {"receipt_id": 4, "waiter_name": "Margaret James"},
160 | ]
161 | for row in rows:
162 | stmt = insert(receipts).values(**row)
163 | with engine.begin() as connection:
164 | cursor = connection.execute(stmt)
165 | ```
166 | Since we changed the table, we update the `SQLExecutorTool` with this table’s description to let the LLM properly leverage information from this table.
167 |
168 | ```py
169 | updated_description = """Allows you to perform SQL queries on the table. Beware that this tool's output is a string representation of the execution output.
170 | It can use the following tables:"""
171 |
172 | inspector = inspect(engine)
173 | for table in ["receipts", "waiters"]:
174 | columns_info = [(col["name"], col["type"]) for col in inspector.get_columns(table)]
175 |
176 | table_description = f"Table '{table}':\n"
177 |
178 | table_description += "Columns:\n" + "\n".join([f" - {name}: {col_type}" for name, col_type in columns_info])
179 | updated_description += "\n\n" + table_description
180 |
181 | print(updated_description)
182 | ```
183 | Since this request is a bit harder than the previous one, we’ll switch the LLM engine to use the more powerful [Qwen/Qwen2.5-Coder-32B-Instruct](https://huggingface.co/Qwen/Qwen2.5-Coder-32B-Instruct)!
184 |
185 | ```py
186 | sql_engine.description = updated_description
187 |
188 | agent = CodeAgent(
189 | tools=[sql_engine],
190 | model=HfApiModel("Qwen/Qwen2.5-Coder-32B-Instruct"),
191 | )
192 |
193 | agent.run("Which waiter got more total money from tips?")
194 | ```
195 | It directly works! The setup was surprisingly simple, wasn’t it?
196 |
197 | This example is done! We've touched upon these concepts:
198 | - Building new tools.
199 | - Updating a tool's description.
200 | - Switching to a stronger LLM helps agent reasoning.
201 |
202 | ✅ Now you can go build this text-to-SQL system you’ve always dreamt of! ✨
--------------------------------------------------------------------------------
/docs/source/en/examples/rag.md:
--------------------------------------------------------------------------------
1 |
16 | # Agentic RAG
17 |
18 | [[open-in-colab]]
19 |
20 | Retrieval-Augmented-Generation (RAG) is “using an LLM to answer a user query, but basing the answer on information retrieved from a knowledge base”. It has many advantages over using a vanilla or fine-tuned LLM: to name a few, it allows to ground the answer on true facts and reduce confabulations, it allows to provide the LLM with domain-specific knowledge, and it allows fine-grained control of access to information from the knowledge base.
21 |
22 | But vanilla RAG has limitations, most importantly these two:
23 | - It performs only one retrieval step: if the results are bad, the generation in turn will be bad.
24 | - Semantic similarity is computed with the user query as a reference, which might be suboptimal: for instance, the user query will often be a question and the document containing the true answer will be in affirmative voice, so its similarity score will be downgraded compared to other source documents in the interrogative form, leading to a risk of missing the relevant information.
25 |
26 | We can alleviate these problems by making a RAG agent: very simply, an agent armed with a retriever tool!
27 |
28 | This agent will: ✅ Formulate the query itself and ✅ Critique to re-retrieve if needed.
29 |
30 | So it should naively recover some advanced RAG techniques!
31 | - Instead of directly using the user query as the reference in semantic search, the agent formulates itself a reference sentence that can be closer to the targeted documents, as in [HyDE](https://huggingface.co/papers/2212.10496).
32 | The agent can the generated snippets and re-retrieve if needed, as in [Self-Query](https://docs.llamaindex.ai/en/stable/examples/evaluation/RetryQuery/).
33 |
34 | Let's build this system. 🛠️
35 |
36 | Run the line below to install required dependencies:
37 | ```bash
38 | !pip install prime pandas langchain langchain-community sentence-transformers faiss-cpu --upgrade -q
39 | ```
40 | To call the HF Inference API, you will need a valid token as your environment variable `HF_TOKEN`.
41 | We use python-dotenv to load it.
42 | ```py
43 | from dotenv import load_dotenv
44 | load_dotenv()
45 | ```
46 |
47 | We first load a knowledge base on which we want to perform RAG: this dataset is a compilation of the documentation pages for many Hugging Face libraries, stored as markdown. We will keep only the documentation for the `transformers` library.
48 |
49 | Then prepare the knowledge base by processing the dataset and storing it into a vector database to be used by the retriever.
50 |
51 | We use [LangChain](https://python.langchain.com/docs/introduction/) for its excellent vector database utilities.
52 |
53 | ```py
54 | import datasets
55 | from langchain.docstore.document import Document
56 | from langchain.text_splitter import RecursiveCharacterTextSplitter
57 | from langchain_community.retrievers import BM25Retriever
58 |
59 | knowledge_base = datasets.load_dataset("m-ric/huggingface_doc", split="train")
60 | knowledge_base = knowledge_base.filter(lambda row: row["source"].startswith("huggingface/transformers"))
61 |
62 | source_docs = [
63 | Document(page_content=doc["text"], metadata={"source": doc["source"].split("/")[1]})
64 | for doc in knowledge_base
65 | ]
66 |
67 | text_splitter = RecursiveCharacterTextSplitter(
68 | chunk_size=500,
69 | chunk_overlap=50,
70 | add_start_index=True,
71 | strip_whitespace=True,
72 | separators=["\n\n", "\n", ".", " ", ""],
73 | )
74 | docs_processed = text_splitter.split_documents(source_docs)
75 | ```
76 |
77 | Now the documents are ready.
78 |
79 | So let’s build our agentic RAG system!
80 |
81 | 👉 We only need a RetrieverTool that our agent can leverage to retrieve information from the knowledge base.
82 |
83 | Since we need to add a vectordb as an attribute of the tool, we cannot simply use the simple tool constructor with a `@tool` decorator: so we will follow the advanced setup highlighted in the [tools tutorial](../tutorials/tools).
84 |
85 | ```py
86 | from prime import Tool
87 |
88 | class RetrieverTool(Tool):
89 | name = "retriever"
90 | description = "Uses semantic search to retrieve the parts of transformers documentation that could be most relevant to answer your query."
91 | inputs = {
92 | "query": {
93 | "type": "string",
94 | "description": "The query to perform. This should be semantically close to your target documents. Use the affirmative form rather than a question.",
95 | }
96 | }
97 | output_type = "string"
98 |
99 | def __init__(self, docs, **kwargs):
100 | super().__init__(**kwargs)
101 | self.retriever = BM25Retriever.from_documents(
102 | docs, k=10
103 | )
104 |
105 | def forward(self, query: str) -> str:
106 | assert isinstance(query, str), "Your search query must be a string"
107 |
108 | docs = self.retriever.invoke(
109 | query,
110 | )
111 | return "\nRetrieved documents:\n" + "".join(
112 | [
113 | f"\n\n===== Document {str(i)} =====\n" + doc.page_content
114 | for i, doc in enumerate(docs)
115 | ]
116 | )
117 |
118 | retriever_tool = RetrieverTool(docs_processed)
119 | ```
120 | We have used BM25, a classic retrieval method, because it's lightning fast to setup.
121 | To improve retrieval accuracy, you could use replace BM25 with semantic search using vector representations for documents: thus you can head to the [MTEB Leaderboard](https://huggingface.co/spaces/mteb/leaderboard) to select a good embedding model.
122 |
123 | Now it’s straightforward to create an agent that leverages this `retriever_tool`!
124 |
125 | The agent will need these arguments upon initialization:
126 | - `tools`: a list of tools that the agent will be able to call.
127 | - `model`: the LLM that powers the agent.
128 | Our `model` must be a callable that takes as input a list of messages and returns text. It also needs to accept a stop_sequences argument that indicates when to stop its generation. For convenience, we directly use the HfEngine class provided in the package to get a LLM engine that calls Hugging Face's Inference API.
129 |
130 | And we use [meta-llama/Llama-3.3-70B-Instruct](meta-llama/Llama-3.3-70B-Instruct) as the llm engine because:
131 | - It has a long 128k context, which is helpful for processing long source documents
132 | - It is served for free at all times on HF's Inference API!
133 |
134 | _Note:_ The Inference API hosts models based on various criteria, and deployed models may be updated or replaced without prior notice. Learn more about it [here](https://huggingface.co/docs/api-inference/supported-models).
135 |
136 | ```py
137 | from prime import HfApiModel, CodeAgent
138 |
139 | agent = CodeAgent(
140 | tools=[retriever_tool], model=HfApiModel("meta-llama/Llama-3.3-70B-Instruct"), max_iterations=4, verbose=True
141 | )
142 | ```
143 |
144 | Upon initializing the CodeAgent, it has been automatically given a default system prompt that tells the LLM engine to process step-by-step and generate tool calls as code snippets, but you could replace this prompt template with your own as needed.
145 |
146 | Then when its `.run()` method is launched, the agent takes care of calling the LLM engine, and executing the tool calls, all in a loop that ends only when tool `final_answer` is called with the final answer as its argument.
147 |
148 | ```py
149 | agent_output = agent.run("For a transformers model training, which is slower, the forward or the backward pass?")
150 |
151 | print("Final output:")
152 | print(agent_output)
153 | ```
154 |
155 |
156 |
157 |
--------------------------------------------------------------------------------
/src/prime/tool_validation.py:
--------------------------------------------------------------------------------
1 | import ast
2 | import inspect
3 | import builtins
4 | from typing import Set
5 | import textwrap
6 | from .utils import BASE_BUILTIN_MODULES
7 |
8 | _BUILTIN_NAMES = set(vars(builtins))
9 |
10 |
11 | class MethodChecker(ast.NodeVisitor):
12 | """
13 | Checks that a method
14 | - only uses defined names
15 | - contains no local imports (e.g. numpy is ok but local_script is not)
16 | """
17 |
18 | def __init__(self, class_attributes: Set[str], check_imports: bool = True):
19 | self.undefined_names = set()
20 | self.imports = {}
21 | self.from_imports = {}
22 | self.assigned_names = set()
23 | self.arg_names = set()
24 | self.class_attributes = class_attributes
25 | self.errors = []
26 | self.check_imports = check_imports
27 |
28 | def visit_arguments(self, node):
29 | """Collect function arguments"""
30 | self.arg_names = {arg.arg for arg in node.args}
31 | if node.kwarg:
32 | self.arg_names.add(node.kwarg.arg)
33 | if node.vararg:
34 | self.arg_names.add(node.vararg.arg)
35 |
36 | def visit_Import(self, node):
37 | for name in node.names:
38 | actual_name = name.asname or name.name
39 | self.imports[actual_name] = name.name
40 |
41 | def visit_ImportFrom(self, node):
42 | module = node.module or ""
43 | for name in node.names:
44 | actual_name = name.asname or name.name
45 | self.from_imports[actual_name] = (module, name.name)
46 |
47 | def visit_Assign(self, node):
48 | for target in node.targets:
49 | if isinstance(target, ast.Name):
50 | self.assigned_names.add(target.id)
51 | self.visit(node.value)
52 |
53 | def visit_With(self, node):
54 | """Track aliases in 'with' statements (the 'y' in 'with X as y')"""
55 | for item in node.items:
56 | if item.optional_vars: # This is the 'y' in 'with X as y'
57 | if isinstance(item.optional_vars, ast.Name):
58 | self.assigned_names.add(item.optional_vars.id)
59 | self.generic_visit(node)
60 |
61 | def visit_ExceptHandler(self, node):
62 | """Track exception aliases (the 'e' in 'except Exception as e')"""
63 | if node.name: # This is the 'e' in 'except Exception as e'
64 | self.assigned_names.add(node.name)
65 | self.generic_visit(node)
66 |
67 | def visit_AnnAssign(self, node):
68 | """Track annotated assignments."""
69 | if isinstance(node.target, ast.Name):
70 | self.assigned_names.add(node.target.id)
71 | if node.value:
72 | self.visit(node.value)
73 |
74 | def visit_For(self, node):
75 | target = node.target
76 | if isinstance(target, ast.Name):
77 | self.assigned_names.add(target.id)
78 | elif isinstance(target, ast.Tuple):
79 | for elt in target.elts:
80 | if isinstance(elt, ast.Name):
81 | self.assigned_names.add(elt.id)
82 | self.generic_visit(node)
83 |
84 | def visit_Attribute(self, node):
85 | if not (isinstance(node.value, ast.Name) and node.value.id == "self"):
86 | self.generic_visit(node)
87 |
88 | def visit_Name(self, node):
89 | if isinstance(node.ctx, ast.Load):
90 | if not (
91 | node.id in _BUILTIN_NAMES
92 | or node.id in BASE_BUILTIN_MODULES
93 | or node.id in self.arg_names
94 | or node.id == "self"
95 | or node.id in self.class_attributes
96 | or node.id in self.imports
97 | or node.id in self.from_imports
98 | or node.id in self.assigned_names
99 | ):
100 | self.errors.append(f"Name '{node.id}' is undefined.")
101 |
102 | def visit_Call(self, node):
103 | if isinstance(node.func, ast.Name):
104 | if not (
105 | node.func.id in _BUILTIN_NAMES
106 | or node.func.id in BASE_BUILTIN_MODULES
107 | or node.func.id in self.arg_names
108 | or node.func.id == "self"
109 | or node.func.id in self.class_attributes
110 | or node.func.id in self.imports
111 | or node.func.id in self.from_imports
112 | or node.func.id in self.assigned_names
113 | ):
114 | self.errors.append(f"Name '{node.func.id}' is undefined.")
115 | self.generic_visit(node)
116 |
117 |
118 | def validate_tool_attributes(cls, check_imports: bool = True) -> None:
119 | """
120 | Validates that a Tool class follows the proper patterns:
121 | 0. __init__ takes no argument (args chosen at init are not traceable so we cannot rebuild the source code for them, make them class attributes!).
122 | 1. About the class:
123 | - Class attributes should only be strings or dicts
124 | - Class attributes cannot be complex attributes
125 | 2. About all class methods:
126 | - Imports must be from packages, not local files
127 | - All methods must be self-contained
128 |
129 | Raises all errors encountered, if no error returns None.
130 | """
131 | errors = []
132 |
133 | source = textwrap.dedent(inspect.getsource(cls))
134 |
135 | tree = ast.parse(source)
136 |
137 | if not isinstance(tree.body[0], ast.ClassDef):
138 | raise ValueError("Source code must define a class")
139 |
140 | # Check that __init__ method takes no arguments
141 | if not cls.__init__.__qualname__ == "Tool.__init__":
142 | sig = inspect.signature(cls.__init__)
143 | non_self_params = list(
144 | [arg_name for arg_name in sig.parameters.keys() if arg_name != "self"]
145 | )
146 | if len(non_self_params) > 0:
147 | errors.append(
148 | f"This tool has additional args specified in __init__(self): {non_self_params}. Make sure it does not, all values should be hardcoded!"
149 | )
150 |
151 | class_node = tree.body[0]
152 |
153 | class ClassLevelChecker(ast.NodeVisitor):
154 | def __init__(self):
155 | self.imported_names = set()
156 | self.complex_attributes = set()
157 | self.class_attributes = set()
158 | self.in_method = False
159 |
160 | def visit_FunctionDef(self, node):
161 | old_context = self.in_method
162 | self.in_method = True
163 | self.generic_visit(node)
164 | self.in_method = old_context
165 |
166 | def visit_Assign(self, node):
167 | if self.in_method:
168 | return
169 | # Track class attributes
170 | for target in node.targets:
171 | if isinstance(target, ast.Name):
172 | self.class_attributes.add(target.id)
173 |
174 | # Check if the assignment is more complex than simple literals
175 | if not all(
176 | isinstance(
177 | val, (ast.Str, ast.Num, ast.Constant, ast.Dict, ast.List, ast.Set)
178 | )
179 | for val in ast.walk(node.value)
180 | ):
181 | for target in node.targets:
182 | if isinstance(target, ast.Name):
183 | self.complex_attributes.add(target.id)
184 |
185 | class_level_checker = ClassLevelChecker()
186 | class_level_checker.visit(class_node)
187 |
188 | if class_level_checker.complex_attributes:
189 | errors.append(
190 | f"Complex attributes should be defined in __init__, not as class attributes: "
191 | f"{', '.join(class_level_checker.complex_attributes)}"
192 | )
193 |
194 | # Run checks on all methods
195 | for node in class_node.body:
196 | if isinstance(node, ast.FunctionDef):
197 | method_checker = MethodChecker(
198 | class_level_checker.class_attributes, check_imports=check_imports
199 | )
200 | method_checker.visit(node)
201 | errors += [f"- {node.name}: {error}" for error in method_checker.errors]
202 |
203 | if errors:
204 | raise ValueError("Tool validation failed:\n" + "\n".join(errors))
205 | return
206 |
--------------------------------------------------------------------------------
/docs/source/en/examples/multiagents.md:
--------------------------------------------------------------------------------
1 |
16 | # Orchestrate a multi-agent system 🤖🤝🤖
17 |
18 | [[open-in-colab]]
19 |
20 | In this notebook we will make a **multi-agent web browser: an agentic system with several agents collaborating to solve problems using the web!**
21 |
22 | It will be a simple hierarchy, using a `ManagedAgent` object to wrap the managed web search agent:
23 |
24 | ```
25 | +----------------+
26 | | Manager agent |
27 | +----------------+
28 | |
29 | _______________|______________
30 | | |
31 | Code interpreter +--------------------------------+
32 | tool | Managed agent |
33 | | +------------------+ |
34 | | | Web Search agent | |
35 | | +------------------+ |
36 | | | | |
37 | | Web Search tool | |
38 | | Visit webpage tool |
39 | +--------------------------------+
40 | ```
41 | Let's set up this system.
42 |
43 | Run the line below to install the required dependencies:
44 |
45 | ```
46 | !pip install markdownify duckduckgo-search prime --upgrade -q
47 | ```
48 |
49 | Let's login in order to call the HF Inference API:
50 |
51 | ```py
52 | from huggingface_hub import notebook_login
53 |
54 | notebook_login()
55 | ```
56 |
57 | ⚡️ Our agent will be powered by [Qwen/Qwen2.5-Coder-32B-Instruct](https://huggingface.co/Qwen/Qwen2.5-Coder-32B-Instruct) using `HfApiModel` class that uses HF's Inference API: the Inference API allows to quickly and easily run any OS model.
58 |
59 | _Note:_ The Inference API hosts models based on various criteria, and deployed models may be updated or replaced without prior notice. Learn more about it [here](https://huggingface.co/docs/api-inference/supported-models).
60 |
61 | ```py
62 | model_id = "Qwen/Qwen2.5-Coder-32B-Instruct"
63 | ```
64 |
65 | ## 🔍 Create a web search tool
66 |
67 | For web browsing, we can already use our pre-existing [`DuckDuckGoSearchTool`](https://github.com/huggingface/prime/blob/main/src/prime/default_tools/search.py) tool to provide a Google search equivalent.
68 |
69 | But then we will also need to be able to peak into the page found by the `DuckDuckGoSearchTool`.
70 | To do so, we could import the library's built-in `VisitWebpageTool`, but we will build it again to see how it's done.
71 |
72 | So let's create our `VisitWebpageTool` tool from scratch using `markdownify`.
73 |
74 | ```py
75 | import re
76 | import requests
77 | from markdownify import markdownify
78 | from requests.exceptions import RequestException
79 | from prime import tool
80 |
81 |
82 | @tool
83 | def visit_webpage(url: str) -> str:
84 | """Visits a webpage at the given URL and returns its content as a markdown string.
85 |
86 | Args:
87 | url: The URL of the webpage to visit.
88 |
89 | Returns:
90 | The content of the webpage converted to Markdown, or an error message if the request fails.
91 | """
92 | try:
93 | # Send a GET request to the URL
94 | response = requests.get(url)
95 | response.raise_for_status() # Raise an exception for bad status codes
96 |
97 | # Convert the HTML content to Markdown
98 | markdown_content = markdownify(response.text).strip()
99 |
100 | # Remove multiple line breaks
101 | markdown_content = re.sub(r"\n{3,}", "\n\n", markdown_content)
102 |
103 | return markdown_content
104 |
105 | except RequestException as e:
106 | return f"Error fetching the webpage: {str(e)}"
107 | except Exception as e:
108 | return f"An unexpected error occurred: {str(e)}"
109 | ```
110 |
111 | Ok, now let's initialize and test our tool!
112 |
113 | ```py
114 | print(visit_webpage("https://en.wikipedia.org/wiki/Hugging_Face")[:500])
115 | ```
116 |
117 | ## Build our multi-agent system 🤖🤝🤖
118 |
119 | Now that we have all the tools `search` and `visit_webpage`, we can use them to create the web agent.
120 |
121 | Which configuration to choose for this agent?
122 | - Web browsing is a single-timeline task that does not require parallel tool calls, so JSON tool calling works well for that. We thus choose a `JsonAgent`.
123 | - Also, since sometimes web search requires exploring many pages before finding the correct answer, we prefer to increase the number of `max_iterations` to 10.
124 |
125 | ```py
126 | from prime import (
127 | CodeAgent,
128 | ToolCallingAgent,
129 | HfApiModel,
130 | ManagedAgent,
131 | DuckDuckGoSearchTool,
132 | LiteLLMModel,
133 | )
134 |
135 | model = HfApiModel(model_id)
136 |
137 | web_agent = ToolCallingAgent(
138 | tools=[DuckDuckGoSearchTool(), visit_webpage],
139 | model=model,
140 | max_iterations=10,
141 | )
142 | ```
143 |
144 | We then wrap this agent into a `ManagedAgent` that will make it callable by its manager agent.
145 |
146 | ```py
147 | managed_web_agent = ManagedAgent(
148 | agent=web_agent,
149 | name="search",
150 | description="Runs web searches for you. Give it your query as an argument.",
151 | )
152 | ```
153 |
154 | Finally we create a manager agent, and upon initialization we pass our managed agent to it in its `managed_agents` argument.
155 |
156 | Since this agent is the one tasked with the planning and thinking, advanced reasoning will be beneficial, so a `CodeAgent` will be the best choice.
157 |
158 | Also, we want to ask a question that involves the current year and does additional data calculations: so let us add `additional_authorized_imports=["time", "numpy", "pandas"]`, just in case the agent needs these packages.
159 |
160 | ```py
161 | manager_agent = CodeAgent(
162 | tools=[],
163 | model=model,
164 | managed_agents=[managed_web_agent],
165 | additional_authorized_imports=["time", "numpy", "pandas"],
166 | )
167 | ```
168 |
169 | That's all! Now let's run our system! We select a question that requires both some calculation and research:
170 |
171 | ```py
172 | answer = manager_agent.run("If LLM trainings continue to scale up at the current rythm until 2030, what would be the electric power in GW required to power the biggest training runs by 2030? What does that correspond to, compared to some contries? Please provide a source for any number used.")
173 | ```
174 |
175 | We get this report as the answer:
176 | ```
177 | Based on current growth projections and energy consumption estimates, if LLM trainings continue to scale up at the
178 | current rhythm until 2030:
179 |
180 | 1. The electric power required to power the biggest training runs by 2030 would be approximately 303.74 GW, which
181 | translates to about 2,660,762 GWh/year.
182 |
183 | 2. Comparing this to countries' electricity consumption:
184 | - It would be equivalent to about 34% of China's total electricity consumption.
185 | - It would exceed the total electricity consumption of India (184%), Russia (267%), and Japan (291%).
186 | - It would be nearly 9 times the electricity consumption of countries like Italy or Mexico.
187 |
188 | 3. Source of numbers:
189 | - The initial estimate of 5 GW for future LLM training comes from AWS CEO Matt Garman.
190 | - The growth projection used a CAGR of 79.80% from market research by Springs.
191 | - Country electricity consumption data is from the U.S. Energy Information Administration, primarily for the year
192 | 2021.
193 | ```
194 |
195 | Seems like we'll need some sizeable powerplants if the [scaling hypothesis](https://gwern.net/scaling-hypothesis) continues to hold true.
196 |
197 | Our agents managed to efficiently collaborate towards solving the task! ✅
198 |
199 | 💡 You can easily extend this orchestration to more agents: one does the code execution, one the web search, one handles file loadings...
--------------------------------------------------------------------------------
/docs/source/en/conceptual_guides/intro_agents.md:
--------------------------------------------------------------------------------
1 |
16 | # Introduction to Agents
17 |
18 | ## 🤔 What are agents?
19 |
20 | Any efficient system using AI will need to provide LLMs some kind of access to the real world: for instance the possibility to call a search tool to get external information, or to act on certain programs in order to solve a task. In other words, LLMs should have ***agency***. Agentic programs are the gateway to the outside world for LLMs.
21 |
22 | > [!TIP]
23 | > AI Agents are **programs where LLM outputs control the workflow**.
24 |
25 | Any system leveraging LLMs will integrate the LLM outputs into code. The influence of the LLM's input on the code workflow is the level of agency of LLMs in the system.
26 |
27 | Note that with this definition, "agent" is not a discrete, 0 or 1 definition: instead, "agency" evolves on a continuous spectrum, as you give more or less power to the LLM on your workflow.
28 |
29 | See in the table below how agency can vary across systems:
30 |
31 | | Agency Level | Description | How that's called | Example Pattern |
32 | | ------------ | ------------------------------------------------------- | ----------------- | -------------------------------------------------- |
33 | | ☆☆☆ | LLM output has no impact on program flow | Simple Processor | `process_llm_output(llm_response)` |
34 | | ★☆☆ | LLM output determines an if/else switch | Router | `if llm_decision(): path_a() else: path_b()` |
35 | | ★★☆ | LLM output determines function execution | Tool Caller | `run_function(llm_chosen_tool, llm_chosen_args)` |
36 | | ★★★ | LLM output controls iteration and program continuation | Multi-step Agent | `while llm_should_continue(): execute_next_step()` |
37 | | ★★★ | One agentic workflow can start another agentic workflow | Multi-Agent | `if llm_trigger(): execute_agent()` |
38 |
39 | The multi-step agent has this code structure:
40 |
41 | ```python
42 | memory = [user_defined_task]
43 | while llm_should_continue(memory): # this loop is the multi-step part
44 | action = llm_get_next_action(memory) # this is the tool-calling part
45 | observations = execute_action(action)
46 | memory += [action, observations]
47 | ```
48 |
49 | This agentic system runs in a loop, executing a new action at each step (the action can involve calling some pre-determined *tools* that are just functions), until its observations make it apparent that a satisfactory state has been reached to solve the given task. Here’s an example of how a multi-step agent can solve a simple math question:
50 |
51 |
52 |
53 |
54 |
55 |
56 | ## ✅ When to use agents / ⛔ when to avoid them
57 |
58 | Agents are useful when you need an LLM to determine the workflow of an app. But they’re often overkill. The question is: do I really need flexibility in the workflow to efficiently solve the task at hand?
59 | If the pre-determined workflow falls short too often, that means you need more flexibility.
60 | Let's take an example: say you're making an app that handles customer requests on a surfing trip website.
61 |
62 | You could know in advance that the requests will can belong to either of 2 buckets (based on user choice), and you have a predefined workflow for each of these 2 cases.
63 |
64 | 1. Want some knowledge on the trips? ⇒ give them access to a search bar to search your knowledge base
65 | 2. Wants to talk to sales? ⇒ let them type in a contact form.
66 |
67 | If that deterministic workflow fits all queries, by all means just code everything! This will give you a 100% reliable system with no risk of error introduced by letting unpredictable LLMs meddle in your workflow. For the sake of simplicity and robustness, it's advised to regularize towards not using any agentic behaviour.
68 |
69 | But what if the workflow can't be determined that well in advance?
70 |
71 | For instance, a user wants to ask : `"I can come on Monday, but I forgot my passport so risk being delayed to Wednesday, is it possible to take me and my stuff to surf on Tuesday morning, with a cancellation insurance?"` This question hinges on many factors, and probably none of the predetermined criteria above will suffice for this request.
72 |
73 | If the pre-determined workflow falls short too often, that means you need more flexibility.
74 |
75 | That is where an agentic setup helps.
76 |
77 | In the above example, you could just make a multi-step agent that has access to a weather API for weather forecasts, Google Maps API to compute travel distance, an employee availability dashboard and a RAG system on your knowledge base.
78 |
79 | Until recently, computer programs were restricted to pre-determined workflows, trying to handle complexity by piling up if/else switches. They focused on extremely narrow tasks, like "compute the sum of these numbers" or "find the shortest path in this graph". But actually, most real-life tasks, like our trip example above, do not fit in pre-determined workflows. Agentic systems open up the vast world of real-world tasks to programs!
80 |
81 | ## Why `prime`?
82 |
83 | For some low-level agentic use cases, like chains or routers, you can write all the code yourself. You'll be much better that way, since it will let you control and understand your system better.
84 |
85 | But once you start going for more complicated behaviours like letting an LLM call a function (that's "tool calling") or letting an LLM run a while loop ("multi-step agent"), some abstractions become necessary:
86 | - for tool calling, you need to parse the agent's output, so this output needs a predefined format like "Thought: I should call tool 'get_weather'. Action: get_weather(Paris).", that you parse with a predefined function, and system prompt given to the LLM should notify it about this format.
87 | - for a multi-step agent where the LLM output determines the loop, you need to give a different prompt to the LLM based on what happened in the last loop iteration: so you need some kind of memory.
88 |
89 | See? With these two examples, we already found the need for a few items to help us:
90 |
91 | - Of course, an LLM that acts as the engine powering the system
92 | - A list of tools that the agent can access
93 | - A parser that extracts tool calls from the LLM output
94 | - A system prompt synced with the parser
95 | - A memory
96 |
97 | But wait, since we give room to LLMs in decisions, surely they will make mistakes: so we need error logging and retry mechanisms.
98 |
99 | All these elements need tight coupling to make a well-functioning system. That's why we decided we needed to make basic building blocks to make all this stuff work together.
100 |
101 | ## Code agents
102 |
103 | In a multi-step agent, at each step, the LLM can write an action, in the form of some calls to external tools. A common format (used by Anthropic, OpenAI, and many others) for writing these actions is generally different shades of "writing actions as a JSON of tools names and arguments to use, which you then parse to know which tool to execute and with which arguments".
104 |
105 | [Multiple](https://huggingface.co/papers/2402.01030) [research](https://huggingface.co/papers/2411.01747) [papers](https://huggingface.co/papers/2401.00812) have shown that having the tool calling LLMs in code is much better.
106 |
107 | The reason for this simply that *we crafted our code languages specifically to be the best possible way to express actions performed by a computer*. If JSON snippets were a better expression, JSON would be the top programming language and programming would be hell on earth.
108 |
109 | The figure below, taken from [Executable Code Actions Elicit Better LLM Agents](https://huggingface.co/papers/2402.01030), illustrate some advantages of writing actions in code:
110 |
111 |
112 |
113 | Writing actions in code rather than JSON-like snippets provides better:
114 |
115 | - **Composability:** could you nest JSON actions within each other, or define a set of JSON actions to re-use later, the same way you could just define a python function?
116 | - **Object management:** how do you store the output of an action like `generate_image` in JSON?
117 | - **Generality:** code is built to express simply anything you can have a computer do.
118 | - **Representation in LLM training data:** plenty of quality code actions is already included in LLMs’ training data which means they’re already trained for this!
119 |
--------------------------------------------------------------------------------
/src/prime/types.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2024 HuggingFace Inc.
3 | #
4 |
5 | import os
6 | import pathlib
7 | import tempfile
8 | import uuid
9 | from io import BytesIO
10 | import requests
11 | import numpy as np
12 |
13 | from transformers.utils import (
14 | is_soundfile_availble,
15 | is_torch_available,
16 | is_vision_available,
17 | )
18 | import logging
19 |
20 |
21 | logger = logging.getLogger(__name__)
22 |
23 | if is_vision_available():
24 | from PIL import Image
25 | from PIL.Image import Image as ImageType
26 | else:
27 | ImageType = object
28 |
29 | if is_torch_available():
30 | import torch
31 | from torch import Tensor
32 | else:
33 | Tensor = object
34 |
35 | if is_soundfile_availble():
36 | import soundfile as sf
37 |
38 |
39 | class AgentType:
40 | """
41 | Abstract class to be reimplemented to define types that can be returned by agents.
42 |
43 | These objects serve three purposes:
44 |
45 | - They behave as they were the type they're meant to be, e.g., a string for text, a PIL.Image for images
46 | - They can be stringified: str(object) in order to return a string defining the object
47 | - They should be displayed correctly in ipython notebooks/colab/jupyter
48 | """
49 |
50 | def __init__(self, value):
51 | self._value = value
52 |
53 | def __str__(self):
54 | return self.to_string()
55 |
56 | def to_raw(self):
57 | logger.error(
58 | "This is a raw AgentType of unknown type. Display in notebooks and string conversion will be unreliable"
59 | )
60 | return self._value
61 |
62 | def to_string(self) -> str:
63 | logger.error(
64 | "This is a raw AgentType of unknown type. Display in notebooks and string conversion will be unreliable"
65 | )
66 | return str(self._value)
67 |
68 |
69 | class AgentText(AgentType, str):
70 | """
71 | Text type returned by the agent. Behaves as a string.
72 | """
73 |
74 | def to_raw(self):
75 | return self._value
76 |
77 | def to_string(self):
78 | return str(self._value)
79 |
80 |
81 | class AgentImage(AgentType, ImageType):
82 | """
83 | Image type returned by the agent. Behaves as a PIL.Image.
84 | """
85 |
86 | def __init__(self, value):
87 | AgentType.__init__(self, value)
88 | ImageType.__init__(self)
89 |
90 | if not is_vision_available():
91 | raise ImportError("PIL must be installed in order to handle images.")
92 |
93 | self._path = None
94 | self._raw = None
95 | self._tensor = None
96 |
97 | if isinstance(value, AgentImage):
98 | self._raw, self._path, self._tensor = value._raw, value._path, value._tensor
99 | elif isinstance(value, ImageType):
100 | self._raw = value
101 | elif isinstance(value, bytes):
102 | self._raw = Image.open(BytesIO(value))
103 | elif isinstance(value, (str, pathlib.Path)):
104 | self._path = value
105 | elif isinstance(value, torch.Tensor):
106 | self._tensor = value
107 | elif isinstance(value, np.ndarray):
108 | self._tensor = torch.from_numpy(value)
109 | else:
110 | raise TypeError(
111 | f"Unsupported type for {self.__class__.__name__}: {type(value)}"
112 | )
113 |
114 | def _ipython_display_(self, include=None, exclude=None):
115 | """
116 | Displays correctly this type in an ipython notebook (ipython, colab, jupyter, ...)
117 | """
118 | from IPython.display import Image, display
119 |
120 | display(Image(self.to_string()))
121 |
122 | def to_raw(self):
123 | """
124 | Returns the "raw" version of that object. In the case of an AgentImage, it is a PIL.Image.
125 | """
126 | if self._raw is not None:
127 | return self._raw
128 |
129 | if self._path is not None:
130 | self._raw = Image.open(self._path)
131 | return self._raw
132 |
133 | if self._tensor is not None:
134 | array = self._tensor.cpu().detach().numpy()
135 | return Image.fromarray((255 - array * 255).astype(np.uint8))
136 |
137 | def to_string(self):
138 | """
139 | Returns the stringified version of that object. In the case of an AgentImage, it is a path to the serialized
140 | version of the image.
141 | """
142 | if self._path is not None:
143 | return self._path
144 |
145 | if self._raw is not None:
146 | directory = tempfile.mkdtemp()
147 | self._path = os.path.join(directory, str(uuid.uuid4()) + ".png")
148 | self._raw.save(self._path, format="png")
149 | return self._path
150 |
151 | if self._tensor is not None:
152 | array = self._tensor.cpu().detach().numpy()
153 |
154 | # There is likely simpler than load into image into save
155 | img = Image.fromarray((255 - array * 255).astype(np.uint8))
156 |
157 | directory = tempfile.mkdtemp()
158 | self._path = os.path.join(directory, str(uuid.uuid4()) + ".png")
159 | img.save(self._path, format="png")
160 |
161 | return self._path
162 |
163 | def save(self, output_bytes, format: str = None, **params):
164 | """
165 | Saves the image to a file.
166 | Args:
167 | output_bytes (bytes): The output bytes to save the image to.
168 | format (str): The format to use for the output image. The format is the same as in PIL.Image.save.
169 | **params: Additional parameters to pass to PIL.Image.save.
170 | """
171 | img = self.to_raw()
172 | img.save(output_bytes, format=format, **params)
173 |
174 |
175 | class AgentAudio(AgentType, str):
176 | """
177 | Audio type returned by the agent.
178 | """
179 |
180 | def __init__(self, value, samplerate=16_000):
181 | super().__init__(value)
182 |
183 | if not is_soundfile_availble():
184 | raise ImportError("soundfile must be installed in order to handle audio.")
185 |
186 | self._path = None
187 | self._tensor = None
188 |
189 | self.samplerate = samplerate
190 | if isinstance(value, (str, pathlib.Path)):
191 | self._path = value
192 | elif is_torch_available() and isinstance(value, torch.Tensor):
193 | self._tensor = value
194 | elif isinstance(value, tuple):
195 | self.samplerate = value[0]
196 | if isinstance(value[1], np.ndarray):
197 | self._tensor = torch.from_numpy(value[1])
198 | else:
199 | self._tensor = torch.tensor(value[1])
200 | else:
201 | raise ValueError(f"Unsupported audio type: {type(value)}")
202 |
203 | def _ipython_display_(self, include=None, exclude=None):
204 | """
205 | Displays correctly this type in an ipython notebook (ipython, colab, jupyter, ...)
206 | """
207 | from IPython.display import Audio, display
208 |
209 | display(Audio(self.to_string(), rate=self.samplerate))
210 |
211 | def to_raw(self):
212 | """
213 | Returns the "raw" version of that object. It is a `torch.Tensor` object.
214 | """
215 | if self._tensor is not None:
216 | return self._tensor
217 |
218 | if self._path is not None:
219 | if "://" in str(self._path):
220 | response = requests.get(self._path)
221 | response.raise_for_status()
222 | tensor, self.samplerate = sf.read(BytesIO(response.content))
223 | else:
224 | tensor, self.samplerate = sf.read(self._path)
225 | self._tensor = torch.tensor(tensor)
226 | return self._tensor
227 |
228 | def to_string(self):
229 | """
230 | Returns the stringified version of that object. In the case of an AgentAudio, it is a path to the serialized
231 | version of the audio.
232 | """
233 | if self._path is not None:
234 | return self._path
235 |
236 | if self._tensor is not None:
237 | directory = tempfile.mkdtemp()
238 | self._path = os.path.join(directory, str(uuid.uuid4()) + ".wav")
239 | sf.write(self._path, self._tensor, samplerate=self.samplerate)
240 | return self._path
241 |
242 |
243 | AGENT_TYPE_MAPPING = {"string": AgentText, "image": AgentImage, "audio": AgentAudio}
244 | INSTANCE_TYPE_MAPPING = {
245 | str: AgentText,
246 | ImageType: AgentImage,
247 | torch.Tensor: AgentAudio,
248 | }
249 |
250 | if is_torch_available():
251 | INSTANCE_TYPE_MAPPING[Tensor] = AgentAudio
252 |
253 |
254 | def handle_agent_input_types(*args, **kwargs):
255 | args = [(arg.to_raw() if isinstance(arg, AgentType) else arg) for arg in args]
256 | kwargs = {
257 | k: (v.to_raw() if isinstance(v, AgentType) else v) for k, v in kwargs.items()
258 | }
259 | return args, kwargs
260 |
261 |
262 | def handle_agent_output_types(output, output_type=None):
263 | if output_type in AGENT_TYPE_MAPPING:
264 | # If the class has defined outputs, we can map directly according to the class definition
265 | decoded_outputs = AGENT_TYPE_MAPPING[output_type](output)
266 | return decoded_outputs
267 | else:
268 | # If the class does not have defined output, then we map according to the type
269 | for _k, _v in INSTANCE_TYPE_MAPPING.items():
270 | if isinstance(output, _k):
271 | return _v(output)
272 | return output
273 |
274 |
275 | __all__ = ["AgentType", "AgentImage", "AgentText", "AgentAudio"]
276 |
--------------------------------------------------------------------------------
/src/prime/utils.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # coding=utf-8
3 |
4 |
5 | import json
6 | import re
7 | from typing import Tuple, Dict, Union
8 | import ast
9 | from rich.console import Console
10 | import inspect
11 | import types
12 |
13 | from transformers.utils.import_utils import _is_package_available
14 |
15 | _pygments_available = _is_package_available("pygments")
16 |
17 |
18 | def is_pygments_available():
19 | return _pygments_available
20 |
21 |
22 | console = Console()
23 |
24 | BASE_BUILTIN_MODULES = [
25 | "collections",
26 | "datetime",
27 | "itertools",
28 | "math",
29 | "queue",
30 | "random",
31 | "re",
32 | "stat",
33 | "statistics",
34 | "time",
35 | "unicodedata",
36 | ]
37 |
38 |
39 | class AgentError(Exception):
40 | """Base class for other agent-related exceptions"""
41 |
42 | def __init__(self, message):
43 | super().__init__(message)
44 | self.message = message
45 | console.print(f"[bold red]{message}[/bold red]")
46 |
47 |
48 | class AgentParsingError(AgentError):
49 | """Exception raised for errors in parsing in the agent"""
50 |
51 | pass
52 |
53 |
54 | class AgentExecutionError(AgentError):
55 | """Exception raised for errors in execution in the agent"""
56 |
57 | pass
58 |
59 |
60 | class AgentMaxIterationsError(AgentError):
61 | """Exception raised for errors in execution in the agent"""
62 |
63 | pass
64 |
65 |
66 | class AgentGenerationError(AgentError):
67 | """Exception raised for errors in generation in the agent"""
68 |
69 | pass
70 |
71 |
72 | def parse_json_blob(json_blob: str) -> Dict[str, str]:
73 | try:
74 | first_accolade_index = json_blob.find("{")
75 | last_accolade_index = [a.start() for a in list(re.finditer("}", json_blob))][-1]
76 | json_blob = json_blob[first_accolade_index : last_accolade_index + 1].replace(
77 | '\\"', "'"
78 | )
79 | json_data = json.loads(json_blob, strict=False)
80 | return json_data
81 | except json.JSONDecodeError as e:
82 | place = e.pos
83 | if json_blob[place - 1 : place + 2] == "},\n":
84 | raise ValueError(
85 | "JSON is invalid: you probably tried to provide multiple tool calls in one action. PROVIDE ONLY ONE TOOL CALL."
86 | )
87 | raise ValueError(
88 | f"The JSON blob you used is invalid due to the following error: {e}.\n"
89 | f"JSON blob was: {json_blob}, decoding failed on that specific part of the blob:\n"
90 | f"'{json_blob[place-4:place+5]}'."
91 | )
92 | except Exception as e:
93 | raise ValueError(f"Error in parsing the JSON blob: {e}")
94 |
95 |
96 | def parse_code_blob(code_blob: str) -> str:
97 | try:
98 | pattern = r"```(?:py|python)?\n(.*?)\n```"
99 | match = re.search(pattern, code_blob, re.DOTALL)
100 | if match is None:
101 | raise ValueError(
102 | f"No match ground for regex pattern {pattern} in {code_blob=}."
103 | )
104 | return match.group(1).strip()
105 |
106 | except Exception as e:
107 | raise ValueError(
108 | f"""
109 | The code blob you used is invalid: due to the following error: {e}
110 | This means that the regex pattern {pattern} was not respected: make sure to include code with the correct pattern, for instance:
111 | Thoughts: Your thoughts
112 | Code:
113 | ```py
114 | # Your python code here
115 | ```"""
116 | )
117 |
118 |
119 | def parse_json_tool_call(json_blob: str) -> Tuple[str, Union[str, None]]:
120 | json_blob = json_blob.replace("```json", "").replace("```", "")
121 | tool_call = parse_json_blob(json_blob)
122 | tool_name_key, tool_arguments_key = None, None
123 | for possible_tool_name_key in ["action", "tool_name", "tool", "name", "function"]:
124 | if possible_tool_name_key in tool_call:
125 | tool_name_key = possible_tool_name_key
126 | for possible_tool_arguments_key in [
127 | "action_input",
128 | "tool_arguments",
129 | "tool_args",
130 | "parameters",
131 | ]:
132 | if possible_tool_arguments_key in tool_call:
133 | tool_arguments_key = possible_tool_arguments_key
134 | if tool_name_key is not None:
135 | if tool_arguments_key is not None:
136 | return tool_call[tool_name_key], tool_call[tool_arguments_key]
137 | else:
138 | return tool_call[tool_name_key], None
139 | error_msg = "No tool name key found in tool call!" + f" Tool call: {json_blob}"
140 | raise AgentParsingError(error_msg)
141 |
142 |
143 | MAX_LENGTH_TRUNCATE_CONTENT = 20000
144 |
145 |
146 | def truncate_content(
147 | content: str, max_length: int = MAX_LENGTH_TRUNCATE_CONTENT
148 | ) -> str:
149 | if len(content) <= max_length:
150 | return content
151 | else:
152 | return (
153 | content[: MAX_LENGTH_TRUNCATE_CONTENT // 2]
154 | + f"\n..._This content has been truncated to stay below {max_length} characters_...\n"
155 | + content[-MAX_LENGTH_TRUNCATE_CONTENT // 2 :]
156 | )
157 |
158 |
159 | class ImportFinder(ast.NodeVisitor):
160 | def __init__(self):
161 | self.packages = set()
162 |
163 | def visit_Import(self, node):
164 | for alias in node.names:
165 | # Get the base package name (before any dots)
166 | base_package = alias.name.split(".")[0]
167 | self.packages.add(base_package)
168 |
169 | def visit_ImportFrom(self, node):
170 | if node.module: # for "from x import y" statements
171 | # Get the base package name (before any dots)
172 | base_package = node.module.split(".")[0]
173 | self.packages.add(base_package)
174 |
175 |
176 | def get_method_source(method):
177 | """Get source code for a method, including bound methods."""
178 | if isinstance(method, types.MethodType):
179 | method = method.__func__
180 | return inspect.getsource(method).strip()
181 |
182 |
183 | def is_same_method(method1, method2):
184 | """Compare two methods by their source code."""
185 | try:
186 | source1 = get_method_source(method1)
187 | source2 = get_method_source(method2)
188 |
189 | # Remove method decorators if any
190 | source1 = "\n".join(
191 | line for line in source1.split("\n") if not line.strip().startswith("@")
192 | )
193 | source2 = "\n".join(
194 | line for line in source2.split("\n") if not line.strip().startswith("@")
195 | )
196 |
197 | return source1 == source2
198 | except (TypeError, OSError):
199 | return False
200 |
201 |
202 | def is_same_item(item1, item2):
203 | """Compare two class items (methods or attributes) for equality."""
204 | if callable(item1) and callable(item2):
205 | return is_same_method(item1, item2)
206 | else:
207 | return item1 == item2
208 |
209 |
210 | def instance_to_source(instance, base_cls=None):
211 | """Convert an instance to its class source code representation."""
212 | cls = instance.__class__
213 | class_name = cls.__name__
214 |
215 | # Start building class lines
216 | class_lines = []
217 | if base_cls:
218 | class_lines.append(f"class {class_name}({base_cls.__name__}):")
219 | else:
220 | class_lines.append(f"class {class_name}:")
221 |
222 | # Add docstring if it exists and differs from base
223 | if cls.__doc__ and (not base_cls or cls.__doc__ != base_cls.__doc__):
224 | class_lines.append(f' """{cls.__doc__}"""')
225 |
226 | # Add class-level attributes
227 | class_attrs = {
228 | name: value
229 | for name, value in cls.__dict__.items()
230 | if not name.startswith("__")
231 | and not callable(value)
232 | and not (
233 | base_cls and hasattr(base_cls, name) and getattr(base_cls, name) == value
234 | )
235 | }
236 |
237 | for name, value in class_attrs.items():
238 | if isinstance(value, str):
239 | if "\n" in value:
240 | class_lines.append(f' {name} = """{value}"""')
241 | else:
242 | class_lines.append(f' {name} = "{value}"')
243 | else:
244 | class_lines.append(f" {name} = {repr(value)}")
245 |
246 | if class_attrs:
247 | class_lines.append("")
248 |
249 | # Add methods
250 | methods = {
251 | name: func
252 | for name, func in cls.__dict__.items()
253 | if callable(func)
254 | and not (
255 | base_cls
256 | and hasattr(base_cls, name)
257 | and getattr(base_cls, name).__code__.co_code == func.__code__.co_code
258 | )
259 | }
260 |
261 | for name, method in methods.items():
262 | method_source = inspect.getsource(method)
263 | # Clean up the indentation
264 | method_lines = method_source.split("\n")
265 | first_line = method_lines[0]
266 | indent = len(first_line) - len(first_line.lstrip())
267 | method_lines = [line[indent:] for line in method_lines]
268 | method_source = "\n".join(
269 | [" " + line if line.strip() else line for line in method_lines]
270 | )
271 | class_lines.append(method_source)
272 | class_lines.append("")
273 |
274 | # Find required imports using ImportFinder
275 | import_finder = ImportFinder()
276 | import_finder.visit(ast.parse("\n".join(class_lines)))
277 | required_imports = import_finder.packages
278 |
279 | # Build final code with imports
280 | final_lines = []
281 |
282 | # Add base class import if needed
283 | if base_cls:
284 | final_lines.append(f"from {base_cls.__module__} import {base_cls.__name__}")
285 |
286 | # Add discovered imports
287 | for package in required_imports:
288 | final_lines.append(f"import {package}")
289 |
290 | if final_lines: # Add empty line after imports
291 | final_lines.append("")
292 |
293 | # Add the class code
294 | final_lines.extend(class_lines)
295 |
296 | return "\n".join(final_lines)
297 |
298 |
299 | __all__ = ["AgentError"]
300 |
--------------------------------------------------------------------------------
/docs/source/en/tutorials/tools.md:
--------------------------------------------------------------------------------
1 |
16 | # Tools
17 |
18 | [[open-in-colab]]
19 |
20 | Here, we're going to see advanced tool usage.
21 |
22 | > [!TIP]
23 | > If you're new to building agents, make sure to first read the [intro to agents](../conceptual_guides/intro_agents) and the [guided tour of prime](../guided_tour).
24 |
25 | - [Tools](#tools)
26 | - [What is a tool, and how to build one?](#what-is-a-tool-and-how-to-build-one)
27 | - [Share your tool to the Hub](#share-your-tool-to-the-hub)
28 | - [Import a Space as a tool](#import-a-space-as-a-tool)
29 | - [Use LangChain tools](#use-langchain-tools)
30 | - [Manage your agent's toolbox](#manage-your-agents-toolbox)
31 | - [Use a collection of tools](#use-a-collection-of-tools)
32 |
33 | ### What is a tool, and how to build one?
34 |
35 | A tool is mostly a function that an LLM can use in an agentic system.
36 |
37 | But to use it, the LLM will need to be given an API: name, tool description, input types and descriptions, output type.
38 |
39 | So it cannot be only a function. It should be a class.
40 |
41 | So at core, the tool is a class that wraps a function with metadata that helps the LLM understand how to use it.
42 |
43 | Here's how it looks:
44 |
45 | ```python
46 | from prime import Tool
47 |
48 | class HFModelDownloadsTool(Tool):
49 | name = "model_download_counter"
50 | description = """
51 | This is a tool that returns the most downloaded model of a given task on the Hugging Face Hub.
52 | It returns the name of the checkpoint."""
53 | inputs = {
54 | "task": {
55 | "type": "string",
56 | "description": "the task category (such as text-classification, depth-estimation, etc)",
57 | }
58 | }
59 | output_type = "string"
60 |
61 | def forward(self, task: str):
62 | from huggingface_hub import list_models
63 |
64 | model = next(iter(list_models(filter=task, sort="downloads", direction=-1)))
65 | return model.id
66 |
67 | model_downloads_tool = HFModelDownloadsTool()
68 | ```
69 |
70 | The custom tool subclasses [`Tool`] to inherit useful methods. The child class also defines:
71 | - An attribute `name`, which corresponds to the name of the tool itself. The name usually describes what the tool does. Since the code returns the model with the most downloads for a task, let's name it `model_download_counter`.
72 | - An attribute `description` is used to populate the agent's system prompt.
73 | - An `inputs` attribute, which is a dictionary with keys `"type"` and `"description"`. It contains information that helps the Python interpreter make educated choices about the input.
74 | - An `output_type` attribute, which specifies the output type. The types for both `inputs` and `output_type` should be [Pydantic formats](https://docs.pydantic.dev/latest/concepts/json_schema/#generating-json-schema), they can be either of these: [`~AUTHORIZED_TYPES`].
75 | - A `forward` method which contains the inference code to be executed.
76 |
77 | And that's all it needs to be used in an agent!
78 |
79 | There's another way to build a tool. In the [guided_tour](../guided_tour), we implemented a tool using the `@tool` decorator. The [`tool`] decorator is the recommended way to define simple tools, but sometimes you need more than this: using several methods in a class for more clarity, or using additional class attributes.
80 |
81 | In this case, you can build your tool by subclassing [`Tool`] as described above.
82 |
83 | ### Share your tool to the Hub
84 |
85 | You can share your custom tool to the Hub by calling [`~Tool.push_to_hub`] on the tool. Make sure you've created a repository for it on the Hub and are using a token with read access.
86 |
87 | ```python
88 | model_downloads_tool.push_to_hub("{your_username}/hf-model-downloads", token="")
89 | ```
90 |
91 | For the push to Hub to work, your tool will need to respect some rules:
92 | - All method are self-contained, e.g. use variables that come either from their args.
93 | - As per the above point, **all imports should be defined directky within the tool's functions**, else you will get an error when trying to call [`~Tool.save`] or [`~Tool.push_to_hub`] with your custom tool.
94 | - If you subclass the `__init__` method, you can give it no other argument than `self`. This is because arguments set during a specific tool instance's initialization are hard to track, which prevents from sharing them properly to the hub. And anyway, the idea of making a specific class is that you can already set class attributes for anything you need to hard-code (just set `your_variable=(...)` directly under the `class YourTool(Tool):` line). And of course you can still create a class attribute anywhere in your code by assigning stuff to `self.your_variable`.
95 |
96 |
97 | Once your tool is pushed to Hub, you can visualize it. [Here](https://huggingface.co/spaces/m-ric/hf-model-downloads) is the `model_downloads_tool` that I've pushed. It has a nice gradio interface.
98 |
99 | When diving into the tool files, you can find that all the tool's logic is under [tool.py](https://huggingface.co/spaces/m-ric/hf-model-downloads/blob/main/tool.py). That is where you can inspect a tool shared by someone else.
100 |
101 | Then you can load the tool with [`load_tool`] or create it with [`~Tool.from_hub`] and pass it to the `tools` parameter in your agent.
102 | Since running tools means running custom code, you need to make sure you trust the repository, thus we require to pass `trust_remote_code=True` to load a tool from the Hub.
103 |
104 | ```python
105 | from prime import load_tool, CodeAgent
106 |
107 | model_download_tool = load_tool(
108 | "{your_username}/hf-model-downloads",
109 | trust_remote_code=True
110 | )
111 | ```
112 |
113 | ### Import a Space as a tool
114 |
115 | You can directly import a Space from the Hub as a tool using the [`Tool.from_space`] method!
116 |
117 | You only need to provide the id of the Space on the Hub, its name, and a description that will help you agent understand what the tool does. Under the hood, this will use [`gradio-client`](https://pypi.org/project/gradio-client/) library to call the Space.
118 |
119 | For instance, let's import the [FLUX.1-dev](https://huggingface.co/black-forest-labs/FLUX.1-dev) Space from the Hub and use it to generate an image.
120 |
121 | ```python
122 | image_generation_tool = Tool.from_space(
123 | "black-forest-labs/FLUX.1-schnell",
124 | name="image_generator",
125 | description="Generate an image from a prompt"
126 | )
127 |
128 | image_generation_tool("A sunny beach")
129 | ```
130 | And voilà, here's your image! 🏖️
131 |
132 |
133 |
134 | Then you can use this tool just like any other tool. For example, let's improve the prompt `a rabbit wearing a space suit` and generate an image of it.
135 |
136 | ```python
137 | from prime import CodeAgent, HfApiModel
138 |
139 | model = HfApiModel("Qwen/Qwen2.5-Coder-32B-Instruct")
140 | agent = CodeAgent(tools=[image_generation_tool], model=model)
141 |
142 | agent.run(
143 | "Improve this prompt, then generate an image of it.", prompt='A rabbit wearing a space suit'
144 | )
145 | ```
146 |
147 | ```text
148 | === Agent thoughts:
149 | improved_prompt could be "A bright blue space suit wearing rabbit, on the surface of the moon, under a bright orange sunset, with the Earth visible in the background"
150 |
151 | Now that I have improved the prompt, I can use the image generator tool to generate an image based on this prompt.
152 | >>> Agent is executing the code below:
153 | image = image_generator(prompt="A bright blue space suit wearing rabbit, on the surface of the moon, under a bright orange sunset, with the Earth visible in the background")
154 | final_answer(image)
155 | ```
156 |
157 |
158 |
159 | How cool is this? 🤩
160 |
161 | ### Use LangChain tools
162 |
163 | We love Langchain and think it has a very compelling suite of tools.
164 | To import a tool from LangChain, use the `from_langchain()` method.
165 |
166 | Here is how you can use it to recreate the intro's search result using a LangChain web search tool.
167 | This tool will need `pip install langchain google-search-results -q` to work properly.
168 | ```python
169 | from langchain.agents import load_tools
170 |
171 | search_tool = Tool.from_langchain(load_tools(["serpapi"])[0])
172 |
173 | agent = CodeAgent(tools=[search_tool], model=model)
174 |
175 | agent.run("How many more blocks (also denoted as layers) are in BERT base encoder compared to the encoder from the architecture proposed in Attention is All You Need?")
176 | ```
177 |
178 | ### Manage your agent's toolbox
179 |
180 | You can manage an agent's toolbox by adding or replacing a tool.
181 |
182 | Let's add the `model_download_tool` to an existing agent initialized with only the default toolbox.
183 |
184 | ```python
185 | from prime import HfApiModel
186 |
187 | model = HfApiModel("Qwen/Qwen2.5-Coder-32B-Instruct")
188 |
189 | agent = CodeAgent(tools=[], model=model, add_base_tools=True)
190 | agent.toolbox.add_tool(model_download_tool)
191 | ```
192 | Now we can leverage the new tool:
193 |
194 | ```python
195 | agent.run(
196 | "Can you give me the name of the model that has the most downloads in the 'text-to-video' task on the Hugging Face Hub but reverse the letters?"
197 | )
198 | ```
199 |
200 |
201 | > [!TIP]
202 | > Beware of not adding too many tools to an agent: this can overwhelm weaker LLM engines.
203 |
204 |
205 | Use the `agent.toolbox.update_tool()` method to replace an existing tool in the agent's toolbox.
206 | This is useful if your new tool is a one-to-one replacement of the existing tool because the agent already knows how to perform that specific task.
207 | Just make sure the new tool follows the same API as the replaced tool or adapt the system prompt template to ensure all examples using the replaced tool are updated.
208 |
209 |
210 | ### Use a collection of tools
211 |
212 | You can leverage tool collections by using the ToolCollection object, with the slug of the collection you want to use.
213 | Then pass them as a list to initialize you agent, and start using them!
214 |
215 | ```py
216 | from transformers import ToolCollection, CodeAgent
217 |
218 | image_tool_collection = ToolCollection(
219 | collection_slug="huggingface-tools/diffusion-tools-6630bb19a942c2306a2cdb6f",
220 | token=""
221 | )
222 | agent = CodeAgent(tools=[*image_tool_collection.tools], model=model, add_base_tools=True)
223 |
224 | agent.run("Please draw me a picture of rivers and lakes.")
225 | ```
226 |
227 | To speed up the start, tools are loaded only if called by the agent.
--------------------------------------------------------------------------------
/docs/README.md:
--------------------------------------------------------------------------------
1 |
16 |
17 | # Generating the documentation
18 |
19 | To generate the documentation, you have to build it. Several packages are necessary to build the doc.
20 |
21 | First, you need to install the project itself by running the following command at the root of the code repository:
22 |
23 | ```bash
24 | pip install -e .
25 | ```
26 |
27 | You also need to install 2 extra packages:
28 |
29 | ```bash
30 | # `hf-doc-builder` to build the docs
31 | pip install git+https://github.com/huggingface/doc-builder@main
32 | # `watchdog` for live reloads
33 | pip install watchdog
34 | ```
35 |
36 | ---
37 | **NOTE**
38 |
39 | You only need to generate the documentation to inspect it locally (if you're planning changes and want to
40 | check how they look before committing for instance). You don't have to commit the built documentation.
41 |
42 | ---
43 |
44 | ## Building the documentation
45 |
46 | Once you have setup the `doc-builder` and additional packages with the pip install command above,
47 | you can generate the documentation by typing the following command:
48 |
49 | ```bash
50 | doc-builder build prime docs/source/en/ --build_dir ~/tmp/test-build
51 | ```
52 |
53 | You can adapt the `--build_dir` to set any temporary folder that you prefer. This command will create it and generate
54 | the MDX files that will be rendered as the documentation on the main website. You can inspect them in your favorite
55 | Markdown editor.
56 |
57 | ## Previewing the documentation
58 |
59 | To preview the docs, run the following command:
60 |
61 | ```bash
62 | doc-builder preview prime docs/source/en/
63 | ```
64 |
65 | The docs will be viewable at [http://localhost:5173](http://localhost:5173). You can also preview the docs once you
66 | have opened a PR. You will see a bot add a comment to a link where the documentation with your changes lives.
67 |
68 | ---
69 | **NOTE**
70 |
71 | The `preview` command only works with existing doc files. When you add a completely new file, you need to update
72 | `_toctree.yml` & restart `preview` command (`ctrl-c` to stop it & call `doc-builder preview ...` again).
73 |
74 | ---
75 |
76 | ## Adding a new element to the navigation bar
77 |
78 | Accepted files are Markdown (.md).
79 |
80 | Create a file with its extension and put it in the source directory. You can then link it to the toc-tree by putting
81 | the filename without the extension in the [`_toctree.yml`](https://github.com/huggingface/prime/blob/main/docs/source/_toctree.yml) file.
82 |
83 | ## Renaming section headers and moving sections
84 |
85 | It helps to keep the old links working when renaming the section header and/or moving sections from one document to another. This is because the old links are likely to be used in Issues, Forums, and Social media and it'd make for a much more superior user experience if users reading those months later could still easily navigate to the originally intended information.
86 |
87 | Therefore, we simply keep a little map of moved sections at the end of the document where the original section was. The key is to preserve the original anchor.
88 |
89 | So if you renamed a section from: "Section A" to "Section B", then you can add at the end of the file:
90 |
91 | ```
92 | Sections that were moved:
93 |
94 | [ Section A ]
95 | ```
96 | and of course, if you moved it to another file, then:
97 |
98 | ```
99 | Sections that were moved:
100 |
101 | [ Section A ]
102 | ```
103 |
104 | Use the relative style to link to the new file so that the versioned docs continue to work.
105 |
106 | For an example of a rich moved section set please see the very end of [the transformers Trainer doc](https://github.com/huggingface/transformers/blob/main/docs/source/en/main_classes/trainer.md).
107 |
108 |
109 | ## Writing Documentation - Specification
110 |
111 | The `huggingface/prime` documentation follows the
112 | [Google documentation](https://sphinxcontrib-napoleon.readthedocs.io/en/latest/example_google.html) style for docstrings,
113 | although we can write them directly in Markdown.
114 |
115 | ### Adding a new tutorial
116 |
117 | Adding a new tutorial or section is done in two steps:
118 |
119 | - Add a new Markdown (.md) file under `./source`.
120 | - Link that file in `./source/_toctree.yml` on the correct toc-tree.
121 |
122 | Make sure to put your new file under the proper section. If you have a doubt, feel free to ask in a Github Issue or PR.
123 |
124 | ### Translating
125 |
126 | When translating, refer to the guide at [./TRANSLATING.md](https://github.com/huggingface/prime/blob/main/docs/TRANSLATING.md).
127 |
128 | ### Writing source documentation
129 |
130 | Values that should be put in `code` should either be surrounded by backticks: \`like so\`. Note that argument names
131 | and objects like True, None, or any strings should usually be put in `code`.
132 |
133 | When mentioning a class, function, or method, it is recommended to use our syntax for internal links so that our tool
134 | adds a link to its documentation with this syntax: \[\`XXXClass\`\] or \[\`function\`\]. This requires the class or
135 | function to be in the main package.
136 |
137 | If you want to create a link to some internal class or function, you need to
138 | provide its path. For instance: \[\`utils.ModelOutput\`\]. This will be converted into a link with
139 | `utils.ModelOutput` in the description. To get rid of the path and only keep the name of the object you are
140 | linking to in the description, add a ~: \[\`~utils.ModelOutput\`\] will generate a link with `ModelOutput` in the description.
141 |
142 | The same works for methods so you can either use \[\`XXXClass.method\`\] or \[~\`XXXClass.method\`\].
143 |
144 | #### Defining arguments in a method
145 |
146 | Arguments should be defined with the `Args:` (or `Arguments:` or `Parameters:`) prefix, followed by a line return and
147 | an indentation. The argument should be followed by its type, with its shape if it is a tensor, a colon, and its
148 | description:
149 |
150 | ```
151 | Args:
152 | n_layers (`int`): The number of layers of the model.
153 | ```
154 |
155 | If the description is too long to fit in one line, another indentation is necessary before writing the description
156 | after the argument.
157 |
158 | Here's an example showcasing everything so far:
159 |
160 | ```
161 | Args:
162 | input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
163 | Indices of input sequence tokens in the vocabulary.
164 |
165 | Indices can be obtained using [`AlbertTokenizer`]. See [`~PreTrainedTokenizer.encode`] and
166 | [`~PreTrainedTokenizer.__call__`] for details.
167 |
168 | [What are input IDs?](../glossary#input-ids)
169 | ```
170 |
171 | For optional arguments or arguments with defaults we follow the following syntax: imagine we have a function with the
172 | following signature:
173 |
174 | ```
175 | def my_function(x: str = None, a: float = 1):
176 | ```
177 |
178 | then its documentation should look like this:
179 |
180 | ```
181 | Args:
182 | x (`str`, *optional*):
183 | This argument controls ...
184 | a (`float`, *optional*, defaults to 1):
185 | This argument is used to ...
186 | ```
187 |
188 | Note that we always omit the "defaults to \`None\`" when None is the default for any argument. Also note that even
189 | if the first line describing your argument type and its default gets long, you can't break it on several lines. You can
190 | however write as many lines as you want in the indented description (see the example above with `input_ids`).
191 |
192 | #### Writing a multi-line code block
193 |
194 | Multi-line code blocks can be useful for displaying examples. They are done between two lines of three backticks as usual in Markdown:
195 |
196 |
197 | ````
198 | ```
199 | # first line of code
200 | # second line
201 | # etc
202 | ```
203 | ````
204 |
205 | #### Writing a return block
206 |
207 | The return block should be introduced with the `Returns:` prefix, followed by a line return and an indentation.
208 | The first line should be the type of the return, followed by a line return. No need to indent further for the elements
209 | building the return.
210 |
211 | Here's an example of a single value return:
212 |
213 | ```
214 | Returns:
215 | `List[int]`: A list of integers in the range [0, 1] --- 1 for a special token, 0 for a sequence token.
216 | ```
217 |
218 | Here's an example of a tuple return, comprising several objects:
219 |
220 | ```
221 | Returns:
222 | `tuple(torch.FloatTensor)` comprising various elements depending on the configuration ([`BertConfig`]) and inputs:
223 | - ** loss** (*optional*, returned when `masked_lm_labels` is provided) `torch.FloatTensor` of shape `(1,)` --
224 | Total loss is the sum of the masked language modeling loss and the next sequence prediction (classification) loss.
225 | - **prediction_scores** (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`) --
226 | Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
227 | ```
228 |
229 | #### Adding an image
230 |
231 | Due to the rapidly growing repository, it is important to make sure that no files that would significantly weigh down the repository are added. This includes images, videos, and other non-text files. We prefer to leverage a hf.co hosted `dataset` like
232 | the ones hosted on [`hf-internal-testing`](https://huggingface.co/hf-internal-testing) in which to place these files and reference
233 | them by URL. We recommend putting them in the following dataset: [huggingface/documentation-images](https://huggingface.co/datasets/huggingface/documentation-images).
234 | If an external contribution, feel free to add the images to your PR and ask a Hugging Face member to migrate your images
235 | to this dataset.
236 |
237 | #### Writing documentation examples
238 |
239 | The syntax for Example docstrings can look as follows:
240 |
241 | ```
242 | Example:
243 |
244 | ```python
245 | >>> from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC
246 | >>> from datasets import load_dataset
247 | >>> import torch
248 |
249 | >>> dataset = load_dataset("hf-internal-testing/librispeech_asr_demo", "clean", split="validation")
250 | >>> dataset = dataset.sort("id")
251 | >>> sampling_rate = dataset.features["audio"].sampling_rate
252 |
253 | >>> processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h")
254 | >>> model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h")
255 |
256 | >>> # audio file is decoded on the fly
257 | >>> inputs = processor(dataset[0]["audio"]["array"], sampling_rate=sampling_rate, return_tensors="pt")
258 | >>> with torch.no_grad():
259 | ... logits = model(**inputs).logits
260 | >>> predicted_ids = torch.argmax(logits, dim=-1)
261 |
262 | >>> # transcribe speech
263 | >>> transcription = processor.batch_decode(predicted_ids)
264 | >>> transcription[0]
265 | 'MISTER QUILTER IS THE APOSTLE OF THE MIDDLE CLASSES AND WE ARE GLAD TO WELCOME HIS GOSPEL'
266 | ```
267 | ```
268 |
269 | The docstring should give a minimal, clear example of how the respective model
270 | is to be used in inference and also include the expected (ideally sensible)
271 | output.
272 | Often, readers will try out the example before even going through the function
273 | or class definitions. Therefore, it is of utmost importance that the example
274 | works as expected.
--------------------------------------------------------------------------------
/tests/test_agents.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2024 HuggingFace Inc.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | import os
16 | import tempfile
17 | import unittest
18 | import uuid
19 | import pytest
20 |
21 | from pathlib import Path
22 |
23 | from prime.types import AgentText, AgentImage
24 | from prime.agents import (
25 | AgentMaxIterationsError,
26 | ManagedAgent,
27 | CodeAgent,
28 | ToolCallingAgent,
29 | Toolbox,
30 | ToolCall,
31 | )
32 | from prime.tools import tool
33 | from prime.default_tools import PythonInterpreterTool
34 | from transformers.testing_utils import get_tests_dir
35 |
36 |
37 | def get_new_path(suffix="") -> str:
38 | directory = tempfile.mkdtemp()
39 | return os.path.join(directory, str(uuid.uuid4()) + suffix)
40 |
41 |
42 | class FakeToolCallModel:
43 | def get_tool_call(
44 | self, messages, available_tools, stop_sequences=None, grammar=None
45 | ):
46 | if len(messages) < 3:
47 | return "python_interpreter", {"code": "2*3.6452"}, "call_0"
48 | else:
49 | return "final_answer", {"answer": "7.2904"}, "call_1"
50 |
51 |
52 | class FakeToolCallModelImage:
53 | def get_tool_call(
54 | self, messages, available_tools, stop_sequences=None, grammar=None
55 | ):
56 | if len(messages) < 3:
57 | return (
58 | "fake_image_generation_tool",
59 | {"prompt": "An image of a cat"},
60 | "call_0",
61 | )
62 |
63 | else: # We're at step 2
64 | return "final_answer", "image.png", "call_1"
65 |
66 |
67 | def fake_code_model(messages, stop_sequences=None, grammar=None) -> str:
68 | prompt = str(messages)
69 | if "special_marker" not in prompt:
70 | return """
71 | Thought: I should multiply 2 by 3.6452. special_marker
72 | Code:
73 | ```py
74 | result = 2**3.6452
75 | ```
76 | """
77 | else: # We're at step 2
78 | return """
79 | Thought: I can now answer the initial question
80 | Code:
81 | ```py
82 | final_answer(7.2904)
83 | ```
84 | """
85 |
86 |
87 | def fake_code_model_error(messages, stop_sequences=None) -> str:
88 | prompt = str(messages)
89 | if "special_marker" not in prompt:
90 | return """
91 | Thought: I should multiply 2 by 3.6452. special_marker
92 | Code:
93 | ```py
94 | a = 2
95 | b = a * 2
96 | print = 2
97 | print("Ok, calculation done!")
98 | ```
99 | """
100 | else: # We're at step 2
101 | return """
102 | Thought: I can now answer the initial question
103 | Code:
104 | ```py
105 | final_answer("got an error")
106 | ```
107 | """
108 |
109 |
110 | def fake_code_model_syntax_error(messages, stop_sequences=None) -> str:
111 | prompt = str(messages)
112 | if "special_marker" not in prompt:
113 | return """
114 | Thought: I should multiply 2 by 3.6452. special_marker
115 | Code:
116 | ```py
117 | a = 2
118 | b = a * 2
119 | print("Failing due to unexpected indent")
120 | print("Ok, calculation done!")
121 | ```
122 | """
123 | else: # We're at step 2
124 | return """
125 | Thought: I can now answer the initial question
126 | Code:
127 | ```py
128 | final_answer("got an error")
129 | ```
130 | """
131 |
132 |
133 | def fake_code_functiondef(messages, stop_sequences=None) -> str:
134 | prompt = str(messages)
135 | if "special_marker" not in prompt:
136 | return """
137 | Thought: Let's define the function. special_marker
138 | Code:
139 | ```py
140 | import numpy as np
141 |
142 | def moving_average(x, w):
143 | return np.convolve(x, np.ones(w), 'valid') / w
144 | ```
145 | """
146 | else: # We're at step 2
147 | return """
148 | Thought: I can now answer the initial question
149 | Code:
150 | ```py
151 | x, w = [0, 1, 2, 3, 4, 5], 2
152 | res = moving_average(x, w)
153 | final_answer(res)
154 | ```
155 | """
156 |
157 |
158 | def fake_code_model_single_step(messages, stop_sequences=None, grammar=None) -> str:
159 | return """
160 | Thought: I should multiply 2 by 3.6452. special_marker
161 | Code:
162 | ```py
163 | result = python_interpreter(code="2*3.6452")
164 | final_answer(result)
165 | ```
166 | """
167 |
168 |
169 | def fake_code_model_no_return(messages, stop_sequences=None, grammar=None) -> str:
170 | return """
171 | Thought: I should multiply 2 by 3.6452. special_marker
172 | Code:
173 | ```py
174 | result = python_interpreter(code="2*3.6452")
175 | print(result)
176 | ```
177 | """
178 |
179 |
180 | class AgentTests(unittest.TestCase):
181 | def test_fake_single_step_code_agent(self):
182 | agent = CodeAgent(
183 | tools=[PythonInterpreterTool()], model=fake_code_model_single_step
184 | )
185 | output = agent.run("What is 2 multiplied by 3.6452?", single_step=True)
186 | assert isinstance(output, str)
187 | assert "7.2904" in output
188 |
189 | def test_fake_toolcalling_agent(self):
190 | agent = ToolCallingAgent(
191 | tools=[PythonInterpreterTool()], model=FakeToolCallModel()
192 | )
193 | output = agent.run("What is 2 multiplied by 3.6452?")
194 | assert isinstance(output, str)
195 | assert "7.2904" in output
196 | assert agent.logs[1].task == "What is 2 multiplied by 3.6452?"
197 | assert "7.2904" in agent.logs[2].observations
198 | assert agent.logs[3].llm_output is None
199 |
200 | def test_toolcalling_agent_handles_image_tool_outputs(self):
201 | from PIL import Image
202 |
203 | @tool
204 | def fake_image_generation_tool(prompt: str) -> Image.Image:
205 | """Tool that generates an image.
206 |
207 | Args:
208 | prompt: The prompt
209 | """
210 | return Image.open(Path(get_tests_dir("fixtures")) / "000000039769.png")
211 |
212 | agent = ToolCallingAgent(
213 | tools=[fake_image_generation_tool], model=FakeToolCallModelImage()
214 | )
215 | output = agent.run("Make me an image.")
216 | assert isinstance(output, AgentImage)
217 | assert isinstance(agent.state["image.png"], Image.Image)
218 |
219 | def test_fake_code_agent(self):
220 | agent = CodeAgent(tools=[PythonInterpreterTool()], model=fake_code_model)
221 | output = agent.run("What is 2 multiplied by 3.6452?")
222 | assert isinstance(output, float)
223 | assert output == 7.2904
224 | assert agent.logs[1].task == "What is 2 multiplied by 3.6452?"
225 | assert agent.logs[3].tool_call == ToolCall(
226 | name="python_interpreter", arguments="final_answer(7.2904)", id="call_3"
227 | )
228 |
229 | def test_additional_args_added_to_task(self):
230 | agent = CodeAgent(tools=[], model=fake_code_model)
231 | agent.run(
232 | "What is 2 multiplied by 3.6452?",
233 | additional_args={"instruction": "Remember this."},
234 | )
235 | assert "Remember this" in agent.task
236 | assert "Remember this" in str(agent.input_messages)
237 |
238 | def test_reset_conversations(self):
239 | agent = CodeAgent(tools=[PythonInterpreterTool()], model=fake_code_model)
240 | output = agent.run("What is 2 multiplied by 3.6452?", reset=True)
241 | assert output == 7.2904
242 | assert len(agent.logs) == 4
243 |
244 | output = agent.run("What is 2 multiplied by 3.6452?", reset=False)
245 | assert output == 7.2904
246 | assert len(agent.logs) == 6
247 |
248 | output = agent.run("What is 2 multiplied by 3.6452?", reset=True)
249 | assert output == 7.2904
250 | assert len(agent.logs) == 4
251 |
252 | def test_code_agent_code_errors_show_offending_lines(self):
253 | agent = CodeAgent(tools=[PythonInterpreterTool()], model=fake_code_model_error)
254 | output = agent.run("What is 2 multiplied by 3.6452?")
255 | assert isinstance(output, AgentText)
256 | assert output == "got an error"
257 | assert "Code execution failed at line 'print = 2' because of" in str(agent.logs)
258 |
259 | def test_code_agent_syntax_error_show_offending_lines(self):
260 | agent = CodeAgent(
261 | tools=[PythonInterpreterTool()], model=fake_code_model_syntax_error
262 | )
263 | output = agent.run("What is 2 multiplied by 3.6452?")
264 | assert isinstance(output, AgentText)
265 | assert output == "got an error"
266 | assert ' print("Failing due to unexpected indent")' in str(agent.logs)
267 |
268 | def test_setup_agent_with_empty_toolbox(self):
269 | ToolCallingAgent(model=FakeToolCallModel(), tools=[])
270 |
271 | def test_fails_max_iterations(self):
272 | agent = CodeAgent(
273 | tools=[PythonInterpreterTool()],
274 | model=fake_code_model_no_return, # use this callable because it never ends
275 | max_iterations=5,
276 | )
277 | agent.run("What is 2 multiplied by 3.6452?")
278 | assert len(agent.logs) == 8
279 | assert type(agent.logs[-1].error) is AgentMaxIterationsError
280 |
281 | def test_init_agent_with_different_toolsets(self):
282 | toolset_1 = []
283 | agent = CodeAgent(tools=toolset_1, model=fake_code_model)
284 | assert (
285 | len(agent.toolbox.tools) == 1
286 | ) # when no tools are provided, only the final_answer tool is added by default
287 |
288 | toolset_2 = [PythonInterpreterTool(), PythonInterpreterTool()]
289 | agent = CodeAgent(tools=toolset_2, model=fake_code_model)
290 | assert (
291 | len(agent.toolbox.tools) == 2
292 | ) # deduplication of tools, so only one python_interpreter tool is added in addition to final_answer
293 |
294 | toolset_3 = Toolbox(toolset_2)
295 | agent = CodeAgent(tools=toolset_3, model=fake_code_model)
296 | assert (
297 | len(agent.toolbox.tools) == 2
298 | ) # same as previous one, where toolset_3 is an instantiation of previous one
299 |
300 | # check that add_base_tools will not interfere with existing tools
301 | with pytest.raises(KeyError) as e:
302 | agent = ToolCallingAgent(
303 | tools=toolset_3, model=FakeToolCallModel(), add_base_tools=True
304 | )
305 | assert "already exists in the toolbox" in str(e)
306 |
307 | # check that python_interpreter base tool does not get added to code agents
308 | agent = CodeAgent(tools=[], model=fake_code_model, add_base_tools=True)
309 | assert (
310 | len(agent.toolbox.tools) == 3
311 | ) # added final_answer tool + search + transcribe
312 |
313 | def test_function_persistence_across_steps(self):
314 | agent = CodeAgent(
315 | tools=[],
316 | model=fake_code_functiondef,
317 | max_iterations=2,
318 | additional_authorized_imports=["numpy"],
319 | )
320 | res = agent.run("ok")
321 | assert res[0] == 0.5
322 |
323 | def test_init_managed_agent(self):
324 | agent = CodeAgent(tools=[], model=fake_code_functiondef)
325 | managed_agent = ManagedAgent(agent, name="managed_agent", description="Empty")
326 | assert managed_agent.name == "managed_agent"
327 | assert managed_agent.description == "Empty"
328 |
329 | def test_agent_description_gets_correctly_inserted_in_system_prompt(self):
330 | agent = CodeAgent(tools=[], model=fake_code_functiondef)
331 | managed_agent = ManagedAgent(agent, name="managed_agent", description="Empty")
332 | manager_agent = CodeAgent(
333 | tools=[],
334 | model=fake_code_functiondef,
335 | managed_agents=[managed_agent],
336 | )
337 | assert "You can also give requests to team members." not in agent.system_prompt
338 | print("ok1")
339 | assert "{{managed_agents_descriptions}}" not in agent.system_prompt
340 | assert (
341 | "You can also give requests to team members." in manager_agent.system_prompt
342 | )
343 |
--------------------------------------------------------------------------------
/src/prime/default_tools.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # coding=utf-8
3 |
4 | # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
5 | #
6 | # Licensed under the Apache License, Version 2.0 (the "License");
7 | # you may not use this file except in compliance with the License.
8 | # You may obtain a copy of the License at
9 | #
10 | # http://www.apache.org/licenses/LICENSE-2.0
11 | #
12 | # Unless required by applicable law or agreed to in writing, software
13 | # distributed under the License is distributed on an "AS IS" BASIS,
14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 | # See the License for the specific language governing permissions and
16 | # limitations under the License.
17 | import json
18 | import re
19 | from dataclasses import dataclass
20 | from typing import Dict, Optional
21 | from huggingface_hub import hf_hub_download, list_spaces
22 |
23 | from transformers.utils import is_offline_mode
24 | from transformers.models.whisper import (
25 | WhisperProcessor,
26 | WhisperForConditionalGeneration,
27 | )
28 |
29 | from .local_python_executor import (
30 | BASE_BUILTIN_MODULES,
31 | BASE_PYTHON_TOOLS,
32 | evaluate_python_code,
33 | )
34 | from .tools import TOOL_CONFIG_FILE, Tool, PipelineTool
35 | from .types import AgentAudio
36 |
37 |
38 | @dataclass
39 | class PreTool:
40 | name: str
41 | inputs: Dict[str, str]
42 | output_type: type
43 | task: str
44 | description: str
45 | repo_id: str
46 |
47 |
48 | def get_remote_tools(logger, organization="huggingface-tools"):
49 | if is_offline_mode():
50 | logger.info("You are in offline mode, so remote tools are not available.")
51 | return {}
52 |
53 | spaces = list_spaces(author=organization)
54 | tools = {}
55 | for space_info in spaces:
56 | repo_id = space_info.id
57 | resolved_config_file = hf_hub_download(
58 | repo_id, TOOL_CONFIG_FILE, repo_type="space"
59 | )
60 | with open(resolved_config_file, encoding="utf-8") as reader:
61 | config = json.load(reader)
62 | task = repo_id.split("/")[-1]
63 | tools[config["name"]] = PreTool(
64 | task=task,
65 | description=config["description"],
66 | repo_id=repo_id,
67 | name=task,
68 | inputs=config["inputs"],
69 | output_type=config["output_type"],
70 | )
71 |
72 | return tools
73 |
74 |
75 | class PythonInterpreterTool(Tool):
76 | name = "python_interpreter"
77 | description = "This is a tool that evaluates python code. It can be used to perform calculations."
78 | inputs = {
79 | "code": {
80 | "type": "string",
81 | "description": "The python code to run in interpreter",
82 | }
83 | }
84 | output_type = "string"
85 |
86 | def __init__(self, *args, authorized_imports=None, **kwargs):
87 | if authorized_imports is None:
88 | self.authorized_imports = list(set(BASE_BUILTIN_MODULES))
89 | else:
90 | self.authorized_imports = list(
91 | set(BASE_BUILTIN_MODULES) | set(authorized_imports)
92 | )
93 | self.inputs = {
94 | "code": {
95 | "type": "string",
96 | "description": (
97 | "The code snippet to evaluate. All variables used in this snippet must be defined in this same snippet, "
98 | f"else you will get an error. This code can only import the following python libraries: {authorized_imports}."
99 | ),
100 | }
101 | }
102 | self.base_python_tools = BASE_PYTHON_TOOLS
103 | self.python_evaluator = evaluate_python_code
104 | super().__init__(*args, **kwargs)
105 |
106 | def forward(self, code: str) -> str:
107 | state = {}
108 | try:
109 | output = str(
110 | self.python_evaluator(
111 | code,
112 | state=state,
113 | static_tools=self.base_python_tools,
114 | authorized_imports=self.authorized_imports,
115 | )
116 | )
117 | return f"Stdout:\n{state['print_outputs']}\nOutput: {output}"
118 | except Exception as e:
119 | return f"Error: {str(e)}"
120 |
121 |
122 | class FinalAnswerTool(Tool):
123 | name = "final_answer"
124 | description = "Provides a final answer to the given problem."
125 | inputs = {
126 | "answer": {"type": "any", "description": "The final answer to the problem"}
127 | }
128 | output_type = "any"
129 |
130 | def forward(self, answer):
131 | return answer
132 |
133 |
134 | class UserInputTool(Tool):
135 | name = "user_input"
136 | description = "Asks for user's input on a specific question"
137 | inputs = {
138 | "question": {"type": "string", "description": "The question to ask the user"}
139 | }
140 | output_type = "string"
141 |
142 | def forward(self, question):
143 | user_input = input(f"{question} => ")
144 | return user_input
145 |
146 |
147 | class DuckDuckGoSearchTool(Tool):
148 | name = "web_search"
149 | description = """Performs a duckduckgo web search based on your query (think a Google search) then returns the top search results as a list of dict elements.
150 | Each result has keys 'title', 'href' and 'body'."""
151 | inputs = {
152 | "query": {"type": "string", "description": "The search query to perform."}
153 | }
154 | output_type = "any"
155 |
156 | def __init__(self, **kwargs):
157 | super().__init__(self, **kwargs)
158 | try:
159 | from duckduckgo_search import DDGS
160 | except ImportError:
161 | raise ImportError(
162 | "You must install package `duckduckgo_search` to run this tool: for instance run `pip install duckduckgo-search`."
163 | )
164 | self.ddgs = DDGS()
165 |
166 | def forward(self, query: str) -> str:
167 | results = self.ddgs.text(query, max_results=10)
168 | postprocessed_results = [
169 | f"[{result['title']}]({result['href']})\n{result['body']}"
170 | for result in results
171 | ]
172 | return "## Search Results\n\n" + "\n\n".join(postprocessed_results)
173 |
174 |
175 | class GoogleSearchTool(Tool):
176 | name = "web_search"
177 | description = """Performs a google web search for your query then returns a string of the top search results."""
178 | inputs = {
179 | "query": {"type": "string", "description": "The search query to perform."},
180 | "filter_year": {
181 | "type": "integer",
182 | "description": "Optionally restrict results to a certain year",
183 | "nullable": True,
184 | },
185 | }
186 | output_type = "string"
187 |
188 | def __init__(self):
189 | super().__init__(self)
190 | import os
191 |
192 | self.serpapi_key = os.getenv("SERPAPI_API_KEY")
193 |
194 | def forward(self, query: str, filter_year: Optional[int] = None) -> str:
195 | import requests
196 |
197 | if self.serpapi_key is None:
198 | raise ValueError(
199 | "Missing SerpAPI key. Make sure you have 'SERPAPI_API_KEY' in your env variables."
200 | )
201 |
202 | params = {
203 | "engine": "google",
204 | "q": query,
205 | "api_key": self.serpapi_key,
206 | "google_domain": "google.com",
207 | }
208 | if filter_year is not None:
209 | params["tbs"] = (
210 | f"cdr:1,cd_min:01/01/{filter_year},cd_max:12/31/{filter_year}"
211 | )
212 |
213 | response = requests.get("https://serpapi.com/search.json", params=params)
214 |
215 | if response.status_code == 200:
216 | results = response.json()
217 | else:
218 | raise ValueError(response.json())
219 |
220 | if "organic_results" not in results.keys():
221 | if filter_year is not None:
222 | raise Exception(
223 | f"'organic_results' key not found for query: '{query}' with filtering on year={filter_year}. Use a less restrictive query or do not filter on year."
224 | )
225 | else:
226 | raise Exception(
227 | f"'organic_results' key not found for query: '{query}'. Use a less restrictive query."
228 | )
229 | if len(results["organic_results"]) == 0:
230 | year_filter_message = (
231 | f" with filter year={filter_year}" if filter_year is not None else ""
232 | )
233 | return f"No results found for '{query}'{year_filter_message}. Try with a more general query, or remove the year filter."
234 |
235 | web_snippets = []
236 | if "organic_results" in results:
237 | for idx, page in enumerate(results["organic_results"]):
238 | date_published = ""
239 | if "date" in page:
240 | date_published = "\nDate published: " + page["date"]
241 |
242 | source = ""
243 | if "source" in page:
244 | source = "\nSource: " + page["source"]
245 |
246 | snippet = ""
247 | if "snippet" in page:
248 | snippet = "\n" + page["snippet"]
249 |
250 | redacted_version = f"{idx}. [{page['title']}]({page['link']}){date_published}{source}\n{snippet}"
251 |
252 | redacted_version = redacted_version.replace(
253 | "Your browser can't play this video.", ""
254 | )
255 | web_snippets.append(redacted_version)
256 |
257 | return "## Search Results\n" + "\n\n".join(web_snippets)
258 |
259 |
260 | class VisitWebpageTool(Tool):
261 | name = "visit_webpage"
262 | description = "Visits a webpage at the given url and reads its content as a markdown string. Use this to browse webpages."
263 | inputs = {
264 | "url": {
265 | "type": "string",
266 | "description": "The url of the webpage to visit.",
267 | }
268 | }
269 | output_type = "string"
270 |
271 | def forward(self, url: str) -> str:
272 | try:
273 | from markdownify import markdownify
274 | import requests
275 | from requests.exceptions import RequestException
276 | except ImportError:
277 | raise ImportError(
278 | "You must install packages `markdownify` and `requests` to run this tool: for instance run `pip install markdownify requests`."
279 | )
280 | try:
281 | # Send a GET request to the URL
282 | response = requests.get(url)
283 | response.raise_for_status() # Raise an exception for bad status codes
284 |
285 | # Convert the HTML content to Markdown
286 | markdown_content = markdownify(response.text).strip()
287 |
288 | # Remove multiple line breaks
289 | markdown_content = re.sub(r"\n{3,}", "\n\n", markdown_content)
290 |
291 | return markdown_content
292 |
293 | except RequestException as e:
294 | return f"Error fetching the webpage: {str(e)}"
295 | except Exception as e:
296 | return f"An unexpected error occurred: {str(e)}"
297 |
298 |
299 | class SpeechToTextTool(PipelineTool):
300 | default_checkpoint = "openai/whisper-large-v3-turbo"
301 | description = "This is a tool that transcribes an audio into text. It returns the transcribed text."
302 | name = "transcriber"
303 | pre_processor_class = WhisperProcessor
304 | model_class = WhisperForConditionalGeneration
305 |
306 | inputs = {
307 | "audio": {
308 | "type": "audio",
309 | "description": "The audio to transcribe. Can be a local path, an url, or a tensor.",
310 | }
311 | }
312 | output_type = "string"
313 |
314 | def encode(self, audio):
315 | audio = AgentAudio(audio).to_raw()
316 | return self.pre_processor(audio, return_tensors="pt")
317 |
318 | def forward(self, inputs):
319 | return self.model.generate(inputs["input_features"])
320 |
321 | def decode(self, outputs):
322 | return self.pre_processor.batch_decode(outputs, skip_special_tokens=True)[0]
323 |
324 |
325 | __all__ = [
326 | "PythonInterpreterTool",
327 | "FinalAnswerTool",
328 | "UserInputTool",
329 | "DuckDuckGoSearchTool",
330 | "GoogleSearchTool",
331 | "VisitWebpageTool",
332 | "SpeechToTextTool",
333 | ]
334 |
--------------------------------------------------------------------------------
/docs/source/en/tutorials/building_good_agents.md:
--------------------------------------------------------------------------------
1 |
16 | # Building good agents
17 |
18 | [[open-in-colab]]
19 |
20 | There's a world of difference between building an agent that works and one that doesn't.
21 | How to build into this latter category?
22 | In this guide, we're going to see best practices for building agents.
23 |
24 | > [!TIP]
25 | > If you're new to building agents, make sure to first read the [intro to agents](../conceptual_guides/intro_agents) and the [guided tour of prime](../guided_tour).
26 |
27 | ### The best agentic systems are the simplest: simplify the workflow as much as you can
28 |
29 | Giving an LLM some agency in your workflow introduces some risk of errors.
30 |
31 | Well-programmed agentic systems have good error logging and retry mechanisms anyway, so the LLM engine has a chance to self-correct their mistake. But to reduce the risk of LLM error to the maximum, you should simplify your workflow!
32 |
33 | Let's take again the example from [intro_agents]: a bot that answers user queries on a surf trip company.
34 | Instead of letting the agent do 2 different calls for "travel distance API" and "weather API" each time they are asked about a new surf spot, you could just make one unified tool "return_spot_information", a function that calls both APIs at once and returns their concatenated outputs to the user.
35 |
36 | This will reduce costs, latency, and error risk!
37 |
38 | The main guideline is: Reduce the number of LLM calls as much as you can.
39 |
40 | This leads to a few takeaways:
41 | - Whenever possible, group 2 tools in one, like in our example of the two APIs.
42 | - Whenever possible, logic should be based on deterministic functions rather than agentic decisions.
43 |
44 | ### Improve the information flow to the LLM engine
45 |
46 | Remember that your LLM engine is like a ~intelligent~ robot, tapped into a room with the only communication with the outside world being notes passed under a door.
47 |
48 | It won't know of anything that happened if you don't explicitly put that into its prompt.
49 |
50 | So first start with making your task very clear!
51 | Since an agent is powered by an LLM, minor variations in your task formulation might yield completely different results.
52 |
53 | Then, improve the information flow towards your agent in tool use.
54 |
55 | Particular guidelines to follow:
56 | - Each tool should log (by simply using `print` statements inside the tool's `forward` method) everything that could be useful for the LLM engine.
57 | - In particular, logging detail on tool execution errors would help a lot!
58 |
59 | For instance, here's a tool that :
60 |
61 | First, here's a poor version:
62 | ```python
63 | import datetime
64 | from prime import tool
65 |
66 | def get_weather_report_at_coordinates(coordinates, date_time):
67 | # Dummy function, returns a list of [temperature in °C, risk of rain on a scale 0-1, wave height in m]
68 | return [28.0, 0.35, 0.85]
69 |
70 | def get_coordinates_from_location(location):
71 | # Returns dummy coordinates
72 | return [3.3, -42.0]
73 |
74 | @tool
75 | def get_weather_api(location: str, date_time: str) -> str:
76 | """
77 | Returns the weather report.
78 |
79 | Args:
80 | location: the name of the place that you want the weather for.
81 | date_time: the date and time for which you want the report.
82 | """
83 | lon, lat = convert_location_to_coordinates(location)
84 | date_time = datetime.strptime(date_time)
85 | return str(get_weather_report_at_coordinates((lon, lat), date_time))
86 | ```
87 |
88 | Why is it bad?
89 | - there's no precision of the format that should be used for `date_time`
90 | - there's no detail on how location should
91 | - there's no logging mechanism tying to explicit failure cases like location not being in a proper format, or date_time not being properly formatted.
92 | - the output format is hard to understand
93 |
94 | If the tool call fails, the error trace logged in memory can help the LLM reverse engineer the tool to fix the errors. But why leave it so much heavy lifting to do?
95 |
96 | A better way to build this tool would have been the following:
97 | ```python
98 | @tool
99 | def get_weather_api(location: str, date_time: str) -> str:
100 | """
101 | Returns the weather report.
102 |
103 | Args:
104 | location: the name of the place that you want the weather for. Should be a place name, followed by possibly a city name, then a country, like "Anchor Point, Taghazout, Morocco".
105 | date_time: the date and time for which you want the report, formatted as '%m/%d/%y %H:%M:%S'.
106 | """
107 | lon, lat = convert_location_to_coordinates(location)
108 | try:
109 | date_time = datetime.strptime(date_time)
110 | except Exception as e:
111 | raise ValueError("Conversion of `date_time` to datetime format failed, make sure to provide a string in format '%m/%d/%y %H:%M:%S'. Full trace:" + str(e))
112 | temperature_celsius, risk_of_rain, wave_height = get_weather_report_at_coordinates((lon, lat), date_time)
113 | return f"Weather report for {location}, {date_time}: Temperature will be {temperature_celsius}°C, risk of rain is {risk_of_rain*100:.0f}%, wave height is {wave_height}m."
114 | ```
115 |
116 | In general, to ease the load on your LLM, the good question to ask yourself is: "How easy would it be for me, if I was dumb and using this tool for the first time ever, to program with this tool and correct my own errors?".
117 |
118 | ### Give more arguments to the agent
119 |
120 | To pass some additional objects to your agent than thes smple string that tells it the task to run, you can use argument `additional_args` to pass any type of object:
121 |
122 | ```py
123 | from prime import CodeAgent, HfApiModel
124 |
125 | model_id = "meta-llama/Llama-3.3-70B-Instruct"
126 |
127 | agent = CodeAgent(tools=[], model=HfApiModel(model_id=model_id), add_base_tools=True)
128 |
129 | agent.run(
130 | "Why does Mike not know many people in New York?",
131 | additional_args={"mp3_sound_file_url":'https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/recording.mp3'}
132 | )
133 | ```
134 | For instance, you can use this `additional_args` argument to pass images or strings that you want your agent to leverage.
135 |
136 |
137 |
138 | ## How to debug your agent
139 |
140 | ### 1. Use a stronger LLM
141 |
142 | In an agentic workflows, some of the errors are actual errors, some other are the fault of your LLM engine not reasoning properly.
143 | For instance, consider this trace for an `CodeAgent` that I asked to make me a car picture:
144 | ```
145 | ==================================================================================================== New task ====================================================================================================
146 | Make me a cool car picture
147 | ──────────────────────────────────────────────────────────────────────────────────────────────────── New step ────────────────────────────────────────────────────────────────────────────────────────────────────
148 | Agent is executing the code below: ───────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
149 | image_generator(prompt="A cool, futuristic sports car with LED headlights, aerodynamic design, and vibrant color, high-res, photorealistic")
150 | ──────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
151 |
152 | Last output from code snippet: ───────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
153 | /var/folders/6m/9b1tts6d5w960j80wbw9tx3m0000gn/T/tmpx09qfsdd/652f0007-3ee9-44e2-94ac-90dae6bb89a4.png
154 | Step 1:
155 |
156 | - Time taken: 16.35 seconds
157 | - Input tokens: 1,383
158 | - Output tokens: 77
159 | ──────────────────────────────────────────────────────────────────────────────────────────────────── New step ────────────────────────────────────────────────────────────────────────────────────────────────────
160 | Agent is executing the code below: ───────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
161 | final_answer("/var/folders/6m/9b1tts6d5w960j80wbw9tx3m0000gn/T/tmpx09qfsdd/652f0007-3ee9-44e2-94ac-90dae6bb89a4.png")
162 | ──────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
163 | Print outputs:
164 |
165 | Last output from code snippet: ───────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
166 | /var/folders/6m/9b1tts6d5w960j80wbw9tx3m0000gn/T/tmpx09qfsdd/652f0007-3ee9-44e2-94ac-90dae6bb89a4.png
167 | Final answer:
168 | /var/folders/6m/9b1tts6d5w960j80wbw9tx3m0000gn/T/tmpx09qfsdd/652f0007-3ee9-44e2-94ac-90dae6bb89a4.png
169 | ```
170 | The user sees, instead of an image being returned, a path being returned to them.
171 | It could look like a bug from the system, but actually the agentic system didn't cause the error: it's just that the LLM engine did the mistake of not saving the image output into a variable.
172 | Thus it cannot access the image again except by leveraging the path that was logged while saving the image, so it returns the path instead of an image.
173 |
174 | The first step to debugging your agent is thus "Use a more powerful LLM". Alternatives like `Qwen2/5-72B-Instruct` wouldn't have made that mistake.
175 |
176 | ### 2. Provide more guidance / more information
177 |
178 | Then you can also use less powerful models but guide them better.
179 |
180 | Put yourself in the shoes of your model: if you were the model solving the task, would you struggle with the information available to you (from the system prompt + task formulation + tool description) ?
181 |
182 | Would you need some added clarifications?
183 |
184 | To provide extra information, we do not recommend to change the system prompt right away: the default system prompt has many adjustments that you do not want to mess up except if you understand the prompt very well.
185 | Better ways to guide your LLM engine are:
186 | - If it 's about the task to solve: add all these details to the task. The task could be 100s of pages long.
187 | - If it's about how to use tools: the description attribute of your tools.
188 |
189 | If after trying the above, you still want to change the system prompt, your new system prompt passed to `system_prompt` upon agent initialization needs to contain the following placeholders that will be used to insert certain automatically generated descriptions when running the agent:
190 | - `"{{tool_descriptions}}"` to insert tool descriptions.
191 | - `"{{managed_agents_description}}"` to insert the description for managed agents if there are any.
192 | - For `CodeAgent` only: `"{{authorized_imports}}"` to insert the list of authorized imports.
193 |
194 |
195 | ### 3. Extra planning
196 |
197 | We provide a model for a supplementary planning step, that an agent can run regularly in-between normal action steps. In this step, there is no tool call, the LLM is simply asked to update a list of facts it knows and to reflect on what steps it should take next based on those facts.
198 |
199 | ```py
200 | from prime import load_tool, CodeAgent, HfApiModel, DuckDuckGoSearchTool
201 | from dotenv import load_dotenv
202 |
203 | load_dotenv()
204 |
205 | # Import tool from Hub
206 | image_generation_tool = load_tool("m-ric/text-to-image", cache=False)
207 |
208 | search_tool = DuckDuckGoSearchTool()
209 |
210 | agent = CodeAgent(
211 | tools=[search_tool],
212 | model=HfApiModel("Qwen/Qwen2.5-72B-Instruct"),
213 | planning_interval=3 # This is where you activate planning!
214 | )
215 |
216 | # Run it!
217 | result = agent.run(
218 | "How long would a cheetah at full speed take to run the length of Pont Alexandre III?",
219 | )
220 | ```
221 |
--------------------------------------------------------------------------------
/tests/test_tools.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2024 HuggingFace Inc.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | import unittest
16 | from pathlib import Path
17 | from typing import Dict, Union, Optional
18 |
19 | import numpy as np
20 | import pytest
21 |
22 | from transformers import is_torch_available, is_vision_available
23 | from prime.types import (
24 | AGENT_TYPE_MAPPING,
25 | AgentAudio,
26 | AgentImage,
27 | AgentText,
28 | )
29 | from prime.tools import Tool, tool, AUTHORIZED_TYPES
30 | from transformers.testing_utils import get_tests_dir
31 |
32 |
33 | if is_torch_available():
34 | import torch
35 |
36 | if is_vision_available():
37 | from PIL import Image
38 |
39 |
40 | def create_inputs(tool_inputs: Dict[str, Dict[Union[str, type], str]]):
41 | inputs = {}
42 |
43 | for input_name, input_desc in tool_inputs.items():
44 | input_type = input_desc["type"]
45 |
46 | if input_type == "string":
47 | inputs[input_name] = "Text input"
48 | elif input_type == "image":
49 | inputs[input_name] = Image.open(
50 | Path(get_tests_dir("fixtures")) / "000000039769.png"
51 | ).resize((512, 512))
52 | elif input_type == "audio":
53 | inputs[input_name] = np.ones(3000)
54 | else:
55 | raise ValueError(f"Invalid type requested: {input_type}")
56 |
57 | return inputs
58 |
59 |
60 | def output_type(output):
61 | if isinstance(output, (str, AgentText)):
62 | return "string"
63 | elif isinstance(output, (Image.Image, AgentImage)):
64 | return "image"
65 | elif isinstance(output, (torch.Tensor, AgentAudio)):
66 | return "audio"
67 | else:
68 | raise TypeError(f"Invalid output: {output}")
69 |
70 |
71 | class ToolTesterMixin:
72 | def test_inputs_output(self):
73 | self.assertTrue(hasattr(self.tool, "inputs"))
74 | self.assertTrue(hasattr(self.tool, "output_type"))
75 |
76 | inputs = self.tool.inputs
77 | self.assertTrue(isinstance(inputs, dict))
78 |
79 | for _, input_spec in inputs.items():
80 | self.assertTrue("type" in input_spec)
81 | self.assertTrue("description" in input_spec)
82 | self.assertTrue(input_spec["type"] in AUTHORIZED_TYPES)
83 | self.assertTrue(isinstance(input_spec["description"], str))
84 |
85 | output_type = self.tool.output_type
86 | self.assertTrue(output_type in AUTHORIZED_TYPES)
87 |
88 | def test_common_attributes(self):
89 | self.assertTrue(hasattr(self.tool, "description"))
90 | self.assertTrue(hasattr(self.tool, "name"))
91 | self.assertTrue(hasattr(self.tool, "inputs"))
92 | self.assertTrue(hasattr(self.tool, "output_type"))
93 |
94 | def test_agent_type_output(self):
95 | inputs = create_inputs(self.tool.inputs)
96 | output = self.tool(**inputs, sanitize_inputs_outputs=True)
97 | if self.tool.output_type != "any":
98 | agent_type = AGENT_TYPE_MAPPING[self.tool.output_type]
99 | self.assertTrue(isinstance(output, agent_type))
100 |
101 |
102 | class ToolTests(unittest.TestCase):
103 | def test_tool_init_with_decorator(self):
104 | @tool
105 | def coolfunc(a: str, b: int) -> float:
106 | """Cool function
107 |
108 | Args:
109 | a: The first argument
110 | b: The second one
111 | """
112 | return b + 2, a
113 |
114 | assert coolfunc.output_type == "number"
115 |
116 | def test_tool_init_vanilla(self):
117 | class HFModelDownloadsTool(Tool):
118 | name = "model_download_counter"
119 | description = """
120 | This is a tool that returns the most downloaded model of a given task on the Hugging Face Hub.
121 | It returns the name of the checkpoint."""
122 |
123 | inputs = {
124 | "task": {
125 | "type": "string",
126 | "description": "the task category (such as text-classification, depth-estimation, etc)",
127 | }
128 | }
129 | output_type = "string"
130 |
131 | def forward(self, task: str) -> str:
132 | return "best model"
133 |
134 | tool = HFModelDownloadsTool()
135 | assert list(tool.inputs.keys())[0] == "task"
136 |
137 | def test_tool_init_decorator_raises_issues(self):
138 | with pytest.raises(Exception) as e:
139 |
140 | @tool
141 | def coolfunc(a: str, b: int):
142 | """Cool function
143 |
144 | Args:
145 | a: The first argument
146 | b: The second one
147 | """
148 | return a + b
149 |
150 | assert coolfunc.output_type == "number"
151 | assert "Tool return type not found" in str(e)
152 |
153 | with pytest.raises(Exception) as e:
154 |
155 | @tool
156 | def coolfunc(a: str, b: int) -> int:
157 | """Cool function
158 |
159 | Args:
160 | a: The first argument
161 | """
162 | return b + a
163 |
164 | assert coolfunc.output_type == "number"
165 | assert "docstring has no description for the argument" in str(e)
166 |
167 | def test_saving_tool_raises_error_imports_outside_function(self):
168 | with pytest.raises(Exception) as e:
169 | import numpy as np
170 |
171 | @tool
172 | def get_current_time() -> str:
173 | """
174 | Gets the current time.
175 | """
176 | return str(np.random.random())
177 |
178 | get_current_time.save("output")
179 |
180 | assert "np" in str(e)
181 |
182 | # Also test with classic definition
183 | with pytest.raises(Exception) as e:
184 |
185 | class GetCurrentTimeTool(Tool):
186 | name = "get_current_time_tool"
187 | description = "Gets the current time"
188 | inputs = {}
189 | output_type = "string"
190 |
191 | def forward(self):
192 | return str(np.random.random())
193 |
194 | get_current_time = GetCurrentTimeTool()
195 | get_current_time.save("output")
196 |
197 | assert "np" in str(e)
198 |
199 | def test_tool_definition_raises_no_error_imports_in_function(self):
200 | @tool
201 | def get_current_time() -> str:
202 | """
203 | Gets the current time.
204 | """
205 | from datetime import datetime
206 |
207 | return str(datetime.now())
208 |
209 | class GetCurrentTimeTool(Tool):
210 | name = "get_current_time_tool"
211 | description = "Gets the current time"
212 | inputs = {}
213 | output_type = "string"
214 |
215 | def forward(self):
216 | from datetime import datetime
217 |
218 | return str(datetime.now())
219 |
220 | def test_saving_tool_allows_no_arg_in_init(self):
221 | # Test one cannot save tool with additional args in init
222 | class FailTool(Tool):
223 | name = "specific"
224 | description = "test description"
225 | inputs = {
226 | "string_input": {"type": "string", "description": "input description"}
227 | }
228 | output_type = "string"
229 |
230 | def __init__(self, url):
231 | super().__init__(self)
232 | self.url = "none"
233 |
234 | def forward(self, string_input: str) -> str:
235 | return self.url + string_input
236 |
237 | fail_tool = FailTool("dummy_url")
238 | with pytest.raises(Exception) as e:
239 | fail_tool.save("output")
240 | assert "__init__" in str(e)
241 |
242 | def test_saving_tool_allows_no_imports_from_outside_methods(self):
243 | # Test that using imports from outside functions fails
244 | import numpy as np
245 |
246 | class FailTool(Tool):
247 | name = "specific"
248 | description = "test description"
249 | inputs = {
250 | "string_input": {"type": "string", "description": "input description"}
251 | }
252 | output_type = "string"
253 |
254 | def useless_method(self):
255 | self.client = np.random.random()
256 | return ""
257 |
258 | def forward(self, string_input):
259 | return self.useless_method() + string_input
260 |
261 | fail_tool = FailTool()
262 | with pytest.raises(Exception) as e:
263 | fail_tool.save("output")
264 | assert "'np' is undefined" in str(e)
265 |
266 | # Test that putting these imports inside functions works
267 | class SuccessTool(Tool):
268 | name = "specific"
269 | description = "test description"
270 | inputs = {
271 | "string_input": {"type": "string", "description": "input description"}
272 | }
273 | output_type = "string"
274 |
275 | def useless_method(self):
276 | import numpy as np
277 |
278 | self.client = np.random.random()
279 | return ""
280 |
281 | def forward(self, string_input):
282 | return self.useless_method() + string_input
283 |
284 | success_tool = SuccessTool()
285 | success_tool.save("output")
286 |
287 | def test_tool_missing_class_attributes_raises_error(self):
288 | with pytest.raises(Exception) as e:
289 |
290 | class GetWeatherTool(Tool):
291 | name = "get_weather"
292 | description = "Get weather in the next days at given location."
293 | inputs = {
294 | "location": {"type": "string", "description": "the location"},
295 | "celsius": {
296 | "type": "string",
297 | "description": "the temperature type",
298 | },
299 | }
300 |
301 | def forward(
302 | self, location: str, celsius: Optional[bool] = False
303 | ) -> str:
304 | return "The weather is UNGODLY with torrential rains and temperatures below -10°C"
305 |
306 | GetWeatherTool()
307 | assert "You must set an attribute output_type" in str(e)
308 |
309 | def test_tool_from_decorator_optional_args(self):
310 | @tool
311 | def get_weather(location: str, celsius: Optional[bool] = False) -> str:
312 | """
313 | Get weather in the next days at given location.
314 | Secretly this tool does not care about the location, it hates the weather everywhere.
315 |
316 | Args:
317 | location: the location
318 | celsius: the temperature type
319 | """
320 | return "The weather is UNGODLY with torrential rains and temperatures below -10°C"
321 |
322 | assert "nullable" in get_weather.inputs["celsius"]
323 | assert get_weather.inputs["celsius"]["nullable"]
324 | assert "nullable" not in get_weather.inputs["location"]
325 |
326 | def test_tool_mismatching_nullable_args_raises_error(self):
327 | with pytest.raises(Exception) as e:
328 |
329 | class GetWeatherTool(Tool):
330 | name = "get_weather"
331 | description = "Get weather in the next days at given location."
332 | inputs = {
333 | "location": {"type": "string", "description": "the location"},
334 | "celsius": {
335 | "type": "string",
336 | "description": "the temperature type",
337 | },
338 | }
339 | output_type = "string"
340 |
341 | def forward(
342 | self, location: str, celsius: Optional[bool] = False
343 | ) -> str:
344 | return "The weather is UNGODLY with torrential rains and temperatures below -10°C"
345 |
346 | GetWeatherTool()
347 | assert "Nullable" in str(e)
348 |
349 | with pytest.raises(Exception) as e:
350 |
351 | class GetWeatherTool2(Tool):
352 | name = "get_weather"
353 | description = "Get weather in the next days at given location."
354 | inputs = {
355 | "location": {"type": "string", "description": "the location"},
356 | "celsius": {
357 | "type": "string",
358 | "description": "the temperature type",
359 | },
360 | }
361 | output_type = "string"
362 |
363 | def forward(self, location: str, celsius: bool = False) -> str:
364 | return "The weather is UNGODLY with torrential rains and temperatures below -10°C"
365 |
366 | GetWeatherTool2()
367 | assert "Nullable" in str(e)
368 |
369 | with pytest.raises(Exception) as e:
370 |
371 | class GetWeatherTool3(Tool):
372 | name = "get_weather"
373 | description = "Get weather in the next days at given location."
374 | inputs = {
375 | "location": {"type": "string", "description": "the location"},
376 | "celsius": {
377 | "type": "string",
378 | "description": "the temperature type",
379 | "nullable": True,
380 | },
381 | }
382 | output_type = "string"
383 |
384 | def forward(self, location, celsius: str) -> str:
385 | return "The weather is UNGODLY with torrential rains and temperatures below -10°C"
386 |
387 | GetWeatherTool3()
388 | assert "Nullable" in str(e)
389 |
--------------------------------------------------------------------------------