├── .gitignore ├── README.md ├── agent.py ├── app.py └── requirements.txt /.gitignore: -------------------------------------------------------------------------------- 1 | data/ 2 | 3 | # Byte-compiled / optimized / DLL files 4 | __pycache__/ 5 | *.py[cod] 6 | *$py.class 7 | 8 | # C extensions 9 | *.so 10 | 11 | # Distribution / packaging 12 | .Python 13 | build/ 14 | develop-eggs/ 15 | dist/ 16 | downloads/ 17 | eggs/ 18 | .eggs/ 19 | lib/ 20 | lib64/ 21 | parts/ 22 | sdist/ 23 | var/ 24 | wheels/ 25 | pip-wheel-metadata/ 26 | share/python-wheels/ 27 | *.egg-info/ 28 | .installed.cfg 29 | *.egg 30 | MANIFEST 31 | 32 | # PyInstaller 33 | # Usually these files are written by a python script from a template 34 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 35 | *.manifest 36 | *.spec 37 | 38 | # Installer logs 39 | pip-log.txt 40 | pip-delete-this-directory.txt 41 | 42 | # Unit test / coverage reports 43 | htmlcov/ 44 | .tox/ 45 | .nox/ 46 | .coverage 47 | .coverage.* 48 | .cache 49 | nosetests.xml 50 | coverage.xml 51 | *.cover 52 | *.py,cover 53 | .hypothesis/ 54 | .pytest_cache/ 55 | 56 | # Translations 57 | *.mo 58 | *.pot 59 | 60 | # Django stuff: 61 | *.log 62 | local_settings.py 63 | db.sqlite3 64 | db.sqlite3-journal 65 | 66 | # Flask stuff: 67 | instance/ 68 | .webassets-cache 69 | 70 | # Scrapy stuff: 71 | .scrapy 72 | 73 | # Sphinx documentation 74 | docs/_build/ 75 | 76 | # PyBuilder 77 | target/ 78 | 79 | # Jupyter Notebook 80 | .ipynb_checkpoints 81 | 82 | # IPython 83 | profile_default/ 84 | ipython_config.py 85 | 86 | # pyenv 87 | .python-version 88 | 89 | # pipenv 90 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 91 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 92 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 93 | # install all needed dependencies. 94 | #Pipfile.lock 95 | 96 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 97 | __pypackages__/ 98 | 99 | # Celery stuff 100 | celerybeat-schedule 101 | celerybeat.pid 102 | 103 | # SageMath parsed files 104 | *.sage.py 105 | 106 | # Environments 107 | .env 108 | .venv 109 | env/ 110 | venv/ 111 | ENV/ 112 | env.bak/ 113 | venv.bak/ 114 | 115 | # Spyder project settings 116 | .spyderproject 117 | .spyproject 118 | 119 | # Rope project settings 120 | .ropeproject 121 | 122 | # mkdocs documentation 123 | /site 124 | 125 | # mypy 126 | .mypy_cache/ 127 | .dmypy.json 128 | dmypy.json 129 | 130 | # Pyre type checker 131 | .pyre/ 132 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ChatPDF 2 | 3 | Ask anything to PDFs. 4 | [Demo](https://chatpdfs.streamlit.app/) 5 | 6 | ## Installation 7 | 8 | Developed using `python 3.10` on windows. 9 | 10 | ```bash 11 | pip install -r requirements.txt 12 | ``` 13 | 14 | ## Usage 15 | 16 | ```bash 17 | streamlit run app.py 18 | ``` 19 | -------------------------------------------------------------------------------- /agent.py: -------------------------------------------------------------------------------- 1 | import os 2 | from langchain.embeddings.openai import OpenAIEmbeddings 3 | from langchain.document_loaders import PyPDFLoader 4 | from langchain.text_splitter import RecursiveCharacterTextSplitter 5 | from langchain.vectorstores import FAISS 6 | 7 | from langchain.chains import ConversationalRetrievalChain 8 | from langchain.llms import OpenAI 9 | 10 | 11 | class Agent: 12 | def __init__(self, openai_api_key: str | None = None) -> None: 13 | # if openai_api_key is None, then it will look the enviroment variable OPENAI_API_KEY 14 | self.embeddings = OpenAIEmbeddings(openai_api_key=openai_api_key) 15 | self.text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200) 16 | 17 | self.llm = OpenAI(temperature=0, openai_api_key=openai_api_key) 18 | 19 | self.chat_history = None 20 | self.chain = None 21 | self.db = None 22 | 23 | def ask(self, question: str) -> str: 24 | if self.chain is None: 25 | response = "Please, add a document." 26 | else: 27 | response = self.chain({"question": question, "chat_history": self.chat_history}) 28 | response = response["answer"].strip() 29 | self.chat_history.append((question, response)) 30 | return response 31 | 32 | def ingest(self, file_path: os.PathLike) -> None: 33 | loader = PyPDFLoader(file_path) 34 | documents = loader.load() 35 | splitted_documents = self.text_splitter.split_documents(documents) 36 | 37 | if self.db is None: 38 | self.db = FAISS.from_documents(splitted_documents, self.embeddings) 39 | self.chain = ConversationalRetrievalChain.from_llm(self.llm, self.db.as_retriever()) 40 | self.chat_history = [] 41 | else: 42 | self.db.add_documents(splitted_documents) 43 | 44 | def forget(self) -> None: 45 | self.db = None 46 | self.chain = None 47 | self.chat_history = None 48 | -------------------------------------------------------------------------------- /app.py: -------------------------------------------------------------------------------- 1 | import os 2 | import tempfile 3 | import streamlit as st 4 | from streamlit_chat import message 5 | from agent import Agent 6 | 7 | st.set_page_config(page_title="ChatPDF") 8 | 9 | 10 | def display_messages(): 11 | st.subheader("Chat") 12 | for i, (msg, is_user) in enumerate(st.session_state["messages"]): 13 | message(msg, is_user=is_user, key=str(i)) 14 | st.session_state["thinking_spinner"] = st.empty() 15 | 16 | 17 | def process_input(): 18 | if st.session_state["user_input"] and len(st.session_state["user_input"].strip()) > 0: 19 | user_text = st.session_state["user_input"].strip() 20 | with st.session_state["thinking_spinner"], st.spinner(f"Thinking"): 21 | agent_text = st.session_state["agent"].ask(user_text) 22 | 23 | st.session_state["messages"].append((user_text, True)) 24 | st.session_state["messages"].append((agent_text, False)) 25 | 26 | 27 | def read_and_save_file(): 28 | st.session_state["agent"].forget() # to reset the knowledge base 29 | st.session_state["messages"] = [] 30 | st.session_state["user_input"] = "" 31 | 32 | for file in st.session_state["file_uploader"]: 33 | with tempfile.NamedTemporaryFile(delete=False) as tf: 34 | tf.write(file.getbuffer()) 35 | file_path = tf.name 36 | 37 | with st.session_state["ingestion_spinner"], st.spinner(f"Ingesting {file.name}"): 38 | st.session_state["agent"].ingest(file_path) 39 | os.remove(file_path) 40 | 41 | 42 | def is_openai_api_key_set() -> bool: 43 | return len(st.session_state["OPENAI_API_KEY"]) > 0 44 | 45 | 46 | def main(): 47 | if len(st.session_state) == 0: 48 | st.session_state["messages"] = [] 49 | st.session_state["OPENAI_API_KEY"] = os.environ.get("OPENAI_API_KEY", "") 50 | if is_openai_api_key_set(): 51 | st.session_state["agent"] = Agent(st.session_state["OPENAI_API_KEY"]) 52 | else: 53 | st.session_state["agent"] = None 54 | 55 | st.header("ChatPDF") 56 | 57 | if st.text_input("OpenAI API Key", value=st.session_state["OPENAI_API_KEY"], key="input_OPENAI_API_KEY", type="password"): 58 | if ( 59 | len(st.session_state["input_OPENAI_API_KEY"]) > 0 60 | and st.session_state["input_OPENAI_API_KEY"] != st.session_state["OPENAI_API_KEY"] 61 | ): 62 | st.session_state["OPENAI_API_KEY"] = st.session_state["input_OPENAI_API_KEY"] 63 | if st.session_state["agent"] is not None: 64 | st.warning("Please, upload the files again.") 65 | st.session_state["messages"] = [] 66 | st.session_state["user_input"] = "" 67 | st.session_state["agent"] = Agent(st.session_state["OPENAI_API_KEY"]) 68 | 69 | st.subheader("Upload a document") 70 | st.file_uploader( 71 | "Upload document", 72 | type=["pdf"], 73 | key="file_uploader", 74 | on_change=read_and_save_file, 75 | label_visibility="collapsed", 76 | accept_multiple_files=True, 77 | disabled=not is_openai_api_key_set(), 78 | ) 79 | 80 | st.session_state["ingestion_spinner"] = st.empty() 81 | 82 | display_messages() 83 | st.text_input("Message", key="user_input", disabled=not is_openai_api_key_set(), on_change=process_input) 84 | 85 | st.divider() 86 | st.markdown("Source code: [Github](https://github.com/viniciusarruda/chatpdf)") 87 | 88 | 89 | if __name__ == "__main__": 90 | main() 91 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | streamlit 2 | streamlit_chat 3 | langchain 4 | openai 5 | faiss-cpu 6 | pypdf 7 | tiktoken --------------------------------------------------------------------------------