├── log_init.py ├── chainlit.md ├── hr_model.py ├── source_splitter.py ├── geolocation.py ├── hr_chatbot_cli.py ├── README.md ├── config.py ├── generate_embeddings.py ├── .gitignore ├── hr_chatbot_chainlit.py └── chain_factory.py /log_init.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | logging.basicConfig(level="INFO") 4 | 5 | logger = logging.getLogger("hr_chatbot") 6 | -------------------------------------------------------------------------------- /chainlit.md: -------------------------------------------------------------------------------- 1 | # Welcome to Ask AIshu HR! 🚀🤖 2 | 3 | Using this chatbot you can ask questions about the HR policies at Onepoint. 4 | 5 | AIshu stands for "Artificial Intelligence (powered) Smart Hr Unit". 6 | 7 | ## Feedback 🔗 8 | 9 | Please send your feedback to Gil and Sangeetha 💻😊 10 | 11 | ## More information about Onepoint 12 | 13 | Here is more information about [Onepoint](https://www.onepointltd.com/). 📚 14 | 15 | 16 | -------------------------------------------------------------------------------- /hr_model.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Any, Dict, Tuple, List, TypeVar 3 | from langchain.schema import Document 4 | from langchain.chains import RetrievalQAWithSourcesChain 5 | from langchain.vectorstores.base import VectorStore 6 | 7 | VST = TypeVar("VST", bound="VectorStore") 8 | 9 | 10 | @dataclass 11 | class QAData: 12 | vst: VST 13 | documents: List[Document] 14 | chain: RetrievalQAWithSourcesChain 15 | -------------------------------------------------------------------------------- /source_splitter.py: -------------------------------------------------------------------------------- 1 | import re 2 | from log_init import logger 3 | 4 | 5 | def source_splitter(sources: str): 6 | logger.info(f"There are sources: {sources}") 7 | raw_sources, file_sources = [], [] 8 | split_char = "," if "," in sources else "-" 9 | for source in sources.split(split_char): 10 | source = source.strip() 11 | raw_sources.append(source) 12 | file_sources.append(re.sub(r"(.+\.pdf).*", r"\1", source)) 13 | return raw_sources, file_sources 14 | 15 | 16 | if __name__ == "__main__": 17 | sources = "04.16 Code of Conduct (1).pdf page 1, 04.13 HR Policies & Procedures V10.docx .pdf page 1" 18 | raw_sources, file_sources = source_splitter(sources) 19 | logger.info(f"raw sources: {raw_sources}") 20 | logger.info(f"file sources: {file_sources}") 21 | print() 22 | sources = """Family Friendly Rights & Policies V1.2.pdf page 16, page 22""" 23 | raw_sources, file_sources = source_splitter(sources) 24 | logger.info(f"raw sources: {raw_sources}") 25 | logger.info(f"file sources: {file_sources}") 26 | -------------------------------------------------------------------------------- /geolocation.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | import requests 3 | 4 | from log_init import logger 5 | 6 | 7 | @dataclass 8 | class GeoLocation: 9 | country_code: str 10 | country_name: str 11 | city: str 12 | postal: str 13 | latitude: float 14 | longitude: float 15 | 16 | 17 | def extract_ip_address(environ: dict) -> str: 18 | asgi_scope = environ.get("asgi.scope") 19 | if asgi_scope: 20 | client = asgi_scope.get("client") 21 | return client[0] 22 | return None 23 | 24 | 25 | def geolocate(ip_address: str) -> GeoLocation: 26 | response = requests.get( 27 | f"https://geolocation-db.com/json/{ip_address}&position=true" 28 | ).json() 29 | return GeoLocation( 30 | country_code=response["country_code"], 31 | country_name=response["country_name"], 32 | city=response["city"], 33 | postal=response["postal"], 34 | latitude=response["latitude"], 35 | longitude=response["longitude"], 36 | ) 37 | 38 | 39 | if __name__ == "__main__": 40 | response = geolocate("185.71.38.58") 41 | logger.info(response) 42 | -------------------------------------------------------------------------------- /hr_chatbot_cli.py: -------------------------------------------------------------------------------- 1 | from langchain.chains import RetrievalQAWithSourcesChain 2 | from prompt_toolkit import HTML, prompt, PromptSession 3 | from prompt_toolkit.history import FileHistory 4 | 5 | from chain_factory import create_retrieval_chain, load_embeddinges 6 | 7 | from log_init import logger 8 | 9 | import sys 10 | 11 | 12 | def init_chain(): 13 | humour = False 14 | if len(sys.argv) > 1: 15 | if sys.argv[1] == "humor": 16 | humour = True 17 | logger.warning("Humor flag activated") 18 | session = PromptSession(history=FileHistory(".agent-history-file")) 19 | docsearch, documents = load_embeddinges() 20 | chain: RetrievalQAWithSourcesChain = create_retrieval_chain( 21 | docsearch, humour=humour 22 | ) 23 | return session, chain 24 | 25 | 26 | if __name__ == "__main__": 27 | session, chain = init_chain() 28 | 29 | while True: 30 | question = session.prompt( 31 | HTML("Type Your question ('q' to exit): ") 32 | ) 33 | if question.lower() in ["q", "exit", "quit"]: 34 | break 35 | response = chain({"question": question}) 36 | logger.info(f"Answer: {response['answer']}") 37 | logger.info(f"Sources: {response['sources']}") 38 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Onepoint HR Chatbot 2 | 3 | This is a simple HR chatbot based on Chainlit with memory support and able to 4 | choose its vector database based on geolocation. 5 | 6 | ## Installation 7 | 8 | ``` 9 | conda create -n langchain_chainlit python=3.11 10 | conda activate langchain_chainlit 11 | pip install langchain 12 | pip install python-dotenv 13 | pip install openai 14 | pip install faiss-cpu 15 | pip install tiktoken 16 | pip install chainlit 17 | pip install pdfminer 18 | pip install pypdfium2 19 | pip install prompt_toolkit 20 | ``` 21 | 22 | ### Custom environment 23 | 24 | ``` 25 | # conda activate base 26 | # conda remove -n langchain_chainlit_2 --all 27 | conda create -n langchain_chainlit_2 python=3.11 28 | conda activate langchain_chainlit_2 29 | # pip install --force-reinstall /home/ubuntu/chainlit-0.5.3-py3-none-any.whl 30 | pip install --force-reinstall C:\development\playground\chainlit\src\dist\chainlit-0.5.3-py3-none-any.whl 31 | pip install langchain 32 | pip install faiss-cpu 33 | pip install tiktoken 34 | pip install pdfminer 35 | pip install pypdfium2 36 | pip install black 37 | ``` 38 | 39 | ## Configuration 40 | 41 | Please make sure that you have a .env file with the following variables: 42 | ``` 43 | OPENAI_API_KEY= 44 | DOC_LOCATION= 45 | FAISS_STORE= 46 | HUMOUR= 47 | ``` 48 | 49 | ## Running 50 | 51 | With Chainlit: 52 | ``` 53 | chainlit run hr_chatbot_chainlit.py --port 8081 54 | ``` 55 | 56 | For Development: 57 | ``` 58 | chainlit run hr_chatbot_chainlit.py -w --port 8081 59 | ``` 60 | 61 | Command line: 62 | ``` 63 | python ./hr_chatbot_cli.py 64 | ``` -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | from langchain.embeddings.openai import OpenAIEmbeddings 2 | from langchain.chat_models import ChatOpenAI 3 | from pathlib import Path 4 | import os 5 | 6 | from dotenv import load_dotenv 7 | 8 | load_dotenv() 9 | 10 | 11 | class Config: 12 | faiss_persist_directory_uk = Path(os.environ["FAISS_STORE_UK"]) 13 | faiss_persist_directory_india = Path(os.environ["FAISS_STORE_INDIA"]) 14 | faiss_dirs = [faiss_persist_directory_uk, faiss_persist_directory_india] 15 | for d in faiss_dirs: 16 | if not d.exists(): 17 | d.mkdir() 18 | 19 | doc_location_uk = Path(os.environ["DOC_LOCATION_UK"]) 20 | doc_location_india = Path(os.environ["DOC_LOCATION_INDIA"]) 21 | doc_locations = [doc_location_uk, doc_location_india] 22 | 23 | location_persistence_map = { 24 | "GB": { 25 | "faiss_persist_directory": faiss_persist_directory_uk, 26 | "doc_location": doc_location_uk, 27 | }, 28 | "IN": { 29 | "faiss_persist_directory": faiss_persist_directory_india, 30 | "doc_location": doc_location_india, 31 | }, 32 | } 33 | 34 | for location in doc_locations: 35 | if not location.exists(): 36 | raise Exception(f"File not found: {location}") 37 | 38 | embeddings = OpenAIEmbeddings(chunk_size=100) 39 | model = "gpt-3.5-turbo-16k" 40 | # model = 'gpt-4' 41 | llm = ChatOpenAI(model=model, temperature=0) 42 | search_results = 5 43 | 44 | def __repr__(self) -> str: 45 | return f"""# Configuration 46 | faiss_persist_directories: {self.faiss_dirs} 47 | doc_locations: {self.doc_locations} 48 | 49 | embeddings: {self.embeddings} 50 | 51 | llm: {self.llm} 52 | """ 53 | 54 | 55 | cfg = Config() 56 | 57 | if __name__ == "__main__": 58 | print(cfg) 59 | -------------------------------------------------------------------------------- /generate_embeddings.py: -------------------------------------------------------------------------------- 1 | from langchain.schema import Document 2 | from langchain.document_loaders import PyPDFium2Loader 3 | from langchain.vectorstores import FAISS 4 | 5 | from typing import TypeVar, List 6 | from pathlib import Path 7 | from dotenv import load_dotenv 8 | import numpy as np 9 | 10 | import os 11 | import re 12 | 13 | from config import cfg 14 | 15 | from log_init import logger 16 | 17 | load_dotenv() 18 | 19 | VST = TypeVar("VST", bound="VectorStore") 20 | 21 | 22 | def load_pdfs(path: Path) -> List[Document]: 23 | """ 24 | Loads the PDFs and extracts a document per page. 25 | The page details are added to the extracted metadata 26 | 27 | Parameters: 28 | path (Path): The path where the PDFs are saved. 29 | 30 | Returns: 31 | List[Document]: Returns a list of values 32 | """ 33 | assert path.exists() 34 | all_pages = [] 35 | for pdf in path.glob("*.pdf"): 36 | loader = PyPDFium2Loader(str(pdf.absolute())) 37 | pages: List[Document] = loader.load_and_split() 38 | for i, p in enumerate(pages): 39 | file_name = re.sub(r".+[\\/]", "", p.metadata["source"]) 40 | p.metadata["source"] = f"{file_name} page {i + 1}" 41 | all_pages.extend(pages) 42 | logger.info(f"Processed {pdf}, all_pages size: {len(all_pages)}") 43 | log_stats(all_pages) 44 | return all_pages 45 | 46 | 47 | def log_stats(documents: List[Document]): 48 | logger.info(f"Total number of documents {len(documents)}") 49 | counts = [] 50 | for d in documents: 51 | counts.append(count_words(d)) 52 | logger.info(f"Tokens Max {np.max(counts)}") 53 | logger.info(f"Tokens Min {np.min(counts)}") 54 | logger.info(f"Tokens Min {np.mean(counts)}") 55 | 56 | 57 | def generate_embeddings( 58 | documents: List[Document], path: Path, faiss_persist_directory: str 59 | ) -> VST: 60 | """ 61 | Receives a list of documents and generates the embeddings via OpenAI API. 62 | 63 | Parameters: 64 | documents (List[Document]): The document list with one page per document. 65 | path (Path): The path where the documents are found. 66 | 67 | Returns: 68 | VST: Recturs a reference to the vector store. 69 | """ 70 | try: 71 | docsearch = FAISS.from_documents(documents, cfg.embeddings) 72 | docsearch.save_local(faiss_persist_directory) 73 | logger.info("Vector database persisted") 74 | except Exception as e: 75 | logger.error(f"Failed to process {path}: {str(e)}") 76 | if "docsearch" in vars() or "docsearch" in globals(): 77 | docsearch.persist() 78 | return docsearch 79 | return docsearch 80 | 81 | 82 | def count_words(document: Document) -> int: 83 | splits = [s for s in re.split("[\s,.]", document.page_content) if len(s) > 0] 84 | return len(splits) 85 | 86 | 87 | if __name__ == "__main__": 88 | doc_location: str = os.environ["DOC_LOCATION"] 89 | documents = load_pdfs(Path(doc_location)) 90 | assert len(documents) > 0 91 | logger.info(documents[2].page_content) 92 | generate_embeddings(documents, doc_location) 93 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ -------------------------------------------------------------------------------- /hr_chatbot_chainlit.py: -------------------------------------------------------------------------------- 1 | from langchain.chains import RetrievalQAWithSourcesChain 2 | import chainlit as cl 3 | 4 | from chain_factory import load_all_chains 5 | from geolocation import extract_ip_address, geolocate 6 | from source_splitter import source_splitter 7 | from chainlit.context import get_emitter 8 | 9 | from log_init import logger 10 | 11 | from pathlib import Path 12 | from typing import Dict, Optional 13 | 14 | from config import cfg 15 | 16 | KEY_META_DATAS = "metadatas" 17 | KEY_TEXTS = "texts" 18 | KEY_GEOLOCATION_COUNTRY_CODE = "geolocation_country_code" 19 | 20 | 21 | def set_session_vars(user_session_dict: Dict): 22 | for k, v in user_session_dict.items(): 23 | cl.user_session.set(k, v) 24 | 25 | 26 | def create_pdf(pdf_name: str, pdf_path: str) -> Optional[cl.File]: 27 | """ 28 | Creates a file download button for a PDF file in case it is found. 29 | 30 | Parameters: 31 | pdf_name (str): The file name 32 | pdf_path (str): The file name 33 | 34 | Returns: 35 | RetrievalQAWithSourcesChain: The QA chain 36 | """ 37 | logger.info(f"Creating pdf for {pdf_path}") 38 | # Sending a pdf with the local file path 39 | country_code = cl.user_session.get(KEY_GEOLOCATION_COUNTRY_CODE) 40 | country_config = cfg.location_persistence_map.get(country_code) 41 | if country_config: 42 | logger.info("country_config found") 43 | doc_location: Path = country_config.get("doc_location") 44 | doc_path = doc_location / pdf_path 45 | if doc_path.exists(): 46 | logger.info("Creating pdf component") 47 | return cl.File( 48 | name=pdf_name, display="inline", path=str(doc_path.absolute()) 49 | ) 50 | else: 51 | logger.info(f"doc path {doc_path} does not exist.") 52 | return None 53 | 54 | 55 | @cl.langchain_factory(use_async=True) 56 | async def init(): 57 | """ 58 | Loads the vector data store object and the PDF documents. Creates the QA chain. 59 | Sets up some session variables and removes the Chainlit footer. 60 | 61 | Parameters: 62 | use_async (bool): Determines whether async is to be used or not. 63 | 64 | Returns: 65 | RetrievalQAWithSourcesChain: The QA chain 66 | """ 67 | 68 | emitter = get_emitter() 69 | # Please note this works only with a modified version of Streamlit 70 | # The repo with this modification are here: https://github.com/gilfernandes/chainlit_hr_extension 71 | 72 | country_code = "GB" 73 | geolocation_failed = False 74 | 75 | try: 76 | remote_address = extract_ip_address(emitter.session.environ) 77 | geo_location = geolocate(remote_address) 78 | 79 | if geo_location.country_code != "Not found": 80 | country_code = geo_location.country_code 81 | # await display_location_details(geo_location, country_code) 82 | except: 83 | logger.exception("Could not locate properly") 84 | geolocation_failed = True 85 | 86 | if geolocation_failed: 87 | await cl.Message(content=f"Geolocation failed ... I do not know where you are.").send() 88 | else: 89 | logger.info(f"Geo location: {geo_location}") 90 | 91 | msg = cl.Message(content=f"Processing files. Please wait.") 92 | await msg.send() 93 | chain_dict = load_all_chains(country_code) 94 | qa_data = chain_dict[country_code] 95 | 96 | documents = qa_data.documents 97 | 98 | chain: RetrievalQAWithSourcesChain = qa_data.chain 99 | metadatas = [d.metadata for d in documents] 100 | texts = [d.page_content for d in documents] 101 | 102 | set_session_vars( 103 | { 104 | KEY_META_DATAS: metadatas, 105 | KEY_TEXTS: texts, 106 | KEY_GEOLOCATION_COUNTRY_CODE: country_code, 107 | } 108 | ) 109 | 110 | 111 | msg.content = f"You can now ask questions about Onepoint HR ({country_code})!" 112 | await msg.send() 113 | 114 | return chain 115 | 116 | 117 | async def display_location_details(geo_location, country_code): 118 | geo_location_msg = cl.Message( 119 | content=f"""Geo location: 120 | - country: {geo_location.country_name} 121 | - country code: {country_code}""" 122 | ) 123 | await geo_location_msg.send() 124 | 125 | 126 | @cl.langchain_postprocess 127 | async def process_response(res) -> cl.Message: 128 | """ 129 | Tries to extract the sources and corresponding texts from the sources. 130 | 131 | Parameters: 132 | res (dict): A dictionary with the answer and sources provided by the LLM via LangChain. 133 | 134 | Returns: 135 | cl.Message: The message containing the answer and the list of sources with corresponding texts. 136 | """ 137 | answer = res["answer"] 138 | sources = res["sources"].strip() 139 | source_elements = [] 140 | 141 | # Get the metadata and texts from the user session 142 | metadatas = cl.user_session.get(KEY_META_DATAS) 143 | all_sources = [m["source"] for m in metadatas] 144 | texts = cl.user_session.get(KEY_TEXTS) 145 | 146 | found_sources = [] 147 | pdf_elements = [] 148 | if sources: 149 | logger.info(f"sources: {sources}") 150 | raw_sources, file_sources = source_splitter(sources) 151 | for i, source in enumerate(raw_sources): 152 | try: 153 | source_name = file_sources[i] 154 | pdf_element = create_pdf(source_name, source_name) 155 | if pdf_element: 156 | pdf_elements.append(pdf_element) 157 | logger.info(f"PDF Elements: {pdf_elements}") 158 | else: 159 | logger.warning(f"No pdf element for {source_name}") 160 | 161 | index = all_sources.index(source) 162 | text = texts[index] 163 | found_sources.append(source) 164 | # Create the text element referenced in the message 165 | logger.info(f"Found text in {source_name}") 166 | source_elements.append(cl.Text(content=text, name=source_name)) 167 | except ValueError as e: 168 | logger.error(f"Value error {e}") 169 | continue 170 | if found_sources: 171 | answer += f"\nSources: {', '.join(found_sources)}" 172 | else: 173 | answer += f"\n{sources}" 174 | 175 | logger.info(f"PDF Elements: {pdf_elements}") 176 | await cl.Message(content=answer, elements=source_elements).send() 177 | await cl.Message(content="PDF Downloads", elements=pdf_elements).send() 178 | 179 | 180 | if __name__ == "__main__": 181 | pass 182 | -------------------------------------------------------------------------------- /chain_factory.py: -------------------------------------------------------------------------------- 1 | from langchain.chains import RetrievalQAWithSourcesChain 2 | from langchain.memory import ConversationSummaryBufferMemory 3 | from langchain.vectorstores import FAISS 4 | from langchain.schema import Document 5 | from langchain.prompts import PromptTemplate 6 | from langchain.memory.utils import get_prompt_input_key 7 | from langchain.vectorstores.base import VectorStoreRetriever, VectorStore 8 | from config import cfg 9 | from typing import Any, Dict, Tuple, List, TypeVar 10 | 11 | 12 | import os 13 | from pathlib import Path 14 | 15 | from generate_embeddings import load_pdfs, generate_embeddings 16 | from hr_model import QAData 17 | from log_init import logger 18 | 19 | VST = TypeVar("VST", bound="VectorStore") 20 | 21 | 22 | class KeySourceMemory(ConversationSummaryBufferMemory): 23 | def _get_input_output( 24 | self, inputs: Dict[str, Any], outputs: Dict[str, str] 25 | ) -> Tuple[str, str]: 26 | if self.input_key is None: 27 | prompt_input_key = get_prompt_input_key(inputs, self.memory_variables) 28 | else: 29 | prompt_input_key = self.input_key 30 | if self.output_key is None: 31 | output_key = "answer" 32 | else: 33 | output_key = self.output_key 34 | return inputs[prompt_input_key], outputs[output_key] 35 | 36 | 37 | def load_embeddinges( 38 | embedding_dir: Path = cfg.faiss_persist_directory_uk, 39 | doc_location: Path = cfg.doc_location_uk, 40 | ) -> Tuple[VST, List[Document]]: 41 | """ 42 | Loads the PDF documents to support text extraction in the Chainlit UI. 43 | In case there are no persisted embeddings, the embeddings are generated. 44 | In case the embeddings are persisted, then they are loaded from the file system. 45 | 46 | Returns: 47 | Tuple[VST, List[Document]]: Recturs a reference to the vector store and the list of all pdf files. 48 | """ 49 | logger.info(f"Checking: {embedding_dir}") 50 | documents = load_pdfs(doc_location) 51 | assert len(documents) > 0 52 | if embedding_dir.exists() and len(list(embedding_dir.glob("*"))) > 0: 53 | logger.info(f"reading from existing directory: {embedding_dir}") 54 | docsearch = FAISS.load_local(embedding_dir, cfg.embeddings) 55 | return docsearch, documents 56 | return ( 57 | generate_embeddings(documents, doc_location, embedding_dir.absolute()), 58 | documents, 59 | ) 60 | 61 | 62 | template = """Given the following extracted parts of a long document and a question, create a final answer with references ("SOURCES"). If you know a joke about the subject, make sure that you include it in the response. 63 | If you don't know the answer, say that you don't know and make up some joke about the subject. Don't try to make up an answer. 64 | ALWAYS return a "SOURCES" part in your answer. 65 | 66 | QUESTION: Which state/country's law governs the interpretation of the contract? 67 | ========= 68 | Content: This Agreement is governed by English law and the parties submit to the exclusive jurisdiction of the English courts in relation to any dispute (contractual or non-contractual) concerning this Agreement save that either party may apply to any court for an injunction or other relief to protect its Intellectual Property Rights. 69 | Source: 28-pl 70 | Content: No Waiver. Failure or delay in exercising any right or remedy under this Agreement shall not constitute a waiver of such (or any other) right or remedy.\n\n11.7 Severability. The invalidity, illegality or unenforceability of any term (or part of a term) of this Agreement shall not affect the continuation in force of the remainder of the term (if any) and this Agreement.\n\n11.8 No Agency. Except as expressly stated otherwise, nothing in this Agreement shall create an agency, partnership or joint venture of any kind between the parties.\n\n11.9 No Third-Party Beneficiaries. 71 | Source: 30-pl 72 | Content: (b) if Google believes, in good faith, that the Distributor has violated or caused Google to violate any Anti-Bribery Laws (as defined in Clause 8.5) or that such a violation is reasonably likely to occur, 73 | Source: 4-pl 74 | ========= 75 | FINAL ANSWER: This Agreement is governed by English law. 76 | SOURCES: 28-pl 77 | 78 | QUESTION: What did the president say about Michael Jackson? 79 | ========= 80 | Content: Madam Speaker, Madam Vice President, our First Lady and Second Gentleman. Members of Congress and the Cabinet. Justices of the Supreme Court. My fellow Americans. \n\nLast year COVID-19 kept us apart. This year we are finally together again. \n\nTonight, we meet as Democrats Republicans and Independents. But most importantly as Americans. \n\nWith a duty to one another to the American people to the Constitution. \n\nAnd with an unwavering resolve that freedom will always triumph over tyranny. \n\nSix days ago, Russia’s Vladimir Putin sought to shake the foundations of the free world thinking he could make it bend to his menacing ways. But he badly miscalculated. \n\nHe thought he could roll into Ukraine and the world would roll over. Instead he met a wall of strength he never imagined. \n\nHe met the Ukrainian people. \n\nFrom President Zelenskyy to every Ukrainian, their fearlessness, their courage, their determination, inspires the world. \n\nGroups of citizens blocking tanks with their bodies. Everyone from students to retirees teachers turned soldiers defending their homeland. 81 | Source: 0-pl 82 | Content: And we won’t stop. \n\nWe have lost so much to COVID-19. Time with one another. And worst of all, so much loss of life. \n\nLet’s use this moment to reset. Let’s stop looking at COVID-19 as a partisan dividing line and see it for what it is: A God-awful disease. \n\nLet’s stop seeing each other as enemies, and start seeing each other for who we really are: Fellow Americans. \n\nWe can’t change how divided we’ve been. But we can change how we move forward—on COVID-19 and other issues we must face together. \n\nI recently visited the New York City Police Department days after the funerals of Officer Wilbert Mora and his partner, Officer Jason Rivera. \n\nThey were responding to a 9-1-1 call when a man shot and killed them with a stolen gun. \n\nOfficer Mora was 27 years old. \n\nOfficer Rivera was 22. \n\nBoth Dominican Americans who’d grown up on the same streets they later chose to patrol as police officers. \n\nI spoke with their families and told them that we are forever in debt for their sacrifice, and we will carry on their mission to restore the trust and safety every community deserves. 83 | Source: 24-pl 84 | Content: And a proud Ukrainian people, who have known 30 years of independence, have repeatedly shown that they will not tolerate anyone who tries to take their country backwards. \n\nTo all Americans, I will be honest with you, as I’ve always promised. A Russian dictator, invading a foreign country, has costs around the world. \n\nAnd I’m taking robust action to make sure the pain of our sanctions is targeted at Russia’s economy. And I will use every tool at our disposal to protect American businesses and consumers. \n\nTonight, I can announce that the United States has worked with 30 other countries to release 60 Million barrels of oil from reserves around the world. \n\nAmerica will lead that effort, releasing 30 Million barrels from our own Strategic Petroleum Reserve. And we stand ready to do more if necessary, unified with our allies. \n\nThese steps will help blunt gas prices here at home. And I know the news about what’s happening can seem alarming. \n\nBut I want you to know that we are going to be okay. 85 | Source: 5-pl 86 | Content: More support for patients and families. \n\nTo get there, I call on Congress to fund ARPA-H, the Advanced Research Projects Agency for Health. \n\nIt’s based on DARPA—the Defense Department project that led to the Internet, GPS, and so much more. \n\nARPA-H will have a singular purpose—to drive breakthroughs in cancer, Alzheimer’s, diabetes, and more. \n\nA unity agenda for the nation. \n\nWe can do this. \n\nMy fellow Americans—tonight , we have gathered in a sacred space—the citadel of our democracy. \n\nIn this Capitol, generation after generation, Americans have debated great questions amid great strife, and have done great things. \n\nWe have fought for freedom, expanded liberty, defeated totalitarianism and terror. \n\nAnd built the strongest, freest, and most prosperous nation the world has ever known. \n\nNow is the hour. \n\nOur moment of responsibility. \n\nOur test of resolve and conscience, of history itself. \n\nIt is in this moment that our character is formed. Our purpose is found. Our future is forged. \n\nWell I know this nation. 87 | Source: 34-pl 88 | ========= 89 | FINAL ANSWER: The president did not mention Michael Jackson. And here is a joke about Michael Jackson: Why did Michael Jackson go to the bakery? Because he wanted to "beat it" and grab some "moon-pies"! 90 | SOURCES: 91 | 92 | QUESTION: {question} 93 | ========= 94 | {summaries} 95 | ========= 96 | FINAL ANSWER:""" 97 | HUMOUR_PROMPT = PromptTemplate( 98 | template=template, input_variables=["summaries", "question"] 99 | ) 100 | 101 | 102 | def create_retrieval_chain( 103 | docsearch: VST, verbose: bool = False, humour: bool = True 104 | ) -> RetrievalQAWithSourcesChain: 105 | """ 106 | This function creates the QA chain with memory and in case the humour parameter is true, 107 | then a manipulated prompt - that tends to create jokes on certain occasions - is used. 108 | 109 | Parameters: 110 | docsearch (VST): A reference to the vector store. 111 | verbose (bool): Determines whether LangChain's internal logging is printed to the console or not. 112 | humour (bool): Determines whether the prompt for answers with jokes is used or not. 113 | 114 | Returns: 115 | RetrievalQAWithSourcesChain: The QA chain 116 | """ 117 | memory = KeySourceMemory(llm=cfg.llm, input_key="question", output_key="answer") 118 | chain_type_kwargs = {} 119 | if verbose: 120 | chain_type_kwargs["verbose"] = True 121 | if humour: 122 | chain_type_kwargs["prompt"] = HUMOUR_PROMPT 123 | search_retriever: VectorStoreRetriever = docsearch.as_retriever() 124 | search_retriever.search_kwargs = {"k": cfg.search_results} 125 | qa_chain = RetrievalQAWithSourcesChain.from_chain_type( 126 | cfg.llm, 127 | retriever=search_retriever, 128 | chain_type="stuff", 129 | memory=memory, 130 | chain_type_kwargs=chain_type_kwargs, 131 | ) 132 | 133 | return qa_chain 134 | 135 | 136 | def load_all_chains(country_filter: str = None) -> Dict[str, QAData]: 137 | res = {} 138 | for country, v in cfg.location_persistence_map.items(): 139 | if country_filter is None or country_filter == country: 140 | faiss_persist_directory = v["faiss_persist_directory"] 141 | doc_location = v["doc_location"] 142 | vst, documents = load_embeddinges(faiss_persist_directory, doc_location) 143 | chain = create_retrieval_chain(vst, humour=os.getenv("HUMOUR") == "true") 144 | res[country] = QAData(vst=vst, documents=documents, chain=chain) 145 | return res 146 | 147 | 148 | if __name__ == "__main__": 149 | chain_dict = load_all_chains() 150 | logger.info(len(chain_dict.items())) 151 | --------------------------------------------------------------------------------