├── .gitignore
├── LICENSE
├── README.md
├── WikipediaQA.py
├── WikipediaQA_batch_runs.ipynb
├── constants.py
├── gradio_app.py
├── images
└── temp.txt
├── requirements.txt
└── template.env
/.gitignore:
--------------------------------------------------------------------------------
1 | .env
2 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023 Jou-ching Sung
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Document Q&A on Wikipedia articles
2 | Run [document Q&A](https://python.langchain.com/en/latest/use_cases/question_answering.html) on Wikipedia articles. Use [Wikipedia-API](https://pypi.org/project/Wikipedia-API/) to search/retrieve/beautify Wikipedia articles, [LangChain](https://python.langchain.com/en/latest/index.html) for the Q&A framework, and OpenAI & [HuggingFace](https://huggingface.co/) models for embeddings and LLMs. The meat of the code is in `WikipediaQA.py`.
3 |
4 | For the accompanying blog post, see [https://georgesung.github.io/ai/llm-qa-eval-wikipedia/](https://georgesung.github.io/ai/llm-qa-eval-wikipedia/)
5 |
6 | ## Architecture
7 | **Search and index Wikipedia article**
8 |
9 | 
10 |
11 | **Q&A on article**
12 |
13 | 
14 | ## Instructions
15 | ### Batch runs
16 | For a batch run over different LLMs and embedding models, you can run the notebook `WikipediaQA_batch_runs.ipynb` in your own compute instance, or run the same notebook on Colab:
17 |
18 |
19 |
20 |
21 |
22 | ### Interactive app
23 | To run an interactive Gradio app, do the following:
24 | * `pip install -r requirements.txt`
25 | * If you're using OpenAI ada embeddings and/or GPT 3.5, then `cp template.env .env`, and edit `.env` to include your OpenAI API key
26 | * `python gradio_app.py`
27 |
28 | ## Results with different LLMs and embeddings
29 | For detailed results and analysis, see the full blog post [here](https://georgesung.github.io/ai/llm-qa-eval-wikipedia/)
30 |
--------------------------------------------------------------------------------
/WikipediaQA.py:
--------------------------------------------------------------------------------
1 | import os
2 | import re
3 | import time
4 |
5 | import requests
6 | import wikipediaapi
7 | from InstructorEmbedding import INSTRUCTOR
8 | from langchain import HuggingFacePipeline
9 | from langchain.chains import RetrievalQA
10 | from langchain.docstore.document import Document
11 | from langchain.embeddings import HuggingFaceInstructEmbeddings
12 | from langchain.embeddings.openai import OpenAIEmbeddings
13 | from langchain.llms import OpenAI
14 | from langchain.prompts import PromptTemplate
15 | from langchain.text_splitter import CharacterTextSplitter, TokenTextSplitter
16 | from langchain.vectorstores import Chroma
17 | from transformers import pipeline
18 |
19 | from constants import *
20 |
21 | class WikipediaQA:
22 | question_check_template = """Given the following pieces of context, determine if the question is able to be answered by the information in the context.
23 | Respond with 'yes' or 'no'.
24 | {context}
25 | Question: {question}
26 | """
27 | QUESTION_CHECK_PROMPT = PromptTemplate(
28 | template=question_check_template, input_variables=["context", "question"]
29 | )
30 | def __init__(self, config: dict={}):
31 | self.config = config
32 | self.embedding = None
33 | self.vectordb = None
34 | self.llm = None
35 | self.qa = None
36 |
37 | # The following class methods are useful to create global GPU model instances
38 | # This way we don't need to reload models in an interactive app,
39 | # and the same model instance can be used across multiple user sessions
40 | @classmethod
41 | def create_instructor_xl(cls):
42 | return HuggingFaceInstructEmbeddings(model_name=EMB_INSTRUCTOR_XL, model_kwargs={"device": "cuda"})
43 |
44 | @classmethod
45 | def create_flan_t5_xxl(cls, load_in_8bit=False):
46 | # Local flan-t5-xxl with 8-bit quantization for inference
47 | # Wrap it in HF pipeline for use with LangChain
48 | return pipeline(
49 | task="text2text-generation",
50 | model="google/flan-t5-xxl",
51 | model_kwargs={"device_map": "auto", "load_in_8bit": load_in_8bit, "max_length": 512, "temperature": 0.}
52 | )
53 |
54 | @classmethod
55 | def create_flan_t5_xl(cls, load_in_8bit=False):
56 | return pipeline(
57 | task="text2text-generation",
58 | model="google/flan-t5-xl",
59 | model_kwargs={"device_map": "auto", "load_in_8bit": load_in_8bit, "max_length": 512, "temperature": 0.}
60 | )
61 |
62 | @classmethod
63 | def create_fastchat_t5_xl(cls, load_in_8bit=False):
64 | return pipeline(
65 | task="text2text-generation",
66 | model = "lmsys/fastchat-t5-3b-v1.0",
67 | model_kwargs={"device_map": "auto", "load_in_8bit": load_in_8bit, "max_length": 512, "temperature": 0.}
68 | )
69 |
70 | def init_models(self) -> None:
71 | """ Initialize new models based on config """
72 | load_in_8bit = self.config["load_in_8bit"]
73 |
74 | if self.config["embedding"] == EMB_OPENAI_ADA:
75 | # OpenAI ada embeddings API
76 | self.embedding = OpenAIEmbeddings()
77 | elif self.config["embedding"] == EMB_INSTRUCTOR_XL:
78 | # Local INSTRUCTOR-XL embeddings
79 | if self.embedding is None:
80 | self.embedding = WikipediaQA.create_instructor_xl()
81 | else:
82 | raise ValueError("Invalid config")
83 |
84 | if self.config["llm"] == LLM_OPENAI_GPT35:
85 | # OpenAI GPT 3.5 API
86 | pass
87 | elif self.config["llm"] == LLM_FLAN_T5_XL:
88 | if self.llm is None:
89 | self.llm = WikipediaQA.create_flan_t5_xl(load_in_8bit=load_in_8bit)
90 | elif self.config["llm"] == LLM_FLAN_T5_XXL:
91 | if self.llm is None:
92 | self.llm = WikipediaQA.create_flan_t5_xxl(load_in_8bit=load_in_8bit)
93 | elif self.config["llm"] == LLM_FASTCHAT_T5_XL:
94 | if self.llm is None:
95 | self.llm = WikipediaQA.create_fastchat_t5_xl(load_in_8bit=load_in_8bit)
96 | else:
97 | raise ValueError("Invalid config")
98 |
99 | def search_and_read_page(self, search_query: str) -> tuple[str, str]:
100 | """
101 | Searches wikipedia for the given query, take the first result
102 | Then chunks the text of it and indexes it into a vector store
103 |
104 | Returns the title and text of the page
105 | """
106 | # Search Wikipedia and get first result
107 | wiki_wiki = wikipediaapi.Wikipedia('en')
108 | docs = {}
109 | search_url = f"https://en.wikipedia.org/w/api.php?action=query&format=json&list=search&srsearch={search_query}"
110 | search_response = requests.get(search_url).json()
111 | wiki_title = search_response["query"]["search"][0]["title"]
112 | wiki_text = wiki_wiki.page(wiki_title).text
113 | docs[wiki_title] = wiki_text
114 |
115 | # Create new vector store and index it
116 | self.vectordb = None
117 | documents = [Document(page_content=docs[title]) for title in docs]
118 |
119 | # Split by section, then split by token limmit
120 | text_splitter = CharacterTextSplitter(chunk_size=100, chunk_overlap=0)
121 | texts = text_splitter.split_documents(documents)
122 | text_splitter = TokenTextSplitter(chunk_size=1000, chunk_overlap=10, encoding_name="cl100k_base") # may be inexact
123 | texts = text_splitter.split_documents(texts)
124 |
125 | self.vectordb = Chroma.from_documents(documents=texts, embedding=self.embedding)
126 |
127 | # Create the LangChain chain
128 | if self.config["llm"] == LLM_OPENAI_GPT35:
129 | # Use ChatGPT API
130 | self.qa = RetrievalQA.from_chain_type(llm=OpenAI(model_name=LLM_OPENAI_GPT35, temperature=0.), chain_type="stuff",\
131 | retriever=self.vectordb.as_retriever(search_kwargs={"k":4}))
132 | else:
133 | # Use local LLM
134 | hf_llm = HuggingFacePipeline(pipeline=self.llm)
135 | self.qa = RetrievalQA.from_chain_type(llm=hf_llm, chain_type="stuff",\
136 | retriever=self.vectordb.as_retriever(search_kwargs={"k":4}))
137 | if self.config["question_check"]:
138 | self.q_check = RetrievalQA.from_chain_type(llm=hf_llm, chain_type="stuff",\
139 | retriever=self.vectordb.as_retriever(search_kwargs={"k":4}))
140 | self.q_check.combine_documents_chain.llm_chain.prompt = WikipediaQA.QUESTION_CHECK_PROMPT
141 |
142 | return wiki_title, wiki_text
143 |
144 | def get_answer(self, question: str) -> str:
145 | if self.config["llm"] != LLM_OPENAI_GPT35 and self.config["question_check"]:
146 | # For local LLMs, do a self-check to see if question can be answered
147 | # If unanswerable, respond with "I don't know"
148 | answerable = self.q_check.run(question)
149 | if self.config["llm"] == LLM_FASTCHAT_T5_XL:
150 | answerable = self._clean_fastchat_t5_output(answerable)
151 | if answerable != "yes":
152 | return "I don't know"
153 |
154 | # Answer the question
155 | answer = self.qa.run(question)
156 | if self.config["llm"] == LLM_FASTCHAT_T5_XL:
157 | answer = self._clean_fastchat_t5_output(answer)
158 | return answer
159 |
160 | def _clean_fastchat_t5_output(self, answer: str) -> str:
161 | # Remove tags, double spaces, trailing newline
162 | answer = re.sub(r"\s+", "", answer)
163 | answer = re.sub(r" ", " ", answer)
164 | answer = re.sub(r"\n$", "", answer)
165 | return answer
166 |
--------------------------------------------------------------------------------
/constants.py:
--------------------------------------------------------------------------------
1 | EMB_OPENAI_ADA = "text-embedding-ada-002"
2 | EMB_INSTRUCTOR_XL = "hkunlp/instructor-xl"
3 |
4 | LLM_OPENAI_GPT35 = "gpt-3.5-turbo"
5 | LLM_FLAN_T5_XXL = "google/flan-t5-xxl"
6 | LLM_FLAN_T5_XL = "google/flan-t5-xl"
7 | LLM_FASTCHAT_T5_XL = "lmsys/fastchat-t5-3b-v1.0"
8 |
--------------------------------------------------------------------------------
/gradio_app.py:
--------------------------------------------------------------------------------
1 | import os
2 | import re
3 | import time
4 |
5 | import wikipediaapi
6 | from InstructorEmbedding import INSTRUCTOR
7 | from langchain import HuggingFacePipeline
8 | from langchain.chains import RetrievalQA
9 | from langchain.docstore.document import Document
10 | from langchain.embeddings import HuggingFaceInstructEmbeddings
11 | from langchain.embeddings.openai import OpenAIEmbeddings
12 | from langchain.llms import OpenAI
13 | from langchain.prompts import PromptTemplate
14 | from langchain.text_splitter import CharacterTextSplitter, TokenTextSplitter
15 | from langchain.vectorstores import Chroma
16 | from transformers import pipeline
17 |
18 | from constants import *
19 | from WikipediaQA import WikipediaQA
20 |
21 | from dotenv import load_dotenv
22 | import gradio as gr
23 |
24 | # Load OpenAI API key
25 | load_dotenv()
26 |
27 | # Global model instances
28 | instructor_xl = None
29 | flan_t5_xl = None
30 | flan_t5_xxl = None
31 | fastchat_t5_xl = None
32 |
33 | with gr.Blocks() as qa_app:
34 | # Initialization
35 | qa = gr.State(WikipediaQA({"question_check": True, "load_in_8bit": False}))
36 |
37 | # Layout
38 | gr.Markdown("""# Wikipedia Q&A
39 | Use OpenAI and/or local models for embeddings/LLM
40 | """)
41 |
42 | with gr.Tab("Model setup"):
43 | gr.Markdown("""Select embedding and LLM models
44 | OpenAI models require an API key, duplicate this space and use [secrets](https://huggingface.co/docs/hub/spaces-overview#managing-secrets)
45 | """)
46 | with gr.Row() as row:
47 | with gr.Column():
48 | emb_radio = gr.Radio([EMB_OPENAI_ADA, EMB_INSTRUCTOR_XL],
49 | label="Select embedding model")
50 | llm_radio = gr.Radio([LLM_OPENAI_GPT35, LLM_FLAN_T5_XL, LLM_FLAN_T5_XXL, LLM_FASTCHAT_T5_XL],
51 | label="Select LLM model",
52 | info="Note: flan-t5-xxl will run out of memory on a single A10G")
53 | with gr.Column():
54 | model_text_box = gr.Textbox(label="Current models")
55 | model_load_btn = gr.Button("Load models")
56 |
57 | with gr.Tab("Read Wikipedia"):
58 | gr.Markdown("""Search Wikipedia and get the first result
59 | Chunk the article and index local vector store
60 | """)
61 | with gr.Row() as row:
62 | with gr.Column():
63 | query = gr.Textbox(label="Wikipedia search query")
64 | search_btn = gr.Button("Search")
65 | with gr.Column():
66 | wiki_title_box = gr.Textbox(label="Article title")
67 | wiki_text_box = gr.Textbox(label="Article text")
68 |
69 | with gr.Tab("Q&A"):
70 | gr.Markdown("""Ask a question about the Wikipedia article""")
71 | with gr.Row() as row:
72 | with gr.Column():
73 | question = gr.Textbox(label="Enter your question")
74 | question_btn = gr.Button("Ask")
75 | with gr.Column():
76 | answer_box = gr.Textbox(label="Answer")
77 |
78 | # Logic
79 | # Model setup
80 | def load_model(emb, llm, qa):
81 | global instructor_xl
82 | global flan_t5_xl
83 | global flan_t5_xxl
84 | global fastchat_t5_xl
85 |
86 | if emb == EMB_OPENAI_ADA:
87 | qa.embedding = OpenAIEmbeddings()
88 | elif emb == EMB_INSTRUCTOR_XL:
89 | if instructor_xl is None:
90 | instructor_xl = WikipediaQA.create_instructor_xl()
91 | qa.embedding = instructor_xl
92 | else:
93 | raise ValueError("Invalid embedding setting")
94 | qa.config["embedding"] = emb
95 |
96 | if llm == LLM_OPENAI_GPT35:
97 | pass
98 | elif llm == LLM_FLAN_T5_XL:
99 | if flan_t5_xl is None:
100 | flan_t5_xl = WikipediaQA.create_flan_t5_xl()
101 | qa.llm = flan_t5_xl
102 | elif llm == LLM_FLAN_T5_XXL:
103 | if flan_t5_xxl is None:
104 | flan_t5_xxl = WikipediaQA.create_flan_t5_xxl()
105 | qa.llm = flan_t5_xxl
106 | elif llm == LLM_FASTCHAT_T5_XL:
107 | if fastchat_t5_xl is None:
108 | fastchat_t5_xl = WikipediaQA.create_fastchat_t5_xl()
109 | qa.llm = fastchat_t5_xl
110 | else:
111 | raise ValueError("Invalid LLM setting")
112 | qa.config["llm"] = llm
113 |
114 | return f"Embedding model: {emb}\nLLM: {llm}"
115 |
116 | model_load_btn.click(
117 | load_model,
118 | [emb_radio, llm_radio, qa],
119 | [model_text_box]
120 | )
121 |
122 | # Search Wikipedia
123 | def wiki_search(query, qa):
124 | wiki_title, wiki_text = qa.search_and_read_page(query)
125 | return {
126 | wiki_title_box: wiki_title,
127 | wiki_text_box: wiki_text,
128 | }
129 |
130 | search_btn.click(
131 | wiki_search,
132 | [query, qa],
133 | [wiki_title_box, wiki_text_box]
134 | )
135 |
136 | # Q&A
137 | def qa_fn(question, qa):
138 | answer = qa.get_answer(question)
139 | return {answer_box: answer}
140 |
141 | question_btn.click(
142 | qa_fn,
143 | [question, qa],
144 | [answer_box]
145 |
146 | )
147 | qa_app.launch(debug=True)
148 |
--------------------------------------------------------------------------------
/images/temp.txt:
--------------------------------------------------------------------------------
1 | just creating a new folder for images, delete this later
2 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | transformers
2 | langchain
3 | accelerate
4 | bitsandbytes
5 | chromadb
6 | beautifulsoup4
7 | openai
8 | tiktoken
9 | sentence_transformers
10 | wikipedia-api==0.5.8
11 | InstructorEmbedding
12 | gradio
13 | python-dotenv
14 | urllib3==1.26.6
15 |
--------------------------------------------------------------------------------
/template.env:
--------------------------------------------------------------------------------
1 | OPENAI_API_KEY=hello
2 |
--------------------------------------------------------------------------------