4 |
5 |
6 |
--------------------------------------------------------------------------------
/.idea/misc.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
--------------------------------------------------------------------------------
/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023 Ethan
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # BriefGPT
2 |
3 | BriefGPT is a powerful, locally-run tool for document summarization and querying using OpenAI's models. You retain control over your documents and API keys, ensuring privacy and security.
4 |
5 | ## Update
6 | Added support for fully local use! Instructor is used to embed documents, and the LLM can be either LlamaCpp or GPT4ALL, ggml formatted. Put your model in the 'models' folder, set up your environmental variables (model type and path), and run ```streamlit run local_app.py``` to get started. Tested with the following models: [Llama](https://huggingface.co/eachadea/ggml-vicuna-13b-1.1/blob/main/ggml-vic13b-q5_1.bin), [GPT4ALL](https://gpt4all.io/models/ggml-gpt4all-j-v1.3-groovy.bin).
7 |
8 | Please note this is experimental - it will be significantly slower and the quality may vary. PR's welcome!
9 |
10 | # Example (using the "Sparks of AGI" paper, sped up)
11 | 
12 |
13 |
14 |
15 |
16 | # Setup
17 | 1. Clone the repository
18 | 2. Download all requirements
19 | ``pip install -r requirements.txt``
20 | 3. Set your API key in test.env
21 | 4. Navigate to the project directory and run
22 | ```streamlit run main.py```
23 | 5. Add your PDF's or .txt's to the documents folder in the project directory
24 | 6. If using epubs, ensure you have pandoc installed and added to PATH
25 |
26 |
27 |
28 |
29 | # How it works
30 | ## Chat
31 | 1. Creating and saving embeddings - once you load a file, it is broken into chunks and stored as a FAISS index in the 'embeddings' folder. These embeddings will be used if you load the document into the chat again.
32 | 2. Retrieving, ranking, and processing results - a similarity search is performed on the index to get the top n results. These results are then re-ranked by a function that strips the original query of stopwords and uses fuzzy matching to find the similarity in exact words between the query and the retrieved results. This gets better results than solely doing a similarity search.
33 | 3. Output - the re-ranked results and the user query are passed to the llm, and the response is displayed.
34 |
35 |
36 |
37 |
38 | ## Summarization
39 | 1. Input - can handle both documents and YouTube URL's - will find the transcript and generate a summary based off of that.
40 | 2. Processing and embedding - before embedding, documents are stripped of any special tokens that might cause errors. Documents are embedded in chunks of varying size, depending on the overall document's size.
41 | 3. Clustering - once the documents are embedded, they are grouped into clusters using the K-means algorithm. The number of clusters can be predetermined (10) or variable (finds optimal number based on the elbow method). The embedding closest to each cluster centroid is retrieved - each cluster might represent a different theme or idea, and the retrieved embeddings are those that best encapsulate that theme or idea - that's the goal, at least.
42 | 4. Summarization - summarization is performed in two steps. First, each retrieved embedding is matched with its corresponding text chunk. Each chunk is passed to GPT-3.5 in an individual call to the API - these calls are made in parallel. Once we have accumulated a summary for each chunk, the summaries are passed to GPT-3.5 or GPT-4 for the final summary.
43 | 5. Output - the summary is displayed on the page and saved as a text file.
44 | 
45 |
46 |
47 |
48 | Improved support for locally run LLM's is coming.
49 |
50 | Built using Langchain! This is project was made for fun, and is likely full of bugs. It is not fully optimized. Contributions or bug reports are welcomed!
51 |
52 | todo: keep summary in session state, save transcripts when loaded to summarize
53 |
--------------------------------------------------------------------------------
/chat_utils.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | from langchain.text_splitter import RecursiveCharacterTextSplitter
4 | from langchain.embeddings import OpenAIEmbeddings
5 | from langchain.vectorstores import FAISS
6 | from langchain.llms import OpenAI
7 |
8 |
9 | from fuzzywuzzy import fuzz
10 |
11 | import nltk
12 | from nltk.corpus import stopwords
13 | from nltk.tokenize import word_tokenize
14 |
15 | from my_prompts import chat_prompt, hypothetical_prompt
16 |
17 | from dotenv import load_dotenv
18 |
19 | from summary_utils import doc_loader, remove_special_tokens, directory_loader
20 |
21 |
22 | nltk.download('stopwords')
23 | nltk.download('punkt')
24 |
25 |
26 | def create_and_save_directory_embeddings(directory_path, name):
27 | embeddings = OpenAIEmbeddings()
28 | splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50)
29 | docs = directory_loader(directory_path)
30 | split_docs = splitter.split_documents(docs)
31 | processed_split_docs = remove_special_tokens(split_docs)
32 | db = FAISS.from_documents(processed_split_docs, embeddings)
33 | db.save_local(folder_path='directory_embeddings', index_name=name)
34 | return db
35 |
36 | def create_and_save_chat_embeddings(file_path):
37 | name = os.path.split(file_path)[1].split('.')[0]
38 | embeddings = OpenAIEmbeddings()
39 | doc = doc_loader(file_path)
40 | splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50)
41 | split_docs = splitter.split_documents(doc)
42 | processed_split_docs = remove_special_tokens(split_docs)
43 | db = FAISS.from_documents(processed_split_docs, embeddings)
44 | db.save_local(folder_path='embeddings', index_name=name)
45 |
46 |
47 | def load_chat_embeddings(file_path):
48 | name = os.path.split(file_path)[1].split('.')[0]
49 | embeddings = OpenAIEmbeddings()
50 | db = FAISS.load_local(folder_path='embeddings', index_name=name, embeddings=embeddings)
51 | return db
52 |
53 |
54 |
55 |
56 | def results_from_db(db:FAISS, question, num_results=10):
57 | results = db.similarity_search(question, k=num_results)
58 | return results
59 |
60 |
61 | def rerank_fuzzy_matching(question, results, num_results=5):
62 | filtered_question = filter_stopwords(question)
63 | if filtered_question == '':
64 | return results[-5:]
65 | scores_and_results = []
66 | for result in results:
67 | score = fuzz.partial_ratio(question, result.page_content)
68 | scores_and_results.append((score, result))
69 |
70 | scores_and_results.sort(key=lambda x: x[0], reverse=True)
71 | reranked = [result for score, result in scores_and_results]
72 |
73 | return reranked[:num_results]
74 |
75 |
76 | def filter_stopwords(question):
77 | words = word_tokenize(question)
78 | filtered_words = [word for word in words if word not in stopwords.words('english')]
79 | filtered_sentence = ' '.join(filtered_words)
80 | return filtered_sentence
81 |
82 |
83 | def qa_from_db(question, db, llm_name, hypothetical):
84 | llm = create_llm(llm_name)
85 | if hypothetical:
86 | hypothetical_llm = create_llm(llm_name)
87 | hypothetical_answer = hypothetical_document_embeddings(question, hypothetical_llm)
88 | results = results_from_db(db, hypothetical_answer)
89 | else:
90 | results = results_from_db(db, question)
91 | reranked_results = rerank_fuzzy_matching(question, results)
92 | reranked_content = [result.page_content for result in reranked_results]
93 |
94 | if type(llm_name) != str:
95 | message = f'Answer the user question based on the context. Question: {question} Context: {reranked_content[:2]} Answer:'
96 | else:
97 | message = f'{chat_prompt} ---------- Context: {reranked_content} -------- User Question: {question} ---------- Response:'
98 | formatted_sources = source_formatter(reranked_results)
99 | output = llm(message)
100 | return output, formatted_sources
101 |
102 |
103 |
104 | def source_formatter(sources):
105 | formatted_strings = []
106 | for doc in sources:
107 | source_name = doc.metadata['source'].split('\\')[-1]
108 | source_content = doc.page_content.replace('\n', ' ') # Replacing newlines with spaces
109 | formatted_string = f"Source name: {source_name} | Source content: '{source_content}' - end of content"
110 | formatted_strings.append(formatted_string)
111 | final_string = '\n\n\n'.join(formatted_strings)
112 | return final_string
113 |
114 | def create_llm(llm_name):
115 | if type(llm_name) != str:
116 | return llm_name
117 | else:
118 | llm = OpenAI(model_name=llm_name)
119 | return llm
120 |
121 | def hypothetical_document_embeddings(question, llm):
122 | message = f'{hypothetical_prompt} {question} :'
123 | output = llm(message)
124 | print("output: ", output)
125 | return output
126 |
127 |
128 |
129 |
130 |
131 |
132 |
133 |
134 |
135 |
--------------------------------------------------------------------------------
/directory_embeddings/testembeddings.faiss:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/e-johnstonn/BriefGPT/6bd265ed7a678de1c7124350d04801b3857bc190/directory_embeddings/testembeddings.faiss
--------------------------------------------------------------------------------
/documents/sparksofagi.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/e-johnstonn/BriefGPT/6bd265ed7a678de1c7124350d04801b3857bc190/documents/sparksofagi.pdf
--------------------------------------------------------------------------------
/embeddings/sparksofagi.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/e-johnstonn/BriefGPT/6bd265ed7a678de1c7124350d04801b3857bc190/embeddings/sparksofagi.pkl
--------------------------------------------------------------------------------
/local_app.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import streamlit as st
4 | from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
5 |
6 | from streamlit_chat import message as st_message
7 |
8 | import pandas as pd
9 |
10 | from local_chat_utils import load_db_from_file_and_create_if_not_exists_local
11 | from streamlit_app_utils import generate_answer
12 |
13 | from langchain.llms import GPT4All, LlamaCpp
14 |
15 | from dotenv import load_dotenv
16 |
17 |
18 |
19 | load_dotenv('test.env')
20 |
21 | model_type = os.getenv('MODEL_TYPE')
22 | model_path = os.getenv('MODEL_PATH')
23 |
24 |
25 | accepted_filetypes = ['.txt', '.pdf', '.epub']
26 |
27 | #Model is initialized here. Configure it with your parameters and the path to your model.
28 |
29 | loading = st.spinner('Initializing LLM')
30 | with st.spinner('Initializing LLM...'):
31 | if 'llm' not in st.session_state:
32 | with st.spinner('Loading LLM...'):
33 | if model_type.lower() == 'LlamaCpp'.lower():
34 | llm = LlamaCpp(model_path=model_path, n_ctx=1000)
35 | st.session_state.llm = llm
36 | elif model_type.lower() == 'GPT4All'.lower():
37 | llm = GPT4All(model=model_path, backend='gptj', n_ctx=1000)
38 | st.session_state.llm = llm
39 | else:
40 | st.warning('Invalid model type. GPT4ALL or LlamaCpp supported - make sure you specify in your env file.')
41 |
42 |
43 | def chat():
44 | st.title('Chat')
45 | if 'text_input' not in st.session_state:
46 | st.session_state.text_input = ''
47 | directory = 'documents'
48 | files = os.listdir(directory)
49 | files = [file for file in files if file.endswith(tuple(accepted_filetypes))]
50 | selected_file = st.selectbox('Select a file', files)
51 | st.write('You selected: ' + selected_file)
52 | selected_file_path = os.path.join(directory, selected_file)
53 |
54 | if st.button('Load file (first time might take a second...) pressing this button will reset the chat history'):
55 | db = load_db_from_file_and_create_if_not_exists_local(selected_file_path, 'hkunlp/instructor-base')
56 | st.session_state.db = db
57 | st.session_state.history = []
58 |
59 | user_input = st.text_input('Enter your question', key='text_input')
60 |
61 | if st.button('Ask') and 'db' in st.session_state:
62 | answer = generate_answer(st.session_state.db, st.session_state.llm)
63 |
64 |
65 | if 'history' not in st.session_state:
66 | st.session_state.history = []
67 | for i, chat in enumerate(st.session_state.history):
68 | st_message(**chat, key=str(i))
69 |
70 |
71 |
72 |
73 | def documents():
74 | st.title('Documents')
75 | st.markdown('Documents are stored in the documents folder in the project directory.')
76 | directory = 'documents'
77 | files = os.listdir(directory)
78 | files = [file for file in files if file.endswith('.txt') or file.endswith('.pdf')]
79 | if files:
80 | files_df = pd.DataFrame(files, columns=['File Name'], index=range(1, len(files) + 1))
81 | st.dataframe(files_df, width=1000)
82 | else:
83 | st.write('No documents found in documents folder. Add some documents first!')
84 |
85 |
86 |
87 | PAGES = {
88 | "Chat": chat,
89 | "Documents": documents,
90 | }
91 |
92 | st.sidebar.title("Navigation")
93 | selection = st.sidebar.radio("Go to", list(PAGES.keys()))
94 | st.sidebar.markdown(' [Contact author](mailto:ethanujohnston@gmail.com)')
95 | st.sidebar.markdown(' [Github](https://github.com/e-johnstonn/docsummarizer)')
96 | page = PAGES[selection]
97 | page()
--------------------------------------------------------------------------------
/local_chat_utils.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | from langchain import FAISS
4 | from langchain.embeddings import HuggingFaceInstructEmbeddings
5 | from langchain.text_splitter import RecursiveCharacterTextSplitter
6 |
7 | from summary_utils import doc_loader, remove_special_tokens
8 |
9 | import streamlit as st
10 |
11 |
12 | def create_and_save_local(file_path, model_name):
13 | name = os.path.split(file_path)[1].split('.')[0]
14 | embeddings = HuggingFaceInstructEmbeddings(model_name=model_name)
15 | doc = doc_loader(file_path)
16 | splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50)
17 | split_docs = splitter.split_documents(doc)
18 | processed_split_docs = remove_special_tokens(split_docs)
19 | db = FAISS.from_documents(processed_split_docs, embeddings)
20 | db.save_local(folder_path='local_embeddings', index_name=name)
21 |
22 |
23 | def load_local_embeddings(file_path, model_name):
24 | name = os.path.split(file_path)[1].split('.')[0]
25 | embeddings = HuggingFaceInstructEmbeddings(model_name=model_name)
26 | db = FAISS.load_local(folder_path='local_embeddings', index_name=name, embeddings=embeddings)
27 | return db
28 |
29 |
30 |
31 | def load_db_from_file_and_create_if_not_exists_local(file_path, model_name):
32 | with st.spinner('Loading chat embeddings...'):
33 | try:
34 | db = load_local_embeddings(file_path, model_name)
35 | print('success')
36 | except RuntimeError:
37 | print('not found')
38 | create_and_save_local(file_path, model_name)
39 | db = load_local_embeddings(file_path, model_name)
40 | if db:
41 | st.success('Loaded successfully! Start a chat below.')
42 | else:
43 | st.warning('Something went wrong... failed to load chat embeddings.')
44 | return db
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
--------------------------------------------------------------------------------
/local_embeddings/sparksofagi.faiss:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/e-johnstonn/BriefGPT/6bd265ed7a678de1c7124350d04801b3857bc190/local_embeddings/sparksofagi.faiss
--------------------------------------------------------------------------------
/local_embeddings/sparksofagi.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/e-johnstonn/BriefGPT/6bd265ed7a678de1c7124350d04801b3857bc190/local_embeddings/sparksofagi.pkl
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | import os
2 | import streamlit as st
3 | from streamlit_chat import message as st_message
4 | from dotenv import load_dotenv
5 |
6 | from chat_utils import create_and_save_directory_embeddings
7 | from streamlit_app_utils import process_summarize_button, generate_answer, load_db_from_file_and_create_if_not_exists, validate_api_key, load_dir_chat_embeddings
8 |
9 | from summary_utils import transcript_loader
10 |
11 | import pandas as pd
12 |
13 | import glob
14 |
15 |
16 |
17 |
18 | #Youtube stuff is kinda broken! I'll fix it soon.
19 |
20 | load_dotenv('test.env')
21 |
22 | st.set_page_config(page_title='BriefGPT')
23 |
24 | accepted_filetypes = ['.txt', '.pdf', '.epub']
25 |
26 | def summarize():
27 | """
28 | The main function for the Streamlit app.
29 |
30 | :return: None.
31 | """
32 | st.title("Summarize")
33 | st.write("Summaries are saved to the 'summaries' folder in the project directory.")
34 |
35 | input_method = st.radio("Select input method", ('Document', 'YouTube URL'))
36 |
37 | if input_method == 'Document':
38 | directory = 'documents'
39 | files = os.listdir(directory)
40 | files = [file for file in files if file.endswith(tuple(accepted_filetypes))]
41 | if files:
42 | selected_file = st.selectbox('Select a file', files)
43 | st.write('You selected: ' + selected_file)
44 | selected_file_path = os.path.join(directory, selected_file)
45 | else:
46 | st.write('No documents found in documents folder. Add some documents first!')
47 | return
48 |
49 | if input_method == 'YouTube URL':
50 | youtube_url = st.text_input("Enter a YouTube URL to summarize")
51 |
52 | use_gpt_4 = st.checkbox("Use GPT-4 for the final prompt (STRONGLY recommended, requires GPT-4 API access - progress bar will appear to get stuck as GPT-4 is slow)", value=True)
53 | find_clusters = st.checkbox('Optimal clustering (saves on tokens)', value=False)
54 |
55 |
56 |
57 | if st.button('Summarize (click once and wait)'):
58 | if input_method == 'Document':
59 | process_summarize_button(selected_file_path, use_gpt_4, find_clusters)
60 |
61 | else:
62 | doc = transcript_loader(youtube_url)
63 | process_summarize_button(doc, use_gpt_4, find_clusters, file=False)
64 |
65 |
66 |
67 | def chat():
68 | dir_or_doc = st.radio('Select a chat method', ('Document', 'Directory'))
69 | st.title('Chat')
70 | model_name = st.radio('Select a model', ('gpt-3.5-turbo', 'gpt-4'))
71 | hypothetical = st.checkbox('Use hypothetical embeddings', value=False)
72 | if dir_or_doc == 'Document':
73 | if 'text_input' not in st.session_state:
74 | st.session_state.text_input = ''
75 | directory = 'documents'
76 | files = os.listdir(directory)
77 | files = [file for file in files if file.endswith(tuple(accepted_filetypes))]
78 | selected_file = st.selectbox('Select a file', files)
79 | st.write('You selected: ' + selected_file)
80 | selected_file_path = os.path.join(directory, selected_file)
81 |
82 | if st.button('Load file (first time might take a second...) pressing this button will reset the chat history'):
83 | db = load_db_from_file_and_create_if_not_exists(selected_file_path)
84 | st.session_state.db = db
85 | st.session_state.history = []
86 |
87 | else:
88 | if 'text_input' not in st.session_state:
89 | st.session_state.text_input = ''
90 | load_or_create = st.checkbox('Load from existing directory (already embedded)', value=False)
91 | if load_or_create:
92 | embeddings = os.listdir('directory_embeddings')
93 | embeddings = [file for file in embeddings if file.endswith('.faiss')]
94 | select_embedding = st.selectbox('Select an embedding', embeddings)
95 | load = st.button('Load embeddings')
96 | if load:
97 | embedding_file_path = os.path.join('directory_embeddings', select_embedding)
98 | db = load_dir_chat_embeddings(embedding_file_path)
99 | st.session_state.db = db
100 | st.session_state.history = []
101 |
102 | else:
103 | directory = st.text_input('Enter a directory to load from - just "documents" will load the default documents folder')
104 | name = st.text_input('Enter a unique nickname for the directory')
105 | if st.button('Load directory (first time might take a second...) pressing this button will reset the chat history'):
106 | with st.spinner('Loading directory...'):
107 | db = create_and_save_directory_embeddings(directory, name)
108 | st.session_state.db = db
109 | st.success('Directory loaded successfully')
110 | st.session_state.history = []
111 |
112 | user_input = st.text_input('Enter your question', key='text_input')
113 |
114 | if st.button('Ask') and 'db' in st.session_state and validate_api_key(model_name):
115 | answer = generate_answer(st.session_state.db, model_name, hypothetical)
116 |
117 | if 'history' not in st.session_state:
118 | st.session_state.history = []
119 | if 'sources' not in st.session_state:
120 | st.session_state.sources = []
121 | for i, chat in enumerate(st.session_state.history):
122 | st_message(**chat, key=str(i))
123 | for i, source in enumerate(st.session_state.sources):
124 | with st.expander('Sources', expanded=False):
125 | st.markdown(source)
126 |
127 |
128 | def documents():
129 | st.title('Documents')
130 | st.markdown('Documents are stored in the documents folder in the project directory.')
131 | directory = 'documents'
132 | files = os.listdir(directory)
133 | files = [file for file in files if file.endswith(tuple(accepted_filetypes))]
134 | if files:
135 | files_df = pd.DataFrame(files, columns=['File Name'], index=range(1, len(files) + 1))
136 | st.dataframe(files_df, width=1000)
137 | else:
138 | st.write('No documents found in documents folder. Add some documents first!')
139 |
140 |
141 | def compare_results():
142 | st.title('Compare')
143 | st.write("Compare retrieval results using hypothetical embeddings vs. normal embeddings. Support for comparing multiple models coming soon.")
144 | model_name = 'gpt-3.5-turbo'
145 |
146 | if 'text_input' not in st.session_state:
147 | st.session_state.text_input = ''
148 | directory = 'documents'
149 | files = os.listdir(directory)
150 | files = [file for file in files if file.endswith(tuple(accepted_filetypes))]
151 | selected_file = st.selectbox('Select a file', files)
152 | st.write('You selected: ' + selected_file)
153 | selected_file_path = os.path.join(directory, selected_file)
154 |
155 | if st.button('Load file (first time might take a second...) pressing this button will reset the chat history'):
156 | db = load_db_from_file_and_create_if_not_exists(selected_file_path)
157 | st.session_state.db = db
158 | st.session_state.history = []
159 |
160 |
161 |
162 |
163 | user_input = st.text_input('Enter your question', key='text_input')
164 |
165 | if st.button('Ask') and 'db' in st.session_state and validate_api_key(model_name):
166 | st.markdown('Question: ' + user_input)
167 | answer_a, sources_a = generate_answer(st.session_state.db, model_name, hypothetical=True)
168 | answer_b, sources_b = generate_answer(st.session_state.db, model_name, hypothetical=False)
169 |
170 | col1, col2 = st.columns(2)
171 |
172 | with col1:
173 | st.header('Hypothetical embeddings')
174 | st.markdown(answer_a)
175 | with st.expander('Sources', expanded=False):
176 | st.markdown(sources_a)
177 | with col2:
178 | st.header('Normal embeddings')
179 | st.markdown(answer_b)
180 | with st.expander('Sources', expanded=False):
181 | st.markdown(sources_b)
182 |
183 | st.session_state.history = []
184 | st.session_state.sources = []
185 |
186 |
187 |
188 |
189 | PAGES = {
190 | "Chat": chat,
191 | "Summarize": summarize,
192 | "Documents": documents,
193 | "Compare": compare_results
194 | }
195 |
196 | st.sidebar.title("Navigation")
197 | selection = st.sidebar.radio("Go to", list(PAGES.keys()))
198 | st.sidebar.markdown(' [Contact author](mailto:ethanujohnston@gmail.com)')
199 | st.sidebar.markdown(' [Github](https://github.com/e-johnstonn/docGPT)')
200 | st.sidebar.markdown('[More info on hypothetical embeddings here](https://arxiv.org/abs/2212.10496)', unsafe_allow_html=True)
201 | page = PAGES[selection]
202 | page()
203 |
204 |
205 |
206 |
207 |
208 |
209 |
--------------------------------------------------------------------------------
/models/put your model here!:
--------------------------------------------------------------------------------
1 | https://gpt4all.io/models/ggml-gpt4all-j-v1.3-groovy.bin
2 |
3 | This link will be useful... soon
4 |
--------------------------------------------------------------------------------
/my_prompts.py:
--------------------------------------------------------------------------------
1 | file_map = """
2 | You will be given a single section from a text. This will be enclosed in triple backticks.
3 | Please provide a cohesive summary of the following section excerpt, focusing on the key points and main ideas, while maintaining clarity and conciseness.
4 |
5 | '''{text}'''
6 |
7 | FULL SUMMARY:
8 | """
9 |
10 |
11 | file_combine = """
12 | Read all the provided summaries from a larger document. They will be enclosed in triple backticks.
13 | Determine what the overall document is about and summarize it with this information in mind.
14 | Synthesize the info into a well-formatted easy-to-read synopsis, structured like an essay that summarizes them cohesively.
15 | Do not simply reword the provided text. Do not copy the structure from the provided text.
16 | Avoid repetition. Connect all the ideas together.
17 | Preceding the synopsis, write a short, bullet form list of key takeaways.
18 | Format in HTML. Text should be divided into paragraphs. Paragraphs should be indented.
19 |
20 | '''{text}'''
21 |
22 |
23 | """
24 |
25 | youtube_map = """
26 | You will be given a single section from a transcript of a youtube video. This will be enclosed in triple backticks.
27 | Please provide a cohesive summary of the section of the transcript, focusing on the key points and main ideas, while maintaining clarity and conciseness.
28 |
29 | '''{text}'''
30 |
31 | FULL SUMMARY:
32 | """
33 |
34 |
35 | youtube_combine = """
36 | Read all the provided summaries from a youtube transcript. They will be enclosed in triple backticks.
37 | Determine what the overall video is about and summarize it with this information in mind.
38 | Synthesize the info into a well-formatted easy-to-read synopsis, structured like an essay that summarizes them cohesively.
39 | Do not simply reword the provided text. Do not copy the structure from the provided text.
40 | Avoid repetition. Connect all the ideas together.
41 | Preceding the synopsis, write a short, bullet form list of key takeaways.
42 | Format in HTML. Text should be divided into paragraphs. Paragraphs should be indented.
43 |
44 | '''{text}'''
45 |
46 |
47 | """
48 |
49 | chat_prompt = """
50 | You will be provided some context from a document.
51 | Based on this context, answer the user question.
52 | Only answer based on the given context.
53 | If you cannot answer, say 'I don't know' and recommend a different question.
54 |
55 | """
56 |
57 | hypothetical_prompt = """
58 | Given the user's question, please generate a response that mimics the exact format in which the relevant information would appear within a document, even if the information does not exist.
59 | The response should not offer explanations, context, or commentary, but should emulate the precise structure in which the answer would be found in a hypothetical document.
60 | Factuality is not important, the priority is the hypothetical structure of the excerpt. Use made-up facts to emulate the structure.
61 | For example, if the user question is "who are the authors?", the response should be something like
62 | 'Authors: John Smith, Jane Doe, and Bob Jones'
63 | The user's question is:
64 |
65 | """
66 |
67 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | aiohttp==3.8.4
2 | aiosignal==1.3.1
3 | altair==4.2.2
4 | anyio==3.6.2
5 | argilla==1.7.0
6 | async-timeout==4.0.2
7 | attrs==23.1.0
8 | backoff==2.2.1
9 | blinker==1.6.2
10 | cachetools==5.3.0
11 | certifi==2023.5.7
12 | cffi==1.15.1
13 | charset-normalizer==3.1.0
14 | click==8.1.3
15 | colorama==0.4.6
16 | commonmark==0.9.1
17 | contourpy==1.0.7
18 | cryptography==40.0.2
19 | cycler==0.11.0
20 | dataclasses-json==0.5.7
21 | datasets==2.12.0
22 | decorator==5.1.1
23 | Deprecated==1.2.13
24 | dill==0.3.6
25 | entrypoints==0.4
26 | et-xmlfile==1.1.0
27 | faiss-cpu==1.7.4
28 | filelock==3.12.0
29 | fonttools==4.39.4
30 | frozenlist==1.3.3
31 | fsspec==2023.5.0
32 | fuzzywuzzy==0.18.0
33 | gitdb==4.0.10
34 | GitPython==3.1.31
35 | greenlet==2.0.2
36 | h11==0.14.0
37 | httpcore==0.16.3
38 | httpx==0.23.3
39 | huggingface-hub==0.14.1
40 | idna==3.4
41 | importlib-metadata==6.6.0
42 | importlib-resources==5.12.0
43 | InstructorEmbedding==1.0.0
44 | Jinja2==3.1.2
45 | joblib==1.2.0
46 | jsonlines==3.1.0
47 | jsonschema==4.17.3
48 | kiwisolver==1.4.4
49 | langchain==0.0.169
50 | Levenshtein==0.21.0
51 | llama-cpp-python==0.1.50
52 | lxml==4.9.2
53 | Markdown==3.4.3
54 | markdown-it-py==2.2.0
55 | MarkupSafe==2.1.2
56 | marshmallow==3.19.0
57 | marshmallow-enum==1.5.1
58 | matplotlib==3.7.1
59 | mdurl==0.1.2
60 | monotonic==1.6
61 | mpmath==1.3.0
62 | msg-parser==1.2.0
63 | multidict==6.0.4
64 | multiprocess==0.70.14
65 | mypy-extensions==1.0.0
66 | networkx==3.1
67 | nltk==3.8.1
68 | numexpr==2.8.4
69 | numpy==1.23.5
70 | olefile==0.46
71 | openai==0.27.6
72 | openapi-schema-pydantic==1.2.4
73 | openpyxl==3.1.2
74 | packaging==23.1
75 | pandas==1.5.3
76 | pandoc==2.3
77 | pdfminer.six==20221105
78 | Pillow==9.5.0
79 | plumbum==1.8.1
80 | ply==3.11
81 | protobuf==3.20.3
82 | pyarrow==12.0.0
83 | pycparser==2.21
84 | pydantic==1.10.7
85 | pydeck==0.8.1b0
86 | Pygments==2.15.1
87 | pygpt4all==1.1.0
88 | pygptj==2.0.3
89 | pyllamacpp==2.1.3
90 | Pympler==1.0.1
91 | pypandoc==1.11
92 | pyparsing==3.0.9
93 | pypdf==3.8.1
94 | PyPDF2==3.0.1
95 | pyrsistent==0.19.3
96 | python-dateutil==2.8.2
97 | python-docx==0.8.11
98 | python-dotenv==1.0.0
99 | python-Levenshtein==0.21.0
100 | python-magic==0.4.27
101 | python-pptx==0.6.21
102 | pytz==2023.3
103 | PyYAML==6.0
104 | rapidfuzz==3.0.0
105 | regex==2023.5.5
106 | requests==2.30.0
107 | responses==0.18.0
108 | rfc3986==1.5.0
109 | rich==13.0.1
110 | scikit-learn==1.2.2
111 | scipy==1.10.1
112 | sentence-transformers==2.2.2
113 | sentencepiece==0.1.99
114 | six==1.16.0
115 | smmap==5.0.0
116 | sniffio==1.3.0
117 | SQLAlchemy==2.0.13
118 | streamlit==1.22.0
119 | streamlit-chat==0.0.2.2
120 | sympy==1.12
121 | tenacity==8.2.2
122 | threadpoolctl==3.1.0
123 | tiktoken==0.4.0
124 | tokenizers==0.12.1
125 | toml==0.10.2
126 | toolz==0.12.0
127 | torch==2.0.1
128 | torchvision==0.15.2
129 | tornado==6.3.2
130 | tqdm==4.65.0
131 | transformers==4.20.0
132 | typer==0.9.0
133 | typing-inspect==0.8.0
134 | typing_extensions==4.5.0
135 | tzdata==2023.3
136 | tzlocal==5.0
137 | unstructured==0.6.6
138 | urllib3==2.0.2
139 | validators==0.20.0
140 | watchdog==3.0.0
141 | wrapt==1.14.1
142 | XlsxWriter==3.1.0
143 | xxhash==3.2.0
144 | yarl==1.9.2
145 | youtube-transcript-api==0.6.0
146 | zipp==3.15.0
147 |
148 |
--------------------------------------------------------------------------------
/streamlit_app_utils.py:
--------------------------------------------------------------------------------
1 | import PyPDF2
2 |
3 | from io import StringIO
4 |
5 | from langchain import FAISS
6 | from langchain.chat_models import ChatOpenAI
7 | from langchain.embeddings import OpenAIEmbeddings
8 |
9 | from chat_utils import load_chat_embeddings, create_and_save_chat_embeddings, qa_from_db, doc_loader
10 |
11 | import streamlit as st
12 |
13 | from my_prompts import file_map, file_combine, youtube_map, youtube_combine
14 |
15 | import os
16 |
17 | from summary_utils import doc_to_text, token_counter, summary_prompt_creator, doc_to_final_summary
18 |
19 |
20 | def pdf_to_text(pdf_file):
21 | """
22 | Convert a PDF file to a string of text.
23 |
24 | :param pdf_file: The PDF file to convert.
25 |
26 | :return: A string of text.
27 | """
28 | pdf_reader = PyPDF2.PdfReader(pdf_file)
29 | text = StringIO()
30 | for i in range(len(pdf_reader.pages)):
31 | p = pdf_reader.pages[i]
32 | text.write(p.extract_text())
33 | return text.getvalue().encode('utf-8')
34 |
35 |
36 | def check_gpt_4():
37 | """
38 | Check if the user has access to GPT-4.
39 |
40 | :param api_key: The user's OpenAI API key.
41 |
42 | :return: True if the user has access to GPT-4, False otherwise.
43 | """
44 | try:
45 | ChatOpenAI(model_name='gpt-4').call_as_llm('Hi')
46 | return True
47 | except Exception as e:
48 | return False
49 |
50 |
51 |
52 | def token_limit(doc, maximum=200000):
53 | """
54 | Check if a document has more tokens than a specified maximum.
55 |
56 | :param doc: The langchain Document object to check.
57 |
58 | :param maximum: The maximum number of tokens allowed.
59 |
60 | :return: True if the document has less than the maximum number of tokens, False otherwise.
61 | """
62 | text = doc_to_text(doc)
63 | count = token_counter(text)
64 | print(count)
65 | if count > maximum:
66 | return False
67 | return True
68 |
69 |
70 | def token_minimum(doc, minimum=2000):
71 | """
72 | Check if a document has more tokens than a specified minimum.
73 |
74 | :param doc: The langchain Document object to check.
75 |
76 | :param minimum: The minimum number of tokens allowed.
77 |
78 | :return: True if the document has more than the minimum number of tokens, False otherwise.
79 | """
80 | text = doc_to_text(doc)
81 | count = token_counter(text)
82 | if count < minimum:
83 | return False
84 | return True
85 |
86 |
87 | def validate_api_key(model_name='gpt-3.5-turbo'):
88 | try:
89 | ChatOpenAI(model_name=model_name).call_as_llm('Hi')
90 | print('API Key is valid')
91 | return True
92 | except Exception as e:
93 | print(e)
94 | st.warning('API key is invalid or OpenAI is having issues.')
95 | print('Invalid API key.')
96 |
97 |
98 | def create_chat_model_for_summary(use_gpt_4):
99 | """
100 | Create a chat model ensuring that the token limit of the overall summary is not exceeded - GPT-4 has a higher token limit.
101 |
102 | :param api_key: The OpenAI API key to use for the chat model.
103 |
104 | :param use_gpt_4: Whether to use GPT-4 or not.
105 |
106 | :return: A chat model.
107 | """
108 | if use_gpt_4:
109 | return ChatOpenAI(temperature=0, max_tokens=500, model_name='gpt-3.5-turbo')
110 | else:
111 | return ChatOpenAI(temperature=0, max_tokens=250, model_name='gpt-3.5-turbo')
112 |
113 |
114 | def process_summarize_button(file_or_transcript, use_gpt_4, find_clusters, file=True):
115 | """
116 | Processes the summarize button, and displays the summary if input and doc size are valid
117 |
118 | :param file_or_transcript: The file uploaded by the user or the transcript from the YouTube URL
119 |
120 | :param api_key: The API key entered by the user
121 |
122 | :param use_gpt_4: Whether to use GPT-4 or not
123 |
124 | :param find_clusters: Whether to find optimal clusters or not, experimental
125 |
126 | :return: None
127 | """
128 | if not validate_input(file_or_transcript, use_gpt_4):
129 | return
130 |
131 | with st.spinner("Summarizing... please wait..."):
132 |
133 | if file:
134 | doc = doc_loader(file_or_transcript)
135 | map_prompt = file_map
136 | combine_prompt = file_combine
137 | head, tail = os.path.split(file_or_transcript)
138 | name = tail.split('.')[0]
139 |
140 | else:
141 | doc = file_or_transcript
142 | map_prompt = youtube_map
143 | combine_prompt = youtube_combine
144 | name = str(file_or_transcript)[30:44].strip()
145 |
146 | llm = create_chat_model_for_summary(use_gpt_4)
147 | initial_prompt_list = summary_prompt_creator(map_prompt, 'text', llm)
148 | final_prompt_list = summary_prompt_creator(combine_prompt, 'text', llm)
149 |
150 | if not validate_doc_size(doc):
151 | return
152 |
153 | if find_clusters:
154 | summary = doc_to_final_summary(doc, 10, initial_prompt_list, final_prompt_list, use_gpt_4, find_clusters)
155 |
156 | else:
157 | summary = doc_to_final_summary(doc, 10, initial_prompt_list, final_prompt_list, use_gpt_4)
158 |
159 | st.markdown(summary, unsafe_allow_html=True)
160 | with open(f'summaries/{name}_summary.txt', 'w') as f:
161 | f.write(summary)
162 | st.text(f' Summary saved to summaries/{name}_summary.txt')
163 |
164 |
165 |
166 |
167 | def validate_doc_size(doc):
168 | """
169 | Validates the size of the document
170 |
171 | :param doc: doc to validate
172 |
173 | :return: True if the doc is valid, False otherwise
174 | """
175 | if not token_limit(doc, 800000):
176 | st.warning('File or transcript too big!')
177 | return False
178 |
179 | if not token_minimum(doc, 2000):
180 | st.warning('File or transcript too small!')
181 | return False
182 | return True
183 |
184 |
185 | def validate_input(file_or_transcript, use_gpt_4):
186 | """
187 | Validates the user input, and displays warnings if the input is invalid
188 |
189 | :param file_or_transcript: The file uploaded by the user or the YouTube URL entered by the user
190 |
191 | :param use_gpt_4: Whether the user wants to use GPT-4
192 |
193 | :return: True if the input is valid, False otherwise
194 | """
195 | if file_or_transcript == None:
196 | st.warning("Please upload a file or enter a YouTube URL.")
197 | return False
198 |
199 | if not validate_api_key():
200 | st.warning('Key not valid or API is down.')
201 | return False
202 |
203 | if use_gpt_4 and not check_gpt_4():
204 | st.warning('Key not valid for GPT-4.')
205 | return False
206 |
207 | return True
208 |
209 |
210 | def generate_answer(db=None, llm_model=None, hypothetical=False):
211 | user_message = st.session_state.text_input
212 | if db and user_message.strip() != "":
213 | with st.spinner('Generating answer...'):
214 | print('About to call API')
215 | sys_message, sources = qa_from_db(user_message, db, llm_model, hypothetical)
216 | print('Done calling API')
217 | st.session_state.history.append({'message': user_message, 'is_user': True})
218 | st.session_state.history.append({'message': sys_message, 'is_user': False})
219 | st.session_state.sources = []
220 | st.session_state.sources.append(sources)
221 | return sys_message, sources
222 | else:
223 | print(user_message)
224 | print('failed')
225 | print(db)
226 |
227 |
228 | def load_db_from_file_and_create_if_not_exists(file_path):
229 | with st.spinner('Loading chat embeddings...'):
230 | try:
231 | db = load_chat_embeddings(file_path)
232 | print('success')
233 | except RuntimeError:
234 | print('not found')
235 | create_and_save_chat_embeddings(file_path)
236 | db = load_chat_embeddings(file_path)
237 | if db:
238 | st.success('Loaded successfully! Start a chat below.')
239 | else:
240 | st.warning('Something went wrong... failed to load chat embeddings.')
241 | return db
242 |
243 |
244 | def load_dir_chat_embeddings(file_path):
245 | name = os.path.split(file_path)[1].split('.')[0]
246 | embeddings = OpenAIEmbeddings()
247 | try:
248 | db = FAISS.load_local(folder_path='directory_embeddings', index_name=name, embeddings=embeddings)
249 | st.success('Embeddings loaded successfully.')
250 | except Exception as e:
251 | st.warning('Loading embeddings failed. Please try again.')
252 | return None
253 |
254 | return db
--------------------------------------------------------------------------------
/summaries/sparksofagi_summary.txt:
--------------------------------------------------------------------------------
1 |
Key Takeaways:
2 |
3 |
GPT-4 exhibits advanced capabilities in natural language processing, mathematical reasoning, image and music generation, and theory of mind.
4 |
While GPT-4 outperforms previous models like ChatGPT, it still has limitations and is not perfect at what it does.
5 |
GPT-4 can generate output-consistent and process-consistent explanations, but may still produce nonsensical or wrong outputs.
6 |
The AI can provide useful suggestions for data visualization and even implement them in Python.
7 |
8 |
9 |
The document discusses the capabilities and limitations of GPT-4, a large language model developed by OpenAI. GPT-4 exhibits advanced abilities in natural language processing, allowing it to interact with APIs for calendar and email functions, answer questions, and coordinate events. It also demonstrates improved mathematical reasoning compared to previous models like ChatGPT, although it still struggles with large numbers and complicated expressions.
10 |
11 |
GPT-4 can generate and manipulate images and music, successfully combining objects with letters of the alphabet and generating 2D and 3D images according to detailed instructions. However, it does not seem to understand harmony in music generation. The model's ability to generate code from prompts could be combined with existing image synthesis models to produce higher-quality images.
12 |
13 |
The document highlights GPT-4's progress towards artificial general intelligence (AGI), with its capabilities at or above human-level in various tasks. However, it is not perfect and does not come close to being able to do anything that a human can do. GPT-4 performs better than GPT-3 in categories related to people and places, but fact-checking may require inputs from an external corpus. The model also occasionally states that neither answer is correct, and its rationale for decisions may differ from human experts.
14 |
15 |
GPT-4 demonstrates advanced theory of mind and is remarkably good at generating reasonable and coherent explanations. The quality of explanations can be evaluated by checking output consistency and process consistency. The model is good at generating output-consistent explanations, even when the output is nonsensical or wrong. Experiments have been conducted to test the process consistency of GPT-4's explanations.
16 |
17 |
In a text-based game scenario, GPT-4 is able to navigate through different rooms and interact with objects. Additionally, the AI can provide useful suggestions for data visualization, such as using a network graph to represent IMDb dataset information, and even implement the visualization in Python using networkx, pandas, and plotly libraries.
18 |
19 |
Overall, GPT-4 represents significant progress in AI research, particularly in natural language processing, mathematical reasoning, image and music generation, and theory of mind. However, it still has limitations and is not perfect at what it does, indicating that there is still room for improvement in the pursuit of AGI.
--------------------------------------------------------------------------------
/summary_utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import time
3 | import urllib.parse
4 | from concurrent.futures import ThreadPoolExecutor, as_completed
5 |
6 | import tiktoken
7 | from langchain import PromptTemplate
8 | from langchain.chains.summarize import load_summarize_chain
9 | from langchain.chat_models import ChatOpenAI
10 | from langchain.document_loaders import YoutubeLoader, TextLoader, PyPDFLoader, UnstructuredEPubLoader
11 | from langchain.embeddings import OpenAIEmbeddings
12 | from langchain.schema import Document
13 |
14 | import numpy as np
15 | from langchain.text_splitter import TokenTextSplitter
16 | from sklearn.cluster import KMeans
17 |
18 | import matplotlib.pyplot as plt
19 |
20 | import streamlit as st
21 |
22 |
23 | def doc_loader(file_path: str):
24 | """
25 | Load the contents of a text document from a file path into a loaded langchain Document object.
26 |
27 | :param file_path: The path to the text document to load.
28 |
29 | :return: A langchain Document object.
30 | """
31 | if file_path.endswith('.txt'):
32 | loader = TextLoader(file_path)
33 | elif file_path.endswith('.pdf'):
34 | loader = PyPDFLoader(file_path)
35 | elif file_path.endswith('.epub'):
36 | try:
37 | loader = UnstructuredEPubLoader(file_path)
38 | except Exception as e:
39 | st.warning('Error loading file - ensure you have pandoc installed and added to PATH.')
40 | return None
41 |
42 | return loader.load()
43 |
44 |
45 | def directory_loader(directory):
46 | files = os.listdir(directory)
47 | documents = []
48 | mixed_documents = []
49 | for file in files:
50 | if file.endswith('.txt'):
51 | loader = TextLoader(os.path.join(directory, file))
52 | documents.append(loader.load())
53 | elif file.endswith('.pdf'):
54 | loader = PyPDFLoader(os.path.join(directory, file))
55 | documents.append(loader.load())
56 | elif file.endswith('.epub'):
57 | loader = UnstructuredEPubLoader(os.path.join(directory, file))
58 | documents.append(loader.load())
59 | for doc in documents:
60 | for section in doc:
61 | mixed_documents.append(section)
62 | return mixed_documents
63 |
64 |
65 |
66 | def token_counter(text: str):
67 | """
68 | Count the number of tokens in a string of text.
69 |
70 | :param text: The text to count the tokens of.
71 |
72 | :return: The number of tokens in the text.
73 | """
74 | encoding = tiktoken.get_encoding('cl100k_base')
75 | token_list = encoding.encode(text, disallowed_special=())
76 | tokens = len(token_list)
77 | return tokens
78 |
79 |
80 | def doc_to_text(document):
81 | """
82 | Convert a langchain Document object into a string of text.
83 |
84 | :param document: The loaded langchain Document object to convert.
85 |
86 | :return: A string of text.
87 | """
88 | text = ''
89 | for i in document:
90 | text += i.page_content
91 | special_tokens = ['>|endoftext|', '<|fim_prefix|', '<|fim_middle|', '<|fim_suffix|', '<|endofprompt|']
92 | words = text.split()
93 | filtered_words = [word for word in words if word not in special_tokens]
94 | text = ' '.join(filtered_words)
95 | return text
96 |
97 | def remove_special_tokens(docs):
98 | special_tokens = ['>|endoftext|', '<|fim_prefix|', '<|fim_middle|', '<|fim_suffix|', '<|endofprompt|>']
99 | for doc in docs:
100 | content = doc.page_content
101 | for special in special_tokens:
102 | content = content.replace(special, '')
103 | doc.page_content = content
104 | return docs
105 |
106 |
107 |
108 | def embed_docs_openai(docs):
109 | """
110 | Embed a list of documents into a list of vectors.
111 |
112 | :param docs: A list of documents to embed.
113 |
114 | :param api_key: The OpenAI API key to use for embedding.
115 |
116 | :return: A list of vectors.
117 | """
118 | docs = remove_special_tokens(docs)
119 | embeddings = OpenAIEmbeddings()
120 | vectors = embeddings.embed_documents([x.page_content for x in docs])
121 | return vectors
122 |
123 |
124 | def kmeans_clustering(vectors, num_clusters=None):
125 | """
126 | Cluster a list of vectors using K-Means clustering.
127 |
128 | :param vectors: A list of vectors to cluster.
129 |
130 | :param num_clusters: The number of clusters to use. If None, the optimal number of clusters will be determined.
131 |
132 | :return: A K-Means clustering object.
133 | """
134 | if num_clusters is None:
135 | inertia_values = calculate_inertia(vectors)
136 | num_clusters = determine_optimal_clusters(inertia_values)
137 | print(f'Optimal number of clusters: {num_clusters}')
138 |
139 | kmeans = KMeans(n_clusters=num_clusters, random_state=42).fit(vectors)
140 | return kmeans
141 |
142 |
143 | def get_closest_vectors(vectors, kmeans):
144 | """
145 | Get the closest vectors to the cluster centers of a K-Means clustering object.
146 |
147 | :param vectors: A list of vectors to cluster.
148 |
149 | :param kmeans: A K-Means clustering object.
150 |
151 | :return: A list of indices of the closest vectors to the cluster centers.
152 | """
153 | closest_indices = []
154 | for i in range(len(kmeans.cluster_centers_)):
155 | distances = np.linalg.norm(vectors - kmeans.cluster_centers_[i], axis=1)
156 | closest_index = np.argmin(distances)
157 | closest_indices.append(closest_index)
158 |
159 | selected_indices = sorted(closest_indices)
160 | return selected_indices
161 |
162 |
163 | def map_vectors_to_docs(indices, docs):
164 | """
165 | Map a list of indices to a list of loaded langchain Document objects.
166 |
167 | :param indices: A list of indices to map.
168 |
169 | :param docs: A list of langchain Document objects to map to.
170 |
171 | :return: A list of loaded langchain Document objects.
172 | """
173 | selected_docs = [docs[i] for i in indices]
174 | return selected_docs
175 |
176 |
177 | def create_summarize_chain(prompt_list):
178 | """
179 | Create a langchain summarize chain from a list of prompts.
180 |
181 | :param prompt_list: A list containing the template, input variables, and llm to use for the chain.
182 |
183 | :return: A langchain summarize chain.
184 | """
185 | template = PromptTemplate(template=prompt_list[0], input_variables=([prompt_list[1]]))
186 | chain = load_summarize_chain(llm=prompt_list[2], chain_type='stuff', prompt=template)
187 | return chain
188 |
189 |
190 | def parallelize_summaries(summary_docs, initial_chain, progress_bar, max_workers=4):
191 | """
192 | Summarize a list of loaded langchain Document objects using multiple langchain summarize chains in parallel.
193 |
194 | :param summary_docs: A list of loaded langchain Document objects to summarize.
195 |
196 | :param initial_chain: A langchain summarize chain to use for summarization.
197 |
198 | :param progress_bar: A streamlit progress bar to display the progress of the summarization.
199 |
200 | :param max_workers: The maximum number of workers to use for parallelization.
201 |
202 | :return: A list of summaries.
203 | """
204 | doc_summaries = []
205 | with ThreadPoolExecutor(max_workers=max_workers) as executor:
206 | future_to_doc = {executor.submit(initial_chain.run, [doc]): doc.page_content for doc in summary_docs}
207 |
208 | for future in as_completed(future_to_doc):
209 | doc = future_to_doc[future]
210 |
211 | try:
212 | summary = future.result()
213 |
214 | except Exception as exc:
215 | print(f'{doc} generated an exception: {exc}')
216 |
217 | else:
218 | doc_summaries.append(summary)
219 | num = (len(doc_summaries)) / (len(summary_docs) + 1)
220 | progress_bar.progress(num) # Remove this line and all references to it if you are not using Streamlit.
221 | return doc_summaries
222 |
223 |
224 |
225 |
226 |
227 | def create_summary_from_docs(summary_docs, initial_chain, final_sum_list, use_gpt_4):
228 | """
229 | Summarize a list of loaded langchain Document objects using multiple langchain summarize chains.
230 |
231 | :param summary_docs: A list of loaded langchain Document objects to summarize.
232 |
233 | :param initial_chain: The initial langchain summarize chain to use.
234 |
235 | :param final_sum_list: A list containing the template, input variables, and llm to use for the final chain.
236 |
237 | :param use_gpt_4: Whether to use GPT-4 or GPT-3.5-turbo for summarization.
238 |
239 | :return: A string containing the summary.
240 | """
241 |
242 | progress = st.progress(0) # Create a progress bar to show the progress of summarization.
243 | # Remove this line and all references to it if you are not using Streamlit.
244 |
245 | doc_summaries = parallelize_summaries(summary_docs, initial_chain, progress_bar=progress)
246 |
247 | summaries = '\n'.join(doc_summaries)
248 | count = token_counter(summaries)
249 |
250 | if use_gpt_4:
251 | max_tokens = 7500 - int(count)
252 | model = 'gpt-4'
253 |
254 | else:
255 | max_tokens = 3800 - int(count)
256 | model = 'gpt-3.5-turbo'
257 |
258 | final_sum_list[2] = ChatOpenAI(temperature=.7, max_tokens=max_tokens, model_name=model)
259 | final_sum_chain = create_summarize_chain(final_sum_list)
260 | summaries = Document(page_content=summaries)
261 | final_summary = final_sum_chain.run([summaries])
262 |
263 | progress.progress(1.0) # Remove this line and all references to it if you are not using Streamlit.
264 | time.sleep(0.4) # Remove this line and all references to it if you are not using Streamlit.
265 | progress.empty() # Remove this line and all references to it if you are not using Streamlit.
266 |
267 | return final_summary
268 |
269 |
270 | def split_by_tokens(doc, num_clusters, ratio=5, minimum_tokens=200, maximum_tokens=2000):
271 | """
272 | Split a langchain Document object into a list of smaller langchain Document objects.
273 |
274 | :param doc: The langchain Document object to split.
275 |
276 | :param num_clusters: The number of clusters to use.
277 |
278 | :param ratio: The ratio of documents to clusters to use for splitting.
279 |
280 | :param minimum_tokens: The minimum number of tokens to use for splitting.
281 |
282 | :param maximum_tokens: The maximum number of tokens to use for splitting.
283 |
284 | :return: A list of langchain Document objects.
285 | """
286 | text_doc = doc_to_text(doc)
287 | tokens = token_counter(text_doc)
288 | chunks = num_clusters * ratio
289 | max_tokens = int(tokens / chunks)
290 | max_tokens = max(minimum_tokens, min(max_tokens, maximum_tokens))
291 | overlap = int(max_tokens/10)
292 |
293 | splitter = TokenTextSplitter(chunk_size=max_tokens, chunk_overlap=overlap)
294 | split_doc = splitter.create_documents([text_doc])
295 | return split_doc
296 |
297 |
298 | def extract_summary_docs(langchain_document, num_clusters, find_clusters):
299 | """
300 | Automatically convert a single langchain Document object into a list of smaller langchain Document objects that represent each cluster.
301 |
302 | :param langchain_document: The langchain Document object to summarize.
303 |
304 | :param num_clusters: The number of clusters to use.
305 |
306 | :param find_clusters: Whether to find the optimal number of clusters to use.
307 |
308 | :return: A list of langchain Document objects.
309 | """
310 | split_document = split_by_tokens(langchain_document, num_clusters)
311 | vectors = embed_docs_openai(split_document)
312 |
313 | if find_clusters:
314 | kmeans = kmeans_clustering(vectors, None)
315 |
316 | else:
317 | kmeans = kmeans_clustering(vectors, num_clusters)
318 |
319 | indices = get_closest_vectors(vectors, kmeans)
320 | summary_docs = map_vectors_to_docs(indices, split_document)
321 | return summary_docs
322 |
323 |
324 | def doc_to_final_summary(langchain_document, num_clusters, initial_prompt_list, final_prompt_list, use_gpt_4, find_clusters=False):
325 | """
326 | Automatically summarize a single langchain Document object using multiple langchain summarize chains.
327 |
328 | :param langchain_document: The langchain Document object to summarize.
329 |
330 | :param num_clusters: The number of clusters to use.
331 |
332 | :param initial_prompt_list: The initial langchain summarize chain to use.
333 |
334 | :param final_prompt_list: A list containing the template, input variables, and llm to use for the final chain.
335 |
336 | :param use_gpt_4: Whether to use GPT-4 or GPT-3.5-turbo for summarization.
337 |
338 | :param find_clusters: Whether to automatically find the optimal number of clusters to use.
339 |
340 | :return: A string containing the summary.
341 | """
342 | initial_prompt_list = create_summarize_chain(initial_prompt_list)
343 | summary_docs = extract_summary_docs(langchain_document, num_clusters, find_clusters)
344 | output = create_summary_from_docs(summary_docs, initial_prompt_list, final_prompt_list, use_gpt_4)
345 | return output
346 |
347 |
348 | def summary_prompt_creator(prompt, input_var, llm):
349 | """
350 | Create a list containing the template, input variables, and llm to use for a langchain summarize chain.
351 |
352 | :param prompt: The template to use for the chain.
353 |
354 | :param input_var: The input variables to use for the chain.
355 |
356 | :param llm: The llm to use for the chain.
357 |
358 | :return: A list containing the template, input variables, and llm to use for the chain.
359 | """
360 | prompt_list = [prompt, input_var, llm]
361 | return prompt_list
362 |
363 |
364 | def extract_video_id(video_url):
365 | """
366 | Extract the YouTube video ID from a YouTube video URL.
367 |
368 | :param video_url: The URL of the YouTube video.
369 |
370 | :return: The ID of the YouTube video.
371 | """
372 | parsed_url = urllib.parse.urlparse(video_url)
373 | if parsed_url.hostname == 'youtu.be':
374 | return parsed_url.path[1:]
375 |
376 | elif parsed_url.hostname in ('www.youtube.com', 'youtube.com'):
377 |
378 | if parsed_url.path == '/watch':
379 | p = urllib.parse.parse_qs(parsed_url.query)
380 | return p.get('v', [None])[0]
381 |
382 | elif parsed_url.path.startswith('/embed/'):
383 | return parsed_url.path.split('/embed/')[1]
384 |
385 | elif parsed_url.path.startswith('/v/'):
386 | return parsed_url.path.split('/v/')[1]
387 |
388 | return None
389 |
390 |
391 | def transcript_loader(video_url):
392 | """
393 | Load the transcript of a YouTube video into a loaded langchain Document object.
394 |
395 | :param video_url: The URL of the YouTube video to load the transcript of.
396 |
397 | :return: A loaded langchain Document object.
398 | """
399 | transcript = YoutubeLoader(video_id=extract_video_id(video_url))
400 | loaded = transcript.load()
401 | return loaded
402 |
403 |
404 | def calculate_inertia(vectors, max_clusters=12):
405 | """
406 | Calculate the inertia values for a range of clusters.
407 |
408 | :param vectors: A list of vectors to cluster.
409 |
410 | :param max_clusters: The maximum number of clusters to use.
411 |
412 | :return: A list of inertia values.
413 | """
414 | inertia_values = []
415 | for num_clusters in range(1, max_clusters + 1):
416 | kmeans = KMeans(n_clusters=num_clusters, random_state=42).fit(vectors)
417 | inertia_values.append(kmeans.inertia_)
418 | return inertia_values
419 |
420 |
421 | def plot_elbow(inertia_values):
422 | """
423 | Plot the inertia values for a range of clusters. Just for fun!
424 |
425 | :param inertia_values: A list of inertia values.
426 |
427 | :return: None.
428 | """
429 | plt.plot(inertia_values)
430 | plt.xlabel('Number of Clusters')
431 | plt.ylabel('Inertia')
432 | plt.show()
433 |
434 |
435 | def determine_optimal_clusters(inertia_values):
436 | """
437 | Determine the optimal number of clusters to use based on the inertia values.
438 |
439 | :param inertia_values: A list of inertia values.
440 |
441 | :return: The optimal number of clusters to use.
442 | """
443 | distances = []
444 | for i in range(len(inertia_values) - 1):
445 | p1 = np.array([i + 1, inertia_values[i]])
446 | p2 = np.array([i + 2, inertia_values[i + 1]])
447 | d = np.linalg.norm(np.cross(p2 - p1, p1 - np.array([1,0]))) / np.linalg.norm(p2 - p1)
448 | distances.append(d)
449 | optimal_clusters = distances.index(max(distances)) + 2
450 | return optimal_clusters
--------------------------------------------------------------------------------
/test.env:
--------------------------------------------------------------------------------
1 | OPENAI_API_KEY = your-key
2 | MODEL_TYPE = LlamaCpp or GPT4ALL
3 | MODEL_PATH = models/modelname.bin
--------------------------------------------------------------------------------