├── .gitignore
├── .env.sample
├── requirements.txt
├── Makefile
├── crawl_index.py
├── emb.py
├── static
├── styles.css
├── es-gpt.js
├── p.html
├── sse.js
└── index.html
├── app.py
├── README.md
└── es_gpt.py
/.gitignore:
--------------------------------------------------------------------------------
1 | .env
2 | __pycache__
3 | .venv
4 |
5 |
--------------------------------------------------------------------------------
/.env.sample:
--------------------------------------------------------------------------------
1 | OPENAI_API_KEY=
2 | OPENAI_GPT_ENGINE=text-babbage-001
3 | OPENAI_GPT_MAX_TOKENS=2000
4 | ES_URL=https://localhost:9200
5 | ES_USER=elastic
6 | ES_PASS=zg=
7 | ES_CA_CERT=config/certs/http_ca.crt
8 |
9 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | scholarly
2 | certifi
3 | fastapi
4 | requests
5 | elasticsearch
6 | uvicorn
7 | tiktoken
8 | openai
9 | matplotlib
10 | plotly
11 | pandas
12 | scipy
13 | scikit-learn
14 | pytest
15 | sentence-transformers
--------------------------------------------------------------------------------
/Makefile:
--------------------------------------------------------------------------------
1 | VENV = .venv
2 | PYTHON = $(VENV)/bin/python3
3 | PIP = $(VENV)/bin/pip3
4 | UVICORN = $(VENV)/bin/uvicorn
5 | PYTEST = $(VENV)/bin/pytest
6 |
7 | include .env
8 | export
9 |
10 | # Need to use python 3.9 for aws lambda
11 | $(VENV)/bin/activate: requirements.txt
12 | python3 -m venv $(VENV)
13 | $(PIP) install -r requirements.txt
14 |
15 | emb: $(VENV)/bin/activate
16 | $(PYTHON) emb.py
17 |
18 | crawl: $(VENV)/bin/activate
19 | $(PYTHON) crawl_index.py
20 |
21 | esgpt: $(VENV)/bin/activate
22 | $(PYTHON) es_gpt.py
23 |
24 | test: $(VENV)/bin/activate
25 | $(PYTEST) --verbose es_gpt_test.py -s -vv
26 |
27 | app: $(VENV)/bin/activate
28 | $(UVICORN) app:app --reload --port 7002
29 |
30 | clean:
31 | rm -rf __pycache__
32 | rm -rf $(VENV)
33 |
--------------------------------------------------------------------------------
/crawl_index.py:
--------------------------------------------------------------------------------
1 | from scholarly import scholarly
2 | from es_gpt import ESGPT
3 |
4 |
5 | def get_text_from_paper(x):
6 | title = x['bib'].get('title', '')
7 | abstract = x['bib'].get('abstract', '')
8 | return title + " " + abstract
9 |
10 |
11 | # Create an instance of the ESGPT class
12 | esgpt = ESGPT(index_name="papers")
13 |
14 | # Search for papers by author ID, Sung Kim
15 | author = scholarly.search_author_id("JE_m2UgAAAAJ")
16 | papers = scholarly.fill(author, sections=['publications'])
17 | # Index each paper in Elasticsearch
18 | for paper in papers['publications']:
19 | print(paper)
20 | paper = scholarly.fill(paper, sections=[])
21 | paper_dict = paper['bib']
22 | id = paper['author_pub_id']
23 |
24 | # Index the paper in Elasticsearch
25 | text = get_text_from_paper(paper)
26 | esgpt.index(doc_id=id, doc=paper_dict, text=text)
27 |
--------------------------------------------------------------------------------
/emb.py:
--------------------------------------------------------------------------------
1 | import os
2 | import openai
3 | from sentence_transformers import SentenceTransformer
4 |
5 |
6 | EMB_USE_OPENAI = os.getenv('EMB_USE_OPENAI', '0')
7 |
8 |
9 | def _get_openai_embedding(input):
10 | openai.api_key = os.environ["OPENAI_API_KEY"]
11 | return openai.Embedding.create(
12 | input=input, engine='text-embedding-ada-002')['data'][0]['embedding']
13 |
14 |
15 | def _get_transformer_embedding(input):
16 | model = SentenceTransformer('paraphrase-MiniLM-L6-v2')
17 |
18 | # Sentences are encoded by calling model.encode()
19 | embedding = model.encode(input)
20 | return embedding
21 |
22 |
23 | def get_embedding(input):
24 | if EMB_USE_OPENAI == '1':
25 | return _get_openai_embedding(input)
26 | else:
27 | return _get_transformer_embedding(input)
28 |
29 |
30 | if __name__ == "__main__":
31 | print("Transformer: ", _get_transformer_embedding('hello world')[0])
32 | print("OpenAI: ", _get_openai_embedding('hello world'))
33 |
--------------------------------------------------------------------------------
/static/styles.css:
--------------------------------------------------------------------------------
1 | body {
2 | display: flex;
3 | flex-wrap: wrap;
4 | }
5 |
6 | #search-container {
7 | flex: 0 0 100%;
8 | padding: 20px;
9 | box-sizing: border-box;
10 | }
11 |
12 | #results-container {
13 | flex: 0 0 70%;
14 | padding: 20px;
15 | box-sizing: border-box;
16 | }
17 |
18 | #summary-container {
19 | flex: 0 0 30%;
20 | padding: 20px;
21 | box-sizing: border-box;
22 | background-color: #f2f2f2;
23 | }
24 |
25 | .search-input-container {
26 | display: flex;
27 | margin-bottom: 20px;
28 | }
29 |
30 | #query {
31 | flex: 1;
32 | margin-right: 10px;
33 | }
34 |
35 | #results {
36 | max-height: 80vh;
37 | overflow-y: scroll;
38 | }
39 |
40 | .search-result {
41 | margin-bottom: 10px;
42 | padding-bottom: 10px;
43 | border-bottom: 1px solid #ddd;
44 | }
45 |
46 | .search-result a {
47 | font-size: 18px;
48 | font-weight: bold;
49 | color: #1a0dab;
50 | text-decoration: none;
51 | }
52 |
53 | .search-result a:hover {
54 | text-decoration: underline;
55 | }
56 |
57 |
--------------------------------------------------------------------------------
/static/es-gpt.js:
--------------------------------------------------------------------------------
1 | function summarizeOnChange(inputId, resultsDivId, outputDivId) {
2 | const input_elem = document.getElementById(inputId);
3 | const results_elem = document.getElementById(resultsDivId);
4 | const output_elem = document.getElementById(outputDivId);
5 |
6 | console.log("input: " + input_elem, "results: " + results_elem);
7 | let timeoutId;
8 |
9 | function handleChange() {
10 | // Clear any previous timeouts
11 | clearTimeout(timeoutId);
12 |
13 | // Delay execution of the summarize function by 500ms to give the results time to load
14 | timeoutId = setTimeout(() => {
15 | console.log("Summarizing...");
16 | summarize(input_elem, results_elem, output_elem);
17 | }, 777); // FIXME: hopefully this is long enough to load the results
18 | }
19 |
20 | // Watch for changes to the results element using MutationObserver
21 | const observer = new MutationObserver(handleChange);
22 | observer.observe(results_elem, { childList: true, subtree: true });
23 | }
24 |
25 | function summarize(input_elem, results_elem, output_elem) {
26 | const text_results = results_elem.textContent;
27 | const q = input_elem.value
28 |
29 | const payload = {
30 | q: q,
31 | text_results: text_results,
32 | };
33 |
34 | // SSE is a class defined in sse.js
35 | // Should be imported in the HTML file before this script
36 | const eventSource = new SSE("summary", { method: 'POST', payload: JSON.stringify(payload) });
37 |
38 | eventSource.onmessage = function (event) {
39 | if (!event.data || event.data == "[DONE]") {
40 | eventSource.close();
41 | return;
42 | }
43 | const result = JSON.parse(event.data);
44 | output_elem.innerHTML += result['choices'][0]['text'];
45 | return;
46 | };
47 | eventSource.stream();
48 | }
49 |
--------------------------------------------------------------------------------
/app.py:
--------------------------------------------------------------------------------
1 | from fastapi import FastAPI, HTTPException, Request
2 | from fastapi.staticfiles import StaticFiles
3 | from fastapi.responses import StreamingResponse
4 | import json
5 | from es_gpt import ESGPT
6 |
7 | # Create an instance of the ESGPT class
8 | es = ESGPT(index_name="papers")
9 |
10 | # Create a FastAPI app
11 | app = FastAPI()
12 |
13 | # Define the search route
14 |
15 |
16 | @app.get("/search")
17 | async def search(q: str):
18 | # Perform a search for the query
19 | results = es.search(q)
20 |
21 | # Stream the search results to the client
22 | async def stream_response():
23 | for hit in results:
24 | yield "data: " + json.dumps(hit) + "\n\n"
25 | yield "[DONE]"
26 |
27 | return StreamingResponse(stream_response(), media_type="text/event-stream")
28 |
29 | # Define the summary route
30 |
31 |
32 | @app.post("/summary")
33 | async def summary(request: Request):
34 |
35 | payload = await request.json()
36 | q = payload["q"]
37 | text_results = payload.get("text_results", "")
38 |
39 | if text_results:
40 | # Generate summaries of the search results
41 | resp = es.gpt_answer(q, text_results=text_results)
42 | else:
43 | es_results = es.search(q)
44 |
45 | if es_results:
46 | # Generate summaries of the search results
47 | resp = es.gpt_answer(q, es_results=es_results)
48 | else:
49 | resp = es.gpt_answer(q, text_results="No results found")
50 |
51 | if resp.status_code != 200:
52 | raise HTTPException(resp.status_code, resp.text)
53 |
54 | return StreamingResponse(resp.iter_content(1),
55 | media_type="text/event-stream")
56 |
57 | ## 추가를 해보았는데 막상 제대로 동작하지 않는 것 같아서 우선은 사용하지 않습니다.
58 | @app.post("/question-suggestion")
59 | async def question_suggestion(request: Request):
60 | payload = await request.json()
61 | text_results = payload.get("text_results", "")
62 |
63 | if text_results is not None:
64 | # Generate further question by first relevent abstract text
65 | resp = es.gpt_question_generator(text_results=text_results)
66 |
67 | if resp.status_code != 200:
68 | raise HTTPException(resp.status_code, resp.text)
69 | else:
70 | return StreamingResponse(resp.iter_content(1), media_type="text/event-stream")
71 | else:
72 | raise HTTPException(400, "Unexpected Error")
73 |
74 | @app.post("/question")
75 | async def question(request: Request):
76 | payload = await request.json()
77 | q = payload.get("q", "")
78 |
79 | if q:
80 | resp = es.gpt_direct_answer(q)
81 | else:
82 | raise HTTPException(400, "Unexpected Error")
83 |
84 | if resp.status_code != 200:
85 | raise HTTPException(resp.status_code, resp.text)
86 |
87 | return StreamingResponse(resp.iter_content(1),
88 | media_type="text/event-stream")
89 |
90 | # Define the static files route
91 | # Need to set html=True to serve index.html
92 | # Need to put at the end of the routes
93 | app.mount("/", StaticFiles(directory="static", html=True), name="static")
94 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Elasticsearch + GPT3 Answerer
2 | Want to turn your (elastic) search into something as hot as Bing + ChatGPT? Look no further than the Elasticsearch + GPT3 Answerer! Our program intercepts Elasticsearch results and sends them to GPT3 to provide accurate and relevant answers to your queries. Plus, it's just plain fun to use!
3 |
4 | ## Features
5 | * Intercept Elasticsearch results and send them to GPT3 for more accurate answers
6 | * Two installation options: all-in-one and on-the-fly
7 | * Live demo available to see the program in action
8 |
9 |
10 |
11 | It is designed to help users get more accurate and relevant answers to their queries, by leveraging the power of Elasticsearch and GPT3.
12 |
13 | ## Live Demo
14 | 
15 |
16 | Check out our live demo at https://es-gpt.sung.devstage.ai/ to see the Elasticsearch + GPT3 Answerer in action! Please note that the site may be unstable and we are currently using the text-ada-001 model for proof of concept, so the GPT answer may be poor. However, this demo shows the concept of how the Elasticsearch + GPT3 Answerer works.
17 |
18 | ## How it works
19 | See this diagram.
20 |
21 |
22 |
23 | ## Installation
24 | To use the Elasticsearch + GPT3 Answerer, you'll need to have access to both Elasticsearch and GPT3, as well as Python installed on your system. We offer two installation options:
25 |
26 | ### All-in-one installation
27 | To use the all-in-one installation, follow these steps:
28 |
29 | Clone this repository to your local machine.
30 | ```bash
31 | $ git clone https://github.com/hunkim/es-gpt.git
32 | $ cd es-gpt
33 | ```
34 |
35 | Modify the .env for your Elasticsearch and GPT3 credentials and crawl_index.py file to index your documents.
36 | ```bash
37 | $make crawl
38 | ```
39 |
40 | Then, this will run the backend server:
41 | ```bash
42 | $ make run
43 | ```
44 |
45 | Then, visit the backend server. The web page will then intercept the Elasticsearch results and send them to GPT3 to provide a reasonable answer. This method is very fast, as the program embeds documents during indexing.
46 |
47 | ### On-the-fly installation
48 | To use the on-the-fly installation, follow these steps:
49 |
50 | Add the following scripts to your search page. See `static/p.html`. Specify the query, results, and gpt_answer output div IDs in your original search page:
51 | ```html
52 |
53 |
54 |
59 | ```
60 | Modify the .env with your Elasticsearch and GPT3 credentials.
61 | Install the required dependencies by running the following command in your terminal:
62 | ```
63 | $ make run
64 | ```
65 |
66 | Run your search web enter a query. The program will intercept the HTML results and send them to GPT3 to provide a reasonable answer. This method is convenient, but slower, as the program embeds the search results and query on-the-fly.
67 |
68 | ## Contributing
69 | We welcome contributions from the community! If you have ideas for how to improve the Elasticsearch + GPT3 Answerer, please open an issue or submit a pull request. We love hearing from fellow search enthusiasts!
70 |
71 | ## License
72 | This program is licensed under the MIT License
73 |
--------------------------------------------------------------------------------
/static/p.html:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 | Summarize GPT
6 |
7 |
8 |
65 |
66 |
67 |
73 |
74 |
75 |
76 |
77 |
78 |
Elasticsearch with GPT (Sung's paper search)
79 |
80 |
82 |
83 |
84 |
85 |
86 |
Search Results
87 |
88 |
89 |
90 |
GPT Answer
91 |
poor result due to text-ada-001
92 |
93 |
94 | Visit https://github.com/hunkim/es-gpt for more information
95 |
96 |
97 |
98 |
--------------------------------------------------------------------------------
/static/sse.js:
--------------------------------------------------------------------------------
1 | /**
2 | * Copyright (C) 2016 Maxime Petazzoni .
3 | * All rights reserved.
4 | */
5 |
6 | var SSE = function (url, options) {
7 | if (!(this instanceof SSE)) {
8 | return new SSE(url, options);
9 | }
10 |
11 | this.INITIALIZING = -1;
12 | this.CONNECTING = 0;
13 | this.OPEN = 1;
14 | this.CLOSED = 2;
15 |
16 | this.url = url;
17 |
18 | options = options || {};
19 | this.headers = options.headers || {};
20 | this.payload = options.payload !== undefined ? options.payload : '';
21 | this.method = options.method || (this.payload && 'POST' || 'GET');
22 | this.withCredentials = !!options.withCredentials;
23 |
24 | this.FIELD_SEPARATOR = ':';
25 | this.listeners = {};
26 |
27 | this.xhr = null;
28 | this.readyState = this.INITIALIZING;
29 | this.progress = 0;
30 | this.chunk = '';
31 |
32 | this.addEventListener = function(type, listener) {
33 | if (this.listeners[type] === undefined) {
34 | this.listeners[type] = [];
35 | }
36 |
37 | if (this.listeners[type].indexOf(listener) === -1) {
38 | this.listeners[type].push(listener);
39 | }
40 | };
41 |
42 | this.removeEventListener = function(type, listener) {
43 | if (this.listeners[type] === undefined) {
44 | return;
45 | }
46 |
47 | var filtered = [];
48 | this.listeners[type].forEach(function(element) {
49 | if (element !== listener) {
50 | filtered.push(element);
51 | }
52 | });
53 | if (filtered.length === 0) {
54 | delete this.listeners[type];
55 | } else {
56 | this.listeners[type] = filtered;
57 | }
58 | };
59 |
60 | this.dispatchEvent = function(e) {
61 | if (!e) {
62 | return true;
63 | }
64 |
65 | e.source = this;
66 |
67 | var onHandler = 'on' + e.type;
68 | if (this.hasOwnProperty(onHandler)) {
69 | this[onHandler].call(this, e);
70 | if (e.defaultPrevented) {
71 | return false;
72 | }
73 | }
74 |
75 | if (this.listeners[e.type]) {
76 | return this.listeners[e.type].every(function(callback) {
77 | callback(e);
78 | return !e.defaultPrevented;
79 | });
80 | }
81 |
82 | return true;
83 | };
84 |
85 | this._setReadyState = function(state) {
86 | var event = new CustomEvent('readystatechange');
87 | event.readyState = state;
88 | this.readyState = state;
89 | this.dispatchEvent(event);
90 | };
91 |
92 | this._onStreamFailure = function(e) {
93 | var event = new CustomEvent('error');
94 | event.data = e.currentTarget.response;
95 | this.dispatchEvent(event);
96 | this.close();
97 | }
98 |
99 | this._onStreamAbort = function(e) {
100 | this.dispatchEvent(new CustomEvent('abort'));
101 | this.close();
102 | }
103 |
104 | this._onStreamProgress = function(e) {
105 | if (!this.xhr) {
106 | return;
107 | }
108 |
109 | if (this.xhr.status !== 200) {
110 | this._onStreamFailure(e);
111 | return;
112 | }
113 |
114 | if (this.readyState == this.CONNECTING) {
115 | this.dispatchEvent(new CustomEvent('open'));
116 | this._setReadyState(this.OPEN);
117 | }
118 |
119 | var data = this.xhr.responseText.substring(this.progress);
120 | this.progress += data.length;
121 | data.split(/(\r\n|\r|\n){2}/g).forEach(function(part) {
122 | if (part.trim().length === 0) {
123 | this.dispatchEvent(this._parseEventChunk(this.chunk.trim()));
124 | this.chunk = '';
125 | } else {
126 | this.chunk += part;
127 | }
128 | }.bind(this));
129 | };
130 |
131 | this._onStreamLoaded = function(e) {
132 | this._onStreamProgress(e);
133 |
134 | // Parse the last chunk.
135 | this.dispatchEvent(this._parseEventChunk(this.chunk));
136 | this.chunk = '';
137 | };
138 |
139 | /**
140 | * Parse a received SSE event chunk into a constructed event object.
141 | */
142 | this._parseEventChunk = function(chunk) {
143 | if (!chunk || chunk.length === 0) {
144 | return null;
145 | }
146 |
147 | var e = {'id': null, 'retry': null, 'data': '', 'event': 'message'};
148 | chunk.split(/\n|\r\n|\r/).forEach(function(line) {
149 | line = line.trimRight();
150 | var index = line.indexOf(this.FIELD_SEPARATOR);
151 | if (index <= 0) {
152 | // Line was either empty, or started with a separator and is a comment.
153 | // Either way, ignore.
154 | return;
155 | }
156 |
157 | var field = line.substring(0, index);
158 | if (!(field in e)) {
159 | return;
160 | }
161 |
162 | var value = line.substring(index + 1).trimLeft();
163 | if (field === 'data') {
164 | e[field] += value;
165 | } else {
166 | e[field] = value;
167 | }
168 | }.bind(this));
169 |
170 | var event = new CustomEvent(e.event);
171 | event.data = e.data;
172 | event.id = e.id;
173 | return event;
174 | };
175 |
176 | this._checkStreamClosed = function() {
177 | if (!this.xhr) {
178 | return;
179 | }
180 |
181 | if (this.xhr.readyState === XMLHttpRequest.DONE) {
182 | this._setReadyState(this.CLOSED);
183 | }
184 | };
185 |
186 | this.stream = function() {
187 | this._setReadyState(this.CONNECTING);
188 |
189 | this.xhr = new XMLHttpRequest();
190 | this.xhr.addEventListener('progress', this._onStreamProgress.bind(this));
191 | this.xhr.addEventListener('load', this._onStreamLoaded.bind(this));
192 | this.xhr.addEventListener('readystatechange', this._checkStreamClosed.bind(this));
193 | this.xhr.addEventListener('error', this._onStreamFailure.bind(this));
194 | this.xhr.addEventListener('abort', this._onStreamAbort.bind(this));
195 | this.xhr.open(this.method, this.url);
196 | for (var header in this.headers) {
197 | this.xhr.setRequestHeader(header, this.headers[header]);
198 | }
199 | this.xhr.withCredentials = this.withCredentials;
200 | this.xhr.send(this.payload);
201 | };
202 |
203 | this.close = function() {
204 | if (this.readyState === this.CLOSED) {
205 | return;
206 | }
207 |
208 | this.xhr.abort();
209 | this.xhr = null;
210 | this._setReadyState(this.CLOSED);
211 | };
212 | };
213 |
214 | // Export our SSE module for npm.js
215 | if (typeof exports !== 'undefined') {
216 | exports.SSE = SSE;
217 | }
218 |
--------------------------------------------------------------------------------
/es_gpt.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import requests
4 | import re
5 | import pandas as pd
6 | import string
7 | from elasticsearch import Elasticsearch
8 |
9 | import tiktoken
10 | import openai
11 | from openai.embeddings_utils import distances_from_embeddings
12 | from emb import get_embedding
13 |
14 |
15 | ES_URL = os.environ["ES_URL"]
16 | ES_USER = os.environ["ES_USER"]
17 | ES_PASS = os.environ["ES_PASS"]
18 | ES_CA_CERT = os.environ["ES_CA_CERT"]
19 |
20 |
21 | class ESGPT:
22 | def __init__(self, index_name):
23 | self.es = Elasticsearch(ES_URL, basic_auth=(ES_USER, ES_PASS),
24 | ca_certs=ES_CA_CERT, verify_certs=True)
25 | self.index_name = index_name
26 |
27 | # FIXME: remove .strip()
28 | self.model_engine = os.environ["OPENAI_GPT_ENGINE"].strip()
29 | self.model_max_tokens = int(os.environ["OPENAI_GPT_MAX_TOKENS"])
30 | self.api_key = os.environ["OPENAI_API_KEY"]
31 | openai.api_key = self.api_key
32 | self.max_tokens = 1000
33 | self.split_max_tokens = 500
34 |
35 | # Load the cl100k_base tokenizer which is designed to work with the ada-002 model
36 | self.tokenizer = tiktoken.get_encoding("cl100k_base")
37 |
38 | self.answer_generation_prompt = "Based on the context below\"\n\nContext: {}\n\n---\n\nPlease provide concise answer for this questions: {}"
39 | self.question_suggestion_prompt = "Based on the context below\"\n\nContext: {}\n\n---\n\nPlease recommend 3 more questions to be curious about {}"
40 | self.just_question_prompt = "{}{}"
41 |
42 | def index(self, doc_id, doc, text):
43 | doc["embeddings_dict_list"] = self._create_emb_dict_list(text)
44 | self.es.index(index=self.index_name,
45 | id=doc_id,
46 | document=doc)
47 |
48 | def search(self, query):
49 | es_query = {
50 | "query_string": {"query": query}
51 | }
52 |
53 | results = self.es.search(index=self.index_name, query=es_query)
54 | return results['hits']['hits']
55 |
56 | def _paper_results_to_text(self, results):
57 | text_result = ""
58 | for paper in results:
59 | title = paper["_source"].get("title", "")
60 | abstract = paper["_source"].get("abstract", "")
61 | paper_str = f"{title}:\n{abstract}\n\n"
62 | text_result += paper_str
63 | return text_result
64 |
65 | # Code from https://github.com/openai/openai-cookbook/blob/main/apps/web-crawl-q-and-a/web-qa.py
66 | # Function to split the text into chunks of a maximum number of tokens
67 | def _split_into_many(self, text):
68 | sentences = []
69 | for sentence in re.split(r'[{}]'.format(string.punctuation), text):
70 | sentence = sentence.strip()
71 | if sentence and (any(char.isalpha() for char in sentence) or any(char.isdigit() for char in sentence)):
72 | sentences.append(sentence)
73 |
74 | n_tokens = [len(self.tokenizer.encode(" " + sentence))
75 | for sentence in sentences]
76 |
77 | chunks = []
78 | tokens_so_far = 0
79 | chunk = []
80 |
81 | # Loop through the sentences and tokens joined together in a tuple
82 | for sentence, token in zip(sentences, n_tokens):
83 | # If the number of tokens so far plus the number of tokens in the current sentence is greater
84 | # than the max number of tokens, then add the chunk to the list of chunks and reset
85 | # the chunk and tokens so far
86 | if tokens_so_far + token > self.split_max_tokens and chunk:
87 | chunks.append(". ".join(chunk) + ".")
88 | chunk = []
89 | tokens_so_far = 0
90 |
91 | # If the number of tokens in the current sentence is greater than the max number of
92 | # tokens, go to the next sentence
93 | if token > self.split_max_tokens:
94 | continue
95 |
96 | # Otherwise, add the sentence to the chunk and add the number of tokens to the total
97 | chunk.append(sentence)
98 | tokens_so_far += token + 1
99 |
100 | # Add the last chunk to the list of chunks
101 | if chunk:
102 | chunks.append(". ".join(chunk) + ".")
103 |
104 | return chunks
105 |
106 |
107 | def _create_emb_dict_list(self, long_text):
108 | shortened = self._split_into_many(long_text)
109 |
110 | embeddings_dict_list = []
111 |
112 | for text in shortened:
113 | n_tokens = len(self.tokenizer.encode(text))
114 | embeddings = get_embedding(input=text)
115 | embeddings_dict = {}
116 | embeddings_dict["text"] = text
117 | embeddings_dict["n_tokens"] = n_tokens
118 | embeddings_dict["embeddings"] = embeddings
119 | embeddings_dict_list.append(embeddings_dict)
120 |
121 | return embeddings_dict_list
122 |
123 | def _create_context(self, question, df):
124 | """
125 | Create a context for a question by finding the most similar context from the dataframe
126 | """
127 |
128 | # Get the embeddings for the question
129 | q_embeddings = get_embedding(input=question)
130 |
131 | # Get the distances from the embeddings
132 | df['distances'] = distances_from_embeddings(
133 | q_embeddings, df['embeddings'].values, distance_metric='cosine')
134 |
135 | returns = []
136 | cur_len = 0
137 |
138 | # Sort by distance and add the text to the context until the context is too long
139 | for i, row in df.sort_values('distances', ascending=True).iterrows():
140 | # Add the length of the text to the current length
141 | cur_len += row['n_tokens'] + 4
142 |
143 | # If the context is too long, break
144 | if cur_len > self.max_tokens:
145 | break
146 |
147 | # Else add it to the text that is being returned
148 | returns.append(row["text"])
149 |
150 | # Return the context and the length of the context
151 | return "\n\n###\n\n".join(returns), cur_len
152 |
153 | def _gpt_api_call(self, query, input_token_len, context, call_type):
154 | if call_type == "answer":
155 | prompt = self.answer_generation_prompt
156 | elif call_type == "question":
157 | prompt = self.just_question_prompt
158 | else:
159 | prompt = self.question_suggestion_prompt
160 |
161 | body = {
162 | "model": self.model_engine,
163 | "prompt": prompt.format(context, query),
164 | "max_tokens": self.model_max_tokens - input_token_len,
165 | "n": 1,
166 | "temperature": 0.5,
167 | "stream": True,
168 | }
169 |
170 | headers = {"Content-Type": "application/json",
171 | "Authorization": f"Bearer {self.api_key}"}
172 |
173 | resp = requests.post("https://api.openai.com/v1/completions",
174 | headers=headers,
175 | data=json.dumps(body),
176 | stream=True)
177 | return resp
178 |
179 |
180 | def gpt_answer(self, query, es_results=None, text_results=None):
181 | # Generate summaries for each search result
182 | if text_results:
183 | input_token_len = len(self.tokenizer.encode(text_results))
184 | if input_token_len < self.max_tokens:
185 | context = text_results
186 | else:
187 | emb_dict_list = self._create_emb_dict_list(text_results)
188 | df = pd.DataFrame(columns=["text", "n_tokens", "embeddings"])
189 | for emb_dict in emb_dict_list:
190 | df = df.append(emb_dict, ignore_index=True)
191 |
192 | context, input_token_len = self._create_context(
193 | question=query,
194 | df=df)
195 | elif es_results:
196 | result_json_str = self._paper_results_to_text(es_results)
197 | if not result_json_str:
198 | result_json_str = "No results found"
199 |
200 | input_token_len = len(self.tokenizer.encode(result_json_str))
201 | if input_token_len < self.max_tokens:
202 | context = result_json_str
203 | else:
204 | # Create a pandas DataFrame from the list of embeddings dictionaries
205 | df = pd.DataFrame(columns=["text", "n_tokens", "embeddings"])
206 |
207 | # extract embeddings_dict from es_results and append to the dataframe
208 | for hit in es_results:
209 | embeddings_dict_list = hit['_source']['embeddings_dict_list']
210 | for embeddings_dict in embeddings_dict_list:
211 | df = df.append(embeddings_dict, ignore_index=True)
212 |
213 | context, input_token_len = self._create_context(
214 | question=query,
215 | df=df)
216 | else:
217 | assert False, "Must provide either es_results or text_results"
218 |
219 | return self._gpt_api_call(query, input_token_len, context, call_type="answer")
220 |
221 | def gpt_question_generator(self, text_results=None):
222 | if text_results:
223 | input_token_len = len(self.tokenizer.encode(text_results))
224 | if input_token_len < self.max_tokens:
225 | context = text_results
226 | else:
227 | context = text_results[:self.max_tokens]
228 | input_token_len = self.max_tokens
229 | else:
230 | assert False, "Text results are not found"
231 |
232 | return self._gpt_api_call("", input_token_len, context, call_type="suggestion")
233 |
234 | def gpt_direct_answer(self, q):
235 | input_token_len = len(self.tokenizer.encode(q))
236 | if input_token_len < self.max_tokens:
237 | query = q
238 | else:
239 | query = q[:self.max_tokens]
240 | input_token_len = self.max_tokens
241 | return self._gpt_api_call(q, input_token_len, "", call_type="question")
242 |
243 |
244 |
245 |
246 |
247 | # Example usage
248 | if __name__ == "__main__":
249 | esgpt = ESGPT("papers")
250 | query = "How to fix this bugs?"
251 | res = esgpt.search(query=query)
252 | res_str = esgpt._paper_results_to_text(res)
253 |
254 | # Pass ES results with precomputed embeddings
255 | res = esgpt.gpt_answer(query=query, es_results=res)
256 | print(res.text)
257 |
258 | # Pass text results and do embeddings on the fly
259 | # Note: This will be slower
260 | res = esgpt.gpt_answer(query=query, text_results=res_str)
261 | print(res.text)
262 |
--------------------------------------------------------------------------------
/static/index.html:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 | Elastic search GPT
6 |
7 |
8 |
9 |
10 |
11 |
12 |
210 |
211 |
212 |
213 |
214 |
215 |
Elasticsearch with GPT (Sung's paper search)
216 |
217 |
218 |
219 |
223 |
224 |
225 |
228 |
229 |
230 |
233 |
234 |
235 |
236 |
237 |
238 |
239 |
Search results
240 |
241 |
242 |
243 |
264 |
265 |
266 |
281 |
307 |
308 |
309 |
310 |
--------------------------------------------------------------------------------