├── .gitignore ├── LICENSE ├── README.md ├── WikipediaQA.py ├── WikipediaQA_batch_runs.ipynb ├── constants.py ├── gradio_app.py ├── images └── temp.txt ├── requirements.txt └── template.env /.gitignore: -------------------------------------------------------------------------------- 1 | .env 2 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Jou-ching Sung 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Document Q&A on Wikipedia articles 2 | Run [document Q&A](https://python.langchain.com/en/latest/use_cases/question_answering.html) on Wikipedia articles. Use [Wikipedia-API](https://pypi.org/project/Wikipedia-API/) to search/retrieve/beautify Wikipedia articles, [LangChain](https://python.langchain.com/en/latest/index.html) for the Q&A framework, and OpenAI & [HuggingFace](https://huggingface.co/) models for embeddings and LLMs. The meat of the code is in `WikipediaQA.py`. 3 | 4 | For the accompanying blog post, see [https://georgesung.github.io/ai/llm-qa-eval-wikipedia/](https://georgesung.github.io/ai/llm-qa-eval-wikipedia/) 5 | 6 | ## Architecture 7 | **Search and index Wikipedia article** 8 | 9 | ![arch](https://georgesung.github.io/assets/img/wikiqa_search.svg) 10 | 11 | **Q&A on article** 12 | 13 | ![arch](https://georgesung.github.io/assets/img/wikiqa_guardrail.svg) 14 | ## Instructions 15 | ### Batch runs 16 | For a batch run over different LLMs and embedding models, you can run the notebook `WikipediaQA_batch_runs.ipynb` in your own compute instance, or run the same notebook on Colab: 17 | 18 | 19 | Open In Colab 20 | 21 | 22 | ### Interactive app 23 | To run an interactive Gradio app, do the following: 24 | * `pip install -r requirements.txt` 25 | * If you're using OpenAI ada embeddings and/or GPT 3.5, then `cp template.env .env`, and edit `.env` to include your OpenAI API key 26 | * `python gradio_app.py` 27 | 28 | ## Results with different LLMs and embeddings 29 | For detailed results and analysis, see the full blog post [here](https://georgesung.github.io/ai/llm-qa-eval-wikipedia/) 30 | -------------------------------------------------------------------------------- /WikipediaQA.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | import time 4 | 5 | import requests 6 | import wikipediaapi 7 | from InstructorEmbedding import INSTRUCTOR 8 | from langchain import HuggingFacePipeline 9 | from langchain.chains import RetrievalQA 10 | from langchain.docstore.document import Document 11 | from langchain.embeddings import HuggingFaceInstructEmbeddings 12 | from langchain.embeddings.openai import OpenAIEmbeddings 13 | from langchain.llms import OpenAI 14 | from langchain.prompts import PromptTemplate 15 | from langchain.text_splitter import CharacterTextSplitter, TokenTextSplitter 16 | from langchain.vectorstores import Chroma 17 | from transformers import pipeline 18 | 19 | from constants import * 20 | 21 | class WikipediaQA: 22 | question_check_template = """Given the following pieces of context, determine if the question is able to be answered by the information in the context. 23 | Respond with 'yes' or 'no'. 24 | {context} 25 | Question: {question} 26 | """ 27 | QUESTION_CHECK_PROMPT = PromptTemplate( 28 | template=question_check_template, input_variables=["context", "question"] 29 | ) 30 | def __init__(self, config: dict={}): 31 | self.config = config 32 | self.embedding = None 33 | self.vectordb = None 34 | self.llm = None 35 | self.qa = None 36 | 37 | # The following class methods are useful to create global GPU model instances 38 | # This way we don't need to reload models in an interactive app, 39 | # and the same model instance can be used across multiple user sessions 40 | @classmethod 41 | def create_instructor_xl(cls): 42 | return HuggingFaceInstructEmbeddings(model_name=EMB_INSTRUCTOR_XL, model_kwargs={"device": "cuda"}) 43 | 44 | @classmethod 45 | def create_flan_t5_xxl(cls, load_in_8bit=False): 46 | # Local flan-t5-xxl with 8-bit quantization for inference 47 | # Wrap it in HF pipeline for use with LangChain 48 | return pipeline( 49 | task="text2text-generation", 50 | model="google/flan-t5-xxl", 51 | model_kwargs={"device_map": "auto", "load_in_8bit": load_in_8bit, "max_length": 512, "temperature": 0.} 52 | ) 53 | 54 | @classmethod 55 | def create_flan_t5_xl(cls, load_in_8bit=False): 56 | return pipeline( 57 | task="text2text-generation", 58 | model="google/flan-t5-xl", 59 | model_kwargs={"device_map": "auto", "load_in_8bit": load_in_8bit, "max_length": 512, "temperature": 0.} 60 | ) 61 | 62 | @classmethod 63 | def create_fastchat_t5_xl(cls, load_in_8bit=False): 64 | return pipeline( 65 | task="text2text-generation", 66 | model = "lmsys/fastchat-t5-3b-v1.0", 67 | model_kwargs={"device_map": "auto", "load_in_8bit": load_in_8bit, "max_length": 512, "temperature": 0.} 68 | ) 69 | 70 | def init_models(self) -> None: 71 | """ Initialize new models based on config """ 72 | load_in_8bit = self.config["load_in_8bit"] 73 | 74 | if self.config["embedding"] == EMB_OPENAI_ADA: 75 | # OpenAI ada embeddings API 76 | self.embedding = OpenAIEmbeddings() 77 | elif self.config["embedding"] == EMB_INSTRUCTOR_XL: 78 | # Local INSTRUCTOR-XL embeddings 79 | if self.embedding is None: 80 | self.embedding = WikipediaQA.create_instructor_xl() 81 | else: 82 | raise ValueError("Invalid config") 83 | 84 | if self.config["llm"] == LLM_OPENAI_GPT35: 85 | # OpenAI GPT 3.5 API 86 | pass 87 | elif self.config["llm"] == LLM_FLAN_T5_XL: 88 | if self.llm is None: 89 | self.llm = WikipediaQA.create_flan_t5_xl(load_in_8bit=load_in_8bit) 90 | elif self.config["llm"] == LLM_FLAN_T5_XXL: 91 | if self.llm is None: 92 | self.llm = WikipediaQA.create_flan_t5_xxl(load_in_8bit=load_in_8bit) 93 | elif self.config["llm"] == LLM_FASTCHAT_T5_XL: 94 | if self.llm is None: 95 | self.llm = WikipediaQA.create_fastchat_t5_xl(load_in_8bit=load_in_8bit) 96 | else: 97 | raise ValueError("Invalid config") 98 | 99 | def search_and_read_page(self, search_query: str) -> tuple[str, str]: 100 | """ 101 | Searches wikipedia for the given query, take the first result 102 | Then chunks the text of it and indexes it into a vector store 103 | 104 | Returns the title and text of the page 105 | """ 106 | # Search Wikipedia and get first result 107 | wiki_wiki = wikipediaapi.Wikipedia('en') 108 | docs = {} 109 | search_url = f"https://en.wikipedia.org/w/api.php?action=query&format=json&list=search&srsearch={search_query}" 110 | search_response = requests.get(search_url).json() 111 | wiki_title = search_response["query"]["search"][0]["title"] 112 | wiki_text = wiki_wiki.page(wiki_title).text 113 | docs[wiki_title] = wiki_text 114 | 115 | # Create new vector store and index it 116 | self.vectordb = None 117 | documents = [Document(page_content=docs[title]) for title in docs] 118 | 119 | # Split by section, then split by token limmit 120 | text_splitter = CharacterTextSplitter(chunk_size=100, chunk_overlap=0) 121 | texts = text_splitter.split_documents(documents) 122 | text_splitter = TokenTextSplitter(chunk_size=1000, chunk_overlap=10, encoding_name="cl100k_base") # may be inexact 123 | texts = text_splitter.split_documents(texts) 124 | 125 | self.vectordb = Chroma.from_documents(documents=texts, embedding=self.embedding) 126 | 127 | # Create the LangChain chain 128 | if self.config["llm"] == LLM_OPENAI_GPT35: 129 | # Use ChatGPT API 130 | self.qa = RetrievalQA.from_chain_type(llm=OpenAI(model_name=LLM_OPENAI_GPT35, temperature=0.), chain_type="stuff",\ 131 | retriever=self.vectordb.as_retriever(search_kwargs={"k":4})) 132 | else: 133 | # Use local LLM 134 | hf_llm = HuggingFacePipeline(pipeline=self.llm) 135 | self.qa = RetrievalQA.from_chain_type(llm=hf_llm, chain_type="stuff",\ 136 | retriever=self.vectordb.as_retriever(search_kwargs={"k":4})) 137 | if self.config["question_check"]: 138 | self.q_check = RetrievalQA.from_chain_type(llm=hf_llm, chain_type="stuff",\ 139 | retriever=self.vectordb.as_retriever(search_kwargs={"k":4})) 140 | self.q_check.combine_documents_chain.llm_chain.prompt = WikipediaQA.QUESTION_CHECK_PROMPT 141 | 142 | return wiki_title, wiki_text 143 | 144 | def get_answer(self, question: str) -> str: 145 | if self.config["llm"] != LLM_OPENAI_GPT35 and self.config["question_check"]: 146 | # For local LLMs, do a self-check to see if question can be answered 147 | # If unanswerable, respond with "I don't know" 148 | answerable = self.q_check.run(question) 149 | if self.config["llm"] == LLM_FASTCHAT_T5_XL: 150 | answerable = self._clean_fastchat_t5_output(answerable) 151 | if answerable != "yes": 152 | return "I don't know" 153 | 154 | # Answer the question 155 | answer = self.qa.run(question) 156 | if self.config["llm"] == LLM_FASTCHAT_T5_XL: 157 | answer = self._clean_fastchat_t5_output(answer) 158 | return answer 159 | 160 | def _clean_fastchat_t5_output(self, answer: str) -> str: 161 | # Remove tags, double spaces, trailing newline 162 | answer = re.sub(r"\s+", "", answer) 163 | answer = re.sub(r" ", " ", answer) 164 | answer = re.sub(r"\n$", "", answer) 165 | return answer 166 | -------------------------------------------------------------------------------- /constants.py: -------------------------------------------------------------------------------- 1 | EMB_OPENAI_ADA = "text-embedding-ada-002" 2 | EMB_INSTRUCTOR_XL = "hkunlp/instructor-xl" 3 | 4 | LLM_OPENAI_GPT35 = "gpt-3.5-turbo" 5 | LLM_FLAN_T5_XXL = "google/flan-t5-xxl" 6 | LLM_FLAN_T5_XL = "google/flan-t5-xl" 7 | LLM_FASTCHAT_T5_XL = "lmsys/fastchat-t5-3b-v1.0" 8 | -------------------------------------------------------------------------------- /gradio_app.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | import time 4 | 5 | import wikipediaapi 6 | from InstructorEmbedding import INSTRUCTOR 7 | from langchain import HuggingFacePipeline 8 | from langchain.chains import RetrievalQA 9 | from langchain.docstore.document import Document 10 | from langchain.embeddings import HuggingFaceInstructEmbeddings 11 | from langchain.embeddings.openai import OpenAIEmbeddings 12 | from langchain.llms import OpenAI 13 | from langchain.prompts import PromptTemplate 14 | from langchain.text_splitter import CharacterTextSplitter, TokenTextSplitter 15 | from langchain.vectorstores import Chroma 16 | from transformers import pipeline 17 | 18 | from constants import * 19 | from WikipediaQA import WikipediaQA 20 | 21 | from dotenv import load_dotenv 22 | import gradio as gr 23 | 24 | # Load OpenAI API key 25 | load_dotenv() 26 | 27 | # Global model instances 28 | instructor_xl = None 29 | flan_t5_xl = None 30 | flan_t5_xxl = None 31 | fastchat_t5_xl = None 32 | 33 | with gr.Blocks() as qa_app: 34 | # Initialization 35 | qa = gr.State(WikipediaQA({"question_check": True, "load_in_8bit": False})) 36 | 37 | # Layout 38 | gr.Markdown("""# Wikipedia Q&A 39 | Use OpenAI and/or local models for embeddings/LLM 40 | """) 41 | 42 | with gr.Tab("Model setup"): 43 | gr.Markdown("""Select embedding and LLM models 44 | OpenAI models require an API key, duplicate this space and use [secrets](https://huggingface.co/docs/hub/spaces-overview#managing-secrets) 45 | """) 46 | with gr.Row() as row: 47 | with gr.Column(): 48 | emb_radio = gr.Radio([EMB_OPENAI_ADA, EMB_INSTRUCTOR_XL], 49 | label="Select embedding model") 50 | llm_radio = gr.Radio([LLM_OPENAI_GPT35, LLM_FLAN_T5_XL, LLM_FLAN_T5_XXL, LLM_FASTCHAT_T5_XL], 51 | label="Select LLM model", 52 | info="Note: flan-t5-xxl will run out of memory on a single A10G") 53 | with gr.Column(): 54 | model_text_box = gr.Textbox(label="Current models") 55 | model_load_btn = gr.Button("Load models") 56 | 57 | with gr.Tab("Read Wikipedia"): 58 | gr.Markdown("""Search Wikipedia and get the first result 59 | Chunk the article and index local vector store 60 | """) 61 | with gr.Row() as row: 62 | with gr.Column(): 63 | query = gr.Textbox(label="Wikipedia search query") 64 | search_btn = gr.Button("Search") 65 | with gr.Column(): 66 | wiki_title_box = gr.Textbox(label="Article title") 67 | wiki_text_box = gr.Textbox(label="Article text") 68 | 69 | with gr.Tab("Q&A"): 70 | gr.Markdown("""Ask a question about the Wikipedia article""") 71 | with gr.Row() as row: 72 | with gr.Column(): 73 | question = gr.Textbox(label="Enter your question") 74 | question_btn = gr.Button("Ask") 75 | with gr.Column(): 76 | answer_box = gr.Textbox(label="Answer") 77 | 78 | # Logic 79 | # Model setup 80 | def load_model(emb, llm, qa): 81 | global instructor_xl 82 | global flan_t5_xl 83 | global flan_t5_xxl 84 | global fastchat_t5_xl 85 | 86 | if emb == EMB_OPENAI_ADA: 87 | qa.embedding = OpenAIEmbeddings() 88 | elif emb == EMB_INSTRUCTOR_XL: 89 | if instructor_xl is None: 90 | instructor_xl = WikipediaQA.create_instructor_xl() 91 | qa.embedding = instructor_xl 92 | else: 93 | raise ValueError("Invalid embedding setting") 94 | qa.config["embedding"] = emb 95 | 96 | if llm == LLM_OPENAI_GPT35: 97 | pass 98 | elif llm == LLM_FLAN_T5_XL: 99 | if flan_t5_xl is None: 100 | flan_t5_xl = WikipediaQA.create_flan_t5_xl() 101 | qa.llm = flan_t5_xl 102 | elif llm == LLM_FLAN_T5_XXL: 103 | if flan_t5_xxl is None: 104 | flan_t5_xxl = WikipediaQA.create_flan_t5_xxl() 105 | qa.llm = flan_t5_xxl 106 | elif llm == LLM_FASTCHAT_T5_XL: 107 | if fastchat_t5_xl is None: 108 | fastchat_t5_xl = WikipediaQA.create_fastchat_t5_xl() 109 | qa.llm = fastchat_t5_xl 110 | else: 111 | raise ValueError("Invalid LLM setting") 112 | qa.config["llm"] = llm 113 | 114 | return f"Embedding model: {emb}\nLLM: {llm}" 115 | 116 | model_load_btn.click( 117 | load_model, 118 | [emb_radio, llm_radio, qa], 119 | [model_text_box] 120 | ) 121 | 122 | # Search Wikipedia 123 | def wiki_search(query, qa): 124 | wiki_title, wiki_text = qa.search_and_read_page(query) 125 | return { 126 | wiki_title_box: wiki_title, 127 | wiki_text_box: wiki_text, 128 | } 129 | 130 | search_btn.click( 131 | wiki_search, 132 | [query, qa], 133 | [wiki_title_box, wiki_text_box] 134 | ) 135 | 136 | # Q&A 137 | def qa_fn(question, qa): 138 | answer = qa.get_answer(question) 139 | return {answer_box: answer} 140 | 141 | question_btn.click( 142 | qa_fn, 143 | [question, qa], 144 | [answer_box] 145 | 146 | ) 147 | qa_app.launch(debug=True) 148 | -------------------------------------------------------------------------------- /images/temp.txt: -------------------------------------------------------------------------------- 1 | just creating a new folder for images, delete this later 2 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | transformers 2 | langchain 3 | accelerate 4 | bitsandbytes 5 | chromadb 6 | beautifulsoup4 7 | openai 8 | tiktoken 9 | sentence_transformers 10 | wikipedia-api==0.5.8 11 | InstructorEmbedding 12 | gradio 13 | python-dotenv 14 | urllib3==1.26.6 15 | -------------------------------------------------------------------------------- /template.env: -------------------------------------------------------------------------------- 1 | OPENAI_API_KEY=hello 2 | --------------------------------------------------------------------------------