├── lcforecast ├── tool │ ├── prompt.py │ ├── __init__.py │ └── tool.py └── agentkit │ ├── __init__.py │ ├── toolkit.py │ ├── prompt.py │ └── base.py ├── pyproject.toml ├── requirements.txt ├── .gitignore ├── setup.cfg ├── setup.py ├── LICENSE └── README.md /lcforecast/tool/prompt.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa 2 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.black] 2 | line-length = 99 -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | langchain==0.0.155 2 | pandas==1.5.3 -------------------------------------------------------------------------------- /lcforecast/agentkit/__init__.py: -------------------------------------------------------------------------------- 1 | """Forecast agent.""" 2 | from .toolkit import ForecastToolkit -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | __pycache__/ 3 | .vscode/ 4 | venv/ 5 | .ipynb_checkpoints/ 6 | *.pyc 7 | .envrc 8 | -------------------------------------------------------------------------------- /lcforecast/tool/__init__.py: -------------------------------------------------------------------------------- 1 | """Tools for forecasting with time series.""" 2 | from .tool import EmaForecastTool -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [flake8] 2 | ignore = E203, E501, W503 3 | per-file-ignores = __init__.py:F401 4 | max-line-length = 99 5 | 6 | [isort] 7 | profile = black 8 | multi_line_output = 3 9 | 10 | [mypy] 11 | ignore_missing_imports = True 12 | allow_redefinition = True -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import setuptools 2 | 3 | setuptools.setup( 4 | name="langchain_forecast", 5 | version="0.1.0", 6 | author="Pedro Lima", 7 | author_email="pedro.lima@gmail.com", 8 | description="Forecasting tool for langchain AI", 9 | long_description_content_type="text/markdown", 10 | url="https://github.com/pvl/langchain_forecast", 11 | packages=setuptools.find_packages(), 12 | install_requires=open('requirements.txt').readlines(), 13 | classifiers=[ 14 | "Programming Language :: Python :: 3", 15 | "Operating System :: OS Independent", 16 | ], 17 | ) -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | The MIT License 2 | 3 | Copyright (c) Pedro Lima 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in 13 | all copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 21 | THE SOFTWARE. -------------------------------------------------------------------------------- /lcforecast/agentkit/toolkit.py: -------------------------------------------------------------------------------- 1 | """Toolkit for interacting with a SQL database.""" 2 | from typing import List 3 | 4 | from pydantic import Field 5 | 6 | from langchain.agents.agent_toolkits.base import BaseToolkit 7 | from langchain.llms.base import BaseLLM 8 | from langchain.sql_database import SQLDatabase 9 | from langchain.tools import BaseTool 10 | from langchain.tools.sql_database.tool import ( 11 | InfoSQLDatabaseTool, 12 | ListSQLDatabaseTool, 13 | QueryCheckerTool, 14 | QuerySQLDataBaseTool, 15 | ) 16 | from lcforecast.tool import EmaForecastTool 17 | 18 | 19 | class ForecastToolkit(BaseToolkit): 20 | """Toolkit for interacting with SQL databases.""" 21 | 22 | db: SQLDatabase = Field(exclude=True) 23 | llm: BaseLLM = Field(exclude=True) 24 | 25 | @property 26 | def dialect(self) -> str: 27 | """Return string representation of dialect to use.""" 28 | return self.db.dialect 29 | 30 | class Config: 31 | """Configuration for this pydantic object.""" 32 | 33 | arbitrary_types_allowed = True 34 | 35 | def get_tools(self) -> List[BaseTool]: 36 | """Get the tools in the toolkit.""" 37 | return [ 38 | QuerySQLDataBaseTool(db=self.db), 39 | InfoSQLDatabaseTool(db=self.db), 40 | ListSQLDatabaseTool(db=self.db), 41 | QueryCheckerTool(db=self.db, llm=self.llm), 42 | EmaForecastTool() 43 | ] 44 | -------------------------------------------------------------------------------- /lcforecast/agentkit/prompt.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa 2 | 3 | FORECAST_PREFIX = """You are an agent designed to do forecasting with data from an SQL database. 4 | Given an input forecasting question, create a syntactically correct {dialect} query to run. 5 | This query should return a date value, usually aggregated by month or week, and a metric to forecast. 6 | With the results of the query execute the forecast tool and return the results to the user. 7 | Unless the user specifies a specific number of examples they wish to obtain, always limit your query to at most {top_k} results. 8 | You can order the results by a relevant column to return the most interesting examples in the database. 9 | Never query for all the columns from a specific table, only ask for the relevant columns given the question. 10 | You have access to tools for interacting with the database. 11 | Only use the below tools. Only use the information returned by the below tools to construct your final answer. 12 | You MUST double check your query before executing it. If you get an error while executing a query, rewrite the query and try again. 13 | 14 | DO NOT make any DML statements (INSERT, UPDATE, DELETE, DROP etc.) to the database. 15 | 16 | If the question does not seem related to the database, just return "I don't know" as the answer. 17 | """ 18 | 19 | FORECAST_SUFFIX = """Begin! 20 | 21 | Question: {input} 22 | Thought: I should look at the tables in the database to see what I can query. 23 | {agent_scratchpad}""" 24 | -------------------------------------------------------------------------------- /lcforecast/agentkit/base.py: -------------------------------------------------------------------------------- 1 | """Forecasting agent.""" 2 | from typing import Any, List, Optional 3 | 4 | from langchain.agents.agent import AgentExecutor 5 | from langchain.agents.mrkl.base import ZeroShotAgent 6 | from langchain.agents.mrkl.prompt import FORMAT_INSTRUCTIONS 7 | from langchain.callbacks.base import BaseCallbackManager 8 | from langchain.chains.llm import LLMChain 9 | from langchain.llms.base import BaseLLM 10 | from lcforecast.agentkit.prompt import FORECAST_PREFIX, FORECAST_SUFFIX 11 | from lcforecast.agentkit.toolkit import ForecastToolkit 12 | 13 | 14 | def create_forecast_agent( 15 | llm: BaseLLM, 16 | toolkit: ForecastToolkit, 17 | callback_manager: Optional[BaseCallbackManager] = None, 18 | prefix: str = FORECAST_PREFIX, 19 | suffix: str = FORECAST_SUFFIX, 20 | format_instructions: str = FORMAT_INSTRUCTIONS, 21 | input_variables: Optional[List[str]] = None, 22 | top_k: int = 10, 23 | max_iterations: Optional[int] = 15, 24 | max_execution_time: Optional[float] = None, 25 | early_stopping_method: str = "force", 26 | verbose: bool = False, 27 | **kwargs: Any, 28 | ) -> AgentExecutor: 29 | """Construct a sql agent from an LLM and tools.""" 30 | tools = toolkit.get_tools() 31 | prefix = prefix.format(dialect=toolkit.dialect, top_k=top_k) 32 | prompt = ZeroShotAgent.create_prompt( 33 | tools, 34 | prefix=prefix, 35 | suffix=suffix, 36 | format_instructions=format_instructions, 37 | input_variables=input_variables, 38 | ) 39 | llm_chain = LLMChain( 40 | llm=llm, 41 | prompt=prompt, 42 | callback_manager=callback_manager, 43 | ) 44 | tool_names = [tool.name for tool in tools] 45 | agent = ZeroShotAgent(llm_chain=llm_chain, allowed_tools=tool_names, **kwargs) 46 | return AgentExecutor.from_agent_and_tools( 47 | agent=agent, 48 | tools=tools, 49 | verbose=verbose, 50 | max_iterations=max_iterations, 51 | max_execution_time=max_execution_time, 52 | early_stopping_method=early_stopping_method, 53 | ) 54 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Forecasting Tool for LangChain AI 2 | 3 | This tool adds a simple exponential moving average forecast to the langchain AI. When used with an SQL tool it allows the agent to fetch history data from a database, forecast a future value and then reason with that information. 4 | 5 | ## Installation 6 | 7 | Clone this repository then run: 8 | 9 | ``` 10 | $ python -m pip install . 11 | ``` 12 | 13 | ## Usage 14 | 15 | The following code will create an agent that has the forecast tool. First export to the shell the OPENAI_API_KEY and your database connection string as DB_CONN. 16 | 17 | ``` 18 | import os 19 | from langchain.llms.openai import OpenAI 20 | from langchain.sql_database import SQLDatabase 21 | from lcforecast.agentkit import ForecastToolkit 22 | from lcforecast.agentkit.base import create_forecast_agent 23 | 24 | llm = OpenAI(temperature=0) 25 | db = SQLDatabase.from_uri(os.environ["DB_CONN"]) 26 | 27 | toolkit = ForecastToolkit(db=db, llm=llm) 28 | 29 | agent_executor = create_forecast_agent( 30 | llm=llm, 31 | toolkit=toolkit, 32 | verbose=True 33 | ) 34 | 35 | agent_executor.run("Forecast the invoice total for next month based on the last 12 months of history") 36 | ``` 37 | 38 | Example output for this agent run. 39 | 40 | ``` 41 | > Entering new AgentExecutor chain... 42 | Action: list_tables_sql_db 43 | Action Input: "" 44 | Observation: 45 | 46 | Thought: I should check the schema of the Invoice table to see what columns I can query 47 | Action: schema_sql_db 48 | Action Input: "Invoice" 49 | 50 | Thought: I should query the Invoice table for the total invoice amount for the last 12 months 51 | Action: query_sql_db 52 | Action Input: SELECT DATE_FORMAT(createdAt, '%Y-%m') AS month, SUM(total) AS total FROM Invoice WHERE createdAt > DATE_SUB(NOW(), INTERVAL 12 MONTH) GROUP BY month ORDER BY month DESC LIMIT 10 53 | 54 | Thought: I should use the EMA forecasting tool to predict the total invoice amount for June 2023 55 | Action: ema_forecast 56 | 57 | Observation: {'forecast_date': '2023-05-31', 'value': 3219.1} 58 | Thought: I now know the final answer 59 | Final Answer: The forecasted invoice total is 3219.1 60 | ``` -------------------------------------------------------------------------------- /lcforecast/tool/tool.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa 2 | """Tools for interacting with a SQL database.""" 3 | from pydantic import BaseModel, Extra, Field, validator, root_validator 4 | from typing import Any, Dict, Union 5 | import json 6 | from langchain.tools.base import BaseTool 7 | import re 8 | from datetime import datetime 9 | import pandas as pd 10 | 11 | 12 | RawDateType = Union[str, datetime] 13 | 14 | 15 | class EmaForecastTool(BaseTool): 16 | """Tool for forecasting with SQL results.""" 17 | 18 | name = "ema_forecast" 19 | description = """ 20 | Input to this tool is a valid json object with keys date and values. The date has a sequence 21 | of dates and values has the corresponding values for the metric to forecast. This is the data 22 | returned from the SQL query. This tool will return a forecast for the next period. 23 | 24 | If an error is returned, rewrite the inputs and try again. 25 | """ 26 | 27 | def _run(self, data: str) -> Dict: 28 | """ exponential moving average forecast""" 29 | import traceback 30 | 31 | if "'" in data: 32 | # Trying hack to replace double quotes 33 | data = data.replace("'", '"') 34 | 35 | try: 36 | data = json.loads(data) 37 | dates = data["date"] 38 | sequence = data["values"] 39 | seq = [] 40 | for val in sequence: 41 | if type(val) == str and re.match(r"Decimal\('[\d\.]+'\)", val): 42 | seq.append(eval(val)) 43 | elif type(val) in (list, tuple) and len(val) == 2: 44 | seq.append(val[1]) 45 | else: 46 | seq.append(val) 47 | dates = [conv_date(d) for d in dates] 48 | df = pd.DataFrame(zip(dates, seq), columns=["timestamp", "values"]) 49 | df["timestamp"] = pd.to_datetime(df["timestamp"]) 50 | df = df.sort_values("timestamp") 51 | df = remove_last_period(df) 52 | ydf = create_forecast_range_single(df) 53 | res = do_ewm(df, ydf) 54 | return {"forecast_date": str(res.timestamp.values[0]), "value": res["values"].values[0]} 55 | except: 56 | traceback.print_exc() 57 | return {"forecast_date": "?", "value": "unknown"} 58 | 59 | async def _arun(self, query: str) -> str: 60 | raise NotImplementedError("EmaForecastTool does not support async") 61 | 62 | 63 | def remove_last_period(df, timecol: str="timestamp"): 64 | # quick workaround, always remove the last period that may be incomplete 65 | last = list(sorted(df[timecol].values))[-1] 66 | return df[df[timecol] != last].copy() 67 | 68 | 69 | def conv_date(dt): 70 | if re.match(r"\d{4}-\d{2}", dt): 71 | return dt + "-01" 72 | return dt 73 | 74 | 75 | def estimate_period(df: pd.DataFrame, timecol: str="timestamp") -> str: 76 | """ simple estimation of period supporting daily, weekly and monthly """ 77 | avg_days = df[timecol].diff().mean().days 78 | if avg_days in [29,30,31]: 79 | return "M" 80 | elif avg_days in [6,7]: 81 | return "W" 82 | elif avg_days == 1: 83 | avg_hr = df[timecol].diff().mean().seconds / (60*60) 84 | if avg_hr > 23: 85 | return "D" 86 | return "" 87 | 88 | 89 | def create_forecast_range_single(df: pd.DataFrame, timecol:str="timestamp") -> pd.DataFrame: 90 | """ generate a dataframe with a single next period to forecast """ 91 | freq = estimate_period(df) 92 | rng = pd.date_range(df[timecol].values[-1], periods=2, freq=freq) 93 | return pd.DataFrame(rng[1:], columns=[timecol]) 94 | 95 | 96 | def create_forecast_range(df: pd.DataFrame, end_date: RawDateType, timecol:str="timestamp", min_periods:int=1) -> pd.DataFrame: 97 | """ generate a dataframe with dates to forecast """ 98 | freq = estimate_period(df) 99 | end_date = pd.to_datetime(end_date) 100 | rng = pd.date_range(df.timestamp.values[-1], end=end_date, freq=freq) 101 | # if only the start period, then make a new range with two periods 102 | if len(rng) < min_periods + 1: 103 | rng = pd.date_range(df[timecol].values[-1], periods=min_periods + 1, freq=freq) 104 | return pd.DataFrame(rng[1:], columns=[timecol]) 105 | 106 | 107 | def do_ewm(df: pd.DataFrame, ydf: pd.DataFrame, timecol: str="timestamp", valcol: str="values", include_history: bool=False) -> pd.DataFrame: 108 | """ execute EMA forecast """ 109 | resdf = df.copy() 110 | 111 | for ts in sorted(ydf[timecol].values): 112 | yhat = resdf[valcol].ewm(com=0.8).mean().values[-1] 113 | resdf = pd.concat([resdf, pd.DataFrame([{timecol: ts, valcol: yhat}])]) 114 | if include_history: 115 | return resdf 116 | else: 117 | return resdf[resdf[timecol].isin(ydf[timecol].values)].copy() 118 | --------------------------------------------------------------------------------