├── api ├── __init__.py └── app.py ├── model ├── __init__.py ├── dev │ ├── __init__.py │ └── properties.json ├── dataprep.py ├── reference_codesnippets.py ├── alternate_functions.py └── llm.py ├── .gitignore ├── requirements.txt ├── Dockerfile ├── setup.py ├── streamlit └── streamlit_app.py ├── README.md └── notebooks ├── rag_llm_v2.ipynb └── rag_llm_learning.ipynb /api/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /model/dev/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | data/ 2 | build/ 3 | sheet_simplify.egg-info -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sivadhulipala1999/SheetSimplify_with_RAG/HEAD/requirements.txt -------------------------------------------------------------------------------- /model/dev/properties.json: -------------------------------------------------------------------------------- 1 | { 2 | "huggingface_model_name": "mistralai/Mistral-7B-Instruct-v0.2", 3 | "embedding_model_name": "thenlper/gte-small", 4 | "api_key": "KEY GOES HERE" 5 | } -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | # syntax=docker/dockerfile:1 2 | 3 | FROM python:3.9 4 | WORKDIR /app 5 | COPY requirements.txt . 6 | RUN pip install -r requirements.txt 7 | COPY . . 8 | CMD ["python", "api/app.py"] 9 | EXPOSE 8000 10 | # EXPOSE exposes the port only in case of docker -p and not with -P -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup( 4 | name='sheet-simplify', 5 | version='0.1', 6 | packages=find_packages(exclude=['tests']), 7 | install_requires=[ 8 | 'pylint', 9 | 'torch', 10 | 'torchvision' 11 | ] 12 | ) -------------------------------------------------------------------------------- /streamlit/streamlit_app.py: -------------------------------------------------------------------------------- 1 | import streamlit as st 2 | from model import llm 3 | 4 | st.title("Sheet Simplify") 5 | 6 | 7 | user_api_key = st.sidebar.text_input('API Key', type='password') 8 | 9 | 10 | def invoke_llm(user_api_key, text): 11 | llm_chain = llm.setup("dev", user_api_key) 12 | st.info(llm_chain.invoke(text).split( 13 | "assistant")[1].split("<|im_end|>")[0]) 14 | 15 | 16 | with st.form("my_form"): 17 | text = st.text_area("Enter text:", "Give a summary of the data provided") 18 | submitted = st.form_submit_button("Submit") 19 | if (user_api_key is None) or (user_api_key == ""): 20 | st.warning("Please enter your API key. It is mandatory", icon="⚠") 21 | else: 22 | invoke_llm(user_api_key, text) 23 | -------------------------------------------------------------------------------- /api/app.py: -------------------------------------------------------------------------------- 1 | """Flask web server exposing endpoints to LLM chats.""" 2 | import os 3 | 4 | from flask import Flask, request, jsonify 5 | from model import llm 6 | 7 | os.environ["CUDA_VISIBLE_DEVICES"] = "" # Do not use GPU 8 | 9 | app = Flask(__name__) 10 | 11 | 12 | @app.route("/") 13 | def index(): 14 | """Provide simple health check route.""" 15 | return "Welcome to Sheet Simplify!" 16 | 17 | 18 | @app.route("/v1/summary", methods=["GET", "POST"]) 19 | def summary(): 20 | """Provide a summary of the data provided. Responds to both GET and POST requests.""" 21 | summary_question = "Give me a summary of the data provided" 22 | llm_chain = llm.setup("dev") 23 | return llm_chain.invoke(summary_question).split("assistant")[1] 24 | 25 | 26 | def main(): 27 | """Run the app.""" 28 | app.run(host="0.0.0.0", port=8000, debug=False) 29 | 30 | 31 | if __name__ == "__main__": 32 | main() 33 | -------------------------------------------------------------------------------- /model/dataprep.py: -------------------------------------------------------------------------------- 1 | from langchain.document_loaders.csv_loader import CSVLoader 2 | from langchain.vectorstores import FAISS 3 | from langchain.embeddings import HuggingFaceInstructEmbeddings 4 | from langchain_community.vectorstores.utils import DistanceStrategy 5 | 6 | 7 | EMBEDDING_MODEL_NAME = "thenlper/gte-small" 8 | 9 | 10 | def prep_vectorstore_csv(filepath): 11 | """Take the CSV data and load it into the FAISS vector store using HuggingFace embeddings""" 12 | loader = CSVLoader(file_path=filepath, encoding="utf-8", csv_args={ 13 | 'delimiter': ','}) 14 | data = loader.load() 15 | embeddings = HuggingFaceInstructEmbeddings( 16 | model_name=EMBEDDING_MODEL_NAME, 17 | multi_process=True, 18 | model_kwargs={"device": "cuda"}, 19 | encode_kwargs={"normalize_embeddings": True},) 20 | vectorstore = FAISS.from_documents( 21 | data, embeddings, distance_strategy=DistanceStrategy.COSINE) 22 | return vectorstore 23 | 24 | 25 | ######## Personal Notes ######## 26 | # Using cosine similarity here since Euclidean will not be accurate in multi-dimensional spaces and with sparse vectors 27 | -------------------------------------------------------------------------------- /model/reference_codesnippets.py: -------------------------------------------------------------------------------- 1 | """File with a bunch a code snippets I found online. Not a good way to organize such content. To be refactored""" 2 | 3 | 4 | # from langchain.agents.agent_toolkits import SQLDatabaseToolkit 5 | # from langchain.agents.agent_types import AgentType 6 | # from langchain.agents import create_sql_agent 7 | # from langchain.agents.agent_toolkits import SQLDatabaseToolkit 8 | # from langchain.agents import AgentExecutor 9 | 10 | # pg_uri = f"postgresql+psycopg2://{username}:{password}@{host}:{port}/{mydatabase}" 11 | # db = SQLDatabase.from_uri(pg_uri) 12 | 13 | # repo_id = "mistralai/Mistral-7B-Instruct-v0.2" 14 | 15 | # llm = HuggingFaceEndpoint( 16 | # repo_id=repo_id, max_length=128, temperature=0.5, token=HUGGINGFACEHUB_API_TOKEN 17 | # ) 18 | 19 | 20 | # agent_executor = create_sql_agent( 21 | # llm=llm, 22 | # db=db, 23 | # verbose=True, 24 | # agent_type=AgentType.ZERO_SHOT_REACT_DESCRIPTION 25 | # ) 26 | 27 | # agent_executor.run( 28 | # "what is the id of host spencer ?" 29 | # ) 30 | 31 | 32 | # import os 33 | # from langchain import PromptTemplate, HuggingFaceHub, LLMChain, OpenAI, SQLDatabase, HuggingFacePipeline 34 | # from langchain.agents import create_csv_agent 35 | # from langchain.chains.sql_database.base import SQLDatabaseChain 36 | # from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AutoConfig 37 | # import transformers 38 | 39 | # model_id = 'google/flan-t5-xxl' 40 | # config = AutoConfig.from_pretrained(model_id) 41 | # tokenizer = AutoTokenizer.from_pretrained(model_id) 42 | # model = AutoModelForSeq2SeqLM.from_pretrained(model_id, config=config) 43 | # pipe = pipeline('text2text-generation', 44 | # model=model, 45 | # tokenizer=tokenizer, 46 | # max_length=1024 47 | # ) 48 | # local_llm = HuggingFacePipeline(pipeline=pipe) 49 | 50 | # agent = create_csv_agent(llm=local_llm, path="dummy_data.csv", verbose=True) 51 | # agent.run('how many unique status are there?') 52 | -------------------------------------------------------------------------------- /model/alternate_functions.py: -------------------------------------------------------------------------------- 1 | from transformers import pipeline 2 | from transformers import AutoModelForCausalLM, BitsAndBytesConfig 3 | from langchain_experimental.agents import create_pandas_dataframe_agent 4 | from langchain.prompts import ChatPromptTemplate 5 | 6 | import torch 7 | import pandas as pd 8 | 9 | 10 | """The file contains all the functions which would be good for reference""" 11 | 12 | 13 | def prep_local_llm_chain(model_name, prompt, tokenizer, data): 14 | """Add the prompt to the Local LLM via a chain to make it work for our case""" 15 | 16 | # Configure quantization in the LLM to make it more efficient 17 | bnb_config = BitsAndBytesConfig( 18 | load_in_4bit=True, 19 | bnb_4bit_use_double_quant=True, 20 | bnb_4bit_quant_type="nf4", 21 | bnb_4bit_compute_dtype=torch.bfloat16, 22 | ) 23 | 24 | # 25 | model = AutoModelForCausalLM.from_pretrained( 26 | model_name, quantization_config=bnb_config) 27 | 28 | reader_llm = pipeline( 29 | model=model, 30 | tokenizer=tokenizer, 31 | task="text-generation", 32 | do_sample=True, 33 | temperature=0.2, 34 | repetition_penalty=1.1, 35 | return_full_text=False, 36 | max_new_tokens=500, 37 | ) 38 | 39 | agent_executor = create_pandas_dataframe_agent( 40 | reader_llm, data, agent_type="tool-calling", verbose=True 41 | ) 42 | 43 | return agent_executor 44 | 45 | 46 | def prep_prompt(): 47 | """Build a standard chat prompt with RAG as the idea -- did not work for my case""" 48 | messages = [ 49 | ("system", "Answer the query using the context provided. Be succinct. Your name is The Data Master"), 50 | ("human", "Hello! I come to you for understanding the data."), 51 | ("ai", "Hey! All good. Just ask me the exact questions and I will be glad to help."), 52 | ("human", "{user_input}") 53 | ] 54 | chat_template = ChatPromptTemplate.from_messages(messages) 55 | messages = chat_template.format_messages( 56 | user_input="What are the columns in the data?") 57 | return messages 58 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SheetSimplify with RAG LLMs 2 | The aim of this project is to simplify data retrieval from Excel Sheets using RAG LLMs, hence the name! Many organizations currently store their data in Excel sheets and have stored decades' worth of data in them. However, retrieving data from these sheets becomes quite difficult unless the user has some technical background. The idea of Natural Language Querying (NLQ) is to exactly solve this issue by allowing users to ask simple questions to a model and get appropriate and rational responses. This NLQ can be achieved using RAG LLMs, which is what we aim to build in this project. 3 | 4 | # The approach 5 | Instead of fine-tuning the model on the relevant data, which consumes significant resources, we shall attempt to utilize prompt-engineering to make the LLM answer based on the context provided via the dataset. This is the basic idea behind RAG. 6 | 7 | Once the model is done, we shall expose a simple API endpoint which responds with a summary of key information from the data. The user would also like to query the model further, for which we provide a streamlit app. 8 | 9 | Since we would like the whole application to be distributed, we would 'dockerize' it. 10 | 11 | # The repo 12 | The repo structure is based on a standard template for production ML projects [3]. 13 | 14 | Notebooks: Typically used for data analysis and exploration. Since we are dealing with LLMs, I decided to add my understanding of different concepts to the notebooks here.
15 | Model: Contains the python file that preps and invokes the LLM.
16 | API: Contains the flask file which exposes the endpoint to make a simple call to the LLM. Following are the endpoints exposed so far
17 |   - "/" - home page which just says "Welcome to Sheet Simplify!"
18 |   - "/v1/summary" - which provides a summary of the data provided as per the LLM
19 | streamlit: Contains the code to setup the streamlit app. Not in the original template. 20 | 21 | There are several other folders from the original template which were not relevant to this case and hence have been omitted. 22 | 23 | # Usage Tips 24 | 1. Before running the scripts in this repo, it is very important to perform a pip install on the entire project so that the internal packages become available to each other. To do this, run the following, 25 | pip install . 26 | 2. Regardless of the script you want to run, it is very important to execute it from the root directory. For example, you would run the llm.py file as follows from the root directory 27 | python -m model/llm.py
28 | This prepares the model and starts the chat in command line. 29 | 3. For the streamlit app, run the following command 30 | streamlit run streamlit/streamlit_app.py
31 | This initiates the streamlit web app frontend in the localhost. 32 | 4. You can also execute the flask app to hit the API endpoints using the command 33 | python -m api/app.py
34 | This initiates the flask server in localhost. Remember, you would have to add the endpoint /v1/summary to the URL in the browser to hit the endpoint. 35 | Of course, you can also hit these endpoints from other API tools like Postman. 36 | 37 | ## References 38 | 1. LangChain's Blogpost on Retrieval from Excel Sheet 39 | 2. This YouTube Video explaining how to use it 40 | 3. Production ML Project Template 41 | 4. Open Source LLMs as LangChain Agents 42 | 5. Advanced RAG tutorial from HuggingFace 43 | -------------------------------------------------------------------------------- /notebooks/rag_llm_v2.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# SheetSimplify - Alternate \n", 8 | "\n", 9 | "I created a simple RAG LLM in the file \"rag_llm.ipynb\" by using LangChain to make a call to OpenAI's GPT 3.5 Turbo model with the context being fed from the excel file for the RAG part. Now, a major problem with this approach is using the API token from OpenAI which has a price tag. In this notebook, I would like to take a different approach and do the following \n", 10 | "\n", 11 | "- Use open source LLMs to avoid the cost concerns\n", 12 | "- Bring the model to production " 13 | ] 14 | }, 15 | { 16 | "cell_type": "code", 17 | "execution_count": 1, 18 | "metadata": {}, 19 | "outputs": [], 20 | "source": [ 21 | "from langchain_community.llms import HuggingFaceEndpoint" 22 | ] 23 | }, 24 | { 25 | "cell_type": "code", 26 | "execution_count": 3, 27 | "metadata": {}, 28 | "outputs": [], 29 | "source": [ 30 | "import os\n", 31 | "\n", 32 | "os.environ[\"HUGGINGFACEHUB_API_TOKEN\"] = \"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\"" 33 | ] 34 | }, 35 | { 36 | "cell_type": "code", 37 | "execution_count": 6, 38 | "metadata": {}, 39 | "outputs": [], 40 | "source": [ 41 | "from langchain.chains import LLMChain\n", 42 | "from langchain_core.prompts import PromptTemplate" 43 | ] 44 | }, 45 | { 46 | "cell_type": "code", 47 | "execution_count": 11, 48 | "metadata": {}, 49 | "outputs": [], 50 | "source": [ 51 | "question = \"Who won the FIFA World Cup in the year 1994? \"\n", 52 | "\n", 53 | "template = \"\"\"Question: {question}\n", 54 | "\n", 55 | "Answer: Let's think step by step.\"\"\"\n", 56 | "\n", 57 | "prompt = PromptTemplate.from_template(template)" 58 | ] 59 | }, 60 | { 61 | "cell_type": "code", 62 | "execution_count": 12, 63 | "metadata": {}, 64 | "outputs": [ 65 | { 66 | "name": "stderr", 67 | "output_type": "stream", 68 | "text": [ 69 | "WARNING! max_length is not default parameter.\n", 70 | " max_length was transferred to model_kwargs.\n", 71 | " Please make sure that max_length is what you intended.\n", 72 | "WARNING! token is not default parameter.\n", 73 | " token was transferred to model_kwargs.\n", 74 | " Please make sure that token is what you intended.\n", 75 | "d:\\Projects\\SheetSimplify_with_RAG\\.venv\\lib\\site-packages\\langchain_core\\_api\\deprecation.py:117: LangChainDeprecationWarning: The function `run` was deprecated in LangChain 0.1.0 and will be removed in 0.2.0. Use invoke instead.\n", 76 | " warn_deprecated(\n" 77 | ] 78 | }, 79 | { 80 | "name": "stdout", 81 | "output_type": "stream", 82 | "text": [ 83 | "The token has not been saved to the git credentials helper. Pass `add_to_git_credential=True` in this function directly or `--add-to-git-credential` if using via `huggingface-cli` if you want to set the git credential as well.\n", 84 | "Token is valid (permission: read).\n", 85 | "Your token has been saved to C:\\Users\\parth\\.cache\\huggingface\\token\n", 86 | "Login successful\n", 87 | " The FIFA World Cup is an international football tournament that takes place every four years. The 1994 FIFA World Cup was held in the United States from June 17 to July 17, 1994. The final match was played on July 17, 1994. The teams that reached the final were Brazil and Italy. Brazil won the match with a score of 0-0 (3-2 in the penalty shootout). Therefore, Brazil won the FIFA World Cup in the year 1994.\n" 88 | ] 89 | } 90 | ], 91 | "source": [ 92 | "repo_id = \"mistralai/Mistral-7B-Instruct-v0.2\"\n", 93 | "\n", 94 | "llm = HuggingFaceEndpoint(\n", 95 | " repo_id=repo_id, max_length=128, temperature=0.5, token=os.environ[\"HUGGINGFACEHUB_API_TOKEN\"]\n", 96 | ")\n", 97 | "llm_chain = LLMChain(prompt=prompt, llm=llm)\n", 98 | "print(llm_chain.run(question))" 99 | ] 100 | } 101 | ], 102 | "metadata": { 103 | "kernelspec": { 104 | "display_name": ".venv", 105 | "language": "python", 106 | "name": "python3" 107 | }, 108 | "language_info": { 109 | "codemirror_mode": { 110 | "name": "ipython", 111 | "version": 3 112 | }, 113 | "file_extension": ".py", 114 | "mimetype": "text/x-python", 115 | "name": "python", 116 | "nbconvert_exporter": "python", 117 | "pygments_lexer": "ipython3", 118 | "version": "3.9.8" 119 | } 120 | }, 121 | "nbformat": 4, 122 | "nbformat_minor": 2 123 | } 124 | -------------------------------------------------------------------------------- /model/llm.py: -------------------------------------------------------------------------------- 1 | """Prepares the RAG LLM for a chat with the user""" 2 | 3 | from langchain_community.llms import HuggingFaceHub 4 | from langchain.document_loaders.csv_loader import CSVLoader 5 | from langchain_community.vectorstores import FAISS 6 | from langchain_community.embeddings import HuggingFaceInstructEmbeddings 7 | from langchain_community.vectorstores.utils import DistanceStrategy 8 | from langchain_core.prompts import PromptTemplate 9 | from langchain_core.runnables import RunnablePassthrough 10 | 11 | import json 12 | import os 13 | 14 | 15 | def prep_llm_chain(repo_id, prompt, retriever): 16 | """Prepare the LLM Agent which answers based on the data given and follows the provided prompt""" 17 | llm = HuggingFaceHub( 18 | repo_id=repo_id, huggingfacehub_api_token=os.environ["HUGGINGFACEHUB_API_TOKEN"], 19 | model_kwargs={"temperature": 0.5, "max_length": 1024}) 20 | 21 | agent_executor = ( 22 | {"context": retriever, "question": RunnablePassthrough()} 23 | | prompt 24 | | llm 25 | ) 26 | 27 | return agent_executor 28 | 29 | 30 | def chat(llm): 31 | """Initiates a chat with the LLM via the command. Run this script directly via command line for this.""" 32 | print("> Chat with the Sheet_Simplify. Please enter your queries here. Press 'quit' to stop the chat") 33 | while True: 34 | user_input = input("> ") 35 | if user_input == "quit": 36 | break 37 | else: 38 | print("Sheet Simplify: ", llm.invoke( 39 | user_input).split("assistant")[1]) 40 | 41 | 42 | def prep_rag_prompt(): 43 | """Prepare the prompt to make the LLM a RAG based model that answers queries like a chatbot""" 44 | 45 | # Note that Mistral does not take system prompts directly and hence a bit of formatting is needed 46 | # Different models have different prompt structures which is absolutely painful to navigate through 47 | # source of this fix: https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1/discussions/41 48 | sys_prompt = """Using the information contained in the context, 49 | give a comprehensive answer to the question. 50 | Respond only to the question asked, response should be concise and relevant to the question. 51 | Provide the number of the source document when relevant. 52 | If the answer cannot be deduced from the context, do not give an answer.""" 53 | 54 | prompt = """Context: 55 | {context} 56 | --- 57 | Now here is the question you need to answer. 58 | 59 | Question: {question}""" 60 | 61 | prefix = "<|im_start|>" 62 | suffix = "<|im_end|>\n" 63 | sys_format = prefix + "system\n" + sys_prompt + suffix 64 | user_format = prefix + "user\n" + prompt + suffix 65 | assistant_format = prefix + "assistant\n" 66 | input_text = sys_format + user_format + assistant_format 67 | 68 | prompt_in_chat_format = [ 69 | { 70 | "role": "user", 71 | "content": input_text, 72 | }, 73 | ] 74 | 75 | return PromptTemplate.from_template(input_text) 76 | 77 | 78 | def prep_vectorstore_csv(filepath, embedding_model_name): 79 | """Takes the CSV data and loads it into the FAISS vector store using HuggingFace embeddings""" 80 | loader = CSVLoader(file_path=filepath, encoding="utf-8", csv_args={ 81 | 'delimiter': ','}) 82 | data = loader.load() 83 | embeddings = HuggingFaceInstructEmbeddings( 84 | model_name=embedding_model_name, 85 | model_kwargs={"device": "cuda"}, 86 | encode_kwargs={"normalize_embeddings": True},) 87 | vectorstore = FAISS.from_documents( 88 | data, embeddings, distance_strategy=DistanceStrategy.COSINE) 89 | return vectorstore.as_retriever() 90 | 91 | 92 | def setup(environment_name, user_api_key=None): 93 | """Setup the LLM Chain for further usage from API or Streamlit app""" 94 | with open(f"model/{environment_name}/properties.json") as f: 95 | json_contents = f.read() 96 | env_data = json.loads(json_contents) 97 | 98 | if user_api_key is None: 99 | os.environ["HUGGINGFACEHUB_API_TOKEN"] = env_data["api_key"] 100 | else: 101 | os.environ["HUGGINGFACEHUB_API_TOKEN"] = user_api_key 102 | model_name = env_data["huggingface_model_name"] 103 | embedding_model_name = env_data["embedding_model_name"] 104 | 105 | # Prep the CSV data 106 | retriever = prep_vectorstore_csv("data/Train.csv", embedding_model_name) 107 | 108 | # Prepare the prompt for the LLM and chain it to the model 109 | # tokenizer, prompt_template = prep_rag_prompt(model_name) 110 | prompt_template = prep_rag_prompt() 111 | 112 | # prep the LLM Agent 113 | llm_chain = prep_llm_chain(model_name, prompt_template, retriever) 114 | 115 | return llm_chain 116 | 117 | 118 | def main(environment_name, user_api_key=None, API_CALL=False, APP_INVOKE=False): 119 | llm_chain = setup(environment_name, user_api_key) 120 | chat(llm_chain) 121 | 122 | 123 | if __name__ == "__main__": 124 | main("dev") 125 | -------------------------------------------------------------------------------- /notebooks/rag_llm_learning.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# SheetSimplify \n", 8 | "\n", 9 | "In this notebook, we shall implement a RAG LLM to retrieve data from an excel sheet. The excel sheet has been taken from here and focuses on sales data for some company. This sheet is meant to be an example for our prototype and can of course be replaced with excel sheets which are internal to the organization. " 10 | ] 11 | }, 12 | { 13 | "cell_type": "markdown", 14 | "metadata": {}, 15 | "source": [ 16 | "## Data Preprocessing" 17 | ] 18 | }, 19 | { 20 | "cell_type": "code", 21 | "execution_count": 44, 22 | "metadata": {}, 23 | "outputs": [], 24 | "source": [ 25 | "import pandas as pd \n", 26 | "\n", 27 | "df = pd.read_csv(\"data/Train.csv\")" 28 | ] 29 | }, 30 | { 31 | "cell_type": "code", 32 | "execution_count": 45, 33 | "metadata": {}, 34 | "outputs": [ 35 | { 36 | "data": { 37 | "text/html": [ 38 | "
\n", 39 | "\n", 52 | "\n", 53 | " \n", 54 | " \n", 55 | " \n", 56 | " \n", 57 | " \n", 58 | " \n", 59 | " \n", 60 | " \n", 61 | " \n", 62 | " \n", 63 | " \n", 64 | " \n", 65 | " \n", 66 | " \n", 67 | " \n", 68 | " \n", 69 | " \n", 70 | " \n", 71 | " \n", 72 | " \n", 73 | " \n", 74 | " \n", 75 | " \n", 76 | " \n", 77 | " \n", 78 | " \n", 79 | " \n", 80 | " \n", 81 | " \n", 82 | " \n", 83 | " \n", 84 | " \n", 85 | " \n", 86 | " \n", 87 | " \n", 88 | " \n", 89 | " \n", 90 | " \n", 91 | " \n", 92 | " \n", 93 | " \n", 94 | " \n", 95 | " \n", 96 | " \n", 97 | " \n", 98 | " \n", 99 | " \n", 100 | " \n", 101 | " \n", 102 | " \n", 103 | " \n", 104 | " \n", 105 | " \n", 106 | " \n", 107 | " \n", 108 | " \n", 109 | " \n", 110 | " \n", 111 | " \n", 112 | " \n", 113 | " \n", 114 | " \n", 115 | " \n", 116 | " \n", 117 | " \n", 118 | " \n", 119 | " \n", 120 | " \n", 121 | " \n", 122 | " \n", 123 | " \n", 124 | " \n", 125 | " \n", 126 | " \n", 127 | " \n", 128 | " \n", 129 | " \n", 130 | " \n", 131 | " \n", 132 | " \n", 133 | " \n", 134 | " \n", 135 | " \n", 136 | " \n", 137 | " \n", 138 | " \n", 139 | " \n", 140 | " \n", 141 | " \n", 142 | " \n", 143 | " \n", 144 | " \n", 145 | " \n", 146 | " \n", 147 | "
Item_IdentifierItem_WeightItem_Fat_ContentItem_VisibilityItem_TypeItem_MRPOutlet_IdentifierOutlet_Establishment_YearOutlet_SizeOutlet_Location_TypeOutlet_TypeItem_Outlet_Sales
0FDA159.30Low Fat0.016047Dairy249.8092OUT0491999MediumTier 1Supermarket Type13735.1380
1DRC015.92Regular0.019278Soft Drinks48.2692OUT0182009MediumTier 3Supermarket Type2443.4228
2FDN1517.50Low Fat0.016760Meat141.6180OUT0491999MediumTier 1Supermarket Type12097.2700
3FDX0719.20Regular0.000000Fruits and Vegetables182.0950OUT0101998NaNTier 3Grocery Store732.3800
4NCD198.93Low Fat0.000000Household53.8614OUT0131987HighTier 3Supermarket Type1994.7052
\n", 148 | "
" 149 | ], 150 | "text/plain": [ 151 | " Item_Identifier Item_Weight Item_Fat_Content Item_Visibility \\\n", 152 | "0 FDA15 9.30 Low Fat 0.016047 \n", 153 | "1 DRC01 5.92 Regular 0.019278 \n", 154 | "2 FDN15 17.50 Low Fat 0.016760 \n", 155 | "3 FDX07 19.20 Regular 0.000000 \n", 156 | "4 NCD19 8.93 Low Fat 0.000000 \n", 157 | "\n", 158 | " Item_Type Item_MRP Outlet_Identifier \\\n", 159 | "0 Dairy 249.8092 OUT049 \n", 160 | "1 Soft Drinks 48.2692 OUT018 \n", 161 | "2 Meat 141.6180 OUT049 \n", 162 | "3 Fruits and Vegetables 182.0950 OUT010 \n", 163 | "4 Household 53.8614 OUT013 \n", 164 | "\n", 165 | " Outlet_Establishment_Year Outlet_Size Outlet_Location_Type \\\n", 166 | "0 1999 Medium Tier 1 \n", 167 | "1 2009 Medium Tier 3 \n", 168 | "2 1999 Medium Tier 1 \n", 169 | "3 1998 NaN Tier 3 \n", 170 | "4 1987 High Tier 3 \n", 171 | "\n", 172 | " Outlet_Type Item_Outlet_Sales \n", 173 | "0 Supermarket Type1 3735.1380 \n", 174 | "1 Supermarket Type2 443.4228 \n", 175 | "2 Supermarket Type1 2097.2700 \n", 176 | "3 Grocery Store 732.3800 \n", 177 | "4 Supermarket Type1 994.7052 " 178 | ] 179 | }, 180 | "execution_count": 45, 181 | "metadata": {}, 182 | "output_type": "execute_result" 183 | } 184 | ], 185 | "source": [ 186 | "df.head()" 187 | ] 188 | }, 189 | { 190 | "cell_type": "code", 191 | "execution_count": 46, 192 | "metadata": {}, 193 | "outputs": [ 194 | { 195 | "name": "stdout", 196 | "output_type": "stream", 197 | "text": [ 198 | "\n", 199 | "RangeIndex: 8523 entries, 0 to 8522\n", 200 | "Data columns (total 12 columns):\n", 201 | " # Column Non-Null Count Dtype \n", 202 | "--- ------ -------------- ----- \n", 203 | " 0 Item_Identifier 8523 non-null object \n", 204 | " 1 Item_Weight 7060 non-null float64\n", 205 | " 2 Item_Fat_Content 8523 non-null object \n", 206 | " 3 Item_Visibility 8523 non-null float64\n", 207 | " 4 Item_Type 8523 non-null object \n", 208 | " 5 Item_MRP 8523 non-null float64\n", 209 | " 6 Outlet_Identifier 8523 non-null object \n", 210 | " 7 Outlet_Establishment_Year 8523 non-null int64 \n", 211 | " 8 Outlet_Size 6113 non-null object \n", 212 | " 9 Outlet_Location_Type 8523 non-null object \n", 213 | " 10 Outlet_Type 8523 non-null object \n", 214 | " 11 Item_Outlet_Sales 8523 non-null float64\n", 215 | "dtypes: float64(4), int64(1), object(7)\n", 216 | "memory usage: 799.2+ KB\n" 217 | ] 218 | } 219 | ], 220 | "source": [ 221 | "df.info()" 222 | ] 223 | }, 224 | { 225 | "cell_type": "code", 226 | "execution_count": 47, 227 | "metadata": {}, 228 | "outputs": [ 229 | { 230 | "data": { 231 | "text/plain": [ 232 | "Item_Identifier 0\n", 233 | "Item_Weight 1463\n", 234 | "Item_Fat_Content 0\n", 235 | "Item_Visibility 0\n", 236 | "Item_Type 0\n", 237 | "Item_MRP 0\n", 238 | "Outlet_Identifier 0\n", 239 | "Outlet_Establishment_Year 0\n", 240 | "Outlet_Size 2410\n", 241 | "Outlet_Location_Type 0\n", 242 | "Outlet_Type 0\n", 243 | "Item_Outlet_Sales 0\n", 244 | "dtype: int64" 245 | ] 246 | }, 247 | "execution_count": 47, 248 | "metadata": {}, 249 | "output_type": "execute_result" 250 | } 251 | ], 252 | "source": [ 253 | "df.isna().sum()" 254 | ] 255 | }, 256 | { 257 | "cell_type": "markdown", 258 | "metadata": {}, 259 | "source": [ 260 | "A quick glance at the data shows us that there is no discrepancy in the data types and that there are couple of null entries in the columns \"Item_Weight\" and \"Outlet_Size\". We have to remember however, that this is a simple dataset downloaded from Kaggle and hence the data is extremely friendly. In real life, this is rarely the case and hence data pre-processing would take much more steps and much longer. " 261 | ] 262 | }, 263 | { 264 | "cell_type": "markdown", 265 | "metadata": {}, 266 | "source": [ 267 | "## Building the LLM Agent " 268 | ] 269 | }, 270 | { 271 | "cell_type": "markdown", 272 | "metadata": {}, 273 | "source": [ 274 | "Using LangChain, we shall now build an LLM that uses the DataFrame for its context. Under the hood, this would mean that the context is set to the particular sheet we provide and then a system prompt is generated to make sure that the LLM utilizes this context appropriately. \n", 275 | "\n", 276 | "The agent would also need to make API calls to OpenAI. Hence, we provide an OpenAI key to fulfill this requirement. " 277 | ] 278 | }, 279 | { 280 | "cell_type": "code", 281 | "execution_count": 48, 282 | "metadata": {}, 283 | "outputs": [], 284 | "source": [ 285 | "import os \n", 286 | "\n", 287 | "# os.environ[\"OPENAI_API_KEY\"] = \"\" # Your key goes here" 288 | ] 289 | }, 290 | { 291 | "cell_type": "code", 292 | "execution_count": 49, 293 | "metadata": {}, 294 | "outputs": [], 295 | "source": [ 296 | "from langchain_experimental.agents import create_csv_agent\n", 297 | "from langchain_openai import ChatOpenAI" 298 | ] 299 | }, 300 | { 301 | "cell_type": "code", 302 | "execution_count": 50, 303 | "metadata": {}, 304 | "outputs": [], 305 | "source": [ 306 | "llm = ChatOpenAI(model=\"gpt-3.5-turbo\", temperature=0) \n", 307 | "agent = create_csv_agent(llm, \n", 308 | " 'data/Train.csv',\n", 309 | " agent_type=\"openai-tools\", \n", 310 | " verbose=True)" 311 | ] 312 | }, 313 | { 314 | "cell_type": "code", 315 | "execution_count": 51, 316 | "metadata": {}, 317 | "outputs": [ 318 | { 319 | "data": { 320 | "text/plain": [ 321 | "AgentExecutor(verbose=True, agent=RunnableMultiActionAgent(runnable=RunnableAssign(mapper={\n", 322 | " agent_scratchpad: RunnableLambda(lambda x: format_to_openai_tool_messages(x['intermediate_steps']))\n", 323 | "})\n", 324 | "| ChatPromptTemplate(input_variables=['agent_scratchpad', 'input'], input_types={'agent_scratchpad': typing.List[typing.Union[langchain_core.messages.ai.AIMessage, langchain_core.messages.human.HumanMessage, langchain_core.messages.chat.ChatMessage, langchain_core.messages.system.SystemMessage, langchain_core.messages.function.FunctionMessage, langchain_core.messages.tool.ToolMessage]]}, messages=[SystemMessage(content='\\nYou are working with a pandas dataframe in Python. The name of the dataframe is `df`.\\nThis is the result of `print(df.head())`:\\n| | Item_Identifier | Item_Weight | Item_Fat_Content | Item_Visibility | Item_Type | Item_MRP | Outlet_Identifier | Outlet_Establishment_Year | Outlet_Size | Outlet_Location_Type | Outlet_Type | Item_Outlet_Sales |\\n|---:|:------------------|--------------:|:-------------------|------------------:|:----------------------|-----------:|:--------------------|----------------------------:|:--------------|:-----------------------|:------------------|--------------------:|\\n| 0 | FDA15 | 9.3 | Low Fat | 0.0160473 | Dairy | 249.809 | OUT049 | 1999 | Medium | Tier 1 | Supermarket Type1 | 3735.14 |\\n| 1 | DRC01 | 5.92 | Regular | 0.0192782 | Soft Drinks | 48.2692 | OUT018 | 2009 | Medium | Tier 3 | Supermarket Type2 | 443.423 |\\n| 2 | FDN15 | 17.5 | Low Fat | 0.0167601 | Meat | 141.618 | OUT049 | 1999 | Medium | Tier 1 | Supermarket Type1 | 2097.27 |\\n| 3 | FDX07 | 19.2 | Regular | 0 | Fruits and Vegetables | 182.095 | OUT010 | 1998 | nan | Tier 3 | Grocery Store | 732.38 |\\n| 4 | NCD19 | 8.93 | Low Fat | 0 | Household | 53.8614 | OUT013 | 1987 | High | Tier 3 | Supermarket Type1 | 994.705 |'), HumanMessagePromptTemplate(prompt=PromptTemplate(input_variables=['input'], template='{input}')), MessagesPlaceholder(variable_name='agent_scratchpad')])\n", 325 | "| RunnableBinding(bound=ChatOpenAI(client=, async_client=, temperature=0.0, openai_api_key=SecretStr('**********'), openai_proxy=''), kwargs={'tools': [{'type': 'function', 'function': {'name': 'python_repl_ast', 'description': 'A Python shell. Use this to execute python commands. Input should be a valid python command. When using this tool, sometimes output is abbreviated - make sure it does not look abbreviated before using it in your answer.', 'parameters': {'type': 'object', 'properties': {'query': {'description': 'code snippet to run', 'type': 'string'}}, 'required': ['query']}}}]})\n", 326 | "| OpenAIToolsAgentOutputParser(), input_keys_arg=['input'], return_keys_arg=['output'], stream_runnable=True), tools=[PythonAstREPLTool(locals={'df': Item_Identifier Item_Weight Item_Fat_Content Item_Visibility \\\n", 327 | "0 FDA15 9.300 Low Fat 0.016047 \n", 328 | "1 DRC01 5.920 Regular 0.019278 \n", 329 | "2 FDN15 17.500 Low Fat 0.016760 \n", 330 | "3 FDX07 19.200 Regular 0.000000 \n", 331 | "4 NCD19 8.930 Low Fat 0.000000 \n", 332 | "... ... ... ... ... \n", 333 | "8518 FDF22 6.865 Low Fat 0.056783 \n", 334 | "8519 FDS36 8.380 Regular 0.046982 \n", 335 | "8520 NCJ29 10.600 Low Fat 0.035186 \n", 336 | "8521 FDN46 7.210 Regular 0.145221 \n", 337 | "8522 DRG01 14.800 Low Fat 0.044878 \n", 338 | "\n", 339 | " Item_Type Item_MRP Outlet_Identifier \\\n", 340 | "0 Dairy 249.8092 OUT049 \n", 341 | "1 Soft Drinks 48.2692 OUT018 \n", 342 | "2 Meat 141.6180 OUT049 \n", 343 | "3 Fruits and Vegetables 182.0950 OUT010 \n", 344 | "4 Household 53.8614 OUT013 \n", 345 | "... ... ... ... \n", 346 | "8518 Snack Foods 214.5218 OUT013 \n", 347 | "8519 Baking Goods 108.1570 OUT045 \n", 348 | "8520 Health and Hygiene 85.1224 OUT035 \n", 349 | "8521 Snack Foods 103.1332 OUT018 \n", 350 | "8522 Soft Drinks 75.4670 OUT046 \n", 351 | "\n", 352 | " Outlet_Establishment_Year Outlet_Size Outlet_Location_Type \\\n", 353 | "0 1999 Medium Tier 1 \n", 354 | "1 2009 Medium Tier 3 \n", 355 | "2 1999 Medium Tier 1 \n", 356 | "3 1998 NaN Tier 3 \n", 357 | "4 1987 High Tier 3 \n", 358 | "... ... ... ... \n", 359 | "8518 1987 High Tier 3 \n", 360 | "8519 2002 NaN Tier 2 \n", 361 | "8520 2004 Small Tier 2 \n", 362 | "8521 2009 Medium Tier 3 \n", 363 | "8522 1997 Small Tier 1 \n", 364 | "\n", 365 | " Outlet_Type Item_Outlet_Sales \n", 366 | "0 Supermarket Type1 3735.1380 \n", 367 | "1 Supermarket Type2 443.4228 \n", 368 | "2 Supermarket Type1 2097.2700 \n", 369 | "3 Grocery Store 732.3800 \n", 370 | "4 Supermarket Type1 994.7052 \n", 371 | "... ... ... \n", 372 | "8518 Supermarket Type1 2778.3834 \n", 373 | "8519 Supermarket Type1 549.2850 \n", 374 | "8520 Supermarket Type1 1193.1136 \n", 375 | "8521 Supermarket Type2 1845.5976 \n", 376 | "8522 Supermarket Type1 765.6700 \n", 377 | "\n", 378 | "[8523 rows x 12 columns]})])" 379 | ] 380 | }, 381 | "execution_count": 51, 382 | "metadata": {}, 383 | "output_type": "execute_result" 384 | } 385 | ], 386 | "source": [ 387 | "agent" 388 | ] 389 | }, 390 | { 391 | "cell_type": "markdown", 392 | "metadata": {}, 393 | "source": [ 394 | "The agent we generated has the characteristics shown above. Note that we have not provided the LLM agent with any prompt. This is because we are using langchain's default function for this. \n", 395 | "\n", 396 | "We shall now try and test this model that we have built " 397 | ] 398 | }, 399 | { 400 | "cell_type": "markdown", 401 | "metadata": {}, 402 | "source": [ 403 | "## Run the model! " 404 | ] 405 | }, 406 | { 407 | "cell_type": "code", 408 | "execution_count": 52, 409 | "metadata": {}, 410 | "outputs": [ 411 | { 412 | "name": "stdout", 413 | "output_type": "stream", 414 | "text": [ 415 | "\n", 416 | "\n", 417 | "\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n", 418 | "\u001b[32;1m\u001b[1;3m\n", 419 | "Invoking: `python_repl_ast` with `{'query': 'df.shape[0]'}`\n", 420 | "\n", 421 | "\n", 422 | "\u001b[0m\u001b[36;1m\u001b[1;3m8523\u001b[0m\u001b[32;1m\u001b[1;3mThere are 8523 rows in the dataframe.\u001b[0m\n", 423 | "\n", 424 | "\u001b[1m> Finished chain.\u001b[0m\n" 425 | ] 426 | }, 427 | { 428 | "data": { 429 | "text/plain": [ 430 | "'There are 8523 rows in the dataframe.'" 431 | ] 432 | }, 433 | "execution_count": 52, 434 | "metadata": {}, 435 | "output_type": "execute_result" 436 | } 437 | ], 438 | "source": [ 439 | "agent.run(\"how many rows are there?\")" 440 | ] 441 | }, 442 | { 443 | "cell_type": "code", 444 | "execution_count": 53, 445 | "metadata": {}, 446 | "outputs": [ 447 | { 448 | "name": "stdout", 449 | "output_type": "stream", 450 | "text": [ 451 | "\n", 452 | "\n", 453 | "\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n", 454 | "\u001b[32;1m\u001b[1;3m\n", 455 | "Invoking: `python_repl_ast` with `{'query': \"df[df['Item_Fat_Content'] == 'Low Fat']['Item_Type'].value_counts()\"}`\n", 456 | "\n", 457 | "\n", 458 | "\u001b[0m\u001b[36;1m\u001b[1;3mItem_Type\n", 459 | "Household 840\n", 460 | "Snack Foods 645\n", 461 | "Fruits and Vegetables 580\n", 462 | "Health and Hygiene 481\n", 463 | "Frozen Foods 424\n", 464 | "Dairy 382\n", 465 | "Soft Drinks 339\n", 466 | "Canned 314\n", 467 | "Baking Goods 301\n", 468 | "Hard Drinks 199\n", 469 | "Meat 159\n", 470 | "Others 156\n", 471 | "Breads 126\n", 472 | "Starchy Foods 72\n", 473 | "Breakfast 39\n", 474 | "Seafood 32\n", 475 | "Name: count, dtype: int64\u001b[0m\u001b[32;1m\u001b[1;3mThe number of low fat items sold in different categories are as follows:\n", 476 | "- Household: 840\n", 477 | "- Snack Foods: 645\n", 478 | "- Fruits and Vegetables: 580\n", 479 | "- Health and Hygiene: 481\n", 480 | "- Frozen Foods: 424\n", 481 | "- Dairy: 382\n", 482 | "- Soft Drinks: 339\n", 483 | "- Canned: 314\n", 484 | "- Baking Goods: 301\n", 485 | "- Hard Drinks: 199\n", 486 | "- Meat: 159\n", 487 | "- Others: 156\n", 488 | "- Breads: 126\n", 489 | "- Starchy Foods: 72\n", 490 | "- Breakfast: 39\n", 491 | "- Seafood: 32\u001b[0m\n", 492 | "\n", 493 | "\u001b[1m> Finished chain.\u001b[0m\n" 494 | ] 495 | }, 496 | { 497 | "data": { 498 | "text/plain": [ 499 | "'The number of low fat items sold in different categories are as follows:\\n- Household: 840\\n- Snack Foods: 645\\n- Fruits and Vegetables: 580\\n- Health and Hygiene: 481\\n- Frozen Foods: 424\\n- Dairy: 382\\n- Soft Drinks: 339\\n- Canned: 314\\n- Baking Goods: 301\\n- Hard Drinks: 199\\n- Meat: 159\\n- Others: 156\\n- Breads: 126\\n- Starchy Foods: 72\\n- Breakfast: 39\\n- Seafood: 32'" 500 | ] 501 | }, 502 | "execution_count": 53, 503 | "metadata": {}, 504 | "output_type": "execute_result" 505 | } 506 | ], 507 | "source": [ 508 | "agent.run(\"How many low fat items were sold in different categories?\")" 509 | ] 510 | }, 511 | { 512 | "cell_type": "code", 513 | "execution_count": 55, 514 | "metadata": {}, 515 | "outputs": [ 516 | { 517 | "name": "stdout", 518 | "output_type": "stream", 519 | "text": [ 520 | "\n", 521 | "\n", 522 | "\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n", 523 | "\u001b[32;1m\u001b[1;3m\n", 524 | "Invoking: `python_repl_ast` with `{'query': \"df.groupby('Item_Type')['Item_Outlet_Sales'].sum().idxmax(), df.groupby('Item_Type')['Item_Outlet_Sales'].sum().max()\"}`\n", 525 | "\n", 526 | "\n", 527 | "\u001b[0m\u001b[36;1m\u001b[1;3m('Fruits and Vegetables', 2820059.8168)\u001b[0m\u001b[32;1m\u001b[1;3mThe highest earning item category was \"Fruits and Vegetables\" with total earnings of $2,820,059.82.\u001b[0m\n", 528 | "\n", 529 | "\u001b[1m> Finished chain.\u001b[0m\n" 530 | ] 531 | }, 532 | { 533 | "data": { 534 | "text/plain": [ 535 | "'The highest earning item category was \"Fruits and Vegetables\" with total earnings of $2,820,059.82.'" 536 | ] 537 | }, 538 | "execution_count": 55, 539 | "metadata": {}, 540 | "output_type": "execute_result" 541 | } 542 | ], 543 | "source": [ 544 | "agent.run(\"What was the highest earning item category? How much were the earnings?\")" 545 | ] 546 | } 547 | ], 548 | "metadata": { 549 | "kernelspec": { 550 | "display_name": ".venv", 551 | "language": "python", 552 | "name": "python3" 553 | }, 554 | "language_info": { 555 | "codemirror_mode": { 556 | "name": "ipython", 557 | "version": 3 558 | }, 559 | "file_extension": ".py", 560 | "mimetype": "text/x-python", 561 | "name": "python", 562 | "nbconvert_exporter": "python", 563 | "pygments_lexer": "ipython3", 564 | "version": "3.9.8" 565 | } 566 | }, 567 | "nbformat": 4, 568 | "nbformat_minor": 2 569 | } 570 | --------------------------------------------------------------------------------