├── log_init.py
├── chainlit.md
├── hr_model.py
├── source_splitter.py
├── geolocation.py
├── hr_chatbot_cli.py
├── README.md
├── config.py
├── generate_embeddings.py
├── .gitignore
├── hr_chatbot_chainlit.py
└── chain_factory.py
/log_init.py:
--------------------------------------------------------------------------------
1 | import logging
2 |
3 | logging.basicConfig(level="INFO")
4 |
5 | logger = logging.getLogger("hr_chatbot")
6 |
--------------------------------------------------------------------------------
/chainlit.md:
--------------------------------------------------------------------------------
1 | # Welcome to Ask AIshu HR! 🚀🤖
2 |
3 | Using this chatbot you can ask questions about the HR policies at Onepoint.
4 |
5 | AIshu stands for "Artificial Intelligence (powered) Smart Hr Unit".
6 |
7 | ## Feedback 🔗
8 |
9 | Please send your feedback to Gil and Sangeetha 💻😊
10 |
11 | ## More information about Onepoint
12 |
13 | Here is more information about [Onepoint](https://www.onepointltd.com/). 📚
14 |
15 |
16 |
--------------------------------------------------------------------------------
/hr_model.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass
2 | from typing import Any, Dict, Tuple, List, TypeVar
3 | from langchain.schema import Document
4 | from langchain.chains import RetrievalQAWithSourcesChain
5 | from langchain.vectorstores.base import VectorStore
6 |
7 | VST = TypeVar("VST", bound="VectorStore")
8 |
9 |
10 | @dataclass
11 | class QAData:
12 | vst: VST
13 | documents: List[Document]
14 | chain: RetrievalQAWithSourcesChain
15 |
--------------------------------------------------------------------------------
/source_splitter.py:
--------------------------------------------------------------------------------
1 | import re
2 | from log_init import logger
3 |
4 |
5 | def source_splitter(sources: str):
6 | logger.info(f"There are sources: {sources}")
7 | raw_sources, file_sources = [], []
8 | split_char = "," if "," in sources else "-"
9 | for source in sources.split(split_char):
10 | source = source.strip()
11 | raw_sources.append(source)
12 | file_sources.append(re.sub(r"(.+\.pdf).*", r"\1", source))
13 | return raw_sources, file_sources
14 |
15 |
16 | if __name__ == "__main__":
17 | sources = "04.16 Code of Conduct (1).pdf page 1, 04.13 HR Policies & Procedures V10.docx .pdf page 1"
18 | raw_sources, file_sources = source_splitter(sources)
19 | logger.info(f"raw sources: {raw_sources}")
20 | logger.info(f"file sources: {file_sources}")
21 | print()
22 | sources = """Family Friendly Rights & Policies V1.2.pdf page 16, page 22"""
23 | raw_sources, file_sources = source_splitter(sources)
24 | logger.info(f"raw sources: {raw_sources}")
25 | logger.info(f"file sources: {file_sources}")
26 |
--------------------------------------------------------------------------------
/geolocation.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass
2 | import requests
3 |
4 | from log_init import logger
5 |
6 |
7 | @dataclass
8 | class GeoLocation:
9 | country_code: str
10 | country_name: str
11 | city: str
12 | postal: str
13 | latitude: float
14 | longitude: float
15 |
16 |
17 | def extract_ip_address(environ: dict) -> str:
18 | asgi_scope = environ.get("asgi.scope")
19 | if asgi_scope:
20 | client = asgi_scope.get("client")
21 | return client[0]
22 | return None
23 |
24 |
25 | def geolocate(ip_address: str) -> GeoLocation:
26 | response = requests.get(
27 | f"https://geolocation-db.com/json/{ip_address}&position=true"
28 | ).json()
29 | return GeoLocation(
30 | country_code=response["country_code"],
31 | country_name=response["country_name"],
32 | city=response["city"],
33 | postal=response["postal"],
34 | latitude=response["latitude"],
35 | longitude=response["longitude"],
36 | )
37 |
38 |
39 | if __name__ == "__main__":
40 | response = geolocate("185.71.38.58")
41 | logger.info(response)
42 |
--------------------------------------------------------------------------------
/hr_chatbot_cli.py:
--------------------------------------------------------------------------------
1 | from langchain.chains import RetrievalQAWithSourcesChain
2 | from prompt_toolkit import HTML, prompt, PromptSession
3 | from prompt_toolkit.history import FileHistory
4 |
5 | from chain_factory import create_retrieval_chain, load_embeddinges
6 |
7 | from log_init import logger
8 |
9 | import sys
10 |
11 |
12 | def init_chain():
13 | humour = False
14 | if len(sys.argv) > 1:
15 | if sys.argv[1] == "humor":
16 | humour = True
17 | logger.warning("Humor flag activated")
18 | session = PromptSession(history=FileHistory(".agent-history-file"))
19 | docsearch, documents = load_embeddinges()
20 | chain: RetrievalQAWithSourcesChain = create_retrieval_chain(
21 | docsearch, humour=humour
22 | )
23 | return session, chain
24 |
25 |
26 | if __name__ == "__main__":
27 | session, chain = init_chain()
28 |
29 | while True:
30 | question = session.prompt(
31 | HTML("Type Your question ('q' to exit): ")
32 | )
33 | if question.lower() in ["q", "exit", "quit"]:
34 | break
35 | response = chain({"question": question})
36 | logger.info(f"Answer: {response['answer']}")
37 | logger.info(f"Sources: {response['sources']}")
38 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Onepoint HR Chatbot
2 |
3 | This is a simple HR chatbot based on Chainlit with memory support and able to
4 | choose its vector database based on geolocation.
5 |
6 | ## Installation
7 |
8 | ```
9 | conda create -n langchain_chainlit python=3.11
10 | conda activate langchain_chainlit
11 | pip install langchain
12 | pip install python-dotenv
13 | pip install openai
14 | pip install faiss-cpu
15 | pip install tiktoken
16 | pip install chainlit
17 | pip install pdfminer
18 | pip install pypdfium2
19 | pip install prompt_toolkit
20 | ```
21 |
22 | ### Custom environment
23 |
24 | ```
25 | # conda activate base
26 | # conda remove -n langchain_chainlit_2 --all
27 | conda create -n langchain_chainlit_2 python=3.11
28 | conda activate langchain_chainlit_2
29 | # pip install --force-reinstall /home/ubuntu/chainlit-0.5.3-py3-none-any.whl
30 | pip install --force-reinstall C:\development\playground\chainlit\src\dist\chainlit-0.5.3-py3-none-any.whl
31 | pip install langchain
32 | pip install faiss-cpu
33 | pip install tiktoken
34 | pip install pdfminer
35 | pip install pypdfium2
36 | pip install black
37 | ```
38 |
39 | ## Configuration
40 |
41 | Please make sure that you have a .env file with the following variables:
42 | ```
43 | OPENAI_API_KEY=
44 | DOC_LOCATION=
45 | FAISS_STORE=
46 | HUMOUR=
47 | ```
48 |
49 | ## Running
50 |
51 | With Chainlit:
52 | ```
53 | chainlit run hr_chatbot_chainlit.py --port 8081
54 | ```
55 |
56 | For Development:
57 | ```
58 | chainlit run hr_chatbot_chainlit.py -w --port 8081
59 | ```
60 |
61 | Command line:
62 | ```
63 | python ./hr_chatbot_cli.py
64 | ```
--------------------------------------------------------------------------------
/config.py:
--------------------------------------------------------------------------------
1 | from langchain.embeddings.openai import OpenAIEmbeddings
2 | from langchain.chat_models import ChatOpenAI
3 | from pathlib import Path
4 | import os
5 |
6 | from dotenv import load_dotenv
7 |
8 | load_dotenv()
9 |
10 |
11 | class Config:
12 | faiss_persist_directory_uk = Path(os.environ["FAISS_STORE_UK"])
13 | faiss_persist_directory_india = Path(os.environ["FAISS_STORE_INDIA"])
14 | faiss_dirs = [faiss_persist_directory_uk, faiss_persist_directory_india]
15 | for d in faiss_dirs:
16 | if not d.exists():
17 | d.mkdir()
18 |
19 | doc_location_uk = Path(os.environ["DOC_LOCATION_UK"])
20 | doc_location_india = Path(os.environ["DOC_LOCATION_INDIA"])
21 | doc_locations = [doc_location_uk, doc_location_india]
22 |
23 | location_persistence_map = {
24 | "GB": {
25 | "faiss_persist_directory": faiss_persist_directory_uk,
26 | "doc_location": doc_location_uk,
27 | },
28 | "IN": {
29 | "faiss_persist_directory": faiss_persist_directory_india,
30 | "doc_location": doc_location_india,
31 | },
32 | }
33 |
34 | for location in doc_locations:
35 | if not location.exists():
36 | raise Exception(f"File not found: {location}")
37 |
38 | embeddings = OpenAIEmbeddings(chunk_size=100)
39 | model = "gpt-3.5-turbo-16k"
40 | # model = 'gpt-4'
41 | llm = ChatOpenAI(model=model, temperature=0)
42 | search_results = 5
43 |
44 | def __repr__(self) -> str:
45 | return f"""# Configuration
46 | faiss_persist_directories: {self.faiss_dirs}
47 | doc_locations: {self.doc_locations}
48 |
49 | embeddings: {self.embeddings}
50 |
51 | llm: {self.llm}
52 | """
53 |
54 |
55 | cfg = Config()
56 |
57 | if __name__ == "__main__":
58 | print(cfg)
59 |
--------------------------------------------------------------------------------
/generate_embeddings.py:
--------------------------------------------------------------------------------
1 | from langchain.schema import Document
2 | from langchain.document_loaders import PyPDFium2Loader
3 | from langchain.vectorstores import FAISS
4 |
5 | from typing import TypeVar, List
6 | from pathlib import Path
7 | from dotenv import load_dotenv
8 | import numpy as np
9 |
10 | import os
11 | import re
12 |
13 | from config import cfg
14 |
15 | from log_init import logger
16 |
17 | load_dotenv()
18 |
19 | VST = TypeVar("VST", bound="VectorStore")
20 |
21 |
22 | def load_pdfs(path: Path) -> List[Document]:
23 | """
24 | Loads the PDFs and extracts a document per page.
25 | The page details are added to the extracted metadata
26 |
27 | Parameters:
28 | path (Path): The path where the PDFs are saved.
29 |
30 | Returns:
31 | List[Document]: Returns a list of values
32 | """
33 | assert path.exists()
34 | all_pages = []
35 | for pdf in path.glob("*.pdf"):
36 | loader = PyPDFium2Loader(str(pdf.absolute()))
37 | pages: List[Document] = loader.load_and_split()
38 | for i, p in enumerate(pages):
39 | file_name = re.sub(r".+[\\/]", "", p.metadata["source"])
40 | p.metadata["source"] = f"{file_name} page {i + 1}"
41 | all_pages.extend(pages)
42 | logger.info(f"Processed {pdf}, all_pages size: {len(all_pages)}")
43 | log_stats(all_pages)
44 | return all_pages
45 |
46 |
47 | def log_stats(documents: List[Document]):
48 | logger.info(f"Total number of documents {len(documents)}")
49 | counts = []
50 | for d in documents:
51 | counts.append(count_words(d))
52 | logger.info(f"Tokens Max {np.max(counts)}")
53 | logger.info(f"Tokens Min {np.min(counts)}")
54 | logger.info(f"Tokens Min {np.mean(counts)}")
55 |
56 |
57 | def generate_embeddings(
58 | documents: List[Document], path: Path, faiss_persist_directory: str
59 | ) -> VST:
60 | """
61 | Receives a list of documents and generates the embeddings via OpenAI API.
62 |
63 | Parameters:
64 | documents (List[Document]): The document list with one page per document.
65 | path (Path): The path where the documents are found.
66 |
67 | Returns:
68 | VST: Recturs a reference to the vector store.
69 | """
70 | try:
71 | docsearch = FAISS.from_documents(documents, cfg.embeddings)
72 | docsearch.save_local(faiss_persist_directory)
73 | logger.info("Vector database persisted")
74 | except Exception as e:
75 | logger.error(f"Failed to process {path}: {str(e)}")
76 | if "docsearch" in vars() or "docsearch" in globals():
77 | docsearch.persist()
78 | return docsearch
79 | return docsearch
80 |
81 |
82 | def count_words(document: Document) -> int:
83 | splits = [s for s in re.split("[\s,.]", document.page_content) if len(s) > 0]
84 | return len(splits)
85 |
86 |
87 | if __name__ == "__main__":
88 | doc_location: str = os.environ["DOC_LOCATION"]
89 | documents = load_pdfs(Path(doc_location))
90 | assert len(documents) > 0
91 | logger.info(documents[2].page_content)
92 | generate_embeddings(documents, doc_location)
93 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 | cover/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | .pybuilder/
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | # For a library or package, you might want to ignore these files since the code is
87 | # intended to run in multiple environments; otherwise, check them in:
88 | # .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # poetry
98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99 | # This is especially recommended for binary packages to ensure reproducibility, and is more
100 | # commonly ignored for libraries.
101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102 | #poetry.lock
103 |
104 | # pdm
105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106 | #pdm.lock
107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108 | # in version control.
109 | # https://pdm.fming.dev/#use-with-ide
110 | .pdm.toml
111 |
112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113 | __pypackages__/
114 |
115 | # Celery stuff
116 | celerybeat-schedule
117 | celerybeat.pid
118 |
119 | # SageMath parsed files
120 | *.sage.py
121 |
122 | # Environments
123 | .env
124 | .venv
125 | env/
126 | venv/
127 | ENV/
128 | env.bak/
129 | venv.bak/
130 |
131 | # Spyder project settings
132 | .spyderproject
133 | .spyproject
134 |
135 | # Rope project settings
136 | .ropeproject
137 |
138 | # mkdocs documentation
139 | /site
140 |
141 | # mypy
142 | .mypy_cache/
143 | .dmypy.json
144 | dmypy.json
145 |
146 | # Pyre type checker
147 | .pyre/
148 |
149 | # pytype static type analyzer
150 | .pytype/
151 |
152 | # Cython debug symbols
153 | cython_debug/
154 |
155 | # PyCharm
156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
158 | # and can be added to the global gitignore or merged into this file. For a more nuclear
159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
160 | #.idea/
--------------------------------------------------------------------------------
/hr_chatbot_chainlit.py:
--------------------------------------------------------------------------------
1 | from langchain.chains import RetrievalQAWithSourcesChain
2 | import chainlit as cl
3 |
4 | from chain_factory import load_all_chains
5 | from geolocation import extract_ip_address, geolocate
6 | from source_splitter import source_splitter
7 | from chainlit.context import get_emitter
8 |
9 | from log_init import logger
10 |
11 | from pathlib import Path
12 | from typing import Dict, Optional
13 |
14 | from config import cfg
15 |
16 | KEY_META_DATAS = "metadatas"
17 | KEY_TEXTS = "texts"
18 | KEY_GEOLOCATION_COUNTRY_CODE = "geolocation_country_code"
19 |
20 |
21 | def set_session_vars(user_session_dict: Dict):
22 | for k, v in user_session_dict.items():
23 | cl.user_session.set(k, v)
24 |
25 |
26 | def create_pdf(pdf_name: str, pdf_path: str) -> Optional[cl.File]:
27 | """
28 | Creates a file download button for a PDF file in case it is found.
29 |
30 | Parameters:
31 | pdf_name (str): The file name
32 | pdf_path (str): The file name
33 |
34 | Returns:
35 | RetrievalQAWithSourcesChain: The QA chain
36 | """
37 | logger.info(f"Creating pdf for {pdf_path}")
38 | # Sending a pdf with the local file path
39 | country_code = cl.user_session.get(KEY_GEOLOCATION_COUNTRY_CODE)
40 | country_config = cfg.location_persistence_map.get(country_code)
41 | if country_config:
42 | logger.info("country_config found")
43 | doc_location: Path = country_config.get("doc_location")
44 | doc_path = doc_location / pdf_path
45 | if doc_path.exists():
46 | logger.info("Creating pdf component")
47 | return cl.File(
48 | name=pdf_name, display="inline", path=str(doc_path.absolute())
49 | )
50 | else:
51 | logger.info(f"doc path {doc_path} does not exist.")
52 | return None
53 |
54 |
55 | @cl.langchain_factory(use_async=True)
56 | async def init():
57 | """
58 | Loads the vector data store object and the PDF documents. Creates the QA chain.
59 | Sets up some session variables and removes the Chainlit footer.
60 |
61 | Parameters:
62 | use_async (bool): Determines whether async is to be used or not.
63 |
64 | Returns:
65 | RetrievalQAWithSourcesChain: The QA chain
66 | """
67 |
68 | emitter = get_emitter()
69 | # Please note this works only with a modified version of Streamlit
70 | # The repo with this modification are here: https://github.com/gilfernandes/chainlit_hr_extension
71 |
72 | country_code = "GB"
73 | geolocation_failed = False
74 |
75 | try:
76 | remote_address = extract_ip_address(emitter.session.environ)
77 | geo_location = geolocate(remote_address)
78 |
79 | if geo_location.country_code != "Not found":
80 | country_code = geo_location.country_code
81 | # await display_location_details(geo_location, country_code)
82 | except:
83 | logger.exception("Could not locate properly")
84 | geolocation_failed = True
85 |
86 | if geolocation_failed:
87 | await cl.Message(content=f"Geolocation failed ... I do not know where you are.").send()
88 | else:
89 | logger.info(f"Geo location: {geo_location}")
90 |
91 | msg = cl.Message(content=f"Processing files. Please wait.")
92 | await msg.send()
93 | chain_dict = load_all_chains(country_code)
94 | qa_data = chain_dict[country_code]
95 |
96 | documents = qa_data.documents
97 |
98 | chain: RetrievalQAWithSourcesChain = qa_data.chain
99 | metadatas = [d.metadata for d in documents]
100 | texts = [d.page_content for d in documents]
101 |
102 | set_session_vars(
103 | {
104 | KEY_META_DATAS: metadatas,
105 | KEY_TEXTS: texts,
106 | KEY_GEOLOCATION_COUNTRY_CODE: country_code,
107 | }
108 | )
109 |
110 |
111 | msg.content = f"You can now ask questions about Onepoint HR ({country_code})!"
112 | await msg.send()
113 |
114 | return chain
115 |
116 |
117 | async def display_location_details(geo_location, country_code):
118 | geo_location_msg = cl.Message(
119 | content=f"""Geo location:
120 | - country: {geo_location.country_name}
121 | - country code: {country_code}"""
122 | )
123 | await geo_location_msg.send()
124 |
125 |
126 | @cl.langchain_postprocess
127 | async def process_response(res) -> cl.Message:
128 | """
129 | Tries to extract the sources and corresponding texts from the sources.
130 |
131 | Parameters:
132 | res (dict): A dictionary with the answer and sources provided by the LLM via LangChain.
133 |
134 | Returns:
135 | cl.Message: The message containing the answer and the list of sources with corresponding texts.
136 | """
137 | answer = res["answer"]
138 | sources = res["sources"].strip()
139 | source_elements = []
140 |
141 | # Get the metadata and texts from the user session
142 | metadatas = cl.user_session.get(KEY_META_DATAS)
143 | all_sources = [m["source"] for m in metadatas]
144 | texts = cl.user_session.get(KEY_TEXTS)
145 |
146 | found_sources = []
147 | pdf_elements = []
148 | if sources:
149 | logger.info(f"sources: {sources}")
150 | raw_sources, file_sources = source_splitter(sources)
151 | for i, source in enumerate(raw_sources):
152 | try:
153 | source_name = file_sources[i]
154 | pdf_element = create_pdf(source_name, source_name)
155 | if pdf_element:
156 | pdf_elements.append(pdf_element)
157 | logger.info(f"PDF Elements: {pdf_elements}")
158 | else:
159 | logger.warning(f"No pdf element for {source_name}")
160 |
161 | index = all_sources.index(source)
162 | text = texts[index]
163 | found_sources.append(source)
164 | # Create the text element referenced in the message
165 | logger.info(f"Found text in {source_name}")
166 | source_elements.append(cl.Text(content=text, name=source_name))
167 | except ValueError as e:
168 | logger.error(f"Value error {e}")
169 | continue
170 | if found_sources:
171 | answer += f"\nSources: {', '.join(found_sources)}"
172 | else:
173 | answer += f"\n{sources}"
174 |
175 | logger.info(f"PDF Elements: {pdf_elements}")
176 | await cl.Message(content=answer, elements=source_elements).send()
177 | await cl.Message(content="PDF Downloads", elements=pdf_elements).send()
178 |
179 |
180 | if __name__ == "__main__":
181 | pass
182 |
--------------------------------------------------------------------------------
/chain_factory.py:
--------------------------------------------------------------------------------
1 | from langchain.chains import RetrievalQAWithSourcesChain
2 | from langchain.memory import ConversationSummaryBufferMemory
3 | from langchain.vectorstores import FAISS
4 | from langchain.schema import Document
5 | from langchain.prompts import PromptTemplate
6 | from langchain.memory.utils import get_prompt_input_key
7 | from langchain.vectorstores.base import VectorStoreRetriever, VectorStore
8 | from config import cfg
9 | from typing import Any, Dict, Tuple, List, TypeVar
10 |
11 |
12 | import os
13 | from pathlib import Path
14 |
15 | from generate_embeddings import load_pdfs, generate_embeddings
16 | from hr_model import QAData
17 | from log_init import logger
18 |
19 | VST = TypeVar("VST", bound="VectorStore")
20 |
21 |
22 | class KeySourceMemory(ConversationSummaryBufferMemory):
23 | def _get_input_output(
24 | self, inputs: Dict[str, Any], outputs: Dict[str, str]
25 | ) -> Tuple[str, str]:
26 | if self.input_key is None:
27 | prompt_input_key = get_prompt_input_key(inputs, self.memory_variables)
28 | else:
29 | prompt_input_key = self.input_key
30 | if self.output_key is None:
31 | output_key = "answer"
32 | else:
33 | output_key = self.output_key
34 | return inputs[prompt_input_key], outputs[output_key]
35 |
36 |
37 | def load_embeddinges(
38 | embedding_dir: Path = cfg.faiss_persist_directory_uk,
39 | doc_location: Path = cfg.doc_location_uk,
40 | ) -> Tuple[VST, List[Document]]:
41 | """
42 | Loads the PDF documents to support text extraction in the Chainlit UI.
43 | In case there are no persisted embeddings, the embeddings are generated.
44 | In case the embeddings are persisted, then they are loaded from the file system.
45 |
46 | Returns:
47 | Tuple[VST, List[Document]]: Recturs a reference to the vector store and the list of all pdf files.
48 | """
49 | logger.info(f"Checking: {embedding_dir}")
50 | documents = load_pdfs(doc_location)
51 | assert len(documents) > 0
52 | if embedding_dir.exists() and len(list(embedding_dir.glob("*"))) > 0:
53 | logger.info(f"reading from existing directory: {embedding_dir}")
54 | docsearch = FAISS.load_local(embedding_dir, cfg.embeddings)
55 | return docsearch, documents
56 | return (
57 | generate_embeddings(documents, doc_location, embedding_dir.absolute()),
58 | documents,
59 | )
60 |
61 |
62 | template = """Given the following extracted parts of a long document and a question, create a final answer with references ("SOURCES"). If you know a joke about the subject, make sure that you include it in the response.
63 | If you don't know the answer, say that you don't know and make up some joke about the subject. Don't try to make up an answer.
64 | ALWAYS return a "SOURCES" part in your answer.
65 |
66 | QUESTION: Which state/country's law governs the interpretation of the contract?
67 | =========
68 | Content: This Agreement is governed by English law and the parties submit to the exclusive jurisdiction of the English courts in relation to any dispute (contractual or non-contractual) concerning this Agreement save that either party may apply to any court for an injunction or other relief to protect its Intellectual Property Rights.
69 | Source: 28-pl
70 | Content: No Waiver. Failure or delay in exercising any right or remedy under this Agreement shall not constitute a waiver of such (or any other) right or remedy.\n\n11.7 Severability. The invalidity, illegality or unenforceability of any term (or part of a term) of this Agreement shall not affect the continuation in force of the remainder of the term (if any) and this Agreement.\n\n11.8 No Agency. Except as expressly stated otherwise, nothing in this Agreement shall create an agency, partnership or joint venture of any kind between the parties.\n\n11.9 No Third-Party Beneficiaries.
71 | Source: 30-pl
72 | Content: (b) if Google believes, in good faith, that the Distributor has violated or caused Google to violate any Anti-Bribery Laws (as defined in Clause 8.5) or that such a violation is reasonably likely to occur,
73 | Source: 4-pl
74 | =========
75 | FINAL ANSWER: This Agreement is governed by English law.
76 | SOURCES: 28-pl
77 |
78 | QUESTION: What did the president say about Michael Jackson?
79 | =========
80 | Content: Madam Speaker, Madam Vice President, our First Lady and Second Gentleman. Members of Congress and the Cabinet. Justices of the Supreme Court. My fellow Americans. \n\nLast year COVID-19 kept us apart. This year we are finally together again. \n\nTonight, we meet as Democrats Republicans and Independents. But most importantly as Americans. \n\nWith a duty to one another to the American people to the Constitution. \n\nAnd with an unwavering resolve that freedom will always triumph over tyranny. \n\nSix days ago, Russia’s Vladimir Putin sought to shake the foundations of the free world thinking he could make it bend to his menacing ways. But he badly miscalculated. \n\nHe thought he could roll into Ukraine and the world would roll over. Instead he met a wall of strength he never imagined. \n\nHe met the Ukrainian people. \n\nFrom President Zelenskyy to every Ukrainian, their fearlessness, their courage, their determination, inspires the world. \n\nGroups of citizens blocking tanks with their bodies. Everyone from students to retirees teachers turned soldiers defending their homeland.
81 | Source: 0-pl
82 | Content: And we won’t stop. \n\nWe have lost so much to COVID-19. Time with one another. And worst of all, so much loss of life. \n\nLet’s use this moment to reset. Let’s stop looking at COVID-19 as a partisan dividing line and see it for what it is: A God-awful disease. \n\nLet’s stop seeing each other as enemies, and start seeing each other for who we really are: Fellow Americans. \n\nWe can’t change how divided we’ve been. But we can change how we move forward—on COVID-19 and other issues we must face together. \n\nI recently visited the New York City Police Department days after the funerals of Officer Wilbert Mora and his partner, Officer Jason Rivera. \n\nThey were responding to a 9-1-1 call when a man shot and killed them with a stolen gun. \n\nOfficer Mora was 27 years old. \n\nOfficer Rivera was 22. \n\nBoth Dominican Americans who’d grown up on the same streets they later chose to patrol as police officers. \n\nI spoke with their families and told them that we are forever in debt for their sacrifice, and we will carry on their mission to restore the trust and safety every community deserves.
83 | Source: 24-pl
84 | Content: And a proud Ukrainian people, who have known 30 years of independence, have repeatedly shown that they will not tolerate anyone who tries to take their country backwards. \n\nTo all Americans, I will be honest with you, as I’ve always promised. A Russian dictator, invading a foreign country, has costs around the world. \n\nAnd I’m taking robust action to make sure the pain of our sanctions is targeted at Russia’s economy. And I will use every tool at our disposal to protect American businesses and consumers. \n\nTonight, I can announce that the United States has worked with 30 other countries to release 60 Million barrels of oil from reserves around the world. \n\nAmerica will lead that effort, releasing 30 Million barrels from our own Strategic Petroleum Reserve. And we stand ready to do more if necessary, unified with our allies. \n\nThese steps will help blunt gas prices here at home. And I know the news about what’s happening can seem alarming. \n\nBut I want you to know that we are going to be okay.
85 | Source: 5-pl
86 | Content: More support for patients and families. \n\nTo get there, I call on Congress to fund ARPA-H, the Advanced Research Projects Agency for Health. \n\nIt’s based on DARPA—the Defense Department project that led to the Internet, GPS, and so much more. \n\nARPA-H will have a singular purpose—to drive breakthroughs in cancer, Alzheimer’s, diabetes, and more. \n\nA unity agenda for the nation. \n\nWe can do this. \n\nMy fellow Americans—tonight , we have gathered in a sacred space—the citadel of our democracy. \n\nIn this Capitol, generation after generation, Americans have debated great questions amid great strife, and have done great things. \n\nWe have fought for freedom, expanded liberty, defeated totalitarianism and terror. \n\nAnd built the strongest, freest, and most prosperous nation the world has ever known. \n\nNow is the hour. \n\nOur moment of responsibility. \n\nOur test of resolve and conscience, of history itself. \n\nIt is in this moment that our character is formed. Our purpose is found. Our future is forged. \n\nWell I know this nation.
87 | Source: 34-pl
88 | =========
89 | FINAL ANSWER: The president did not mention Michael Jackson. And here is a joke about Michael Jackson: Why did Michael Jackson go to the bakery? Because he wanted to "beat it" and grab some "moon-pies"!
90 | SOURCES:
91 |
92 | QUESTION: {question}
93 | =========
94 | {summaries}
95 | =========
96 | FINAL ANSWER:"""
97 | HUMOUR_PROMPT = PromptTemplate(
98 | template=template, input_variables=["summaries", "question"]
99 | )
100 |
101 |
102 | def create_retrieval_chain(
103 | docsearch: VST, verbose: bool = False, humour: bool = True
104 | ) -> RetrievalQAWithSourcesChain:
105 | """
106 | This function creates the QA chain with memory and in case the humour parameter is true,
107 | then a manipulated prompt - that tends to create jokes on certain occasions - is used.
108 |
109 | Parameters:
110 | docsearch (VST): A reference to the vector store.
111 | verbose (bool): Determines whether LangChain's internal logging is printed to the console or not.
112 | humour (bool): Determines whether the prompt for answers with jokes is used or not.
113 |
114 | Returns:
115 | RetrievalQAWithSourcesChain: The QA chain
116 | """
117 | memory = KeySourceMemory(llm=cfg.llm, input_key="question", output_key="answer")
118 | chain_type_kwargs = {}
119 | if verbose:
120 | chain_type_kwargs["verbose"] = True
121 | if humour:
122 | chain_type_kwargs["prompt"] = HUMOUR_PROMPT
123 | search_retriever: VectorStoreRetriever = docsearch.as_retriever()
124 | search_retriever.search_kwargs = {"k": cfg.search_results}
125 | qa_chain = RetrievalQAWithSourcesChain.from_chain_type(
126 | cfg.llm,
127 | retriever=search_retriever,
128 | chain_type="stuff",
129 | memory=memory,
130 | chain_type_kwargs=chain_type_kwargs,
131 | )
132 |
133 | return qa_chain
134 |
135 |
136 | def load_all_chains(country_filter: str = None) -> Dict[str, QAData]:
137 | res = {}
138 | for country, v in cfg.location_persistence_map.items():
139 | if country_filter is None or country_filter == country:
140 | faiss_persist_directory = v["faiss_persist_directory"]
141 | doc_location = v["doc_location"]
142 | vst, documents = load_embeddinges(faiss_persist_directory, doc_location)
143 | chain = create_retrieval_chain(vst, humour=os.getenv("HUMOUR") == "true")
144 | res[country] = QAData(vst=vst, documents=documents, chain=chain)
145 | return res
146 |
147 |
148 | if __name__ == "__main__":
149 | chain_dict = load_all_chains()
150 | logger.info(len(chain_dict.items()))
151 |
--------------------------------------------------------------------------------