├── .gitignore ├── LICENSE ├── README.md ├── dist ├── taskflowai-0.5.13-py3-none-any.whl └── taskflowai-0.5.13.tar.gz ├── poetry.lock ├── pyproject.toml ├── research_agent.py ├── taskflowai-multi-agent-team.png └── taskflowai ├── __init__.py ├── agent.py ├── knowledgebases ├── __init__.py └── faiss_knowledgebase.py ├── llm.py ├── task.py ├── tools ├── __init__.py ├── amadeus_tools.py ├── audio_tools.py ├── calculator_tools.py ├── conversation_tools.py ├── embedding_tools.py ├── faiss_tools.py ├── file_tools.py ├── fred_tools.py ├── github_tools.py ├── langchain_tools.py ├── matplotlib_tools.py ├── pinecone_tools.py ├── semantic_splitter.py ├── sentence_splitter.py ├── text_splitters.py ├── web_tools.py ├── wikipedia_tools.py └── yahoo_finance_tools.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Virtual Environment 2 | venv/ 3 | env/ 4 | ENV/ 5 | .env 6 | .venv 7 | pythonenv* 8 | 9 | # Python cache files 10 | __pycache__/ 11 | *.py[cod] 12 | *$py.class 13 | .pytest_cache/ 14 | .coverage 15 | htmlcov/ 16 | .tox/ 17 | .nox/ 18 | 19 | # Distribution / packaging 20 | dist/ 21 | build/ 22 | *.egg-info/ 23 | *.egg 24 | MANIFEST 25 | 26 | # IDEs and editors 27 | .idea/ 28 | .vscode/ 29 | *.swp 30 | *.swo 31 | *~ 32 | .DS_Store 33 | .project 34 | .pydevproject 35 | .settings/ 36 | *.sublime-workspace 37 | *.sublime-project 38 | 39 | # Jupyter Notebook 40 | .ipynb_checkpoints 41 | *.ipynb 42 | 43 | # Local development settings 44 | *.env 45 | .env.local 46 | .env.*.local 47 | .env.development 48 | .env.test 49 | .env.production 50 | 51 | # Logs and databases 52 | *.log 53 | *.sqlite 54 | *.db 55 | *.sql 56 | 57 | # Unit test / coverage reports 58 | coverage.xml 59 | *.cover 60 | .hypothesis/ 61 | .pytest_cache/ 62 | nosetests.xml 63 | 64 | # mypy 65 | .mypy_cache/ 66 | .dmypy.json 67 | dmypy.json 68 | 69 | # Rope project settings 70 | .ropeproject 71 | 72 | # mkdocs documentation 73 | /site 74 | 75 | # Celery 76 | celerybeat-schedule 77 | celerybeat.pid 78 | 79 | # Spyder project settings 80 | .spyderproject 81 | .spyproject -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright 2024 Philippe Andre Page and TaskflowAI Contributors 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, softwareok 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | 15 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 16 | 17 | 1. Definitions. 18 | “License” shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document. 19 | 20 | “Licensor” shall mean the copyright owner or entity authorized by the copyright owner that is granting the License. 21 | 22 | “Legal Entity” shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, “control” means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | “You” (or “Your”) shall mean an individual or Legal Entity exercising permissions granted by this License. 25 | 26 | “Source” form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files. 27 | 28 | “Object” form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types. 29 | 30 | “Work” shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below). 31 | 32 | “Derivative Works” shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof. 33 | 34 | “Contribution” shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, “submitted” means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as “Not a Contribution.” 35 | 36 | “Contributor” shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work. 37 | 38 | 2. Grant of Copyright License. 39 | Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form. 40 | 41 | 3. Grant of Patent License. 42 | Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed. 43 | 44 | 4. Redistribution. 45 | You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions: 46 | 47 | You must give any other recipients of the Work or Derivative Works a copy of this License; and 48 | You must cause any modified files to carry prominent notices stating that You changed the files; and 49 | You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and 50 | If the Work includes a “NOTICE” text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License. 51 | You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License. 52 | 53 | 5. Submission of Contributions. 54 | Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions. 55 | 56 | 6. Trademarks. 57 | This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file. 58 | 59 | 7. Disclaimer of Warranty. 60 | Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an “AS IS” BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License. 61 | 62 | 8. Limitation of Liability. 63 | In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages. 64 | 65 | 9. Accepting Warranty or Additional Liability. 66 | While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability. 67 | 68 | END OF TERMS AND CONDITIONS -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | [![License](https://img.shields.io/badge/License-Apache%202.0-blue.svg)](https://opensource.org/licenses/Apache-2.0) 2 | [![PRs Welcome](https://img.shields.io/badge/PRs-welcome-brightgreen.svg?style=flat-square)](https://github.com/philippe-page/taskflowai/issues) 3 | [![Downloads](https://static.pepy.tech/badge/taskflowai)](https://pepy.tech/project/taskflowai) 4 | [![PyPI version](https://badge.fury.io/py/taskflowai.svg)](https://pypi.org/project/taskflowai/) 5 | [![Twitter](https://img.shields.io/twitter/follow/philippe__page?label=Follow%20@philippe__page&style=social)](https://twitter.com/philippe__page) 6 | 7 | 8 | # TaskflowAI: Task-Centric Framework for LLM-Driven Pipelines and Multi-Agent Teams 9 | 10 | 11 | TaskflowAI is a lightweight, intuitive, and flexible framework for creating AI-driven task pipelines and multi-agent teams. Centered around the concept of Tasks, rather than conversation patterns, it enables the design and orchestration of autonomous workflows while balancing flexibility and reliability. 12 | 13 | ## Key Features 14 | #### 🧠 Task-centric design aligning closely with real-world operational processes 15 | 16 | #### 🧩 Modular architecture for easy building, extension, and integration 17 | 18 | #### 🌐 Flexible workflows allow you to design everything from deterministic pipelines to autonomous multi-agent teams 19 | 20 | #### 📈 The frameworks complexity floor starts low from simple deterministic pipelines and scales to complex multi-agent teams 21 | 22 | #### 💬 Support for hundreds of language models (OpenAI, Anthropic, OpenRouter, Groq, and local models with Ollama.) 23 | 24 | #### 🛠️ Comprehensive and extendable toolset for web interaction, file operations, embeddings generation, and more 25 | 26 | #### 🔍 Transparency through detailed logging and state exposure 27 | 28 | #### ⚡️ Lightweight core with minimal dependencies 29 | 30 | ## Installation 31 | 32 | Install TaskflowAI using pip: 33 | 34 | ```bash 35 | pip install taskflowai 36 | ``` 37 | 38 | ## Quick Start 39 | 40 | Here's a simple example to get you started: 41 | 42 | ```python 43 | from taskflowai import Agent, Task, OpenaiModels, WebTools, set_verbosity 44 | 45 | set_verbosity(1) 46 | 47 | research_agent = Agent( 48 | role="research assistant", 49 | goal="answer user queries", 50 | llm=OpenaiModels.gpt_4o, 51 | tools={WebTools.exa_search} 52 | ) 53 | 54 | def research_task(topic): 55 | return Task.create( 56 | agent=research_agent, 57 | instruction=f"Use your exa search tool to research {topic} and explain it in a way that is easy to understand.", 58 | ) 59 | 60 | result = research_task("quantum computing") 61 | print(result) 62 | ``` 63 | 64 | ## Core Components 65 | 66 | **Tasks**: Discrete units of work 67 | 68 | **Agents**: Personas that perform tasks and can be assigned tools 69 | 70 | **Tools**: Wrappers around external services or specific functionalities 71 | 72 | **Language Model Interfaces**: Consistent interface for various LLM providers 73 | 74 | ## Supported Language Models and Providers 75 | 76 | TaskflowAI supports a wide range of language models from a number of providers: 77 | 78 | ### OpenAI 79 | GPT-4 Turbo, GPT-3.5 Turbo, GPT-4, GPT-4o, GPT-4o Mini, & more 80 | 81 | ### Anthropic 82 | Claude 3 Haiku, Claude 3 Sonnet, Claude 3 Opus, Claude 3.5 Sonnet, & more 83 | 84 | ### Openrouter 85 | GPT-4 Turbo, Claude 3 Opus, Mixtral 8x7B, Llama 3.1 405B, & more 86 | 87 | ### Ollama 88 | Mistral, Mixtral, Llama 3.1, Qwen, Gemma, & more 89 | 90 | ### Groq 91 | Mixtral 8x7B, Llama 3, Llama 3.1, Gemma, & more 92 | 93 | Each provider is accessible through a dedicated class (e.g., `OpenaiModels`, `AnthropicModels`, etc.) with methods corresponding to specific models. This structure allows for painless switching between models and providers, enabling users to leverage the most suitable LLM for their tasks. 94 | 95 | ## Tools 96 | 97 | TaskflowAI comes with a set of built-in tools that provide a wide range of functionalities, skills, actions, and knowledge for your agents to use in their task completion. 98 | 99 | - WebTools: For web scraping, searches, and data retrieval with Serper, Exa, WeatherAPI, etc. 100 | - FileTools: Handling various file operations like reading CSV, JSON, and XML files. 101 | - GitHubTools: Interacting with GitHub repositories, including listing contributors and fetching repository contents. 102 | - CalculatorTools: Performing date and time calculations. 103 | - EmbeddingsTools: Generating embeddings for text. 104 | - WikipediaTools: Searching and retrieving information from Wikipedia. 105 | - AmadeusTools: Searching for flight information. 106 | - LangchainTools: A wrapper for integrating Langchain tools to allow agents to use tools in the Langchain catalog. 107 | - Custom Tools: You can also create your own custom tools to add any functionality you need. 108 | 109 | ## Multi-Agent Teams 110 | 111 | TaskflowAI allows you to create multi-agent teams that can use tools to complete a series of tasks. Here's an example of a travel planning agent that uses multiple agents to research and plan a trip: 112 | 113 | ```python 114 | from taskflowai import Agent, Task, WebTools, WikipediaTools, AmadeusTools, OpenaiModels, OpenrouterModels, set_verbosity 115 | 116 | set_verbosity(1) 117 | 118 | web_research_agent = Agent( 119 | role="web research agent", 120 | goal="search the web thoroughly for travel information", 121 | attributes="hardworking, diligent, thorough, comphrehensive.", 122 | llm=OpenrouterModels.haiku, 123 | tools={WebTools.serper_search, WikipediaTools.search_articles, WikipediaTools.search_images} 124 | ) 125 | 126 | travel_agent = Agent( 127 | role="travel agent", 128 | goal="assist the traveller with their request", 129 | attributes="friendly, hardworking, and comprehensive and extensive in reporting back to users", 130 | llm=OpenrouterModels.haiku, 131 | tools={AmadeusTools.search_flights} 132 | ) 133 | 134 | def research_destination(destination, interests): 135 | destination_report = Task.create( 136 | agent=web_research_agent, 137 | context=f"User Destination: {destination}\nUser Interests: {interests}", 138 | instruction=f"Use your tools to search relevant information about the given destination: {destination}. Use your serper web search tool to research information about the destination to write a comprehensive report. Use wikipedia tools to search the destination's wikipedia page, as well as images of the destination. In your final answer you should write a comprehensive report about the destination with images embedded in markdown." 139 | ) 140 | return destination_report 141 | 142 | def research_events(destination, dates, interests): 143 | events_report = Task.create( 144 | agent=web_research_agent, 145 | context=f"User's intended destination: {destination}\n\nUser's intended dates of travel: {dates}\nUser Interests: {interests}", 146 | instruction="Use your tools to research events in the given location for the given date span. Ensure your report is a comprehensive report on events in the area for that time period." 147 | ) 148 | return events_report 149 | 150 | def search_flights(current_location, destination, dates): 151 | flight_report = Task.create( 152 | agent=travel_agent, 153 | context=f"Current Location: {current_location}\n\nDestination: {destination}\nDate Range: {dates}", 154 | instruction=f"Search for a lot of flights in the given date range to collect a bunch of options and return a report on the best options in your opinion, based on convenience and lowest price." 155 | ) 156 | return flight_report 157 | 158 | def write_travel_report(destination_report, events_report, flight_report): 159 | travel_report = Task.create( 160 | agent=travel_agent, 161 | context=f"Destination Report: {destination_report}\n--------\n\nEvents Report: {events_report}\n--------\n\nFlight Report: {flight_report}", 162 | instruction=f"Write a comprehensive travel plan and report given the information above. Ensure your report conveys all the detail in the given information, from flight options, to events, and image urls, etc. Preserve detail and write your report in extensive length." 163 | ) 164 | return travel_report 165 | 166 | def main(): 167 | current_location = input("Where are you traveling from?\n") 168 | destination = input("Where are you travelling to?\n") 169 | dates = input("What are the dates for your trip?\n") 170 | interests= input("Do you have any particular interests?\n") 171 | 172 | destination_report = research_destination(web_research_agent, destination, interests) 173 | print(destination_report) 174 | 175 | events_report = research_events(web_research_agent, destination, dates, interests) 176 | print(events_report) 177 | 178 | flight_report = search_flights(travel_agent, current_location, destination, dates) 179 | print(flight_report) 180 | 181 | final_report = write_travel_report(travel_agent, destination_report, events_report, flight_report) 182 | print(final_report) 183 | 184 | if __name__ == "__main__": 185 | main() 186 | ``` 187 | 188 | By combining agents, tasks, tools, and language models, you can create a wide range of workflows, from simple pipelines to complex multi-agent teams. 189 | 190 | ## Documentation 191 | 192 | For more detailed information, tutorials, and advanced usage, visit our [documentation](https://taskflowai.org). 193 | 194 | ## Contributing 195 | 196 | TaskflowAI depends on and welcomes community contributions! Please review contribution guidelines and submit a pull request if you'd like to contribute. 197 | 198 | ## License 199 | 200 | TaskflowAI is released under the Apache License 2.0. See the [LICENSE](LICENSE) file for details. 201 | 202 | ## Support 203 | 204 | For issues or questions, please file an issue on our [GitHub repository](https://github.com/philippe-page/taskflowai/issues). 205 | 206 | ⭐️ If you find TaskflowAI helpful, please consider giving it a star! 207 | 208 | Happy building! 209 | -------------------------------------------------------------------------------- /dist/taskflowai-0.5.13-py3-none-any.whl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/philippe-page/taskflowai/e1c16b0832f10231c719dfcac768a1fc05f0ee53/dist/taskflowai-0.5.13-py3-none-any.whl -------------------------------------------------------------------------------- /dist/taskflowai-0.5.13.tar.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/philippe-page/taskflowai/e1c16b0832f10231c719dfcac768a1fc05f0ee53/dist/taskflowai-0.5.13.tar.gz -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.poetry] 2 | name = "taskflowai" 3 | version = "0.5.13" 4 | description = "TaskFlowAI is a lightweight, open-source python framework for building LLM based pipelines and multi-agent teams" 5 | authors = ["Philippe Andre Page "] 6 | readme = "README.md" 7 | packages = [{include = "taskflowai"}] 8 | license = "Apache 2.0" 9 | classifiers = [ 10 | "Development Status :: 4 - Beta", 11 | "Intended Audience :: Developers", 12 | "License :: OSI Approved :: Apache Software License", 13 | "Operating System :: OS Independent", 14 | "Programming Language :: Python :: 3", 15 | "Programming Language :: Python :: 3.10", 16 | "Programming Language :: Python :: 3.11", 17 | "Programming Language :: Python :: 3.12", 18 | ] 19 | 20 | [tool.poetry.dependencies] 21 | python = "^3.10" 22 | requests = "*" 23 | pydantic = ">=2.0" 24 | anthropic = "*" 25 | openai = ">=1.0" 26 | cohere = "*" 27 | beautifulsoup4 = "*" 28 | tqdm = "*" 29 | python-dotenv = "*" 30 | PyYAML = "*" 31 | ollama = "*" 32 | lxml = "*" 33 | halo = "*" 34 | groq = "*" 35 | numpy = "*" 36 | elevenlabs = "*" 37 | faiss-cpu = "*" 38 | pyyaml = "*" 39 | fredapi = "*" 40 | yfinance = "*" 41 | yahoofinance = "*" 42 | pinecone = "*" 43 | sentence_splitter = "*" 44 | igraph = "*" 45 | leidenalg = "*" 46 | fake-useragent = "*" 47 | waitress = "*" 48 | 49 | [tool.poetry.extras] 50 | langchain_tools = [ 51 | "langchain-core", 52 | "langchain-community", 53 | "langchain-openai" 54 | ] 55 | matplotlib_tools = [ 56 | "matplotlib" 57 | ] 58 | yahoo_finance_tools = [ 59 | "yfinance", 60 | "yahoofinance", 61 | "pandas" 62 | ] 63 | fred_tools = [ 64 | "fredapi", 65 | "pandas" 66 | ] 67 | 68 | [tool.poetry.dependencies.langchain-core] 69 | version = "*" 70 | optional = true 71 | 72 | [tool.poetry.dependencies.langchain-community] 73 | version = "*" 74 | optional = true 75 | 76 | [tool.poetry.dependencies.langchain-openai] 77 | version = "*" 78 | optional = true 79 | 80 | [tool.poetry.dependencies.matplotlib] 81 | version = "*" 82 | optional = true 83 | 84 | [tool.poetry.group.dev.dependencies] 85 | pytest = ">=8.0.0,<9.0.0" 86 | black = ">=23.0,<25.0" 87 | isort = ">=5.0,<6.0" 88 | mypy = ">=1.0,<2.0" 89 | 90 | [build-system] 91 | requires = ["poetry-core"] 92 | build-backend = "poetry.core.masonry.api" 93 | 94 | [tool.poetry.urls] 95 | "Homepage" = "https://github.com/philippe-page/taskflowai/" 96 | "Bug Tracker" = "https://github.com/philippe-page/taskflowai/issues" 97 | "Documentation" = "https://taskflowai.org" 98 | 99 | [tool.black] 100 | line-length = 100 101 | target-version = ['py310'] 102 | 103 | [tool.isort] 104 | profile = "black" 105 | line_length = 100 106 | 107 | [tool.mypy] 108 | python_version = "3.10" 109 | strict = true 110 | ignore_missing_imports = true 111 | 112 | [tool.pytest.ini_options] 113 | minversion = "6.0" 114 | addopts = "-ra -q" 115 | testpaths = [ 116 | "tests", 117 | ] -------------------------------------------------------------------------------- /research_agent.py: -------------------------------------------------------------------------------- 1 | from taskflowai import Task, Agent, OpenrouterModels, WebTools 2 | 3 | # Requires SERPER_API_KEY and OPENROUTER_API_KEY 4 | 5 | agent = Agent( 6 | role="research assistant", 7 | goal="answer user queries", 8 | attributes="you're thorough in your web research and you write extensive reports on your research", 9 | llm=OpenrouterModels.haiku, 10 | tools={WebTools.serper_search} 11 | ) 12 | 13 | def create_research_task(user_query): 14 | return Task.create( 15 | agent=agent, 16 | instruction=f"Answer the following query: {user_query}" 17 | ) 18 | 19 | def main(): 20 | user_query = input("Enter your research query: ") 21 | response = create_research_task(user_query) 22 | print(f"\nResearch Assistant's Response:\n{response}\n") 23 | 24 | if __name__ == "__main__": 25 | main() 26 | -------------------------------------------------------------------------------- /taskflowai-multi-agent-team.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/philippe-page/taskflowai/e1c16b0832f10231c719dfcac768a1fc05f0ee53/taskflowai-multi-agent-team.png -------------------------------------------------------------------------------- /taskflowai/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | TaskFlowAI: A lightweight Python framework for building and orchestrating multi-agent systems powered by LLMs. 3 | """ 4 | # Copyright 2024 TaskFlowAI Contributors. Licensed under Apache License 2.0. 5 | 6 | __version__ = "0.5.13" 7 | 8 | # Import main classes and core tools 9 | from .task import Task 10 | from .agent import Agent 11 | from .utils import Utils 12 | from .llm import OpenaiModels, AnthropicModels, OpenrouterModels, OllamaModels, GroqModels, TogetheraiModels, set_verbosity 13 | from .knowledgebases import FaissKnowledgeBase 14 | from .tools import ( 15 | FileTools, 16 | EmbeddingsTools, 17 | WebTools, 18 | GitHubTools, 19 | TextToSpeechTools, 20 | WhisperTools, 21 | WikipediaTools, 22 | AmadeusTools, 23 | CalculatorTools, 24 | ConversationTools, 25 | FAISSTools, 26 | PineconeTools, 27 | ) 28 | 29 | # Conditional imports for optional dependencies 30 | import sys 31 | import importlib 32 | from typing import TYPE_CHECKING 33 | 34 | if TYPE_CHECKING: 35 | from .tools.langchain_tools import LangchainTools 36 | from .tools.matplotlib_tools import MatplotlibTools 37 | from .tools.yahoo_finance_tools import YahooFinanceTools 38 | from .tools.fred_tools import FredTools 39 | 40 | def __getattr__(name): 41 | package_map = { 42 | 'LangchainTools': ('langchain_tools', ['langchain-core', 'langchain-community', 'langchain-openai']), 43 | 'MatplotlibTools': ('matplotlib_tools', ['matplotlib']), 44 | 'YahooFinanceTools': ('yahoo_finance_tools', ['yfinance']), 45 | 'FredTools': ('fred_tools', ['fredapi']) 46 | } 47 | 48 | if name in package_map: 49 | module_name, required_packages = package_map[name] 50 | try: 51 | for package in required_packages: 52 | importlib.import_module(package) 53 | 54 | # If successful, import and return the tool 55 | module = __import__(f'taskflowai.tools.{module_name}', fromlist=[name]) 56 | return getattr(module, name) 57 | except ImportError as e: 58 | print(f"\033[95mError: The required packages for {name} are not installed. " 59 | f"Please install them using 'pip install {' '.join(required_packages)}'.\n" 60 | f"Specific error: {str(e)}\033[0m") 61 | sys.exit(1) 62 | else: 63 | raise AttributeError(f"Module '{__name__}' has no attribute '{name}'") 64 | 65 | # List of all public attributes 66 | __all__ = [ 67 | "Task", 68 | "Agent", 69 | "Utils", 70 | "OpenaiModels", 71 | "AnthropicModels", 72 | "OpenrouterModels", 73 | "OllamaModels", 74 | "GroqModels", 75 | "TogetheraiModels", 76 | "set_verbosity", 77 | # List core tools 78 | "FileTools", 79 | "EmbeddingsTools", 80 | "WebTools", 81 | "GitHubTools", 82 | "TextToSpeechTools", 83 | "WhisperTools", 84 | "WikipediaTools", 85 | "AmadeusTools", 86 | "CalculatorTools", 87 | "ConversationTools", 88 | "FAISSTools", 89 | "PineconeTools", 90 | "FaissKnowledgeBase", 91 | # Add optional tools here for IDE recognition 92 | "LangchainTools", 93 | "MatplotlibTools", 94 | "YahooFinanceTools", 95 | "FredTools", 96 | ] -------------------------------------------------------------------------------- /taskflowai/agent.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 TaskFlowAI Contributors. Licensed under Apache License 2.0. 2 | 3 | from pydantic import BaseModel, Field 4 | from typing import Optional, Callable, Set 5 | 6 | class Agent(BaseModel): 7 | role: str = Field(..., description="The role or type of agent performing tasks") 8 | goal: str = Field(..., description="The objective or purpose of the agent") 9 | attributes: Optional[str] = Field(None, description="Additional attributes or characteristics of the agent") 10 | llm: Optional[Callable] = Field(None, description="The language model function to be used by the agent") 11 | tools: Optional[Set[Callable]] = Field(default=None, description="Optional set of tool functions") 12 | temperature: Optional[float] = Field(default=0.7, description="The temperature for the language model") 13 | max_tokens: Optional[int] = Field(default=4000, description="The maximum number of tokens for the language model") 14 | 15 | model_config = { 16 | "arbitrary_types_allowed": True 17 | } -------------------------------------------------------------------------------- /taskflowai/knowledgebases/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 TaskFlowAI Contributors. Licensed under Apache License 2.0. 2 | 3 | from .faiss_knowledgebase import FaissKnowledgeBase 4 | 5 | __all__ = ['FaissKnowledgeBase'] 6 | 7 | 8 | 9 | -------------------------------------------------------------------------------- /taskflowai/knowledgebases/faiss_knowledgebase.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 TaskFlowAI Contributors. Licensed under Apache License 2.0. 2 | 3 | from typing import List, Dict, Any, Optional 4 | from ..tools import EmbeddingsTools 5 | from ..tools import FAISSTools 6 | import os 7 | import json 8 | import uuid 9 | 10 | class FaissKnowledgeBase: 11 | """ 12 | A knowledge base that uses FAISS for efficient similarity search. 13 | 14 | Parameters: 15 | - kb_name (str): The name of the knowledge base. 16 | - embedding_provider (str): The provider of the embeddings. 17 | - embedding_model (str): The model used for the embeddings. 18 | - load_from_index (str, optional): The path to load the index from. 19 | - chunks (List[str], optional): The chunks of text to initialize the knowledge base with. 20 | - save_to_filepath (str, optional): The path to save the index to. 21 | - **kwargs: Additional keyword arguments. 22 | 23 | Examples: 24 | >>> kb = FaissKnowledgeBase("default", "openai", "text-embedding-3-small") 25 | >>> kb = FaissKnowledgeBase("default", "openai", "text-embedding-3-small", load_from_index="index.faiss") 26 | >>> kb = FaissKnowledgeBase("default", "openai", "text-embedding-3-small", chunks=["chunk1", "chunk2"]) 27 | >>> kb = FaissKnowledgeBase("default", "openai", "text-embedding-3-small", chunks=["chunk1", "chunk2"], save_to_filepath="index.faiss") 28 | """ 29 | def __init__(self, kb_name: str = "default", 30 | embedding_provider: str = "openai", 31 | embedding_model: str = "text-embedding-3-small", 32 | load_from_index: Optional[str] = None, 33 | chunks: Optional[List[str]] = None, 34 | save_to_filepath: Optional[str] = None, 35 | **kwargs): 36 | self.kb_name = kb_name 37 | self.embedding_provider = embedding_provider 38 | self.embedding_model = embedding_model 39 | self.index_tool = None 40 | self.memories = {} # Dictionary to store memories with their IDs 41 | self.save_filepath = save_to_filepath 42 | 43 | try: 44 | import faiss 45 | import numpy as np 46 | except ModuleNotFoundError as e: 47 | raise ImportError(f"{e.name} is required for KnowledgeBase. Install with `pip install {e.name}`") 48 | 49 | self.faiss = faiss 50 | self.np = np 51 | 52 | if load_from_index: 53 | self.load_from_index(load_from_index) 54 | elif chunks: 55 | self.initialize_from_chunks(chunks, **kwargs) 56 | else: 57 | self.initialize_empty(**kwargs) 58 | 59 | if self.save_filepath: 60 | self.save_index(self.save_filepath) 61 | print(f"Index saved to {self.save_filepath}") 62 | 63 | def _format_memory(self, content: str, metadata: Optional[Dict[str, Any]] = None) -> Dict[str, Any]: 64 | """Helper method to format a memory with ID and metadata.""" 65 | return { 66 | "id": str(len(self.memories) + 1), # Simple incremental ID 67 | "content": content, 68 | "metadata": metadata or {} 69 | } 70 | 71 | def initialize_from_chunks(self, chunks: List[str], 72 | **kwargs) -> None: 73 | try: 74 | formatted_chunks = [self._format_memory(chunk) for chunk in chunks] 75 | self.memories = {chunk['id']: chunk for chunk in formatted_chunks} 76 | chunk_texts = [chunk['content'] for chunk in formatted_chunks] 77 | 78 | embeddings, _ = EmbeddingsTools.get_embeddings(chunk_texts, provider=self.embedding_provider, model=self.embedding_model) 79 | 80 | self.index_tool = FAISSTools(dimension=len(embeddings[0]), metric=kwargs.get("metric", "IP")) 81 | self.index_tool.create_index() 82 | self.index_tool.set_embedding_info(self.embedding_provider, self.embedding_model) 83 | 84 | np_vectors = self.np.array(embeddings).astype('float32') 85 | self.index_tool.add_vectors(np_vectors) 86 | 87 | self.index_tool.set_metadata('memories', list(self.memories.values())) 88 | 89 | except Exception as e: 90 | raise Exception(f"Error initializing knowledgebase: {str(e)}") 91 | 92 | def initialize_empty(self, **kwargs): 93 | try: 94 | dimension = EmbeddingsTools.get_model_dimension(self.embedding_provider, self.embedding_model) 95 | if not dimension: 96 | raise ValueError(f"Unsupported embedding model: {self.embedding_model} for provider: {self.embedding_provider}") 97 | 98 | self.index_tool = FAISSTools(dimension=dimension, metric=kwargs.get("metric", "IP")) 99 | self.index_tool.create_index() 100 | self.index_tool.set_embedding_info(self.embedding_provider, self.embedding_model) 101 | self.index_tool.set_metadata('chunks', []) 102 | except Exception as e: 103 | raise Exception(f"Error initializing empty knowledgebase: {str(e)}") 104 | 105 | def load_from_index(self, index_path: str) -> None: 106 | try: 107 | index_file = index_path 108 | metadata_file = f"{index_path}.metadata" 109 | 110 | if not os.path.exists(index_file): 111 | raise FileNotFoundError(f"Index file not found at {index_file}") 112 | 113 | if not os.path.exists(metadata_file): 114 | raise FileNotFoundError(f"Metadata file not found at {metadata_file}") 115 | 116 | self.index_tool = FAISSTools(dimension=1) # Dimension will be updated when loading 117 | self.index_tool.load_index(index_file) 118 | 119 | with open(metadata_file, 'r') as f: 120 | metadata = json.load(f) 121 | 122 | self.memories = {memory['id']: memory for memory in metadata.get('memories', [])} 123 | if not self.memories: 124 | raise ValueError(f"No memories found in metadata for '{self.kb_name}'") 125 | 126 | self.embedding_provider = self.index_tool.embedding_provider 127 | self.embedding_model = self.index_tool.embedding_model 128 | 129 | if not self.embedding_provider or not self.embedding_model: 130 | raise ValueError(f"No embedding provider or model found in metadata for '{self.kb_name}'") 131 | 132 | except Exception as e: 133 | raise Exception(f"Error loading knowledgebase: {str(e)}") 134 | 135 | def query(self, query: str, top_k: int = 6) -> List[Dict[str, Any]]: 136 | """ 137 | Query the knowledge base for the most relevant unique chunks. 138 | 139 | Args: 140 | query (str): The query to search for. 141 | top_k (int): The number of unique results to return. 142 | 143 | Returns: 144 | List[Dict[str, Any]]: A list of dictionaries, each containing the id, score, and content of a relevant unique chunk. 145 | """ 146 | if not self.memories: 147 | return [] 148 | try: 149 | if self.index_tool is None: 150 | raise ValueError(f"Knowledgebase '{self.kb_name}' is not initialized. Please initialize or load the knowledgebase first.") 151 | 152 | query_embedding, _ = EmbeddingsTools.get_embeddings([query], provider=self.embedding_provider, model=self.embedding_model) 153 | 154 | query_vector = self.np.array(query_embedding).astype('float32') 155 | 156 | # Increase the number of results to search for to ensure we have enough unique results 157 | distances, indices = self.index_tool.search_vectors( 158 | query_vectors=query_vector, top_k=top_k) 159 | 160 | formatted_results = [] 161 | seen_contents = set() 162 | 163 | for idx, dist in zip(indices[0], distances[0]): 164 | memory = list(self.memories.values())[idx] 165 | content = memory['content'] 166 | 167 | # Skip if we've already seen this content 168 | if content in seen_contents: 169 | continue 170 | 171 | seen_contents.add(content) 172 | formatted_results.append({ 173 | "id": memory['id'], 174 | "score": round(float(dist), 4), 175 | "content": content 176 | }) 177 | 178 | # Break if we have enough unique results 179 | if len(formatted_results) == top_k: 180 | break 181 | 182 | return formatted_results 183 | 184 | except Exception as e: 185 | raise Exception(f"Error querying knowledgebase: {str(e)}") 186 | 187 | def add_memory(self, memory: str, metadata: Optional[Dict[str, Any]] = None) -> Dict[str, Any]: 188 | """ 189 | Add a memory to the knowledge base and save the updated index. 190 | 191 | Args: 192 | memory (str): The memory to add. 193 | 194 | Returns: 195 | Dict[str, Any]: A dictionary containing the success status and message. 196 | """ 197 | try: 198 | if self.index_tool is None: 199 | raise ValueError(f"Knowledgebase '{self.kb_name}' is not initialized. Please initialize or load the knowledgebase first.") 200 | 201 | memory_obj = self._format_memory(memory, metadata) 202 | memory_id = memory_obj['id'] 203 | 204 | new_embedding, _ = EmbeddingsTools.get_embeddings([memory], provider=self.embedding_provider, model=self.embedding_model) 205 | new_vector = self.np.array(new_embedding).astype('float32') 206 | 207 | self.index_tool.add_vectors(new_vector) 208 | self.memories[memory_id] = memory_obj 209 | self.index_tool.set_metadata('memories', list(self.memories.values())) 210 | 211 | if self.save_filepath: 212 | self.save_index(self.save_filepath) 213 | else: 214 | print("Warning: No save path set for the index. Changes are only in memory.") 215 | 216 | return {"success": True, "message": "Memory added successfully and index saved", "id": memory_id} 217 | except Exception as e: 218 | error_message = f"Error adding memory to knowledgebase: {str(e)}" 219 | print(error_message) 220 | return {"success": False, "message": error_message} 221 | 222 | def save_index(self, save_to_filepath: Optional[str] = None) -> None: 223 | try: 224 | if self.index_tool is None: 225 | raise ValueError(f"Knowledgebase '{self.kb_name}' is not initialized or loaded.") 226 | 227 | filepath = save_to_filepath or self.save_filepath 228 | if not filepath: 229 | raise ValueError("No filepath provided and no known filepath for the index.") 230 | 231 | self.index_tool.save_index(filepath) 232 | self.save_filepath = filepath 233 | #print(f"Index saved to {filepath}") 234 | except Exception as e: 235 | raise Exception(f"Error saving knowledgebase index: {str(e)}") 236 | -------------------------------------------------------------------------------- /taskflowai/tools/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 TaskFlowAI Contributors. Licensed under Apache License 2.0. 2 | 3 | # Core tools 4 | from .amadeus_tools import AmadeusTools 5 | from .audio_tools import TextToSpeechTools, WhisperTools 6 | from .calculator_tools import CalculatorTools 7 | from .conversation_tools import ConversationTools 8 | from .embedding_tools import EmbeddingsTools 9 | from .file_tools import FileTools 10 | from .github_tools import GitHubTools 11 | from .faiss_tools import FAISSTools 12 | from .pinecone_tools import PineconeTools 13 | from .web_tools import WebTools 14 | from .wikipedia_tools import WikipediaTools 15 | 16 | __all__ = [ 17 | 'AmadeusTools', 18 | 'TextToSpeechTools', 19 | 'WhisperTools', 20 | 'CalculatorTools', 21 | 'ConversationTools', 22 | 'EmbeddingsTools', 23 | 'FAISSTools', 24 | 'FileTools', 25 | 'GitHubTools', 26 | 'PineconeTools', 27 | 'WebTools', 28 | 'WikipediaTools' 29 | ] 30 | 31 | # Helper function for optional imports 32 | def _optional_import(tool_name, install_name): 33 | class OptionalTool: 34 | def __init__(self, *args, **kwargs): 35 | raise ImportError( 36 | f"The tool '{tool_name}' requires additional dependencies. " 37 | f"Please install them using: 'pip install taskflowai[{install_name}]'" 38 | ) 39 | return OptionalTool 40 | 41 | # Conditional imports or placeholders 42 | try: 43 | from .langchain_tools import LangchainTools 44 | __all__.append('LangchainTools') 45 | except ImportError: 46 | LangchainTools = _optional_import('LangchainTools', 'langchain_tools') 47 | 48 | try: 49 | from .matplotlib_tools import MatplotlibTools 50 | __all__.append('MatplotlibTools') 51 | except ImportError: 52 | MatplotlibTools = _optional_import('MatplotlibTools', 'matplotlib_tools') 53 | 54 | try: 55 | from .yahoo_finance_tools import YahooFinanceTools 56 | __all__.append('YahooFinanceTools') 57 | except ImportError: 58 | YahooFinanceTools = _optional_import('YahooFinanceTools', 'yahoo_finance_tools') 59 | 60 | try: 61 | from .fred_tools import FredTools 62 | __all__.append('FredTools') 63 | except ImportError: 64 | FredTools = _optional_import('FredTools', 'fred_tools') 65 | 66 | -------------------------------------------------------------------------------- /taskflowai/tools/amadeus_tools.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 TaskFlowAI Contributors. Licensed under Apache License 2.0. 2 | 3 | import os 4 | from typing import Dict, Any, Optional, Tuple, Union 5 | from dotenv import load_dotenv 6 | import requests 7 | from datetime import datetime, timedelta 8 | 9 | class AmadeusTools: 10 | @staticmethod 11 | def _get_access_token(): 12 | """Get Amadeus API access token.""" 13 | load_dotenv() 14 | api_key = os.getenv("AMADEUS_API_KEY") 15 | api_secret = os.getenv("AMADEUS_API_SECRET") 16 | 17 | if not api_key or not api_secret: 18 | raise ValueError("AMADEUS_API_KEY and AMADEUS_API_SECRET must be set in .env file") 19 | 20 | token_url = "https://test.api.amadeus.com/v1/security/oauth2/token" 21 | data = { 22 | "grant_type": "client_credentials", 23 | "client_id": api_key, 24 | "client_secret": api_secret 25 | } 26 | 27 | response = requests.post(token_url, data=data) 28 | response.raise_for_status() 29 | return response.json()["access_token"] 30 | 31 | @staticmethod 32 | def search_flights( 33 | origin: str, 34 | destination: str, 35 | departure_date: str, 36 | return_date: Optional[str] = None, 37 | adults: int = 1, 38 | children: int = 0, 39 | infants: int = 0, 40 | travel_class: Optional[str] = None, 41 | non_stop: bool = False, 42 | currency: str = "USD", 43 | max_price: Optional[int] = None, 44 | max_results: int = 10 45 | ) -> Dict[str, Any]: 46 | """ 47 | Search for flight offers using Amadeus API. 48 | 49 | Args: 50 | origin (str): Origin airport (e.g., "YYZ") 51 | destination (str): Destination airport (e.g., "CDG") 52 | departure_date (str): Departure date in YYYY-MM-DD format 53 | return_date (Optional[str]): Return date in YYYY-MM-DD format for round trips 54 | adults (int): Number of adult travelers 55 | children (int): Number of child travelers 56 | infants (int): Number of infant travelers 57 | travel_class (Optional[str]): Preferred travel class (ECONOMY, PREMIUM_ECONOMY, BUSINESS, FIRST) 58 | non_stop (bool): If True, search for non-stop flights only 59 | currency (str): Currency code for pricing (default: USD) 60 | max_price (Optional[int]): Maximum price per traveler 61 | max_results (int): Maximum number of results to return (default: 10) 62 | 63 | Returns: 64 | Dict[str, Any]: Flight search results 65 | """ 66 | access_token = AmadeusTools._get_access_token() 67 | 68 | url = "https://test.api.amadeus.com/v2/shopping/flight-offers" 69 | 70 | headers = { 71 | "Authorization": f"Bearer {access_token}", 72 | "Content-Type": "application/json" 73 | } 74 | 75 | params = { 76 | "originLocationCode": origin, 77 | "destinationLocationCode": destination, 78 | "departureDate": departure_date, 79 | "adults": adults, 80 | "children": children, 81 | "infants": infants, 82 | "currencyCode": currency, 83 | "max": max_results 84 | } 85 | 86 | if return_date: 87 | params["returnDate"] = return_date 88 | 89 | if travel_class: 90 | params["travelClass"] = travel_class 91 | 92 | if non_stop: 93 | params["nonStop"] = "true" 94 | 95 | if max_price: 96 | params["maxPrice"] = max_price 97 | 98 | try: 99 | response = requests.get(url, headers=headers, params=params) 100 | response.raise_for_status() 101 | return response.json() 102 | except requests.exceptions.RequestException as e: 103 | error_message = f"Error searching for flight offers: {str(e)}" 104 | if hasattr(e, 'response') and e.response is not None: 105 | error_message += f"\nResponse status code: {e.response.status_code}" 106 | error_message += f"\nResponse content: {e.response.text}" 107 | return error_message 108 | 109 | @staticmethod 110 | def get_cheapest_date( 111 | origin: str, 112 | destination: str, 113 | departure_date: Union[str, Tuple[str, str]], 114 | return_date: Optional[Union[str, Tuple[str, str]]] = None, 115 | adults: int = 1 116 | ) -> Dict[str, Any]: 117 | """ 118 | Find the cheapest flight offer for a given route and date or date range using the Amadeus Flight Offers Search API. 119 | Max range of 7 days between start and end date. 120 | 121 | Args: 122 | origin (str): IATA code of the origin airport. 123 | destination (str): IATA code of the destination airport. 124 | departure_date (Union[str, Tuple[str, str]]): Departure date in YYYY-MM-DD format or a tuple of (start_date, end_date). 125 | return_date (Optional[Union[str, Tuple[str, str]]]): Return date in YYYY-MM-DD format or a tuple of (start_date, end_date) for round trips. Defaults to None. 126 | adults (int): Number of adult travelers. Defaults to 1. 127 | 128 | Returns: 129 | Dict[str, Any]: A dictionary containing the cheapest flight offer information. 130 | 131 | Raises: 132 | requests.exceptions.HTTPError: If the API request fails. 133 | ValueError: If the date range is more than 7 days. 134 | """ 135 | access_token = AmadeusTools._get_access_token() 136 | 137 | url = "https://test.api.amadeus.com/v2/shopping/flight-offers" 138 | 139 | headers = { 140 | "Authorization": f"Bearer {access_token}", 141 | "Content-Type": "application/json" 142 | } 143 | 144 | def date_range(start_date: str, end_date: str): 145 | start = datetime.fromisoformat(start_date) 146 | end = datetime.fromisoformat(end_date) 147 | if (end - start).days > 7: 148 | return {"error": "Date range cannot exceed 7 days"} 149 | date = start 150 | while date <= end: 151 | yield date.strftime("%Y-%m-%d") 152 | date += timedelta(days=1) 153 | 154 | departure_dates = [departure_date] if isinstance(departure_date, str) else list(date_range(*departure_date)) 155 | return_dates = [return_date] if return_date and isinstance(return_date, str) else (list(date_range(*return_date)) if return_date else [None]) 156 | 157 | cheapest_offer = None 158 | cheapest_price = float('inf') 159 | 160 | for dep_date in departure_dates: 161 | for ret_date in return_dates: 162 | params = { 163 | "originLocationCode": origin, 164 | "destinationLocationCode": destination, 165 | "departureDate": dep_date, 166 | "adults": adults, 167 | "max": 1, 168 | "currencyCode": "USD" 169 | } 170 | if ret_date: 171 | params["returnDate"] = ret_date 172 | 173 | response = requests.get(url, headers=headers, params=params) 174 | response.raise_for_status() 175 | 176 | data = response.json() 177 | if data.get('data'): 178 | offer = data['data'][0] 179 | price = float(offer['price']['total']) 180 | if price < cheapest_price: 181 | cheapest_price = price 182 | cheapest_offer = offer 183 | 184 | if not cheapest_offer: 185 | return {"error": "No flights found for the given criteria"} 186 | 187 | result = { 188 | "price": cheapest_offer['price']['total'], 189 | "departureDate": cheapest_offer['itineraries'][0]['segments'][0]['departure']['at'], 190 | "airline": cheapest_offer['validatingAirlineCodes'][0], 191 | "details": cheapest_offer 192 | } 193 | 194 | if return_date: 195 | result["returnDate"] = cheapest_offer['itineraries'][-1]['segments'][0]['departure']['at'] 196 | 197 | return result 198 | 199 | @staticmethod 200 | def get_flight_inspiration( 201 | origin: str, 202 | max_price: Optional[int] = None, 203 | currency: str = "EUR" 204 | ) -> Dict[str, Any]: 205 | """ 206 | Get flight inspiration using the Flight Inspiration Search API. 207 | 208 | This method uses the Amadeus Flight Inspiration Search API to find travel destinations 209 | based on the origin city and optional price constraints. 210 | 211 | Args: 212 | origin (str): IATA code of the origin city. 213 | max_price (Optional[int], optional): Maximum price for the flights. Defaults to None. 214 | currency (str, optional): Currency for the price. Defaults to "EUR". 215 | 216 | Returns: 217 | Dict[str, Any]: A dictionary containing flight inspiration results, including: 218 | - data: A list of dictionaries, each representing a destination with details such as: 219 | - type: The type of the result (usually "flight-destination"). 220 | - origin: The IATA code of the origin city. 221 | - destination: The IATA code of the destination city. 222 | - departureDate: The suggested departure date. 223 | - returnDate: The suggested return date. 224 | - price: The price information for the trip. 225 | 226 | Raises: 227 | requests.exceptions.HTTPError: If the API request fails. 228 | ValueError: If the required environment variables are not set. 229 | 230 | Note: 231 | This method requires valid Amadeus API credentials to be set in the environment variables. 232 | """ 233 | access_token = AmadeusTools._get_access_token() 234 | 235 | url = "https://test.api.amadeus.com/v1/shopping/flight-destinations" 236 | 237 | headers = { 238 | "Authorization": f"Bearer {access_token}", 239 | "Content-Type": "application/json" 240 | } 241 | 242 | params = { 243 | "origin": origin, 244 | "currency": currency 245 | } 246 | 247 | if max_price: 248 | params["maxPrice"] = max_price 249 | 250 | response = requests.get(url, headers=headers, params=params) 251 | try: 252 | response.raise_for_status() 253 | except requests.exceptions.HTTPError as e: 254 | print(f"HTTP Error: {e}") 255 | print(f"Response content: {response.text}") 256 | raise 257 | return response.json() 258 | 259 | -------------------------------------------------------------------------------- /taskflowai/tools/audio_tools.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 TaskFlowAI Contributors. Licensed under Apache License 2.0. 2 | 3 | import os 4 | from typing import List, Dict, Optional, Literal, Union 5 | from dotenv import load_dotenv 6 | import requests 7 | from openai import OpenAI 8 | import time 9 | import io 10 | 11 | load_dotenv() 12 | 13 | class TextToSpeechTools: 14 | @staticmethod 15 | def elevenlabs_text_to_speech(text: str, voice: str = "Giovanni", output_file: str = None): 16 | """ 17 | Convert text to speech using the ElevenLabs API and either play the generated audio or save it to a file. 18 | 19 | Args: 20 | text (str): The text to convert to speech. 21 | voice (str, optional): The name of the voice to use. Defaults to "Giovanni". 22 | output_file (str, optional): The name of the file to save the generated audio. If None, the audio will be played aloud. 23 | 24 | Returns: 25 | None 26 | """ 27 | try: 28 | from elevenlabs import play 29 | from elevenlabs.client import ElevenLabs 30 | except ModuleNotFoundError as e: 31 | raise ImportError(f"{e.name} is required for text_to_speech. Install with `pip install {e.name}`") 32 | 33 | api_key = os.getenv('ELEVENLABS_API_KEY') 34 | 35 | if not api_key: 36 | raise ValueError("ELEVENLABS_API_KEY not found in environment variables") 37 | 38 | client = ElevenLabs(api_key=api_key) 39 | 40 | audio = client.generate( 41 | text=text, 42 | voice=voice, 43 | model="eleven_multilingual_v2" 44 | ) 45 | 46 | if output_file: 47 | with open(output_file, "wb") as file: 48 | file.write(audio) 49 | else: 50 | play(audio) 51 | return audio 52 | 53 | @staticmethod 54 | def openai_text_to_speech(text: str, voice: str = "onyx", output_file: str = None): 55 | """ 56 | Generate speech from text using the OpenAI API and either save it to a file or play it aloud. 57 | 58 | Args: 59 | text (str): The text to convert to speech. 60 | voice (str, optional): The name of the voice to use. Defaults to "onyx". 61 | output_file (str, optional): The name of the file to save the generated audio. If None, the audio will be played aloud. 62 | 63 | Returns: 64 | None 65 | """ 66 | try: 67 | import os 68 | os.environ['PYGAME_HIDE_SUPPORT_PROMPT'] = "hide" 69 | import pygame 70 | except ModuleNotFoundError as e: 71 | raise ImportError(f"pygame is required for audio playback in the openai_text_to_speech tool. Install with `pip install pygame`") 72 | 73 | client = OpenAI(api_key=os.getenv("OPENAI_API_KEY")) 74 | 75 | response = client.audio.speech.create( 76 | model="tts-1", 77 | voice=voice, 78 | input=text, 79 | speed=1.0 80 | ) 81 | 82 | if output_file: 83 | # Save the audio file using the recommended streaming method 84 | with open(output_file, "wb") as file: 85 | for chunk in response.iter_bytes(): 86 | file.write(chunk) 87 | else: 88 | time.sleep(0.7) 89 | # Play the audio directly using pygame 90 | pygame.mixer.init() 91 | audio_data = b''.join(chunk for chunk in response.iter_bytes()) 92 | audio_file = io.BytesIO(audio_data) 93 | pygame.mixer.music.load(audio_file) 94 | 95 | # Add a small delay before playing 96 | time.sleep(1) 97 | 98 | pygame.mixer.music.play() 99 | while pygame.mixer.music.get_busy(): 100 | time.sleep(0.1) 101 | pygame.mixer.quit() 102 | 103 | class WhisperTools: 104 | @staticmethod 105 | def whisper_transcribe_audio( 106 | audio_input: Union[str, List[str]], 107 | model: str = "whisper-1", 108 | language: Optional[str] = None, 109 | prompt: Optional[str] = None, 110 | response_format: Literal["json", "text", "srt", "verbose_json", "vtt"] = "json", 111 | temperature: float = 0, 112 | timestamp_granularities: Optional[List[Literal["segment", "word"]]] = None 113 | ) -> Union[Dict, List[Dict]]: 114 | """ 115 | Transcribe audio using the OpenAI Whisper API. 116 | 117 | Args: 118 | audio_input (Union[str, List[str]]): Path to audio file(s) or list of paths. 119 | model (str): The model to use for transcription. Default is "whisper-1". 120 | language (Optional[str]): The language of the input audio. If None, Whisper will auto-detect. 121 | prompt (Optional[str]): An optional text to guide the model's style or continue a previous audio segment. 122 | response_format (str): The format of the transcript output. Default is "json". 123 | temperature (float): The sampling temperature, between 0 and 1. Default is 0. 124 | timestamp_granularities (Optional[List[str]]): List of timestamp granularities to include. 125 | 126 | Returns: 127 | Union[Dict, List[Dict]]: Transcription result(s) in the specified format. 128 | """ 129 | api_key = os.getenv("OPENAI_API_KEY") 130 | 131 | if not api_key: 132 | raise ValueError("OPENAI_API_KEY must be set in .env file") 133 | 134 | url = 'https://api.openai.com/v1/audio/transcriptions' 135 | headers = {'Authorization': f'Bearer {api_key}'} 136 | 137 | def process_single_file(file_path): 138 | with open(file_path, 'rb') as audio_file: 139 | files = {'file': audio_file} 140 | data = { 141 | 'model': model, 142 | 'response_format': response_format, 143 | 'temperature': temperature, 144 | } 145 | if language: 146 | data['language'] = language 147 | if prompt: 148 | data['prompt'] = prompt 149 | if timestamp_granularities: 150 | data['timestamp_granularities'] = timestamp_granularities 151 | 152 | response = requests.post(url, headers=headers, files=files, data=data) 153 | response.raise_for_status() 154 | 155 | if response_format == 'json' or response_format == 'verbose_json': 156 | return response.json() 157 | else: 158 | return response.text 159 | 160 | if isinstance(audio_input, str): 161 | return process_single_file(audio_input) 162 | elif isinstance(audio_input, list) and all(isinstance(file, str) for file in audio_input): 163 | return [process_single_file(file) for file in audio_input if os.path.isfile(file)] 164 | else: 165 | raise ValueError('Invalid input type. Expected string or list of strings.') 166 | 167 | @staticmethod 168 | def whisper_translate_audio( 169 | audio_input: Union[str, List[str]], 170 | model: str = "whisper-1", 171 | prompt: Optional[str] = None, 172 | response_format: Literal["json", "text", "srt", "verbose_json", "vtt"] = "json", 173 | temperature: float = 0, 174 | timestamp_granularities: Optional[List[Literal["segment", "word"]]] = None 175 | ) -> Union[Dict, List[Dict]]: 176 | """ 177 | Translate audio to English using the OpenAI Whisper API. 178 | 179 | Args: 180 | audio_input (Union[str, List[str]]): Path to audio file(s) or list of paths. 181 | model (str): The model to use for translation. Default is "whisper-1". 182 | prompt (Optional[str]): An optional text to guide the model's style or continue a previous audio segment. 183 | response_format (str): The format of the transcript output. Default is "json". 184 | temperature (float): The sampling temperature, between 0 and 1. Default is 0. 185 | timestamp_granularities (Optional[List[str]]): List of timestamp granularities to include. 186 | 187 | Returns: 188 | Union[Dict, List[Dict]]: Translation result(s) in the specified format. 189 | """ 190 | load_dotenv() 191 | api_key = os.getenv("OPENAI_API_KEY") 192 | 193 | if not api_key: 194 | raise ValueError("OPENAI_API_KEY must be set in .env file") 195 | 196 | url = 'https://api.openai.com/v1/audio/translations' 197 | headers = {'Authorization': f'Bearer {api_key}'} 198 | 199 | def process_single_file(file_path): 200 | with open(file_path, 'rb') as audio_file: 201 | files = {'file': audio_file} 202 | data = { 203 | 'model': model, 204 | 'response_format': response_format, 205 | 'temperature': temperature, 206 | } 207 | if prompt: 208 | data['prompt'] = prompt 209 | if timestamp_granularities: 210 | data['timestamp_granularities'] = timestamp_granularities 211 | 212 | response = requests.post(url, headers=headers, files=files, data=data) 213 | response.raise_for_status() 214 | 215 | if response_format == 'json' or response_format == 'verbose_json': 216 | return response.json() 217 | else: 218 | return response.text 219 | 220 | if isinstance(audio_input, str): 221 | return process_single_file(audio_input) 222 | elif isinstance(audio_input, list): 223 | return [process_single_file(file) for file in audio_input if os.path.isfile(file)] 224 | else: 225 | raise ValueError('Invalid input type. Expected string or list of strings.') 226 | -------------------------------------------------------------------------------- /taskflowai/tools/calculator_tools.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 TaskFlowAI Contributors. Licensed under Apache License 2.0. 2 | 3 | from datetime import timedelta, datetime 4 | 5 | class CalculatorTools: 6 | @staticmethod 7 | def basic_math(operation: str, args: list) -> float: 8 | """ 9 | Perform basic and advanced math operations on multiple numbers. 10 | 11 | Args: 12 | operation (str): One of 'add', 'subtract', 'multiply', 'divide', 'exponent', 'root', 'modulo', or 'factorial'. 13 | args (list): List of numbers to perform the operation on. 14 | 15 | Returns: 16 | float: Result of the operation. 17 | 18 | Raises: 19 | ValueError: If an invalid operation is provided, if dividing by zero, if fewer than required numbers are provided, or for invalid inputs. 20 | 21 | Note: 22 | This method does not take in letters or words. It only takes in numbers. 23 | """ 24 | if len(args) < 1: 25 | raise ValueError("At least one number is required for the operation.") 26 | 27 | # Convert all args to float, except for factorial which requires int 28 | if operation != 'factorial': 29 | args = [float(arg) for arg in args] 30 | 31 | result = args[0] 32 | 33 | if operation in ['add', 'subtract', 'multiply', 'divide']: 34 | if len(args) < 2: 35 | raise ValueError("At least two numbers are required for this operation.") 36 | 37 | if operation == 'add': 38 | for num in args[1:]: 39 | result += num 40 | elif operation == 'subtract': 41 | for num in args[1:]: 42 | result -= num 43 | elif operation == 'multiply': 44 | for num in args[1:]: 45 | result *= num 46 | elif operation == 'divide': 47 | for num in args[1:]: 48 | if num == 0: 49 | raise ValueError("Cannot divide by zero") 50 | result /= num 51 | elif operation == 'exponent': 52 | if len(args) != 2: 53 | raise ValueError("Exponent operation requires exactly two numbers.") 54 | result = args[0] ** args[1] 55 | elif operation == 'root': 56 | if len(args) != 2: 57 | raise ValueError("Root operation requires exactly two numbers.") 58 | if args[1] == 0: 59 | raise ValueError("Cannot calculate 0th root") 60 | result = args[0] ** (1 / args[1]) 61 | elif operation == 'modulo': 62 | if len(args) != 2: 63 | raise ValueError("Modulo operation requires exactly two numbers.") 64 | if args[1] == 0: 65 | raise ValueError("Cannot perform modulo with zero") 66 | result = args[0] % args[1] 67 | elif operation == 'factorial': 68 | if len(args) != 1 or args[0] < 0 or not isinstance(args[0], int): 69 | raise ValueError("Factorial operation requires exactly one non-negative integer.") 70 | result = 1 71 | for i in range(1, args[0] + 1): 72 | result *= i 73 | else: 74 | raise ValueError("Invalid operation. Choose 'add', 'subtract', 'multiply', 'divide', 'exponent', 'root', 'modulo', or 'factorial'.") 75 | 76 | # Convert the result to a string before returning 77 | return str(result) 78 | 79 | @staticmethod 80 | def get_current_time() -> str: 81 | """ 82 | Get the current UTC time. 83 | 84 | Returns: 85 | str: The current UTC time in the format 'YYYY-MM-DD HH:MM:SS'. 86 | """ 87 | return datetime.utcnow().strftime("%Y-%m-%d %H:%M:%S") 88 | 89 | @staticmethod 90 | def add_days(date_str: str, days: int) -> str: 91 | """ 92 | Add a number of days to a given date. 93 | 94 | Args: 95 | date_str (str): The starting date in 'YYYY-MM-DD' format. 96 | days (int): The number of days to add (can be negative). 97 | 98 | Returns: 99 | str: The resulting date in 'YYYY-MM-DD' format. 100 | """ 101 | date = datetime.strptime(date_str, "%Y-%m-%d") 102 | new_date = date + timedelta(days=days) 103 | return new_date.strftime("%Y-%m-%d") 104 | 105 | @staticmethod 106 | def days_between(date1_str: str, date2_str: str) -> int: 107 | """ 108 | Calculate the number of days between two dates. 109 | 110 | Args: 111 | date1_str (str): The first date in 'YYYY-MM-DD' format. 112 | date2_str (str): The second date in 'YYYY-MM-DD' format. 113 | 114 | Returns: 115 | int: The number of days between the two dates. 116 | """ 117 | date1 = datetime.strptime(date1_str, "%Y-%m-%d") 118 | date2 = datetime.strptime(date2_str, "%Y-%m-%d") 119 | return abs((date2 - date1).days) 120 | 121 | @staticmethod 122 | def format_date(date_str: str, input_format: str, output_format: str) -> str: 123 | """ 124 | Convert a date string from one format to another. 125 | 126 | Args: 127 | date_str (str): The date string to format. 128 | input_format (str): The current format of the date string. 129 | output_format (str): The desired output format. 130 | 131 | Returns: 132 | str: The formatted date string. 133 | 134 | Example: 135 | format_date("2023-05-15", "%Y-%m-%d", "%B %d, %Y") -> "May 15, 2023" 136 | """ 137 | date_obj = datetime.strptime(date_str, input_format) 138 | return date_obj.strftime(output_format) 139 | -------------------------------------------------------------------------------- /taskflowai/tools/conversation_tools.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 TaskFlowAI Contributors. Licensed under Apache License 2.0. 2 | 3 | from typing import Any, List, Callable 4 | from taskflowai import Agent, Task 5 | 6 | class ConversationTools: 7 | 8 | @staticmethod 9 | def ask_user(question: str) -> str: 10 | """ 11 | Prompt the human user with a question and return their answer. 12 | 13 | This method prints a question to the console, prefixed with "Agent: ", 14 | and then waits for the user to input their response. The user's input 15 | is captured and returned as a string. 16 | 17 | Args: 18 | question (str): The question to ask the human user. 19 | 20 | Returns: 21 | str: The user's response to the question. 22 | 23 | Note: 24 | - This method blocks execution until the user provides valid input. 25 | """ 26 | while True: 27 | print(f"Agent: {question}") 28 | answer = input("Human: ").strip() 29 | if answer.lower() == 'exit': 30 | return None 31 | if answer: 32 | return answer 33 | print("You entered an empty string. Please try again or type 'exit' to cancel.") 34 | 35 | @staticmethod 36 | def ask_agent(*agents: Agent) -> List[Callable]: 37 | """ 38 | Creates tool functions for inter-agent communication. 39 | 40 | Args: 41 | *agents: Variable length list of Agent instances. 42 | 43 | Returns: 44 | List of callable tool functions that can be used to communicate with the specified agents. 45 | """ 46 | def create_ask_tool(target_agent: Agent) -> Callable: 47 | def ask_tool(question: str) -> Any: 48 | """ 49 | Tool function to ask a question to a specific agent. 50 | 51 | Args: 52 | question (str): The question or instruction to send. 53 | 54 | Returns: 55 | Any: Response from the agent being asked. 56 | """ 57 | return Task.create( 58 | agent=target_agent, 59 | 60 | instruction=question, 61 | ) 62 | 63 | ask_tool.__name__ = f"ask_{target_agent.role.replace(' ', '_')}" 64 | return ask_tool 65 | 66 | return [create_ask_tool(agent) for agent in agents] 67 | -------------------------------------------------------------------------------- /taskflowai/tools/embedding_tools.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 TaskFlowAI Contributors. Licensed under Apache License 2.0. 2 | 3 | import os 4 | from typing import List, Dict, Any, Union, Literal, Tuple 5 | from dotenv import load_dotenv 6 | import requests 7 | import time 8 | 9 | # Load environment variables 10 | load_dotenv() 11 | 12 | class EmbeddingsTools: 13 | """ 14 | A class for generating embeddings using various models. 15 | """ 16 | MODEL_DIMENSIONS = { 17 | # OpenAI 18 | "text-embedding-3-small": 1536, 19 | "text-embedding-3-large": 3072, 20 | "text-embedding-ada-002": 1536, 21 | # Cohere 22 | "embed-english-v3.0": 1024, 23 | "embed-english-light-v3.0": 384, 24 | "embed-english-v2.0": 4096, 25 | "embed-english-light-v2.0": 1024, 26 | "embed-multilingual-v3.0": 1024, 27 | "embed-multilingual-light-v3.0": 384, 28 | "embed-multilingual-v2.0": 768, 29 | # Mistral 30 | "mistral-embed": 1024 31 | } 32 | 33 | @staticmethod 34 | def get_model_dimension(provider: str, model: str) -> int: 35 | """ 36 | Get the dimension of the specified embedding model. 37 | 38 | Args: 39 | provider (str): The provider of the embedding model. 40 | model (str): The name of the embedding model. 41 | 42 | Returns: 43 | int: The dimension of the embedding model. 44 | 45 | Raises: 46 | ValueError: If the provider or model is not supported. 47 | """ 48 | if provider == "openai" or provider == "cohere" or provider == "mistral": 49 | if model in EmbeddingsTools.MODEL_DIMENSIONS: 50 | return EmbeddingsTools.MODEL_DIMENSIONS[model] 51 | 52 | raise ValueError(f"Unsupported embedding model: {model} for provider: {provider}") 53 | 54 | @staticmethod 55 | def get_openai_embeddings( 56 | input_text: Union[str, List[str]], 57 | model: Literal["text-embedding-3-small", "text-embedding-3-large", "text-embedding-ada-002"] = "text-embedding-3-small" 58 | ) -> Tuple[List[List[float]], Dict[str, int]]: 59 | """ 60 | Generate embeddings for the given input text using OpenAI's API. 61 | 62 | Args: 63 | input_text (Union[str, List[str]]): The input text or list of texts to embed. 64 | model (Literal["text-embedding-3-small", "text-embedding-3-large", "text-embedding-ada-002"]): 65 | The model to use for generating embeddings. Default is "text-embedding-3-small". 66 | 67 | Returns: 68 | Tuple[List[List[float]], Dict[str, int]]: A tuple containing: 69 | - A list of embeddings. 70 | - A dictionary with the number of dimensions for the chosen model. 71 | 72 | Raises: 73 | ValueError: If the API key is not set or if there's an error with the API call. 74 | requests.exceptions.RequestException: If there's an error with the HTTP request. 75 | 76 | Note: 77 | This method requires a valid OpenAI API key to be set in the OPENAI_API_KEY environment variable. 78 | """ 79 | 80 | # Get API key from environment variable 81 | api_key = os.getenv("OPENAI_API_KEY") 82 | if not api_key: 83 | raise ValueError("OPENAI_API_KEY environment variable is not set") 84 | 85 | # Prepare the API request 86 | url = "https://api.openai.com/v1/embeddings" 87 | headers = { 88 | "Authorization": f"Bearer {api_key}", 89 | "Content-Type": "application/json" 90 | } 91 | 92 | # Ensure input_text is a list and not empty 93 | if isinstance(input_text, str): 94 | input_text = [input_text] 95 | 96 | if not input_text or any(not text.strip() for text in input_text): 97 | raise ValueError("Input text cannot be empty") 98 | 99 | # Validate input length 100 | max_tokens = 8191 # OpenAI's max token limit 101 | if any(len(text) > max_tokens for text in input_text): 102 | raise ValueError(f"Input text exceeds maximum token limit of {max_tokens}") 103 | 104 | payload = { 105 | "input": input_text, 106 | "model": model, 107 | } 108 | 109 | try: 110 | response = requests.post(url, headers=headers, json=payload) 111 | response.raise_for_status() 112 | data = response.json() 113 | 114 | embeddings = [item['embedding'] for item in data['data']] 115 | 116 | return embeddings, {"dimensions": EmbeddingsTools.MODEL_DIMENSIONS[model]} 117 | 118 | except requests.exceptions.RequestException as e: 119 | raise requests.exceptions.RequestException(f"Error making request to OpenAI API: {str(e)}") 120 | 121 | @staticmethod 122 | def get_cohere_embeddings( 123 | input_text: Union[str, List[str]], 124 | model: str = "embed-english-v3.0", 125 | input_type: str = "search_document" 126 | ) -> Tuple[List[List[float]], Dict[str, int]]: 127 | """ 128 | Generate embeddings for the given input text using Cohere's API. 129 | 130 | Args: 131 | input_text (Union[str, List[str]]): The input text or list of texts to embed. 132 | model (str): The model to use for generating embeddings. Default is "embed-english-v3.0". 133 | input_type (str): The type of input. Default is "search_document". 134 | 135 | Returns: 136 | Tuple[List[List[float]], Dict[str, int]]: A tuple containing: 137 | - A list of embeddings. 138 | - A dictionary with the number of dimensions for the chosen model. 139 | 140 | Raises: 141 | ValueError: If the API key is not set. 142 | RuntimeError: If there's an error in generating embeddings from the Cohere API. 143 | """ 144 | # Get API key from environment variable 145 | api_key = os.getenv("COHERE_API_KEY") 146 | if not api_key: 147 | raise ValueError("COHERE_API_KEY environment variable is not set") 148 | 149 | # Check for cohere 150 | try: 151 | import cohere 152 | except ModuleNotFoundError: 153 | raise ImportError("cohere package is required for Cohere embedding tools. Install with `pip install cohere`") 154 | cohere_client = cohere.Client(api_key) 155 | 156 | # Ensure input_text is a list 157 | if isinstance(input_text, str): 158 | input_text = [input_text] 159 | 160 | try: 161 | time.sleep(1) # Rate limiting 162 | response = cohere_client.embed( 163 | texts=input_text, 164 | model=model, 165 | input_type=input_type 166 | ) 167 | embeddings = response.embeddings 168 | return embeddings, {"dimensions": EmbeddingsTools.MODEL_DIMENSIONS[model]} 169 | 170 | except Exception as e: 171 | raise RuntimeError(f"Failed to get embeddings from Cohere API: {str(e)}") 172 | 173 | @staticmethod 174 | def get_mistral_embeddings( 175 | input_text: Union[str, List[str]], 176 | model: str = "mistral-embed" 177 | ) -> Tuple[List[List[float]], Dict[str, int]]: 178 | """ 179 | Generate embeddings for the given input text using Mistral AI's API. 180 | 181 | Args: 182 | input_text (Union[str, List[str]]): The input text or list of texts to embed. 183 | model (str): The model to use for generating embeddings. Default is "mistral-embed". 184 | 185 | Returns: 186 | Tuple[List[List[float]], Dict[str, int]]: A tuple containing: 187 | - A list of embeddings. 188 | - A dictionary with the number of dimensions for the chosen model. 189 | 190 | Raises: 191 | ValueError: If the API key is not set or if there's an error with the API call. 192 | requests.exceptions.RequestException: If there's an error with the HTTP request. 193 | 194 | Note: 195 | This method requires a valid Mistral AI API key to be set in the MISTRAL_API_KEY environment variable. 196 | """ 197 | 198 | # Get API key from environment variable 199 | api_key = os.getenv("MISTRAL_API_KEY") 200 | if not api_key: 201 | raise ValueError("MISTRAL_API_KEY environment variable is not set") 202 | 203 | # Prepare the API request 204 | url = "https://api.mistral.ai/v1/embeddings" 205 | headers = { 206 | "Authorization": f"Bearer {api_key}", 207 | "Content-Type": "application/json" 208 | } 209 | 210 | # Ensure input_text is a list 211 | if isinstance(input_text, str): 212 | input_text = [input_text] 213 | 214 | payload = { 215 | "model": model, 216 | "input": input_text 217 | } 218 | 219 | try: 220 | response = requests.post(url, headers=headers, json=payload) 221 | response.raise_for_status() 222 | data = response.json() 223 | 224 | embeddings = [item['embedding'] for item in data['data']] 225 | dimensions = len(embeddings[0]) if embeddings else 0 226 | 227 | return embeddings, {"dimensions": dimensions} 228 | 229 | except requests.exceptions.RequestException as e: 230 | if hasattr(e.response, 'text'): 231 | error_details = e.response.text 232 | else: 233 | error_details = str(e) 234 | raise requests.exceptions.RequestException(f"Error making request to Mistral AI API: {error_details}") 235 | 236 | @staticmethod 237 | def get_embeddings(input_text: Union[str, List[str]], provider: str, model: str) -> Tuple[List[List[float]], Dict[str, int]]: 238 | """ 239 | Generate embeddings for the given input text using the specified provider and model. 240 | 241 | Args: 242 | input_text (Union[str, List[str]]): The input text or list of texts to embed. 243 | provider (str): The provider to use for generating embeddings. 244 | model (str): The model to use for generating embeddings. 245 | 246 | Returns: 247 | Tuple[List[List[float]], Dict[str, int]]: A tuple containing: 248 | - A list of embeddings. 249 | - A dictionary with the number of dimensions for the chosen model. 250 | 251 | Raises: 252 | ValueError: If the provider or model is not supported. 253 | """ 254 | if provider == "openai": 255 | if model in ["text-embedding-3-small", "text-embedding-3-large", "text-embedding-ada-002"]: 256 | return EmbeddingsTools.get_openai_embeddings(input_text, model) 257 | else: 258 | raise ValueError(f"Unsupported OpenAI embedding model: {model}") 259 | elif provider == "cohere": 260 | if model in [ 261 | "embed-english-v3.0", "embed-english-light-v3.0", "embed-english-v2.0", 262 | "embed-english-light-v2.0", "embed-multilingual-v3.0", "embed-multilingual-light-v3.0", 263 | "embed-multilingual-v2.0" 264 | ]: 265 | return EmbeddingsTools.get_cohere_embeddings(input_text, model) 266 | else: 267 | raise ValueError(f"Unsupported Cohere embedding model: {model}") 268 | elif provider == "mistral": 269 | if model == "mistral-embed": 270 | return EmbeddingsTools.get_mistral_embeddings(input_text, model) 271 | else: 272 | raise ValueError(f"Unsupported Mistral embedding model: {model}") 273 | else: 274 | raise ValueError(f"Unsupported embedding provider: {provider}") 275 | 276 | -------------------------------------------------------------------------------- /taskflowai/tools/faiss_tools.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 TaskFlowAI Contributors. Licensed under Apache License 2.0. 2 | 3 | import os 4 | import json 5 | import numpy as np 6 | from typing import Tuple, Any 7 | 8 | class FAISSTools: 9 | def __init__(self, dimension: int, metric: str = "IP"): 10 | """ 11 | Initialize FAISSTools with the specified dimension and metric. 12 | 13 | Args: 14 | dimension (int): Dimension of the vectors to be stored in the index. 15 | metric (str, optional): Distance metric to use. Defaults to "IP" (Inner Product). 16 | """ 17 | # Load the faiss library 18 | try: 19 | import faiss 20 | except ModuleNotFoundError: 21 | raise ImportError("faiss is required for FAISSTools. Install with `pip install faiss-cpu` or `pip install faiss-gpu`") 22 | 23 | self.dimension = dimension 24 | self.metric = metric 25 | self.index = None 26 | self.embedding_model = None 27 | self.embedding_provider = None 28 | self.metadata = {} # Added to store metadata 29 | 30 | self.faiss = faiss 31 | 32 | def create_index(self, index_type: str = "Flat") -> None: 33 | """ 34 | Create a new FAISS index. 35 | 36 | Args: 37 | index_type (str, optional): Type of index to create. Defaults to "Flat". 38 | 39 | Raises: 40 | ValueError: If an unsupported index type is specified. 41 | """ 42 | if index_type == "Flat": 43 | if self.metric == "IP": 44 | self.index = self.faiss.IndexFlatIP(self.dimension) 45 | elif self.metric == "L2": 46 | self.index = self.faiss.IndexFlatL2(self.dimension) 47 | else: 48 | raise ValueError(f"Unsupported metric: {self.metric}") 49 | else: 50 | raise ValueError(f"Unsupported index type: {index_type}") 51 | 52 | def load_index(self, index_path: str) -> None: 53 | """ 54 | Load a FAISS index and metadata from files. 55 | 56 | Args: 57 | index_path (str): Path to the index file. 58 | 59 | Raises: 60 | FileNotFoundError: If the index file or metadata file is not found. 61 | """ 62 | if not os.path.exists(index_path): 63 | raise FileNotFoundError(f"Index file not found: {index_path}") 64 | self.index = self.faiss.read_index(index_path) 65 | 66 | metadata_path = f"{index_path}.metadata" 67 | if not os.path.exists(metadata_path): 68 | raise FileNotFoundError(f"Metadata file not found: {metadata_path}") 69 | with open(metadata_path, 'r') as f: 70 | self.metadata = json.load(f) 71 | 72 | self.dimension = self.index.d 73 | self.embedding_model = self.metadata.get('embedding_model') 74 | 75 | def save_index(self, index_path: str) -> None: 76 | """ 77 | Save the FAISS index and metadata to files. 78 | 79 | Args: 80 | index_path (str): Path to save the index file. 81 | """ 82 | self.faiss.write_index(self.index, index_path) 83 | metadata_path = f"{index_path}.metadata" 84 | with open(metadata_path, 'w') as f: 85 | json.dump(self.metadata, f) 86 | 87 | def add_vectors(self, vectors: np.ndarray) -> None: 88 | """ 89 | Add vectors to the FAISS index. 90 | 91 | Args: 92 | vectors (np.ndarray): Array of vectors to add. 93 | 94 | Raises: 95 | ValueError: If the vector dimension does not match the index dimension. 96 | """ 97 | if vectors.shape[1] != self.dimension: 98 | raise ValueError(f"Vector dimension {vectors.shape[1]} does not match index dimension {self.dimension}") 99 | 100 | if self.metric == "IP": 101 | # Normalize vectors for Inner Product similarity 102 | vectors = np.apply_along_axis(self.normalize_vector, 1, vectors) 103 | 104 | self.index.add(vectors) 105 | 106 | def search_vectors(self, query_vectors: np.ndarray, top_k: int = 10) -> Tuple[np.ndarray, np.ndarray]: 107 | """ 108 | Search for similar vectors in the FAISS index. 109 | 110 | Args: 111 | query_vectors (np.ndarray): Array of query vectors. 112 | top_k (int, optional): Number of results to return for each query vector. Defaults to 10. 113 | 114 | Returns: 115 | Tuple[np.ndarray, np.ndarray]: A tuple containing the distances and indices of the top-k results. 116 | 117 | Raises: 118 | ValueError: If the query vector dimension does not match the index dimension. 119 | """ 120 | if query_vectors.shape[1] != self.dimension: 121 | raise ValueError(f"Query vector dimension {query_vectors.shape[1]} does not match index dimension {self.dimension}") 122 | 123 | if self.metric == "IP": 124 | # Normalize query vectors for Inner Product similarity 125 | query_vectors = np.apply_along_axis(self.normalize_vector, 1, query_vectors) 126 | 127 | distances, indices = self.index.search(query_vectors, top_k) 128 | return distances, indices 129 | 130 | def remove_vectors(self, ids: np.ndarray) -> None: 131 | """ 132 | Remove vectors from the FAISS index by their IDs. 133 | 134 | Args: 135 | ids (np.ndarray): Array of vector IDs to remove. 136 | """ 137 | self.index.remove_ids(ids) 138 | 139 | def get_vector_count(self) -> int: 140 | """ 141 | Get the number of vectors in the FAISS index. 142 | 143 | Returns: 144 | int: Number of vectors in the index. 145 | """ 146 | return self.index.ntotal 147 | 148 | @staticmethod 149 | def normalize_vector(vector: np.ndarray) -> np.ndarray: 150 | """ 151 | Normalize a vector to unit length. 152 | 153 | Args: 154 | vector (np.ndarray): The input vector. 155 | 156 | Returns: 157 | np.ndarray: The normalized vector. 158 | """ 159 | norm = np.linalg.norm(vector) 160 | return vector / norm if norm != 0 else vector 161 | 162 | def set_metadata(self, key: str, value: Any) -> None: 163 | """ 164 | Set metadata for the index. 165 | 166 | Args: 167 | key (str): Metadata key. 168 | value (Any): Metadata value. 169 | """ 170 | self.metadata[key] = value 171 | 172 | def get_metadata(self, key: str) -> Any: 173 | """ 174 | Get metadata from the index. 175 | 176 | Args: 177 | key (str): Metadata key. 178 | 179 | Returns: 180 | Any: Metadata value. 181 | """ 182 | return self.metadata.get(key) 183 | 184 | def set_embedding_info(self, provider: str, model: str) -> None: 185 | """ 186 | Set the embedding provider and model information. 187 | 188 | Args: 189 | provider (str): The embedding provider (e.g., "openai"). 190 | model (str): The embedding model name. 191 | """ 192 | self.embedding_provider = provider 193 | self.embedding_model = model 194 | self.set_metadata('embedding_provider', provider) 195 | self.set_metadata('embedding_model', model) -------------------------------------------------------------------------------- /taskflowai/tools/file_tools.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 TaskFlowAI Contributors. Licensed under Apache License 2.0. 2 | 3 | import os 4 | import csv 5 | import json 6 | import xml.etree.ElementTree as ET 7 | from typing import Any, List, Dict, Union 8 | import yaml 9 | 10 | debug_mode = False 11 | 12 | class FileTools: 13 | @staticmethod 14 | def save_code_to_file(code: str, file_path: str): 15 | """ 16 | Save the given code to a file at the specified path. 17 | 18 | Args: 19 | code (str): The code to be saved. 20 | file_path (str): The path where the file should be saved. 21 | 22 | Raises: 23 | OSError: If there's an error creating the directory or writing the file. 24 | TypeError: If the input types are incorrect. 25 | """ 26 | try: 27 | #print(f"Attempting to save code to file: {file_path}") 28 | os.makedirs(os.path.dirname(file_path), exist_ok=True) 29 | with open(file_path, "w") as file: 30 | file.write(code) 31 | print(f"\033[95mSaved code to {file_path}\033[0m") 32 | print(f"Successfully saved code to file: {file_path}") 33 | except OSError as e: 34 | print(f"Error creating directory or writing file at FileTools.save_code_to_file: {e}") 35 | print(f"OSError occurred at FileTools.save_code_to_file: {str(e)}") 36 | except TypeError as e: 37 | print(f"Invalid input type: {e}") 38 | print(f"TypeError occurred at FileTools.save_code_to_file: {str(e)}") 39 | except Exception as e: 40 | print(f"An unexpected error occurred at FileTools.save_code_to_file: {e}") 41 | print(f"Unexpected error at FileTools.save_code_to_file: {str(e)}") 42 | 43 | @staticmethod 44 | def generate_directory_tree(base_path, additional_ignore=None): 45 | """ 46 | Recursively generate a file structure dictionary for the given base path. 47 | 48 | Args: 49 | base_path (str): The root directory path to start the file structure generation. 50 | additional_ignore (List[str], optional): Additional files or directories to ignore. 51 | 52 | Returns: 53 | dict: A nested dictionary representing the file structure, where each directory 54 | is represented by a dict with 'name', 'type', and 'children' keys, and each 55 | file is represented by a dict with 'name', 'type', and 'contents' keys. 56 | 57 | Raises: 58 | ValueError: If the specified path is not within the current working directory. 59 | PermissionError: If there's a permission error accessing the directory or its contents. 60 | FileNotFoundError: If the specified path does not exist. 61 | OSError: If there's an error accessing the directory or its contents. 62 | """ 63 | default_ignore_list = {".DS_Store", ".gitignore", ".env", "node_modules", "__pycache__"} 64 | 65 | if additional_ignore: 66 | ignore_list = default_ignore_list.union(set(additional_ignore)) 67 | else: 68 | ignore_list = default_ignore_list 69 | 70 | #print(f"Starting file structure generation for path: {base_path}") 71 | #print(f"Ignore list: {ignore_list}") 72 | 73 | try: 74 | # Convert both paths to absolute and normalize them 75 | abs_base_path = os.path.abspath(os.path.normpath(base_path)) 76 | abs_cwd = os.path.abspath(os.path.normpath(os.getcwd())) 77 | 78 | # Check if the base_path is within or equal to the current working directory 79 | if not abs_base_path.startswith(abs_cwd): 80 | raise ValueError(f"Access to the specified path is not allowed: {abs_base_path}") 81 | 82 | if not os.path.exists(abs_base_path): 83 | raise FileNotFoundError(f"The specified path does not exist: {abs_base_path}") 84 | 85 | if not os.path.isdir(abs_base_path): 86 | raise NotADirectoryError(f"The specified path is not a directory: {abs_base_path}") 87 | 88 | file_structure = { 89 | "name": os.path.basename(abs_base_path), 90 | "type": "directory", 91 | "children": [] 92 | } 93 | 94 | for item in os.listdir(abs_base_path): 95 | if item in ignore_list or item.startswith('.'): 96 | print(f"Skipping ignored or hidden item: {item}") 97 | continue # Skip ignored and hidden files/directories 98 | 99 | item_path = os.path.join(abs_base_path, item) 100 | print(f"Processing item: {item_path}") 101 | 102 | if os.path.isdir(item_path): 103 | try: 104 | file_structure["children"].append(FileTools.generate_directory_tree(item_path)) 105 | except PermissionError: 106 | print(f"Permission denied for directory: {item_path}") 107 | file_structure["children"].append({ 108 | "name": item, 109 | "type": "directory", 110 | "error": "Permission denied" 111 | }) 112 | else: 113 | try: 114 | with open(item_path, "r", encoding="utf-8") as file: 115 | file_contents = file.read() 116 | #print(f"Successfully read file contents: {item_path}") 117 | except UnicodeDecodeError: 118 | print(f"UTF-8 decoding failed for {item_path}, attempting ISO-8859-1") 119 | try: 120 | with open(item_path, "r", encoding="iso-8859-1") as file: 121 | file_contents = file.read() 122 | #print(f"Successfully read file contents with ISO-8859-1: {item_path}") 123 | except Exception as e: 124 | print(f"Failed to read file: {item_path}, Error: {str(e)}") 125 | file_contents = f"Error reading file: {str(e)}" 126 | except PermissionError: 127 | print(f"Permission denied for file: {item_path}") 128 | file_contents = "Permission denied" 129 | except Exception as e: 130 | print(f"Unexpected error reading file: {item_path}, Error: {str(e)}") 131 | file_contents = f"Unexpected error: {str(e)}" 132 | 133 | file_structure["children"].append({ 134 | "name": item, 135 | "type": "file", 136 | "contents": file_contents 137 | }) 138 | 139 | print(f"Completed file structure generation for path: {abs_base_path}") 140 | return file_structure 141 | 142 | except PermissionError as e: 143 | print(f"Permission error accessing directory or its contents: {str(e)}") 144 | raise 145 | except FileNotFoundError as e: 146 | print(f"File or directory not found: {str(e)}") 147 | raise 148 | except NotADirectoryError as e: 149 | print(f"Not a directory error: {str(e)}") 150 | raise 151 | except OSError as e: 152 | print(f"OS error accessing directory or its contents: {str(e)}") 153 | raise 154 | except Exception as e: 155 | print(f"Unexpected error in generate_directory_tree: {str(e)}") 156 | raise 157 | 158 | @staticmethod 159 | def read_file_contents(full_file_path): 160 | """ 161 | Retrieve the contents of a file at the specified path. 162 | 163 | Args: 164 | full_file_path (str): The full path to the file. 165 | 166 | Returns: 167 | str: The contents of the file if successfully read, None otherwise. 168 | 169 | Raises: 170 | IOError: If there's an error reading the file. 171 | """ 172 | print(f"Attempting to read file contents from: {full_file_path}") 173 | 174 | try: 175 | with open(full_file_path, 'r', encoding='utf-8') as file: 176 | file_contents = file.read() 177 | print("File contents successfully retrieved.") 178 | return file_contents 179 | except FileNotFoundError: 180 | print(f"Error: File not found at path: {full_file_path}") 181 | print(f"FileNotFoundError at FileTools.read_file_contents: {full_file_path}") 182 | return None 183 | except IOError as e: 184 | print(f"Error reading file: {e}") 185 | print(f"IOError while reading file at FileTools.read_file_contents: {full_file_path}. Error: {str(e)}") 186 | return None 187 | except UnicodeDecodeError: 188 | print(f"Error: Unable to decode file contents using UTF-8 encoding: {full_file_path}") 189 | print(f"UnicodeDecodeError at FileTools.read_file_contents: Attempting to read with ISO-8859-1 encoding") 190 | try: 191 | with open(full_file_path, 'r', encoding='iso-8859-1') as file: 192 | file_contents = file.read() 193 | #print("File contents successfully retrieved using ISO-8859-1 encoding.") 194 | return file_contents 195 | except Exception as e: 196 | print(f"Error: Failed to read file with ISO-8859-1 encoding: {e}") 197 | #print(f"Error reading file with ISO-8859-1 encoding: {full_file_path}. Error: {str(e)}") 198 | return None 199 | except Exception as e: 200 | print(f"Unexpected error occurred while reading file: {e}") 201 | print(f"Unexpected error in FileTools.read_file_contents: {full_file_path}. Error: {str(e)}") 202 | return None 203 | 204 | @staticmethod 205 | def read_csv(file_path: str) -> List[Dict[str, Any]]: 206 | """ 207 | Read a CSV file and return its contents as a list of dictionaries. 208 | 209 | Args: 210 | file_path (str): The path to the CSV file. 211 | 212 | Returns: 213 | List[Dict[str, Any]]: A list of dictionaries, where each dictionary represents a row in the CSV. 214 | 215 | Raises: 216 | FileNotFoundError: If the specified file is not found. 217 | csv.Error: If there's an error parsing the CSV file. 218 | """ 219 | try: 220 | with open(file_path, 'r', newline='', encoding='utf-8') as csvfile: 221 | reader = csv.DictReader(csvfile) 222 | return [row for row in reader] 223 | except FileNotFoundError: 224 | print(f"Error: CSV file not found at {file_path}") 225 | return (f"Error: CSV file not found at {file_path}") 226 | except csv.Error as e: 227 | print(f"Error parsing CSV file: {e}") 228 | return (f"Error parsing CSV file: {e}") 229 | 230 | @staticmethod 231 | def read_json(file_path: str) -> Union[Dict[str, Any], List[Any]]: 232 | """ 233 | Read a JSON file and return its contents. 234 | 235 | Args: 236 | file_path (str): The path to the JSON file. 237 | 238 | Returns: 239 | Union[Dict[str, Any], List[Any]]: The parsed JSON data. 240 | 241 | Raises: 242 | FileNotFoundError: If the specified file is not found. 243 | json.JSONDecodeError: If there's an error parsing the JSON file. 244 | """ 245 | try: 246 | with open(file_path, 'r', encoding='utf-8') as jsonfile: 247 | return json.load(jsonfile) 248 | except FileNotFoundError: 249 | print(f"Error: JSON file not found at {file_path}") 250 | return (f"Error: JSON file not found at {file_path}") 251 | except json.JSONDecodeError as e: 252 | print(f"Error parsing JSON file: {e}") 253 | return (f"Error parsing JSON file: {e}") 254 | 255 | @staticmethod 256 | def read_xml(file_path: str) -> ET.Element: 257 | """ 258 | Read an XML file and return its contents as an ElementTree. 259 | 260 | Args: 261 | file_path (str): The path to the XML file. 262 | 263 | Returns: 264 | ET.Element: The root element of the parsed XML. 265 | 266 | Raises: 267 | FileNotFoundError: If the specified file is not found. 268 | ET.ParseError: If there's an error parsing the XML file. 269 | """ 270 | try: 271 | tree = ET.parse(file_path) 272 | return tree.getroot() 273 | except FileNotFoundError: 274 | print(f"Error: XML file not found at {file_path}") 275 | return (f"Error: XML file not found at {file_path}") 276 | except ET.ParseError as e: 277 | print(f"Error parsing XML file: {e}") 278 | return (f"Error parsing XML file: {e}") 279 | 280 | @staticmethod 281 | def read_yaml(file_path: str) -> Union[Dict[str, Any], List[Any]]: 282 | """ 283 | Read a YAML file and return its contents. 284 | 285 | Args: 286 | file_path (str): The path to the YAML file. 287 | 288 | Returns: 289 | Union[Dict[str, Any], List[Any]]: The parsed YAML data. 290 | 291 | Raises: 292 | FileNotFoundError: If the specified file is not found. 293 | yaml.YAMLError: If there's an error parsing the YAML file. 294 | """ 295 | try: 296 | with open(file_path, 'r', encoding='utf-8') as yamlfile: 297 | return yaml.safe_load(yamlfile) 298 | except FileNotFoundError: 299 | print(f"Error: YAML file not found at {file_path}") 300 | return (f"Error: YAML file not found at {file_path}") 301 | except yaml.YAMLError as e: 302 | print(f"Error parsing YAML file: {e}") 303 | return (f"Error parsing YAML file: {e}") 304 | 305 | @staticmethod 306 | def search_csv(file_path: str, search_column: str, search_value: Any) -> List[Dict[str, Any]]: 307 | """ 308 | Search for a specific value in a CSV file and return matching rows. 309 | 310 | Args: 311 | file_path (str): The path to the CSV file. 312 | search_column (str): The name of the column to search in. 313 | search_value (Any): The value to search for. 314 | 315 | Returns: 316 | List[Dict[str, Any]]: A list of dictionaries representing matching rows. 317 | 318 | Raises: 319 | FileNotFoundError: If the specified file is not found. 320 | KeyError: If the specified search column doesn't exist in the CSV. 321 | """ 322 | try: 323 | with open(file_path, 'r', newline='', encoding='utf-8') as csvfile: 324 | reader = csv.DictReader(csvfile) 325 | 326 | # Check if the search_column exists 327 | if search_column not in reader.fieldnames: 328 | raise KeyError(f"Column '{search_column}' not found in the CSV file.") 329 | 330 | # Search for matching rows 331 | return [row for row in reader if row[search_column] == str(search_value)] 332 | except FileNotFoundError: 333 | print(f"Error: CSV file not found at {file_path}") 334 | return (f"Error: CSV file not found at {file_path}") 335 | except KeyError as e: 336 | print(f"Error: {e}") 337 | return (f"Error: {e}") 338 | 339 | @staticmethod 340 | def search_json(data: Union[Dict[str, Any], List[Any]], search_key: str, search_value: Any) -> List[Any]: 341 | """ 342 | Search for a specific key-value pair in a JSON structure and return matching items. 343 | 344 | Args: 345 | data (Union[Dict[str, Any], List[Any]]): The JSON data to search. 346 | search_key (str): The key to search for. 347 | search_value (Any): The value to match. 348 | 349 | Returns: 350 | List[Any]: A list of items that match the search criteria. 351 | """ 352 | results = [] 353 | 354 | def search_recursive(item): 355 | if isinstance(item, dict): 356 | if search_key in item and item[search_key] == search_value: 357 | results.append(item) 358 | for value in item.values(): 359 | search_recursive(value) 360 | elif isinstance(item, list): 361 | for element in item: 362 | search_recursive(element) 363 | 364 | search_recursive(data) 365 | return results 366 | 367 | @staticmethod 368 | def search_xml(root: ET.Element, tag: str, attribute: str = None, value: str = None) -> List[ET.Element]: 369 | """ 370 | Search for specific elements in an XML structure. 371 | 372 | Args: 373 | root (ET.Element): The root element of the XML to search. 374 | tag (str): The tag name to search for. 375 | attribute (str, optional): The attribute name to match. Defaults to None. 376 | value (str, optional): The attribute value to match. Defaults to None. 377 | 378 | Returns: 379 | List[ET.Element]: A list of matching XML elements. 380 | """ 381 | if attribute and value: 382 | return root.findall(f".//*{tag}[@{attribute}='{value}']") 383 | else: 384 | return root.findall(f".//*{tag}") 385 | 386 | @staticmethod 387 | def search_yaml(data: Union[Dict[str, Any], List[Any]], search_key: str, search_value: Any) -> List[Any]: 388 | """ 389 | Search for a specific key-value pair in a YAML structure and return matching items. 390 | 391 | Args: 392 | data (Union[Dict[str, Any], List[Any]]): The YAML data to search. 393 | search_key (str): The key to search for. 394 | search_value (Any): The value to match. 395 | 396 | Returns: 397 | List[Any]: A list of items that match the search criteria. 398 | """ 399 | # YAML is parsed into Python data structures, so we can reuse the JSON search method 400 | return FileTools.search_json(data, search_key, search_value) 401 | 402 | 403 | @staticmethod 404 | def write_markdown(file_path: str, content: str) -> str: 405 | """ 406 | Write content to a markdown file. 407 | 408 | Args: 409 | file_path (str): The path to the file to write to. 410 | content (str): The content to write to the file. 411 | Returns: 412 | str: Confirmation with the path to the file that was written to. 413 | """ 414 | # Ensure the directory exists 415 | directory = os.path.dirname(file_path) 416 | if directory and not os.path.exists(directory): 417 | os.makedirs(directory) 418 | 419 | with open(file_path, 'w') as file: 420 | file.write(content) 421 | 422 | # Get the absolute path 423 | abs_path = os.path.abspath(file_path) 424 | return f"Markdown file written to {abs_path}" 425 | 426 | 427 | @staticmethod 428 | def write_csv(file_path: str, data: List[List[str]], delimiter: str = ',') -> Union[bool, str]: 429 | """ 430 | Write data to a CSV file. 431 | 432 | Args: 433 | file_path (str): The path to the CSV file. 434 | data (List[List[str]]): The data to write to the CSV file. 435 | delimiter (str, optional): The delimiter to use in the CSV file. Defaults to ','. 436 | 437 | Returns: 438 | Union[bool, str]: True if the data was successfully written, or an error message as a string. 439 | """ 440 | try: 441 | with open(file_path, 'w', newline='') as file: 442 | writer = csv.writer(file, delimiter=delimiter) 443 | writer.writerows(data) 444 | return f"Successfully wrote CSV file to {file_path}." 445 | except Exception as e: 446 | error_msg = f"Error: An unexpected error occurred while writing the CSV file: {e}" 447 | print(error_msg) 448 | return error_msg 449 | 450 | @staticmethod 451 | def get_column(data: List[List[str]], column_index: int) -> Union[List[str], str]: 452 | """ 453 | Extract a specific column from a list of lists representing CSV data. 454 | 455 | Args: 456 | data (List[List[str]]): The CSV data as a list of lists. 457 | column_index (int): The index of the column to extract (0-based). 458 | 459 | Returns: 460 | Union[List[str], str]: The extracted column as a list of strings, or an error message as a string. 461 | """ 462 | try: 463 | if not data: 464 | error_msg = "Error: Input data is empty." 465 | print(error_msg) 466 | return error_msg 467 | 468 | num_columns = len(data[0]) 469 | if column_index < 0 or column_index >= num_columns: 470 | error_msg = f"Error: Invalid column index. Must be between 0 and {num_columns - 1}." 471 | print(error_msg) 472 | return error_msg 473 | 474 | column = [row[column_index] for row in data] 475 | return column 476 | except IndexError: 477 | error_msg = "Error: Inconsistent number of columns in the input data." 478 | print(error_msg) 479 | return error_msg 480 | except Exception as e: 481 | error_msg = f"Error: An unexpected error occurred while extracting the column: {e}" 482 | print(error_msg) 483 | return error_msg 484 | 485 | @staticmethod 486 | def filter_rows(data: List[List[str]], column_index: int, value: str) -> Union[List[List[str]], str]: 487 | """ 488 | Filter rows in a list of lists representing CSV data based on a specific column value. 489 | 490 | Args: 491 | data (List[List[str]]): The CSV data as a list of lists. 492 | column_index (int): The index of the column to filter on (0-based). 493 | value (str): The value to match in the specified column. 494 | 495 | Returns: 496 | Union[List[List[str]], str]: The filtered rows as a list of lists, or an error message as a string. 497 | """ 498 | try: 499 | if not data: 500 | error_msg = "Error: Input data is empty." 501 | print(error_msg) 502 | return error_msg 503 | 504 | num_columns = len(data[0]) 505 | if column_index < 0 or column_index >= num_columns: 506 | error_msg = f"Error: Invalid column index. Must be between 0 and {num_columns - 1}." 507 | print(error_msg) 508 | return error_msg 509 | 510 | filtered_rows = [row for row in data if row[column_index] == value] 511 | return filtered_rows 512 | except IndexError: 513 | error_msg = "Error: Inconsistent number of columns in the input data." 514 | print(error_msg) 515 | return error_msg 516 | except Exception as e: 517 | error_msg = f"Error: An unexpected error occurred while filtering rows: {e}" 518 | print(error_msg) 519 | return error_msg 520 | 521 | @staticmethod 522 | def peek_csv(file_path: str, num_lines: int = 5) -> Union[List[List[str]], str]: 523 | """ 524 | Peek at the first few lines of a CSV file. 525 | 526 | Args: 527 | file_path (str): The path to the CSV file. 528 | num_lines (int, optional): The number of lines to peek. Defaults to 5. 529 | 530 | Returns: 531 | Union[List[List[str]], str]: The first few lines of the CSV as a list of lists, or an error message as a string. 532 | """ 533 | try: 534 | with open(file_path, 'r', newline='') as csvfile: 535 | csv_reader = csv.reader(csvfile) 536 | peeked_data = [next(csv_reader) for _ in range(num_lines)] 537 | return peeked_data 538 | except FileNotFoundError: 539 | error_msg = f"Error: File not found at {file_path}" 540 | print(error_msg) 541 | return error_msg 542 | except csv.Error as e: 543 | error_msg = f"Error: CSV parsing error - {str(e)}" 544 | print(error_msg) 545 | return error_msg 546 | except Exception as e: 547 | error_msg = f"Error: An unexpected error occurred while peeking at the CSV: {str(e)}" 548 | print(error_msg) 549 | return error_msg -------------------------------------------------------------------------------- /taskflowai/tools/fred_tools.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 TaskFlowAI Contributors. Licensed under Apache License 2.0. 2 | 3 | import os 4 | from typing import Dict, Any, List 5 | 6 | def check_pandas(): 7 | try: 8 | import pandas as pd 9 | return pd 10 | except ImportError: 11 | raise ImportError("pandas is required for FredTools. Install with `pip install taskflowai[fred_tools]`") 12 | 13 | def check_fredapi(): 14 | try: 15 | from fredapi import Fred 16 | return Fred 17 | except ImportError: 18 | raise ImportError("fredapi is required for FredTools. Install with `pip install taskflowai[fred_tools]`") 19 | 20 | 21 | class FredTools: 22 | @staticmethod 23 | def economic_indicator_analysis(indicator_ids: List[str], start_date: str, end_date: str) -> Dict[str, Any]: 24 | """ 25 | Perform a comprehensive analysis of economic indicators. 26 | 27 | Args: 28 | indicator_ids (List[str]): List of economic indicator series IDs. 29 | start_date (str): Start date for the analysis (YYYY-MM-DD). 30 | end_date (str): End date for the analysis (YYYY-MM-DD). 31 | 32 | Returns: 33 | Dict[str, Any]: A dictionary containing the analysis results for each indicator. 34 | """ 35 | pd = check_pandas() 36 | Fred = check_fredapi() 37 | fred = Fred(api_key=os.getenv('FRED_API_KEY')) 38 | 39 | results = {} 40 | 41 | for indicator_id in indicator_ids: 42 | series = fred.get_series(indicator_id, observation_start=start_date, observation_end=end_date) 43 | series = series.dropna() 44 | 45 | if len(series) > 0: 46 | pct_change = series.pct_change() 47 | annual_change = series.resample('YE').last().pct_change() 48 | 49 | results[indicator_id] = { 50 | "indicator": indicator_id, 51 | "title": fred.get_series_info(indicator_id).title, 52 | "start_date": start_date, 53 | "end_date": end_date, 54 | "min_value": series.min(), 55 | "max_value": series.max(), 56 | "mean_value": series.mean(), 57 | "std_dev": series.std(), 58 | "pct_change_mean": pct_change.mean(), 59 | "pct_change_std": pct_change.std(), 60 | "annual_change_mean": annual_change.mean(), 61 | "annual_change_std": annual_change.std(), 62 | "last_value": series.iloc[-1], 63 | "last_pct_change": pct_change.iloc[-1], 64 | "last_annual_change": annual_change.iloc[-1] 65 | } 66 | else: 67 | results[indicator_id] = None 68 | 69 | return results 70 | 71 | @staticmethod 72 | def yield_curve_analysis(treasury_maturities: List[str], start_date: str, end_date: str) -> Dict[str, Any]: 73 | """ 74 | Perform an analysis of the US Treasury yield curve. 75 | 76 | Args: 77 | treasury_maturities (List[str]): List of Treasury maturity series IDs. 78 | start_date (str): Start date for the analysis (YYYY-MM-DD). 79 | end_date (str): End date for the analysis (YYYY-MM-DD). 80 | 81 | Returns: 82 | Dict[str, Any]: A dictionary containing the yield curve analysis results. 83 | """ 84 | pd = check_pandas() 85 | Fred = check_fredapi() 86 | fred = Fred(api_key=os.getenv('FRED_API_KEY')) 87 | 88 | yield_data = {} 89 | 90 | for maturity in treasury_maturities: 91 | series = fred.get_series(maturity, observation_start=start_date, observation_end=end_date) 92 | yield_data[maturity] = series 93 | 94 | yield_df = pd.DataFrame(yield_data) 95 | yield_df = yield_df.dropna() 96 | 97 | if len(yield_df) > 0: 98 | yield_curve_slopes = {} 99 | for i in range(len(treasury_maturities) - 1): 100 | short_maturity = treasury_maturities[i] 101 | long_maturity = treasury_maturities[i + 1] 102 | slope = yield_df[long_maturity] - yield_df[short_maturity] 103 | yield_curve_slopes[f"{short_maturity}_to_{long_maturity}"] = slope 104 | 105 | yield_curve_slopes_df = pd.DataFrame(yield_curve_slopes) 106 | 107 | results = { 108 | "start_date": start_date, 109 | "end_date": end_date, 110 | "yield_data": yield_df, 111 | "yield_curve_slopes": yield_curve_slopes_df, 112 | "inverted_yield_curve": yield_curve_slopes_df.min().min() < 0 113 | } 114 | else: 115 | results = None 116 | 117 | return results 118 | 119 | @staticmethod 120 | def economic_news_sentiment_analysis(news_series_id: str, start_date: str, end_date: str) -> Dict[str, Any]: 121 | """ 122 | Perform sentiment analysis on economic news series. 123 | 124 | Args: 125 | news_series_id (str): Economic news series ID. 126 | start_date (str): Start date for the analysis (YYYY-MM-DD). 127 | end_date (str): End date for the analysis (YYYY-MM-DD). 128 | 129 | Returns: 130 | Dict[str, Any]: A dictionary containing the sentiment analysis results. 131 | """ 132 | pd = check_pandas() 133 | Fred = check_fredapi() 134 | fred = Fred(api_key=os.getenv('FRED_API_KEY')) 135 | 136 | series = fred.get_series(news_series_id, observation_start=start_date, observation_end=end_date) 137 | series = series.dropna() 138 | 139 | if len(series) > 0: 140 | sentiment_scores = series.apply(lambda x: 1 if x > 0 else (-1 if x < 0 else 0)) 141 | sentiment_counts = sentiment_scores.value_counts() 142 | 143 | results = { 144 | "series_id": news_series_id, 145 | "title": fred.get_series_info(news_series_id).title, 146 | "start_date": start_date, 147 | "end_date": end_date, 148 | "positive_sentiment_count": sentiment_counts.get(1, 0), 149 | "negative_sentiment_count": sentiment_counts.get(-1, 0), 150 | "neutral_sentiment_count": sentiment_counts.get(0, 0), 151 | "net_sentiment_score": sentiment_scores.sum() 152 | } 153 | else: 154 | results = None 155 | 156 | return results 157 | -------------------------------------------------------------------------------- /taskflowai/tools/github_tools.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 TaskFlowAI Contributors. Licensed under Apache License 2.0. 2 | 3 | import requests 4 | from typing import Dict, Any, List 5 | import base64 6 | 7 | class GitHubTools: 8 | 9 | @staticmethod 10 | def get_user_info(username: str) -> Dict[str, Any]: 11 | """ 12 | Get public information about a GitHub user. 13 | 14 | Args: 15 | username (str): The GitHub username of the user. 16 | 17 | Returns: 18 | Dict[str, Any]: A dictionary containing the user's public information. 19 | Keys may include 'login', 'id', 'name', 'company', 'blog', 20 | 'location', 'email', 'hireable', 'bio', 'public_repos', 21 | 'public_gists', 'followers', 'following', etc. 22 | 23 | Raises: 24 | requests.exceptions.HTTPError: If the API request fails. 25 | """ 26 | base_url = "https://api.github.com" 27 | headers = {"Accept": "application/vnd.github+json"} 28 | url = f"{base_url}/users/{username}" 29 | response = requests.get(url, headers=headers) 30 | response.raise_for_status() 31 | return response.json() 32 | 33 | @staticmethod 34 | def list_user_repos(username: str) -> List[Dict[str, Any]]: 35 | """ 36 | List public repositories for the specified user. 37 | 38 | Args: 39 | username (str): The GitHub username of the user. 40 | 41 | Returns: 42 | List[Dict[str, Any]]: A list of dictionaries, each containing information 43 | about a public repository. Keys may include 'id', 44 | 'node_id', 'name', 'full_name', 'private', 'owner', 45 | 'html_url', 'description', 'fork', 'url', 'created_at', 46 | 'updated_at', 'pushed_at', 'homepage', 'size', 47 | 'stargazers_count', 'watchers_count', 'language', 48 | 'forks_count', 'open_issues_count', etc. 49 | 50 | Raises: 51 | requests.exceptions.HTTPError: If the API request fails. 52 | """ 53 | base_url = "https://api.github.com" 54 | headers = {"Accept": "application/vnd.github+json"} 55 | url = f"{base_url}/users/{username}/repos" 56 | response = requests.get(url, headers=headers) 57 | response.raise_for_status() 58 | return response.json() 59 | 60 | @staticmethod 61 | def list_repo_issues(owner: str, repo: str, state: str = "open") -> List[Dict[str, Any]]: 62 | """ 63 | List issues in the specified public repository. 64 | 65 | Args: 66 | owner (str): The owner (user or organization) of the repository. 67 | repo (str): The name of the repository. 68 | state (str, optional): The state of the issues to return. Can be either 'open', 'closed', or 'all'. Defaults to 'open'. 69 | 70 | Returns: 71 | List[Dict[str, Any]]: A list of dictionaries, each containing essential information about an issue. 72 | 73 | Raises: 74 | requests.exceptions.HTTPError: If the API request fails. 75 | """ 76 | base_url = "https://api.github.com" 77 | headers = {"Accept": "application/vnd.github+json"} 78 | url = f"{base_url}/repos/{owner}/{repo}/issues" 79 | params = {"state": state} 80 | response = requests.get(url, headers=headers, params=params) 81 | response.raise_for_status() 82 | 83 | def simplify_issue(issue: Dict[str, Any]) -> Dict[str, Any]: 84 | return { 85 | "number": issue["number"], 86 | "title": issue["title"], 87 | "state": issue["state"], 88 | "created_at": issue["created_at"], 89 | "updated_at": issue["updated_at"], 90 | "html_url": issue["html_url"], 91 | "user": { 92 | "login": issue["user"]["login"], 93 | "id": issue["user"]["id"] 94 | }, 95 | "comments": issue["comments"], 96 | "pull_request": "pull_request" in issue 97 | } 98 | 99 | return [simplify_issue(issue) for issue in response.json()] 100 | 101 | @staticmethod 102 | def get_issue_comments(owner: str, repo: str, issue_number: int) -> List[Dict[str, Any]]: 103 | """ 104 | Get essential information about an issue and its comments in a repository. 105 | 106 | Args: 107 | owner (str): The owner (user or organization) of the repository. 108 | repo (str): The name of the repository. 109 | issue_number (int): The number of the issue. 110 | 111 | Returns: 112 | List[Dict[str, Any]]: A list of dictionaries, containing the issue description and all comments. 113 | 114 | Raises: 115 | requests.exceptions.HTTPError: If the API request fails. 116 | """ 117 | base_url = "https://api.github.com" 118 | headers = {"Accept": "application/vnd.github+json"} 119 | 120 | # Get issue details 121 | issue_url = f"{base_url}/repos/{owner}/{repo}/issues/{issue_number}" 122 | issue_response = requests.get(issue_url, headers=headers) 123 | issue_response.raise_for_status() 124 | issue_data = issue_response.json() 125 | 126 | # Get comments 127 | comments_url = f"{issue_url}/comments" 128 | comments_response = requests.get(comments_url, headers=headers) 129 | comments_response.raise_for_status() 130 | comments_data = comments_response.json() 131 | 132 | def simplify_data(data: Dict[str, Any], is_issue: bool = False) -> Dict[str, Any]: 133 | return { 134 | "id": data["id"], 135 | "user": { 136 | "login": data["user"]["login"], 137 | "id": data["user"]["id"] 138 | }, 139 | "created_at": data["created_at"], 140 | "updated_at": data["updated_at"], 141 | "body": data["body"], 142 | "type": "issue" if is_issue else "comment" 143 | } 144 | 145 | result = [simplify_data(issue_data, is_issue=True)] 146 | result.extend([simplify_data(comment) for comment in comments_data]) 147 | 148 | return result 149 | 150 | @staticmethod 151 | def get_repo_details(owner: str, repo: str) -> Dict[str, Any]: 152 | """ 153 | Get detailed information about a specific GitHub repository. 154 | 155 | Args: 156 | owner (str): The username or organization name that owns the repository. 157 | repo (str): The name of the repository. 158 | 159 | Returns: 160 | Dict[str, Any]: A dictionary containing detailed information about the repository. 161 | Keys may include 'id', 'node_id', 'name', 'full_name', 'private', 162 | 'owner', 'html_url', 'description', 'fork', 'url', 'created_at', 163 | 'updated_at', 'pushed_at', 'homepage', 'size', 'stargazers_count', 164 | 'watchers_count', 'language', 'forks_count', 'open_issues_count', 165 | 'master_branch', 'default_branch', 'topics', 'has_issues', 'has_projects', 166 | 'has_wiki', 'has_pages', 'has_downloads', 'archived', 'disabled', etc. 167 | 168 | Raises: 169 | requests.exceptions.HTTPError: If the API request fails. 170 | """ 171 | base_url = "https://api.github.com" 172 | headers = {"Accept": "application/vnd.github+json"} 173 | url = f"{base_url}/repos/{owner}/{repo}" 174 | response = requests.get(url, headers=headers) 175 | response.raise_for_status() 176 | return response.json() 177 | 178 | @staticmethod 179 | def list_repo_contributors(owner: str, repo: str) -> List[Dict[str, Any]]: 180 | """ 181 | List contributors to a specific GitHub repository. 182 | 183 | Args: 184 | owner (str): The username or organization name that owns the repository. 185 | repo (str): The name of the repository. 186 | 187 | Returns: 188 | List[Dict[str, Any]]: A list of dictionaries, each containing information about a contributor. 189 | Keys may include 'login', 'id', 'node_id', 'avatar_url', 'gravatar_id', 190 | 'url', 'html_url', 'followers_url', 'following_url', 'gists_url', 191 | 'starred_url', 'subscriptions_url', 'organizations_url', 'repos_url', 192 | 'events_url', 'received_events_url', 'type', 'site_admin', 193 | 'contributions', etc. 194 | 195 | Raises: 196 | requests.exceptions.HTTPError: If the API request fails. 197 | """ 198 | base_url = "https://api.github.com" 199 | headers = {"Accept": "application/vnd.github+json"} 200 | url = f"{base_url}/repos/{owner}/{repo}/contributors" 201 | response = requests.get(url, headers=headers) 202 | response.raise_for_status() 203 | return response.json() 204 | 205 | @staticmethod 206 | def get_repo_readme(owner: str, repo: str) -> Dict[str, str]: 207 | """ 208 | Get the README content of a GitHub repository. 209 | 210 | Args: 211 | owner (str): The username or organization name that owns the repository. 212 | repo (str): The name of the repository. 213 | 214 | Returns: 215 | Dict[str, str]: A dictionary containing the README content. 216 | The key is 'content' and the value is the raw text of the README file. 217 | 218 | Raises: 219 | requests.exceptions.HTTPError: If the API request fails. 220 | 221 | Note: 222 | This method retrieves the raw content of the README file, regardless of its format 223 | (e.g., .md, .rst, .txt). The content is not rendered or processed in any way. 224 | """ 225 | base_url = "https://api.github.com" 226 | headers = {"Accept": "application/vnd.github+json"} 227 | url = f"{base_url}/repos/{owner}/{repo}/readme" 228 | response = requests.get(url, headers=headers) 229 | response.raise_for_status() 230 | return {"content": response.text} 231 | 232 | @staticmethod 233 | def search_repositories(query: str, sort: str = "stars", max_results: int = 10) -> Dict[str, Any]: 234 | """ 235 | Search for repositories on GitHub with a maximum number of results. 236 | 237 | Args: 238 | query (str): Search keywords and qualifiers. 239 | sort (str): Can be one of: stars, forks, help-wanted-issues, updated. Default: stars 240 | max_results (int): Maximum number of results to return. Default: 10 241 | 242 | Returns: 243 | Dict[str, Any]: Dictionary containing search results and metadata. 244 | 245 | Raises: 246 | requests.exceptions.HTTPError: If the API request fails. 247 | """ 248 | base_url = "https://api.github.com" 249 | headers = {"Accept": "application/vnd.github+json"} 250 | url = f"{base_url}/search/repositories" 251 | params = { 252 | "q": query, 253 | "sort": sort, 254 | "order": "desc", 255 | "per_page": min(max_results, 100) # GitHub API allows max 100 items per page 256 | } 257 | 258 | results = [] 259 | while len(results) < max_results: 260 | response = requests.get(url, headers=headers, params=params) 261 | response.raise_for_status() 262 | data = response.json() 263 | 264 | results.extend(data['items'][:max_results - len(results)]) 265 | 266 | if 'next' not in response.links: 267 | break 268 | 269 | url = response.links['next']['url'] 270 | params = {} # Clear params as they're included in the next URL 271 | 272 | def simplify_repo(repo: Dict[str, Any]) -> Dict[str, Any]: 273 | return { 274 | "id": repo["id"], 275 | "name": repo["name"], 276 | "full_name": repo["full_name"], 277 | "owner": { 278 | "login": repo["owner"]["login"], 279 | "id": repo["owner"]["id"] 280 | }, 281 | "html_url": repo["html_url"], 282 | "description": repo["description"], 283 | "created_at": repo["created_at"], 284 | "updated_at": repo["updated_at"], 285 | "stargazers_count": repo["stargazers_count"], 286 | "forks_count": repo["forks_count"], 287 | "language": repo["language"], 288 | "topics": repo["topics"], 289 | "license": repo["license"]["name"] if repo["license"] else None, 290 | "open_issues_count": repo["open_issues_count"] 291 | } 292 | 293 | simplified_results = [simplify_repo(repo) for repo in results] 294 | 295 | return { 296 | "total_count": data['total_count'], 297 | "incomplete_results": data['incomplete_results'], 298 | "items": simplified_results[:max_results] 299 | } 300 | 301 | 302 | @staticmethod 303 | def get_repo_contents(owner: str, repo: str, path: str = "") -> List[Dict[str, Any]]: 304 | """ 305 | Get contents of a repository directory or file. 306 | 307 | Args: 308 | owner (str): The owner (user or organization) of the repository. 309 | repo (str): The name of the repository. 310 | path (str, optional): The directory or file path. Defaults to root directory. 311 | 312 | Returns: 313 | List[Dict[str, Any]]: A list of dictionaries containing information about the contents. 314 | 315 | Raises: 316 | requests.exceptions.HTTPError: If the API request fails. 317 | """ 318 | base_url = "https://api.github.com" 319 | headers = {"Accept": "application/vnd.github+json"} 320 | url = f"{base_url}/repos/{owner}/{repo}/contents/{path}" 321 | response = requests.get(url, headers=headers) 322 | response.raise_for_status() 323 | return response.json() 324 | 325 | @staticmethod 326 | def get_file_content(owner: str, repo: str, path: str) -> str: 327 | """ 328 | Get the content of a specific file in the repository. 329 | 330 | Args: 331 | owner (str): The owner (user or organization) of the repository. 332 | repo (str): The name of the repository. 333 | path (str): The file path within the repository. 334 | 335 | Returns: 336 | str: The content of the file. 337 | 338 | Raises: 339 | requests.exceptions.HTTPError: If the API request fails. 340 | """ 341 | base_url = "https://api.github.com" 342 | headers = {"Accept": "application/vnd.github+json"} 343 | url = f"{base_url}/repos/{owner}/{repo}/contents/{path}" 344 | response = requests.get(url, headers=headers) 345 | response.raise_for_status() 346 | content = response.json()["content"] 347 | return base64.b64decode(content).decode('utf-8') 348 | 349 | @staticmethod 350 | def get_directory_structure(owner: str, repo: str, path: str = "") -> Dict[str, Any]: 351 | """ 352 | Get the directory structure of a repository. 353 | 354 | Args: 355 | owner (str): The owner (user or organization) of the repository. 356 | repo (str): The name of the repository. 357 | path (str, optional): The directory path. Defaults to root directory. 358 | 359 | Returns: 360 | Dict[str, Any]: A nested dictionary representing the directory structure. 361 | 362 | Raises: 363 | requests.exceptions.HTTPError: If the API request fails. 364 | """ 365 | contents = GitHubTools.get_repo_contents(owner, repo, path) 366 | structure = {} 367 | for item in contents: 368 | if item['type'] == 'dir': 369 | structure[item['name']] = GitHubTools.get_directory_structure(owner, repo, item['path']) 370 | else: 371 | structure[item['name']] = item['type'] 372 | return structure 373 | 374 | @staticmethod 375 | def search_code(query: str, owner: str, repo: str, max_results: int = 10) -> Dict[str, Any]: 376 | """ 377 | Search for code within a specific repository. 378 | 379 | Args: 380 | query (str): The search query. 381 | owner (str): The owner (user or organization) of the repository. 382 | repo (str): The name of the repository. 383 | max_results (int, optional): Maximum number of results to return. Defaults to 10. 384 | 385 | Returns: 386 | Dict[str, Any]: A dictionary containing search results and metadata. 387 | 388 | Raises: 389 | requests.exceptions.HTTPError: If the API request fails. 390 | """ 391 | base_url = "https://api.github.com" 392 | headers = {"Accept": "application/vnd.github+json"} 393 | url = f"{base_url}/search/code" 394 | params = { 395 | "q": f"{query} repo:{owner}/{repo}", 396 | "per_page": min(max_results, 100) 397 | } 398 | 399 | response = requests.get(url, headers=headers, params=params) 400 | response.raise_for_status() 401 | data = response.json() 402 | 403 | def simplify_code_result(item: Dict[str, Any]) -> Dict[str, Any]: 404 | return { 405 | "name": item["name"], 406 | "path": item["path"], 407 | "sha": item["sha"], 408 | "url": item["html_url"], 409 | "repository": { 410 | "name": item["repository"]["name"], 411 | "full_name": item["repository"]["full_name"], 412 | "owner": item["repository"]["owner"]["login"] 413 | } 414 | } 415 | 416 | simplified_results = [simplify_code_result(item) for item in data['items'][:max_results]] 417 | 418 | return { 419 | "total_count": data['total_count'], 420 | "incomplete_results": data['incomplete_results'], 421 | "items": simplified_results 422 | } 423 | 424 | -------------------------------------------------------------------------------- /taskflowai/tools/langchain_tools.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 TaskFlowAI Contributors. Licensed under Apache License 2.0. 2 | 3 | from typing import List 4 | from langchain_core.tools import module 5 | 6 | class LangchainTools: 7 | @staticmethod 8 | def _check_dependencies(): 9 | try: 10 | from langchain_core.tools import BaseTool 11 | from langchain_community.tools import _module_lookup 12 | except ImportError as e: 13 | raise ImportError( 14 | "Langchain dependencies are not installed. " 15 | "To use LangchainTools, install the required packages with " 16 | "'pip install taskflowai[langchain_tools]'\n" 17 | f"Original error: {e}" 18 | ) 19 | 20 | @staticmethod 21 | def _wrap(langchain_tool): 22 | LangchainTools._check_dependencies() 23 | # Import optional dependencies inside the method 24 | from typing import Any, Callable, Type 25 | from pydantic.v1 import BaseModel 26 | import json 27 | from langchain_core.tools import BaseTool 28 | 29 | # Now proceed with the implementation 30 | def wrapped_tool(**kwargs: Any) -> str: 31 | tool_instance = langchain_tool() 32 | # Convert kwargs to a single string input 33 | tool_input = json.dumps(kwargs) 34 | return tool_instance.run(tool_input) 35 | 36 | tool_instance = langchain_tool() 37 | name = getattr(tool_instance, 'name', langchain_tool.__name__) 38 | description = getattr(tool_instance, 'description', "No description available") 39 | 40 | # Build the docstring dynamically 41 | doc_parts = [ 42 | f"- {name}:", 43 | f" Description: {description}", 44 | ] 45 | 46 | args_schema = getattr(langchain_tool, 'args_schema', None) or getattr(tool_instance, 'args_schema', None) 47 | if args_schema and issubclass(args_schema, BaseModel): 48 | doc_parts.append(" Arguments:") 49 | for field_name, field in args_schema.__fields__.items(): 50 | field_desc = field.field_info.description or "No description" 51 | doc_parts.append(f" - {field_name}: {field_desc}") 52 | 53 | wrapped_tool.__name__ = name 54 | wrapped_tool.__doc__ = "\n".join(doc_parts) 55 | return wrapped_tool 56 | 57 | @classmethod 58 | def get_tool(cls, tool_name: str): 59 | cls._check_dependencies() 60 | from langchain_community.tools import _module_lookup 61 | import importlib 62 | 63 | if tool_name not in _module_lookup: 64 | raise ValueError(f"Unknown Langchain tool: {tool_name}") 65 | 66 | module_path = _module_lookup[tool_name] 67 | module = importlib.import_module(module_path) 68 | tool_class = getattr(module, tool_name) 69 | 70 | wrapped_tool = LangchainTools._wrap(tool_class) 71 | return wrapped_tool 72 | 73 | @classmethod 74 | def list_available_tools(cls) -> List[str]: 75 | """ 76 | List all available Langchain tools. 77 | 78 | Returns: 79 | List[str]: A list of names of all available Langchain tools. 80 | 81 | Raises: 82 | ImportError: If langchain-community is not installed. 83 | 84 | Example: 85 | >>> tools = LangchainTools.list_available_tools() 86 | >>> "WikipediaQueryRun" in tools 87 | True 88 | """ 89 | try: 90 | from langchain_community.tools import _module_lookup 91 | except ImportError: 92 | print("Error: langchain-community is not installed. Please install it using 'pip install langchain-community'.") 93 | return [] 94 | 95 | return list(_module_lookup.keys()) 96 | 97 | @classmethod 98 | def get_tool_info(cls, tool_name: str) -> dict: 99 | """ 100 | Retrieve information about a specific Langchain tool. 101 | 102 | Args: 103 | tool_name (str): The name of the Langchain tool. 104 | 105 | Returns: 106 | dict: A dictionary containing the tool's name, description, and module path. 107 | 108 | Raises: 109 | ValueError: If an unknown tool name is provided. 110 | ImportError: If langchain-community is not installed. 111 | 112 | Example: 113 | >>> info = LangchainTools.get_tool_info("WikipediaQueryRun") 114 | >>> "name" in info and "description" in info and "module_path" in info 115 | True 116 | """ 117 | cls._check_dependencies() 118 | try: 119 | from langchain_community.tools import _module_lookup 120 | except ImportError: 121 | raise ImportError("langchain-community is not installed. Please install it using 'pip install langchain-community'.") 122 | 123 | if tool_name not in _module_lookup: 124 | raise ValueError(f"Unknown Langchain tool: {tool_name}") 125 | 126 | module_path = _module_lookup[tool_name] 127 | import importlib 128 | module = importlib.import_module(module_path) 129 | tool_class = getattr(module, tool_name) 130 | 131 | tool_instance = tool_class() 132 | name = getattr(tool_instance, 'name', tool_class.__name__) 133 | description = getattr(tool_instance, 'description', "No description available") 134 | 135 | return { 136 | "name": name, 137 | "description": description, 138 | "module_path": module_path 139 | } 140 | -------------------------------------------------------------------------------- /taskflowai/tools/matplotlib_tools.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 TaskFlowAI Contributors. Licensed under Apache License 2.0. 2 | 3 | import warnings 4 | from typing import List, Union, Any 5 | 6 | def check_matplotlib_dependencies(): 7 | try: 8 | import numpy as np 9 | import matplotlib.pyplot as plt 10 | import matplotlib.dates as mdates 11 | return np, plt, mdates 12 | except ImportError: 13 | raise ImportError("Matplotlib dependencies are not installed. To use MatplotlibTools, install the required packages with 'pip install taskflowai[matplotlib_tools]'") 14 | 15 | class MatplotlibTools: 16 | @staticmethod 17 | def _check_dependencies(): 18 | if 'np' not in globals() or 'plt' not in globals(): 19 | raise ImportError("Matplotlib is not installed. To use MatplotlibTools, install the required packages with 'pip install taskflowai[matplotlib_tools]'") 20 | 21 | class MatplotlibTools: 22 | @staticmethod 23 | def create_line_plot(x: List[List[Union[float, str]]], y: List[List[float]], title: str = None, xlabel: str = "X", ylabel: str = "Y", 24 | output_file: str = "line_plot.png") -> Union[Any, str]: 25 | """ 26 | Create a line plot using the provided x and y data. 27 | 28 | Args: 29 | x (List[List[Union[float, str]]]): The x-coordinates of the data points for each line. 30 | y (List[List[float]]): The y-coordinates of the data points for each line. 31 | title (str, optional): The title of the plot. Defaults to None. 32 | xlabel (str, optional): The label for the x-axis. Defaults to "X". 33 | ylabel (str, optional): The label for the y-axis. Defaults to "Y". 34 | output_file (str, optional): The output file name. Defaults to "line_plot.png". 35 | 36 | Returns: 37 | Union[Any, str]: The matplotlib figure object, or an error message as a string. 38 | """ 39 | fig = None 40 | try: 41 | np, plt, mdates = check_matplotlib_dependencies() 42 | if len(x) != len(y): 43 | raise ValueError(f"The number of x and y lists must be equal. Got {len(x)} x-lists and {len(y)} y-lists. Check your data and try again.") 44 | 45 | for i, (xi, yi) in enumerate(zip(x, y)): 46 | if not isinstance(xi, list) or not isinstance(yi, list): 47 | raise TypeError(f"Both x[{i}] and y[{i}] must be lists. Check your data and try again.") 48 | if len(xi) != len(yi): 49 | raise ValueError(f"The lengths of x[{i}] and y[{i}] must be equal. Got lengths {len(xi)} and {len(yi)}. Check your data and try again.") 50 | if not all(isinstance(val, (int, float, str)) for val in xi): 51 | raise TypeError(f"All values in x[{i}] must be numbers or strings. Check your data and try again.") 52 | if not all(isinstance(val, (int, float)) for val in yi): 53 | raise TypeError(f"All values in y[{i}] must be numbers. Check your data and try again.") 54 | 55 | fig, ax = plt.subplots(figsize=(10, 6)) 56 | for xi, yi in zip(x, y): 57 | # Convert dates to numerical values 58 | if all(isinstance(val, str) for val in xi): 59 | xi = [mdates.datestr2num(val) for val in xi] 60 | ax.plot(xi, yi) 61 | 62 | ax.set_title(title) 63 | ax.set_xlabel(xlabel) 64 | ax.set_ylabel(ylabel) 65 | 66 | # Format x-axis as dates 67 | ax.xaxis.set_major_formatter(mdates.DateFormatter('%Y-%m-%d')) 68 | ax.xaxis.set_major_locator(mdates.MonthLocator()) 69 | plt.gcf().autofmt_xdate() # Rotate and align the tick labels 70 | 71 | fig.tight_layout() 72 | plt.savefig(output_file) 73 | return fig 74 | 75 | except ImportError as e: 76 | return str(e) 77 | except Exception as e: 78 | error_msg = f"Error: An unexpected error occurred while creating the line plot: {str(e)}" 79 | print(error_msg) 80 | return error_msg 81 | finally: 82 | if fig is not None: 83 | plt.close(fig) # Ensure the figure is closed to free up memory 84 | 85 | @staticmethod 86 | def create_scatter_plot(x: List[float], y: List[float], title: str = None, xlabel: str = None, ylabel: str = None) -> Union[str, None]: 87 | """ 88 | Create a scatter plot using the provided x and y data. 89 | 90 | Args: 91 | x (List[float]): The x-coordinates of the data points. 92 | y (List[float]): The y-coordinates of the data points. 93 | title (str, optional): The title of the plot. Defaults to None. 94 | xlabel (str, optional): The label for the x-axis. Defaults to None. 95 | ylabel (str, optional): The label for the y-axis. Defaults to None. 96 | 97 | Returns: 98 | Union[str, None]: The path to the saved plot image file, or an error message as a string. 99 | """ 100 | try: 101 | np, plt, _ = check_matplotlib_dependencies() 102 | if len(x) != len(y): 103 | raise ValueError("The lengths of x and y must be equal.") 104 | 105 | plt.figure(figsize=(8, 6)) 106 | plt.scatter(x, y) 107 | 108 | if title: 109 | plt.title(title) 110 | if xlabel: 111 | plt.xlabel(xlabel) 112 | if ylabel: 113 | plt.ylabel(ylabel) 114 | 115 | plt.tight_layout() 116 | plot_path = "scatter_plot.png" 117 | plt.savefig(plot_path) 118 | plt.close() 119 | 120 | return plot_path 121 | 122 | except ImportError as e: 123 | return str(e) 124 | 125 | except ValueError as e: 126 | error_msg = f"Error: {str(e)} Please ensure x and y have the same length." 127 | print(error_msg) 128 | return error_msg 129 | except Exception as e: 130 | error_msg = f"Error: An unexpected error occurred while creating the scatter plot: {str(e)}" 131 | print(error_msg) 132 | return error_msg 133 | 134 | @staticmethod 135 | def create_bar_plot(x: List[str], y: List[float], title: str = None, xlabel: str = None, ylabel: str = None) -> Union[str, None]: 136 | """ 137 | Create a bar plot using the provided x and y data. 138 | 139 | Args: 140 | x (List[str]): The categories for the x-axis. 141 | y (List[float]): The values for each category. 142 | title (str, optional): The title of the plot. Defaults to None. 143 | xlabel (str, optional): The label for the x-axis. Defaults to None. 144 | ylabel (str, optional): The label for the y-axis. Defaults to None. 145 | 146 | Returns: 147 | Union[str, None]: The path to the saved plot image file, or an error message as a string. 148 | """ 149 | try: 150 | np, plt, _ = check_matplotlib_dependencies() 151 | if len(x) != len(y): 152 | raise ValueError("The lengths of x and y must be equal.") 153 | 154 | plt.figure(figsize=(8, 6)) 155 | plt.bar(x, y) 156 | 157 | if title: 158 | plt.title(title) 159 | if xlabel: 160 | plt.xlabel(xlabel) 161 | if ylabel: 162 | plt.ylabel(ylabel) 163 | 164 | plt.tight_layout() 165 | plot_path = "bar_plot.png" 166 | plt.savefig(plot_path) 167 | plt.close() 168 | 169 | return plot_path 170 | except ImportError as e: 171 | return str(e) 172 | 173 | except ValueError as e: 174 | error_msg = f"Error: {str(e)} Please ensure x and y have the same length." 175 | print(error_msg) 176 | return error_msg 177 | except Exception as e: 178 | error_msg = f"Error: An unexpected error occurred while creating the bar plot: {str(e)}" 179 | print(error_msg) 180 | return error_msg 181 | 182 | @staticmethod 183 | def create_histogram(data: List[float], bins: int = 10, title: str = None, xlabel: str = None, ylabel: str = None) -> Union[str, None]: 184 | """ 185 | Create a histogram using the provided data. 186 | 187 | Args: 188 | data (List[float]): The data to plot in the histogram. 189 | bins (int, optional): The number of bins for the histogram. Defaults to 10. 190 | title (str, optional): The title of the plot. Defaults to None. 191 | xlabel (str, optional): The label for the x-axis. Defaults to None. 192 | ylabel (str, optional): The label for the y-axis. Defaults to None. 193 | 194 | Returns: 195 | Union[str, None]: The path to the saved plot image file, or an error message as a string. 196 | """ 197 | try: 198 | np, plt, _ = check_matplotlib_dependencies() 199 | plt.figure(figsize=(8, 6)) 200 | plt.hist(data, bins=bins) 201 | 202 | if title: 203 | plt.title(title) 204 | if xlabel: 205 | plt.xlabel(xlabel) 206 | if ylabel: 207 | plt.ylabel(ylabel) 208 | 209 | plt.tight_layout() 210 | plot_path = "histogram.png" 211 | plt.savefig(plot_path) 212 | plt.close() 213 | 214 | return plot_path 215 | 216 | except ImportError as e: 217 | return str(e) 218 | except Exception as e: 219 | error_msg = f"Error: An unexpected error occurred while creating the histogram: {str(e)}" 220 | print(error_msg) 221 | return error_msg 222 | 223 | @staticmethod 224 | def create_heatmap(data: List[List[float]], title: str = None, xlabel: str = None, ylabel: str = None) -> Union[str, None]: 225 | """ 226 | Create a heatmap using the provided 2D data. 227 | 228 | Args: 229 | data (List[List[float]]): The 2D data to plot in the heatmap. 230 | title (str, optional): The title of the plot. Defaults to None. 231 | xlabel (str, optional): The label for the x-axis. Defaults to None. 232 | ylabel (str, optional): The label for the y-axis. Defaults to None. 233 | 234 | Returns: 235 | Union[str, None]: The path to the saved plot image file, or an error message as a string. 236 | """ 237 | try: 238 | np, plt, _ = check_matplotlib_dependencies() 239 | data = np.array(data) 240 | 241 | plt.figure(figsize=(8, 6)) 242 | plt.imshow(data, cmap='viridis') 243 | plt.colorbar() 244 | 245 | if title: 246 | plt.title(title) 247 | if xlabel: 248 | plt.xlabel(xlabel) 249 | if ylabel: 250 | plt.ylabel(ylabel) 251 | 252 | plt.tight_layout() 253 | plot_path = "heatmap.png" 254 | plt.savefig(plot_path) 255 | plt.close() 256 | 257 | return plot_path 258 | 259 | except ImportError as e: 260 | return str(e) 261 | except Exception as e: 262 | error_msg = f"Error: An unexpected error occurred while creating the heatmap: {str(e)}" 263 | print(error_msg) 264 | return error_msg -------------------------------------------------------------------------------- /taskflowai/tools/pinecone_tools.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 TaskFlowAI Contributors. Licensed under Apache License 2.0. 2 | 3 | import os 4 | from typing import List, Dict, Any 5 | import numpy as np 6 | 7 | def check_pinecone(): 8 | try: 9 | import pinecone 10 | except ModuleNotFoundError: 11 | raise ImportError("pinecone is required for Pinecone tools. Install with `pip install pinecone`") 12 | return pinecone 13 | 14 | class PineconeTools: 15 | def __init__(self, api_key: str = None): 16 | """ 17 | Initialize PineconeTools with the Pinecone API key. 18 | 19 | Args: 20 | api_key (str, optional): Pinecone API key. If not provided, it will try to use the PINECONE_API_KEY environment variable. 21 | """ 22 | self.api_key = api_key or os.getenv("PINECONE_API_KEY") 23 | if not self.api_key: 24 | raise ValueError("Pinecone API key is required. Please provide it or set the PINECONE_API_KEY environment variable.") 25 | self.pc = check_pinecone().Pinecone(api_key=self.api_key) 26 | 27 | def get_pinecone_index(self, name: str): 28 | pinecone_client = check_pinecone().Pinecone(api_key=os.getenv("PINECONE_API_KEY")) 29 | return pinecone_client.Index(name) 30 | 31 | def create_index(self, name: str, dimension: int, metric: str = "cosine", cloud: str = "aws", region: str = "us-east-1") -> None: 32 | """ 33 | Create a new Pinecone index. 34 | 35 | Args: 36 | name (str): Name of the index to create. 37 | dimension (int): Dimension of the vectors to be stored in the index. 38 | metric (str, optional): Distance metric to use. Defaults to "cosine". 39 | cloud (str, optional): Cloud provider. Defaults to "aws". 40 | region (str, optional): Cloud region. Defaults to "us-east-1". 41 | 42 | Raises: 43 | Exception: If there's an error creating the index. 44 | """ 45 | try: 46 | self.pc.create_index( 47 | name=name, 48 | dimension=dimension, 49 | metric=metric, 50 | spec=check_pinecone().ServerlessSpec(cloud=cloud, region=region) 51 | ) 52 | print(f"Index '{name}' created successfully.") 53 | except Exception as e: 54 | raise Exception(f"Error creating index: {str(e)}") 55 | 56 | def delete_index(self, name: str) -> None: 57 | """ 58 | Delete a Pinecone index. 59 | 60 | Args: 61 | name (str): Name of the index to delete. 62 | 63 | Raises: 64 | Exception: If there's an error deleting the index. 65 | """ 66 | try: 67 | self.pc.delete_index(name) 68 | print(f"Index '{name}' deleted successfully.") 69 | except Exception as e: 70 | raise Exception(f"Error deleting index: {str(e)}") 71 | 72 | def list_indexes(self) -> List[str]: 73 | """ 74 | List all available Pinecone indexes. 75 | 76 | Returns: 77 | List[str]: List of index names. 78 | 79 | Raises: 80 | Exception: If there's an error listing the indexes. 81 | """ 82 | try: 83 | return self.pc.list_indexes() 84 | except Exception as e: 85 | raise Exception(f"Error listing indexes: {str(e)}") 86 | 87 | def upsert_vectors(self, index_name: str, vectors: List[Dict[str, Any]]) -> None: 88 | """ 89 | Upsert vectors into a Pinecone index. 90 | 91 | Args: 92 | index_name (str): Name of the index to upsert vectors into. 93 | vectors (List[Dict[str, Any]]): List of vectors to upsert. Each vector should be a dictionary with 'id', 'values', and optionally 'metadata'. 94 | 95 | Raises: 96 | Exception: If there's an error upserting vectors. 97 | """ 98 | try: 99 | index = self.pc.Index(index_name) 100 | index.upsert(vectors=vectors) 101 | print(f"Vectors upserted successfully into index '{index_name}'.") 102 | except Exception as e: 103 | raise Exception(f"Error upserting vectors: {str(e)}") 104 | 105 | def query_index(self, index_name: str, query_vector: List[float], top_k: int = 10, filter: Dict = None, include_metadata: bool = True) -> Dict[str, Any]: 106 | """ 107 | Query a Pinecone index for similar vectors. 108 | 109 | Args: 110 | index_name (str): Name of the index to query. 111 | query_vector (List[float]): The query vector. 112 | top_k (int, optional): Number of results to return. Defaults to 10. 113 | filter (Dict, optional): Metadata filter to apply to the query. Defaults to None. 114 | include_metadata (bool, optional): Whether to include metadata in the results. Defaults to True. 115 | 116 | Returns: 117 | Dict[str, Any]: Query results containing matches and their scores. 118 | 119 | Raises: 120 | Exception: If there's an error querying the index. 121 | """ 122 | try: 123 | index = self.pc.Index(index_name) 124 | results = index.query( 125 | vector=query_vector, 126 | top_k=top_k, 127 | include_metadata=include_metadata, 128 | filter=filter 129 | ) 130 | return results 131 | except Exception as e: 132 | raise Exception(f"Error querying index: {str(e)}") 133 | 134 | def delete_vectors(self, index_name: str, ids: List[str]) -> None: 135 | """ 136 | Delete vectors from a Pinecone index by their IDs. 137 | 138 | Args: 139 | index_name (str): Name of the index to delete vectors from. 140 | ids (List[str]): List of vector IDs to delete. 141 | 142 | Raises: 143 | Exception: If there's an error deleting vectors. 144 | """ 145 | try: 146 | index = self.pc.Index(index_name) 147 | index.delete(ids=ids) 148 | print(f"Vectors deleted successfully from index '{index_name}'.") 149 | except Exception as e: 150 | raise Exception(f"Error deleting vectors: {str(e)}") 151 | 152 | def update_vector_metadata(self, index_name: str, id: str, metadata: Dict[str, Any]) -> None: 153 | """ 154 | Update the metadata of a vector in a Pinecone index. 155 | 156 | Args: 157 | index_name (str): Name of the index containing the vector. 158 | id (str): ID of the vector to update. 159 | metadata (Dict[str, Any]): New metadata to assign to the vector. 160 | 161 | Raises: 162 | Exception: If there's an error updating the vector metadata. 163 | """ 164 | try: 165 | index = self.pc.Index(index_name) 166 | index.update(id=id, set_metadata=metadata) 167 | print(f"Metadata updated successfully for vector '{id}' in index '{index_name}'.") 168 | except Exception as e: 169 | raise Exception(f"Error updating vector metadata: {str(e)}") 170 | 171 | def describe_index_stats(self, index_name: str) -> Dict[str, Any]: 172 | """ 173 | Get statistics about a Pinecone index. 174 | 175 | Args: 176 | index_name (str): Name of the index to describe. 177 | 178 | Returns: 179 | Dict[str, Any]: Statistics about the index. 180 | 181 | Raises: 182 | Exception: If there's an error describing the index stats. 183 | """ 184 | try: 185 | index = self.pc.Index(index_name) 186 | return index.describe_index_stats() 187 | except Exception as e: 188 | raise Exception(f"Error describing index stats: {str(e)}") 189 | 190 | @staticmethod 191 | def normalize_vector(vector: List[float]) -> List[float]: 192 | """ 193 | Normalize a vector to unit length. 194 | 195 | Args: 196 | vector (List[float]): The input vector. 197 | 198 | Returns: 199 | List[float]: The normalized vector. 200 | """ 201 | norm = np.linalg.norm(vector) 202 | return (np.array(vector) / norm).tolist() if norm != 0 else vector -------------------------------------------------------------------------------- /taskflowai/tools/semantic_splitter.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import List 3 | import numpy as np 4 | from dotenv import load_dotenv 5 | from .embedding_tools import EmbeddingsTools 6 | import igraph as ig 7 | import leidenalg as la 8 | from sentence_splitter import SentenceSplitter 9 | load_dotenv() 10 | 11 | class SemanticSplitter: 12 | def __init__(self, embedding_provider: str = "openai", embedding_model: str = "text-embedding-3-small"): 13 | self.embedding_provider = embedding_provider 14 | self.embedding_model = embedding_model 15 | self.splitter = SentenceSplitter(language='en') 16 | 17 | @staticmethod 18 | def chunk_text(text: str, resolution: float = 0.35, similarity_threshold: float = 0.5, rearrange: bool = False, 19 | embedding_provider: str = "openai", embedding_model: str = "text-embedding-3-small") -> List[str]: 20 | splitter = SemanticSplitter(embedding_provider, embedding_model) 21 | segments = splitter._create_sentence_segments(text) 22 | embeddings = splitter._embed_segments(segments) 23 | communities = splitter._detect_communities(embeddings, resolution, similarity_threshold) 24 | chunks = splitter._create_chunks_from_communities(segments, communities, rearrange) 25 | 26 | print(f"Created {len(chunks)} non-empty chunks") 27 | return chunks 28 | 29 | def _create_sentence_segments(self, text: str) -> List[str]: 30 | sentences = self.splitter.split(text) 31 | segments = [sentence.strip() for sentence in sentences] 32 | print(f"Created {len(segments)} segments") 33 | return segments 34 | 35 | def _embed_segments(self, segments: List[str]) -> np.ndarray: 36 | embeddings, _ = EmbeddingsTools.get_embeddings(segments, self.embedding_provider, self.embedding_model) 37 | return np.array(embeddings) 38 | 39 | def _detect_communities(self, embeddings: np.ndarray, resolution: float, similarity_threshold: float) -> List[int]: 40 | if embeddings.shape[0] < 2: 41 | return [0] 42 | 43 | G = self._create_similarity_graph(embeddings, similarity_threshold) 44 | 45 | partition = self._find_optimal_partition(G, resolution) 46 | 47 | communities = partition.membership 48 | 49 | num_communities = len(set(communities)) 50 | print(f"Resolution: {resolution}, Similarity Threshold: {similarity_threshold}, Communities: {num_communities}") 51 | 52 | return communities 53 | 54 | def _create_chunks_from_communities(self, segments: List[str], communities: List[int], rearrange: bool) -> List[str]: 55 | if rearrange: 56 | # Group segments by community 57 | community_groups = {} 58 | for segment, community in zip(segments, communities): 59 | if community not in community_groups: 60 | community_groups[community] = [] 61 | community_groups[community].append(segment) 62 | 63 | # Create chunks from rearranged communities 64 | chunks = [' '.join(group).strip() for group in community_groups.values() if group] 65 | else: 66 | # Create chunks respecting original order 67 | chunks = [] 68 | current_community = communities[0] 69 | current_chunk = [] 70 | 71 | for segment, community in zip(segments, communities): 72 | if community != current_community: 73 | chunks.append(' '.join(current_chunk).strip()) 74 | current_chunk = [] 75 | current_community = community 76 | current_chunk.append(segment) 77 | 78 | # Add the last chunk 79 | if current_chunk: 80 | chunks.append(' '.join(current_chunk).strip()) 81 | 82 | return [chunk for chunk in chunks if chunk] # Remove any empty chunks 83 | 84 | def _identify_breakpoints(self, communities: List[int]) -> List[int]: 85 | breakpoints = [] 86 | for i in range(1, len(communities)): 87 | if communities[i] != communities[i-1]: 88 | breakpoints.append(i) 89 | return breakpoints 90 | 91 | def _create_similarity_graph(self, embeddings: np.ndarray, similarity_threshold: float) -> ig.Graph: 92 | similarities = np.dot(embeddings, embeddings.T) 93 | np.fill_diagonal(similarities, 0) 94 | similarities = np.maximum(similarities, 0) 95 | similarities = (similarities - np.min(similarities)) / (np.max(similarities) - np.min(similarities)) 96 | 97 | # Apply similarity threshold 98 | adjacency_matrix = (similarities >= similarity_threshold).astype(int) 99 | 100 | G = ig.Graph.Adjacency(adjacency_matrix.tolist()) 101 | G.es['weight'] = similarities[np.where(adjacency_matrix)] 102 | return G 103 | 104 | def _find_optimal_partition(self, G: ig.Graph, resolution: float) -> la.VertexPartition: 105 | return la.find_partition( 106 | G, 107 | la.CPMVertexPartition, 108 | weights='weight', 109 | resolution_parameter=resolution 110 | ) 111 | 112 | def _split_oversized_communities(self, membership: List[int], max_size: int) -> List[int]: 113 | community_sizes = {} 114 | for comm in membership: 115 | community_sizes[comm] = community_sizes.get(comm, 0) + 1 116 | 117 | new_membership = [] 118 | current_comm = max(membership) + 1 119 | for i, comm in enumerate(membership): 120 | if community_sizes[comm] > max_size: 121 | if i % max_size == 0: 122 | current_comm += 1 123 | new_membership.append(current_comm) 124 | else: 125 | new_membership.append(comm) 126 | 127 | return new_membership 128 | 129 | # Example usage 130 | #text = "This is a test text to demonstrate the semantic splitter. It should be split into meaningful chunks based on the content and similarity threshold." 131 | 132 | # Using OpenAI (default) 133 | #chunks = SemanticSplitter.chunk_text(text) 134 | 135 | # Using Cohere 136 | #chunks = SemanticSplitter.chunk_text(text, embedding_provider="cohere", embedding_model="embed-english-v3.0") 137 | 138 | # Using Mistral 139 | #chunks = SemanticSplitter.chunk_text(text, embedding_provider="mistral", embedding_model="mistral-embed") -------------------------------------------------------------------------------- /taskflowai/tools/sentence_splitter.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 TaskFlowAI Contributors. Licensed under Apache License 2.0. 2 | 3 | from typing import List 4 | from sentence_splitter import SentenceSplitter 5 | 6 | class TextSplitter: 7 | @staticmethod 8 | def split_text_by_sentences(text: str, chunk_size: int = 5, overlap: int = 1, language: str = 'en') -> List[str]: 9 | """ 10 | Split the text into chunks of sentences with overlap. 11 | 12 | :param text: The input text to split. 13 | :param chunk_size: The number of sentences per chunk. 14 | :param overlap: The number of sentences to overlap between chunks. 15 | :param language: The language of the text (default: 'en'). 16 | :return: A list of text chunks. 17 | """ 18 | splitter = SentenceSplitter(language) 19 | sentences = splitter.split(text) 20 | chunks = [] 21 | 22 | for i in range(0, len(sentences), chunk_size - overlap): 23 | chunk = ' '.join(sentences[i:i + chunk_size]) 24 | chunks.append(chunk.strip()) 25 | 26 | print(f"Created {len(chunks)} chunks with {chunk_size} sentences each and {overlap} sentence overlap") 27 | return chunks 28 | -------------------------------------------------------------------------------- /taskflowai/tools/text_splitters.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 TaskFlowAI Contributors. Licensed under Apache License 2.0. 2 | 3 | import os 4 | from typing import List, Union 5 | import numpy as np 6 | from dotenv import load_dotenv 7 | from .embedding_tools import EmbeddingsTools 8 | import igraph as ig 9 | import leidenalg as la 10 | from sentence_splitter import SentenceSplitter as ExternalSentenceSplitter 11 | load_dotenv() 12 | 13 | class SemanticSplitter: 14 | def __init__(self, embedding_provider: str = "openai", embedding_model: str = "text-embedding-3-small"): 15 | self.embedding_provider = embedding_provider 16 | self.embedding_model = embedding_model 17 | 18 | @staticmethod 19 | def chunk_text(text: Union[str, List[str]], rearrange: bool = False, 20 | embedding_provider: str = "openai", embedding_model: str = "text-embedding-3-small") -> List[str]: 21 | splitter = SemanticSplitter(embedding_provider, embedding_model) 22 | 23 | if isinstance(text, str): 24 | return splitter._process_single_text(text, rearrange) 25 | elif isinstance(text, list): 26 | all_chunks = [] 27 | for doc in text: 28 | all_chunks.extend(splitter._process_single_text(doc, rearrange)) 29 | return all_chunks 30 | else: 31 | raise ValueError("Input must be either a string or a list of strings") 32 | 33 | def _process_single_text(self, text: str, rearrange: bool) -> List[str]: 34 | segments = self._create_sentence_segments(text) 35 | embeddings = self._embed_segments(segments) 36 | communities = self._detect_communities(embeddings) 37 | chunks = self._create_chunks_from_communities(segments, communities, rearrange) 38 | 39 | print(f"Created {len(chunks)} non-empty chunks for this document") 40 | return chunks 41 | 42 | def _create_sentence_segments(self, text: str) -> List[str]: 43 | sentences = SentenceSplitter.split_text_by_sentences(text) 44 | segments = [sentence.strip() for sentence in sentences] 45 | print(f"Created {len(segments)} segments") 46 | return segments 47 | 48 | def _embed_segments(self, segments: List[str]) -> np.ndarray: 49 | embeddings, _ = EmbeddingsTools.get_embeddings(segments, self.embedding_provider, self.embedding_model) 50 | return np.array(embeddings) 51 | 52 | def _detect_communities(self, embeddings: np.ndarray) -> List[int]: 53 | if embeddings.shape[0] < 2: 54 | return [0] 55 | 56 | G = self._create_similarity_graph(embeddings, similarity_threshold=0.55) 57 | 58 | partition = self._find_optimal_partition(G, resolution=0.35) 59 | 60 | communities = partition.membership 61 | 62 | num_communities = len(set(communities)) 63 | print(f"Communities: {num_communities}") 64 | 65 | return communities 66 | 67 | def _create_chunks_from_communities(self, segments: List[str], communities: List[int], rearrange: bool) -> List[str]: 68 | if rearrange: 69 | # Group segments by community 70 | community_groups = {} 71 | for segment, community in zip(segments, communities): 72 | if community not in community_groups: 73 | community_groups[community] = [] 74 | community_groups[community].append(segment) 75 | 76 | # Create chunks from rearranged communities 77 | chunks = [' '.join(group).strip() for group in community_groups.values() if group] 78 | else: 79 | # Create chunks respecting original order 80 | chunks = [] 81 | current_community = communities[0] 82 | current_chunk = [] 83 | 84 | for segment, community in zip(segments, communities): 85 | if community != current_community: 86 | chunks.append(' '.join(current_chunk).strip()) 87 | current_chunk = [] 88 | current_community = community 89 | current_chunk.append(segment) 90 | 91 | # Add the last chunk 92 | if current_chunk: 93 | chunks.append(' '.join(current_chunk).strip()) 94 | 95 | return [chunk for chunk in chunks if chunk] # Remove any empty chunks 96 | 97 | def _identify_breakpoints(self, communities: List[int]) -> List[int]: 98 | breakpoints = [] 99 | for i in range(1, len(communities)): 100 | if communities[i] != communities[i-1]: 101 | breakpoints.append(i) 102 | return breakpoints 103 | 104 | def _create_similarity_graph(self, embeddings: np.ndarray, similarity_threshold: float) -> ig.Graph: 105 | similarities = np.dot(embeddings, embeddings.T) 106 | np.fill_diagonal(similarities, 0) 107 | similarities = np.maximum(similarities, 0) 108 | similarities = (similarities - np.min(similarities)) / (np.max(similarities) - np.min(similarities)) 109 | 110 | # Apply similarity threshold 111 | adjacency_matrix = (similarities >= similarity_threshold).astype(int) 112 | 113 | G = ig.Graph.Adjacency(adjacency_matrix.tolist()) 114 | G.es['weight'] = similarities[np.where(adjacency_matrix)] 115 | return G 116 | 117 | def _find_optimal_partition(self, G: ig.Graph, resolution: float) -> la.VertexPartition: 118 | return la.find_partition( 119 | G, 120 | la.CPMVertexPartition, 121 | weights='weight', 122 | resolution_parameter=resolution 123 | ) 124 | 125 | def _split_oversized_communities(self, membership: List[int], max_size: int) -> List[int]: 126 | community_sizes = {} 127 | for comm in membership: 128 | community_sizes[comm] = community_sizes.get(comm, 0) + 1 129 | 130 | new_membership = [] 131 | current_comm = max(membership) + 1 132 | for i, comm in enumerate(membership): 133 | if community_sizes[comm] > max_size: 134 | if i % max_size == 0: 135 | current_comm += 1 136 | new_membership.append(current_comm) 137 | else: 138 | new_membership.append(comm) 139 | 140 | return new_membership 141 | 142 | 143 | class SentenceSplitter: 144 | @staticmethod 145 | def split_text_by_sentences(text: str, chunk_size: int = 5, overlap: int = 1, language: str = 'en') -> List[str]: 146 | """ 147 | Split the text into chunks of sentences with overlap. 148 | 149 | :param text: The input text to split. 150 | :param chunk_size: The number of sentences per chunk. 151 | :param overlap: The number of sentences to overlap between chunks. 152 | :param language: The language of the text (default: 'en'). 153 | :return: A list of text chunks. 154 | """ 155 | splitter = ExternalSentenceSplitter(language=language) 156 | sentences = splitter.split(text) 157 | chunks = [] 158 | 159 | for i in range(0, len(sentences), chunk_size - overlap): 160 | chunk = ' '.join(sentences[i:i + chunk_size]) 161 | chunks.append(chunk.strip()) 162 | 163 | print(f"Created {len(chunks)} chunks with {chunk_size} sentences each and {overlap} sentence overlap") 164 | return chunks 165 | -------------------------------------------------------------------------------- /taskflowai/tools/wikipedia_tools.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 TaskFlowAI Contributors. Licensed under Apache License 2.0. 2 | 3 | import requests 4 | from typing import List, Dict, Optional 5 | 6 | class WikipediaTools: 7 | @staticmethod 8 | def get_article(title, include_images=False): 9 | """ 10 | Retrieve a Wikipedia article by its title. 11 | 12 | This method fetches the content of a Wikipedia article, including its extract, URL, and optionally, images. 13 | 14 | Args: 15 | title (str): The title of the Wikipedia article to retrieve. 16 | include_images (bool, optional): Whether to include images in the response. Defaults to False. 17 | 18 | Returns: 19 | dict: A dictionary containing the article data, including: 20 | - 'extract': The main text content of the article. 21 | - 'fullurl': The full URL of the article on Wikipedia. 22 | - 'pageid': The unique identifier of the page. 23 | - 'title': The title of the article. 24 | - 'thumbnail' (optional): Information about the article's thumbnail image, if available and requested. 25 | 26 | Raises: 27 | requests.exceptions.RequestException: If there's an error fetching the article from the Wikipedia API. 28 | KeyError, ValueError: If there's an error parsing the API response. 29 | 30 | Note: 31 | This method uses the Wikipedia API to fetch article data. The API has rate limits and usage policies 32 | that should be respected when making frequent requests. 33 | """ 34 | print(f"Getting article for title: {title}") 35 | base_url = "https://en.wikipedia.org/w/api.php" 36 | params = { 37 | "action": "query", 38 | "titles": title, 39 | "prop": "extracts|info|pageimages", 40 | "inprop": "url", 41 | "redirects": "", 42 | "format": "json", 43 | "origin": "*", 44 | "pithumbsize": "400" if include_images else "0" 45 | } 46 | 47 | try: 48 | response = requests.get(base_url, params=params) 49 | response.raise_for_status() 50 | data = response.json() 51 | pages = data["query"]["pages"] 52 | page_id = next(iter(pages)) 53 | article = pages[page_id] 54 | return article 55 | except requests.exceptions.RequestException as e: 56 | print(f"Error fetching article: {e}") 57 | return None 58 | except (KeyError, ValueError) as e: 59 | print(f"Error parsing response: {e}") 60 | return None 61 | 62 | @staticmethod 63 | def search_articles(query: str, num_results: int = 10) -> List[Dict[str, str]]: 64 | """ 65 | Search for Wikipedia articles based on a given query. 66 | 67 | Args: 68 | query (str): The search query string. 69 | num_results (int, optional): The maximum number of search results to return. Defaults to 10. 70 | 71 | Returns: 72 | List[Dict[str, str]]: A list of dictionaries containing detailed information about each search result. 73 | Each dictionary includes: 74 | - 'title': The title of the article. 75 | - 'fullurl': The full URL of the article on Wikipedia. 76 | - 'snippet': A brief extract or snippet from the article. 77 | 78 | Raises: 79 | requests.exceptions.RequestException: If there's an error fetching search results from the Wikipedia API. 80 | KeyError, ValueError: If there's an error parsing the API response. 81 | 82 | Note: 83 | This method uses the Wikipedia API to perform the search. The API has rate limits and usage policies 84 | that should be respected when making frequent requests. 85 | """ 86 | print(f"Searching articles for query: {query}") 87 | base_url = "https://en.wikipedia.org/w/api.php" 88 | params = { 89 | "action": "query", 90 | "list": "search", 91 | "srsearch": query, 92 | "srlimit": num_results, 93 | "format": "json", 94 | "origin": "*" 95 | } 96 | 97 | try: 98 | response = requests.get(base_url, params=params) 99 | response.raise_for_status() 100 | data = response.json() 101 | search_results = data["query"]["search"] 102 | 103 | # Fetch additional details for each search result 104 | detailed_results = [] 105 | for result in search_results: 106 | page_id = result['pageid'] 107 | detailed_params = { 108 | "action": "query", 109 | "pageids": page_id, 110 | "prop": "info|extracts|pageimages", 111 | "inprop": "url", 112 | "exintro": "", 113 | "explaintext": "", 114 | "pithumbsize": "250", 115 | "format": "json", 116 | "origin": "*" 117 | } 118 | detailed_response = requests.get(base_url, params=detailed_params) 119 | detailed_response.raise_for_status() 120 | detailed_data = detailed_response.json() 121 | page_data = detailed_data["query"]["pages"][str(page_id)] 122 | 123 | detailed_result = { 124 | "title": page_data.get("title"), 125 | "fullurl": page_data.get("fullurl"), 126 | "snippet": page_data.get("extract", "") 127 | } 128 | detailed_results.append(detailed_result) 129 | 130 | return detailed_results 131 | except requests.exceptions.RequestException as e: 132 | print(f"Error searching articles: {e}") 133 | return [] 134 | except (KeyError, ValueError) as e: 135 | print(f"Error parsing response: {e}") 136 | return [] 137 | 138 | @staticmethod 139 | def get_main_image(title: str, thumb_size: int = 250) -> Optional[str]: 140 | """ 141 | Retrieve the main image for a given Wikipedia article title. 142 | 143 | This method queries the Wikipedia API to fetch the main image (thumbnail) 144 | associated with the specified article title. 145 | 146 | Args: 147 | title (str): The title of the Wikipedia article. 148 | thumb_size (int, optional): The desired size of the thumbnail in pixels. Defaults to 250. 149 | 150 | Returns: 151 | Optional[str]: The URL of the main image if found, None otherwise. 152 | 153 | Raises: 154 | requests.exceptions.RequestException: If there's an error in the HTTP request. 155 | KeyError, ValueError: If there's an error parsing the API response. 156 | """ 157 | print(f"Getting main image for title: {title}") 158 | base_url = "https://en.wikipedia.org/w/api.php" 159 | params = { 160 | "action": "query", 161 | "titles": title, 162 | "prop": "pageimages", 163 | "pithumbsize": thumb_size, 164 | "format": "json", 165 | "origin": "*" 166 | } 167 | 168 | try: 169 | response = requests.get(base_url, params=params) 170 | response.raise_for_status() 171 | data = response.json() 172 | pages = data["query"]["pages"] 173 | page_id = next(iter(pages)) 174 | image_info = pages[page_id].get("thumbnail") 175 | if image_info: 176 | return image_info["source"] 177 | else: 178 | return None 179 | except requests.exceptions.RequestException as e: 180 | print(f"Error fetching main image: {e}") 181 | return None 182 | except (KeyError, ValueError) as e: 183 | print(f"Error parsing response: {e}") 184 | return None 185 | 186 | @staticmethod 187 | def search_images(query: str, limit: int = 20, thumb_size: int = 250) -> List[Dict[str, str]]: 188 | """ 189 | Search for images on Wikimedia Commons based on a given query. 190 | 191 | This method queries the Wikimedia Commons API to fetch images related to the specified query. 192 | 193 | Args: 194 | query (str): The search query for finding images. 195 | limit (int, optional): The maximum number of image results to return. Defaults to 20. 196 | thumb_size (int, optional): The desired size of the thumbnail in pixels. Defaults to 250. 197 | 198 | Returns: 199 | List[Dict[str, str]]: A list of dictionaries containing image information. 200 | Each dictionary includes 'title', 'url', and 'thumbnail' keys. 201 | 202 | Raises: 203 | requests.exceptions.RequestException: If there's an error in the HTTP request. 204 | KeyError, ValueError: If there's an error parsing the API response. 205 | """ 206 | print(f"Searching images for query: {query}") 207 | base_url = "https://commons.wikimedia.org/w/api.php" 208 | params = { 209 | "action": "query", 210 | "generator": "search", 211 | "gsrnamespace": "6", 212 | "gsrsearch": f"intitle:{query}", 213 | "gsrlimit": limit, 214 | "prop": "pageimages|info", 215 | "pithumbsize": thumb_size, 216 | "inprop": "url", 217 | "format": "json", 218 | "origin": "*" 219 | } 220 | 221 | try: 222 | response = requests.get(base_url, params=params) 223 | response.raise_for_status() 224 | data = response.json() 225 | pages = data["query"]["pages"] 226 | image_results = [] 227 | for page_id, page_data in pages.items(): 228 | image_info = { 229 | "title": page_data["title"], 230 | "url": page_data["fullurl"], 231 | "thumbnail": page_data.get("thumbnail", {}).get("source") 232 | } 233 | image_results.append(image_info) 234 | 235 | # New code to format the output 236 | formatted_results = [] 237 | separator = "-" * 30 # Create a separator line 238 | for i, image in enumerate(image_results, 1): 239 | formatted_image = f"\nImage {i}:\n" 240 | formatted_image += f" Title: {image['title']}\n" 241 | formatted_image += f" URL: {image['url']}\n" 242 | formatted_image += f" Thumbnail: {image['thumbnail']}\n" 243 | formatted_image += f"{separator}" 244 | formatted_results.append(formatted_image) 245 | 246 | return "".join(formatted_results) 247 | 248 | except requests.exceptions.RequestException as e: 249 | print(f"Error searching images: {e}") 250 | return "Error occurred while searching for images." 251 | except (KeyError, ValueError) as e: 252 | print(f"Error parsing response: {e}") 253 | return "Error occurred while parsing the response." 254 | -------------------------------------------------------------------------------- /taskflowai/tools/yahoo_finance_tools.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 TaskFlowAI Contributors. Licensed under Apache License 2.0. 2 | 3 | import json 4 | from typing import List, Dict, Any, Union 5 | 6 | def check_yfinance(): 7 | try: 8 | import yfinance as yf 9 | return yf 10 | except ImportError: 11 | raise ImportError("yfinance is required for YahooFinanceTools. Install with `pip install taskflowai[yahoo_finance_tools]`") 12 | 13 | def check_yahoofinance(): 14 | try: 15 | import yahoofinance 16 | return yahoofinance 17 | except ImportError: 18 | raise ImportError("yahoofinance is required for YahooFinanceTools. Install with `pip install taskflowai[yahoo_finance_tools]`") 19 | 20 | def check_pandas(): 21 | try: 22 | import pandas as pd 23 | return pd 24 | except ImportError: 25 | raise ImportError("pandas is required for YahooFinanceTools. Install with `pip install taskflowai[yahoo_finance_tools]`") 26 | 27 | 28 | class YahooFinanceTools: 29 | @staticmethod 30 | def get_ticker_info(ticker: str) -> Dict[str, Any]: 31 | """ 32 | Get comprehensive information about a stock ticker. 33 | 34 | Args: 35 | ticker (str): The stock ticker symbol. 36 | 37 | Returns: 38 | Dict[str, Any]: A dictionary containing detailed stock information. 39 | 40 | Raises: 41 | ValueError: If the ticker is invalid or data cannot be retrieved. 42 | """ 43 | try: 44 | yf = check_yfinance() 45 | stock = yf.Ticker(ticker) 46 | info = stock.info 47 | 48 | # Get the latest price 49 | history = stock.history(period="1d") 50 | latest_price = history['Close'].iloc[-1] if not history.empty else None 51 | 52 | return { 53 | "name": info.get("longName"), 54 | "symbol": info.get("symbol"), 55 | "current_price": latest_price, 56 | "currency": info.get("currency"), 57 | "market_cap": info.get("marketCap"), 58 | "sector": info.get("sector"), 59 | "industry": info.get("industry"), 60 | "pe_ratio": info.get("forwardPE"), 61 | "dividend_yield": info.get("dividendYield"), 62 | "52_week_high": info.get("fiftyTwoWeekHigh"), 63 | "52_week_low": info.get("fiftyTwoWeekLow"), 64 | "50_day_average": info.get("fiftyDayAverage"), 65 | "200_day_average": info.get("twoHundredDayAverage"), 66 | "volume": info.get("volume"), 67 | "avg_volume": info.get("averageVolume"), 68 | "beta": info.get("beta"), 69 | "book_value": info.get("bookValue"), 70 | "price_to_book": info.get("priceToBook"), 71 | "earnings_growth": info.get("earningsGrowth"), 72 | "revenue_growth": info.get("revenueGrowth"), 73 | "profit_margins": info.get("profitMargins"), 74 | "analyst_target_price": info.get("targetMeanPrice"), 75 | "recommendation": info.get("recommendationKey"), 76 | } 77 | except Exception as e: 78 | raise ValueError(f"Error retrieving info for ticker {ticker}: {str(e)}") 79 | 80 | @staticmethod 81 | def get_historical_data(ticker: str, period: str = "1y", interval: str = "1wk") -> str: 82 | """ 83 | Get historical price data for a stock ticker. 84 | 85 | Args: 86 | ticker (str): The stock ticker symbol. 87 | period (str): The time period to retrieve data for (e.g., "1d", "5d", "1mo", "3mo", "6mo", "1y", "2y", "5y", "10y", "ytd", "max"). 88 | interval (str): The interval between data points (e.g., "5m", "30m", "1h", "1d", "5d", "1wk", "1mo", "3mo"). 89 | 90 | Returns: 91 | str: A JSON string containing historical price data. 92 | 93 | Raises: 94 | ValueError: If the ticker is invalid or data cannot be retrieved. 95 | """ 96 | try: 97 | yf = check_yfinance() 98 | stock = yf.Ticker(ticker) 99 | data = stock.history(period=period, interval=interval) 100 | 101 | # Convert DataFrame to a dictionary of records 102 | data_dict = data.reset_index().to_dict(orient='records') 103 | 104 | # Convert datetime objects to strings and round price values 105 | for record in data_dict: 106 | record['Date'] = record['Date'].isoformat() 107 | for key in ['Open', 'High', 'Low', 'Close']: 108 | if key in record: 109 | record[key] = round(record[key], 2) 110 | 111 | # Serialize to JSON string 112 | return json.dumps(data_dict, default=str) 113 | except Exception as e: 114 | raise ValueError(f"Error retrieving historical data for ticker {ticker}: {str(e)}") 115 | 116 | @staticmethod 117 | def calculate_returns(tickers: Union[str, List[str]], period: str = "1y", interval: str = "1d"): 118 | """ 119 | Calculate daily returns for given stock ticker(s). 120 | 121 | Args: 122 | tickers (Union[str, List[str]]): The stock ticker symbol or a list of symbols. 123 | period (str): The time period to retrieve data for (e.g., "1d", "1mo", "1y", "ytd"). 124 | interval (str): The interval between data points (e.g., "1wk", "1mo"). 125 | 126 | Returns: 127 | Dict[str, Series]: A dictionary where keys are ticker symbols and values are Series containing daily returns. 128 | 129 | Raises: 130 | ValueError: If data cannot be retrieved for any ticker or if the DataFrame structure is unexpected. 131 | """ 132 | if isinstance(tickers, str): 133 | tickers = [tickers] 134 | 135 | returns = {} 136 | try: 137 | yf = check_yfinance() 138 | data = YahooFinanceTools.download_multiple_tickers(tickers, period=period, interval=interval) 139 | 140 | if data.empty: 141 | raise ValueError("No data returned from download_multiple_tickers") 142 | 143 | for ticker in tickers: 144 | if ('Close' in data.columns.get_level_values('Price') and 145 | ticker in data.columns.get_level_values('Ticker')): 146 | returns[ticker] = data[ticker]['Close'].pct_change() 147 | else: 148 | raise ValueError(f"'Close' column not found for ticker {ticker}") 149 | 150 | return returns 151 | except Exception as e: 152 | raise ValueError(f"Error calculating returns for tickers {tickers}: {str(e)}") 153 | 154 | @staticmethod 155 | def get_financials(ticker: str, statement: str = "income"): 156 | """ 157 | Get financial statements for a stock ticker. 158 | 159 | Args: 160 | ticker (str): The stock ticker symbol. 161 | statement (str): The type of financial statement ("income", "balance", or "cash"). 162 | 163 | Returns: 164 | DataFrame: A DataFrame containing the requested financial statement. 165 | 166 | Raises: 167 | ValueError: If the ticker is invalid, data cannot be retrieved, or an invalid statement type is provided. 168 | """ 169 | try: 170 | yf = check_yfinance() 171 | stock = yf.Ticker(ticker) 172 | if statement == "income": 173 | return stock.financials 174 | elif statement == "balance": 175 | return stock.balance_sheet 176 | elif statement == "cash": 177 | return stock.cashflow 178 | else: 179 | raise ValueError("Invalid statement type. Choose 'income', 'balance', or 'cash'.") 180 | except Exception as e: 181 | raise ValueError(f"Error retrieving {statement} statement for ticker {ticker}: {str(e)}") 182 | 183 | @staticmethod 184 | def get_recommendations(ticker: str): 185 | """ 186 | Get analyst recommendations for a stock ticker. 187 | 188 | Args: 189 | ticker (str): The stock ticker symbol. 190 | 191 | Returns: 192 | DataFrame: A DataFrame containing analyst recommendations. 193 | 194 | Raises: 195 | ValueError: If the ticker is invalid or data cannot be retrieved. 196 | """ 197 | try: 198 | yf = check_yfinance() 199 | stock = yf.Ticker(ticker) 200 | return stock.recommendations 201 | except Exception as e: 202 | raise ValueError(f"Error retrieving recommendations for ticker {ticker}: {str(e)}") 203 | 204 | @staticmethod 205 | def download_multiple_tickers(tickers: List[str], period: str = "1mo", interval: str = "1d"): 206 | """ 207 | Download historical data for multiple tickers. 208 | 209 | Args: 210 | tickers (List[str]): A list of stock ticker symbols. 211 | period (str): The time period to retrieve data for (e.g., "1d", "1mo", "1y"). 212 | interval (str): The interval between data points (e.g., "1m", "1h", "1d"). 213 | 214 | Returns: 215 | DataFrame: A DataFrame containing historical price data for all tickers. 216 | 217 | Raises: 218 | ValueError: If any ticker is invalid or data cannot be retrieved. 219 | """ 220 | try: 221 | yf = check_yfinance() 222 | data = yf.download(" ".join(tickers), period=period, interval=interval, group_by="ticker") 223 | return data 224 | except Exception as e: 225 | raise ValueError(f"Error downloading data for tickers {tickers}: {str(e)}") 226 | 227 | @staticmethod 228 | def get_asset_profile(ticker: str) -> Dict[str, Any]: 229 | """ 230 | Get the asset profile for a given stock ticker. 231 | 232 | Args: 233 | ticker (str): The stock ticker symbol. 234 | 235 | Returns: 236 | Dict[str, Any]: A dictionary containing the asset profile information. 237 | 238 | Raises: 239 | ValueError: If the ticker is invalid or data cannot be retrieved. 240 | """ 241 | try: 242 | yahoofinance = check_yahoofinance() 243 | profile = yahoofinance.AssetProfile(ticker) 244 | return profile.to_dfs() 245 | except Exception as e: 246 | raise ValueError(f"Error retrieving asset profile for ticker {ticker}: {str(e)}") 247 | 248 | @staticmethod 249 | def get_balance_sheet(ticker: str, quarterly: bool = False): 250 | """ 251 | Get the balance sheet for a given stock ticker. 252 | 253 | Args: 254 | ticker (str): The stock ticker symbol. 255 | quarterly (bool): If True, retrieve quarterly data; if False, retrieve annual data. 256 | 257 | Returns: 258 | pd.DataFrame: A DataFrame containing the balance sheet data. 259 | 260 | Raises: 261 | ValueError: If the ticker is invalid or data cannot be retrieved. 262 | """ 263 | try: 264 | yahoofinance = check_yahoofinance() 265 | if quarterly: 266 | balance_sheet = yahoofinance.BalanceSheetQuarterly(ticker) 267 | else: 268 | balance_sheet = yahoofinance.BalanceSheet(ticker) 269 | return balance_sheet.to_dfs()['Balance Sheet'] 270 | except Exception as e: 271 | raise ValueError(f"Error retrieving balance sheet for ticker {ticker}: {str(e)}") 272 | 273 | @staticmethod 274 | def get_cash_flow(ticker: str, quarterly: bool = False): 275 | """ 276 | Get the cash flow statement for a given stock ticker. 277 | 278 | Args: 279 | ticker (str): The stock ticker symbol. 280 | quarterly (bool): If True, retrieve quarterly data; if False, retrieve annual data. 281 | 282 | Returns: 283 | pd.DataFrame: A DataFrame containing the cash flow statement data. 284 | 285 | Raises: 286 | ValueError: If the ticker is invalid or data cannot be retrieved. 287 | """ 288 | try: 289 | yahoofinance = check_yahoofinance() 290 | if quarterly: 291 | cash_flow = yahoofinance.CashFlowQuarterly(ticker) 292 | else: 293 | cash_flow = yahoofinance.CashFlow(ticker) 294 | return cash_flow.to_dfs()['Cash Flow'] 295 | except Exception as e: 296 | raise ValueError(f"Error retrieving cash flow statement for ticker {ticker}: {str(e)}") 297 | 298 | @staticmethod 299 | def get_income_statement(ticker: str, quarterly: bool = False): 300 | 301 | """ 302 | Get the income statement for a given stock ticker. 303 | 304 | Args: 305 | ticker (str): The stock ticker symbol. 306 | quarterly (bool): If True, retrieve quarterly data; if False, retrieve annual data. 307 | 308 | Returns: 309 | pd.DataFrame: A DataFrame containing the income statement data. 310 | 311 | Raises: 312 | ValueError: If the ticker is invalid or data cannot be retrieved. 313 | """ 314 | try: 315 | yahoofinance = check_yahoofinance() 316 | if quarterly: 317 | income_statement = yahoofinance.IncomeStatementQuarterly(ticker) 318 | else: 319 | income_statement = yahoofinance.IncomeStatement(ticker) 320 | return income_statement.to_dfs()['Income Statement'] 321 | except Exception as e: 322 | raise ValueError(f"Error retrieving income statement for ticker {ticker}: {str(e)}") 323 | 324 | @staticmethod 325 | def get_custom_historical_data(ticker: str, start_date: str, end_date: str, 326 | frequency: str = '1d', event: str = 'history'): 327 | """ 328 | Get custom historical data for a stock ticker with specified parameters. 329 | 330 | Args: 331 | ticker (str): The stock ticker symbol. 332 | start_date (str): The start date for the query (format: 'YYYY-MM-DD'). 333 | end_date (str): The end date for the query (format: 'YYYY-MM-DD'). 334 | frequency (str): The data frequency ('1d', '1wk', or '1mo'). Default is '1d'. 335 | event (str): The type of event data to retrieve. Default is 'history'. 336 | 337 | Returns: 338 | pd.DataFrame: A DataFrame containing the custom historical data. 339 | 340 | Raises: 341 | ValueError: If the ticker is invalid, dates are incorrect, or data cannot be retrieved. 342 | """ 343 | try: 344 | yahoofinance = check_yahoofinance() 345 | historical_data = yahoofinance.HistoricalPrices( 346 | ticker, start_date, end_date, 347 | frequency=yahoofinance.DataFrequency(frequency), 348 | event=yahoofinance.DataEvent(event) 349 | ) 350 | return historical_data.to_dfs() 351 | except Exception as e: 352 | raise ValueError(f"Error retrieving custom historical data for ticker {ticker}: {str(e)}") 353 | 354 | @staticmethod 355 | def technical_analysis(ticker: str, period: str = "1y") -> Dict[str, Any]: 356 | """ 357 | Perform technical analysis for a given stock ticker. 358 | 359 | Args: 360 | ticker (str): The stock ticker symbol. 361 | period (str): The time period for historical data (e.g., "1mo", "3mo", "1y"). 362 | 363 | Returns: 364 | Dict[str, Any]: A dictionary containing various technical analysis indicators. 365 | 366 | Raises: 367 | ValueError: If the ticker is invalid or data cannot be retrieved. 368 | """ 369 | try: 370 | # Get historical data 371 | data = YahooFinanceTools.get_historical_data(ticker, period) 372 | 373 | # Calculate moving averages 374 | data['SMA_50'] = data['Close'].rolling(window=50).mean() 375 | data['SMA_200'] = data['Close'].rolling(window=200).mean() 376 | 377 | # Calculate Relative Strength Index (RSI) 378 | delta = data['Close'].diff() 379 | gain = (delta.where(delta > 0, 0)).rolling(window=14).mean() 380 | loss = (-delta.where(delta < 0, 0)).rolling(window=14).mean() 381 | rs = gain / loss 382 | data['RSI'] = 100 - (100 / (1 + rs)) 383 | 384 | # Calculate MACD 385 | exp1 = data['Close'].ewm(span=12, adjust=False).mean() 386 | exp2 = data['Close'].ewm(span=26, adjust=False).mean() 387 | data['MACD'] = exp1 - exp2 388 | data['Signal_Line'] = data['MACD'].ewm(span=9, adjust=False).mean() 389 | 390 | # Calculate Bollinger Bands 391 | data['BB_Middle'] = data['Close'].rolling(window=20).mean() 392 | data['BB_Upper'] = data['BB_Middle'] + (data['Close'].rolling(window=20).std() * 2) 393 | data['BB_Lower'] = data['BB_Middle'] - (data['Close'].rolling(window=20).std() * 2) 394 | 395 | latest = data.iloc[-1] 396 | 397 | return { 398 | "current_price": latest['Close'], 399 | "sma_50": latest['SMA_50'], 400 | "sma_200": latest['SMA_200'], 401 | "rsi": latest['RSI'], 402 | "macd": latest['MACD'], 403 | "macd_signal": latest['Signal_Line'], 404 | "bollinger_upper": latest['BB_Upper'], 405 | "bollinger_middle": latest['BB_Middle'], 406 | "bollinger_lower": latest['BB_Lower'], 407 | "volume": latest['Volume'] 408 | } 409 | except Exception as e: 410 | raise ValueError(f"Error performing technical analysis for ticker {ticker}: {str(e)}") 411 | 412 | @staticmethod 413 | def fundamental_analysis(ticker: str) -> Dict[str, Any]: 414 | """ 415 | Perform a comprehensive fundamental analysis for a given stock ticker. 416 | 417 | Args: 418 | ticker (str): The stock ticker symbol. 419 | 420 | Returns: 421 | Dict[str, Any]: A dictionary containing various fundamental analysis metrics. 422 | 423 | Raises: 424 | ValueError: If the ticker is invalid or data cannot be retrieved. 425 | """ 426 | try: 427 | # Get basic info 428 | info = YahooFinanceTools.get_ticker_info(ticker) 429 | 430 | # Get financial statements 431 | income_statement = YahooFinanceTools.get_income_statement(ticker) 432 | balance_sheet = YahooFinanceTools.get_balance_sheet(ticker) 433 | cash_flow = YahooFinanceTools.get_cash_flow(ticker) 434 | 435 | # Calculate additional metrics 436 | latest_year = income_statement.columns[0] 437 | 438 | revenue = income_statement.loc['Total Revenue', latest_year] 439 | net_income = income_statement.loc['Net Income', latest_year] 440 | total_assets = balance_sheet.loc['Total Assets', latest_year] 441 | total_liabilities = balance_sheet.loc['Total Liabilities Net Minority Interest', latest_year] 442 | total_equity = balance_sheet.loc['Total Equity Gross Minority Interest', latest_year] 443 | 444 | # Return on Equity (ROE) 445 | roe = net_income / total_equity 446 | 447 | # Return on Assets (ROA) 448 | roa = net_income / total_assets 449 | 450 | # Debt to Equity Ratio 451 | debt_to_equity = total_liabilities / total_equity 452 | 453 | # Current Ratio 454 | current_assets = balance_sheet.loc['Current Assets', latest_year] 455 | current_liabilities = balance_sheet.loc['Current Liabilities', latest_year] 456 | current_ratio = current_assets / current_liabilities 457 | 458 | # Free Cash Flow 459 | operating_cash_flow = cash_flow.loc['Operating Cash Flow', latest_year] 460 | capital_expenditures = cash_flow.loc['Capital Expenditure', latest_year] 461 | free_cash_flow = operating_cash_flow - capital_expenditures 462 | 463 | return { 464 | **info, 465 | "revenue": revenue, 466 | "net_income": net_income, 467 | "total_assets": total_assets, 468 | "total_liabilities": total_liabilities, 469 | "total_equity": total_equity, 470 | "return_on_equity": roe, 471 | "return_on_assets": roa, 472 | "debt_to_equity_ratio": debt_to_equity, 473 | "current_ratio": current_ratio, 474 | "free_cash_flow": free_cash_flow 475 | } 476 | except Exception as e: 477 | raise ValueError(f"Error performing fundamental analysis for ticker {ticker}: {str(e)}") -------------------------------------------------------------------------------- /taskflowai/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 TaskFlowAI Contributors. Licensed under Apache License 2.0. 2 | 3 | from datetime import datetime 4 | import base64 5 | import re 6 | import json 7 | from typing import Union 8 | class Utils: 9 | 10 | @staticmethod 11 | def update_conversation_history(history, role, content): 12 | """ 13 | Format a message with a timestamp and role, then update the conversation history. 14 | 15 | Args: 16 | history (list): The current conversation history as a list of formatted messages. 17 | role (str): The role of the message sender (either "User" or "Assistant"). 18 | content (str): The message content. 19 | 20 | Returns: 21 | list: Updated conversation history. 22 | """ 23 | timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S") 24 | formatted_message = f"[{timestamp}] {role}: {content}" 25 | history.append(formatted_message) 26 | return history 27 | 28 | 29 | @staticmethod 30 | def image_to_base64(image_path: str, scale_factor: float = 0.5) -> str: 31 | """ 32 | Convert an image to a base64-encoded string, with optional resizing. 33 | 34 | Args: 35 | image_path (str): The path to the image file. 36 | scale_factor (float, optional): Factor to scale the image dimensions. Defaults to 0.5. 37 | 38 | Returns: 39 | str: Base64-encoded string representation of the (optionally resized) image. 40 | 41 | Raises: 42 | IOError: If there's an error opening or processing the image file. 43 | """ 44 | import numpy as np 45 | from io import BytesIO 46 | 47 | with open(image_path, "rb") as image_file: 48 | img = np.frombuffer(image_file.read(), dtype=np.uint8) 49 | img = img.reshape((-1, 3)) # Assuming 3 channels (RGB) 50 | resized_img = img[::int(1/scale_factor)] 51 | buffer = BytesIO() 52 | np.save(buffer, resized_img, allow_pickle=False) 53 | return base64.b64encode(buffer.getvalue()).decode("utf-8") 54 | 55 | 56 | @staticmethod 57 | def parse_json_response(response: Union[str, dict]) -> dict: 58 | """ 59 | Parse a JSON object from a string response or return the dict if already parsed. 60 | 61 | Args: 62 | response (Union[str, dict]): The response, either a string containing a JSON object or an already parsed dict. 63 | 64 | Returns: 65 | dict: The parsed JSON object. 66 | 67 | Raises: 68 | ValueError: If no valid JSON object is found in the response string. 69 | """ 70 | if isinstance(response, dict): 71 | return response 72 | 73 | if isinstance(response, str): 74 | try: 75 | return json.loads(response) 76 | except json.JSONDecodeError: 77 | json_match = re.search(r'\{.*?\}', response, re.DOTALL) 78 | if json_match: 79 | json_str = json_match.group(0) 80 | try: 81 | return json.loads(json_str) 82 | except json.JSONDecodeError: 83 | raise ValueError(f"Failed to parse JSON from response: {response}") 84 | else: 85 | raise ValueError(f"No JSON object found in response: {response}") 86 | 87 | raise ValueError(f"Unexpected response type: {type(response)}") 88 | --------------------------------------------------------------------------------