├── .gitignore ├── .env.sample ├── requirements.txt ├── Makefile ├── crawl_index.py ├── emb.py ├── static ├── styles.css ├── es-gpt.js ├── p.html ├── sse.js └── index.html ├── app.py ├── README.md └── es_gpt.py /.gitignore: -------------------------------------------------------------------------------- 1 | .env 2 | __pycache__ 3 | .venv 4 | 5 | -------------------------------------------------------------------------------- /.env.sample: -------------------------------------------------------------------------------- 1 | OPENAI_API_KEY= 2 | OPENAI_GPT_ENGINE=text-babbage-001 3 | OPENAI_GPT_MAX_TOKENS=2000 4 | ES_URL=https://localhost:9200 5 | ES_USER=elastic 6 | ES_PASS=zg= 7 | ES_CA_CERT=config/certs/http_ca.crt 8 | 9 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | scholarly 2 | certifi 3 | fastapi 4 | requests 5 | elasticsearch 6 | uvicorn 7 | tiktoken 8 | openai 9 | matplotlib 10 | plotly 11 | pandas 12 | scipy 13 | scikit-learn 14 | pytest 15 | sentence-transformers -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | VENV = .venv 2 | PYTHON = $(VENV)/bin/python3 3 | PIP = $(VENV)/bin/pip3 4 | UVICORN = $(VENV)/bin/uvicorn 5 | PYTEST = $(VENV)/bin/pytest 6 | 7 | include .env 8 | export 9 | 10 | # Need to use python 3.9 for aws lambda 11 | $(VENV)/bin/activate: requirements.txt 12 | python3 -m venv $(VENV) 13 | $(PIP) install -r requirements.txt 14 | 15 | emb: $(VENV)/bin/activate 16 | $(PYTHON) emb.py 17 | 18 | crawl: $(VENV)/bin/activate 19 | $(PYTHON) crawl_index.py 20 | 21 | esgpt: $(VENV)/bin/activate 22 | $(PYTHON) es_gpt.py 23 | 24 | test: $(VENV)/bin/activate 25 | $(PYTEST) --verbose es_gpt_test.py -s -vv 26 | 27 | app: $(VENV)/bin/activate 28 | $(UVICORN) app:app --reload --port 7002 29 | 30 | clean: 31 | rm -rf __pycache__ 32 | rm -rf $(VENV) 33 | -------------------------------------------------------------------------------- /crawl_index.py: -------------------------------------------------------------------------------- 1 | from scholarly import scholarly 2 | from es_gpt import ESGPT 3 | 4 | 5 | def get_text_from_paper(x): 6 | title = x['bib'].get('title', '') 7 | abstract = x['bib'].get('abstract', '') 8 | return title + " " + abstract 9 | 10 | 11 | # Create an instance of the ESGPT class 12 | esgpt = ESGPT(index_name="papers") 13 | 14 | # Search for papers by author ID, Sung Kim 15 | author = scholarly.search_author_id("JE_m2UgAAAAJ") 16 | papers = scholarly.fill(author, sections=['publications']) 17 | # Index each paper in Elasticsearch 18 | for paper in papers['publications']: 19 | print(paper) 20 | paper = scholarly.fill(paper, sections=[]) 21 | paper_dict = paper['bib'] 22 | id = paper['author_pub_id'] 23 | 24 | # Index the paper in Elasticsearch 25 | text = get_text_from_paper(paper) 26 | esgpt.index(doc_id=id, doc=paper_dict, text=text) 27 | -------------------------------------------------------------------------------- /emb.py: -------------------------------------------------------------------------------- 1 | import os 2 | import openai 3 | from sentence_transformers import SentenceTransformer 4 | 5 | 6 | EMB_USE_OPENAI = os.getenv('EMB_USE_OPENAI', '0') 7 | 8 | 9 | def _get_openai_embedding(input): 10 | openai.api_key = os.environ["OPENAI_API_KEY"] 11 | return openai.Embedding.create( 12 | input=input, engine='text-embedding-ada-002')['data'][0]['embedding'] 13 | 14 | 15 | def _get_transformer_embedding(input): 16 | model = SentenceTransformer('paraphrase-MiniLM-L6-v2') 17 | 18 | # Sentences are encoded by calling model.encode() 19 | embedding = model.encode(input) 20 | return embedding 21 | 22 | 23 | def get_embedding(input): 24 | if EMB_USE_OPENAI == '1': 25 | return _get_openai_embedding(input) 26 | else: 27 | return _get_transformer_embedding(input) 28 | 29 | 30 | if __name__ == "__main__": 31 | print("Transformer: ", _get_transformer_embedding('hello world')[0]) 32 | print("OpenAI: ", _get_openai_embedding('hello world')) 33 | -------------------------------------------------------------------------------- /static/styles.css: -------------------------------------------------------------------------------- 1 | body { 2 | display: flex; 3 | flex-wrap: wrap; 4 | } 5 | 6 | #search-container { 7 | flex: 0 0 100%; 8 | padding: 20px; 9 | box-sizing: border-box; 10 | } 11 | 12 | #results-container { 13 | flex: 0 0 70%; 14 | padding: 20px; 15 | box-sizing: border-box; 16 | } 17 | 18 | #summary-container { 19 | flex: 0 0 30%; 20 | padding: 20px; 21 | box-sizing: border-box; 22 | background-color: #f2f2f2; 23 | } 24 | 25 | .search-input-container { 26 | display: flex; 27 | margin-bottom: 20px; 28 | } 29 | 30 | #query { 31 | flex: 1; 32 | margin-right: 10px; 33 | } 34 | 35 | #results { 36 | max-height: 80vh; 37 | overflow-y: scroll; 38 | } 39 | 40 | .search-result { 41 | margin-bottom: 10px; 42 | padding-bottom: 10px; 43 | border-bottom: 1px solid #ddd; 44 | } 45 | 46 | .search-result a { 47 | font-size: 18px; 48 | font-weight: bold; 49 | color: #1a0dab; 50 | text-decoration: none; 51 | } 52 | 53 | .search-result a:hover { 54 | text-decoration: underline; 55 | } 56 | 57 | -------------------------------------------------------------------------------- /static/es-gpt.js: -------------------------------------------------------------------------------- 1 | function summarizeOnChange(inputId, resultsDivId, outputDivId) { 2 | const input_elem = document.getElementById(inputId); 3 | const results_elem = document.getElementById(resultsDivId); 4 | const output_elem = document.getElementById(outputDivId); 5 | 6 | console.log("input: " + input_elem, "results: " + results_elem); 7 | let timeoutId; 8 | 9 | function handleChange() { 10 | // Clear any previous timeouts 11 | clearTimeout(timeoutId); 12 | 13 | // Delay execution of the summarize function by 500ms to give the results time to load 14 | timeoutId = setTimeout(() => { 15 | console.log("Summarizing..."); 16 | summarize(input_elem, results_elem, output_elem); 17 | }, 777); // FIXME: hopefully this is long enough to load the results 18 | } 19 | 20 | // Watch for changes to the results element using MutationObserver 21 | const observer = new MutationObserver(handleChange); 22 | observer.observe(results_elem, { childList: true, subtree: true }); 23 | } 24 | 25 | function summarize(input_elem, results_elem, output_elem) { 26 | const text_results = results_elem.textContent; 27 | const q = input_elem.value 28 | 29 | const payload = { 30 | q: q, 31 | text_results: text_results, 32 | }; 33 | 34 | // SSE is a class defined in sse.js 35 | // Should be imported in the HTML file before this script 36 | const eventSource = new SSE("summary", { method: 'POST', payload: JSON.stringify(payload) }); 37 | 38 | eventSource.onmessage = function (event) { 39 | if (!event.data || event.data == "[DONE]") { 40 | eventSource.close(); 41 | return; 42 | } 43 | const result = JSON.parse(event.data); 44 | output_elem.innerHTML += result['choices'][0]['text']; 45 | return; 46 | }; 47 | eventSource.stream(); 48 | } 49 | -------------------------------------------------------------------------------- /app.py: -------------------------------------------------------------------------------- 1 | from fastapi import FastAPI, HTTPException, Request 2 | from fastapi.staticfiles import StaticFiles 3 | from fastapi.responses import StreamingResponse 4 | import json 5 | from es_gpt import ESGPT 6 | 7 | # Create an instance of the ESGPT class 8 | es = ESGPT(index_name="papers") 9 | 10 | # Create a FastAPI app 11 | app = FastAPI() 12 | 13 | # Define the search route 14 | 15 | 16 | @app.get("/search") 17 | async def search(q: str): 18 | # Perform a search for the query 19 | results = es.search(q) 20 | 21 | # Stream the search results to the client 22 | async def stream_response(): 23 | for hit in results: 24 | yield "data: " + json.dumps(hit) + "\n\n" 25 | yield "[DONE]" 26 | 27 | return StreamingResponse(stream_response(), media_type="text/event-stream") 28 | 29 | # Define the summary route 30 | 31 | 32 | @app.post("/summary") 33 | async def summary(request: Request): 34 | 35 | payload = await request.json() 36 | q = payload["q"] 37 | text_results = payload.get("text_results", "") 38 | 39 | if text_results: 40 | # Generate summaries of the search results 41 | resp = es.gpt_answer(q, text_results=text_results) 42 | else: 43 | es_results = es.search(q) 44 | 45 | if es_results: 46 | # Generate summaries of the search results 47 | resp = es.gpt_answer(q, es_results=es_results) 48 | else: 49 | resp = es.gpt_answer(q, text_results="No results found") 50 | 51 | if resp.status_code != 200: 52 | raise HTTPException(resp.status_code, resp.text) 53 | 54 | return StreamingResponse(resp.iter_content(1), 55 | media_type="text/event-stream") 56 | 57 | ## 추가를 해보았는데 막상 제대로 동작하지 않는 것 같아서 우선은 사용하지 않습니다. 58 | @app.post("/question-suggestion") 59 | async def question_suggestion(request: Request): 60 | payload = await request.json() 61 | text_results = payload.get("text_results", "") 62 | 63 | if text_results is not None: 64 | # Generate further question by first relevent abstract text 65 | resp = es.gpt_question_generator(text_results=text_results) 66 | 67 | if resp.status_code != 200: 68 | raise HTTPException(resp.status_code, resp.text) 69 | else: 70 | return StreamingResponse(resp.iter_content(1), media_type="text/event-stream") 71 | else: 72 | raise HTTPException(400, "Unexpected Error") 73 | 74 | @app.post("/question") 75 | async def question(request: Request): 76 | payload = await request.json() 77 | q = payload.get("q", "") 78 | 79 | if q: 80 | resp = es.gpt_direct_answer(q) 81 | else: 82 | raise HTTPException(400, "Unexpected Error") 83 | 84 | if resp.status_code != 200: 85 | raise HTTPException(resp.status_code, resp.text) 86 | 87 | return StreamingResponse(resp.iter_content(1), 88 | media_type="text/event-stream") 89 | 90 | # Define the static files route 91 | # Need to set html=True to serve index.html 92 | # Need to put at the end of the routes 93 | app.mount("/", StaticFiles(directory="static", html=True), name="static") 94 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Elasticsearch + GPT3 Answerer 2 | Want to turn your (elastic) search into something as hot as Bing + ChatGPT? Look no further than the Elasticsearch + GPT3 Answerer! Our program intercepts Elasticsearch results and sends them to GPT3 to provide accurate and relevant answers to your queries. Plus, it's just plain fun to use! 3 | 4 | ## Features 5 | * Intercept Elasticsearch results and send them to GPT3 for more accurate answers 6 | * Two installation options: all-in-one and on-the-fly 7 | * Live demo available to see the program in action 8 | 9 | image 10 | 11 | It is designed to help users get more accurate and relevant answers to their queries, by leveraging the power of Elasticsearch and GPT3. 12 | 13 | ## Live Demo 14 | ![ezgif-2-48b3807122](https://user-images.githubusercontent.com/901975/219939314-a8f8f63e-75f6-4805-a743-2b03ab410e0c.gif) 15 | 16 | Check out our live demo at https://es-gpt.sung.devstage.ai/ to see the Elasticsearch + GPT3 Answerer in action! Please note that the site may be unstable and we are currently using the text-ada-001 model for proof of concept, so the GPT answer may be poor. However, this demo shows the concept of how the Elasticsearch + GPT3 Answerer works. 17 | 18 | ## How it works 19 | See this diagram. 20 | 21 | image 22 | 23 | ## Installation 24 | To use the Elasticsearch + GPT3 Answerer, you'll need to have access to both Elasticsearch and GPT3, as well as Python installed on your system. We offer two installation options: 25 | 26 | ### All-in-one installation 27 | To use the all-in-one installation, follow these steps: 28 | 29 | Clone this repository to your local machine. 30 | ```bash 31 | $ git clone https://github.com/hunkim/es-gpt.git 32 | $ cd es-gpt 33 | ``` 34 | 35 | Modify the .env for your Elasticsearch and GPT3 credentials and crawl_index.py file to index your documents. 36 | ```bash 37 | $make crawl 38 | ``` 39 | 40 | Then, this will run the backend server: 41 | ```bash 42 | $ make run 43 | ``` 44 | 45 | Then, visit the backend server. The web page will then intercept the Elasticsearch results and send them to GPT3 to provide a reasonable answer. This method is very fast, as the program embeds documents during indexing. 46 | 47 | ### On-the-fly installation 48 | To use the on-the-fly installation, follow these steps: 49 | 50 | Add the following scripts to your search page. See `static/p.html`. Specify the query, results, and gpt_answer output div IDs in your original search page: 51 | ```html 52 | 53 | 54 | 59 | ``` 60 | Modify the .env with your Elasticsearch and GPT3 credentials. 61 | Install the required dependencies by running the following command in your terminal: 62 | ``` 63 | $ make run 64 | ``` 65 | 66 | Run your search web enter a query. The program will intercept the HTML results and send them to GPT3 to provide a reasonable answer. This method is convenient, but slower, as the program embeds the search results and query on-the-fly. 67 | 68 | ## Contributing 69 | We welcome contributions from the community! If you have ideas for how to improve the Elasticsearch + GPT3 Answerer, please open an issue or submit a pull request. We love hearing from fellow search enthusiasts! 70 | 71 | ## License 72 | This program is licensed under the MIT License 73 | -------------------------------------------------------------------------------- /static/p.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | Summarize GPT 6 | 7 | 8 | 65 | 66 | 67 | 73 | 74 | 75 | 76 | 77 |
78 |

Elasticsearch with GPT (Sung's paper search)

79 |
80 | 82 | 83 |
84 |
85 |
86 |

Search Results

87 |
88 |
89 |
90 |

GPT Answer

91 | poor result due to text-ada-001 92 |
93 |
94 | Visit https://github.com/hunkim/es-gpt for more information 95 | 96 | 97 | 98 | -------------------------------------------------------------------------------- /static/sse.js: -------------------------------------------------------------------------------- 1 | /** 2 | * Copyright (C) 2016 Maxime Petazzoni . 3 | * All rights reserved. 4 | */ 5 | 6 | var SSE = function (url, options) { 7 | if (!(this instanceof SSE)) { 8 | return new SSE(url, options); 9 | } 10 | 11 | this.INITIALIZING = -1; 12 | this.CONNECTING = 0; 13 | this.OPEN = 1; 14 | this.CLOSED = 2; 15 | 16 | this.url = url; 17 | 18 | options = options || {}; 19 | this.headers = options.headers || {}; 20 | this.payload = options.payload !== undefined ? options.payload : ''; 21 | this.method = options.method || (this.payload && 'POST' || 'GET'); 22 | this.withCredentials = !!options.withCredentials; 23 | 24 | this.FIELD_SEPARATOR = ':'; 25 | this.listeners = {}; 26 | 27 | this.xhr = null; 28 | this.readyState = this.INITIALIZING; 29 | this.progress = 0; 30 | this.chunk = ''; 31 | 32 | this.addEventListener = function(type, listener) { 33 | if (this.listeners[type] === undefined) { 34 | this.listeners[type] = []; 35 | } 36 | 37 | if (this.listeners[type].indexOf(listener) === -1) { 38 | this.listeners[type].push(listener); 39 | } 40 | }; 41 | 42 | this.removeEventListener = function(type, listener) { 43 | if (this.listeners[type] === undefined) { 44 | return; 45 | } 46 | 47 | var filtered = []; 48 | this.listeners[type].forEach(function(element) { 49 | if (element !== listener) { 50 | filtered.push(element); 51 | } 52 | }); 53 | if (filtered.length === 0) { 54 | delete this.listeners[type]; 55 | } else { 56 | this.listeners[type] = filtered; 57 | } 58 | }; 59 | 60 | this.dispatchEvent = function(e) { 61 | if (!e) { 62 | return true; 63 | } 64 | 65 | e.source = this; 66 | 67 | var onHandler = 'on' + e.type; 68 | if (this.hasOwnProperty(onHandler)) { 69 | this[onHandler].call(this, e); 70 | if (e.defaultPrevented) { 71 | return false; 72 | } 73 | } 74 | 75 | if (this.listeners[e.type]) { 76 | return this.listeners[e.type].every(function(callback) { 77 | callback(e); 78 | return !e.defaultPrevented; 79 | }); 80 | } 81 | 82 | return true; 83 | }; 84 | 85 | this._setReadyState = function(state) { 86 | var event = new CustomEvent('readystatechange'); 87 | event.readyState = state; 88 | this.readyState = state; 89 | this.dispatchEvent(event); 90 | }; 91 | 92 | this._onStreamFailure = function(e) { 93 | var event = new CustomEvent('error'); 94 | event.data = e.currentTarget.response; 95 | this.dispatchEvent(event); 96 | this.close(); 97 | } 98 | 99 | this._onStreamAbort = function(e) { 100 | this.dispatchEvent(new CustomEvent('abort')); 101 | this.close(); 102 | } 103 | 104 | this._onStreamProgress = function(e) { 105 | if (!this.xhr) { 106 | return; 107 | } 108 | 109 | if (this.xhr.status !== 200) { 110 | this._onStreamFailure(e); 111 | return; 112 | } 113 | 114 | if (this.readyState == this.CONNECTING) { 115 | this.dispatchEvent(new CustomEvent('open')); 116 | this._setReadyState(this.OPEN); 117 | } 118 | 119 | var data = this.xhr.responseText.substring(this.progress); 120 | this.progress += data.length; 121 | data.split(/(\r\n|\r|\n){2}/g).forEach(function(part) { 122 | if (part.trim().length === 0) { 123 | this.dispatchEvent(this._parseEventChunk(this.chunk.trim())); 124 | this.chunk = ''; 125 | } else { 126 | this.chunk += part; 127 | } 128 | }.bind(this)); 129 | }; 130 | 131 | this._onStreamLoaded = function(e) { 132 | this._onStreamProgress(e); 133 | 134 | // Parse the last chunk. 135 | this.dispatchEvent(this._parseEventChunk(this.chunk)); 136 | this.chunk = ''; 137 | }; 138 | 139 | /** 140 | * Parse a received SSE event chunk into a constructed event object. 141 | */ 142 | this._parseEventChunk = function(chunk) { 143 | if (!chunk || chunk.length === 0) { 144 | return null; 145 | } 146 | 147 | var e = {'id': null, 'retry': null, 'data': '', 'event': 'message'}; 148 | chunk.split(/\n|\r\n|\r/).forEach(function(line) { 149 | line = line.trimRight(); 150 | var index = line.indexOf(this.FIELD_SEPARATOR); 151 | if (index <= 0) { 152 | // Line was either empty, or started with a separator and is a comment. 153 | // Either way, ignore. 154 | return; 155 | } 156 | 157 | var field = line.substring(0, index); 158 | if (!(field in e)) { 159 | return; 160 | } 161 | 162 | var value = line.substring(index + 1).trimLeft(); 163 | if (field === 'data') { 164 | e[field] += value; 165 | } else { 166 | e[field] = value; 167 | } 168 | }.bind(this)); 169 | 170 | var event = new CustomEvent(e.event); 171 | event.data = e.data; 172 | event.id = e.id; 173 | return event; 174 | }; 175 | 176 | this._checkStreamClosed = function() { 177 | if (!this.xhr) { 178 | return; 179 | } 180 | 181 | if (this.xhr.readyState === XMLHttpRequest.DONE) { 182 | this._setReadyState(this.CLOSED); 183 | } 184 | }; 185 | 186 | this.stream = function() { 187 | this._setReadyState(this.CONNECTING); 188 | 189 | this.xhr = new XMLHttpRequest(); 190 | this.xhr.addEventListener('progress', this._onStreamProgress.bind(this)); 191 | this.xhr.addEventListener('load', this._onStreamLoaded.bind(this)); 192 | this.xhr.addEventListener('readystatechange', this._checkStreamClosed.bind(this)); 193 | this.xhr.addEventListener('error', this._onStreamFailure.bind(this)); 194 | this.xhr.addEventListener('abort', this._onStreamAbort.bind(this)); 195 | this.xhr.open(this.method, this.url); 196 | for (var header in this.headers) { 197 | this.xhr.setRequestHeader(header, this.headers[header]); 198 | } 199 | this.xhr.withCredentials = this.withCredentials; 200 | this.xhr.send(this.payload); 201 | }; 202 | 203 | this.close = function() { 204 | if (this.readyState === this.CLOSED) { 205 | return; 206 | } 207 | 208 | this.xhr.abort(); 209 | this.xhr = null; 210 | this._setReadyState(this.CLOSED); 211 | }; 212 | }; 213 | 214 | // Export our SSE module for npm.js 215 | if (typeof exports !== 'undefined') { 216 | exports.SSE = SSE; 217 | } 218 | -------------------------------------------------------------------------------- /es_gpt.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import requests 4 | import re 5 | import pandas as pd 6 | import string 7 | from elasticsearch import Elasticsearch 8 | 9 | import tiktoken 10 | import openai 11 | from openai.embeddings_utils import distances_from_embeddings 12 | from emb import get_embedding 13 | 14 | 15 | ES_URL = os.environ["ES_URL"] 16 | ES_USER = os.environ["ES_USER"] 17 | ES_PASS = os.environ["ES_PASS"] 18 | ES_CA_CERT = os.environ["ES_CA_CERT"] 19 | 20 | 21 | class ESGPT: 22 | def __init__(self, index_name): 23 | self.es = Elasticsearch(ES_URL, basic_auth=(ES_USER, ES_PASS), 24 | ca_certs=ES_CA_CERT, verify_certs=True) 25 | self.index_name = index_name 26 | 27 | # FIXME: remove .strip() 28 | self.model_engine = os.environ["OPENAI_GPT_ENGINE"].strip() 29 | self.model_max_tokens = int(os.environ["OPENAI_GPT_MAX_TOKENS"]) 30 | self.api_key = os.environ["OPENAI_API_KEY"] 31 | openai.api_key = self.api_key 32 | self.max_tokens = 1000 33 | self.split_max_tokens = 500 34 | 35 | # Load the cl100k_base tokenizer which is designed to work with the ada-002 model 36 | self.tokenizer = tiktoken.get_encoding("cl100k_base") 37 | 38 | self.answer_generation_prompt = "Based on the context below\"\n\nContext: {}\n\n---\n\nPlease provide concise answer for this questions: {}" 39 | self.question_suggestion_prompt = "Based on the context below\"\n\nContext: {}\n\n---\n\nPlease recommend 3 more questions to be curious about {}" 40 | self.just_question_prompt = "{}{}" 41 | 42 | def index(self, doc_id, doc, text): 43 | doc["embeddings_dict_list"] = self._create_emb_dict_list(text) 44 | self.es.index(index=self.index_name, 45 | id=doc_id, 46 | document=doc) 47 | 48 | def search(self, query): 49 | es_query = { 50 | "query_string": {"query": query} 51 | } 52 | 53 | results = self.es.search(index=self.index_name, query=es_query) 54 | return results['hits']['hits'] 55 | 56 | def _paper_results_to_text(self, results): 57 | text_result = "" 58 | for paper in results: 59 | title = paper["_source"].get("title", "") 60 | abstract = paper["_source"].get("abstract", "") 61 | paper_str = f"{title}:\n{abstract}\n\n" 62 | text_result += paper_str 63 | return text_result 64 | 65 | # Code from https://github.com/openai/openai-cookbook/blob/main/apps/web-crawl-q-and-a/web-qa.py 66 | # Function to split the text into chunks of a maximum number of tokens 67 | def _split_into_many(self, text): 68 | sentences = [] 69 | for sentence in re.split(r'[{}]'.format(string.punctuation), text): 70 | sentence = sentence.strip() 71 | if sentence and (any(char.isalpha() for char in sentence) or any(char.isdigit() for char in sentence)): 72 | sentences.append(sentence) 73 | 74 | n_tokens = [len(self.tokenizer.encode(" " + sentence)) 75 | for sentence in sentences] 76 | 77 | chunks = [] 78 | tokens_so_far = 0 79 | chunk = [] 80 | 81 | # Loop through the sentences and tokens joined together in a tuple 82 | for sentence, token in zip(sentences, n_tokens): 83 | # If the number of tokens so far plus the number of tokens in the current sentence is greater 84 | # than the max number of tokens, then add the chunk to the list of chunks and reset 85 | # the chunk and tokens so far 86 | if tokens_so_far + token > self.split_max_tokens and chunk: 87 | chunks.append(". ".join(chunk) + ".") 88 | chunk = [] 89 | tokens_so_far = 0 90 | 91 | # If the number of tokens in the current sentence is greater than the max number of 92 | # tokens, go to the next sentence 93 | if token > self.split_max_tokens: 94 | continue 95 | 96 | # Otherwise, add the sentence to the chunk and add the number of tokens to the total 97 | chunk.append(sentence) 98 | tokens_so_far += token + 1 99 | 100 | # Add the last chunk to the list of chunks 101 | if chunk: 102 | chunks.append(". ".join(chunk) + ".") 103 | 104 | return chunks 105 | 106 | 107 | def _create_emb_dict_list(self, long_text): 108 | shortened = self._split_into_many(long_text) 109 | 110 | embeddings_dict_list = [] 111 | 112 | for text in shortened: 113 | n_tokens = len(self.tokenizer.encode(text)) 114 | embeddings = get_embedding(input=text) 115 | embeddings_dict = {} 116 | embeddings_dict["text"] = text 117 | embeddings_dict["n_tokens"] = n_tokens 118 | embeddings_dict["embeddings"] = embeddings 119 | embeddings_dict_list.append(embeddings_dict) 120 | 121 | return embeddings_dict_list 122 | 123 | def _create_context(self, question, df): 124 | """ 125 | Create a context for a question by finding the most similar context from the dataframe 126 | """ 127 | 128 | # Get the embeddings for the question 129 | q_embeddings = get_embedding(input=question) 130 | 131 | # Get the distances from the embeddings 132 | df['distances'] = distances_from_embeddings( 133 | q_embeddings, df['embeddings'].values, distance_metric='cosine') 134 | 135 | returns = [] 136 | cur_len = 0 137 | 138 | # Sort by distance and add the text to the context until the context is too long 139 | for i, row in df.sort_values('distances', ascending=True).iterrows(): 140 | # Add the length of the text to the current length 141 | cur_len += row['n_tokens'] + 4 142 | 143 | # If the context is too long, break 144 | if cur_len > self.max_tokens: 145 | break 146 | 147 | # Else add it to the text that is being returned 148 | returns.append(row["text"]) 149 | 150 | # Return the context and the length of the context 151 | return "\n\n###\n\n".join(returns), cur_len 152 | 153 | def _gpt_api_call(self, query, input_token_len, context, call_type): 154 | if call_type == "answer": 155 | prompt = self.answer_generation_prompt 156 | elif call_type == "question": 157 | prompt = self.just_question_prompt 158 | else: 159 | prompt = self.question_suggestion_prompt 160 | 161 | body = { 162 | "model": self.model_engine, 163 | "prompt": prompt.format(context, query), 164 | "max_tokens": self.model_max_tokens - input_token_len, 165 | "n": 1, 166 | "temperature": 0.5, 167 | "stream": True, 168 | } 169 | 170 | headers = {"Content-Type": "application/json", 171 | "Authorization": f"Bearer {self.api_key}"} 172 | 173 | resp = requests.post("https://api.openai.com/v1/completions", 174 | headers=headers, 175 | data=json.dumps(body), 176 | stream=True) 177 | return resp 178 | 179 | 180 | def gpt_answer(self, query, es_results=None, text_results=None): 181 | # Generate summaries for each search result 182 | if text_results: 183 | input_token_len = len(self.tokenizer.encode(text_results)) 184 | if input_token_len < self.max_tokens: 185 | context = text_results 186 | else: 187 | emb_dict_list = self._create_emb_dict_list(text_results) 188 | df = pd.DataFrame(columns=["text", "n_tokens", "embeddings"]) 189 | for emb_dict in emb_dict_list: 190 | df = df.append(emb_dict, ignore_index=True) 191 | 192 | context, input_token_len = self._create_context( 193 | question=query, 194 | df=df) 195 | elif es_results: 196 | result_json_str = self._paper_results_to_text(es_results) 197 | if not result_json_str: 198 | result_json_str = "No results found" 199 | 200 | input_token_len = len(self.tokenizer.encode(result_json_str)) 201 | if input_token_len < self.max_tokens: 202 | context = result_json_str 203 | else: 204 | # Create a pandas DataFrame from the list of embeddings dictionaries 205 | df = pd.DataFrame(columns=["text", "n_tokens", "embeddings"]) 206 | 207 | # extract embeddings_dict from es_results and append to the dataframe 208 | for hit in es_results: 209 | embeddings_dict_list = hit['_source']['embeddings_dict_list'] 210 | for embeddings_dict in embeddings_dict_list: 211 | df = df.append(embeddings_dict, ignore_index=True) 212 | 213 | context, input_token_len = self._create_context( 214 | question=query, 215 | df=df) 216 | else: 217 | assert False, "Must provide either es_results or text_results" 218 | 219 | return self._gpt_api_call(query, input_token_len, context, call_type="answer") 220 | 221 | def gpt_question_generator(self, text_results=None): 222 | if text_results: 223 | input_token_len = len(self.tokenizer.encode(text_results)) 224 | if input_token_len < self.max_tokens: 225 | context = text_results 226 | else: 227 | context = text_results[:self.max_tokens] 228 | input_token_len = self.max_tokens 229 | else: 230 | assert False, "Text results are not found" 231 | 232 | return self._gpt_api_call("", input_token_len, context, call_type="suggestion") 233 | 234 | def gpt_direct_answer(self, q): 235 | input_token_len = len(self.tokenizer.encode(q)) 236 | if input_token_len < self.max_tokens: 237 | query = q 238 | else: 239 | query = q[:self.max_tokens] 240 | input_token_len = self.max_tokens 241 | return self._gpt_api_call(q, input_token_len, "", call_type="question") 242 | 243 | 244 | 245 | 246 | 247 | # Example usage 248 | if __name__ == "__main__": 249 | esgpt = ESGPT("papers") 250 | query = "How to fix this bugs?" 251 | res = esgpt.search(query=query) 252 | res_str = esgpt._paper_results_to_text(res) 253 | 254 | # Pass ES results with precomputed embeddings 255 | res = esgpt.gpt_answer(query=query, es_results=res) 256 | print(res.text) 257 | 258 | # Pass text results and do embeddings on the fly 259 | # Note: This will be slower 260 | res = esgpt.gpt_answer(query=query, text_results=res_str) 261 | print(res.text) 262 | -------------------------------------------------------------------------------- /static/index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | Elastic search GPT 6 | 7 | 8 | 9 | 10 | 11 | 12 | 210 | 211 | 212 | 213 |
214 |
215 |

Elasticsearch with GPT (Sung's paper search)

216 | 217 |
218 |
219 | 223 |
224 |
225 | 228 |
229 |
230 | 233 |
234 |
235 |
236 | 237 |
238 |
239 |

Search results

240 |
241 |
242 | 243 |
244 |

Summaries

245 | 247 | 248 | 260 | 261 | 262 | 263 |
264 |
265 |
266 | 281 | 307 | 308 | 309 | 310 | --------------------------------------------------------------------------------