├── misc └── RemoLogo.png ├── requirements.txt ├── app ├── embedding_pipline.py ├── utils.py └── app.py ├── raptor_pipeline ├── pipeline.py ├── convert_html2md.py └── raptor.py ├── .gitignore └── README.md /misc/RemoLogo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MSNP1381/advanced_raptor_rag/HEAD/misc/RemoLogo.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | langchain_postgres 2 | langchain-text-splitters 3 | unstructured[md] 4 | langchain_postgres 5 | langchain-google-vertexai 6 | semantic-text-splitter 7 | langchain 8 | umap-learn 9 | scikit-learn 10 | langchain_community 11 | tiktoken 12 | langchain-openai 13 | semantic-text-splitter 14 | sentence-transformers -------------------------------------------------------------------------------- /app/embedding_pipline.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | from dotenv import load_dotenv, find_dotenv 4 | import vertexai 5 | 6 | from langchain_google_vertexai import VertexAIEmbeddings 7 | from langchain_postgres import PGVector 8 | from tqdm import tqdm 9 | 10 | load_dotenv(find_dotenv(), override=True) 11 | 12 | # %% 13 | PROJECT_ID = "" # @param {type:"string"} 14 | LOCATION = "us-central1" # @param {type:"string"} 15 | 16 | 17 | vertexai.init(project=PROJECT_ID, location=LOCATION) 18 | 19 | # %% 20 | 21 | 22 | embd = VertexAIEmbeddings(model_name="text-embedding-005") 23 | 24 | 25 | # %% 26 | 27 | 28 | # See docker command above to launch a postgres instance with pgvector enabled. 29 | connection = os.environ["PG_CONN"] 30 | collection_name = "my_rag" 31 | 32 | 33 | vectorstore = PGVector( 34 | embeddings=embd, 35 | collection_name=collection_name, 36 | connection=connection, 37 | use_jsonb=True, 38 | ) 39 | # Now, use all_texts to build the vectorstore with Chroma 40 | retriever = vectorstore.as_retriever() 41 | 42 | 43 | # %% 44 | 45 | 46 | def reset_db(vectorstore): 47 | vectorstore.delete_collection() 48 | 49 | vectorstore.drop_tables() 50 | 51 | vectorstore.create_tables_if_not_exists() 52 | 53 | vectorstore.create_collection() 54 | 55 | 56 | def save_results(results, filename): 57 | """Saves the results to a pickle file. 58 | 59 | Args: 60 | results: The results to save. 61 | filename: The name of the file to save to. 62 | """ 63 | with open(filename, "wb") as f: 64 | pickle.dump(results, f) 65 | 66 | 67 | def load_results(filename): 68 | """Loads the results from a pickle file. 69 | 70 | Args: 71 | filename: The name of the file to load from. 72 | 73 | Returns: 74 | The loaded results. 75 | """ 76 | with open(filename, "rb") as f: 77 | return pickle.load(f) 78 | 79 | 80 | # %% 81 | 82 | # reset_db() 83 | 84 | # %% 85 | all_texts = load_results("all_texts.pickle") 86 | 87 | # %% 88 | 89 | batch_size = 1000 # Adjust based on your vector dimensions 90 | for i in tqdm(range(0, len(all_texts), batch_size)): 91 | batch = all_texts[i : i + batch_size] 92 | vectorstore.add_documents(batch) 93 | -------------------------------------------------------------------------------- /raptor_pipeline/pipeline.py: -------------------------------------------------------------------------------- 1 | # %% 2 | import matplotlib.pyplot as plt 3 | from langchain_community.document_loaders import UnstructuredMarkdownLoader 4 | from langchain_community.document_loaders import DirectoryLoader 5 | from raptor import num_tokens_from_string, recursive_embed_cluster_summarize 6 | import pickle 7 | from langchain_text_splitters import RecursiveCharacterTextSplitter 8 | import asyncio 9 | from dotenv import load_dotenv, find_dotenv 10 | 11 | load_dotenv(find_dotenv(), override=True) 12 | 13 | # %% 14 | ####################### 15 | # 16 | # load Documents 17 | # 18 | ###################### 19 | 20 | # you can add title and url to document metadata for usage in rag app 21 | loader = DirectoryLoader( 22 | "./markdown_output/", 23 | glob="**/*.md", 24 | show_progress=True, 25 | max_concurrency=-1, 26 | silent_errors=True, 27 | loader_cls=UnstructuredMarkdownLoader, 28 | ) 29 | docs = loader.load() 30 | print(len(docs)) 31 | with open("loaded_docs.pickle", "wb") as f: 32 | pickle.dump(docs, f) 33 | # with open("loaded_docs.pickle", "rb") as f: 34 | # docs=pickle.load(f) 35 | # # Doc texts 36 | docs_texts = [d.page_content for d in docs] 37 | print("Number of documents:", len(docs_texts)) 38 | 39 | # %% 40 | 41 | ####################### 42 | # 43 | # visualize token counts 44 | # 45 | ###################### 46 | 47 | counts = [num_tokens_from_string(d, "cl100k_base") for d in docs_texts] 48 | 49 | plt.figure(figsize=(10, 6)) 50 | plt.hist(counts, bins=30, color="blue", edgecolor="black", alpha=0.7) 51 | plt.title("Histogram of Token Counts") 52 | plt.xlabel("Token Count") 53 | plt.ylabel("Frequency") 54 | plt.grid(axis="y", alpha=0.75) 55 | 56 | plt.savefig("token_counts_histogram.png") 57 | # %% 58 | 59 | 60 | concatenated_content = "\n\n\n --- \n\n\n".join([doc.page_content for doc in docs]) 61 | 62 | chunk_size_tok = 3600 # choose based on visualization 63 | 64 | text_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder( 65 | chunk_size=chunk_size_tok, chunk_overlap=0 66 | ) 67 | texts_split = text_splitter.split_text(concatenated_content) 68 | 69 | with open("splited_txt.pickle", "wb") as f: 70 | texts_split = pickle.dump(texts_split, f) 71 | 72 | # with open("splited_txt.pickle", "rb") as f: 73 | # texts_split =pickle.load( f) 74 | 75 | 76 | print("split done") 77 | # %% 78 | ####################### 79 | # 80 | # run main process 81 | # 82 | ###################### 83 | 84 | 85 | leaf_texts = texts_split 86 | # For large document sets, we should: 87 | # 1. Use higher dimensionality reduction for better cluster separation 88 | # 2. Adjust threshold for more precise clusters 89 | # 3. Set appropriate recursion levels for hierarchical summarization 90 | 91 | # Note: these are in raptor module 92 | # in embed(texts, max_tokens_per_request=1_000_000, batch_size=512) i have implemented a max_tokens_per_request parameter to avoid the error of exceeding the token limit. 93 | # You can set it to a value that is less than the max token limit of your model. 94 | 95 | # in embed_cluster_summarize_texts( texts: List[str], level: int, batch_size=80) i have added a batching to avoid the error of exceeding the token limit. 96 | 97 | 98 | # Configure and run the recursive embedding and clustering 99 | loop = asyncio.get_event_loop() 100 | results = loop.run_until_complete( 101 | recursive_embed_cluster_summarize(leaf_texts, level=1, n_levels=3) 102 | ) 103 | 104 | 105 | def save_results(results, filename): 106 | """Saves the results to a pickle file. 107 | 108 | Args: 109 | results: The results to save. 110 | filename: The name of the file to save to. 111 | """ 112 | with open(filename, "wb") as f: 113 | pickle.dump(results, f) 114 | 115 | 116 | def load_results(filename): 117 | """Loads the results from a pickle file. 118 | 119 | Args: 120 | filename: The name of the file to load from. 121 | 122 | Returns: 123 | The loaded results. 124 | """ 125 | with open(filename, "rb") as f: 126 | return pickle.load(f) 127 | 128 | 129 | save_results(results, "results.pickle") 130 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # UV 98 | # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | #uv.lock 102 | 103 | # poetry 104 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 105 | # This is especially recommended for binary packages to ensure reproducibility, and is more 106 | # commonly ignored for libraries. 107 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 108 | #poetry.lock 109 | 110 | # pdm 111 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 112 | #pdm.lock 113 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 114 | # in version control. 115 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 116 | .pdm.toml 117 | .pdm-python 118 | .pdm-build/ 119 | 120 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 121 | __pypackages__/ 122 | 123 | # Celery stuff 124 | celerybeat-schedule 125 | celerybeat.pid 126 | 127 | # SageMath parsed files 128 | *.sage.py 129 | 130 | # Environments 131 | .env 132 | .venv 133 | env/ 134 | venv/ 135 | ENV/ 136 | env.bak/ 137 | venv.bak/ 138 | 139 | # Spyder project settings 140 | .spyderproject 141 | .spyproject 142 | 143 | # Rope project settings 144 | .ropeproject 145 | 146 | # mkdocs documentation 147 | /site 148 | 149 | # mypy 150 | .mypy_cache/ 151 | .dmypy.json 152 | dmypy.json 153 | 154 | # Pyre type checker 155 | .pyre/ 156 | 157 | # pytype static type analyzer 158 | .pytype/ 159 | 160 | # Cython debug symbols 161 | cython_debug/ 162 | 163 | # PyCharm 164 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 165 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 166 | # and can be added to the global gitignore or merged into this file. For a more nuclear 167 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 168 | #.idea/ 169 | 170 | # Ruff stuff: 171 | .ruff_cache/ 172 | 173 | # PyPI configuration file 174 | .pypirc 175 | 176 | jobs 177 | 178 | html_output 179 | 180 | markdown_output 181 | embd_* 182 | archive* 183 | results.pickle 184 | 185 | *.zip 186 | application_default_credentials.json 187 | *.pickle -------------------------------------------------------------------------------- /raptor_pipeline/convert_html2md.py: -------------------------------------------------------------------------------- 1 | import os 2 | import trafilatura 3 | from tqdm import tqdm 4 | import multiprocessing 5 | from functools import partial 6 | 7 | 8 | def html_to_markdown(html_file_path, output_dir): 9 | """ 10 | Convert HTML file to Markdown using trafilatura, preserving directory structure. 11 | 12 | Args: 13 | html_file_path (str): Path to the HTML file. 14 | output_dir (str): Base directory to save the Markdown output, mirroring the HTML file's directory structure. 15 | 16 | Returns: 17 | tuple: (html_file_path, success, error_message) 18 | """ 19 | try: 20 | with open(html_file_path, "r", encoding="utf-8") as f: 21 | html_content = f.read() 22 | 23 | # Extract the main content using trafilatura 24 | downloaded = trafilatura.extract(html_content) 25 | 26 | if downloaded: 27 | # Determine the relative path from the base directory 28 | relative_path = os.path.relpath(html_file_path, start=HTML_DIR) 29 | 30 | # Create the output directory path 31 | output_path = os.path.join(output_dir, os.path.dirname(relative_path)) 32 | 33 | # Ensure the output directory exists 34 | os.makedirs(output_path, exist_ok=True) 35 | 36 | # Create the output file path 37 | output_file_path = os.path.join( 38 | output_path, 39 | os.path.splitext(os.path.basename(html_file_path))[0] + ".md", 40 | ) 41 | 42 | # Save the Markdown content to the output file 43 | with open(output_file_path, "w", encoding="utf-8") as md_file: 44 | md_file.write(downloaded) 45 | 46 | return (html_file_path, True, None) # Indicate success 47 | else: 48 | return ( 49 | html_file_path, 50 | False, 51 | "Failed to extract content", 52 | ) # Indicate failure to extract 53 | 54 | except Exception as e: 55 | return (html_file_path, False, str(e)) # Indicate failure with error message 56 | 57 | 58 | def process_directory(html_dir, output_dir, num_processes=None): 59 | """ 60 | Process all HTML files in a directory and its subdirectories. 61 | Files in each directory are processed in parallel. 62 | 63 | Args: 64 | html_dir (str): Root directory containing HTML files. 65 | output_dir (str): Base directory to save the Markdown output, mirroring the HTML file's directory structure. 66 | num_processes (int, optional): Number of processes to use. Defaults to CPU count. 67 | """ 68 | if num_processes is None: 69 | num_processes = multiprocessing.cpu_count() 70 | 71 | # Count total files for progress tracking 72 | total_files = 0 73 | for root, _, files in os.walk(html_dir): 74 | total_files += len( 75 | [f for f in files if f.endswith(".html") or f.endswith(".htm")] 76 | ) 77 | 78 | # Create a progress bar for tracking the overall conversion 79 | with tqdm(total=total_files, desc="Converting HTML to Markdown") as pbar: 80 | # Process each directory, but parallelize the file processing within each directory 81 | for root, _, files in os.walk(html_dir): 82 | # Filter for HTML files in the current directory 83 | html_files = [ 84 | os.path.join(root, file) 85 | for file in files 86 | if file.endswith(".html") or file.endswith(".htm") 87 | ] 88 | 89 | if not html_files: 90 | continue 91 | 92 | # Create a process pool 93 | with multiprocessing.Pool(processes=num_processes) as pool: 94 | # Create a partial function with fixed arguments 95 | process_func = partial(html_to_markdown, output_dir=output_dir) 96 | 97 | # Process files in parallel and handle results as they complete 98 | for file_path, success, error_msg in pool.imap_unordered( 99 | process_func, html_files 100 | ): 101 | pbar.update(1) 102 | if not success: 103 | pbar.write(f"Failed to convert {file_path}: {error_msg}") 104 | 105 | 106 | # Example Usage (replace with your actual directories) 107 | HTML_DIR = "html_output" # Replace with the path to your HTML directory 108 | MARKDOWN_DIR = ( 109 | "markdown_output" # Replace with the desired output directory for Markdown files 110 | ) 111 | 112 | if __name__ == "__main__": 113 | # Optional: specify number of processes, defaults to CPU count if not specified 114 | num_processes = multiprocessing.cpu_count() 115 | process_directory(HTML_DIR, MARKDOWN_DIR, num_processes) 116 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Advanced RAG + Raptor 2 | --- 3 | 4 | This project was conducted in partnership with [remolab](https://remolab.fr/), a startup solutions hub. 5 | 6 |

7 | 8 | Remolab Logo 9 | 10 |

11 | 12 | --- 13 | An advanced Retrieval-Augmented Generation (RAG) system with Raptor integration for enhanced semantic search and document retrieval. 14 | 15 | ```mermaid 16 | graph LR; 17 | 18 | subgraph raptor; 19 | ScrapedData --> RaptorPipLine 20 | end 21 | subgraph app 22 | %% direction ; 23 | InputQuery-->Contextualize-->QueryExpansion-->Retrieval-->ReRanking-->Output 24 | end 25 | 26 | raptor --> app 27 | ``` 28 | 29 | ## Overview 30 | 31 | This project implements an advanced RAG system with a two-stage pipeline: 32 | 33 | 1. **Raptor Pipeline**: A hierarchical document processing system that takes scraped web content, converts it to markdown, and creates semantically clustered document embeddings with hierarchical summarization. 34 | 35 | 2. **Application**: A sophisticated query processing system that handles user input through contextualization, query expansion, retrieval, and re-ranking for accurate, contextually-relevant responses. 36 | 37 | ## Project Structure 38 | 39 | ``` 40 | Advanced RAG + Raptor 41 | ├── app/ 42 | │ ├── app.py # Main Streamlit application 43 | │ ├── embedding_pipline.py # Embedding pipeline for documents 44 | │ └── utils.py # Utility functions for query processing 45 | └── raptor_pipeline/ 46 | ├── convert_html2md.py # HTML to Markdown converter 47 | ├── pipeline.py # Document processing and embedding generation 48 | └── raptor.py # Hierarchical clustering and summarization 49 | ``` 50 | 51 | ## Features 52 | 53 | ### Raptor Pipeline 54 | 55 | - **HTML to Markdown Conversion**: Converts HTML content to clean Markdown format 56 | - **Document Processing**: Handles large document collections efficiently 57 | - **Hierarchical Embedding**: Generates semantically rich embeddings 58 | - **Clustering**: Groups similar documents using advanced clustering algorithms 59 | - **Summarization**: Creates hierarchical summaries of document clusters 60 | 61 | ### Application Flow 62 | 63 | - **Contextualization**: Understands query in context of conversation history 64 | - **Query Expansion**: Generates multiple variations of the query to improve retrieval 65 | - **Retrieval**: Fetches relevant documents from the vector store 66 | - **Re-Ranking**: Prioritizes the most relevant documents using a cross-encoder model 67 | - **Response Generation**: Creates concise, accurate responses based on retrieved information 68 | 69 | ## Technologies Used 70 | 71 | - **LangChain**: Core framework for building the RAG pipeline 72 | - **Google Vertex AI**: For embeddings (text-embedding-005) and LLM (gemini-2.0-flash) 73 | - **PostgreSQL with pgvector**: Vector database for document storage and retrieval 74 | - **UMAP & Gaussian Mixture Models**: For dimensionality reduction and clustering 75 | - **Streamlit**: For the interactive user interface 76 | - **Trafilatura**: For clean HTML content extraction 77 | 78 | ## Getting Started 79 | 80 | ### Prerequisites 81 | 82 | - Python 3.8+ 83 | - PostgreSQL with pgvector extension 84 | - Google Cloud Platform account with Vertex AI enabled 85 | 86 | ### Installation 87 | 88 | 1. Clone the repository 89 | 2. Install dependencies: 90 | ``` 91 | pip install -r requirements.txt 92 | ``` 93 | 94 | 3. Set up environment variables: 95 | ``` 96 | PG_CONN="postgresql+psycopg://user:password@localhost:5432/database" 97 | ``` 98 | 99 | ### Data Processing 100 | 101 | 1. Convert HTML documents to Markdown: 102 | ``` 103 | python raptor_pipeline/convert_html2md.py 104 | ``` 105 | 106 | 2. Process documents through the Raptor pipeline: 107 | ``` 108 | python raptor_pipeline/pipeline.py 109 | ``` 110 | 111 | ### Running the Application 112 | 113 | Launch the Streamlit application: 114 | ``` 115 | streamlit run app/app.py 116 | ``` 117 | 118 | ## How It Works 119 | 120 | ### Document Processing Flow 121 | 122 | 1. HTML documents are converted to Markdown format 123 | 2. Documents are split into manageable chunks 124 | 3. Chunks are embedded using Vertex AI embeddings 125 | 4. Embeddings are clustered using hierarchical clustering 126 | 5. Clusters are summarized to create hierarchical summaries 127 | 6. Final embeddings and metadata are stored in PostgreSQL with pgvector 128 | 129 | ### Query Processing Flow 130 | 131 | 1. User query is contextualized with conversation history 132 | 2. Query is expanded to create multiple variations 133 | 3. Expanded queries are used to retrieve relevant documents 134 | 4. Retrieved documents are re-ranked for relevance 135 | 5. Most relevant documents are used to generate a response 136 | 6. Response is presented to the user with source references 137 | 138 | ## Advanced Features 139 | 140 | - **Recursive Clustering**: Multi-level document organization 141 | - **Contextualized Queries**: Consideration of conversation history 142 | - **Query Expansion**: Improved retrieval through query variations 143 | - **Cross-Encoder Re-Ranking**: More accurate document relevance scoring 144 | - **Source Attribution**: References to source documents in responses 145 | -------------------------------------------------------------------------------- /app/utils.py: -------------------------------------------------------------------------------- 1 | from langchain.prompts import PromptTemplate 2 | from langchain_google_vertexai import ChatVertexAI 3 | from pydantic import BaseModel, Field 4 | from langchain_core.vectorstores import VectorStoreRetriever 5 | from langchain_community.document_transformers import LongContextReorder 6 | from langchain_core.prompts import ( 7 | ChatPromptTemplate, 8 | MessagesPlaceholder, 9 | ) # Import the ChatPromptTemplate and MessagesPlaceholder classes 10 | from langchain_core.output_parsers import ( 11 | StrOutputParser, 12 | ) # Import the StrOutputParser class 13 | from sentence_transformers import CrossEncoder 14 | from langchain_core.documents import Document 15 | 16 | cross_encoder = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-6-v2") 17 | 18 | 19 | class LineList(BaseModel): 20 | lines: list[str] = Field(description="expanded queries") 21 | 22 | 23 | def expand_query( 24 | query: str, 25 | model: ChatVertexAI, 26 | ) -> LineList: 27 | QUERY_PROMPT = PromptTemplate( 28 | input_variables=["question", "institute_name"], 29 | template=( 30 | """You are a query optimization assistant for information retrieval. Your task is to improve the chances of finding relevant documents by generating *three distinct variations* of a user's original question. These variations should aim to: 31 | 32 | * **Rephrase:** Express the same basic question using different words and sentence structures. 33 | * **Expand:** Add related terms, synonyms, or closely associated concepts that the user *might* have meant, even if they didn't explicitly mention them. *Be careful* not to add irrelevant information. 34 | * **Consider different aspects:** Think about the question from different angles, such as customer needs, product features, or common issues. 35 | 36 | 37 | Output format: Provide each alternative query on a *new line*. Do not include *any* numbering or labels. Do *not* include the original question in your output. 38 | 39 | Original question: {question}""" 40 | ), 41 | ) 42 | model.temperature = 0.1 43 | 44 | llm_chain = QUERY_PROMPT | model.with_structured_output(LineList) 45 | queries: LineList = llm_chain.invoke(query) 46 | return queries.lines 47 | 48 | 49 | def retrieve_expanded_queries(queries, retriever: VectorStoreRetriever): 50 | print("\n\n\n--------\n\n\n") 51 | docs = [retriever.invoke(query) for query in queries] 52 | unique_contents = set() 53 | unique_docs = [] 54 | for sublist in docs: 55 | for doc in sublist: 56 | if doc.page_content not in unique_contents: 57 | unique_docs.append(doc) 58 | unique_contents.add(doc.page_content) 59 | unique_contents = list(unique_contents) 60 | return unique_docs 61 | 62 | 63 | def rerank(unique_contents: list[Document], query): 64 | pairs = [] 65 | for doc in unique_contents: 66 | pairs.append([query, doc.page_content]) 67 | scores = cross_encoder.predict(pairs) 68 | 69 | scored_docs = zip(scores, unique_contents) 70 | sorted_docs = sorted(scored_docs, reverse=True) 71 | reranked_docs = [doc for _, doc in sorted_docs][0:8] 72 | reordering = LongContextReorder() 73 | reordered_docs = reordering.transform_documents(reranked_docs) 74 | return reordered_docs 75 | 76 | 77 | def contextualize_docs( 78 | llm: ChatVertexAI, retriever: VectorStoreRetriever, query, conversation 79 | ): 80 | # Define the system prompt for contextualizing the question 81 | contextualize_q_system_prompt = """You are a contextualization assistant for La Banque Postale's question-answering system. Your task is to take a follow-up question from a user and the history of the previous conversation, and rephrase the follow-up question into a single, standalone question that includes all the necessary context. 82 | 83 | Here's what you need to do: 84 | 85 | * **Understand the conversation:** Carefully analyze the conversation history to grasp the topic and all important details. 86 | * **Incorporate context:** Integrate relevant information from the conversation history *directly into* the rephrased question. The new question should be understandable *without* needing to see the previous conversation. 87 | * **Maintain clarity:** The rephrased question should be clear, concise, and easy to understand. Use language appropriate for La Banque Postale customers. 88 | * **Focus on information needs:** Ensure that the rephrased question accurately reflects the user's *underlying information need*, even if their original wording was ambiguous. 89 | * **Be autonomous:** the autonomous question should not contain words such as "it", "that", "this", "they", etc., and must contain all the entities for an autonomous question. 90 | 91 | """ 92 | 93 | # Create a ChatPromptTemplate for contextualizing the question 94 | contextualize_q_prompt = ChatPromptTemplate.from_messages( 95 | [ 96 | ("system", contextualize_q_system_prompt), # Set the system prompt 97 | MessagesPlaceholder("chat_history"), # Placeholder for the chat history 98 | ( 99 | "human", 100 | "Entrée de suivi : {input}\nQuestion autonome reformulée :", 101 | ), # Placeholder for the user's input question 102 | ] 103 | ) 104 | chain = contextualize_q_prompt | llm | StrOutputParser() 105 | result = chain.invoke({"chat_history": conversation, "input": query}) 106 | return result 107 | -------------------------------------------------------------------------------- /app/app.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | import os 3 | from typing import List 4 | from langchain.prompts import ChatPromptTemplate 5 | from utils import expand_query, rerank, retrieve_expanded_queries, contextualize_docs 6 | from dotenv import load_dotenv 7 | from langchain_core.output_parsers import StrOutputParser 8 | import streamlit as st 9 | from langchain_postgres import PGVector 10 | from langchain_google_vertexai import VertexAIEmbeddings, ChatVertexAI 11 | import vertexai 12 | from langchain_core.documents import Document 13 | 14 | 15 | load_dotenv(override=True) 16 | 17 | 18 | st.set_page_config(page_title="RAG Chatbot", layout="wide") 19 | st.title("advanced Information Assistant") 20 | 21 | # Initialize session state for chat history 22 | if "messages" not in st.session_state: 23 | st.session_state.messages = [] 24 | st.session_state.conversation = [] 25 | 26 | PROJECT_ID = "" # @param {type:"string"} 27 | LOCATION = "us-central1" # @param {type:"string"} 28 | 29 | 30 | vertexai.init(project=PROJECT_ID, location=LOCATION) 31 | embd = VertexAIEmbeddings(model_name="text-embedding-005") 32 | 33 | # from langchain_openai import ChatOpenAI 34 | 35 | # model = ChatOpenAI(temperature=0, model="gpt-4-1106-preview") 36 | 37 | model = ChatVertexAI(temperature=0, model="gemini-2.0-flash") 38 | 39 | # Load data 40 | # @st.cache_resource 41 | 42 | 43 | # See docker command above to launch a postgres instance with pgvector enabled. 44 | connection = os.environ["PG_CONN"] # Uses psycopg3! 45 | collection_name = "my_rag" 46 | 47 | 48 | vectorstore = PGVector( 49 | embeddings=embd, 50 | collection_name=collection_name, 51 | connection=connection, 52 | use_jsonb=True, 53 | ) 54 | # Now, use all_texts to build the vectorstore with Chroma 55 | retriever = vectorstore.as_retriever( 56 | search_type="mmr", search_kwargs={"k": 6, "lambda_mult": 0.5, "fetch_k": 50} 57 | ) 58 | # Prompt 59 | human_message = """\ 60 | 61 | {question} 62 | 63 | 64 | 65 | {context} 66 | """ 67 | prompt = ChatPromptTemplate.from_messages( 68 | [ 69 | ( 70 | "system", 71 | """You are a friendly and knowledgeable virtual assistant. Your primary role is to answer questions accurately and concisely, *exclusively* using the information found in the provided documents. 72 | 73 | Here's how you will operate: 74 | 75 | * **Data Source:** Your answers *must* be based solely on the retrieved documents wich is provided in tag. Do not use any external knowledge. 76 | * **Conciseness:** Keep your responses short and to the point, using a maximum of five sentences. 77 | * **Language:** Respond *exclusively* in French. 78 | * **Tone:** Maintain a polite, professional, and helpful tone, as if you were a customer service representative. 79 | * **Handling Unknowns:** If the answer is not found within the provided documents, respond with: "Je suis désolé, mais je ne trouve pas la réponse à votre question dans les documents fournis." 80 | * **Conversation History:** Use the previous turns in the conversation to ensure your responses are relevant and contextualized. 81 | * **Easy Language**: use easy and lucid language so any user can get it without any doubt. 82 | * **Greeting (Conditional):** If it's the start of a new conversation, begin your response with a brief, friendly French greeting (e.g., "Bonjour !", "Bien sûr !", "Avec plaisir !"). 83 | 84 | Your overall goal is to provide clear, accurate, and helpful information from the documentation to its customers, in French. 85 | """, 86 | ), 87 | # Add previous conversation turns 88 | ("placeholder", "{conversation_history}"), 89 | ("human", human_message), 90 | ] 91 | ) 92 | 93 | 94 | # Post-processing 95 | def format_docs(docs): 96 | formatted_docs = "" 97 | for doc in docs: 98 | if "url" in doc.metadata: 99 | appeded_txt = f"{doc.metadata['title']}\n{doc.metadata['url']}" 100 | else: 101 | appeded_txt = "provided from summary not any source" 102 | formatted_docs += f""" 103 | 104 | 105 | {appeded_txt} 106 | 107 | {doc.page_content} 108 | 109 | 110 | 111 | """ 112 | return formatted_docs 113 | 114 | 115 | def format_docs_ref(docs: List[Document]): 116 | output_txt = "\n\n\n" 117 | for index, doc in enumerate(docs): 118 | doc_metadata = doc.metadata 119 | if "url" in doc_metadata: 120 | # add #source url as reference in markdown fromat to list of sources 121 | output_txt += f"[source{index + 1}]({doc_metadata['url']}) | " 122 | else: 123 | output_txt += f"[source{index + 1}](#summary) | " 124 | return output_txt 125 | 126 | 127 | # Chain 128 | rag_chain = ( 129 | # {"context": | format_docs, "question": RunnablePassthrough()} 130 | # | 131 | prompt | model | StrOutputParser() 132 | ) 133 | 134 | # Display chat history 135 | for message in st.session_state.messages: 136 | with st.chat_message(message["role"]): 137 | if message["role"] == "user": 138 | print(message["content"]) 139 | 140 | st.markdown(message["content"]) 141 | 142 | # Get user input 143 | if user_query_init := st.chat_input("Ask about la post..."): 144 | # Add user message to chat history 145 | if st.session_state.messages: 146 | user_query = contextualize_docs( 147 | llm=model, 148 | retriever=retriever, 149 | query=user_query_init, 150 | conversation=(st.session_state.conversation), 151 | ) 152 | else: 153 | user_query = "" + user_query_init 154 | print(user_query) 155 | st.session_state.messages.append( 156 | {"role": "user", "content": deepcopy(user_query_init)} 157 | ) 158 | 159 | # Display user message 160 | with st.chat_message("user"): 161 | st.markdown(user_query_init) 162 | with st.chat_message("assistant"): 163 | with st.spinner("Expanding query..."): 164 | expanded_queries = expand_query(user_query, model) 165 | with st.spinner("Retrieving..."): 166 | retrieved = retrieve_expanded_queries(expanded_queries, retriever) 167 | with st.spinner("Reranking..."): 168 | reordered_docs = rerank(retrieved, user_query) 169 | with st.spinner("Thinking..."): 170 | llm_input = { 171 | "context": format_docs(reordered_docs), 172 | "question": user_query, 173 | "conversation_history": st.session_state.conversation, 174 | } 175 | 176 | response = rag_chain.invoke(llm_input) 177 | st.session_state.conversation.append( 178 | ( 179 | "human", 180 | human_message.format( 181 | question=user_query, context=format_docs(reordered_docs) 182 | ), 183 | ) 184 | ) 185 | st.session_state.conversation.append(("ai", response)) 186 | st.markdown(response + format_docs_ref(reordered_docs)) 187 | 188 | # Add assistant response to chat history 189 | st.session_state.messages.append({"role": "assistant", "content": response}) 190 | -------------------------------------------------------------------------------- /raptor_pipeline/raptor.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import json 3 | import tiktoken 4 | from typing import Dict, List, Optional, Tuple 5 | import logging 6 | import vertexai 7 | from langchain_google_vertexai import VertexAIEmbeddings, ChatVertexAI 8 | import numpy as np 9 | import pandas as pd 10 | import umap 11 | from langchain.prompts import ChatPromptTemplate 12 | from langchain_core.output_parsers import StrOutputParser 13 | from sklearn.mixture import GaussianMixture 14 | import os 15 | 16 | RANDOM_SEED = 224 # Fixed seed for reproducibility 17 | summary_dict = {} 18 | # Configure logging 19 | logging.basicConfig( 20 | level=logging.INFO, 21 | format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", 22 | handlers=[logging.FileHandler("la_post_rag.log"), logging.StreamHandler()], 23 | ) 24 | logger = logging.getLogger("la_post_RAG") 25 | 26 | ### --- Code from citations referenced above (added comments and docstrings) --- ### 27 | 28 | 29 | PROJECT_ID = "research-and-development-wlc" # @param {type:"string"} 30 | LOCATION = "us-central1" # @param {type:"string"} 31 | 32 | 33 | vertexai.init(project=PROJECT_ID, location=LOCATION) 34 | 35 | print("Vertex AI initialized") 36 | 37 | embd = VertexAIEmbeddings(model_name="text-embedding-005") 38 | 39 | model = ChatVertexAI(temperature=0, model="gemini-2.0-flash-lite") 40 | print("Model loaded") 41 | 42 | 43 | def global_cluster_embeddings( 44 | embeddings: np.ndarray, 45 | dim: int, 46 | n_neighbors: Optional[int] = None, 47 | metric: str = "cosine", 48 | ) -> np.ndarray: 49 | """ 50 | Perform global dimensionality reduction on the embeddings using UMAP. 51 | 52 | Parameters: 53 | - embeddings: The input embeddings as a numpy array. 54 | - dim: The target dimensionality for the reduced space. 55 | - n_neighbors: Optional; the number of neighbors to consider for each point. 56 | If not provided, it defaults to the square root of the number of embeddings. 57 | - metric: The distance metric to use for UMAP. 58 | 59 | Returns: 60 | - A numpy array of the embeddings reduced to the specified dimensionality. 61 | """ 62 | logger.info(f"Starting global dimensionality reduction with target dim={dim}") 63 | if n_neighbors is None: 64 | n_neighbors = int((len(embeddings) - 1) ** 0.5) 65 | logger.debug(f"Using calculated n_neighbors={n_neighbors}") 66 | 67 | try: 68 | result = umap.UMAP( 69 | n_neighbors=n_neighbors, n_components=dim, metric=metric 70 | ).fit_transform(embeddings) 71 | logger.info(f"Global dimensionality reduction complete: {result.shape}") 72 | return result 73 | except Exception as e: 74 | logger.error(f"Error in global_cluster_embeddings: {str(e)}") 75 | raise 76 | 77 | 78 | def local_cluster_embeddings( 79 | embeddings: np.ndarray, dim: int, num_neighbors: int = 10, metric: str = "cosine" 80 | ) -> np.ndarray: 81 | """ 82 | Perform local dimensionality reduction on the embeddings using UMAP, typically after global clustering. 83 | 84 | Parameters: 85 | - embeddings: The input embeddings as a numpy array. 86 | - dim: The target dimensionality for the reduced space. 87 | - num_neighbors: The number of neighbors to consider for each point. 88 | - metric: The distance metric to use for UMAP. 89 | 90 | Returns: 91 | - A numpy array of the embeddings reduced to the specified dimensionality. 92 | """ 93 | logger.info( 94 | f"Starting local dimensionality reduction with dim={dim}, num_neighbors={num_neighbors}" 95 | ) 96 | try: 97 | result = umap.UMAP( 98 | n_neighbors=num_neighbors, n_components=dim, metric=metric 99 | ).fit_transform(embeddings) 100 | logger.info(f"Local dimensionality reduction complete: {result.shape}") 101 | return result 102 | except Exception as e: 103 | logger.error(f"Error in local_cluster_embeddings: {str(e)}") 104 | raise 105 | 106 | 107 | def get_optimal_clusters( 108 | embeddings: np.ndarray, max_clusters: int = 50, random_state: int = RANDOM_SEED 109 | ) -> int: 110 | """ 111 | Determine the optimal number of clusters using the Bayesian Information Criterion (BIC) with a Gaussian Mixture Model. 112 | 113 | Parameters: 114 | - embeddings: The input embeddings as a numpy array. 115 | - max_clusters: The maximum number of clusters to consider. 116 | - random_state: Seed for reproducibility. 117 | 118 | Returns: 119 | - An integer representing the optimal number of clusters found. 120 | """ 121 | logger.info(f"Finding optimal number of clusters (max={max_clusters})") 122 | max_clusters = min(max_clusters, len(embeddings)) 123 | n_clusters = np.arange(1, max_clusters) 124 | bics = [] 125 | 126 | for n in n_clusters: 127 | logger.debug(f"Testing cluster count: {n}") 128 | gm = GaussianMixture(n_components=n, random_state=random_state) 129 | gm.fit(embeddings) 130 | bics.append(gm.bic(embeddings)) 131 | 132 | optimal_clusters = n_clusters[np.argmin(bics)] 133 | logger.info(f"Optimal number of clusters determined: {optimal_clusters}") 134 | return optimal_clusters 135 | 136 | 137 | def GMM_cluster(embeddings: np.ndarray, threshold: float, random_state: int = 0): 138 | """ 139 | Cluster embeddings using a Gaussian Mixture Model (GMM) based on a probability threshold. 140 | 141 | Parameters: 142 | - embeddings: The input embeddings as a numpy array. 143 | - threshold: The probability threshold for assigning an embedding to a cluster. 144 | - random_state: Seed for reproducibility. 145 | 146 | Returns: 147 | - A tuple containing the cluster labels and the number of clusters determined. 148 | """ 149 | logger.info(f"Starting GMM clustering with threshold={threshold}") 150 | n_clusters = get_optimal_clusters(embeddings) 151 | logger.info(f"Using {n_clusters} clusters for GMM") 152 | 153 | gm = GaussianMixture(n_components=n_clusters, random_state=random_state) 154 | gm.fit(embeddings) 155 | probs = gm.predict_proba(embeddings) 156 | labels = [np.where(prob > threshold)[0] for prob in probs] 157 | 158 | # Log distribution of assignments 159 | cluster_counts = [len(label) for label in labels] 160 | avg_clusters_per_item = sum(cluster_counts) / len(labels) 161 | logger.info( 162 | f"GMM clustering complete. Avg clusters per item: {avg_clusters_per_item:.2f}" 163 | ) 164 | return labels, n_clusters 165 | 166 | 167 | def perform_clustering( 168 | embeddings: np.ndarray, 169 | dim: int, 170 | threshold: float, 171 | ) -> List[np.ndarray]: 172 | """ 173 | Perform clustering on the embeddings by first reducing their dimensionality globally, then clustering 174 | using a Gaussian Mixture Model, and finally performing local clustering within each global cluster. 175 | 176 | Parameters: 177 | - embeddings: The input embeddings as a numpy array. 178 | - dim: The target dimensionality for UMAP reduction. 179 | - threshold: The probability threshold for assigning an embedding to a cluster in GMM. 180 | 181 | Returns: 182 | - A list of numpy arrays, where each array contains the cluster IDs for each embedding. 183 | """ 184 | logger.info( 185 | f"Starting hierarchical clustering process with {len(embeddings)} embeddings" 186 | ) 187 | 188 | if len(embeddings) <= dim + 1: 189 | logger.warning( 190 | f"Insufficient data for clustering ({len(embeddings)} <= {dim + 1}). Assigning all to cluster 0." 191 | ) 192 | return [np.array([0]) for _ in range(len(embeddings))] 193 | 194 | # Global dimensionality reduction 195 | reduced_embeddings_global = global_cluster_embeddings(embeddings, dim) 196 | 197 | # Global clustering 198 | logger.info("Starting global clustering") 199 | global_clusters, n_global_clusters = GMM_cluster( 200 | reduced_embeddings_global, threshold 201 | ) 202 | logger.info(f"Found {n_global_clusters} global clusters") 203 | 204 | all_local_clusters = [np.array([]) for _ in range(len(embeddings))] 205 | total_clusters = 0 206 | 207 | # Iterate through each global cluster to perform local clustering 208 | for i in range(n_global_clusters): 209 | logger.info(f"Processing global cluster {i + 1}/{n_global_clusters}") 210 | # Extract embeddings belonging to the current global cluster 211 | global_cluster_indices = np.array([i in gc for gc in global_clusters]) 212 | global_cluster_embeddings_ = embeddings[global_cluster_indices] 213 | 214 | logger.debug( 215 | f"Global cluster {i} contains {len(global_cluster_embeddings_)} embeddings" 216 | ) 217 | 218 | if len(global_cluster_embeddings_) == 0: 219 | logger.warning(f"Global cluster {i} is empty, skipping") 220 | continue 221 | 222 | if len(global_cluster_embeddings_) <= dim + 1: 223 | # Handle small clusters with direct assignment 224 | logger.debug( 225 | f"Global cluster {i} too small, assigning all to local cluster 0" 226 | ) 227 | local_clusters = [np.array([0]) for _ in global_cluster_embeddings_] 228 | n_local_clusters = 1 229 | else: 230 | # Local dimensionality reduction and clustering 231 | logger.debug(f"Performing local clustering within global cluster {i}") 232 | reduced_embeddings_local = local_cluster_embeddings( 233 | global_cluster_embeddings_, dim 234 | ) 235 | local_clusters, n_local_clusters = GMM_cluster( 236 | reduced_embeddings_local, threshold 237 | ) 238 | logger.debug( 239 | f"Found {n_local_clusters} local clusters in global cluster {i}" 240 | ) 241 | 242 | # Assign local cluster IDs, adjusting for total clusters already processed 243 | for j in range(n_local_clusters): 244 | local_cluster_indices = np.array([j in lc for lc in local_clusters]) 245 | local_cluster_embeddings_ = global_cluster_embeddings_[ 246 | local_cluster_indices 247 | ] 248 | indices = np.where( 249 | (embeddings == local_cluster_embeddings_[:, None]).all(-1) 250 | )[1] 251 | for idx in indices: 252 | all_local_clusters[idx] = np.append( 253 | all_local_clusters[idx], j + total_clusters 254 | ) 255 | 256 | total_clusters += n_local_clusters 257 | 258 | logger.info(f"Clustering complete. Total clusters: {total_clusters}") 259 | return all_local_clusters 260 | 261 | 262 | ### --- Our code below --- ### 263 | 264 | 265 | def load_and_consolidate_embeddings(directory="embd_out"): 266 | """ 267 | Loads all embedding batch files from the specified directory and 268 | consolidates them into a single numpy array. 269 | 270 | Parameters: 271 | - directory: The directory where embedding batch files are stored 272 | 273 | Returns: 274 | - A single numpy array containing all embeddings 275 | """ 276 | logger.info(f"Loading and consolidating embeddings from {directory}") 277 | 278 | try: 279 | # Find all .npy files in the directory 280 | batch_files = glob.glob(os.path.join(directory, "batch_*.npy")) 281 | 282 | # Check if there's a consolidated file already 283 | all_embeddings_path = os.path.join(directory, "all_embeddings.npy") 284 | if os.path.exists(all_embeddings_path): 285 | logger.info(f"Found consolidated embeddings file: {all_embeddings_path}") 286 | batch_files = [f for f in batch_files if f != all_embeddings_path] 287 | 288 | # Load the existing consolidated embeddings 289 | all_embeddings = np.load(all_embeddings_path) 290 | logger.info( 291 | f"Loaded existing consolidated embeddings with shape {all_embeddings.shape}" 292 | ) 293 | 294 | # Load and append batch files if any 295 | if batch_files: 296 | logger.info( 297 | f"Found {len(batch_files)} additional batch files to append" 298 | ) 299 | for batch_file in sorted( 300 | batch_files, 301 | key=lambda x: int(os.path.basename(x).split("_")[1].split(".")[0]), 302 | ): 303 | logger.info(f"Loading batch file: {batch_file}") 304 | batch_embeddings = np.load(batch_file) 305 | all_embeddings = np.vstack((all_embeddings, batch_embeddings)) 306 | 307 | # Save the updated consolidated embeddings 308 | np.save(all_embeddings_path, all_embeddings) 309 | logger.info( 310 | f"Updated consolidated embeddings saved to {all_embeddings_path}" 311 | ) 312 | else: 313 | if not batch_files: 314 | logger.warning(f"No embedding files found in {directory}") 315 | return None 316 | 317 | # Sort the batch files by their batch number 318 | batch_files.sort( 319 | key=lambda x: int(os.path.basename(x).split("_")[1].split(".")[0]) 320 | ) 321 | 322 | # Load the first batch to get dimensions 323 | logger.info(f"Loading first batch: {batch_files[0]}") 324 | all_embeddings = np.load(batch_files[0]) 325 | 326 | # Load and append the rest of the batches 327 | for batch_file in batch_files[1:]: 328 | logger.info(f"Loading batch file: {batch_file}") 329 | batch_embeddings = np.load(batch_file) 330 | all_embeddings = np.vstack((all_embeddings, batch_embeddings)) 331 | 332 | # Save the consolidated embeddings 333 | np.save(all_embeddings_path, all_embeddings) 334 | logger.info(f"Consolidated embeddings saved to {all_embeddings_path}") 335 | 336 | logger.info(f"Consolidation complete. Final shape: {all_embeddings.shape}") 337 | return all_embeddings 338 | 339 | except Exception as e: 340 | logger.error(f"Error consolidating embeddings: {str(e)}") 341 | raise 342 | 343 | 344 | async def embed(texts, max_tokens_per_request=1_000_000, batch_size=512): 345 | """ 346 | Generate embeddings for a list of text documents. 347 | 348 | This function assumes the existence of an `embd` object with a method `embed_documents` 349 | that takes a list of texts and returns their embeddings. 350 | 351 | Parameters: 352 | - texts: List[str], a list of text documents to be embedded. 353 | - max_tokens_per_request: Maximum number of tokens to process in a single request 354 | - batch_size: Number of documents to process in each batch 355 | 356 | Returns: 357 | - numpy.ndarray: An array of embeddings for the given text documents. 358 | """ 359 | logger.info(f"Generating embeddings for {len(texts)} documents") 360 | 361 | if os.path.exists(f"embd_{len(texts)}_out/all_embeddings.npy"): 362 | logger.info("Found existing embeddings file, loading...") 363 | loaded_data = np.load(f"embd_{len(texts)}_out/all_embeddings.npy") 364 | if len(loaded_data) == len(texts): 365 | logger.info("Embeddings already exist for all documents, loading...") 366 | return loaded_data 367 | else: 368 | logger.info("Embeddings exist but not for all documents, regenerating...") 369 | 370 | try: 371 | # Create output directory if it doesn't exist 372 | os.makedirs(f"embd_{len(texts)}_out", exist_ok=True) 373 | 374 | # Process in batches to handle large volumes of text 375 | logger.info(f"Processing {len(texts)} documents for embedding") 376 | 377 | # Initialize list to store all embeddings 378 | all_embeddings = [] 379 | 380 | # Process in batches of specified size 381 | total_batches = (len(texts) + batch_size - 1) // batch_size 382 | for i in range(0, len(texts) + 1, batch_size): 383 | batch_texts = texts[i : i + batch_size] 384 | batch_filename = f"embd_{len(texts)}_out/batch_{i // batch_size}.npy" 385 | 386 | # Check if this batch already exists 387 | 388 | logger.info( 389 | f"Embedding batch {i // batch_size + 1}/{total_batches} ({len(batch_texts)} documents)" 390 | ) 391 | 392 | # Rough estimate: 1 token ≈ 4 characters 393 | estimated_tokens = sum( 394 | [num_tokens_from_string(i, "cl100k_base") for i in batch_texts] 395 | ) 396 | 397 | if estimated_tokens > max_tokens_per_request: 398 | logger.warning( 399 | f"Batch size may exceed token limit. Estimated tokens: {estimated_tokens}" 400 | ) 401 | # Process in smaller chunks if needed 402 | sub_batch_size = max( 403 | 1, int(batch_size * max_tokens_per_request / estimated_tokens) 404 | ) 405 | logger.info( 406 | f"Reducing sub-batch size to approximately {sub_batch_size} documents" 407 | ) 408 | 409 | batch_embeddings = [] 410 | for j in range(0, len(batch_texts), sub_batch_size): 411 | sub_batch = batch_texts[j : j + sub_batch_size] 412 | logger.info( 413 | f"Processing sub-batch {j // sub_batch_size + 1}/{(len(batch_texts) + sub_batch_size - 1) // sub_batch_size}" 414 | ) 415 | sub_embeddings = await embd.aembed_documents(sub_batch) 416 | batch_embeddings.extend(sub_embeddings) 417 | 418 | batch_embeddings = np.array(batch_embeddings) 419 | else: 420 | # Process the whole batch at once 421 | batch_embeddings = await embd.aembed_documents(batch_texts) 422 | 423 | # Save this batch's embeddings 424 | np.save(batch_filename, batch_embeddings) 425 | logger.info(f"Saved batch embeddings to {batch_filename}") 426 | 427 | all_embeddings.extend(batch_embeddings) 428 | 429 | # Convert to numpy array 430 | embeddings_np = np.array(all_embeddings) 431 | 432 | # Save the consolidated embeddings 433 | np.save(f"embd_{len(texts)}_out/all_embeddings.npy", embeddings_np) 434 | logger.info( 435 | f"All embeddings consolidated and saved. Shape: {embeddings_np.shape}" 436 | ) 437 | 438 | return embeddings_np 439 | except Exception as e: 440 | logger.error(f"Error generating embeddings: {str(e)}") 441 | raise 442 | 443 | 444 | async def embed_cluster_texts(texts): 445 | """ 446 | Embeds a list of texts and clusters them, returning a DataFrame with texts, their embeddings, and cluster labels. 447 | 448 | This function combines embedding generation and clustering into a single step. It assumes the existence 449 | of a previously defined `perform_clustering` function that performs clustering on the embeddings. 450 | 451 | Parameters: 452 | - texts: List[str], a list of text documents to be processed. 453 | 454 | Returns: 455 | - pandas.DataFrame: A DataFrame containing the original texts, their embeddings, and the assigned cluster labels. 456 | """ 457 | logger.info(f"Starting embed_cluster_texts for {len(texts)} documents") 458 | 459 | # Generate embeddings 460 | text_embeddings_np = await embed(texts) 461 | logger.info("Embeddings generated, proceeding to clustering") 462 | 463 | # Perform clustering on the embeddings 464 | cluster_labels = perform_clustering(text_embeddings_np, 10, 0.1) 465 | 466 | # Create and populate DataFrame 467 | logger.info("Creating results DataFrame") 468 | df = pd.DataFrame() 469 | df["text"] = texts 470 | df["embd"] = list(text_embeddings_np) 471 | df["cluster"] = cluster_labels 472 | 473 | # Log some statistics about the clustering 474 | cluster_counts = [len(labels) for labels in cluster_labels] 475 | logger.info( 476 | f"Clustering stats: Avg clusters per document: {sum(cluster_counts) / len(cluster_counts):.2f}" 477 | ) 478 | logger.info( 479 | f"Documents with no clusters: {sum(1 for c in cluster_counts if c == 0)}" 480 | ) 481 | 482 | return df 483 | 484 | 485 | def fmt_txt(df: pd.DataFrame) -> str: 486 | """ 487 | Formats the text documents in a DataFrame into a single string. 488 | 489 | Parameters: 490 | - df: DataFrame containing the 'text' column with text documents to format. 491 | 492 | Returns: 493 | - A single string where all text documents are joined by a specific delimiter. 494 | """ 495 | logger.debug(f"Formatting {len(df)} text documents") 496 | unique_txt = df["text"].tolist() 497 | return "--- --- \n --- --- ".join(unique_txt) 498 | 499 | 500 | async def embed_cluster_summarize_texts( 501 | texts: List[str], level: int, batch_size=80 502 | ) -> Tuple[pd.DataFrame, pd.DataFrame]: 503 | """ 504 | Embeds, clusters, and summarizes a list of texts. This function first generates embeddings for the texts, 505 | clusters them based on similarity, expands the cluster assignments for easier processing, and then summarizes 506 | the content within each cluster. 507 | 508 | Parameters: 509 | - texts: A list of text documents to be processed. 510 | - level: An integer parameter that could define the depth or detail of processing. 511 | 512 | Returns: 513 | - Tuple containing two DataFrames: 514 | 1. The first DataFrame (`df_clusters`) includes the original texts, their embeddings, and cluster assignments. 515 | 2. The second DataFrame (`df_summary`) contains summaries for each cluster, the specified level of detail, 516 | and the cluster identifiers. 517 | """ 518 | logger.info( 519 | f"Starting embed_cluster_summarize_texts with {len(texts)} documents at level {level}" 520 | ) 521 | 522 | # Embed and cluster the texts, resulting in a DataFrame with 'text', 'embd', and 'cluster' columns 523 | df_clusters = await embed_cluster_texts(texts) 524 | 525 | # Prepare to expand the DataFrame for easier manipulation of clusters 526 | logger.info("Expanding cluster assignments") 527 | expanded_list = [] 528 | 529 | # Expand DataFrame entries to document-cluster pairings for straightforward processing 530 | for index, row in df_clusters.iterrows(): 531 | for cluster in row["cluster"]: 532 | expanded_list.append( 533 | {"text": row["text"], "embd": row["embd"], "cluster": cluster} 534 | ) 535 | 536 | # Create a new DataFrame from the expanded list 537 | expanded_df = pd.DataFrame(expanded_list) 538 | 539 | # Retrieve unique cluster identifiers for processing 540 | all_clusters = expanded_df["cluster"].unique() 541 | 542 | logger.info(f"--Generated {len(all_clusters)} clusters--") 543 | 544 | # Summarization 545 | template = """You are a world class summarization expert. Please provide a concise and informative summary of the following text, extracting the key points and main ideas. 546 | Text: 547 | {context} 548 | """ 549 | prompt = ChatPromptTemplate.from_template(template) 550 | chain = prompt | model | StrOutputParser() 551 | 552 | # Format text within each cluster for summarization 553 | logger.info("Starting summarization of clusters") 554 | summaries = [] 555 | # Configurable batch size 556 | formatted_texts = [] 557 | cluster_ids = [] 558 | 559 | # Prepare data for batch processing 560 | for i, cluster_id in enumerate(all_clusters): 561 | logger.info(f"Preparing cluster {i + 1}/{len(all_clusters)} (ID: {cluster_id})") 562 | df_cluster = expanded_df[expanded_df["cluster"] == cluster_id] 563 | logger.debug(f"Cluster {cluster_id} contains {len(df_cluster)} documents") 564 | 565 | formatted_txt = fmt_txt(df_cluster) 566 | formatted_texts.append(formatted_txt) 567 | cluster_ids.append(cluster_id) 568 | 569 | # Process in batches 570 | for batch_start in range(0, len(formatted_texts), batch_size): 571 | batch_end = min(batch_start + batch_size, len(formatted_texts)) 572 | batch_texts = formatted_texts[batch_start:batch_end] 573 | batch_cluster_ids = cluster_ids[batch_start:batch_end] 574 | 575 | logger.info( 576 | f"Processing batch {batch_start // batch_size + 1}/{(len(formatted_texts) + batch_size - 1) // batch_size}" 577 | ) 578 | 579 | # Filter out texts that are already in the summary_dict 580 | new_batch_texts = [] 581 | new_batch_indices = [] 582 | 583 | for i, text in enumerate(batch_texts): 584 | if text not in summary_dict: 585 | new_batch_texts.append(text) 586 | new_batch_indices.append(i) 587 | 588 | if new_batch_texts: 589 | try: 590 | # Use batch invocation 591 | batch_inputs = [{"context": text} for text in new_batch_texts] 592 | batch_results = chain.batch(batch_inputs) 593 | 594 | # Store results 595 | for i, result in enumerate(batch_results): 596 | original_index = new_batch_indices[i] 597 | original_text = batch_texts[original_index] 598 | summary_dict[original_text] = result 599 | except Exception as e: 600 | logger.error(f"Error in batch processing: {str(e)}") 601 | # Handle batch failure by processing individually 602 | for i in new_batch_indices: 603 | text = batch_texts[i] 604 | cluster_id = batch_cluster_ids[i] 605 | try: 606 | result = await chain.ainvoke({"context": text}) 607 | summary_dict[text] = result 608 | except Exception as inner_e: 609 | logger.error( 610 | f"Error generating summary for cluster {cluster_id}: {str(inner_e)}" 611 | ) 612 | summary_dict[text] = f"Error generating summary: {str(inner_e)}" 613 | 614 | # Collect summaries for this batch 615 | for i in range(batch_start, batch_end): 616 | text = formatted_texts[i] 617 | summary = summary_dict[text] 618 | summaries.append(summary) 619 | logger.debug( 620 | f"Summary for cluster {cluster_ids[i]} has length {len(summary)}" 621 | ) 622 | # Create a DataFrame to store summaries with their corresponding cluster and level 623 | logger.info("Creating summary DataFrame") 624 | with open("summaries_dict.json", "w") as f: 625 | json.dump( 626 | { 627 | "summaries": summaries, 628 | "level": [level] * len(summaries), 629 | "cluster": list(all_clusters), 630 | }, 631 | f, 632 | ) 633 | df_summary = pd.DataFrame( 634 | { 635 | "summaries": summaries, 636 | "level": [level] * len(summaries), 637 | "cluster": list(all_clusters), 638 | } 639 | ) 640 | 641 | return df_clusters, df_summary 642 | 643 | 644 | async def recursive_embed_cluster_summarize( 645 | texts: List[str], level: int = 1, n_levels: int = 3 646 | ) -> Dict[int, Tuple[pd.DataFrame, pd.DataFrame]]: 647 | """ 648 | Recursively embeds, clusters, and summarizes texts up to a specified level or until 649 | the number of unique clusters becomes 1, storing the results at each level. 650 | 651 | Parameters: 652 | - texts: List[str], texts to be processed. 653 | - level: int, current recursion level (starts at 1). 654 | - n_levels: int, maximum depth of recursion. 655 | 656 | Returns: 657 | - Dict[int, Tuple[pd.DataFrame, pd.DataFrame]], a dictionary where keys are the recursion 658 | levels and values are tuples containing the clusters DataFrame and summaries DataFrame at that level. 659 | """ 660 | logger.info( 661 | f"Starting recursive processing at level {level}/{n_levels} with {len(texts)} documents" 662 | ) 663 | results = {} # Dictionary to store results at each level 664 | 665 | # Perform embedding, clustering, and summarization for the current level 666 | df_clusters, df_summary = await embed_cluster_summarize_texts(texts, level) 667 | 668 | # Store the results of the current level 669 | results[level] = (df_clusters, df_summary) 670 | 671 | # Determine if further recursion is possible and meaningful 672 | unique_clusters = df_summary["cluster"].nunique() 673 | logger.info(f"Level {level} produced {unique_clusters} unique clusters") 674 | 675 | if level < n_levels and unique_clusters > 1: 676 | logger.info( 677 | f"Proceeding to level {level + 1} with {len(df_summary['summaries'])} summaries" 678 | ) 679 | # Use summaries as the input texts for the next level of recursion 680 | new_texts = df_summary["summaries"].tolist() 681 | next_level_results = await recursive_embed_cluster_summarize( 682 | new_texts, level + 1, n_levels 683 | ) 684 | 685 | # Merge the results from the next level into the current results dictionary 686 | results.update(next_level_results) 687 | else: 688 | if unique_clusters <= 1: 689 | logger.info( 690 | f"Stopping recursion at level {level}: Only {unique_clusters} clusters found" 691 | ) 692 | else: 693 | logger.info(f"Stopping recursion at level {level}: Reached maximum depth") 694 | 695 | return results 696 | 697 | 698 | def num_tokens_from_string(string: str, encoding_name: str) -> int: 699 | """Returns the number of tokens in a text string.""" 700 | encoding = tiktoken.get_encoding(encoding_name) 701 | num_tokens = len(encoding.encode(string)) 702 | return num_tokens 703 | --------------------------------------------------------------------------------