├── .devcontainer └── devcontainer.json ├── .dockerignore ├── .gitignore ├── Cargo.toml ├── Dockerfile ├── LICENSE ├── README.md ├── examples ├── cli.rs └── function_calling_agent.rs └── src ├── agents.rs ├── bin └── main.rs ├── errors.rs ├── lib.rs ├── local_python_interpreter.rs ├── logger.rs ├── models ├── mod.rs ├── model_traits.rs ├── ollama.rs ├── openai.rs └── types.rs ├── prompts.rs └── tools ├── base.rs ├── ddg_search.rs ├── final_answer.rs ├── google_search.rs ├── mod.rs ├── python_interpreter.rs ├── tool_traits.rs └── visit_website.rs /.devcontainer/devcontainer.json: -------------------------------------------------------------------------------- 1 | { 2 | "image": "mcr.microsoft.com/devcontainers/universal:2", 3 | "features": { 4 | "ghcr.io/devcontainers/features/rust:1": {} 5 | } 6 | } 7 | -------------------------------------------------------------------------------- /.dockerignore: -------------------------------------------------------------------------------- 1 | # Generated by Cargo 2 | # will have compiled files and executables 3 | debug/ 4 | target/ 5 | 6 | # Remove Cargo.lock from gitignore if creating an executable, leave it for libraries 7 | # More information here https://doc.rust-lang.org/cargo/guide/cargo-toml-vs-cargo-lock.html 8 | Cargo.lock 9 | 10 | # These are backup files generated by rustfmt 11 | **/*.rs.bk 12 | 13 | # MSVC Windows builds of rustc generate these, which store debugging information 14 | *.pdb 15 | 16 | # RustRover 17 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Generated by Cargo 2 | # will have compiled files and executables 3 | debug/ 4 | target/ 5 | 6 | # Remove Cargo.lock from gitignore if creating an executable, leave it for libraries 7 | # More information here https://doc.rust-lang.org/cargo/guide/cargo-toml-vs-cargo-lock.html 8 | Cargo.lock 9 | 10 | # These are backup files generated by rustfmt 11 | **/*.rs.bk 12 | 13 | # MSVC Windows builds of rustc generate these, which store debugging information 14 | *.pdb 15 | 16 | # RustRover 17 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 18 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 19 | # and can be added to the global gitignore or merged into this file. For a more nuclear 20 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 21 | #.idea/ 22 | *.ipynb -------------------------------------------------------------------------------- /Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "smolagents-rs" 3 | version = "0.1.2" 4 | edition = "2021" 5 | description = "A rust port of the the HuggingFace smolagents library. Build LLM agents with tools and code execution." 6 | license = "Apache-2.0" 7 | authors = ["Akshay Ballal "] 8 | repository = "https://github.com/akshayballal95/smolagents-rs" 9 | 10 | [dependencies] 11 | htmd = "0.1.6" 12 | reqwest = {version = "0.12.12", features = ['blocking', 'json']} 13 | anyhow = "1.0.95" 14 | serde = {version = "1.0.217", features = ["derive"]} 15 | serde_json = "1.0.135" 16 | log = "0.4" 17 | colored = "3.0.0" 18 | scraper = "0.22.0" 19 | terminal_size = "0.4.1" 20 | schemars = "0.8.21" 21 | chrono = "0.4.39" 22 | rustpython-parser = { version = "0.4.0", optional = true } 23 | pyo3 = { version = "0.19", features = ["auto-initialize"], optional = true } 24 | regex = "1.11.0" 25 | 26 | [dev-dependencies] 27 | clap = { version = "4.5.1", features = ["derive"] } 28 | textwrap = "0.16.0" 29 | 30 | 31 | [[bin]] 32 | name = "smolagents-rs" 33 | path = "src/bin/main.rs" 34 | required-features = ["cli"] 35 | 36 | [features] 37 | default = ["cli", "code-agent"] 38 | cli = ["dep:clap"] 39 | code-agent = ["dep:rustpython-parser", "dep:pyo3"] 40 | all = ["cli", "code-agent"] 41 | 42 | [dependencies.clap] 43 | version = "4.5.1" 44 | features = ["derive"] 45 | optional = true 46 | 47 | [[example]] 48 | name = "cli" 49 | required-features = ["cli", "code-agent"] 50 | 51 | [package.metadata.docs.rs] 52 | all-features = true 53 | rustdoc-args = ["--cfg", "docsrs"] -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | # Rust builder with cargo-chef 2 | FROM lukemathwalker/cargo-chef:latest AS chef 3 | WORKDIR /app 4 | 5 | FROM chef AS planner 6 | COPY . . 7 | RUN cargo chef prepare --recipe-path recipe.json 8 | 9 | FROM chef AS builder 10 | COPY --from=planner /app/recipe.json recipe.json 11 | # Build dependencies - this layer is cached as long as dependencies don't change 12 | RUN cargo chef cook --release --recipe-path recipe.json 13 | # Build application 14 | COPY . . 15 | # Build with minimal features and optimize for size 16 | RUN cargo build --release --bin smolagents-rs --features cli-deps \ 17 | && strip /app/target/release/smolagents-rs 18 | 19 | # Use distroless as runtime image 20 | FROM gcr.io/distroless/cc-debian12 AS runtime 21 | WORKDIR /app 22 | # Copy only the binary 23 | COPY --from=builder /app/target/release/smolagents-rs /usr/local/bin/ 24 | # Create config directory 25 | WORKDIR /root/.config/smolagents-rs 26 | 27 | ENTRYPOINT ["/usr/local/bin/smolagents-rs"] 28 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # 🤖 smolagents-rs 2 | 3 | This is a rust implementation of HF [smolagents](https://github.com/huggingface/smolagents) library. It provides a powerful autonomous agent framework written in Rust that solves complex tasks using tools and LLM models. 4 | 5 | --- 6 | 7 | ## ✨ Features 8 | 9 | - 🧠 **Function-Calling Agent Architecture**: Implements the ReAct framework for advanced reasoning and action. 10 | - 🔍 **Built-in Tools**: 11 | - Google Search 12 | - DuckDuckGo Search 13 | - Website Visit & Scraping 14 | - 🤝 **OpenAI Integration**: Works seamlessly with GPT models. 15 | - 🎯 **Task Execution**: Enables autonomous completion of complex tasks. 16 | - 🔄 **State Management**: Maintains persistent state across steps. 17 | - 📊 **Beautiful Logging**: Offers colored terminal output for easy debugging. 18 | 19 | --- 20 | 21 | ![demo](https://res.cloudinary.com/dltwftrgc/image/upload/v1737485304/smolagents-small_fmaikq.gif) 22 | 23 | ## ✅ Feature Checklist 24 | 25 | ### Models 26 | 27 | - [x] OpenAI Models (e.g., GPT-4o, GPT-4o-mini) 28 | - [x] Ollama Integration 29 | - [ ] Hugging Face API support 30 | - [ ] Open-source model integration via Candle 31 | - [ ] Light LLM integration 32 | 33 | ### Agents 34 | 35 | - [x] Tool-Calling Agent 36 | - [x] CodeAgent 37 | - [ ] Planning Agent 38 | 39 | The code agent is still in development, so there might be python code that is not yet supported and may cause errors. Try using the tool-calling agent for now. 40 | 41 | ### Tools 42 | 43 | - [x] Google Search Tool 44 | - [x] DuckDuckGo Tool 45 | - [x] Website Visit & Scraping Tool 46 | - [ ] RAG Tool 47 | - More tools to come... 48 | 49 | ### Other 50 | 51 | - [ ] Sandbox environment 52 | - [ ] Streaming output 53 | - [ ] Improve logging 54 | - [ ] Parallel execution 55 | 56 | --- 57 | 58 | ## 🚀 Quick Start 59 | 60 | Warning: Since there is no implementation of a Sandbox environment, be careful with the tools you use. Preferrably run the agent in a controlled environment using a Docker container. 61 | 62 | ### Using Docker 63 | 64 | ```bash 65 | # Pull the image 66 | docker pull akshayballal95/smolagents-rs:latest 67 | 68 | # Run with your OpenAI API key 69 | docker run -e OPENAI_API_KEY=your-key-here smolagents-rs -t "What is the latest news about Rust programming?" 70 | ``` 71 | 72 | ### Building from Source 73 | 74 | ```bash 75 | # Clone the repository 76 | git clone https://github.com/yourusername/smolagents-rs.git 77 | cd smolagents-rs 78 | 79 | # Build the project 80 | cargo build --release --features cli-deps 81 | 82 | # Run the agent 83 | OPENAI_API_KEY=your-key-here ./target/release/smolagents-rs -t "Your task here" 84 | ``` 85 | 86 | --- 87 | 88 | ## 🛠️ Usage 89 | 90 | ```bash 91 | smolagents-rs [OPTIONS] -t TASK 92 | 93 | Options: 94 | -t, --task The task to execute 95 | -a, --agent-type Agent type [default: function-calling] 96 | -l, --tools Comma-separated list of tools [default: duckduckgo,visit-website] 97 | -m, --model Model type [default: open-ai] 98 | -k, --api-key OpenAI API key (only required for OpenAI model) 99 | --model-id Model ID (e.g., "gpt-4" for OpenAI or "qwen2.5" for Ollama) [default: gpt-4o-mini] 100 | -u, --ollama-url Ollama server URL [default: http://localhost:11434] 101 | -s, --stream Enable streaming output 102 | -h, --help Print help 103 | ``` 104 | 105 | --- 106 | 107 | ## 🌟 Examples 108 | 109 | ```bash 110 | # Simple search task 111 | smolagents-rs -t "What are the main features of Rust 1.75?" 112 | 113 | # Research with multiple tools 114 | smolagents-rs -t "Compare Rust and Go performance" -l duckduckgo,google-search,visit-website 115 | 116 | # Stream output for real-time updates 117 | smolagents-rs -t "Analyze the latest crypto trends" -s 118 | ``` 119 | 120 | --- 121 | 122 | ## 🔧 Configuration 123 | 124 | ### Environment Variables 125 | 126 | - `OPENAI_API_KEY`: Your OpenAI API key (required). 127 | - `SERPAPI_API_KEY`: Google Search API key (optional). 128 | 129 | --- 130 | 131 | ## 🏗️ Architecture 132 | 133 | The project follows a modular architecture with the following components: 134 | 135 | 1. **Agent System**: Implements the ReAct framework. 136 | 137 | 2. **Tool System**: An extensible tool framework for seamless integration of new tools. 138 | 139 | 3. **Model Integration**: Robust OpenAI API integration for powerful LLM capabilities. 140 | 141 | --- 142 | 143 | ## 🚀 Why port to Rust? 144 | Rust provides critical advantages that make it the ideal choice for smolagents-rs: 145 | 146 | 1. ⚡ **High Performance**:
147 | Zero-cost abstractions and no garbage collector overhead enable smolagents-rs to handle complex agent tasks with near-native performance. This is crucial for running multiple agents and processing large amounts of data efficiently. 148 | 149 | 2. 🛡️ **Memory Safety & Security**:
150 | Rust's compile-time guarantees prevent memory-related vulnerabilities and data races - essential for an agent system that handles sensitive API keys and interacts with external resources. The ownership model ensures thread-safe concurrent operations without runtime overhead. 151 | 152 | 3. 🔄 **Powerful Concurrency**:
153 | Fearless concurrency through the ownership system enable smolagents-rs to efficiently manage multiple agents and tools in parallel, maximizing resource utilization. 154 | 155 | 4. 💻 **Universal Deployment**:
156 | Compile once, run anywhere - from high-performance servers to WebAssembly in browsers. This allows smolagents-rs to run natively on any platform or be embedded in web applications with near-native performance. 157 | 158 | Apart from this, its essential to push new technologies around agentic systems to the Rust ecoystem and this library aims to do so. 159 | 160 | --- 161 | 162 | ## 🤝 Contributing 163 | 164 | Contributions are welcome! To contribute: 165 | 166 | 1. Fork the repository. 167 | 2. Create your feature branch (`git checkout -b feature/amazing-feature`). 168 | 3. Commit your changes (`git commit -m 'Add some amazing feature'`). 169 | 4. Push to the branch (`git push origin feature/amazing-feature`). 170 | 5. Open a Pull Request. 171 | 172 | 173 | --- 174 | 175 | ## ⭐ Show Your Support 176 | 177 | Give a ⭐️ if this project helps you or inspires your work! 178 | 179 | -------------------------------------------------------------------------------- /examples/cli.rs: -------------------------------------------------------------------------------- 1 | use std::collections::HashMap; 2 | 3 | use anyhow::Result; 4 | use clap::{Parser, ValueEnum}; 5 | use serde_json; 6 | use smolagents_rs::agents::Step; 7 | use smolagents_rs::agents::{Agent, CodeAgent, FunctionCallingAgent}; 8 | use smolagents_rs::errors::AgentError; 9 | use smolagents_rs::models::model_traits::{Model, ModelResponse}; 10 | use smolagents_rs::models::ollama::{OllamaModel, OllamaModelBuilder}; 11 | use smolagents_rs::models::openai::OpenAIServerModel; 12 | use smolagents_rs::models::types::Message; 13 | use smolagents_rs::tools::{ 14 | AnyTool, DuckDuckGoSearchTool, GoogleSearchTool, ToolInfo, VisitWebsiteTool, 15 | }; 16 | use std::fs::File; 17 | 18 | #[derive(Debug, Clone, ValueEnum)] 19 | enum AgentType { 20 | FunctionCalling, 21 | Code, 22 | } 23 | 24 | #[derive(Debug, Clone, ValueEnum)] 25 | enum ToolType { 26 | DuckDuckGo, 27 | VisitWebsite, 28 | GoogleSearchTool, 29 | } 30 | 31 | #[derive(Debug, Clone, ValueEnum)] 32 | enum ModelType { 33 | OpenAI, 34 | Ollama, 35 | } 36 | 37 | #[derive(Debug)] 38 | enum ModelWrapper { 39 | OpenAI(OpenAIServerModel), 40 | Ollama(OllamaModel), 41 | } 42 | 43 | enum AgentWrapper { 44 | FunctionCalling(FunctionCallingAgent), 45 | Code(CodeAgent), 46 | } 47 | 48 | impl AgentWrapper { 49 | fn run(&mut self, task: &str, stream: bool, reset: bool) -> Result { 50 | match self { 51 | AgentWrapper::FunctionCalling(agent) => agent.run(task, stream, reset), 52 | AgentWrapper::Code(agent) => agent.run(task, stream, reset), 53 | } 54 | } 55 | fn get_logs_mut(&mut self) -> &mut Vec { 56 | match self { 57 | AgentWrapper::FunctionCalling(agent) => agent.get_logs_mut(), 58 | AgentWrapper::Code(agent) => agent.get_logs_mut(), 59 | } 60 | } 61 | } 62 | impl Model for ModelWrapper { 63 | fn run( 64 | &self, 65 | messages: Vec, 66 | tools: Vec, 67 | max_tokens: Option, 68 | args: Option>>, 69 | ) -> Result, AgentError> { 70 | match self { 71 | ModelWrapper::OpenAI(m) => Ok(m.run(messages, tools, max_tokens, args)?), 72 | ModelWrapper::Ollama(m) => Ok(m.run(messages, tools, max_tokens, args)?), 73 | } 74 | } 75 | } 76 | 77 | #[derive(Parser, Debug)] 78 | #[command(author, version, about, long_about = None)] 79 | struct Args { 80 | /// The type of agent to use 81 | #[arg(short = 'a', long, value_enum, default_value = "function-calling")] 82 | agent_type: AgentType, 83 | 84 | /// List of tools to use 85 | #[arg(short = 'l', long = "tools", value_enum, num_args = 1.., value_delimiter = ',', default_values_t = [ToolType::DuckDuckGo, ToolType::VisitWebsite])] 86 | tools: Vec, 87 | 88 | /// The type of model to use 89 | #[arg(short = 'm', long, value_enum, default_value = "open-ai")] 90 | model_type: ModelType, 91 | 92 | /// OpenAI API key (only required for OpenAI model) 93 | #[arg(short = 'k', long)] 94 | api_key: Option, 95 | 96 | /// Model ID (e.g., "gpt-4" for OpenAI or "qwen2.5" for Ollama) 97 | #[arg(long, default_value = "gpt-4o-mini")] 98 | model_id: String, 99 | 100 | /// Whether to stream the output 101 | #[arg(short, long, default_value = "false")] 102 | stream: bool, 103 | 104 | /// Whether to reset the agent 105 | #[arg(short, long, default_value = "false")] 106 | reset: bool, 107 | 108 | /// The task to execute 109 | #[arg(short, long)] 110 | task: String, 111 | 112 | /// Base URL for the API 113 | #[arg(short, long)] 114 | base_url: Option, 115 | } 116 | 117 | fn create_tool(tool_type: &ToolType) -> Box { 118 | match tool_type { 119 | ToolType::DuckDuckGo => Box::new(DuckDuckGoSearchTool::new()), 120 | ToolType::VisitWebsite => Box::new(VisitWebsiteTool::new()), 121 | ToolType::GoogleSearchTool => Box::new(GoogleSearchTool::new(None)), 122 | } 123 | } 124 | 125 | fn main() -> Result<()> { 126 | let args = Args::parse(); 127 | 128 | let tools: Vec> = args.tools.iter().map(create_tool).collect(); 129 | 130 | // Create model based on type 131 | let model = match args.model_type { 132 | ModelType::OpenAI => ModelWrapper::OpenAI(OpenAIServerModel::new( 133 | args.base_url.as_deref(), 134 | Some(&args.model_id), 135 | None, 136 | args.api_key, 137 | )), 138 | ModelType::Ollama => ModelWrapper::Ollama( 139 | OllamaModelBuilder::new() 140 | .model_id(&args.model_id) 141 | .ctx_length(8000) 142 | .build(), 143 | ), 144 | }; 145 | 146 | // Create agent based on type 147 | let mut agent = match args.agent_type { 148 | AgentType::FunctionCalling => AgentWrapper::FunctionCalling(FunctionCallingAgent::new( 149 | model, 150 | tools, 151 | None, 152 | None, 153 | Some("CLI Agent"), 154 | None, 155 | )?), 156 | AgentType::Code => AgentWrapper::Code(CodeAgent::new( 157 | model, 158 | tools, 159 | None, 160 | None, 161 | Some("CLI Agent"), 162 | None, 163 | )?), 164 | }; 165 | 166 | // Run the agent with the task from stdin 167 | let _result = agent.run(&args.task, args.stream, args.reset)?; 168 | let logs = agent.get_logs_mut(); 169 | 170 | // store logs in a file 171 | let mut file = File::create("logs.txt")?; 172 | 173 | // Get the last log entry and serialize it in a controlled way 174 | for log in logs { 175 | // Serialize to JSON with pretty printing 176 | serde_json::to_writer_pretty(&mut file, &log)?; 177 | } 178 | 179 | Ok(()) 180 | } 181 | -------------------------------------------------------------------------------- /examples/function_calling_agent.rs: -------------------------------------------------------------------------------- 1 | use smolagents_rs::agents::{Agent, FunctionCallingAgent}; 2 | use smolagents_rs::models::openai::OpenAIServerModel; 3 | use smolagents_rs::tools::{AnyTool, DuckDuckGoSearchTool, VisitWebsiteTool}; 4 | 5 | fn main() { 6 | let tools: Vec> = vec![ 7 | Box::new(DuckDuckGoSearchTool::new()), 8 | Box::new(VisitWebsiteTool::new()), 9 | ]; 10 | let model = OpenAIServerModel::new( 11 | Some("https://api.openai.com/v1/chat/completions"), 12 | Some("gpt-4o-mini"), 13 | None, 14 | None, 15 | ); 16 | let mut agent = FunctionCallingAgent::new(model, tools, None, None, None, None).unwrap(); 17 | let _result = agent 18 | .run("Who has the most followers on Twitter?", false, false) 19 | .unwrap(); 20 | } 21 | -------------------------------------------------------------------------------- /src/agents.rs: -------------------------------------------------------------------------------- 1 | //! This module contains the agents that can be used to solve tasks. 2 | //! 3 | //! Currently, there are two agents: 4 | //! - The function calling agent. This agent is used for models that have tool calling capabilities. 5 | //! - The code agent. This agent takes tools and can write simple python code that is executed to solve the task. 6 | //! 7 | //! To use this agent you need to enable the `code-agent` feature. 8 | //! 9 | //! You can also implement your own agents by implementing the `Agent` trait. 10 | //! 11 | //! Planning agent is not implemented yet and will be added in the future. 12 | //! 13 | use crate::errors::AgentError; 14 | use crate::models::model_traits::Model; 15 | use crate::models::openai::ToolCall; 16 | use crate::models::types::Message; 17 | use crate::models::types::MessageRole; 18 | use crate::prompts::{ 19 | user_prompt_plan, SYSTEM_PROMPT_FACTS, SYSTEM_PROMPT_PLAN, TOOL_CALLING_SYSTEM_PROMPT, 20 | }; 21 | use crate::tools::{AnyTool, FinalAnswerTool, ToolGroup, ToolInfo}; 22 | use std::collections::HashMap; 23 | 24 | use crate::logger::LOGGER; 25 | use anyhow::Result; 26 | use colored::Colorize; 27 | use log::info; 28 | 29 | use serde::Serialize; 30 | use serde_json::json; 31 | #[cfg(feature = "code-agent")] 32 | use { 33 | crate::errors::InterpreterError, crate::local_python_interpreter::LocalPythonInterpreter, 34 | crate::models::openai::FunctionCall, crate::prompts::CODE_SYSTEM_PROMPT, regex::Regex, 35 | }; 36 | 37 | const DEFAULT_TOOL_DESCRIPTION_TEMPLATE: &str = r#" 38 | {{ tool.name }}: {{ tool.description }} 39 | Takes inputs: {{tool.inputs}} 40 | "#; 41 | 42 | use std::fmt::Debug; 43 | 44 | pub fn get_tool_description_with_args(tool: &ToolInfo) -> String { 45 | let mut description = DEFAULT_TOOL_DESCRIPTION_TEMPLATE.to_string(); 46 | description = description.replace("{{ tool.name }}", tool.function.name); 47 | description = description.replace("{{ tool.description }}", tool.function.description); 48 | description = description.replace( 49 | "{{tool.inputs}}", 50 | json!(&tool.function.parameters.schema)["properties"] 51 | .to_string() 52 | .as_str(), 53 | ); 54 | 55 | description 56 | } 57 | 58 | pub fn get_tool_descriptions(tools: &[ToolInfo]) -> Vec { 59 | tools.iter().map(get_tool_description_with_args).collect() 60 | } 61 | pub fn format_prompt_with_tools(tools: Vec, prompt_template: &str) -> String { 62 | let tool_descriptions = get_tool_descriptions(&tools); 63 | let mut prompt = prompt_template.to_string(); 64 | prompt = prompt.replace("{{tool_descriptions}}", &tool_descriptions.join("\n")); 65 | if prompt.contains("{{tool_names}}") { 66 | let tool_names: Vec = tools 67 | .iter() 68 | .map(|tool| tool.function.name.to_string()) 69 | .collect(); 70 | prompt = prompt.replace("{{tool_names}}", &tool_names.join(", ")); 71 | } 72 | prompt 73 | } 74 | 75 | pub fn show_agents_description(managed_agents: &HashMap>) -> String { 76 | let mut managed_agent_description = r#"You can also give requests to team members. 77 | Calling a team member works the same as for calling a tool: simply, the only argument you can give in the call is 'request', a long string explaining your request. 78 | Given that this team member is a real human, you should be very verbose in your request. 79 | Here is a list of the team members that you can call:"#.to_string(); 80 | 81 | for (name, agent) in managed_agents.iter() { 82 | managed_agent_description.push_str(&format!("{}: {:?}\n", name, agent.description())); 83 | } 84 | 85 | managed_agent_description 86 | } 87 | 88 | pub fn format_prompt_with_managed_agent_description( 89 | prompt_template: String, 90 | managed_agents: &HashMap>, 91 | agent_descriptions_placeholder: Option<&str>, 92 | ) -> Result { 93 | let agent_descriptions_placeholder = 94 | agent_descriptions_placeholder.unwrap_or("{{managed_agents_descriptions}}"); 95 | 96 | if managed_agents.keys().len() > 0 { 97 | Ok(prompt_template.replace( 98 | agent_descriptions_placeholder, 99 | &show_agents_description(managed_agents), 100 | )) 101 | } else { 102 | Ok(prompt_template.replace(agent_descriptions_placeholder, "")) 103 | } 104 | } 105 | 106 | pub trait Agent { 107 | fn name(&self) -> &'static str; 108 | fn get_max_steps(&self) -> usize; 109 | fn get_step_number(&self) -> usize; 110 | fn reset_step_number(&mut self); 111 | fn increment_step_number(&mut self); 112 | fn get_logs_mut(&mut self) -> &mut Vec; 113 | fn set_task(&mut self, task: &str); 114 | fn get_system_prompt(&self) -> &str; 115 | fn description(&self) -> String { 116 | "".to_string() 117 | } 118 | fn model(&self) -> &dyn Model; 119 | fn step(&mut self, log_entry: &mut Step) -> Result>; 120 | fn direct_run(&mut self, _task: &str) -> Result { 121 | let mut final_answer: Option = None; 122 | while final_answer.is_none() && self.get_step_number() < self.get_max_steps() { 123 | println!("Step number: {:?}", self.get_step_number()); 124 | let mut step_log = Step::ActionStep(AgentStep { 125 | agent_memory: None, 126 | llm_output: None, 127 | tool_call: None, 128 | error: None, 129 | observations: None, 130 | _step: self.get_step_number(), 131 | }); 132 | 133 | final_answer = self.step(&mut step_log)?; 134 | self.get_logs_mut().push(step_log); 135 | self.increment_step_number(); 136 | } 137 | 138 | if final_answer.is_none() && self.get_step_number() >= self.get_max_steps() { 139 | final_answer = self.provide_final_answer(_task)?; 140 | } 141 | info!( 142 | "Final answer: {}", 143 | final_answer 144 | .clone() 145 | .unwrap_or("Could not find answer".to_string()) 146 | ); 147 | Ok(final_answer.unwrap_or_else(|| "Max steps reached without final answer".to_string())) 148 | } 149 | fn stream_run(&mut self, _task: &str) -> Result { 150 | todo!() 151 | } 152 | fn run(&mut self, task: &str, stream: bool, reset: bool) -> Result { 153 | // self.task = task.to_string(); 154 | self.set_task(task); 155 | 156 | let system_prompt_step = Step::SystemPromptStep(self.get_system_prompt().to_string()); 157 | if reset { 158 | self.get_logs_mut().clear(); 159 | self.get_logs_mut().push(system_prompt_step); 160 | self.reset_step_number(); 161 | } else if self.get_logs_mut().is_empty() { 162 | self.get_logs_mut().push(system_prompt_step); 163 | } else { 164 | self.get_logs_mut()[0] = system_prompt_step; 165 | } 166 | self.get_logs_mut().push(Step::TaskStep(task.to_string())); 167 | match stream { 168 | true => self.stream_run(task), 169 | false => self.direct_run(task), 170 | } 171 | } 172 | fn provide_final_answer(&mut self, task: &str) -> Result> { 173 | let mut input_messages = vec![Message { 174 | role: MessageRole::System, 175 | content: "An agent tried to answer a user query but it got stuck and failed to do so. You are tasked with providing an answer instead. Here is the agent's memory:".to_string(), 176 | }]; 177 | 178 | input_messages.extend(self.write_inner_memory_from_logs(Some(true))?[1..].to_vec()); 179 | input_messages.push(Message { 180 | role: MessageRole::User, 181 | content: format!("Based on the above, please provide an answer to the following user request: \n```\n{}", task), 182 | }); 183 | let response = self 184 | .model() 185 | .run(input_messages, vec![], None, None)? 186 | .get_response()?; 187 | Ok(Some(response)) 188 | } 189 | 190 | fn write_inner_memory_from_logs(&mut self, summary_mode: Option) -> Result> { 191 | let mut memory = Vec::new(); 192 | let summary_mode = summary_mode.unwrap_or(false); 193 | for log in self.get_logs_mut() { 194 | match log { 195 | Step::ToolCall(_) => {} 196 | Step::PlanningStep(plan, facts) => { 197 | memory.push(Message { 198 | role: MessageRole::Assistant, 199 | content: "[PLAN]:\n".to_owned() + plan.as_str(), 200 | }); 201 | 202 | if !summary_mode { 203 | memory.push(Message { 204 | role: MessageRole::Assistant, 205 | content: "[FACTS]:\n".to_owned() + facts.as_str(), 206 | }); 207 | } 208 | } 209 | Step::TaskStep(task) => { 210 | memory.push(Message { 211 | role: MessageRole::User, 212 | content: "New Task: ".to_owned() + task.as_str(), 213 | }); 214 | } 215 | Step::SystemPromptStep(prompt) => { 216 | memory.push(Message { 217 | role: MessageRole::System, 218 | content: prompt.to_string(), 219 | }); 220 | } 221 | Step::ActionStep(step_log) => { 222 | if step_log.llm_output.is_some() && !summary_mode { 223 | memory.push(Message { 224 | role: MessageRole::Assistant, 225 | content: step_log.llm_output.clone().unwrap_or_default(), 226 | }); 227 | } 228 | if step_log.tool_call.is_some() { 229 | let tool_call_message = step_log 230 | .tool_call 231 | .clone() 232 | .unwrap() 233 | .iter() 234 | .map(|tool_call| -> Message { 235 | Message { 236 | role: MessageRole::Assistant, 237 | content: serde_json::to_string_pretty(&tool_call) 238 | .unwrap_or_default(), 239 | } 240 | }) 241 | .collect::>(); 242 | memory.extend(tool_call_message); 243 | } 244 | 245 | if let (Some(tool_calls), Some(observations)) = 246 | (&step_log.tool_call, &step_log.observations) 247 | { 248 | for (i, tool_call) in tool_calls.iter().enumerate() { 249 | let message_content = format!( 250 | "Call id: {}\nObservation: {}", 251 | tool_call.id.as_deref().unwrap_or_default(), 252 | observations[i] 253 | ); 254 | 255 | memory.push(Message { 256 | role: MessageRole::User, 257 | content: message_content, 258 | }); 259 | } 260 | } else if let Some(observations) = &step_log.observations { 261 | memory.push(Message { 262 | role: MessageRole::User, 263 | content: format!("Observations: {}", observations.join("\n")), 264 | }); 265 | } 266 | if step_log.error.is_some() { 267 | let error_string = 268 | "Error: ".to_owned() + step_log.error.clone().unwrap().message(); // Its fine to unwrap because we check for None above 269 | 270 | let error_string = error_string + "\nNow let's retry: take care not to repeat previous errors! If you have retried several times, try a completely different approach.\n"; 271 | memory.push(Message { 272 | role: MessageRole::User, 273 | content: error_string, 274 | }); 275 | } 276 | } 277 | } 278 | } 279 | Ok(memory) 280 | } 281 | } 282 | 283 | #[derive(Debug, Serialize)] 284 | pub enum Step { 285 | PlanningStep(String, String), 286 | TaskStep(String), 287 | SystemPromptStep(String), 288 | ActionStep(AgentStep), 289 | ToolCall(ToolCall), 290 | } 291 | 292 | impl std::fmt::Display for Step { 293 | fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 294 | match self { 295 | Step::PlanningStep(plan, facts) => { 296 | write!(f, "PlanningStep(plan: {}, facts: {})", plan, facts) 297 | } 298 | Step::TaskStep(task) => write!(f, "TaskStep({})", task), 299 | Step::SystemPromptStep(prompt) => write!(f, "SystemPromptStep({})", prompt), 300 | Step::ActionStep(step) => write!(f, "ActionStep({})", step), 301 | Step::ToolCall(tool_call) => write!(f, "ToolCall({:?})", tool_call), 302 | } 303 | } 304 | } 305 | 306 | #[derive(Debug, Clone, Serialize)] 307 | pub struct AgentStep { 308 | agent_memory: Option>, 309 | llm_output: Option, 310 | tool_call: Option>, 311 | error: Option, 312 | observations: Option>, 313 | _step: usize, 314 | } 315 | 316 | impl std::fmt::Display for AgentStep { 317 | fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 318 | write!(f, "AgentStep({:?})", self) 319 | } 320 | } 321 | 322 | // Define a trait for the parent functionality 323 | 324 | pub struct MultiStepAgent { 325 | pub model: M, 326 | pub tools: Vec>, 327 | pub system_prompt_template: String, 328 | pub name: &'static str, 329 | pub managed_agents: Option>>, 330 | pub description: String, 331 | pub max_steps: usize, 332 | pub step_number: usize, 333 | pub task: String, 334 | pub input_messages: Option>, 335 | pub logs: Vec, 336 | } 337 | 338 | impl Agent for MultiStepAgent { 339 | fn name(&self) -> &'static str { 340 | self.name 341 | } 342 | fn get_max_steps(&self) -> usize { 343 | self.max_steps 344 | } 345 | fn get_step_number(&self) -> usize { 346 | self.step_number 347 | } 348 | fn set_task(&mut self, task: &str) { 349 | self.task = task.to_string(); 350 | } 351 | fn get_system_prompt(&self) -> &str { 352 | &self.system_prompt_template 353 | } 354 | fn increment_step_number(&mut self) { 355 | self.step_number += 1; 356 | } 357 | fn reset_step_number(&mut self) { 358 | self.step_number = 0; 359 | } 360 | fn get_logs_mut(&mut self) -> &mut Vec { 361 | &mut self.logs 362 | } 363 | fn description(&self) -> String { 364 | self.description.clone() 365 | } 366 | fn model(&self) -> &dyn Model { 367 | &self.model 368 | } 369 | 370 | /// Perform one step in the ReAct framework: the agent thinks, acts, and observes the result. 371 | /// 372 | /// Returns None if the step is not final. 373 | fn step(&mut self, _: &mut Step) -> Result> { 374 | todo!() 375 | } 376 | } 377 | 378 | impl MultiStepAgent { 379 | pub fn new( 380 | model: M, 381 | mut tools: Vec>, 382 | system_prompt: Option<&str>, 383 | managed_agents: Option>>, 384 | description: Option<&str>, 385 | max_steps: Option, 386 | ) -> Result { 387 | // Initialize logger 388 | log::set_logger(&LOGGER).unwrap(); 389 | log::set_max_level(log::LevelFilter::Info); 390 | 391 | let name = "MultiStepAgent"; 392 | 393 | let system_prompt_template = match system_prompt { 394 | Some(prompt) => prompt.to_string(), 395 | None => TOOL_CALLING_SYSTEM_PROMPT.to_string(), 396 | }; 397 | let description = match description { 398 | Some(desc) => desc.to_string(), 399 | None => "A multi-step agent that can solve tasks using a series of tools".to_string(), 400 | }; 401 | 402 | let final_answer_tool = FinalAnswerTool::new(); 403 | tools.push(Box::new(final_answer_tool)); 404 | 405 | let mut agent = MultiStepAgent { 406 | model, 407 | tools, 408 | system_prompt_template, 409 | name, 410 | managed_agents, 411 | description, 412 | max_steps: max_steps.unwrap_or(10), 413 | step_number: 0, 414 | task: "".to_string(), 415 | logs: Vec::new(), 416 | input_messages: None, 417 | }; 418 | 419 | agent.initialize_system_prompt()?; 420 | Ok(agent) 421 | } 422 | 423 | fn initialize_system_prompt(&mut self) -> Result { 424 | let tools = self.tools.tool_info(); 425 | self.system_prompt_template = format_prompt_with_tools(tools, &self.system_prompt_template); 426 | match &self.managed_agents { 427 | Some(managed_agents) => { 428 | self.system_prompt_template = format_prompt_with_managed_agent_description( 429 | self.system_prompt_template.clone(), 430 | managed_agents, 431 | None, 432 | )?; 433 | } 434 | None => { 435 | self.system_prompt_template = format_prompt_with_managed_agent_description( 436 | self.system_prompt_template.clone(), 437 | &HashMap::new(), 438 | None, 439 | )?; 440 | } 441 | } 442 | self.system_prompt_template = self 443 | .system_prompt_template 444 | .replace("{{current_time}}", &chrono::Local::now().to_string()); 445 | Ok(self.system_prompt_template.clone()) 446 | } 447 | 448 | pub fn planning_step(&mut self, task: &str, is_first_step: bool, _step: usize) { 449 | if is_first_step { 450 | let message_prompt_facts = Message { 451 | role: MessageRole::System, 452 | content: SYSTEM_PROMPT_FACTS.to_string(), 453 | }; 454 | let message_prompt_task = Message { 455 | role: MessageRole::User, 456 | content: format!( 457 | "Here is the task: ``` 458 | {} 459 | ``` 460 | Now Begin! 461 | ", 462 | task 463 | ), 464 | }; 465 | 466 | let answer_facts = self 467 | .model 468 | .run( 469 | vec![message_prompt_facts, message_prompt_task], 470 | vec![], 471 | None, 472 | None, 473 | ) 474 | .unwrap() 475 | .get_response() 476 | .unwrap_or("".to_string()); 477 | let message_system_prompt_plan = Message { 478 | role: MessageRole::System, 479 | content: SYSTEM_PROMPT_PLAN.to_string(), 480 | }; 481 | let tool_descriptions = serde_json::to_string( 482 | &self 483 | .tools 484 | .iter() 485 | .map(|tool| tool.tool_info()) 486 | .collect::>(), 487 | ) 488 | .unwrap(); 489 | let message_user_prompt_plan = Message { 490 | role: MessageRole::User, 491 | content: user_prompt_plan( 492 | task, 493 | &tool_descriptions, 494 | &show_agents_description( 495 | self.managed_agents.as_ref().unwrap_or(&HashMap::new()), 496 | ), 497 | &answer_facts, 498 | ), 499 | }; 500 | let answer_plan = self 501 | .model 502 | .run( 503 | vec![message_system_prompt_plan, message_user_prompt_plan], 504 | vec![], 505 | None, 506 | Some(HashMap::from([( 507 | "stop".to_string(), 508 | vec!["Observation:".to_string()], 509 | )])), 510 | ) 511 | .unwrap() 512 | .get_response() 513 | .unwrap(); 514 | let final_plan_redaction = format!( 515 | "Here is the plan of action that I will follow for the task: \n{}", 516 | answer_plan 517 | ); 518 | let final_facts_redaction = 519 | format!("Here are the facts that I know so far: \n{}", answer_facts); 520 | self.logs.push(Step::PlanningStep( 521 | final_plan_redaction.clone(), 522 | final_facts_redaction, 523 | )); 524 | info!("Plan: {}", final_plan_redaction.blue().bold()); 525 | } 526 | } 527 | } 528 | 529 | pub struct FunctionCallingAgent { 530 | base_agent: MultiStepAgent, 531 | } 532 | 533 | impl FunctionCallingAgent { 534 | pub fn new( 535 | model: M, 536 | tools: Vec>, 537 | system_prompt: Option<&str>, 538 | managed_agents: Option>>, 539 | description: Option<&str>, 540 | max_steps: Option, 541 | ) -> Result { 542 | let system_prompt = system_prompt.unwrap_or(TOOL_CALLING_SYSTEM_PROMPT); 543 | let base_agent = MultiStepAgent::new( 544 | model, 545 | tools, 546 | Some(system_prompt), 547 | managed_agents, 548 | description, 549 | max_steps, 550 | )?; 551 | Ok(Self { base_agent }) 552 | } 553 | } 554 | 555 | impl Agent for FunctionCallingAgent { 556 | fn name(&self) -> &'static str { 557 | self.base_agent.name() 558 | } 559 | fn set_task(&mut self, task: &str) { 560 | self.base_agent.set_task(task); 561 | } 562 | fn get_system_prompt(&self) -> &str { 563 | self.base_agent.get_system_prompt() 564 | } 565 | fn get_max_steps(&self) -> usize { 566 | self.base_agent.get_max_steps() 567 | } 568 | fn get_step_number(&self) -> usize { 569 | self.base_agent.get_step_number() 570 | } 571 | fn reset_step_number(&mut self) { 572 | self.base_agent.reset_step_number(); 573 | } 574 | fn increment_step_number(&mut self) { 575 | self.base_agent.increment_step_number(); 576 | } 577 | fn get_logs_mut(&mut self) -> &mut Vec { 578 | self.base_agent.get_logs_mut() 579 | } 580 | fn model(&self) -> &dyn Model { 581 | self.base_agent.model() 582 | } 583 | 584 | /// Perform one step in the ReAct framework: the agent thinks, acts, and observes the result. 585 | /// 586 | /// Returns None if the step is not final. 587 | fn step(&mut self, log_entry: &mut Step) -> Result> { 588 | match log_entry { 589 | Step::ActionStep(step_log) => { 590 | let agent_memory = self.base_agent.write_inner_memory_from_logs(None)?; 591 | self.base_agent.input_messages = Some(agent_memory.clone()); 592 | step_log.agent_memory = Some(agent_memory.clone()); 593 | let tools = self 594 | .base_agent 595 | .tools 596 | .iter() 597 | .map(|tool| tool.tool_info()) 598 | .collect::>(); 599 | let model_message = self 600 | .base_agent 601 | .model 602 | .run( 603 | self.base_agent.input_messages.as_ref().unwrap().clone(), 604 | tools, 605 | None, 606 | Some(HashMap::from([( 607 | "stop".to_string(), 608 | vec!["Observation:".to_string()], 609 | )])), 610 | )?; 611 | 612 | let mut observations = Vec::new(); 613 | let tools = model_message.get_tools_used()?; 614 | step_log.tool_call = Some(tools.clone()); 615 | 616 | if let Ok(response) = model_message.get_response() { 617 | if !response.trim().is_empty() { 618 | observations.push(response.clone()); 619 | } 620 | if tools.is_empty() { 621 | return Ok(Some(response)); 622 | } 623 | } 624 | for tool in tools { 625 | let function_name = tool.clone().function.name; 626 | 627 | match function_name.as_str() { 628 | "final_answer" => { 629 | info!("Executing tool call: {}", function_name); 630 | let answer = self.base_agent.tools.call(&tool.function)?; 631 | self.base_agent.write_inner_memory_from_logs(None)?; 632 | return Ok(Some(answer)); 633 | } 634 | _ => { 635 | info!( 636 | "Executing tool call: {} with arguments: {:?}", 637 | function_name, tool.function.arguments 638 | ); 639 | let observation = self.base_agent.tools.call(&tool.function); 640 | match observation { 641 | Ok(observation) => { 642 | observations.push(format!( 643 | "Observation from {}: {}", 644 | function_name, 645 | observation.chars().take(30000).collect::() 646 | )); 647 | } 648 | Err(e) => { 649 | observations.push(e.to_string()); 650 | info!("Error: {}", e); 651 | } 652 | } 653 | } 654 | } 655 | } 656 | step_log.observations = Some(observations); 657 | 658 | info!( 659 | "Observation: {} \n ....This content has been truncated due to the 30000 character limit.....", 660 | step_log.observations.clone().unwrap_or_default().join("\n").trim().chars().take(30000).collect::() 661 | ); 662 | Ok(None) 663 | } 664 | _ => { 665 | todo!() 666 | } 667 | } 668 | } 669 | } 670 | 671 | #[cfg(feature = "code-agent")] 672 | pub struct CodeAgent { 673 | base_agent: MultiStepAgent, 674 | local_python_interpreter: LocalPythonInterpreter, 675 | } 676 | 677 | #[cfg(feature = "code-agent")] 678 | impl CodeAgent { 679 | pub fn new( 680 | model: M, 681 | tools: Vec>, 682 | system_prompt: Option<&str>, 683 | managed_agents: Option>>, 684 | description: Option<&str>, 685 | max_steps: Option, 686 | ) -> Result { 687 | let system_prompt = system_prompt.unwrap_or(CODE_SYSTEM_PROMPT); 688 | 689 | let base_agent = MultiStepAgent::new( 690 | model, 691 | tools, 692 | Some(system_prompt), 693 | managed_agents, 694 | description, 695 | max_steps, 696 | )?; 697 | let local_python_interpreter = LocalPythonInterpreter::new( 698 | base_agent 699 | .tools 700 | .iter() 701 | .map(|tool| tool.clone_box()) 702 | .collect(), 703 | ); 704 | 705 | Ok(Self { 706 | base_agent, 707 | local_python_interpreter, 708 | }) 709 | } 710 | } 711 | 712 | #[cfg(feature = "code-agent")] 713 | impl Agent for CodeAgent { 714 | fn name(&self) -> &'static str { 715 | self.base_agent.name() 716 | } 717 | fn get_max_steps(&self) -> usize { 718 | self.base_agent.get_max_steps() 719 | } 720 | fn get_step_number(&self) -> usize { 721 | self.base_agent.get_step_number() 722 | } 723 | fn increment_step_number(&mut self) { 724 | self.base_agent.increment_step_number() 725 | } 726 | fn get_logs_mut(&mut self) -> &mut Vec { 727 | self.base_agent.get_logs_mut() 728 | } 729 | fn reset_step_number(&mut self) { 730 | self.base_agent.reset_step_number() 731 | } 732 | fn set_task(&mut self, task: &str) { 733 | self.base_agent.set_task(task); 734 | } 735 | fn get_system_prompt(&self) -> &str { 736 | self.base_agent.get_system_prompt() 737 | } 738 | fn model(&self) -> &dyn Model { 739 | self.base_agent.model() 740 | } 741 | fn step(&mut self, log_entry: &mut Step) -> Result> { 742 | match log_entry { 743 | Step::ActionStep(step_log) => { 744 | let agent_memory = self.base_agent.write_inner_memory_from_logs(None)?; 745 | self.base_agent.input_messages = Some(agent_memory.clone()); 746 | step_log.agent_memory = Some(agent_memory); 747 | 748 | let llm_output = self.base_agent.model.run( 749 | self.base_agent.input_messages.as_ref().unwrap().clone(), 750 | vec![], 751 | None, 752 | Some(HashMap::from([( 753 | "stop".to_string(), 754 | vec!["Observation:".to_string(), "".to_string()], 755 | )])), 756 | )?; 757 | 758 | let response = llm_output.get_response()?; 759 | step_log.llm_output = Some(response.clone()); 760 | 761 | let code = match parse_code_blobs(&response) { 762 | Ok(code) => code, 763 | Err(e) => { 764 | step_log.error = Some(e.clone()); 765 | info!("Error: {}", response + "\n" + &e.to_string()); 766 | return Ok(None); 767 | } 768 | }; 769 | 770 | info!("Code: {}", code); 771 | step_log.tool_call = Some(vec![ToolCall { 772 | id: None, 773 | call_type: Some("function".to_string()), 774 | function: FunctionCall { 775 | name: "python_interpreter".to_string(), 776 | arguments: serde_json::json!({ "code": code }), 777 | }, 778 | }]); 779 | let result = self.local_python_interpreter.forward(&code); 780 | match result { 781 | Ok(result) => { 782 | let (result, execution_logs) = result; 783 | let mut observation = if !execution_logs.is_empty() { 784 | format!("Execution logs: {}", execution_logs) 785 | } else { 786 | format!("Observation: {}", result) 787 | }; 788 | if observation.len() > 30000 { 789 | observation = observation.chars().take(30000).collect::(); 790 | observation = format!("{} \n....This content has been truncated due to the 30000 character limit.....", observation); 791 | } 792 | info!("Observation: {}", observation); 793 | 794 | step_log.observations = Some(vec![observation]); 795 | } 796 | Err(e) => match e { 797 | InterpreterError::FinalAnswer(answer) => { 798 | return Ok(Some(answer)); 799 | } 800 | _ => { 801 | step_log.error = Some(AgentError::Execution(e.to_string())); 802 | info!("Error: {}", e); 803 | } 804 | }, 805 | } 806 | } 807 | _ => { 808 | todo!() 809 | } 810 | } 811 | 812 | Ok(None) 813 | } 814 | } 815 | 816 | #[cfg(feature = "code-agent")] 817 | pub fn parse_code_blobs(code_blob: &str) -> Result { 818 | let pattern = r"```(?:py|python)?\n([\s\S]*?)\n```"; 819 | let re = Regex::new(pattern).map_err(|e| AgentError::Execution(e.to_string()))?; 820 | 821 | let matches: Vec = re 822 | .captures_iter(code_blob) 823 | .map(|cap| cap[1].trim().to_string()) 824 | .collect(); 825 | 826 | if matches.is_empty() { 827 | // Check if it's a direct code blob or final answer 828 | if code_blob.contains("final") && code_blob.contains("answer") { 829 | return Err(AgentError::Parsing( 830 | "The code blob is invalid. It seems like you're trying to return the final answer. Use:\n\ 831 | Code:\n\ 832 | ```py\n\ 833 | final_answer(\"YOUR FINAL ANSWER HERE\")\n\ 834 | ```".to_string(), 835 | )); 836 | } 837 | 838 | return Err(AgentError::Parsing( 839 | "The code blob is invalid. Make sure to include code with the correct pattern, for instance:\n\ 840 | Thoughts: Your thoughts\n\ 841 | Code:\n\ 842 | ```py\n\ 843 | # Your python code here\n\ 844 | ```".to_string(), 845 | )); 846 | } 847 | 848 | Ok(matches.join("\n\n")) 849 | } 850 | -------------------------------------------------------------------------------- /src/bin/main.rs: -------------------------------------------------------------------------------- 1 | use anyhow::Result; 2 | use clap::{Parser, ValueEnum}; 3 | use colored::*; 4 | use smolagents_rs::agents::Step; 5 | use smolagents_rs::agents::{Agent, CodeAgent, FunctionCallingAgent}; 6 | use smolagents_rs::errors::AgentError; 7 | use smolagents_rs::models::model_traits::{Model, ModelResponse}; 8 | use smolagents_rs::models::ollama::{OllamaModel, OllamaModelBuilder}; 9 | use smolagents_rs::models::openai::OpenAIServerModel; 10 | use smolagents_rs::models::types::Message; 11 | use smolagents_rs::tools::{ 12 | AnyTool, DuckDuckGoSearchTool, GoogleSearchTool, ToolInfo, VisitWebsiteTool, 13 | }; 14 | use std::collections::HashMap; 15 | use std::fs::File; 16 | use std::io::{self, Write}; 17 | 18 | #[derive(Debug, Clone, ValueEnum)] 19 | enum AgentType { 20 | FunctionCalling, 21 | Code, 22 | } 23 | 24 | #[derive(Debug, Clone, ValueEnum)] 25 | enum ToolType { 26 | DuckDuckGo, 27 | VisitWebsite, 28 | GoogleSearchTool, 29 | } 30 | 31 | #[derive(Debug, Clone, ValueEnum)] 32 | enum ModelType { 33 | OpenAI, 34 | Ollama, 35 | } 36 | 37 | #[derive(Debug)] 38 | enum ModelWrapper { 39 | OpenAI(OpenAIServerModel), 40 | Ollama(OllamaModel), 41 | } 42 | 43 | enum AgentWrapper { 44 | FunctionCalling(FunctionCallingAgent), 45 | Code(CodeAgent), 46 | } 47 | 48 | impl AgentWrapper { 49 | fn run(&mut self, task: &str, stream: bool, reset: bool) -> Result { 50 | match self { 51 | AgentWrapper::FunctionCalling(agent) => agent.run(task, stream, reset), 52 | AgentWrapper::Code(agent) => agent.run(task, stream, reset), 53 | } 54 | } 55 | fn get_logs_mut(&mut self) -> &mut Vec { 56 | match self { 57 | AgentWrapper::FunctionCalling(agent) => agent.get_logs_mut(), 58 | AgentWrapper::Code(agent) => agent.get_logs_mut(), 59 | } 60 | } 61 | } 62 | impl Model for ModelWrapper { 63 | fn run( 64 | &self, 65 | messages: Vec, 66 | tools: Vec, 67 | max_tokens: Option, 68 | args: Option>>, 69 | ) -> Result, AgentError> { 70 | match self { 71 | ModelWrapper::OpenAI(m) => Ok(m.run(messages, tools, max_tokens, args)?), 72 | ModelWrapper::Ollama(m) => Ok(m.run(messages, tools, max_tokens, args)?), 73 | } 74 | } 75 | } 76 | 77 | #[derive(Parser, Debug)] 78 | #[command(author, version, about, long_about = None)] 79 | struct Args { 80 | /// The type of agent to use 81 | #[arg(short = 'a', long, value_enum, default_value = "function-calling")] 82 | agent_type: AgentType, 83 | 84 | /// List of tools to use 85 | #[arg(short = 'l', long = "tools", value_enum, num_args = 1.., value_delimiter = ',', default_values_t = [ToolType::DuckDuckGo, ToolType::VisitWebsite])] 86 | tools: Vec, 87 | 88 | /// The type of model to use 89 | #[arg(short = 'm', long, value_enum, default_value = "open-ai")] 90 | model_type: ModelType, 91 | 92 | /// OpenAI API key (only required for OpenAI model) 93 | #[arg(short = 'k', long)] 94 | api_key: Option, 95 | 96 | /// Model ID (e.g., "gpt-4" for OpenAI or "qwen2.5" for Ollama) 97 | #[arg(long, default_value = "gpt-4o-mini")] 98 | model_id: String, 99 | 100 | /// Whether to stream the output 101 | #[arg(short, long, default_value = "false")] 102 | stream: bool, 103 | 104 | /// Base URL for the API 105 | #[arg(short, long)] 106 | base_url: Option, 107 | } 108 | 109 | fn create_tool(tool_type: &ToolType) -> Box { 110 | match tool_type { 111 | ToolType::DuckDuckGo => Box::new(DuckDuckGoSearchTool::new()), 112 | ToolType::VisitWebsite => Box::new(VisitWebsiteTool::new()), 113 | ToolType::GoogleSearchTool => Box::new(GoogleSearchTool::new(None)), 114 | } 115 | } 116 | 117 | fn main() -> Result<()> { 118 | let args = Args::parse(); 119 | 120 | let tools: Vec> = args.tools.iter().map(create_tool).collect(); 121 | 122 | // Create model based on type 123 | let model = match args.model_type { 124 | ModelType::OpenAI => ModelWrapper::OpenAI(OpenAIServerModel::new( 125 | args.base_url.as_deref(), 126 | Some(&args.model_id), 127 | None, 128 | args.api_key, 129 | )), 130 | ModelType::Ollama => ModelWrapper::Ollama( 131 | OllamaModelBuilder::new() 132 | .model_id(&args.model_id) 133 | .ctx_length(8000) 134 | .build(), 135 | ), 136 | }; 137 | 138 | // Create agent based on type 139 | let mut agent = match args.agent_type { 140 | AgentType::FunctionCalling => AgentWrapper::FunctionCalling(FunctionCallingAgent::new( 141 | model, 142 | tools, 143 | None, 144 | None, 145 | Some("CLI Agent"), 146 | None, 147 | )?), 148 | AgentType::Code => AgentWrapper::Code(CodeAgent::new( 149 | model, 150 | tools, 151 | None, 152 | None, 153 | Some("CLI Agent"), 154 | None, 155 | )?), 156 | }; 157 | 158 | let mut file: File = File::create("logs.txt")?; 159 | 160 | loop { 161 | print!("{}", "User: ".yellow().bold()); 162 | io::stdout().flush()?; 163 | 164 | let mut task = String::new(); 165 | io::stdin().read_line(&mut task)?; 166 | let task = task.trim(); 167 | 168 | // Exit if user enters empty line or Ctrl+D 169 | if task.is_empty() { 170 | println!("Enter a task to execute"); 171 | continue; 172 | } 173 | if task == "exit" { 174 | break; 175 | } 176 | 177 | // Run the agent with the task from stdin 178 | let _result = agent.run(task, args.stream, true)?; 179 | // Get the last log entry and serialize it in a controlled way 180 | 181 | let logs = agent.get_logs_mut(); 182 | for log in logs { 183 | // Serialize to JSON with pretty printing 184 | serde_json::to_writer_pretty(&mut file, &log)?; 185 | } 186 | } 187 | Ok(()) 188 | } 189 | -------------------------------------------------------------------------------- /src/errors.rs: -------------------------------------------------------------------------------- 1 | use std::fmt; 2 | 3 | use serde::Serialize; 4 | 5 | #[derive(Debug, Clone, Serialize)] 6 | pub enum AgentError { 7 | Parsing(String), 8 | Execution(String), 9 | MaxSteps(String), 10 | Generation(String), 11 | } 12 | 13 | impl std::error::Error for AgentError {} 14 | 15 | impl AgentError { 16 | pub fn message(&self) -> &str { 17 | match self { 18 | Self::Parsing(msg) => msg, 19 | Self::Execution(msg) => msg, 20 | Self::MaxSteps(msg) => msg, 21 | Self::Generation(msg) => msg, 22 | } 23 | } 24 | } 25 | impl std::fmt::Display for AgentError { 26 | fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 27 | match self { 28 | Self::Parsing(msg) => write!(f, "{}", msg), 29 | Self::Execution(msg) => write!(f, "{}", msg), 30 | Self::MaxSteps(msg) => write!(f, "{}", msg), 31 | Self::Generation(msg) => write!(f, "{}", msg), 32 | } 33 | } 34 | } 35 | 36 | pub type AgentParsingError = AgentError; 37 | pub type AgentExecutionError = AgentError; 38 | pub type AgentMaxStepsError = AgentError; 39 | pub type AgentGenerationError = AgentError; 40 | 41 | // Custom error type for interpreter 42 | #[derive(Debug, PartialEq)] 43 | pub enum InterpreterError { 44 | SyntaxError(String), 45 | RuntimeError(String), 46 | FinalAnswer(String), 47 | OperationLimitExceeded, 48 | UnauthorizedImport(String), 49 | UnsupportedOperation(String), 50 | } 51 | 52 | impl fmt::Display for InterpreterError { 53 | fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { 54 | match self { 55 | InterpreterError::SyntaxError(msg) => write!(f, "Syntax Error: {}", msg), 56 | InterpreterError::RuntimeError(msg) => write!(f, "Runtime Error: {}", msg), 57 | InterpreterError::FinalAnswer(msg) => write!(f, "Final Answer: {}", msg), 58 | InterpreterError::OperationLimitExceeded => write!( 59 | f, 60 | "Operation limit exceeded. Possible infinite loop detected." 61 | ), 62 | InterpreterError::UnauthorizedImport(module) => { 63 | write!(f, "Unauthorized import of module: {}", module) 64 | } 65 | InterpreterError::UnsupportedOperation(op) => { 66 | write!(f, "Unsupported operation: {}", op) 67 | } 68 | } 69 | } 70 | } 71 | -------------------------------------------------------------------------------- /src/lib.rs: -------------------------------------------------------------------------------- 1 | //! SmolAgents is a Rust library for building and running agents that can use tools and code. 2 | //! 3 | //! It is inspired by Hugging Face's [smolagents](https://github.com/huggingface/smolagents) library and provides a simple interface for building and running agents. 4 | //! 5 | //! It is designed to be used in a CLI application, but can be used in any Rust application. 6 | //! 7 | //! 8 | //! ## Example usage: 9 | //! 10 | //! ```rust 11 | //! use smolagents_rs::agents::{Agent, FunctionCallingAgent}; 12 | //! use smolagents_rs::models::openai::OpenAIServerModel; 13 | //! use smolagents_rs::tools::{AnyTool, DuckDuckGoSearchTool, VisitWebsiteTool}; 14 | //! let tools: Vec> = vec![ 15 | //! Box::new(DuckDuckGoSearchTool::new()), 16 | //! Box::new(VisitWebsiteTool::new()), 17 | //! ]; 18 | //! let model = OpenAIServerModel::new(Some("https://api.openai.com/v1/chat/completions"), Some("gpt-4o-mini"), None, None); 19 | //! let mut agent = FunctionCallingAgent::new(model, tools, None, None, None, None).unwrap(); 20 | //! let _result = agent 21 | //! .run("Who has the most followers on Twitter?", false, true) 22 | //! .unwrap(); 23 | //! ``` 24 | //! 25 | //! ### Code Agent: 26 | //! 27 | //! To use the code agent simply enable the `code-agent` feature. 28 | //! ```rust 29 | //! use smolagents_rs::agents::{Agent, CodeAgent}; 30 | //! use smolagents_rs::models::openai::OpenAIServerModel; 31 | //! use smolagents_rs::tools::{AnyTool, DuckDuckGoSearchTool, VisitWebsiteTool}; 32 | 33 | //! let tools: Vec> = vec![ 34 | //! Box::new(DuckDuckGoSearchTool::new()), 35 | //! Box::new(VisitWebsiteTool::new()), 36 | //! ]; 37 | //! let model = OpenAIServerModel::new(Some("https://api.openai.com/v1/chat/completions"), Some("gpt-4o-mini"), None, None); 38 | //! let mut agent = CodeAgent::new(model, tools, None, None, None, None).unwrap(); 39 | //! let _result = agent 40 | //! .run("Who has the most followers on Twitter?", false, true) 41 | //! .unwrap(); 42 | 43 | //! ``` 44 | pub mod agents; 45 | pub mod errors; 46 | 47 | #[cfg(feature = "code-agent")] 48 | pub mod local_python_interpreter; 49 | pub(crate) mod logger; 50 | pub mod models; 51 | pub mod prompts; 52 | pub mod tools; 53 | 54 | pub use agents::*; 55 | -------------------------------------------------------------------------------- /src/local_python_interpreter.rs: -------------------------------------------------------------------------------- 1 | use crate::errors::InterpreterError; 2 | use crate::tools::AnyTool; 3 | use anyhow::Result; 4 | use pyo3::prelude::*; 5 | use pyo3::types::{PyDict, PyModule, PyTuple}; 6 | use rustpython_parser::{ 7 | ast::{ 8 | self, 9 | bigint::{BigInt, Sign}, 10 | Constant, Expr, Operator, Stmt, UnaryOp, 11 | }, 12 | Parse, 13 | }; 14 | use serde_json::{self, json}; 15 | use std::{any::Any, collections::HashMap}; 16 | 17 | pub fn get_base_python_tools() -> HashMap<&'static str, &'static str> { 18 | [ 19 | ("print", "custom_print"), 20 | ("isinstance", "isinstance"), 21 | ("range", "range"), 22 | ("float", "float"), 23 | ("int", "int"), 24 | ("bool", "bool"), 25 | ("str", "str"), 26 | ("set", "set"), 27 | ("list", "list"), 28 | ("dict", "dict"), 29 | ("tuple", "tuple"), 30 | ("round", "round"), 31 | ("ceil", "math.ceil"), 32 | ("floor", "math.floor"), 33 | ("log", "math.log"), 34 | ("exp", "math.exp"), 35 | ("sin", "math.sin"), 36 | ("cos", "math.cos"), 37 | ("tan", "math.tan"), 38 | ("asin", "math.asin"), 39 | ("acos", "math.acos"), 40 | ("atan", "math.atan"), 41 | ("atan2", "math.atan2"), 42 | ("degrees", "math.degrees"), 43 | ("radians", "math.radians"), 44 | ("pow", "math.pow"), 45 | ("sqrt", "math.sqrt"), 46 | ("len", "len"), 47 | ("sum", "sum"), 48 | ("max", "max"), 49 | ("min", "min"), 50 | ("abs", "abs"), 51 | ("enumerate", "enumerate"), 52 | ("zip", "zip"), 53 | ("reversed", "reversed"), 54 | ("sorted", "sorted"), 55 | ("all", "all"), 56 | ("any", "any"), 57 | ("map", "map"), 58 | ("filter", "filter"), 59 | ("ord", "ord"), 60 | ("chr", "chr"), 61 | ("next", "next"), 62 | ("iter", "iter"), 63 | ("divmod", "divmod"), 64 | ("callable", "callable"), 65 | ("getattr", "getattr"), 66 | ("hasattr", "hasattr"), 67 | ("setattr", "setattr"), 68 | ("issubclass", "issubclass"), 69 | ("type", "type"), 70 | ("complex", "complex"), 71 | ] 72 | .iter() 73 | .cloned() 74 | .collect() 75 | } 76 | 77 | impl From for InterpreterError { 78 | fn from(err: PyErr) -> Self { 79 | InterpreterError::RuntimeError(err.to_string()) 80 | } 81 | } 82 | 83 | #[derive(Clone, Debug)] 84 | pub enum CustomConstant { 85 | Int(BigInt), 86 | Float(f64), 87 | Str(String), 88 | Bool(bool), 89 | Tuple(Vec), 90 | PyObj(PyObject), 91 | Dict(Vec, Vec), 92 | } 93 | 94 | impl CustomConstant { 95 | pub fn float(&self) -> Option { 96 | match self { 97 | CustomConstant::Float(f) => Some(*f), 98 | _ => None, 99 | } 100 | } 101 | pub fn str(&self) -> String { 102 | match self { 103 | CustomConstant::Str(s) => s.clone(), 104 | CustomConstant::Float(f) => f.to_string(), 105 | CustomConstant::Int(i) => i.to_string(), 106 | CustomConstant::Tuple(t) => { 107 | let mut result = String::new(); 108 | result.push('['); 109 | for (i, item) in t.iter().enumerate() { 110 | if i > 0 { 111 | result.push_str(", "); 112 | } 113 | result.push_str(&item.str()); 114 | } 115 | result.push(']'); 116 | result 117 | } 118 | CustomConstant::Dict(keys, values) => { 119 | let mut result = String::new(); 120 | result.push('{'); 121 | for (i, key) in keys.iter().enumerate() { 122 | if i > 0 { 123 | result.push_str(", "); 124 | } 125 | result.push_str(&format!("'{}': {}", key, values[i].str())); 126 | } 127 | result.push('}'); 128 | 129 | for (i, item) in values.iter().enumerate() { 130 | if i > 0 { 131 | result.push_str(", "); 132 | } 133 | result.push_str(&item.str()); 134 | } 135 | result.push('}'); 136 | result 137 | } 138 | CustomConstant::PyObj(obj) => obj.to_string(), 139 | CustomConstant::Bool(b) => b.to_string(), 140 | } 141 | } 142 | pub fn tuple(&self) -> Option> { 143 | match self { 144 | CustomConstant::Tuple(t) => Some(t.clone()), 145 | _ => None, 146 | } 147 | } 148 | } 149 | 150 | impl From for Constant { 151 | fn from(custom: CustomConstant) -> Self { 152 | match custom { 153 | CustomConstant::Int(i) => Constant::Int(i), 154 | CustomConstant::Float(f) => Constant::Float(f), 155 | CustomConstant::Str(s) => Constant::Str(s), 156 | CustomConstant::Bool(b) => Constant::Bool(b), 157 | CustomConstant::PyObj(obj) => Constant::Str(obj.to_string()), 158 | CustomConstant::Tuple(t) => { 159 | let tuple_items = t 160 | .iter() 161 | .map(|c| Constant::from(c.clone())) 162 | .collect::>(); 163 | Constant::Tuple(tuple_items) 164 | } 165 | CustomConstant::Dict(keys, values) => { 166 | let tuple_items = keys 167 | .iter() 168 | .zip(values.iter()) 169 | .map(|(k, v)| { 170 | Constant::Tuple(vec![Constant::Str(k.clone()), Constant::from(v.clone())]) 171 | }) 172 | .collect::>(); 173 | Constant::Tuple(tuple_items) 174 | } 175 | } 176 | } 177 | } 178 | 179 | impl From for CustomConstant { 180 | fn from(constant: Constant) -> Self { 181 | match constant { 182 | Constant::Int(i) => CustomConstant::Int(i), 183 | Constant::Float(f) => CustomConstant::Float(f), 184 | Constant::Str(s) => CustomConstant::Str(s), 185 | Constant::Bool(b) => CustomConstant::Bool(b), 186 | Constant::None => CustomConstant::Str("None".to_string()), 187 | Constant::Tuple(t) => { 188 | CustomConstant::Tuple(t.iter().map(|c| c.clone().into()).collect()) 189 | } 190 | _ => panic!("Unsupported constant type"), 191 | } 192 | } 193 | } 194 | 195 | impl IntoPy for CustomConstant { 196 | fn into_py(self, py: Python<'_>) -> PyObject { 197 | match self { 198 | CustomConstant::Int(i) => convert_bigint_to_i64(&i).into_py(py), 199 | CustomConstant::Float(f) => f.into_py(py), 200 | CustomConstant::Str(s) => s.into_py(py), 201 | CustomConstant::Bool(b) => b.into_py(py), 202 | CustomConstant::Tuple(t) => { 203 | let py_list = t 204 | .iter() 205 | .map(|x| x.clone().into_py(py)) 206 | .collect::>(); 207 | py_list.into_py(py) 208 | } 209 | CustomConstant::PyObj(obj) => obj, 210 | CustomConstant::Dict(keys, values) => { 211 | let dict = PyDict::new(py); 212 | for (key, value) in keys.iter().zip(values.iter()) { 213 | dict.set_item(key, value.clone().into_py(py)) 214 | .unwrap_or_default(); 215 | } 216 | dict.into_py(py) 217 | } 218 | } 219 | } 220 | } 221 | 222 | type ToolFunction = Box) -> Result>; 223 | type CustomToolFunction = 224 | Box, HashMap) -> Result>; 225 | 226 | fn setup_custom_tools(tools: Vec>) -> HashMap { 227 | let mut tools_map = HashMap::new(); 228 | for tool in tools { 229 | let tool_info = tool.tool_info(); 230 | tools_map.insert( 231 | tool.name().to_string(), 232 | Box::new( 233 | move |args: Vec, kwargs: HashMap| { 234 | //merge args and kwargs 235 | let tool_parameter_names = tool_info.get_parameter_names(); 236 | 237 | let mut new_args = HashMap::new(); 238 | for (i, arg) in args.iter().enumerate() { 239 | new_args 240 | .insert(tool_parameter_names[i].clone(), arg.clone().str().unwrap()); 241 | } 242 | for (key, value) in kwargs { 243 | new_args.insert(key, value); 244 | } 245 | match tool.forward_json(json!(new_args)) { 246 | Ok(results) => Ok(CustomConstant::Str(results)), 247 | Err(e) => Ok(CustomConstant::Str(format!("Error: {}", e))), 248 | } 249 | }, 250 | ) as CustomToolFunction, 251 | ); 252 | } 253 | tools_map 254 | } 255 | 256 | pub fn setup_static_tools( 257 | static_tools: HashMap<&'static str, &'static str>, 258 | ) -> HashMap { 259 | let mut tools = HashMap::new(); 260 | let static_tools_clone = static_tools.clone(); 261 | let eval_py = move |func: &str, args: Vec| { 262 | Python::with_gil(|py| { 263 | let locals = PyDict::new(py); 264 | 265 | // Import required modules 266 | let math = PyModule::import(py, "math")?; 267 | locals.set_item("math", math)?; 268 | 269 | for (i, arg) in args.iter().enumerate() { 270 | match arg { 271 | Constant::Float(f) => locals.set_item(format!("arg{}", i), f)?, 272 | Constant::Int(int) => { 273 | locals.set_item(format!("arg{}", i), convert_bigint_to_i64(int))?; 274 | } 275 | Constant::Str(s) => locals.set_item(format!("arg{}", i), s)?, 276 | Constant::Tuple(t) => { 277 | let py_list: Vec = t 278 | .iter() 279 | .map(|x| match x { 280 | Constant::Float(f) => *f, 281 | Constant::Int(i) => convert_bigint_to_f64(i), 282 | _ => 0.0, 283 | }) 284 | .collect(); 285 | locals.set_item(format!("arg{}", i), py_list)? 286 | } 287 | _ => locals.set_item(format!("arg{}", i), 0.0)?, 288 | } 289 | } 290 | 291 | let arg_names: Vec = (0..args.len()).map(|i| format!("arg{}", i)).collect(); 292 | let func_path = static_tools.get(func).unwrap_or(&"builtins.float"); 293 | let expr = format!("{}({})", func_path, arg_names.join(",")); 294 | 295 | let result = py.eval(&expr, None, Some(locals))?; 296 | // Handle different return types 297 | if let Ok(float_val) = result.extract::() { 298 | Ok(CustomConstant::Float(float_val)) 299 | } else if let Ok(list_val) = result.extract::>() { 300 | Ok(CustomConstant::Tuple( 301 | list_val.into_iter().map(CustomConstant::Str).collect(), 302 | )) 303 | } else if let Ok(string_val) = result.extract::() { 304 | Ok(CustomConstant::Str(string_val)) 305 | } else if let Ok(bool_val) = result.extract::() { 306 | Ok(CustomConstant::Bool(bool_val)) 307 | } else if let Ok(int_val) = result.extract::() { 308 | Ok(CustomConstant::Int(BigInt::from(int_val))) 309 | } else { 310 | Ok(CustomConstant::PyObj(result.into_py(py))) 311 | } 312 | }) 313 | }; 314 | 315 | // Register tools after eval_py is defined 316 | for func in static_tools_clone.keys() { 317 | let func = func.to_string(); // Create owned String 318 | let eval_py = eval_py.clone(); // Clone the closure 319 | tools.insert( 320 | func.clone(), 321 | Box::new(move |args| eval_py(&func, args)) as ToolFunction, 322 | ); 323 | } 324 | 325 | tools 326 | } 327 | 328 | fn evaluate_stmt( 329 | node: &ast::Stmt, 330 | state: &mut HashMap>, 331 | static_tools: &HashMap, 332 | custom_tools: &HashMap, 333 | ) -> Result { 334 | match node { 335 | Stmt::FunctionDef(func) => Ok(CustomConstant::Str(format!("Function: {:?}", func.name))), 336 | Stmt::Expr(expr) => { 337 | let result = evaluate_expr(&expr.value, state, static_tools, custom_tools)?; 338 | Ok(result) 339 | } 340 | Stmt::For(for_stmt) => { 341 | let iter = evaluate_expr(&for_stmt.iter.clone(), state, static_tools, custom_tools)?; 342 | // Convert PyObj iterator into a vector of values 343 | let values = match iter { 344 | CustomConstant::PyObj(obj) => { 345 | Python::with_gil(|py| -> Result, InterpreterError> { 346 | let iter = obj.as_ref(py).iter()?; 347 | let mut values = Vec::new(); 348 | 349 | for item in iter { 350 | let item = item?; 351 | if let Ok(num) = item.extract::() { 352 | values.push(CustomConstant::Int(BigInt::from(num))); 353 | } else if let Ok(float) = item.extract::() { 354 | values.push(CustomConstant::Float(float)); 355 | } else if let Ok(string) = item.extract::() { 356 | values.push(CustomConstant::Str(string)); 357 | } else { 358 | return Err(InterpreterError::RuntimeError( 359 | "Unsupported type in iterator".to_string(), 360 | )); 361 | } 362 | } 363 | Ok(values) 364 | })? 365 | } 366 | CustomConstant::Tuple(items) => items, 367 | _ => { 368 | return Err(InterpreterError::RuntimeError( 369 | "Expected iterable".to_string(), 370 | )) 371 | } 372 | }; 373 | // Get the target variable name 374 | let target_name = match &*for_stmt.target { 375 | ast::Expr::Name(name) => name.id.to_string(), 376 | _ => { 377 | return Err(InterpreterError::RuntimeError( 378 | "Expected name as loop target".to_string(), 379 | )) 380 | } 381 | }; 382 | let mut for_loop_result = CustomConstant::Str(String::new()); 383 | // Iterate over the values and execute the body for each iteration 384 | for value in values { 385 | // Update the loop variable in the state 386 | state.insert(target_name.clone(), Box::new(value)); 387 | 388 | // Execute each statement in the loop body 389 | for stmt in &for_stmt.body { 390 | for_loop_result = evaluate_stmt(stmt, state, static_tools, custom_tools)?; 391 | } 392 | } 393 | Ok(for_loop_result) 394 | } 395 | 396 | Stmt::Assign(assign) => { 397 | for target in assign.targets.iter() { 398 | // let target = evaluate_expr(&Box::new(target.clone()), state, static_tools)?; 399 | match target { 400 | ast::Expr::Name(name) => { 401 | let value = 402 | evaluate_expr(&assign.value, state, static_tools, custom_tools)?; 403 | state.insert(name.id.to_string(), Box::new(value)); 404 | } 405 | ast::Expr::Tuple(target_names) => { 406 | let value = 407 | evaluate_expr(&assign.value, state, static_tools, custom_tools)?; 408 | let values = value.tuple().ok_or_else(|| { 409 | InterpreterError::RuntimeError( 410 | "Tuple unpacking failed. Expected values of type tuple".to_string(), 411 | ) 412 | })?; 413 | if target_names.elts.len() != values.len() { 414 | return Err(InterpreterError::RuntimeError(format!( 415 | "Tuple unpacking failed. Expected {} values, got {}", 416 | target_names.elts.len(), 417 | values.len() 418 | ))); 419 | } 420 | for (i, target_name) in target_names.elts.iter().enumerate() { 421 | match target_name { 422 | ast::Expr::Name(name) => { 423 | state.insert(name.id.to_string(), Box::new(values[i].clone())); 424 | } 425 | _ => panic!("Expected string"), 426 | } 427 | } 428 | } 429 | _ => panic!("Expected string"), 430 | } 431 | } 432 | Ok(CustomConstant::Str(String::new())) 433 | } 434 | 435 | _ => Err(InterpreterError::RuntimeError(format!( 436 | "Unsupported statement {:?}", 437 | node 438 | ))), 439 | } 440 | } 441 | 442 | fn evaluate_ast( 443 | ast: &ast::Suite, 444 | state: &mut HashMap>, 445 | static_tools: &HashMap, 446 | custom_tools: &HashMap, 447 | ) -> Result { 448 | let mut result = CustomConstant::Str(String::new()); 449 | for node in ast.iter() { 450 | result = evaluate_stmt(node, state, static_tools, custom_tools)?; 451 | } 452 | Ok(result) 453 | } 454 | 455 | fn convert_bigint_to_f64(i: &BigInt) -> f64 { 456 | let i = i.to_u32_digits(); 457 | let num = i.1.iter().fold(0i64, |acc, &d| acc * (1 << 32) + d as i64); 458 | match i.0 { 459 | Sign::Minus => -num as f64, 460 | Sign::NoSign | Sign::Plus => num as f64, 461 | } 462 | } 463 | fn convert_bigint_to_i64(i: &BigInt) -> i64 { 464 | let i = i.to_u32_digits(); 465 | let num = i.1.iter().fold(0i64, |acc, &d| acc * (1 << 32) + d as i64); 466 | match i.0 { 467 | Sign::Minus => -num, 468 | Sign::NoSign | Sign::Plus => num, 469 | } 470 | } 471 | 472 | type StaticTool = Box) -> Result>; 473 | type CustomTool = 474 | Box, HashMap) -> Result>; 475 | 476 | fn evaluate_expr( 477 | expr: &Expr, 478 | state: &mut HashMap>, 479 | static_tools: &HashMap, 480 | custom_tools: &HashMap, 481 | ) -> Result { 482 | match &expr { 483 | ast::Expr::Dict(dict) => { 484 | let keys = dict 485 | .keys 486 | .iter() 487 | .map(|e| { 488 | evaluate_expr( 489 | &Box::new(e.clone().ok_or_else(|| { 490 | InterpreterError::RuntimeError( 491 | "Dictionary key cannot be None".to_string(), 492 | ) 493 | })?), 494 | state, 495 | static_tools, 496 | custom_tools, 497 | ) 498 | .map(|c| c.str()) 499 | }) 500 | .collect::, _>>()?; 501 | let values = dict 502 | .values 503 | .iter() 504 | .map(|e| evaluate_expr(&Box::new(e.clone()), state, static_tools, custom_tools)) 505 | .collect::, _>>()?; 506 | Ok(CustomConstant::Dict(keys, values)) 507 | } 508 | ast::Expr::ListComp(list_comp) => { 509 | let iter = evaluate_expr( 510 | &list_comp.generators[0].iter, 511 | state, 512 | static_tools, 513 | custom_tools, 514 | )?; 515 | let result = Python::with_gil(|py| -> Result, InterpreterError> { 516 | let iter = iter.into_py(py); 517 | let iter = iter.as_ref(py).iter()?; 518 | let mut result = Vec::new(); 519 | for item in iter { 520 | let target = match &list_comp.generators[0].target { 521 | ast::Expr::Name(name) => name.id.to_string(), 522 | _ => panic!("Expected string"), 523 | }; 524 | let item = item?; 525 | let item = extract_constant_from_pyobject(item, py)?; 526 | state.insert(target, Box::new(item)); 527 | let eval_expr = 528 | evaluate_expr(&list_comp.elt, state, static_tools, custom_tools)?; 529 | result.push(eval_expr); 530 | } 531 | Ok(result) 532 | }); 533 | let result = result?; 534 | Ok(CustomConstant::Tuple(result)) 535 | } 536 | ast::Expr::Call(call) => { 537 | let args = call 538 | .args 539 | .iter() 540 | .map(|e| evaluate_expr(&Box::new(e.clone()), state, static_tools, custom_tools)) 541 | .collect::, InterpreterError>>()?; 542 | let func = match &*call.func { 543 | ast::Expr::Name(name) => name.id.to_string(), 544 | ast::Expr::Attribute(attr) => { 545 | let obj = evaluate_expr( 546 | &Box::new(*attr.value.clone()), 547 | state, 548 | static_tools, 549 | custom_tools, 550 | )?; 551 | 552 | let func_name = attr.attr.to_string(); 553 | let output = 554 | Python::with_gil(|py| -> Result { 555 | let obj = obj.into_py(py); 556 | let func = obj.getattr(py, func_name.as_str())?; 557 | let py_args = args 558 | .iter() 559 | .map(|a| match a { 560 | // Convert numeric types to strings when calling string methods 561 | CustomConstant::Float(f) => f.into_py(py), 562 | CustomConstant::Int(i) => convert_bigint_to_i64(i).into_py(py), 563 | _ => a.clone().into_py(py), 564 | }) 565 | .collect::>(); 566 | let py_tuple = PyTuple::new(py, py_args); 567 | let result = func.call1(py, py_tuple)?; 568 | 569 | // For methods that modify in place (like append), return the original object 570 | if func_name == "append" 571 | || func_name == "extend" 572 | || func_name == "insert" 573 | { 574 | let target = match &*attr.value { 575 | ast::Expr::Name(name) => name.id.to_string(), 576 | _ => panic!("Expected name"), 577 | }; 578 | let out = extract_constant_from_pyobject(obj.as_ref(py), py)?; 579 | state.insert(target, Box::new(out.clone())); 580 | return Ok(out); 581 | } 582 | 583 | extract_constant_from_pyobject(result.as_ref(py), py) 584 | }); 585 | return output; 586 | } 587 | _ => panic!("Expected function name"), 588 | }; 589 | 590 | let keywords = call 591 | .keywords 592 | .iter() 593 | .map(|k| { 594 | let value = evaluate_expr( 595 | &Box::new(k.value.clone()), 596 | state, 597 | static_tools, 598 | custom_tools, 599 | )?; 600 | Ok((k.arg.as_ref().unwrap().to_string(), value.str())) 601 | }) 602 | .collect::, InterpreterError>>()?; 603 | if func == "final_answer" { 604 | if let Some(answer) = keywords.get("answer") { 605 | return Err(InterpreterError::FinalAnswer(answer.to_string())); 606 | } else { 607 | return Err(InterpreterError::FinalAnswer( 608 | args.iter() 609 | .map(|c| c.str()) 610 | .collect::>() 611 | .join(" "), 612 | )); 613 | } 614 | } 615 | if func == "print" { 616 | match state.get_mut("print_logs") { 617 | Some(logs) => { 618 | if let Some(logs) = logs.downcast_mut::>() { 619 | logs.push( 620 | args.iter() 621 | .map(|c| c.str()) 622 | .collect::>() 623 | .join(" "), 624 | ); 625 | } else { 626 | return Err(InterpreterError::RuntimeError( 627 | "print_logs is not a list".to_string(), 628 | )); 629 | } 630 | } 631 | None => { 632 | state.insert( 633 | "print_logs".to_string(), 634 | Box::new(args.iter().map(|c| c.str()).collect::>()), 635 | ); 636 | } 637 | } 638 | return Ok(CustomConstant::Str( 639 | args.iter() 640 | .map(|c| c.str()) 641 | .collect::>() 642 | .join(" "), 643 | )); 644 | } 645 | if static_tools.contains_key(&func) { 646 | let result = 647 | static_tools[&func](args.iter().map(|c| Constant::from(c.clone())).collect()); 648 | result 649 | } else if custom_tools.contains_key(&func) { 650 | let result = custom_tools[&func]( 651 | args.iter().map(|c| Constant::from(c.clone())).collect(), 652 | keywords, 653 | ); 654 | result 655 | } else { 656 | Err(InterpreterError::RuntimeError(format!( 657 | "Function '{}' not found", 658 | func 659 | ))) 660 | } 661 | } 662 | ast::Expr::BinOp(binop) => { 663 | let left_val_exp = 664 | evaluate_expr(&binop.left.clone(), state, static_tools, custom_tools)?; 665 | let right_val_exp: CustomConstant = 666 | evaluate_expr(&binop.right.clone(), state, static_tools, custom_tools)?; 667 | 668 | match binop.op { 669 | Operator::Add => match (left_val_exp.clone(), right_val_exp.clone()) { 670 | (CustomConstant::Str(s), CustomConstant::Str(s2)) => { 671 | return Ok(CustomConstant::Str(s + &s2)); 672 | } 673 | (CustomConstant::Str(s), CustomConstant::Int(i)) => { 674 | return Ok(CustomConstant::Str(s + &i.to_string())); 675 | } 676 | (CustomConstant::Int(i), CustomConstant::Str(s)) => { 677 | return Ok(CustomConstant::Str(i.to_string() + &s)); 678 | } 679 | _ => {} 680 | }, 681 | Operator::Mult => match (left_val_exp.clone(), right_val_exp.clone()) { 682 | (CustomConstant::Str(s), CustomConstant::Int(i)) => { 683 | return Ok(CustomConstant::Str( 684 | s.repeat(convert_bigint_to_i64(&i) as usize), 685 | )); 686 | } 687 | (CustomConstant::Int(i), CustomConstant::Str(s)) => { 688 | return Ok(CustomConstant::Str( 689 | s.repeat(convert_bigint_to_i64(&i) as usize), 690 | )); 691 | } 692 | _ => {} 693 | }, 694 | _ => {} 695 | } 696 | let left_val = match left_val_exp.clone() { 697 | CustomConstant::Float(f) => f, 698 | CustomConstant::Int(i) => convert_bigint_to_f64(&i), 699 | _ => panic!("Expected float or int"), 700 | }; 701 | let right_val = match right_val_exp.clone() { 702 | CustomConstant::Float(f) => f, 703 | CustomConstant::Int(i) => convert_bigint_to_f64(&i), 704 | _ => panic!("Expected float or int"), 705 | }; 706 | 707 | match &binop.op { 708 | Operator::Add => Ok(CustomConstant::Float(left_val + right_val)), 709 | Operator::Sub => Ok(CustomConstant::Float(left_val - right_val)), 710 | Operator::Mult => Ok(CustomConstant::Float(left_val * right_val)), 711 | Operator::Div => Ok(CustomConstant::Float(left_val / right_val)), 712 | Operator::FloorDiv => Ok(CustomConstant::Float(left_val / right_val)), 713 | Operator::Mod => Ok(CustomConstant::Float(left_val % right_val)), 714 | Operator::Pow => Ok(CustomConstant::Float(left_val.powf(right_val))), 715 | Operator::BitOr => Ok(CustomConstant::Int(BigInt::from( 716 | left_val as i64 | right_val as i64, 717 | ))), 718 | Operator::BitXor => Ok(CustomConstant::Int(BigInt::from( 719 | left_val as i64 ^ right_val as i64, 720 | ))), 721 | Operator::BitAnd => Ok(CustomConstant::Int(BigInt::from( 722 | left_val as i64 & right_val as i64, 723 | ))), 724 | Operator::LShift => { 725 | let left_val = left_val as i64; 726 | let right_val = right_val as i64; 727 | Ok(CustomConstant::Int(BigInt::from(left_val << right_val))) 728 | } 729 | Operator::RShift => { 730 | let left_val = left_val as i64; 731 | let right_val = right_val as i64; 732 | Ok(CustomConstant::Int(BigInt::from(left_val >> right_val))) 733 | } 734 | Operator::MatMult => Ok(CustomConstant::Float(left_val * right_val)), 735 | } 736 | } 737 | ast::Expr::UnaryOp(unaryop) => { 738 | let operand = evaluate_expr(&unaryop.operand, state, static_tools, custom_tools)?; 739 | match &unaryop.op { 740 | UnaryOp::USub => match operand { 741 | CustomConstant::Float(f) => Ok(CustomConstant::Float(-f)), 742 | CustomConstant::Int(i) => Ok(CustomConstant::Int(-i)), 743 | _ => panic!("Expected float or int"), 744 | }, 745 | UnaryOp::UAdd => Ok(operand), 746 | UnaryOp::Not => { 747 | if let CustomConstant::Bool(b) = operand { 748 | Ok(CustomConstant::Bool(!b)) 749 | } else { 750 | panic!("Expected boolean") 751 | } 752 | } 753 | UnaryOp::Invert => { 754 | if let CustomConstant::Float(f) = operand { 755 | Ok(CustomConstant::Float(-(f as i64) as f64)) 756 | } else { 757 | panic!("Expected float") 758 | } 759 | } 760 | } 761 | } 762 | ast::Expr::Constant(constant) => match &constant.value { 763 | Constant::Int(i) => Ok(CustomConstant::Int(i.clone())), 764 | _ => Ok(constant.value.clone().into()), 765 | }, 766 | ast::Expr::List(list) => Ok(CustomConstant::Tuple( 767 | list.elts 768 | .iter() 769 | .map(|e| evaluate_expr(&Box::new(e.clone()), state, static_tools, custom_tools)) 770 | .collect::, _>>()?, 771 | )), 772 | ast::Expr::Name(name) => { 773 | if let Some(value) = state.get(name.id.as_str()) { 774 | if let Some(constant) = value.downcast_ref::() { 775 | Ok(constant.clone()) 776 | } else { 777 | Err(InterpreterError::RuntimeError(format!( 778 | "Error in downcasting constant {}", 779 | name.id 780 | ))) 781 | } 782 | } else { 783 | Err(InterpreterError::RuntimeError(format!( 784 | "Variable '{}' used before assignment", 785 | name.id 786 | ))) 787 | } 788 | } 789 | ast::Expr::Tuple(tuple) => Ok(CustomConstant::Tuple( 790 | tuple 791 | .elts 792 | .iter() 793 | .map(|e| evaluate_expr(&Box::new(e.clone()), state, static_tools, custom_tools)) 794 | .collect::, _>>()?, 795 | )), 796 | ast::Expr::JoinedStr(joinedstr) => Ok(CustomConstant::Str( 797 | joinedstr 798 | .values 799 | .iter() 800 | .map(|e| { 801 | evaluate_expr(&Box::new(e.clone()), state, static_tools, custom_tools) 802 | .map(|result| result.str()) 803 | }) 804 | .collect::, _>>()? 805 | .join(""), 806 | )), 807 | ast::Expr::FormattedValue(formattedvalue) => { 808 | let result = evaluate_expr(&formattedvalue.value, state, static_tools, custom_tools)?; 809 | 810 | Ok(CustomConstant::Str(result.str())) 811 | } 812 | ast::Expr::Subscript(subscript) => { 813 | let result = Python::with_gil(|py| { 814 | // Get the value being subscripted (e.g., the list/string) 815 | let value = evaluate_expr(&subscript.value, state, static_tools, custom_tools)?; 816 | let value_obj = value.into_py(py); 817 | 818 | let slice = Constant::from(evaluate_expr( 819 | &subscript.slice, 820 | state, 821 | static_tools, 822 | custom_tools, 823 | )?); 824 | 825 | // Handle integer indices for lists/sequences 826 | if let Constant::Int(i) = slice { 827 | let index = convert_bigint_to_i64(&i); 828 | let result = value_obj.as_ref(py).get_item(index); 829 | match result { 830 | Ok(result) => return extract_constant_from_pyobject(result, py), 831 | Err(e) => return Err(InterpreterError::RuntimeError(e.to_string())), 832 | } 833 | } 834 | 835 | // Handle string keys for dictionaries 836 | if let Constant::Str(s) = slice { 837 | // Try to extract as dictionary first 838 | if let Ok(dict) = value_obj.as_ref(py).downcast::() { 839 | let result = dict.get_item(s.clone()); 840 | match result { 841 | Some(value) => return extract_constant_from_pyobject(value, py), 842 | None => { 843 | return Err(InterpreterError::RuntimeError(format!( 844 | "KeyError: '{}'", 845 | s 846 | ))) 847 | } 848 | } 849 | } 850 | } 851 | 852 | // Handle both simple indexing and slicing 853 | let result = match &*subscript.slice { 854 | // For slice operations like num[1:3:2] 855 | ast::Expr::Slice(slice) => { 856 | let start = match &slice.lower { 857 | Some(lower) => { 858 | evaluate_expr(lower, state, static_tools, custom_tools)?.into() 859 | } 860 | None => None, 861 | }; 862 | let start = start 863 | .map(|start| { 864 | let constant = Constant::from(start); 865 | constant 866 | .int() 867 | .map(|i| convert_bigint_to_i64(&i)) 868 | .ok_or_else(|| { 869 | InterpreterError::RuntimeError( 870 | "Invalid start value in slice".to_string(), 871 | ) 872 | }) 873 | }) 874 | .transpose()?; 875 | 876 | let stop = match &slice.upper { 877 | Some(upper) => { 878 | evaluate_expr(upper, state, static_tools, custom_tools)?.into() 879 | } 880 | None => None, 881 | }; 882 | let stop = stop 883 | .map(|stop| { 884 | let constant = Constant::from(stop); 885 | constant 886 | .int() 887 | .map(|i| convert_bigint_to_i64(&i)) 888 | .ok_or_else(|| { 889 | InterpreterError::RuntimeError( 890 | "Invalid stop value in slice".to_string(), 891 | ) 892 | }) 893 | }) 894 | .transpose()?; 895 | 896 | let step = match &slice.step { 897 | Some(step) => { 898 | evaluate_expr(step, state, static_tools, custom_tools)?.into() 899 | } 900 | None => None, 901 | }; 902 | let step = step 903 | .map(|step| { 904 | let constant = Constant::from(step); 905 | constant 906 | .int() 907 | .map(|i| convert_bigint_to_i64(&i)) 908 | .ok_or_else(|| { 909 | InterpreterError::RuntimeError( 910 | "Invalid step value in slice".to_string(), 911 | ) 912 | }) 913 | }) 914 | .transpose()?; 915 | 916 | let slice_obj = py 917 | .eval("slice", None, None)? 918 | .call1((start, stop, step))? 919 | .into_py(py); 920 | value_obj.as_ref(py).get_item(slice_obj)? 921 | } 922 | _ => return Err(InterpreterError::RuntimeError("Invalid slice".to_string())), 923 | }; 924 | 925 | // Convert the result back to our CustomConstant type 926 | extract_constant_from_pyobject(result, py) 927 | }); 928 | result 929 | } 930 | ast::Expr::Slice(slice) => { 931 | let start = match &slice.lower { 932 | Some(lower) => evaluate_expr(lower, state, static_tools, custom_tools)?, 933 | None => CustomConstant::Int(BigInt::from(0)), 934 | }; 935 | let end = match &slice.upper { 936 | Some(upper) => evaluate_expr(upper, state, static_tools, custom_tools)?, 937 | None => CustomConstant::Int(BigInt::from(0)), 938 | }; 939 | let step = match &slice.step { 940 | Some(step) => evaluate_expr(step, state, static_tools, custom_tools)?, 941 | None => CustomConstant::Int(BigInt::from(1)), 942 | }; 943 | Ok(CustomConstant::Tuple(vec![start, end, step])) 944 | } 945 | _ => { 946 | panic!("Unsupported expression: {:?}", expr); 947 | } 948 | } 949 | } 950 | 951 | fn extract_constant_from_pyobject( 952 | obj: &PyAny, 953 | py: Python<'_>, 954 | ) -> Result { 955 | if let Ok(float_val) = obj.extract::() { 956 | Ok(CustomConstant::Float(float_val)) 957 | } else if let Ok(string_val) = obj.extract::() { 958 | Ok(CustomConstant::Str(string_val)) 959 | } else if let Ok(bool_val) = obj.extract::() { 960 | Ok(CustomConstant::Bool(bool_val)) 961 | } else if let Ok(int_val) = obj.extract::() { 962 | Ok(CustomConstant::Int(BigInt::from(int_val))) 963 | } else if let Ok(list_val) = obj.extract::>() { 964 | Ok(CustomConstant::Tuple( 965 | list_val.into_iter().map(CustomConstant::Str).collect(), 966 | )) 967 | } else if let Ok(list_val) = obj.extract::>() { 968 | Ok(CustomConstant::Tuple( 969 | list_val 970 | .into_iter() 971 | .map(|i| CustomConstant::Int(BigInt::from(i))) 972 | .collect(), 973 | )) 974 | } else if let Ok(list_val) = obj.extract::>() { 975 | Ok(CustomConstant::Tuple( 976 | list_val.into_iter().map(CustomConstant::Float).collect(), 977 | )) 978 | } else if let Ok(dict_value) = obj.extract::<&PyDict>() { 979 | let keys = dict_value 980 | .keys() 981 | .iter() 982 | .map(|key| key.extract::()) 983 | .collect::, _>>()?; 984 | let values = dict_value 985 | .values() 986 | .iter() 987 | .map(|value| extract_constant_from_pyobject(value, py)) 988 | .collect::, _>>()?; 989 | Ok(CustomConstant::Dict(keys, values)) 990 | } else { 991 | Ok(CustomConstant::PyObj(obj.into_py(py))) 992 | } 993 | } 994 | pub fn evaluate_python_code( 995 | code: &str, 996 | custom_tools: Vec>, 997 | state: &mut HashMap>, 998 | ) -> Result { 999 | let base_tools = get_base_python_tools(); 1000 | let static_tools = setup_static_tools(base_tools); 1001 | let custom_tools = setup_custom_tools(custom_tools); 1002 | let ast = ast::Suite::parse(code, "") 1003 | .map_err(|e| InterpreterError::SyntaxError(e.to_string()))?; 1004 | 1005 | let result = evaluate_ast(&ast, state, &static_tools, &custom_tools)?; 1006 | Ok(result.str()) 1007 | } 1008 | 1009 | pub struct LocalPythonInterpreter { 1010 | static_tools: HashMap, 1011 | custom_tools: HashMap, 1012 | state: HashMap>, 1013 | } 1014 | 1015 | impl LocalPythonInterpreter { 1016 | pub fn new(custom_tools: Vec>) -> Self { 1017 | let custom_tools = setup_custom_tools(custom_tools); 1018 | let base_tools = get_base_python_tools(); 1019 | let static_tools = setup_static_tools(base_tools); 1020 | Self { 1021 | static_tools, 1022 | custom_tools, 1023 | state: HashMap::new(), 1024 | } 1025 | } 1026 | pub fn forward(&mut self, code: &str) -> Result<(String, String), InterpreterError> { 1027 | let ast = ast::Suite::parse(code, "") 1028 | .map_err(|e| InterpreterError::SyntaxError(e.to_string()))?; 1029 | let state = &mut self.state; 1030 | let result = evaluate_ast(&ast, state, &self.static_tools, &self.custom_tools)?; 1031 | 1032 | let mut empty_string = Vec::new(); 1033 | let execution_logs = state 1034 | .get_mut("print_logs") 1035 | .and_then(|logs| logs.downcast_mut::>()) 1036 | .unwrap_or(&mut empty_string) 1037 | .join("\n"); 1038 | Ok((result.str(), execution_logs)) 1039 | } 1040 | } 1041 | #[cfg(test)] 1042 | mod tests { 1043 | use super::*; 1044 | use crate::tools::{DuckDuckGoSearchTool, FinalAnswerTool, VisitWebsiteTool}; 1045 | use std::collections::HashMap; 1046 | 1047 | #[test] 1048 | fn test_evaluate_python_code() { 1049 | let code = "print('Hello, world!')"; 1050 | let mut state = HashMap::new(); 1051 | let result = evaluate_python_code(code, vec![], &mut state).unwrap(); 1052 | assert_eq!(result, "Hello, world!"); 1053 | } 1054 | 1055 | #[test] 1056 | fn test_evaluate_python_code_with_joined_str() { 1057 | let code = r#"word = 'strawberry' 1058 | r_count = word.count('r') 1059 | print(f"The letter 'r' appears {r_count} times in the word '{word}'.")"#; 1060 | let mut state = HashMap::new(); 1061 | let result = evaluate_python_code(code, vec![], &mut state).unwrap(); 1062 | assert_eq!( 1063 | result, 1064 | "The letter 'r' appears 3 times in the word 'strawberry'." 1065 | ); 1066 | } 1067 | 1068 | #[test] 1069 | fn test_final_answer_execution() { 1070 | let tools: Vec> = vec![Box::new(FinalAnswerTool::new())]; 1071 | let mut state = HashMap::new(); 1072 | let result = 1073 | evaluate_python_code("final_answer(answer='Hello, world!')", tools, &mut state); 1074 | assert_eq!( 1075 | result, 1076 | Err(InterpreterError::FinalAnswer("Hello, world!".to_string())) 1077 | ); 1078 | } 1079 | 1080 | #[test] 1081 | fn test_evaluate_python_code_with_subscript() { 1082 | let code = textwrap::dedent( 1083 | r#" 1084 | word = 'strawberry' 1085 | print(word[3])"#, 1086 | ); 1087 | let mut state = HashMap::new(); 1088 | let result = evaluate_python_code(&code, vec![], &mut state).unwrap(); 1089 | assert_eq!(result, "a"); 1090 | 1091 | let code = textwrap::dedent( 1092 | r#" 1093 | word = 'strawberry' 1094 | print(word[-3])"#, 1095 | ); 1096 | let mut state = HashMap::new(); 1097 | let result = evaluate_python_code(&code, vec![], &mut state).unwrap(); 1098 | assert_eq!(result, "r"); 1099 | 1100 | let code = textwrap::dedent( 1101 | r#" 1102 | word = 'strawberry' 1103 | print(word[9])"#, 1104 | ); 1105 | let mut state = HashMap::new(); 1106 | let result = evaluate_python_code(&code, vec![], &mut state).unwrap(); 1107 | assert_eq!(result, "y"); 1108 | 1109 | let code = textwrap::dedent( 1110 | r#" 1111 | word = 'strawberry' 1112 | print(word[10])"#, 1113 | ); 1114 | let mut state = HashMap::new(); 1115 | let result = evaluate_python_code(&code, vec![], &mut state); 1116 | assert_eq!( 1117 | result, 1118 | Err(InterpreterError::RuntimeError( 1119 | "IndexError: string index out of range".to_string() 1120 | )) 1121 | ); 1122 | 1123 | let code = textwrap::dedent( 1124 | r#" 1125 | numbers = [1, 2, 3, 4, 5] 1126 | print(numbers[1])"#, 1127 | ); 1128 | let mut state = HashMap::new(); 1129 | let result = evaluate_python_code(&code, vec![], &mut state).unwrap(); 1130 | assert_eq!(result, "2"); 1131 | 1132 | let code = textwrap::dedent( 1133 | r#" 1134 | numbers = [1, 2, 3, 4, 5] 1135 | print(numbers[-5])"#, 1136 | ); 1137 | let mut state = HashMap::new(); 1138 | let result = evaluate_python_code(&code, vec![], &mut state).unwrap(); 1139 | assert_eq!(result, "1"); 1140 | 1141 | let code = textwrap::dedent( 1142 | r#" 1143 | numbers = [1, 2, 3, 4, 5] 1144 | print(numbers[-6])"#, 1145 | ); 1146 | let mut state = HashMap::new(); 1147 | let result = evaluate_python_code(&code, vec![], &mut state); 1148 | assert_eq!( 1149 | result, 1150 | Err(InterpreterError::RuntimeError( 1151 | "IndexError: list index out of range".to_string() 1152 | )) 1153 | ); 1154 | } 1155 | 1156 | #[test] 1157 | fn test_evaluate_python_code_with_slice() { 1158 | let code = textwrap::dedent( 1159 | r#" 1160 | numbers = [1, 2, 3, 4, 5] 1161 | print(numbers[1:3])"#, 1162 | ); 1163 | let mut state = HashMap::new(); 1164 | let result = evaluate_python_code(&code, vec![], &mut state).unwrap(); 1165 | assert_eq!(result, "[2, 3]"); 1166 | 1167 | let code = textwrap::dedent( 1168 | r#" 1169 | numbers = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] 1170 | print(numbers[1:5:2])"#, 1171 | ); 1172 | let mut state = HashMap::new(); 1173 | let result = evaluate_python_code(&code, vec![], &mut state).unwrap(); 1174 | assert_eq!(result, "[2, 4]"); 1175 | 1176 | let code = textwrap::dedent( 1177 | r#" 1178 | numbers = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] 1179 | print(numbers[5:1:-2])"#, 1180 | ); 1181 | let mut state = HashMap::new(); 1182 | let result = evaluate_python_code(&code, vec![], &mut state).unwrap(); 1183 | assert_eq!(result, "[6, 4]"); 1184 | 1185 | let code = textwrap::dedent( 1186 | r#" 1187 | word = 'strawberry' 1188 | print(word[::-1])"#, 1189 | ); 1190 | let mut state = HashMap::new(); 1191 | let result = evaluate_python_code(&code, vec![], &mut state).unwrap(); 1192 | assert_eq!(result, "yrrebwarts"); 1193 | 1194 | let code = textwrap::dedent( 1195 | r#" 1196 | numbers = [1, 2, 3, 4, 5] 1197 | print(numbers[::-1])"#, 1198 | ); 1199 | let mut state = HashMap::new(); 1200 | let result = evaluate_python_code(&code, vec![], &mut state).unwrap(); 1201 | assert_eq!(result, "[5, 4, 3, 2, 1]"); 1202 | } 1203 | 1204 | #[test] 1205 | fn test_for_loop() { 1206 | let code = textwrap::dedent( 1207 | r#" 1208 | for i in range(5): 1209 | print(i) 1210 | "#, 1211 | ); 1212 | let mut state = HashMap::new(); 1213 | let _ = evaluate_python_code(&code, vec![], &mut state).unwrap(); 1214 | assert_eq!( 1215 | state 1216 | .get("print_logs") 1217 | .unwrap() 1218 | .downcast_ref::>() 1219 | .unwrap(), 1220 | &vec!["0", "1", "2", "3", "4"] 1221 | ); 1222 | } 1223 | 1224 | #[test] 1225 | fn test_for_loop_with_tools() { 1226 | let code = textwrap::dedent( 1227 | r#" 1228 | for i in range(5): 1229 | search = duckduckgo_search(query=i) 1230 | print(search) 1231 | "#, 1232 | ); 1233 | let mut state = HashMap::new(); 1234 | let tools: Vec> = vec![Box::new(DuckDuckGoSearchTool::new())]; 1235 | let _ = evaluate_python_code(&code, tools, &mut state).unwrap(); 1236 | } 1237 | 1238 | #[test] 1239 | fn test_evaluate_python_code_with_dict() { 1240 | let code = textwrap::dedent( 1241 | r#" 1242 | my_dict = {'a': "1", 'b': "2", 'c': "3"} 1243 | print(f"my_dict['a'] is {my_dict['a']}") 1244 | "#, 1245 | ); 1246 | let mut state = HashMap::new(); 1247 | let result = evaluate_python_code(&code, vec![], &mut state).unwrap(); 1248 | assert_eq!(result, "my_dict['a'] is 1"); 1249 | 1250 | let code = textwrap::dedent( 1251 | r#" 1252 | dinner_places = [ 1253 | { 1254 | "title": "25 Best Restaurants in Berlin, By Local Foodies", 1255 | "url": "https://www.timeout.com/berlin/restaurants/best-restaurants-in-berlin" 1256 | }, 1257 | { 1258 | "title": "The 38 Best Berlin Restaurants - Eater", 1259 | "url": "https://www.eater.com/maps/best-restaurants-berlin" 1260 | }, 1261 | { 1262 | "title": "THE 10 BEST Restaurants in Berlin - Tripadvisor", 1263 | "url": "https://www.tripadvisor.com/Restaurants-g187323-Berlin.html" 1264 | }, 1265 | { 1266 | "title": "12 Unique Restaurants in Berlin", 1267 | "url": "https://www.myglobalviewpoint.com/unique-restaurants-in-berlin/" 1268 | }, 1269 | { 1270 | "title": "Berlin's best restaurants: 101 places to eat right now", 1271 | "url": "https://www.the-berliner.com/food/best-restaurants-berlin-101-places-to-eat/" 1272 | } 1273 | ] 1274 | 1275 | for place in dinner_places: 1276 | print(f"{place['title']}: {place['url']}") 1277 | "#, 1278 | ); 1279 | let mut local_python_interpreter = LocalPythonInterpreter::new(vec![]); 1280 | let (_, execution_logs) = local_python_interpreter.forward(&code).unwrap(); 1281 | assert_eq!(execution_logs, "25 Best Restaurants in Berlin, By Local Foodies: https://www.timeout.com/berlin/restaurants/best-restaurants-in-berlin\nThe 38 Best Berlin Restaurants - Eater: https://www.eater.com/maps/best-restaurants-berlin\nTHE 10 BEST Restaurants in Berlin - Tripadvisor: https://www.tripadvisor.com/Restaurants-g187323-Berlin.html\n12 Unique Restaurants in Berlin: https://www.myglobalviewpoint.com/unique-restaurants-in-berlin/\nBerlin's best restaurants: 101 places to eat right now: https://www.the-berliner.com/food/best-restaurants-berlin-101-places-to-eat/"); 1282 | 1283 | let code = textwrap::dedent( 1284 | r#" 1285 | movies = [ 1286 | {"title": "Babygirl", "showtimes": ["12:50 pm", "6:20 pm"]}, 1287 | {"title": "Better Man", "showtimes": ["9:20 pm"]}, 1288 | {"title": "La acompañante", "showtimes": ["3:40 pm", "6:30 pm", "9:10 pm"]}, 1289 | {"title": "Amenaza en el aire", "showtimes": ["9:30 pm"]}, 1290 | {"title": "Juf Braaksel en de Geniale Ontsnapping", "showtimes": ["12:30 pm"]}, 1291 | {"title": "Juffrouw Pots", "showtimes": ["10:35 am", "3:50 pm"]}, 1292 | {"title": "K3 en Het Lied van de Zeemeermin", "showtimes": ["10:00 am"]}, 1293 | {"title": "Marked Men", "showtimes": ["2:50 pm", "6:50 pm"]}, 1294 | {"title": "Vaiana 2", "showtimes": ["11:10 am", "12:40 pm"]}, 1295 | {"title": "Mufasa: El rey león", "showtimes": ["10:20 am", "3:10 pm", "9:00 pm"]}, 1296 | {"title": "Paddington: Aventura en la selva", "showtimes": ["12:20 pm", "3:30 pm", "6:10 pm"]}, 1297 | {"title": "Royal Opera House: The Tales of Hoffmann", "showtimes": ["1:30 pm"]}, 1298 | {"title": "The Growcodile", "showtimes": ["10:10 am"]}, 1299 | {"title": "Vivir el momento", "showtimes": ["5:20 pm"]}, 1300 | {"title": "Wicked", "showtimes": ["7:00 pm"]}, 1301 | {"title": "Woezel & Pip op Avontuur in de Tovertuin", "showtimes": ["10:30 am", "1:50 pm"]} 1302 | ] 1303 | 1304 | for movie in movies: 1305 | print(f"{movie['title']}: {', '.join(movie['showtimes'])}") 1306 | 1307 | "#, 1308 | ); 1309 | let mut local_python_interpreter = LocalPythonInterpreter::new(vec![]); 1310 | let (_, _) = local_python_interpreter.forward(&code).unwrap(); 1311 | 1312 | let code = textwrap::dedent( 1313 | r#" 1314 | urls = [ 1315 | "https://www.tripadvisor.com/Restaurants-g187323-Berlin.html", 1316 | "https://www.timeout.com/berlin/restaurants/best-restaurants-in-berlin" 1317 | ] 1318 | 1319 | for url in urls: 1320 | page_content = duckduckgo_search(url) 1321 | print(page_content) 1322 | print("\n" + "="*80 + "\n") # Print separator between pages 1323 | "#, 1324 | ); 1325 | let mut state = HashMap::new(); 1326 | let tools: Vec> = vec![Box::new(DuckDuckGoSearchTool::new())]; 1327 | let _ = evaluate_python_code(&code, tools, &mut state).unwrap(); 1328 | } 1329 | 1330 | #[test] 1331 | fn test_evaluate_python_code_with_list_comprehension() { 1332 | let code = textwrap::dedent( 1333 | r#" 1334 | a = [1,2,3] 1335 | print([x for x in a]) 1336 | "#, 1337 | ); 1338 | let mut state = HashMap::new(); 1339 | let _ = evaluate_python_code(&code, vec![], &mut state).unwrap(); 1340 | assert_eq!( 1341 | state 1342 | .get("print_logs") 1343 | .unwrap() 1344 | .downcast_ref::>() 1345 | .unwrap(), 1346 | &vec!["[1, 2, 3]"] 1347 | ); 1348 | } 1349 | 1350 | #[test] 1351 | fn test_evaluate_python_code_append_to_list() { 1352 | let code = textwrap::dedent( 1353 | r#" 1354 | a = [1,2,3] 1355 | a.append(4) 1356 | print(a) 1357 | "#, 1358 | ); 1359 | let mut state = HashMap::new(); 1360 | let _ = evaluate_python_code(&code, vec![], &mut state).unwrap(); 1361 | assert_eq!( 1362 | state 1363 | .get("print_logs") 1364 | .unwrap() 1365 | .downcast_ref::>() 1366 | .unwrap(), 1367 | &vec!["[1, 2, 3, 4]"] 1368 | ); 1369 | 1370 | let code = textwrap::dedent( 1371 | r#" 1372 | urls = [ 1373 | "https://www.imdb.com/showtimes/cinema/ES/ci1028808/ES/08520", 1374 | "https://en.pathe.nl/bioscoopagenda", 1375 | "https://www.filmvandaag.nl/bioscoop?filter=64" 1376 | ] 1377 | movies = [] 1378 | for url in urls: 1379 | page_content = url 1380 | movies.append(page_content) 1381 | 1382 | print(movies) 1383 | "#, 1384 | ); 1385 | let mut state = HashMap::new(); 1386 | let tools: Vec> = vec![Box::new(VisitWebsiteTool::new())]; 1387 | let _ = evaluate_python_code(&code, tools, &mut state).unwrap(); 1388 | assert_eq!( 1389 | state 1390 | .get("print_logs") 1391 | .unwrap() 1392 | .downcast_ref::>() 1393 | .unwrap(), 1394 | &vec!["[https://www.imdb.com/showtimes/cinema/ES/ci1028808/ES/08520, https://en.pathe.nl/bioscoopagenda, https://www.filmvandaag.nl/bioscoop?filter=64]"] 1395 | ); 1396 | } 1397 | 1398 | #[test] 1399 | fn test_evaluate_python_code_with_error() { 1400 | let code = textwrap::dedent( 1401 | r#" 1402 | guidelines = ( 1403 | "To avoid being blocked by websites, use the following guidelines for user agent strings:\n" 1404 | "1. Use a valid browser user agent to mimic a real web browser.\n" 1405 | "2. Rotate User-Agent headers for each outgoing request to prevent identification as a bot.\n" 1406 | "3. Avoid using generic user-agent strings like 'Python Requests Library' or an empty UA string.\n" 1407 | "4. Use a user agent string that includes information about the browser, operating system, and other parameters.\n" 1408 | "5. Understand that websites use user agent strings to organize protection against malicious actions, including parsing blocks." 1409 | ) 1410 | 1411 | "#, 1412 | ); 1413 | let code_2 = textwrap::dedent( 1414 | r#" 1415 | print(guidelines) 1416 | "#, 1417 | ); 1418 | let tools: Vec> = vec![Box::new(VisitWebsiteTool::new())]; 1419 | let mut local_python_interpreter = LocalPythonInterpreter::new(tools); 1420 | let (_, logs) = local_python_interpreter.forward(&code).unwrap(); 1421 | println!("logs: {:?}", logs); 1422 | let (_, logs_2) = local_python_interpreter.forward(&code_2).unwrap(); 1423 | println!("logs_2: {:?}", logs_2); 1424 | } 1425 | } 1426 | -------------------------------------------------------------------------------- /src/logger.rs: -------------------------------------------------------------------------------- 1 | use colored::Colorize; 2 | use log::{Level, Metadata, Record}; 3 | use std::io::Write; 4 | use terminal_size::{self, Width}; 5 | 6 | pub struct ColoredLogger; 7 | 8 | impl log::Log for ColoredLogger { 9 | fn enabled(&self, metadata: &Metadata) -> bool { 10 | metadata.level() <= Level::Info 11 | } 12 | 13 | fn log(&self, record: &Record) { 14 | if self.enabled(record.metadata()) { 15 | let mut stdout = std::io::stdout(); 16 | let msg = record.args().to_string(); 17 | 18 | // Add a newline before each message for spacing 19 | writeln!(stdout).unwrap(); 20 | 21 | // Get terminal width 22 | let width = if let Some((Width(w), _)) = terminal_size::terminal_size() { 23 | w as usize - 2 // Subtract 2 for the side borders 24 | } else { 25 | 78 // fallback width if terminal size cannot be determined 26 | }; 27 | 28 | // Box drawing characters 29 | let top_border = format!("╔{}═", "═".repeat(width)); 30 | let bottom_border = format!("╚{}═", "═".repeat(width)); 31 | let side_border = "║ "; 32 | 33 | // Check for specific prefixes and apply different colors 34 | if msg.starts_with("Observation:") { 35 | let (prefix, content) = msg.split_at(12); 36 | writeln!(stdout, "{}", top_border.yellow()).unwrap(); 37 | writeln!( 38 | stdout, 39 | "{}{}{}", 40 | side_border.yellow(), 41 | prefix.yellow().bold(), 42 | content.green() 43 | ) 44 | .unwrap(); 45 | writeln!(stdout, "{}", bottom_border.yellow()).unwrap(); 46 | } else if msg.starts_with("Error:") { 47 | let (prefix, content) = msg.split_at(6); 48 | writeln!(stdout, "{}", top_border.red()).unwrap(); 49 | writeln!( 50 | stdout, 51 | "{}{}{}", 52 | side_border.red(), 53 | prefix.red().bold(), 54 | content.white().bold() 55 | ) 56 | .unwrap(); 57 | } else if msg.starts_with("Executing tool call:") { 58 | let (prefix, content) = msg.split_at(21); 59 | writeln!(stdout, "{}", top_border.magenta()).unwrap(); 60 | writeln!( 61 | stdout, 62 | "{}{}{}", 63 | side_border.magenta(), 64 | prefix.magenta().bold(), 65 | content.cyan() 66 | ) 67 | .unwrap(); 68 | writeln!(stdout, "{}", bottom_border.magenta()).unwrap(); 69 | } else if msg.starts_with("Plan:") { 70 | let (prefix, content) = msg.split_at(5); 71 | writeln!(stdout, "{}", top_border.red()).unwrap(); 72 | writeln!( 73 | stdout, 74 | "{}{}{}", 75 | side_border.red(), 76 | prefix.red().bold(), 77 | content.blue().italic() 78 | ) 79 | .unwrap(); 80 | writeln!(stdout, "{}", bottom_border.red()).unwrap(); 81 | } else if msg.starts_with("Final answer:") { 82 | let (prefix, content) = msg.split_at(13); 83 | writeln!(stdout, "{}", top_border.green()).unwrap(); 84 | writeln!( 85 | stdout, 86 | "{}{}{}", 87 | side_border.green(), 88 | prefix.green().bold(), 89 | content.white().bold() 90 | ) 91 | .unwrap(); 92 | writeln!(stdout, "{}", bottom_border.green()).unwrap(); 93 | } else if msg.starts_with("Code:") { 94 | let (prefix, content) = msg.split_at(5); 95 | writeln!(stdout, "{}", top_border.yellow()).unwrap(); 96 | writeln!( 97 | stdout, 98 | "{}{}{}", 99 | side_border.yellow(), 100 | prefix.yellow().bold(), 101 | content.magenta().bold() 102 | ) 103 | .unwrap(); 104 | writeln!(stdout, "{}", bottom_border.yellow()).unwrap(); 105 | } else { 106 | writeln!(stdout, "{}", top_border.blue()).unwrap(); 107 | writeln!(stdout, "{}{}", side_border.blue(), msg.blue()).unwrap(); 108 | writeln!(stdout, "{}", bottom_border.blue()).unwrap(); 109 | } 110 | } 111 | } 112 | 113 | fn flush(&self) {} 114 | } 115 | 116 | pub static LOGGER: ColoredLogger = ColoredLogger; 117 | -------------------------------------------------------------------------------- /src/models/mod.rs: -------------------------------------------------------------------------------- 1 | pub mod model_traits; 2 | pub mod ollama; 3 | pub mod openai; 4 | pub mod types; 5 | -------------------------------------------------------------------------------- /src/models/model_traits.rs: -------------------------------------------------------------------------------- 1 | use std::collections::HashMap; 2 | 3 | use crate::{ 4 | errors::AgentError, 5 | models::{openai::ToolCall, types::Message}, 6 | tools::tool_traits::ToolInfo, 7 | }; 8 | use anyhow::Result; 9 | pub trait ModelResponse { 10 | fn get_response(&self) -> Result; 11 | fn get_tools_used(&self) -> Result, AgentError>; 12 | } 13 | 14 | pub trait Model { 15 | fn run( 16 | &self, 17 | input_messages: Vec, 18 | tools: Vec, 19 | max_tokens: Option, 20 | args: Option>>, 21 | ) -> Result, AgentError>; 22 | } 23 | -------------------------------------------------------------------------------- /src/models/ollama.rs: -------------------------------------------------------------------------------- 1 | use std::collections::HashMap; 2 | 3 | use serde::Deserialize; 4 | use serde_json::json; 5 | 6 | use crate::{errors::AgentError, tools::ToolInfo}; 7 | use anyhow::Result; 8 | 9 | use super::{ 10 | model_traits::{Model, ModelResponse}, 11 | openai::ToolCall, 12 | types::{Message, MessageRole}, 13 | }; 14 | 15 | #[derive(Debug, Deserialize)] 16 | pub struct OllamaResponse { 17 | pub message: AssistantMessage, 18 | } 19 | 20 | #[derive(Debug, Deserialize)] 21 | pub struct AssistantMessage { 22 | pub role: MessageRole, 23 | pub content: Option, 24 | pub tool_calls: Option>, 25 | } 26 | 27 | impl ModelResponse for OllamaResponse { 28 | fn get_response(&self) -> Result { 29 | Ok(self.message.content.clone().unwrap_or_default()) 30 | } 31 | 32 | fn get_tools_used(&self) -> Result, AgentError> { 33 | Ok(self.message.tool_calls.clone().unwrap_or_default()) 34 | } 35 | } 36 | 37 | #[derive(Debug, Clone)] 38 | pub struct OllamaModel { 39 | model_id: String, 40 | temperature: f32, 41 | url: String, 42 | client: reqwest::blocking::Client, 43 | ctx_length: usize, 44 | } 45 | 46 | #[derive(Default)] 47 | pub struct OllamaModelBuilder { 48 | model_id: String, 49 | temperature: Option, 50 | client: Option, 51 | url: Option, 52 | ctx_length: Option, 53 | } 54 | 55 | impl OllamaModelBuilder { 56 | pub fn new() -> Self { 57 | let client = reqwest::blocking::Client::new(); 58 | Self { 59 | model_id: "llama3.2".to_string(), 60 | temperature: Some(0.5), 61 | client: Some(client), 62 | url: Some("http://localhost:11434".to_string()), 63 | ctx_length: Some(2048), 64 | } 65 | } 66 | 67 | pub fn model_id(mut self, model_id: &str) -> Self { 68 | self.model_id = model_id.to_string(); 69 | self 70 | } 71 | 72 | pub fn temperature(mut self, temperature: Option) -> Self { 73 | self.temperature = temperature; 74 | self 75 | } 76 | 77 | pub fn url(mut self, url: String) -> Self { 78 | self.url = Some(url); 79 | self 80 | } 81 | 82 | pub fn ctx_length(mut self, ctx_length: usize) -> Self { 83 | self.ctx_length = Some(ctx_length); 84 | self 85 | } 86 | 87 | pub fn build(self) -> OllamaModel { 88 | OllamaModel { 89 | model_id: self.model_id, 90 | temperature: self.temperature.unwrap_or(0.5), 91 | url: self.url.unwrap_or("http://localhost:11434".to_string()), 92 | client: self.client.unwrap_or_default(), 93 | ctx_length: self.ctx_length.unwrap_or(2048), 94 | } 95 | } 96 | } 97 | 98 | impl Model for OllamaModel { 99 | fn run( 100 | &self, 101 | messages: Vec, 102 | tools_to_call_from: Vec, 103 | max_tokens: Option, 104 | args: Option>>, 105 | ) -> Result, AgentError> { 106 | let messages = messages 107 | .iter() 108 | .map(|message| { 109 | json!({ 110 | "role": message.role, 111 | "content": message.content 112 | }) 113 | }) 114 | .collect::>(); 115 | 116 | let tools = json!(tools_to_call_from); 117 | 118 | let mut body = json!({ 119 | "model": self.model_id, 120 | "messages": messages, 121 | "temperature": self.temperature, 122 | "stream": false, 123 | "options": json!({ 124 | "num_ctx": self.ctx_length, 125 | }), 126 | "tools": tools, 127 | "max_tokens": max_tokens.unwrap_or(1500), 128 | }); 129 | if let Some(args) = args { 130 | for (key, value) in args { 131 | body["options"][key] = json!(value); 132 | } 133 | } 134 | 135 | let response = self 136 | .client 137 | .post(format!("{}/api/chat", self.url)) 138 | .json(&body) 139 | .send() 140 | .map_err(|e| { 141 | AgentError::Generation(format!("Failed to get response from Ollama: {}", e)) 142 | })?; 143 | let output = response.json::().unwrap(); 144 | Ok(Box::new(output)) 145 | } 146 | } 147 | -------------------------------------------------------------------------------- /src/models/openai.rs: -------------------------------------------------------------------------------- 1 | use std::collections::HashMap; 2 | 3 | use crate::errors::AgentError; 4 | use crate::models::model_traits::{Model, ModelResponse}; 5 | use crate::models::types::{Message, MessageRole}; 6 | use crate::tools::ToolInfo; 7 | use anyhow::Result; 8 | use reqwest::blocking::Client; 9 | use serde::{Deserialize, Serialize}; 10 | use serde_json::{json, Value}; 11 | 12 | #[derive(Debug, Deserialize)] 13 | pub struct OpenAIResponse { 14 | pub choices: Vec, 15 | } 16 | 17 | #[derive(Debug, Deserialize)] 18 | pub struct Choice { 19 | pub message: AssistantMessage, 20 | } 21 | 22 | #[derive(Debug, Deserialize)] 23 | pub struct AssistantMessage { 24 | pub role: MessageRole, 25 | pub content: Option, 26 | pub tool_calls: Option>, 27 | pub refusal: Option, 28 | } 29 | 30 | #[derive(Debug, Serialize, Deserialize, Clone)] 31 | pub struct ToolCall { 32 | pub id: Option, 33 | #[serde(rename = "type")] 34 | pub call_type: Option, 35 | pub function: FunctionCall, 36 | } 37 | 38 | #[derive(Debug, Serialize, Deserialize, Clone)] 39 | pub struct FunctionCall { 40 | pub name: String, 41 | #[serde(deserialize_with = "deserialize_arguments")] 42 | pub arguments: Value, 43 | } 44 | 45 | // Add this function to handle argument deserialization 46 | fn deserialize_arguments<'de, D>(deserializer: D) -> Result 47 | where 48 | D: serde::Deserializer<'de>, 49 | { 50 | let value = Value::deserialize(deserializer)?; 51 | 52 | // If it's a string, try to parse it as JSON 53 | if let Value::String(s) = &value { 54 | if let Ok(parsed) = serde_json::from_str(s) { 55 | return Ok(parsed); 56 | } 57 | } 58 | 59 | Ok(value) 60 | } 61 | 62 | impl FunctionCall { 63 | pub fn get_arguments(&self) -> Result> { 64 | // First try to parse as a HashMap directly 65 | if let Ok(map) = serde_json::from_value(self.arguments.clone()) { 66 | return Ok(map); 67 | } 68 | 69 | // If that fails, try to parse as a string and then parse that string as JSON 70 | if let Value::String(arg_str) = &self.arguments { 71 | if let Ok(parsed) = serde_json::from_str(arg_str) { 72 | return Ok(parsed); 73 | } 74 | } 75 | 76 | // If all parsing attempts fail, return the original error 77 | Err(anyhow::anyhow!( 78 | "Failed to parse arguments as HashMap or JSON string" 79 | )) 80 | } 81 | } 82 | 83 | impl ModelResponse for OpenAIResponse { 84 | fn get_response(&self) -> Result { 85 | Ok(self 86 | .choices 87 | .first() 88 | .ok_or(AgentError::Generation( 89 | "No message returned from OpenAI".to_string(), 90 | ))? 91 | .message 92 | .content 93 | .clone() 94 | .unwrap_or_default()) 95 | } 96 | 97 | fn get_tools_used(&self) -> Result, AgentError> { 98 | Ok(self 99 | .choices 100 | .first() 101 | .ok_or(AgentError::Generation( 102 | "No message returned from OpenAI".to_string(), 103 | ))? 104 | .message 105 | .tool_calls 106 | .clone() 107 | .unwrap_or_default()) 108 | } 109 | } 110 | 111 | #[derive(Debug)] 112 | pub struct OpenAIServerModel { 113 | pub base_url: String, 114 | pub model_id: String, 115 | pub client: Client, 116 | pub temperature: f32, 117 | pub api_key: String, 118 | } 119 | 120 | impl OpenAIServerModel { 121 | pub fn new( 122 | base_url: Option<&str>, 123 | model_id: Option<&str>, 124 | temperature: Option, 125 | api_key: Option, 126 | ) -> Self { 127 | let api_key = api_key.unwrap_or_else(|| { 128 | std::env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY must be set") 129 | }); 130 | let model_id = model_id.unwrap_or("gpt-4o-mini").to_string(); 131 | let base_url = base_url.unwrap_or("https://api.openai.com/v1/chat/completions"); 132 | let client = Client::new(); 133 | 134 | OpenAIServerModel { 135 | base_url: base_url.to_string(), 136 | model_id, 137 | client, 138 | temperature: temperature.unwrap_or(0.5), 139 | api_key, 140 | } 141 | } 142 | } 143 | 144 | impl Model for OpenAIServerModel { 145 | fn run( 146 | &self, 147 | messages: Vec, 148 | tools_to_call_from: Vec, 149 | max_tokens: Option, 150 | args: Option>>, 151 | ) -> Result, AgentError> { 152 | let max_tokens = max_tokens.unwrap_or(1500); 153 | 154 | let messages = messages 155 | .iter() 156 | .map(|message| { 157 | json!({ 158 | "role": message.role, 159 | "content": message.content 160 | }) 161 | }) 162 | .collect::>(); 163 | let mut body = json!({ 164 | "model": self.model_id, 165 | "messages": messages, 166 | "temperature": self.temperature, 167 | "max_tokens": max_tokens, 168 | }); 169 | 170 | if !tools_to_call_from.is_empty() { 171 | body["tools"] = json!(tools_to_call_from); 172 | body["tool_choice"] = json!("required"); 173 | } 174 | 175 | if let Some(args) = args { 176 | let body_map = body.as_object_mut().unwrap(); 177 | for (key, value) in args { 178 | body_map.insert(key, json!(value)); 179 | } 180 | } 181 | 182 | let response = self 183 | .client 184 | .post(&self.base_url) 185 | .header("Authorization", format!("Bearer {}", self.api_key)) 186 | .json(&body) 187 | .send() 188 | .map_err(|e| { 189 | AgentError::Generation(format!("Failed to get response from OpenAI: {}", e)) 190 | })?; 191 | 192 | match response.status() { 193 | reqwest::StatusCode::OK => { 194 | let response = response.json::().unwrap(); 195 | Ok(Box::new(response)) 196 | } 197 | _ => Err(AgentError::Generation(format!( 198 | "Failed to get response from OpenAI: {}", 199 | response.text().unwrap() 200 | ))), 201 | } 202 | } 203 | } 204 | -------------------------------------------------------------------------------- /src/models/types.rs: -------------------------------------------------------------------------------- 1 | use serde::{Deserialize, Serialize}; 2 | use std::fmt::Debug; 3 | 4 | #[derive(Debug, Deserialize, Serialize, Clone, Copy, PartialEq, Eq)] 5 | #[serde(rename_all = "lowercase")] 6 | pub enum MessageRole { 7 | User, 8 | Assistant, 9 | System, 10 | #[serde(rename = "tool")] 11 | ToolCall, 12 | #[serde(rename = "tool_response")] 13 | ToolResponse, 14 | } 15 | 16 | impl std::fmt::Display for MessageRole { 17 | fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 18 | match self { 19 | MessageRole::User => write!(f, "User"), 20 | MessageRole::Assistant => write!(f, "Assistant"), 21 | MessageRole::System => write!(f, "System"), 22 | MessageRole::ToolCall => write!(f, "ToolCall"), 23 | MessageRole::ToolResponse => write!(f, "ToolResponse"), 24 | } 25 | } 26 | } 27 | 28 | #[derive(Debug, Serialize, Clone)] 29 | pub struct Message { 30 | pub role: MessageRole, 31 | pub content: String, 32 | } 33 | 34 | impl std::fmt::Display for Message { 35 | fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 36 | write!(f, "Message(role: {}, content: {})", self.role, self.content) 37 | } 38 | } 39 | -------------------------------------------------------------------------------- /src/prompts.rs: -------------------------------------------------------------------------------- 1 | //! This module contains the prompts for the agents. 2 | 3 | /// The system prompt for the code agent. 4 | pub const CODE_SYSTEM_PROMPT: &str = r#"You are an expert assistant who can solve any task using code blobs. You will be given a task to solve as best you can. 5 | To do so, you have been given access to a list of tools: these tools are basically Python functions which you can call with code. 6 | To solve the task, you must plan forward to proceed in a series of steps, in a cycle of 'Thought:', 'Code:', and 'Observation:' sequences. 7 | 8 | At each step, in the 'Thought:' sequence, you should first explain your reasoning towards solving the task and the tools that you want to use. 9 | Then in the 'Code:' sequence, you should write the code in simple Python. The code sequence must end with '' sequence. 10 | During each intermediate step, you can use 'print()' to save whatever important information you will then need. 11 | These print outputs will then appear in the 'Observation:' field, which will be available as input for the next step. 12 | In the end you have to return a final answer using the `final_answer` tool. 13 | 14 | Here are a few examples using notional tools: 15 | --- 16 | Task: "Generate an image of the oldest person in this document." 17 | 18 | Thought: I will proceed step by step and use the following tools: `document_qa` to find the oldest person in the document, then `image_generator` to generate an image according to the answer. 19 | Code: 20 | ```py 21 | answer = document_qa(document=document, question="Who is the oldest person mentioned?") 22 | print(answer) 23 | ``` 24 | Observation: "The oldest person in the document is John Doe, a 55 year old lumberjack living in Newfoundland." 25 | 26 | Thought: I will now generate an image showcasing the oldest person. 27 | Code: 28 | ```py 29 | image = image_generator("A portrait of John Doe, a 55-year-old man living in Canada.") 30 | final_answer(image) 31 | ``` 32 | 33 | --- 34 | Task: "What is the result of the following operation: 5 + 3 + 1294.678?" 35 | 36 | Thought: I will use python code to compute the result of the operation and then return the final answer using the `final_answer` tool 37 | Code: 38 | ```py 39 | result = 5 + 3 + 1294.678 40 | final_answer(result) 41 | ``` 42 | 43 | --- 44 | Task: 45 | "Answer the question in the variable `question` about the image stored in the variable `image`. The question is in French. 46 | You have been provided with these additional arguments, that you can access using the keys as variables in your python code: 47 | {'question': 'Quel est l'animal sur l'image?', 'image': 'path/to/image.jpg'}" 48 | 49 | Thought: I will use the following tools: `translator` to translate the question into English and then `image_qa` to answer the question on the input image. 50 | Code: 51 | ```py 52 | translated_question = translator(question=question, src_lang="French", tgt_lang="English") 53 | print(f"The translated question is {translated_question}.") 54 | answer = image_qa(image=image, question=translated_question) 55 | final_answer(f"The answer is {answer}") 56 | ``` 57 | 58 | --- 59 | Task: 60 | In a 1979 interview, Stanislaus Ulam discusses with Martin Sherwin about other great physicists of his time, including Oppenheimer. 61 | What does he say was the consequence of Einstein learning too much math on his creativity, in one word? 62 | 63 | Thought: I need to find and read the 1979 interview of Stanislaus Ulam with Martin Sherwin. 64 | Code: 65 | ```py 66 | pages = search(query="1979 interview Stanislaus Ulam Martin Sherwin physicists Einstein") 67 | print(pages) 68 | ``` 69 | Observation: 70 | No result found for query "1979 interview Stanislaus Ulam Martin Sherwin physicists Einstein". 71 | 72 | Thought: The query was maybe too restrictive and did not find any results. Let's try again with a broader query. 73 | Code: 74 | ```py 75 | pages = search(query="1979 interview Stanislaus Ulam") 76 | print(pages) 77 | ``` 78 | Observation: 79 | Found 6 pages: 80 | [Stanislaus Ulam 1979 interview](https://ahf.nuclearmuseum.org/voices/oral-histories/stanislaus-ulams-interview-1979/) 81 | 82 | [Ulam discusses Manhattan Project](https://ahf.nuclearmuseum.org/manhattan-project/ulam-manhattan-project/) 83 | 84 | (truncated) 85 | 86 | Thought: I will read the first 2 pages to know more. 87 | Code: 88 | ```py 89 | for url in ["https://ahf.nuclearmuseum.org/voices/oral-histories/stanislaus-ulams-interview-1979/", "https://ahf.nuclearmuseum.org/manhattan-project/ulam-manhattan-project/"]: 90 | whole_page = visit_webpage(url) 91 | print(whole_page) 92 | print("\n" + "="*80 + "\n") # Print separator between pages 93 | ``` 94 | Observation: 95 | Manhattan Project Locations: 96 | Los Alamos, NM 97 | Stanislaus Ulam was a Polish-American mathematician. He worked on the Manhattan Project at Los Alamos and later helped design the hydrogen bomb. In this interview, he discusses his work at 98 | (truncated) 99 | 100 | Thought: I now have the final answer: from the webpages visited, Stanislaus Ulam says of Einstein: "He learned too much mathematics and sort of diminished, it seems to me personally, it seems to me his purely physics creativity." Let's answer in one word. 101 | Code: 102 | ```py 103 | final_answer("diminished") 104 | ``` 105 | 106 | --- 107 | Task: "Which city has the highest population: Guangzhou or Shanghai?" 108 | 109 | Thought: I need to get the populations for both cities and compare them: I will use the tool `search` to get the population of both cities. 110 | Code: 111 | ```py 112 | for city in ["Guangzhou", "Shanghai"]: 113 | print(f"Population {city}:", search(f"{city} population") 114 | ``` 115 | Observation: 116 | Population Guangzhou: ['Guangzhou has a population of 15 million inhabitants as of 2021.'] 117 | Population Shanghai: '26 million (2019)' 118 | 119 | Thought: Now I know that Shanghai has the highest population. 120 | Code: 121 | ```py 122 | final_answer("Shanghai") 123 | ``` 124 | 125 | --- 126 | Task: "What is the current age of the pope, raised to the power 0.36?" 127 | 128 | Thought: I will use the tool `wiki` to get the age of the pope, and confirm that with a web search. 129 | Code: 130 | ```py 131 | pope_age_wiki = wiki(query="current pope age") 132 | print("Pope age as per wikipedia:", pope_age_wiki) 133 | pope_age_search = web_search(query="current pope age") 134 | print("Pope age as per google search:", pope_age_search) 135 | ``` 136 | Observation: 137 | Pope age: "The pope Francis is currently 88 years old." 138 | 139 | Thought: I know that the pope is 88 years old. Let's compute the result using python code. 140 | Code: 141 | ```py 142 | pope_current_age = 88 ** 0.36 143 | final_answer(pope_current_age) 144 | ``` 145 | 146 | Above example were using notional tools that might not exist for you. On top of performing computations in the Python code snippets that you create, you only have access to these tools: 147 | 148 | {{tool_descriptions}} 149 | 150 | {{managed_agents_descriptions}} 151 | 152 | Here are the rules you should always follow to solve your task: 153 | 1. Always provide a 'Thought:' sequence, and a 'Code:\n```py' sequence ending with '```' sequence, else you will fail. 154 | 2. Use only variables that you have defined! 155 | 3. Always use the right arguments for the tools. DO NOT pass the arguments as a dict as in 'answer = wiki({'query': "What is the place where James Bond lives?"})', but use the arguments directly as in 'answer = wiki(query="What is the place where James Bond lives?")'. 156 | 4. Take care to not chain too many sequential tool calls in the same code block, especially when the output format is unpredictable. For instance, a call to search has an unpredictable return format, so do not have another tool call that depends on its output in the same block: rather output results with print() to use them in the next block. 157 | 5. Call a tool only when needed, and never re-do a tool call that you previously did with the exact same parameters. 158 | 6. Don't name any new variable with the same name as a tool: for instance don't name a variable 'final_answer'. 159 | 7. Never create any notional variables in our code, as having these in your logs will derail you from the true variables. 160 | 8. You can use imports in your code, but only from the following list of modules: {{authorized_imports}} 161 | 9. The state persists between code executions: so if in one step you've created variables or imported modules, these will all persist. 162 | 10. Don't give up! You're in charge of solving the task, not providing directions to solve it. 163 | 164 | Now Begin! If you solve the task correctly, you will receive a reward of $1,000,000. 165 | 166 | "#; 167 | 168 | /// The system prompt for the facts agent. This prompt is used to build a comprehensive preparatory survey of which facts we have at our disposal and which ones we still need. 169 | pub const SYSTEM_PROMPT_FACTS: &str = r#"Below I will present you a task. 170 | 171 | You will now build a comprehensive preparatory survey of which facts we have at our disposal and which ones we still need. 172 | To do so, you will have to read the task and identify things that must be discovered in order to successfully complete it. 173 | Don't make any assumptions. For each item, provide a thorough reasoning. Here is how you will structure this survey: 174 | 175 | --- 176 | ### 1. Facts given in the task 177 | List here the specific facts given in the task that could help you (there might be nothing here). 178 | 179 | ### 2. Facts to look up 180 | List here any facts that we may need to look up. 181 | Also list where to find each of these, for instance a website, a file... - maybe the task contains some sources that you should re-use here. 182 | 183 | ### 3. Facts to derive 184 | List here anything that we want to derive from the above by logical reasoning, for instance computation or simulation. 185 | 186 | Keep in mind that "facts" will typically be specific names, dates, values, etc. Your answer should use the below headings: 187 | ### 1. Facts given in the task 188 | ### 2. Facts to look up 189 | ### 3. Facts to derive 190 | Do not add anything else."#; 191 | 192 | /// The system prompt for the plan agent. This prompt is used to develop a step-by-step high-level plan to solve a task. 193 | pub const SYSTEM_PROMPT_PLAN: &str = r#"You are a world expert at making efficient plans to solve any task using a set of carefully crafted tools. 194 | 195 | Now for the given task, develop a step-by-step high-level plan taking into account the above inputs and list of facts. 196 | This plan should involve individual tasks based on the available tools, that if executed correctly will yield the correct answer. 197 | Do not skip steps, do not add any superfluous steps. Only write the high-level plan, DO NOT DETAIL INDIVIDUAL TOOL CALLS. 198 | After writing the final step of the plan, write the '\n' tag and stop there."#; 199 | 200 | /// The user prompt for the plan agent. This prompt is used to develop a step-by-step high-level plan to solve a task. 201 | pub fn user_prompt_plan( 202 | task: &str, 203 | tool_descriptions: &str, 204 | managed_agent_descriptions: &str, 205 | answer_facts: &str, 206 | ) -> String { 207 | format!( 208 | "Here is your task: 209 | 210 | Task: 211 | ``` 212 | {} 213 | ``` 214 | 215 | Your plan can leverage any of these tools: 216 | {} 217 | 218 | {} 219 | 220 | List of facts that you know: 221 | ``` 222 | {} 223 | ``` 224 | 225 | Now begin! Write your plan below", 226 | task, tool_descriptions, managed_agent_descriptions, answer_facts 227 | ) 228 | } 229 | 230 | /// The system prompt for the tool calling agent. This prompt is used for models that do not have tool calling capabilities. 231 | pub const TOOL_CALLING_SYSTEM_PROMPT: &str = r#"You are an expert assistant who can solve any task using tool calls. You will be given a task to solve as best you can. 232 | To do so, you have been given access to the following tools: {{tool_names}} 233 | 234 | The tool call you write is an action: after the tool is executed, you will get the result of the tool call as an "observation". 235 | This Action/Observation can repeat N times, you should take several steps when needed. 236 | 237 | You can use the result of the previous action as input for the next action. 238 | The observation will always be a string: it can represent a file, like "image_1.jpg". 239 | Then you can use it as input for the next action. You can do it for instance as follows: 240 | 241 | Observation: "image_1.jpg" 242 | 243 | Action: 244 | { 245 | "tool_name": "image_transformer", 246 | "tool_arguments": {"image": "image_1.jpg"} 247 | } 248 | 249 | To provide the final answer to the task, use an action blob with "tool_name": "final_answer" tool. It is the only way to complete the task, else you will be stuck on a loop. So your final output should look like this: 250 | Action: 251 | { 252 | "tool_name": "final_answer", 253 | "tool_arguments": {"answer": "insert your final answer here"} 254 | } 255 | 256 | 257 | Here are a few examples using notional tools: 258 | --- 259 | Task: "Generate an image of the oldest person in this document." 260 | 261 | Action: 262 | { 263 | "tool_name": "document_qa", 264 | "tool_arguments": {"document": "document.pdf", "question": "Who is the oldest person mentioned?"} 265 | } 266 | Observation: "The oldest person in the document is John Doe, a 55 year old lumberjack living in Newfoundland." 267 | 268 | Action: 269 | { 270 | "tool_name": "image_generator", 271 | "tool_arguments": {"prompt": "A portrait of John Doe, a 55-year-old man living in Canada."} 272 | } 273 | Observation: "image.png" 274 | 275 | Action: 276 | { 277 | "tool_name": "final_answer", 278 | "tool_arguments": "image.png" 279 | } 280 | 281 | --- 282 | Task: "What is the result of the following operation: 5 + 3 + 1294.678?" 283 | 284 | Action: 285 | { 286 | "tool_name": "python_interpreter", 287 | "tool_arguments": {"code": "5 + 3 + 1294.678"} 288 | } 289 | Observation: 1302.678 290 | 291 | Action: 292 | { 293 | "tool_name": "final_answer", 294 | "tool_arguments": "1302.678" 295 | } 296 | 297 | --- 298 | Task: "Which city has the highest population , Guangzhou or Shanghai?" 299 | 300 | Action: 301 | { 302 | "tool_name": "search", 303 | "tool_arguments": "Population Guangzhou" 304 | } 305 | Observation: ['Guangzhou has a population of 15 million inhabitants as of 2021.'] 306 | 307 | 308 | Action: 309 | { 310 | "tool_name": "search", 311 | "tool_arguments": "Population Shanghai" 312 | } 313 | Observation: '26 million (2019)' 314 | 315 | Action: 316 | { 317 | "tool_name": "final_answer", 318 | "tool_arguments": "Shanghai" 319 | } 320 | 321 | 322 | Above example were using notional tools that might not exist for you. You only have access to these tools: 323 | 324 | {{tool_descriptions}} 325 | 326 | {{managed_agents_descriptions}} 327 | 328 | Here are the rules you should always follow to solve your task: 329 | 1. ALWAYS provide a tool call, else you will fail. 330 | 2. Always use the right arguments for the tools. Never use variable names as the action arguments, use the value instead. 331 | 3. Call a tool only when needed: do not call the search agent if you do not need information, try to solve the task yourself. 332 | If no tool call is needed, use final_answer tool to return your answer. 333 | 4. Never re-do a tool call that you previously did with the exact same parameters. 334 | 5. The current time is {{current_time}}. 335 | 336 | Now Begin! If you solve the task correctly and call the final_answer tool to give your answer, you will receive a reward of $1,000,000. 337 | "#; 338 | 339 | /// The system prompt for the function calling agent. This prompt is used for models that have tool calling capabilities. 340 | pub const FUNCTION_CALLING_SYSTEM_PROMPT: &str = r#"You are an expert assistant who can solve any task. You will be given a task to solve as best you can. 341 | 342 | 1. The current time is {{current_time}}. 343 | 2. DO NOT INCLUDE THE TOOL CALL IN YOUR RESPONSE, JUST RETURN THE ANSWER. 344 | 345 | Now Begin! If you solve the task correctly, you will receive a reward of $1,000,000. 346 | "#; 347 | -------------------------------------------------------------------------------- /src/tools/base.rs: -------------------------------------------------------------------------------- 1 | use schemars::JsonSchema; 2 | use serde::{de::DeserializeOwned, Deserialize, Serialize}; 3 | 4 | use super::tool_traits::{Parameters, Tool}; 5 | use anyhow::Result; 6 | 7 | #[derive(Deserialize, JsonSchema)] 8 | #[schemars(title = "BaseParams")] 9 | pub struct BaseParams { 10 | #[schemars(description = "The name of the tool")] 11 | _name: String, 12 | } 13 | 14 | impl Parameters for P where P: JsonSchema {} 15 | 16 | #[derive(Debug, Serialize, Default, Clone)] 17 | pub struct BaseTool { 18 | pub name: &'static str, 19 | pub description: &'static str, 20 | } 21 | 22 | impl Tool for BaseTool { 23 | type Params = serde_json::Value; 24 | fn name(&self) -> &'static str { 25 | self.name 26 | } 27 | 28 | fn description(&self) -> &'static str { 29 | self.description 30 | } 31 | fn forward(&self, _arguments: serde_json::Value) -> Result { 32 | Ok("Not implemented".to_string()) 33 | } 34 | } 35 | -------------------------------------------------------------------------------- /src/tools/ddg_search.rs: -------------------------------------------------------------------------------- 1 | //! This module contains the DuckDuckGo search tool. 2 | 3 | use schemars::JsonSchema; 4 | use scraper::Selector; 5 | use serde::{Deserialize, Serialize}; 6 | 7 | use super::base::BaseTool; 8 | use super::tool_traits::Tool; 9 | use anyhow::Result; 10 | 11 | #[derive(Deserialize, JsonSchema)] 12 | #[schemars(title = "DuckDuckGoSearchToolParams")] 13 | pub struct DuckDuckGoSearchToolParams { 14 | #[schemars(description = "The query to search for")] 15 | query: String, 16 | } 17 | 18 | #[derive(Debug, Serialize, Default)] 19 | pub struct SearchResult { 20 | pub title: String, 21 | pub snippet: String, 22 | pub url: String, 23 | } 24 | 25 | #[derive(Debug, Serialize, Default, Clone)] 26 | pub struct DuckDuckGoSearchTool { 27 | pub tool: BaseTool, 28 | } 29 | 30 | impl DuckDuckGoSearchTool { 31 | pub fn new() -> Self { 32 | DuckDuckGoSearchTool { 33 | tool: BaseTool { 34 | name: "duckduckgo_search", 35 | description: "Performs a duckduckgo web search for your query then returns a string of the top search results.", 36 | }, 37 | } 38 | } 39 | 40 | pub fn forward(&self, query: &str) -> Result> { 41 | let client = reqwest::blocking::Client::builder() 42 | .user_agent("Mozilla/5.0 (compatible; MyRustTool/1.0)") 43 | .build()?; 44 | let response = client 45 | .get(format!("https://html.duckduckgo.com/html/?q={}", query)) 46 | .send()?; 47 | let html = response.text().unwrap(); 48 | let document = scraper::Html::parse_document(&html); 49 | let result_selector = Selector::parse(".result") 50 | .map_err(|e| anyhow::anyhow!("Failed to parse result selector: {}", e))?; 51 | let title_selector = Selector::parse(".result__title a") 52 | .map_err(|e| anyhow::anyhow!("Failed to parse title selector: {}", e))?; 53 | let snippet_selector = Selector::parse(".result__snippet") 54 | .map_err(|e| anyhow::anyhow!("Failed to parse snippet selector: {}", e))?; 55 | let url_selector = Selector::parse(".result__url") 56 | .map_err(|e| anyhow::anyhow!("Failed to parse url selector: {}", e))?; 57 | let mut results = Vec::new(); 58 | 59 | for result in document.select(&result_selector) { 60 | let title_element = result.select(&title_selector).next(); 61 | let snippet_element = result.select(&snippet_selector).next(); 62 | if let (Some(title), Some(snippet)) = (title_element, snippet_element) { 63 | let title_text = title.text().collect::().trim().to_string(); 64 | let snippet_text = snippet.text().collect::().trim().to_string(); 65 | let url = result 66 | .select(&url_selector) 67 | .next() 68 | .unwrap() 69 | .text() 70 | .collect::>() 71 | .join("") 72 | .trim() 73 | .to_string(); 74 | if !title_text.is_empty() && !url.is_empty() { 75 | results.push(SearchResult { 76 | title: title_text, 77 | snippet: snippet_text, 78 | url, 79 | }); 80 | } 81 | } 82 | } 83 | Ok(results) 84 | } 85 | } 86 | 87 | impl Tool for DuckDuckGoSearchTool { 88 | type Params = DuckDuckGoSearchToolParams; 89 | fn name(&self) -> &'static str { 90 | self.tool.name 91 | } 92 | fn description(&self) -> &'static str { 93 | self.tool.description 94 | } 95 | fn forward(&self, arguments: DuckDuckGoSearchToolParams) -> Result { 96 | let query = arguments.query; 97 | let results = self.forward(&query)?; 98 | let results_string = results 99 | .iter() 100 | .map(|r| format!("[{}]({}) \n{}", r.title, r.url, r.snippet)) 101 | .collect::>() 102 | .join("\n\n"); 103 | Ok(results_string) 104 | } 105 | } 106 | 107 | #[cfg(test)] 108 | mod tests { 109 | use super::*; 110 | 111 | #[test] 112 | fn test_duckduckgo_search_tool() { 113 | let tool = DuckDuckGoSearchTool::new(); 114 | let query = "What is the capital of France?"; 115 | let result = tool.forward(query).unwrap(); 116 | assert!(result.iter().any(|r| r.snippet.contains("Paris"))); 117 | } 118 | } 119 | -------------------------------------------------------------------------------- /src/tools/final_answer.rs: -------------------------------------------------------------------------------- 1 | //! This module contains the final answer tool. The model uses this tool to provide a final answer to the problem. 2 | 3 | use schemars::JsonSchema; 4 | use serde::{Deserialize, Serialize}; 5 | 6 | use super::base::BaseTool; 7 | use super::tool_traits::Tool; 8 | use anyhow::Result; 9 | 10 | #[derive(Debug, Deserialize, JsonSchema)] 11 | #[schemars(title = "FinalAnswerToolParams")] 12 | pub struct FinalAnswerToolParams { 13 | #[schemars(description = "The final answer to the problem")] 14 | answer: String, 15 | } 16 | 17 | #[derive(Debug, Serialize, Default, Clone)] 18 | pub struct FinalAnswerTool { 19 | pub tool: BaseTool, 20 | } 21 | 22 | impl FinalAnswerTool { 23 | pub fn new() -> Self { 24 | FinalAnswerTool { 25 | tool: BaseTool { 26 | name: "final_answer", 27 | description: "Provides a final answer to the given problem.", 28 | }, 29 | } 30 | } 31 | } 32 | 33 | impl Tool for FinalAnswerTool { 34 | type Params = FinalAnswerToolParams; 35 | fn name(&self) -> &'static str { 36 | self.tool.name 37 | } 38 | fn description(&self) -> &'static str { 39 | self.tool.description 40 | } 41 | 42 | fn forward(&self, arguments: FinalAnswerToolParams) -> Result { 43 | Ok(arguments.answer) 44 | } 45 | } 46 | 47 | #[cfg(test)] 48 | mod tests { 49 | use super::*; 50 | 51 | #[test] 52 | fn test_final_answer_tool() { 53 | let tool = FinalAnswerTool::new(); 54 | let arguments = FinalAnswerToolParams { 55 | answer: "The answer is 42".to_string(), 56 | }; 57 | let result = tool.forward(arguments).unwrap(); 58 | assert_eq!(result, "The answer is 42"); 59 | } 60 | } 61 | -------------------------------------------------------------------------------- /src/tools/google_search.rs: -------------------------------------------------------------------------------- 1 | //! This module contains the Google search tool. 2 | 3 | use schemars::JsonSchema; 4 | use serde::{Deserialize, Serialize}; 5 | use serde_json::json; 6 | 7 | use super::base::BaseTool; 8 | use super::tool_traits::Tool; 9 | use anyhow::Result; 10 | 11 | #[derive(Deserialize, JsonSchema)] 12 | #[schemars(title = "GoogleSearchToolParams")] 13 | pub struct GoogleSearchToolParams { 14 | #[schemars(description = "The query to search for")] 15 | query: String, 16 | #[schemars(description = "Optionally restrict results to a certain year")] 17 | filter_year: Option, 18 | } 19 | 20 | #[derive(Debug, Serialize, Default, Clone)] 21 | pub struct GoogleSearchTool { 22 | pub tool: BaseTool, 23 | pub api_key: String, 24 | } 25 | 26 | impl GoogleSearchTool { 27 | pub fn new(api_key: Option) -> Self { 28 | let api_key = api_key.unwrap_or(std::env::var("SERPAPI_API_KEY").unwrap()); 29 | 30 | GoogleSearchTool { 31 | tool: BaseTool { 32 | name: "google_search", 33 | description: "Performs a google web search for your query then returns a string of the top search results.", 34 | }, 35 | api_key, 36 | } 37 | } 38 | 39 | fn forward(&self, query: &str, filter_year: Option<&str>) -> String { 40 | let params = { 41 | let mut params = json!({ 42 | "engine": "google", 43 | "q": query, 44 | "api_key": self.api_key, 45 | "google_domain": "google.com", 46 | }); 47 | 48 | if let Some(year) = filter_year { 49 | params["tbs"] = json!(format!("cdr:1,cd_min:01/01/{},cd_max:12/31/{}", year, year)); 50 | } 51 | 52 | params 53 | }; 54 | 55 | let client = reqwest::blocking::Client::new(); 56 | let response = client 57 | .get("https://serpapi.com/search.json") 58 | .query(¶ms) 59 | .send(); 60 | match response { 61 | Ok(resp) => { 62 | if resp.status().is_success() { 63 | let results: serde_json::Value = resp.json().unwrap(); 64 | if results.get("organic_results").is_none() { 65 | if filter_year.is_some() { 66 | return format!("'organic_results' key not found for query: '{}' with filtering on year={}. Use a less restrictive query or do not filter on year.", query, filter_year.unwrap()); 67 | } else { 68 | return format!("'organic_results' key not found for query: '{}'. Use a less restrictive query.", query); 69 | } 70 | } 71 | 72 | let organic_results = 73 | results.get("organic_results").unwrap().as_array().unwrap(); 74 | if organic_results.is_empty() { 75 | let _ = if filter_year.is_some() { 76 | format!(" with filter year={}", filter_year.unwrap()) 77 | } else { 78 | "".to_string() 79 | }; 80 | return format!("No results found for '{}'. Try with a more general query, or remove the year filter.", query); 81 | } 82 | 83 | let mut web_snippets = Vec::new(); 84 | for (idx, page) in organic_results.iter().enumerate() { 85 | let date_published = page.get("date").map_or("".to_string(), |d| { 86 | format!("\nDate published: {}", d.as_str().unwrap_or("")) 87 | }); 88 | let source = page.get("source").map_or("".to_string(), |s| { 89 | format!("\nSource: {}", s.as_str().unwrap_or("")) 90 | }); 91 | let snippet = page.get("snippet").map_or("".to_string(), |s| { 92 | format!("\n{}", s.as_str().unwrap_or("")) 93 | }); 94 | 95 | let redacted_version = format!( 96 | "{}. [{}]({}){}{}\n{}", 97 | idx, 98 | page.get("title").unwrap().as_str().unwrap(), 99 | page.get("link").unwrap().as_str().unwrap(), 100 | date_published, 101 | source, 102 | snippet 103 | ); 104 | let redacted_version = 105 | redacted_version.replace("Your browser can't play this video.", ""); 106 | web_snippets.push(redacted_version); 107 | } 108 | 109 | format!("## Search Results\n{}", web_snippets.join("\n\n")) 110 | } else { 111 | format!( 112 | "Failed to fetch search results: HTTP {}, Error: {}", 113 | resp.status(), 114 | resp.text().unwrap() 115 | ) 116 | } 117 | } 118 | Err(e) => format!("Failed to make the request: {}", e), 119 | } 120 | } 121 | } 122 | 123 | impl Tool for GoogleSearchTool { 124 | type Params = GoogleSearchToolParams; 125 | fn name(&self) -> &'static str { 126 | self.tool.name 127 | } 128 | fn description(&self) -> &'static str { 129 | self.tool.description 130 | } 131 | 132 | fn forward(&self, arguments: GoogleSearchToolParams) -> Result { 133 | let query = arguments.query; 134 | let filter_year = arguments.filter_year; 135 | Ok(self.forward(&query, filter_year.as_deref())) 136 | } 137 | } 138 | 139 | #[cfg(test)] 140 | mod tests { 141 | use super::*; 142 | 143 | #[test] 144 | fn test_google_search_tool() { 145 | let tool = GoogleSearchTool::new(None); 146 | let query = "What is the capital of France?"; 147 | let result = tool.forward(query, None); 148 | assert!(result.contains("Paris")); 149 | } 150 | } 151 | -------------------------------------------------------------------------------- /src/tools/mod.rs: -------------------------------------------------------------------------------- 1 | //! This module contains the tools that can be used in an agent. These are the default tools that are available. 2 | //! You can also implement your own tools by implementing the `Tool` trait. 3 | 4 | pub mod base; 5 | pub mod ddg_search; 6 | pub mod final_answer; 7 | pub mod google_search; 8 | pub mod tool_traits; 9 | pub mod visit_website; 10 | 11 | #[cfg(feature = "code-agent")] 12 | pub mod python_interpreter; 13 | 14 | pub use base::*; 15 | pub use ddg_search::*; 16 | pub use final_answer::*; 17 | pub use google_search::*; 18 | pub use tool_traits::*; 19 | pub use visit_website::*; 20 | 21 | #[cfg(feature = "code-agent")] 22 | pub use python_interpreter::*; 23 | -------------------------------------------------------------------------------- /src/tools/python_interpreter.rs: -------------------------------------------------------------------------------- 1 | //! This module contains the Python interpreter tool. The model uses this tool to evaluate python code. 2 | 3 | use std::collections::HashMap; 4 | 5 | use schemars::JsonSchema; 6 | use serde::{Deserialize, Serialize}; 7 | 8 | use super::base::BaseTool; 9 | use super::tool_traits::Tool; 10 | use crate::local_python_interpreter::evaluate_python_code; 11 | use anyhow::Result; 12 | 13 | #[derive(Deserialize, JsonSchema)] 14 | #[schemars(title = "PythonInterpreterToolParams")] 15 | pub struct PythonInterpreterToolParams { 16 | #[schemars( 17 | description = "The code snippet to evaluate. All variables used in this snippet must be defined in this same snippet, 18 | else you will get an error. 19 | This code can only import the following python libraries: 20 | collections, datetime, itertools, math, queue, random, re, stat, statistics, time, unicodedata" 21 | )] 22 | code: String, 23 | } 24 | #[derive(Debug, Serialize, Default, Clone)] 25 | pub struct PythonInterpreterTool { 26 | pub tool: BaseTool, 27 | } 28 | 29 | impl PythonInterpreterTool { 30 | pub fn new() -> Self { 31 | PythonInterpreterTool { 32 | tool: BaseTool { 33 | name: "python_interpreter", 34 | description: "This is a tool that evaluates python code. It can be used to perform calculations." 35 | }} 36 | } 37 | } 38 | 39 | impl Tool for PythonInterpreterTool { 40 | type Params = PythonInterpreterToolParams; 41 | fn name(&self) -> &'static str { 42 | self.tool.name 43 | } 44 | fn description(&self) -> &'static str { 45 | self.tool.description 46 | } 47 | fn forward(&self, arguments: PythonInterpreterToolParams) -> Result { 48 | let result = evaluate_python_code(&arguments.code, vec![], &mut HashMap::new()); 49 | match result { 50 | Ok(result) => Ok(format!("Evaluation Result: {}", result)), 51 | Err(e) => Err(anyhow::anyhow!("Error evaluating code: {}", e)), 52 | } 53 | } 54 | } 55 | -------------------------------------------------------------------------------- /src/tools/tool_traits.rs: -------------------------------------------------------------------------------- 1 | //! This module contains the traits for tools that can be used in an agent. 2 | 3 | use anyhow::Result; 4 | use schemars::gen::SchemaSettings; 5 | use schemars::schema::RootSchema; 6 | use schemars::JsonSchema; 7 | use serde::de::DeserializeOwned; 8 | use serde::Serialize; 9 | use serde_json::json; 10 | use std::fmt::Debug; 11 | 12 | use crate::errors::{AgentError, AgentExecutionError}; 13 | use crate::models::openai::FunctionCall; 14 | 15 | /// A trait for parameters that can be used in a tool. This defines the arguments that can be passed to the tool. 16 | pub trait Parameters: DeserializeOwned + JsonSchema {} 17 | 18 | /// A trait for tools that can be used in an agent. 19 | pub trait Tool: Debug { 20 | type Params: Parameters; 21 | /// The name of the tool. 22 | fn name(&self) -> &'static str; 23 | /// The description of the tool. 24 | fn description(&self) -> &'static str; 25 | /// The function to call when the tool is used. 26 | fn forward(&self, arguments: Self::Params) -> Result; 27 | } 28 | 29 | #[derive(serde::Serialize, serde::Deserialize, Debug)] 30 | pub enum ToolType { 31 | #[serde(rename = "function")] 32 | Function, 33 | } 34 | 35 | /// A struct that contains information about a tool. This is used to serialize the tool for the API. 36 | #[derive(Serialize, Debug)] 37 | pub struct ToolInfo { 38 | #[serde(rename = "type")] 39 | tool_type: ToolType, 40 | pub function: ToolFunctionInfo, 41 | } 42 | /// This struct contains information about the function to call when the tool is used. 43 | #[derive(Serialize, Debug)] 44 | pub struct ToolFunctionInfo { 45 | pub name: &'static str, 46 | pub description: &'static str, 47 | pub parameters: RootSchema, 48 | } 49 | 50 | impl ToolInfo { 51 | pub fn new(tool: &T) -> Self { 52 | let mut settings = SchemaSettings::draft07(); 53 | settings.inline_subschemas = true; 54 | let generator = settings.into_generator(); 55 | 56 | let parameters = generator.into_root_schema_for::

(); 57 | 58 | Self { 59 | tool_type: ToolType::Function, 60 | function: ToolFunctionInfo { 61 | name: tool.name(), 62 | description: tool.description(), 63 | parameters, 64 | }, 65 | } 66 | } 67 | 68 | pub fn get_parameter_names(&self) -> Vec { 69 | if let Some(schema) = &self.function.parameters.schema.object { 70 | return schema.properties.keys().cloned().collect(); 71 | } 72 | Vec::new() 73 | } 74 | } 75 | 76 | pub fn get_json_schema(tool: &ToolInfo) -> serde_json::Value { 77 | json!(tool) 78 | } 79 | 80 | pub trait ToolGroup: Debug { 81 | fn call(&self, arguments: &FunctionCall) -> Result; 82 | fn tool_info(&self) -> Vec; 83 | } 84 | 85 | impl ToolGroup for Vec> { 86 | fn call(&self, arguments: &FunctionCall) -> Result { 87 | let tool = self.iter().find(|tool| tool.name() == arguments.name); 88 | if let Some(tool) = tool { 89 | let p = arguments.arguments.clone(); 90 | return tool.forward_json(p); 91 | } 92 | Err(AgentError::Execution("Tool not found".to_string())) 93 | } 94 | fn tool_info(&self) -> Vec { 95 | self.iter().map(|tool| tool.tool_info()).collect() 96 | } 97 | } 98 | 99 | pub trait AnyTool: Debug { 100 | fn name(&self) -> &'static str; 101 | fn description(&self) -> &'static str; 102 | fn forward_json(&self, json_args: serde_json::Value) -> Result; 103 | fn tool_info(&self) -> ToolInfo; 104 | fn clone_box(&self) -> Box; 105 | } 106 | 107 | impl AnyTool for T { 108 | fn name(&self) -> &'static str { 109 | Tool::name(self) 110 | } 111 | 112 | fn description(&self) -> &'static str { 113 | Tool::description(self) 114 | } 115 | 116 | fn forward_json(&self, json_args: serde_json::Value) -> Result { 117 | let params = serde_json::from_value::(json_args.clone()).map_err(|e| { 118 | AgentError::Parsing(format!( 119 | "Error when executing tool with arguments: {:?}: {}. As a reminder, this tool's description is: {} and takes inputs: {}", 120 | json_args, 121 | e.to_string(), 122 | self.description(), 123 | json!(&self.tool_info().function.parameters.schema)["properties"].to_string() 124 | )) 125 | })?; 126 | Tool::forward(self, params).map_err(|e| AgentError::Execution(e.to_string())) 127 | } 128 | 129 | fn tool_info(&self) -> ToolInfo { 130 | ToolInfo::new::(self) 131 | } 132 | 133 | fn clone_box(&self) -> Box { 134 | Box::new(self.clone()) 135 | } 136 | } 137 | -------------------------------------------------------------------------------- /src/tools/visit_website.rs: -------------------------------------------------------------------------------- 1 | //! This module contains the visit website tool. The model uses this tool to visit a webpage and read its content as a markdown string. 2 | 3 | use htmd::HtmlToMarkdown; 4 | use reqwest::Url; 5 | use schemars::JsonSchema; 6 | use serde::{Deserialize, Serialize}; 7 | 8 | use super::{base::BaseTool, tool_traits::Tool}; 9 | use anyhow::Result; 10 | 11 | #[derive(Debug, Serialize, Default, Clone)] 12 | pub struct VisitWebsiteTool { 13 | pub tool: BaseTool, 14 | } 15 | 16 | impl VisitWebsiteTool { 17 | pub fn new() -> Self { 18 | VisitWebsiteTool { 19 | tool: BaseTool { 20 | name: "visit_website", 21 | description: "Visits a webpage at the given url and reads its content as a markdown string. Use this to browse webpages", 22 | }, 23 | } 24 | } 25 | 26 | pub fn forward(&self, url: &str) -> String { 27 | let client = reqwest::blocking::Client::builder() 28 | .user_agent("Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36") 29 | .build() 30 | .unwrap_or_else(|_| reqwest::blocking::Client::new()); 31 | let url = match Url::parse(url) { 32 | Ok(url) => url, 33 | Err(_) => Url::parse(&format!("https://{}", url)).unwrap(), 34 | }; 35 | 36 | let response = client.get(url.clone()).send(); 37 | 38 | match response { 39 | Ok(resp) => { 40 | if resp.status().is_success() { 41 | match resp.text() { 42 | Ok(text) => { 43 | let converter = HtmlToMarkdown::builder() 44 | .skip_tags(vec!["script", "style", "header", "nav", "footer"]) 45 | .build(); 46 | converter.convert(&text).unwrap() 47 | } 48 | Err(_) => "Failed to read response text".to_string(), 49 | } 50 | } else if resp.status().as_u16() == 999 { 51 | "The website appears to be blocking automated access. Try visiting the URL directly in your browser.".to_string() 52 | } else { 53 | format!( 54 | "Failed to fetch the webpage {}: HTTP {} - {}", 55 | url, 56 | resp.status(), 57 | resp.status().canonical_reason().unwrap_or("Unknown Error") 58 | ) 59 | } 60 | } 61 | Err(e) => format!("Failed to make the request to {}: {}", url, e), 62 | } 63 | } 64 | } 65 | 66 | #[derive(Deserialize, JsonSchema)] 67 | #[schemars(title = "VisitWebsiteToolParams")] 68 | pub struct VisitWebsiteToolParams { 69 | #[schemars(description = "The url of the website to visit")] 70 | url: String, 71 | } 72 | 73 | impl Tool for VisitWebsiteTool { 74 | type Params = VisitWebsiteToolParams; 75 | fn name(&self) -> &'static str { 76 | self.tool.name 77 | } 78 | 79 | fn description(&self) -> &'static str { 80 | self.tool.description 81 | } 82 | 83 | fn forward(&self, arguments: VisitWebsiteToolParams) -> Result { 84 | let url = arguments.url; 85 | Ok(self.forward(&url)) 86 | } 87 | } 88 | 89 | #[cfg(test)] 90 | mod tests { 91 | use super::*; 92 | 93 | #[test] 94 | fn test_visit_website_tool() { 95 | let tool = VisitWebsiteTool::new(); 96 | let url = "https://finance.yahoo.com/quote/NVDA"; 97 | let _result = tool.forward(&url); 98 | println!("{}", _result); 99 | } 100 | } 101 | --------------------------------------------------------------------------------