├── data └── test_workers │ ├── harrison.txt │ └── robertson.pdf ├── misc └── example.jpg ├── requirements.txt ├── streaming.py ├── utils.py ├── chains ├── conversational_chain.py └── conversational_retrieval_chain.py ├── README.md ├── knowledge_set.py ├── .gitignore └── chat.py /data/test_workers/harrison.txt: -------------------------------------------------------------------------------- 1 | harrison worked at woowoo -------------------------------------------------------------------------------- /misc/example.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mkmenta/rag-chatgpt/HEAD/misc/example.jpg -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | langchain==0.0.348 2 | streamlit==1.26.0 3 | python-dotenv==1.0.0 4 | openai==1.3.8 -------------------------------------------------------------------------------- /data/test_workers/robertson.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mkmenta/rag-chatgpt/HEAD/data/test_workers/robertson.pdf -------------------------------------------------------------------------------- /streaming.py: -------------------------------------------------------------------------------- 1 | from langchain.callbacks.base import BaseCallbackHandler 2 | 3 | 4 | class StreamHandler(BaseCallbackHandler): 5 | 6 | def __init__(self, container, initial_text=""): 7 | self.container = container 8 | self.text = initial_text 9 | 10 | def on_llm_new_token(self, token: str, **kwargs): 11 | self.text += token 12 | self.container.markdown(self.text) 13 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | from openai import OpenAI 2 | 3 | 4 | def get_available_openai_models(put_first=None, filter_by=None): 5 | model_list = OpenAI().models.list() 6 | model_list = [m.id for m in model_list.data] 7 | if filter_by: 8 | model_list = [m for m in model_list if filter_by in m] 9 | if put_first: 10 | assert put_first in model_list 11 | model_list.remove(put_first) 12 | model_list = [put_first] + model_list 13 | return model_list 14 | -------------------------------------------------------------------------------- /chains/conversational_chain.py: -------------------------------------------------------------------------------- 1 | from langchain.chains import LLMChain 2 | from langchain.prompts import (ChatPromptTemplate, HumanMessagePromptTemplate, 3 | MessagesPlaceholder, 4 | SystemMessagePromptTemplate) 5 | 6 | 7 | class ConversationalChain(LLMChain): 8 | """Basic conversational chain. 9 | 10 | The original chains from `langchain` do not use the `MessagesPlaceholder` template. 11 | They summarize all the conversation in a single prompt that they send to ChatGPT API. 12 | This chain allows to use the natural messaging structure from ChatGPT API. 13 | """ 14 | 15 | def __init__(self, llm, memory, system_message: str, verbose: bool = False): 16 | prompt = ChatPromptTemplate( 17 | messages=[ 18 | SystemMessagePromptTemplate.from_template(system_message), 19 | # The `variable_name` here is what must align with memory 20 | MessagesPlaceholder(variable_name=memory.memory_key), 21 | HumanMessagePromptTemplate.from_template("{question}"), 22 | ] 23 | ) 24 | super().__init__(llm=llm, prompt=prompt, verbose=verbose, memory=memory) 25 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Retrieval Augmented Generation Lab with ChatGPT 2 | This is a simple lab I have implemented to test Knowledge Augmented or Retrieval Augmented Generation (RAG) with Large Language Models. In particular, I am using LangChain, Streamlit, and OpenAI ChatGPT API. This project has been an excuse to try RAG and LangChain. 3 | 4 | Run it with: 5 | ``` 6 | streamlit run chat.py 7 | ``` 8 | 9 | Disclaimer: This is a personal project without any guarantee and I am not planning to maintain it. 10 | 11 | ## Requirements 12 | I have fixed the requirements with the versions I have used during the development. But the code will probably work with previous and newer versions. 13 | 14 | Install the requirements with 15 | ``` 16 | python3 -m pip install -r requirements.txt 17 | ``` 18 | 19 | You will also need to create a `.env` file with the content: 20 | ``` 21 | OPENAI_API_KEY="YOUR OPENAI API KEY" 22 | ``` 23 | 24 | ## Example 25 | After running the streamlit app with 26 | ``` 27 | streamlit run chat.py 28 | ``` 29 | access to the website. 30 | 31 | The repository has a sample of documents in the `data/test_workers` folder. You can ask ChatGPT about them: 32 | ![Example of the web interface and ChatGPT using the content of the documents to answer](/misc/example.jpg) 33 | -------------------------------------------------------------------------------- /knowledge_set.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from langchain.document_loaders import PyPDFLoader, TextLoader 4 | from langchain.embeddings import OpenAIEmbeddings 5 | from langchain.text_splitter import CharacterTextSplitter 6 | from langchain.vectorstores import FAISS 7 | 8 | 9 | def compute_knowledge_vectorstore(name: str, embeddings, base_path: str = "data"): 10 | full_path = os.path.join(base_path, name) 11 | if not os.path.exists(full_path): 12 | os.makedirs(full_path) 13 | 14 | # Load if exists 15 | tracked_files = set() 16 | vectorstore = None 17 | if os.path.exists(os.path.join(full_path, "faiss_index")): 18 | vectorstore = FAISS.load_local(os.path.join(full_path, "faiss_index"), embeddings=embeddings) 19 | with open(os.path.join(full_path, "tracked_files"), 'r') as f: 20 | for line in f: 21 | tracked_files.add(line.strip()) 22 | 23 | # Check missing 24 | all_files = [fn for fn in os.listdir(full_path) 25 | if fn not in ('faiss_index', 'tracked_files')] 26 | missing_files = [fn for fn in all_files if fn not in tracked_files] 27 | 28 | # Compute missing 29 | missing_docs = [] 30 | for file in missing_files: 31 | if file.lower().endswith('.pdf'): 32 | loader = PyPDFLoader 33 | elif file.lower().endswith('.txt'): 34 | loader = TextLoader 35 | else: 36 | print(f'Unsupported file type: {file}') 37 | loader = loader(os.path.join(full_path, file)) 38 | documents = loader.load() 39 | text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0) 40 | missing_docs.extend(text_splitter.split_documents(documents)) 41 | 42 | if missing_docs: 43 | new_vectorstore = FAISS.from_documents(missing_docs, embedding=OpenAIEmbeddings()) 44 | if vectorstore is None: 45 | vectorstore = new_vectorstore 46 | else: 47 | vectorstore.merge_from(new_vectorstore) 48 | vectorstore.save_local(os.path.join(full_path, "faiss_index")) 49 | 50 | with open(os.path.join(full_path, "tracked_files"), 'w') as f: 51 | for fn in all_files: 52 | f.write(fn + '\n') 53 | return vectorstore 54 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | .idea/ 161 | .vscode/ 162 | data/* 163 | .DS_Store 164 | -------------------------------------------------------------------------------- /chat.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import os 3 | 4 | import streamlit as st 5 | from dotenv import load_dotenv 6 | from langchain.chat_models import ChatOpenAI 7 | from langchain.embeddings import OpenAIEmbeddings 8 | from langchain.memory import (ConversationBufferMemory, 9 | StreamlitChatMessageHistory) 10 | 11 | from chains.conversational_chain import ConversationalChain 12 | from chains.conversational_retrieval_chain import ( 13 | TEMPLATE, ConversationalRetrievalChain) 14 | from knowledge_set import compute_knowledge_vectorstore 15 | from streaming import StreamHandler 16 | from utils import get_available_openai_models 17 | 18 | 19 | @st.cache_resource(show_spinner=True) 20 | def load_knowledge(knowledge, model_name): 21 | return compute_knowledge_vectorstore(knowledge, OpenAIEmbeddings(model=model_name)) 22 | 23 | 24 | @st.cache_data 25 | def chat_model_list(): 26 | return get_available_openai_models(put_first='gpt-3.5-turbo', filter_by='gpt') 27 | 28 | 29 | @st.cache_data 30 | def embedding_model_list(): 31 | return get_available_openai_models(filter_by='embedding') 32 | 33 | 34 | class StreamlitChatView: 35 | def __init__(self, knowledge_folder: str) -> None: 36 | st.set_page_config(page_title="RAG ChatGPT", page_icon="📚", layout="wide") 37 | with st.sidebar: 38 | st.title("RAG ChatGPT") 39 | with st.expander("Model parameters"): 40 | self.model_name = st.selectbox("Model:", options=chat_model_list()) 41 | self.temperature = st.slider("Temperature", min_value=0., max_value=2., value=0.7, step=0.01) 42 | self.top_p = st.slider("Top p", min_value=0., max_value=1., value=1., step=0.01) 43 | self.frequency_penalty = st.slider("Frequency penalty", min_value=0., max_value=2., value=0., step=0.01) 44 | self.presence_penalty = st.slider("Presence penalty", min_value=0., max_value=2., value=0., step=0.01) 45 | with st.expander("Prompts"): 46 | curdate = datetime.datetime.now().strftime("%Y-%m-%d") 47 | model_name = self.model_name.replace('-turbo', '').upper() 48 | system_message = (f"You are ChatGPT, a large language model trained by OpenAI, " 49 | f"based on the {model_name} architecture.\n" 50 | f"Knowledge cutoff: 2021-09\n" 51 | f"Current date: {curdate}\n") 52 | self.system_message = st.text_area("System message", value=system_message) 53 | self.context_prompt = st.text_area("Context prompt", value=TEMPLATE) 54 | with st.expander("Embeddings parameters"): 55 | self.embeddings_model_name = st.selectbox("Embeddings model:", options=embedding_model_list()) 56 | self.inject_knowledge = st.checkbox("Inject knowledge", value=True) 57 | knowledge_names = [fn for fn in os.listdir(knowledge_folder) 58 | if os.path.isdir(os.path.join(knowledge_folder, fn))] 59 | self.knowledge = st.selectbox("Select a knowledge folder:", knowledge_names) 60 | self.user_query = st.chat_input(placeholder="Ask me anything!") 61 | 62 | def add_message(self, message: str, author: str): 63 | assert author in ["user", "assistant"] 64 | with st.chat_message(author): 65 | st.markdown(message) 66 | 67 | def add_message_stream(self, author: str): 68 | assert author in ["user", "assistant"] 69 | return StreamHandler(st.chat_message(author).empty()) 70 | 71 | 72 | def setup_memory(): 73 | msgs = StreamlitChatMessageHistory(key="langchain_messages") 74 | return ConversationBufferMemory(memory_key="chat_history", chat_memory=msgs, return_messages=True) 75 | 76 | 77 | def setup_chain(llm, memory, inject_knowledge, system_message, context_prompt, retriever): 78 | if not inject_knowledge: 79 | # Custom conversational chain 80 | return ConversationalChain( 81 | llm=llm, 82 | memory=memory, 83 | system_message=system_message, 84 | verbose=True) 85 | else: 86 | return ConversationalRetrievalChain( 87 | llm=llm, 88 | retriever=retriever, 89 | memory=memory, 90 | system_message=system_message, 91 | context_prompt=context_prompt, 92 | verbose=True) 93 | 94 | 95 | STREAM = True 96 | 97 | # Setup 98 | load_dotenv() 99 | view = StreamlitChatView("data") 100 | memory = setup_memory() 101 | retriever = None 102 | if view.inject_knowledge: 103 | retriever = load_knowledge(view.knowledge, model_name=view.embeddings_model_name).as_retriever() 104 | llm = ChatOpenAI( 105 | streaming=STREAM, 106 | model_name=view.model_name, 107 | temperature=view.temperature, 108 | top_p=view.top_p, 109 | frequency_penalty=view.frequency_penalty, 110 | presence_penalty=view.presence_penalty) 111 | chain = setup_chain(llm=llm, memory=memory, inject_knowledge=view.inject_knowledge, 112 | retriever=retriever, system_message=view.system_message, 113 | context_prompt=view.context_prompt) 114 | 115 | # Display previous messages 116 | for message in memory.chat_memory.messages: 117 | view.add_message(message.content, 'assistant' if message.type == 'ai' else 'user') 118 | 119 | # Send message 120 | if view.user_query: 121 | view.add_message(view.user_query, "user") 122 | if STREAM: 123 | st_callback = view.add_message_stream("assistant") 124 | chain.run({"question": view.user_query}, callbacks=[st_callback]) 125 | else: 126 | response = chain.run({"question": view.user_query}) 127 | view.add_message(response, "assistant") 128 | -------------------------------------------------------------------------------- /chains/conversational_retrieval_chain.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, List, Optional 2 | 3 | from langchain.callbacks.manager import (AsyncCallbackManagerForChainRun, 4 | CallbackManagerForChainRun) 5 | from langchain.chains.base import Chain 6 | from langchain.prompts import (ChatPromptTemplate, HumanMessagePromptTemplate, 7 | MessagesPlaceholder, 8 | SystemMessagePromptTemplate) 9 | from langchain.schema import BaseRetriever 10 | from langchain.schema.language_model import BaseLanguageModel 11 | from pydantic import Extra 12 | 13 | TEMPLATE = ( 14 | "Question: {question}\n\n" 15 | "Use the following pieces of context to answer the question.\n" 16 | "If you don't know the answer, just say that you don't know, don't try to make up an answer.\n" 17 | "----------------\n" 18 | "{context}" 19 | ) 20 | 21 | 22 | class ConversationalRetrievalChain(Chain): 23 | """ 24 | An example of a custom chain. 25 | """ 26 | system_message: str 27 | context_prompt: str = TEMPLATE 28 | """Prompt object to use.""" 29 | llm: BaseLanguageModel 30 | output_key: str = "text" #: :meta private: 31 | retriever: BaseRetriever 32 | # docs_combiner: BaseCombineDocumentsChain = StuffDocumentsChain() 33 | 34 | class Config: 35 | """Configuration for this pydantic object.""" 36 | 37 | extra = Extra.forbid 38 | arbitrary_types_allowed = True 39 | 40 | @property 41 | def input_keys(self) -> List[str]: 42 | """Will be whatever keys the prompt expects. 43 | 44 | :meta private: 45 | """ 46 | return ["question"] 47 | 48 | @property 49 | def output_keys(self) -> List[str]: 50 | """Will always return text key. 51 | 52 | :meta private: 53 | """ 54 | return [self.output_key] 55 | 56 | def _call( 57 | self, 58 | inputs: Dict[str, Any], 59 | run_manager: Optional[CallbackManagerForChainRun] = None, 60 | ) -> Dict[str, str]: 61 | prompt = ChatPromptTemplate( 62 | messages=[ 63 | SystemMessagePromptTemplate.from_template(self.system_message), 64 | # The `variable_name` here is what must align with memory 65 | MessagesPlaceholder(variable_name=self.memory.memory_key), 66 | HumanMessagePromptTemplate.from_template(self.context_prompt), 67 | ] 68 | ) 69 | 70 | # TODO maybe it makes sense to use the vectorstore directly with a k 71 | docs = self.retriever.get_relevant_documents( 72 | inputs['question'], callbacks=run_manager.get_child() 73 | ) 74 | inputs = inputs.copy() 75 | inputs['context'] = "\n\n".join([doc.page_content for doc in docs]) 76 | 77 | # Your custom chain logic goes here 78 | # This is just an example that mimics LLMChain 79 | prompt_value = prompt.format_prompt(**inputs) 80 | 81 | # Whenever you call a language model, or another chain, you should pass 82 | # a callback manager to it. This allows the inner run to be tracked by 83 | # any callbacks that are registered on the outer run. 84 | # You can always obtain a callback manager for this by calling 85 | # `run_manager.get_child()` as shown below. 86 | response = self.llm.generate_prompt( 87 | [prompt_value], callbacks=run_manager.get_child() if run_manager else None 88 | ) 89 | 90 | # If you want to log something about this run, you can do so by calling 91 | # methods on the `run_manager`, as shown below. This will trigger any 92 | # callbacks that are registered for that event. 93 | # if run_manager: 94 | # run_manager.on_text("Log something about this run") 95 | 96 | return {self.output_key: response.generations[0][0].text} 97 | 98 | async def _acall( 99 | self, 100 | inputs: Dict[str, Any], 101 | run_manager: Optional[AsyncCallbackManagerForChainRun] = None, 102 | ) -> Dict[str, str]: 103 | prompt = ChatPromptTemplate( 104 | messages=[ 105 | SystemMessagePromptTemplate.from_template(self.system_message), 106 | # The `variable_name` here is what must align with memory 107 | MessagesPlaceholder(variable_name=self.memory.memory_key), 108 | HumanMessagePromptTemplate.from_template(self.context_prompt), 109 | ] 110 | ) 111 | 112 | # TODO maybe it makes sense to use the vectorstore directly with a k 113 | docs = await self.retriever.aget_relevant_documents( 114 | inputs['question'], callbacks=run_manager.get_child() 115 | ) 116 | inputs = inputs.copy() 117 | inputs['context'] = "\n\n".join([doc.page_content for doc in docs]) 118 | 119 | # Your custom chain logic goes here 120 | # This is just an example that mimics LLMChain 121 | prompt_value = prompt.format_prompt(**inputs) 122 | 123 | # Whenever you call a language model, or another chain, you should pass 124 | # a callback manager to it. This allows the inner run to be tracked by 125 | # any callbacks that are registered on the outer run. 126 | # You can always obtain a callback manager for this by calling 127 | # `run_manager.get_child()` as shown below. 128 | response = await self.llm.agenerate_prompt( 129 | [prompt_value], callbacks=run_manager.get_child() if run_manager else None 130 | ) 131 | 132 | # If you want to log something about this run, you can do so by calling 133 | # methods on the `run_manager`, as shown below. This will trigger any 134 | # callbacks that are registered for that event. 135 | # if run_manager: 136 | # run_manager.on_text("Log something about this run") 137 | 138 | return {self.output_key: response.generations[0][0].text} 139 | 140 | @property 141 | def _chain_type(self) -> str: 142 | return "my_custom_chain" 143 | --------------------------------------------------------------------------------