├── .env.sample ├── .gitignore ├── Makefile ├── README.md ├── plots └── warriors_victories.png ├── poetry.lock ├── pyproject.toml ├── src ├── __init__.py ├── agent.py ├── models.py ├── plot.py └── tools.py └── tests └── __init__.py /.env.sample: -------------------------------------------------------------------------------- 1 | COHERE_API_KEY=your_cohere_api_key_goes_here 2 | TAVILY_API_KEY=your_tavily_api_key_goes_here -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | .ruff_cache/ 6 | 7 | # C extensions 8 | *.so 9 | 10 | # Distribution / packaging 11 | .Python 12 | #build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | cover/ 54 | 55 | # Translations 56 | *.mo 57 | *.pot 58 | 59 | # Django stuff: 60 | *.log 61 | local_settings.py 62 | db.sqlite3 63 | db.sqlite3-journal 64 | 65 | # Flask stuff: 66 | instance/ 67 | .webassets-cache 68 | 69 | # Scrapy stuff:import comet_ml 70 | .scrapy 71 | 72 | # Sphinx documentation 73 | docs/_build/ 74 | 75 | # PyBuilder 76 | .pybuilder/ 77 | target/ 78 | 79 | # Jupyter Notebook 80 | .ipynb_checkpoints 81 | 82 | # IPython 83 | profile_default/ 84 | ipython_config.py 85 | 86 | # pyenv 87 | # For a library or package, you might want to ignore these files since the code is 88 | # intended to run in multiple environments; otherwise, check them in: 89 | # .python-version 90 | 91 | # pipenv 92 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 93 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 94 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 95 | # install all needed dependencies. 96 | #Pipfile.lock 97 | 98 | # poetry 99 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 100 | # This is especially recommended for binary packages to ensure reproducibility, and is more 101 | # commonly ignored for libraries. 102 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 103 | #poetry.lock 104 | 105 | # pdm 106 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 107 | #pdm.lock 108 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 109 | # in version control. 110 | # https://pdm.fming.dev/#use-with-ide 111 | .pdm.toml 112 | 113 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 114 | __pypackages__/ 115 | 116 | # Celery stuff 117 | celerybeat-schedule 118 | celerybeat.pid 119 | 120 | # SageMath parsed files 121 | *.sage.py 122 | 123 | # Environments 124 | .env 125 | .venv 126 | env/ 127 | venv/ 128 | ENV/ 129 | env.bak/ 130 | venv.bak/ 131 | 132 | # Spyder project settings 133 | .spyderproject 134 | .spyproject 135 | 136 | # Rope project settings 137 | .ropeproject 138 | 139 | # mkdocs documentation 140 | /site 141 | 142 | # mypy 143 | .mypy_cache/ 144 | .dmypy.json 145 | dmypy.json 146 | 147 | # Pyre type checker 148 | .pyre/ 149 | 150 | # pytype static type analyzer 151 | .pytype/ 152 | 153 | # Cython debug symbols 154 | cython_debug/ 155 | 156 | # PyCharm 157 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 158 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 159 | # and can be added to the global gitignore or merged into this file. For a more nuclear 160 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 161 | #.idea/ 162 | 163 | # VSCode 164 | .vscode/* 165 | !.vscode/settings.json 166 | !.vscode/tasks.json 167 | !.vscode/launch.json 168 | !.vscode/extensions.json 169 | !.vscode/*.code-snippets 170 | 171 | # Local History for Visual Studio Code 172 | .history/ 173 | 174 | # Built Visual Studio Code Extensions 175 | *.vsix 176 | 177 | # Artifacts 178 | results/ 179 | output*/ 180 | model_cache/ 181 | output* 182 | user_data.sh 183 | 184 | set_env_variables.sh 185 | logs/ 186 | .DS_Store 187 | gradio_cached_examples 188 | trust_policy.json 189 | 190 | .TODOS 191 | ca.cert 192 | *.pkl -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | install: 2 | @echo "Installing dependencies" 3 | curl -sSL https://install.python-poetry.org | python3 - 4 | poetry install 5 | 6 | run-llm: 7 | poetry run python src/models.py \ 8 | --input "Create a plot with Python of the number of games won by the Golden State Warriors in each of the last 2 NBA seasons." \ 9 | --model_name gpt-3.5-turbo \ 10 | --model_provider openai 11 | 12 | run-agent: 13 | poetry run python src/agent.py \ 14 | --input "Create a plot of the number of games won by the golden state warriors in each of the last 2 seasons. Save the plot as a png file under plots/warriors_victories.png" \ 15 | --model_name command-r-plus \ 16 | --model_provider cohere 17 | lint: 18 | @echo "Fixing linting issues" 19 | poetry run ruff check --fix . 20 | 21 | format: 22 | echo "Formatting Python code" 23 | poetry run ruff format . -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Let's build our first LLM agent 2 | 3 | 4 | ## Table of Contents 5 | - [What is this repo about?](#) 6 | - [How to run the code](#features) 7 | - [Wanna get more hands-on content like this?](#) 8 | 9 | ## What is this repo about? 10 | In this repository you will find a Python implementation of an LLM agent that can generate visualisations using public data available on the Internet. 11 | 12 | ### What is an LLM agent? 13 | An agent is essentially a wrapper around your LLM, that provides extra functionality like 14 | 15 | - Tool usage. The LLM is able to select and use tools, like internet search, to fetch relevant information it might need to accomplish the task. 16 | 17 | - Multi-step reasoning. The LLM can generate a plan, execute it, and adjust it based on the partial outcomes obtained. 18 | 19 | The LLM acts as a reasoning machine, that helps the agent choose the sequence of actions to take to solve the task. 20 | 21 | Let me show you how to build a ReAct (Reason and Act) agent in Python that can generate the plot we want. 22 | 23 | 24 | ## Run the whole thing in 3 minutes 25 | 26 | 1. Create Python virtual environment and install all dependencies using Python Poetry 27 | ``` 28 | $ make install 29 | ``` 30 | 31 | 2. Set API keys for Tavily and Cohere in an `.env` file 32 | ``` 33 | $ cp .env.sample .env 34 | ``` 35 | and replace placeholders with your keys. 36 | 37 | 3. Ask the agent to generate a plot example 38 | ``` 39 | $ make run-agent 40 | ``` 41 | 42 | ## Wanna get more hands-on content like this? 43 | 44 | Jon 15k+ builders to the the Real-World ML Newsletter, and learn to build real-world ML products. 45 | 46 | -> [Click to join for FREE](https://www.realworldml.net/subscribe) -------------------------------------------------------------------------------- /plots/warriors_victories.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Paulescu/plot-generator-agent/53ebedf774fa30adf20fea39f447635e14e5b477/plots/warriors_victories.png -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.poetry] 2 | name = "src" 3 | version = "0.1.0" 4 | description = "" 5 | authors = ["Paulescu "] 6 | readme = "README.md" 7 | 8 | [tool.poetry.dependencies] 9 | python = "^3.10" 10 | langchain = "^0.1.17" 11 | langchain-cohere = "^0.1.4" 12 | langchain-experimental = "^0.0.57" 13 | fire = "^0.6.0" 14 | matplotlib = "^3.8.4" 15 | python-dotenv = "^1.0.1" 16 | langchainhub = "^0.1.15" 17 | pandas = "^2.2.2" 18 | langchain-openai = "^0.1.6" 19 | 20 | 21 | [tool.poetry.group.dev.dependencies] 22 | ruff = "^0.4.3" 23 | 24 | [build-system] 25 | requires = ["poetry-core"] 26 | build-backend = "poetry.core.masonry.api" 27 | 28 | [tool.ruff] 29 | line-length = 88 30 | 31 | [tool.ruff.format] 32 | quote-style = "single" 33 | indent-style = "space" 34 | docstring-code-format = true 35 | 36 | [tool.ruff.lint] 37 | extend-select = ["I"] 38 | -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Paulescu/plot-generator-agent/53ebedf774fa30adf20fea39f447635e14e5b477/src/__init__.py -------------------------------------------------------------------------------- /src/agent.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from typing import List 3 | 4 | from dotenv import find_dotenv, load_dotenv 5 | from langchain import hub 6 | from langchain.agents import AgentExecutor, create_react_agent 7 | from langchain_core.prompts import ChatPromptTemplate 8 | from langchain_core.language_models.chat_models import BaseChatModel 9 | from langchain_cohere.react_multi_hop.agent import create_cohere_react_agent 10 | 11 | from src.tools import get_tools, Tool 12 | from src.models import get_chat_model 13 | 14 | # Load the environment variables from my .env file 15 | load_dotenv(find_dotenv()) 16 | 17 | # Set up logging 18 | logging.basicConfig( 19 | level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s' 20 | ) 21 | logger = logging.getLogger(__name__) 22 | 23 | 24 | 25 | def get_react_agent_executor( 26 | llm: BaseChatModel, 27 | tools: List[Tool], 28 | model_provider: str) -> AgentExecutor: 29 | """ 30 | Returns an agent executor that can execute a multi-hop react agent. 31 | If the model_provider is 'cohere', the agent will be created using the Cohere API. 32 | 33 | Args: 34 | llm (BaseChatModel): The chat-based LLM model to use. 35 | tools (List[Tool]): A list of tools that can be used by the agent. 36 | model_provider (str): The provider of the chat-based LLM model. 37 | 38 | Returns: 39 | AgentExecutor: An agent executor that can execute a multi-hop react agent. 40 | """ 41 | if model_provider == 'cohere': 42 | prompt = ChatPromptTemplate.from_template("{input}") 43 | agent = create_cohere_react_agent( 44 | llm=llm, 45 | tools=tools, 46 | prompt=prompt, 47 | ) 48 | else: 49 | prompt = hub.pull('hwchase17/react') 50 | agent = create_react_agent( 51 | llm=llm, 52 | tools=tools, 53 | prompt=prompt, 54 | ) 55 | 56 | agent_executor = AgentExecutor( 57 | agent=agent, 58 | tools=tools, 59 | verbose=True 60 | 61 | ) 62 | return agent_executor 63 | 64 | 65 | def run( 66 | model_provider: str, 67 | model_name: str, 68 | input: str 69 | ): 70 | """ 71 | Creates an agent executor that can execute a multi-hop react agent, using the 72 | specified model, and runs the agent with the given input. 73 | 74 | Args: 75 | model_name (str): The name of the chat-based LLM model to use. 76 | input (str): The input to the agent. 77 | """ 78 | llm = get_chat_model(model_provider, model_name) 79 | tools = get_tools() 80 | 81 | agent_executor = get_react_agent_executor(llm, tools, model_provider) 82 | 83 | # breakpoint() 84 | 85 | agent_executor.invoke( 86 | { 87 | 'input': input, 88 | } 89 | ) 90 | 91 | 92 | if __name__ == '__main__': 93 | from fire import Fire 94 | Fire(run) -------------------------------------------------------------------------------- /src/models.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | from langchain_core.language_models.chat_models import BaseChatModel 4 | from langchain_cohere.chat_models import ChatCohere 5 | from langchain_community.chat_models import ChatOllama 6 | from langchain_openai import ChatOpenAI 7 | 8 | logger = logging.getLogger(__name__) 9 | 10 | 11 | def get_chat_model(model_provider: str, model_name: str) -> BaseChatModel: 12 | """ 13 | Return a chat-based LLM model from the given `model_provider` with the given `model_name`. 14 | 15 | Args: 16 | model_provider (str): The provider of the chat-based LLM model. 17 | model_name (str): The name of the chat-based LLM model to use. 18 | 19 | Returns: 20 | ChatOllama: A chat-based LLM model. 21 | """ 22 | assert model_provider in [ 23 | 'ollama', 24 | 'cohere', 25 | 'openai' 26 | ], f'Invalid model provider: {model_provider}' 27 | 28 | if model_provider == 'ollama': 29 | logger.info(f'Loading Ollama model: {model_name}') 30 | llm = ChatOllama(model=model_name, temperature=0.0) 31 | 32 | elif model_provider == 'cohere': 33 | logger.info(f'Loading Cohere model: {model_name}') 34 | llm = ChatCohere(model=model_name, temperature=0.0) 35 | 36 | elif model_provider == 'openai': 37 | logger.info(f'Loading OpenAI model: {model_name}') 38 | llm = ChatOpenAI(model=model_name, temperature=0.0) 39 | 40 | return llm 41 | 42 | def run(model_provider: str, model_name: str, input: str): 43 | 44 | llm = get_chat_model(model_provider, model_name) 45 | output = llm.invoke(input) 46 | print(output.content) 47 | 48 | if __name__ == '__main__': 49 | 50 | from fire import Fire 51 | from dotenv import load_dotenv, find_dotenv 52 | 53 | # Load the environment variables from my .env file 54 | load_dotenv(find_dotenv()) 55 | 56 | Fire(run) -------------------------------------------------------------------------------- /src/plot.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | 3 | # Data for the number of games won by the Golden State Warriors in the last 2 NBA seasons 4 | seasons = ['2019-2020', '2020-2021'] 5 | games_won = [15, 39] 6 | 7 | # Create a bar plot 8 | plt.figure(figsize=(10, 6)) 9 | plt.bar(seasons, games_won, color='blue') 10 | 11 | # Add title and labels 12 | plt.title('Number of Games Won by the Golden State Warriors in the Last 2 NBA Seasons') 13 | plt.xlabel('Season') 14 | plt.ylabel('Number of Games Won') 15 | 16 | # Show plot 17 | plt.show() -------------------------------------------------------------------------------- /src/tools.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | from langchain_core.pydantic_v1 import BaseModel, Field 4 | from langchain.agents import Tool 5 | from langchain_experimental.utilities import PythonREPL 6 | from langchain_community.tools.tavily_search import TavilySearchResults 7 | 8 | def get_tools() -> List[Tool]: 9 | """ 10 | Returns a list of tools that can be used by an agent. 11 | """ 12 | tools = [] 13 | 14 | # Create the tools 15 | internet_search = get_tabily_search_tool() 16 | repl_tool = get_python_interpreter_tool() 17 | 18 | tools.append(internet_search) 19 | tools.append(repl_tool) 20 | 21 | return tools 22 | 23 | def get_tabily_search_tool() -> Tool: 24 | """ 25 | Returns a tool that searches the internet for a query using the Tavily API. 26 | 27 | Returns: 28 | TavilySearchResults: A tool that searches the internet for a query using the Tavily API. 29 | """ 30 | internet_search = TavilySearchResults() 31 | internet_search.name = 'internet_search' 32 | internet_search.description = 'Returns a list of relevant document snippets for a textual query retrieved from the internet.' 33 | 34 | class TavilySearchInput(BaseModel): 35 | query: str = Field(description='Query to search the internet with') 36 | 37 | internet_search.args_schema = TavilySearchInput 38 | 39 | return internet_search 40 | 41 | 42 | def get_python_interpreter_tool() -> Tool: 43 | """ 44 | Creates a tool that executes python code and returns the result. 45 | 46 | Returns: 47 | Tool: A tool that executes python code and returns the result. 48 | """ 49 | python_repl = PythonREPL() 50 | 51 | repl_tool = Tool( 52 | name='python_repl', 53 | description='Executes python code and returns the result. The code runs in a static sandbox without interactive mode, so print output or save output to a file.', 54 | func=python_repl.run, 55 | ) 56 | repl_tool.name = 'python_interpreter' 57 | 58 | # from langchain_core.pydantic_v1 import BaseModel, Field 59 | class ToolInput(BaseModel): 60 | code: str = Field(description='Python code to execute.') 61 | 62 | repl_tool.args_schema = ToolInput 63 | 64 | return repl_tool 65 | 66 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Paulescu/plot-generator-agent/53ebedf774fa30adf20fea39f447635e14e5b477/tests/__init__.py --------------------------------------------------------------------------------