├── raptor.jpg ├── demo ├── cinderella └── sample.txt ├── raptor_dark.png ├── raptor ├── Retrievers.py ├── tree_structures.py ├── __init__.py ├── EmbeddingModels.py ├── SummarizationModels.py ├── cluster_tree_builder.py ├── cluster_utils.py ├── utils.py ├── QAModels.py ├── FaissRetriever.py ├── tree_retriever.py ├── RetrievalAugmentation.py └── tree_builder.py ├── requirements.txt ├── .gitignore ├── LICENSE.txt ├── README.md └── demo.ipynb /raptor.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rubenvangenugten/raptor/master/raptor.jpg -------------------------------------------------------------------------------- /demo/cinderella: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rubenvangenugten/raptor/master/demo/cinderella -------------------------------------------------------------------------------- /raptor_dark.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rubenvangenugten/raptor/master/raptor_dark.png -------------------------------------------------------------------------------- /raptor/Retrievers.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from typing import List 3 | 4 | 5 | class BaseRetriever(ABC): 6 | @abstractmethod 7 | def retrieve(self, query: str) -> str: 8 | pass 9 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | faiss-cpu 2 | numpy==1.26.3 3 | openai==1.3.3 4 | scikit-learn 5 | sentence-transformers==2.2.2 6 | tenacity==8.2.3 7 | tiktoken==0.5.1 8 | torch 9 | transformers==4.38.1 10 | umap-learn==0.5.5 11 | urllib3==1.26.6 12 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # Vim 7 | *.swp 8 | 9 | # Jupyter Notebook 10 | .ipynb_checkpoints 11 | 12 | # mac 13 | .DS_Store 14 | 15 | # Other 16 | .vscode 17 | *.tsv 18 | *.pt 19 | gpt*.txt 20 | *.env 21 | local/ 22 | local_* 23 | build/ 24 | *.egg-info/ 25 | !/data/*.json 26 | /dist/ 27 | checklist.md 28 | finetuning_ckpts/ 29 | * copy* 30 | .idea 31 | assertion.log 32 | *.log 33 | *.db -------------------------------------------------------------------------------- /raptor/tree_structures.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List, Set 2 | 3 | 4 | class Node: 5 | """ 6 | Represents a node in the hierarchical tree structure. 7 | """ 8 | 9 | def __init__(self, text: str, index: int, children: Set[int], embeddings) -> None: 10 | self.text = text 11 | self.index = index 12 | self.children = children 13 | self.embeddings = embeddings 14 | 15 | 16 | class Tree: 17 | """ 18 | Represents the entire hierarchical tree structure. 19 | """ 20 | 21 | def __init__( 22 | self, all_nodes, root_nodes, leaf_nodes, num_layers, layer_to_nodes 23 | ) -> None: 24 | self.all_nodes = all_nodes 25 | self.root_nodes = root_nodes 26 | self.leaf_nodes = leaf_nodes 27 | self.num_layers = num_layers 28 | self.layer_to_nodes = layer_to_nodes 29 | -------------------------------------------------------------------------------- /raptor/__init__.py: -------------------------------------------------------------------------------- 1 | # raptor/__init__.py 2 | from .cluster_tree_builder import ClusterTreeBuilder, ClusterTreeConfig 3 | from .EmbeddingModels import (BaseEmbeddingModel, OpenAIEmbeddingModel, 4 | SBertEmbeddingModel) 5 | from .FaissRetriever import FaissRetriever, FaissRetrieverConfig 6 | from .QAModels import (BaseQAModel, GPT3QAModel, GPT3TurboQAModel, GPT4QAModel, 7 | UnifiedQAModel) 8 | from .RetrievalAugmentation import (RetrievalAugmentation, 9 | RetrievalAugmentationConfig) 10 | from .Retrievers import BaseRetriever 11 | from .SummarizationModels import (BaseSummarizationModel, 12 | GPT3SummarizationModel, 13 | GPT3TurboSummarizationModel) 14 | from .tree_builder import TreeBuilder, TreeBuilderConfig 15 | from .tree_retriever import TreeRetriever, TreeRetrieverConfig 16 | from .tree_structures import Node, Tree 17 | -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | The MIT License 2 | 3 | Copyright (c) Parth Sarthi 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in 13 | all copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 21 | THE SOFTWARE. 22 | -------------------------------------------------------------------------------- /raptor/EmbeddingModels.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from abc import ABC, abstractmethod 3 | 4 | from openai import OpenAI 5 | from sentence_transformers import SentenceTransformer 6 | from tenacity import retry, stop_after_attempt, wait_random_exponential 7 | 8 | logging.basicConfig(format="%(asctime)s - %(message)s", level=logging.INFO) 9 | 10 | 11 | class BaseEmbeddingModel(ABC): 12 | @abstractmethod 13 | def create_embedding(self, text): 14 | pass 15 | 16 | 17 | class OpenAIEmbeddingModel(BaseEmbeddingModel): 18 | def __init__(self, model="text-embedding-ada-002"): 19 | self.client = OpenAI() 20 | self.model = model 21 | 22 | @retry(wait=wait_random_exponential(min=1, max=20), stop=stop_after_attempt(6)) 23 | def create_embedding(self, text): 24 | text = text.replace("\n", " ") 25 | return ( 26 | self.client.embeddings.create(input=[text], model=self.model) 27 | .data[0] 28 | .embedding 29 | ) 30 | 31 | 32 | class SBertEmbeddingModel(BaseEmbeddingModel): 33 | def __init__(self, model_name="sentence-transformers/multi-qa-mpnet-base-cos-v1"): 34 | self.model = SentenceTransformer(model_name) 35 | 36 | def create_embedding(self, text): 37 | return self.model.encode(text) 38 | -------------------------------------------------------------------------------- /raptor/SummarizationModels.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | from abc import ABC, abstractmethod 4 | 5 | from openai import OpenAI 6 | from tenacity import retry, stop_after_attempt, wait_random_exponential 7 | 8 | logging.basicConfig(format="%(asctime)s - %(message)s", level=logging.INFO) 9 | 10 | 11 | class BaseSummarizationModel(ABC): 12 | @abstractmethod 13 | def summarize(self, context, max_tokens=150): 14 | pass 15 | 16 | 17 | class GPT3TurboSummarizationModel(BaseSummarizationModel): 18 | def __init__(self, model="gpt-3.5-turbo"): 19 | 20 | self.model = model 21 | 22 | @retry(wait=wait_random_exponential(min=1, max=20), stop=stop_after_attempt(6)) 23 | def summarize(self, context, max_tokens=500, stop_sequence=None): 24 | 25 | try: 26 | client = OpenAI() 27 | 28 | response = client.chat.completions.create( 29 | model=self.model, 30 | messages=[ 31 | {"role": "system", "content": "You are a helpful assistant."}, 32 | { 33 | "role": "user", 34 | "content": f"Write a summary of the following, including as many key details as possible: {context}:", 35 | }, 36 | ], 37 | max_tokens=max_tokens, 38 | ) 39 | 40 | return response.choices[0].message.content 41 | 42 | except Exception as e: 43 | print(e) 44 | return e 45 | 46 | 47 | class GPT3SummarizationModel(BaseSummarizationModel): 48 | def __init__(self, model="text-davinci-003"): 49 | 50 | self.model = model 51 | 52 | @retry(wait=wait_random_exponential(min=1, max=20), stop=stop_after_attempt(6)) 53 | def summarize(self, context, max_tokens=500, stop_sequence=None): 54 | 55 | try: 56 | client = OpenAI() 57 | 58 | response = client.chat.completions.create( 59 | model=self.model, 60 | messages=[ 61 | {"role": "system", "content": "You are a helpful assistant."}, 62 | { 63 | "role": "user", 64 | "content": f"Write a summary of the following, including as many key details as possible: {context}:", 65 | }, 66 | ], 67 | max_tokens=max_tokens, 68 | ) 69 | 70 | return response.choices[0].message.content 71 | 72 | except Exception as e: 73 | print(e) 74 | return e 75 | -------------------------------------------------------------------------------- /raptor/cluster_tree_builder.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import pickle 3 | from concurrent.futures import ThreadPoolExecutor 4 | from threading import Lock 5 | from typing import Dict, List, Set 6 | 7 | from .cluster_utils import ClusteringAlgorithm, RAPTOR_Clustering 8 | from .tree_builder import TreeBuilder, TreeBuilderConfig 9 | from .tree_structures import Node, Tree 10 | from .utils import (distances_from_embeddings, get_children, get_embeddings, 11 | get_node_list, get_text, 12 | indices_of_nearest_neighbors_from_distances, split_text) 13 | 14 | logging.basicConfig(format="%(asctime)s - %(message)s", level=logging.INFO) 15 | 16 | 17 | class ClusterTreeConfig(TreeBuilderConfig): 18 | def __init__( 19 | self, 20 | reduction_dimension=10, 21 | clustering_algorithm=RAPTOR_Clustering, # Default to RAPTOR clustering 22 | clustering_params={}, # Pass additional params as a dict 23 | *args, 24 | **kwargs, 25 | ): 26 | super().__init__(*args, **kwargs) 27 | self.reduction_dimension = reduction_dimension 28 | self.clustering_algorithm = clustering_algorithm 29 | self.clustering_params = clustering_params 30 | 31 | def log_config(self): 32 | base_summary = super().log_config() 33 | cluster_tree_summary = f""" 34 | Reduction Dimension: {self.reduction_dimension} 35 | Clustering Algorithm: {self.clustering_algorithm.__name__} 36 | Clustering Parameters: {self.clustering_params} 37 | """ 38 | return base_summary + cluster_tree_summary 39 | 40 | 41 | class ClusterTreeBuilder(TreeBuilder): 42 | def __init__(self, config) -> None: 43 | super().__init__(config) 44 | 45 | if not isinstance(config, ClusterTreeConfig): 46 | raise ValueError("config must be an instance of ClusterTreeConfig") 47 | self.reduction_dimension = config.reduction_dimension 48 | self.clustering_algorithm = config.clustering_algorithm 49 | self.clustering_params = config.clustering_params 50 | 51 | logging.info( 52 | f"Successfully initialized ClusterTreeBuilder with Config {config.log_config()}" 53 | ) 54 | 55 | def construct_tree( 56 | self, 57 | current_level_nodes: Dict[int, Node], 58 | all_tree_nodes: Dict[int, Node], 59 | layer_to_nodes: Dict[int, List[Node]], 60 | use_multithreading: bool = False, 61 | ) -> Dict[int, Node]: 62 | logging.info("Using Cluster TreeBuilder") 63 | 64 | next_node_index = len(all_tree_nodes) 65 | 66 | def process_cluster( 67 | cluster, new_level_nodes, next_node_index, summarization_length, lock 68 | ): 69 | node_texts = get_text(cluster) 70 | 71 | summarized_text = self.summarize( 72 | context=node_texts, 73 | max_tokens=summarization_length, 74 | ) 75 | 76 | logging.info( 77 | f"Node Texts Length: {len(self.tokenizer.encode(node_texts))}, Summarized Text Length: {len(self.tokenizer.encode(summarized_text))}" 78 | ) 79 | 80 | __, new_parent_node = self.create_node( 81 | next_node_index, summarized_text, {node.index for node in cluster} 82 | ) 83 | 84 | with lock: 85 | new_level_nodes[next_node_index] = new_parent_node 86 | 87 | for layer in range(self.num_layers): 88 | 89 | new_level_nodes = {} 90 | 91 | logging.info(f"Constructing Layer {layer}") 92 | 93 | node_list_current_layer = get_node_list(current_level_nodes) 94 | 95 | if len(node_list_current_layer) <= self.reduction_dimension + 1: 96 | self.num_layers = layer 97 | logging.info( 98 | f"Stopping Layer construction: Cannot Create More Layers. Total Layers in tree: {layer}" 99 | ) 100 | break 101 | 102 | clusters = self.clustering_algorithm.perform_clustering( 103 | node_list_current_layer, 104 | self.cluster_embedding_model, 105 | reduction_dimension=self.reduction_dimension, 106 | **self.clustering_params, 107 | ) 108 | 109 | lock = Lock() 110 | 111 | summarization_length = self.summarization_length 112 | logging.info(f"Summarization Length: {summarization_length}") 113 | 114 | if use_multithreading: 115 | with ThreadPoolExecutor() as executor: 116 | for cluster in clusters: 117 | executor.submit( 118 | process_cluster, 119 | cluster, 120 | new_level_nodes, 121 | next_node_index, 122 | summarization_length, 123 | lock, 124 | ) 125 | next_node_index += 1 126 | executor.shutdown(wait=True) 127 | 128 | else: 129 | for cluster in clusters: 130 | process_cluster( 131 | cluster, 132 | new_level_nodes, 133 | next_node_index, 134 | summarization_length, 135 | lock, 136 | ) 137 | next_node_index += 1 138 | 139 | layer_to_nodes[layer + 1] = list(new_level_nodes.values()) 140 | current_level_nodes = new_level_nodes 141 | all_tree_nodes.update(new_level_nodes) 142 | 143 | tree = Tree( 144 | all_tree_nodes, 145 | layer_to_nodes[layer + 1], 146 | layer_to_nodes[0], 147 | layer + 1, 148 | layer_to_nodes, 149 | ) 150 | 151 | return current_level_nodes 152 | -------------------------------------------------------------------------------- /raptor/cluster_utils.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import random 3 | from abc import ABC, abstractmethod 4 | from typing import List, Optional 5 | 6 | import numpy as np 7 | import tiktoken 8 | import umap 9 | from sklearn.mixture import GaussianMixture 10 | 11 | # Initialize logging 12 | logging.basicConfig(format="%(asctime)s - %(message)s", level=logging.INFO) 13 | 14 | from .tree_structures import Node 15 | # Import necessary methods from other modules 16 | from .utils import get_embeddings 17 | 18 | # Set a random seed for reproducibility 19 | RANDOM_SEED = 224 20 | random.seed(RANDOM_SEED) 21 | 22 | 23 | def global_cluster_embeddings( 24 | embeddings: np.ndarray, 25 | dim: int, 26 | n_neighbors: Optional[int] = None, 27 | metric: str = "cosine", 28 | ) -> np.ndarray: 29 | if n_neighbors is None: 30 | n_neighbors = int((len(embeddings) - 1) ** 0.5) 31 | reduced_embeddings = umap.UMAP( 32 | n_neighbors=n_neighbors, n_components=dim, metric=metric 33 | ).fit_transform(embeddings) 34 | return reduced_embeddings 35 | 36 | 37 | def local_cluster_embeddings( 38 | embeddings: np.ndarray, dim: int, num_neighbors: int = 10, metric: str = "cosine" 39 | ) -> np.ndarray: 40 | reduced_embeddings = umap.UMAP( 41 | n_neighbors=num_neighbors, n_components=dim, metric=metric 42 | ).fit_transform(embeddings) 43 | return reduced_embeddings 44 | 45 | 46 | def get_optimal_clusters( 47 | embeddings: np.ndarray, max_clusters: int = 50, random_state: int = RANDOM_SEED 48 | ) -> int: 49 | max_clusters = min(max_clusters, len(embeddings)) 50 | n_clusters = np.arange(1, max_clusters) 51 | bics = [] 52 | for n in n_clusters: 53 | gm = GaussianMixture(n_components=n, random_state=random_state) 54 | gm.fit(embeddings) 55 | bics.append(gm.bic(embeddings)) 56 | optimal_clusters = n_clusters[np.argmin(bics)] 57 | return optimal_clusters 58 | 59 | 60 | def GMM_cluster(embeddings: np.ndarray, threshold: float, random_state: int = 0): 61 | n_clusters = get_optimal_clusters(embeddings) 62 | gm = GaussianMixture(n_components=n_clusters, random_state=random_state) 63 | gm.fit(embeddings) 64 | probs = gm.predict_proba(embeddings) 65 | labels = [np.where(prob > threshold)[0] for prob in probs] 66 | return labels, n_clusters 67 | 68 | 69 | def perform_clustering( 70 | embeddings: np.ndarray, dim: int, threshold: float, verbose: bool = False 71 | ) -> List[np.ndarray]: 72 | reduced_embeddings_global = global_cluster_embeddings(embeddings, min(dim, len(embeddings) -2)) 73 | global_clusters, n_global_clusters = GMM_cluster( 74 | reduced_embeddings_global, threshold 75 | ) 76 | 77 | if verbose: 78 | logging.info(f"Global Clusters: {n_global_clusters}") 79 | 80 | all_local_clusters = [np.array([]) for _ in range(len(embeddings))] 81 | total_clusters = 0 82 | 83 | for i in range(n_global_clusters): 84 | global_cluster_embeddings_ = embeddings[ 85 | np.array([i in gc for gc in global_clusters]) 86 | ] 87 | if verbose: 88 | logging.info( 89 | f"Nodes in Global Cluster {i}: {len(global_cluster_embeddings_)}" 90 | ) 91 | if len(global_cluster_embeddings_) == 0: 92 | continue 93 | if len(global_cluster_embeddings_) <= dim + 1: 94 | local_clusters = [np.array([0]) for _ in global_cluster_embeddings_] 95 | n_local_clusters = 1 96 | else: 97 | reduced_embeddings_local = local_cluster_embeddings( 98 | global_cluster_embeddings_, dim 99 | ) 100 | local_clusters, n_local_clusters = GMM_cluster( 101 | reduced_embeddings_local, threshold 102 | ) 103 | 104 | if verbose: 105 | logging.info(f"Local Clusters in Global Cluster {i}: {n_local_clusters}") 106 | 107 | for j in range(n_local_clusters): 108 | local_cluster_embeddings_ = global_cluster_embeddings_[ 109 | np.array([j in lc for lc in local_clusters]) 110 | ] 111 | indices = np.where( 112 | (embeddings == local_cluster_embeddings_[:, None]).all(-1) 113 | )[1] 114 | for idx in indices: 115 | all_local_clusters[idx] = np.append( 116 | all_local_clusters[idx], j + total_clusters 117 | ) 118 | 119 | total_clusters += n_local_clusters 120 | 121 | if verbose: 122 | logging.info(f"Total Clusters: {total_clusters}") 123 | return all_local_clusters 124 | 125 | 126 | class ClusteringAlgorithm(ABC): 127 | @abstractmethod 128 | def perform_clustering(self, embeddings: np.ndarray, **kwargs) -> List[List[int]]: 129 | pass 130 | 131 | 132 | class RAPTOR_Clustering(ClusteringAlgorithm): 133 | def perform_clustering( 134 | nodes: List[Node], 135 | embedding_model_name: str, 136 | max_length_in_cluster: int = 3500, 137 | tokenizer=tiktoken.get_encoding("cl100k_base"), 138 | reduction_dimension: int = 10, 139 | threshold: float = 0.1, 140 | verbose: bool = False, 141 | ) -> List[List[Node]]: 142 | # Get the embeddings from the nodes 143 | embeddings = np.array([node.embeddings[embedding_model_name] for node in nodes]) 144 | 145 | # Perform the clustering 146 | clusters = perform_clustering( 147 | embeddings, dim=reduction_dimension, threshold=threshold 148 | ) 149 | 150 | # Initialize an empty list to store the clusters of nodes 151 | node_clusters = [] 152 | 153 | # Iterate over each unique label in the clusters 154 | for label in np.unique(np.concatenate(clusters)): 155 | # Get the indices of the nodes that belong to this cluster 156 | indices = [i for i, cluster in enumerate(clusters) if label in cluster] 157 | 158 | # Add the corresponding nodes to the node_clusters list 159 | cluster_nodes = [nodes[i] for i in indices] 160 | 161 | # Base case: if the cluster only has one node, do not attempt to recluster it 162 | if len(cluster_nodes) == 1: 163 | node_clusters.append(cluster_nodes) 164 | continue 165 | 166 | # Calculate the total length of the text in the nodes 167 | total_length = sum( 168 | [len(tokenizer.encode(node.text)) for node in cluster_nodes] 169 | ) 170 | 171 | # If the total length exceeds the maximum allowed length, recluster this cluster 172 | if total_length > max_length_in_cluster: 173 | if verbose: 174 | logging.info( 175 | f"reclustering cluster with {len(cluster_nodes)} nodes" 176 | ) 177 | node_clusters.extend( 178 | RAPTOR_Clustering.perform_clustering( 179 | cluster_nodes, embedding_model_name, max_length_in_cluster 180 | ) 181 | ) 182 | else: 183 | node_clusters.append(cluster_nodes) 184 | 185 | return node_clusters 186 | -------------------------------------------------------------------------------- /raptor/utils.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import re 3 | from typing import Dict, List, Set 4 | 5 | import numpy as np 6 | import tiktoken 7 | from scipy import spatial 8 | 9 | from .tree_structures import Node 10 | 11 | logging.basicConfig(format="%(asctime)s - %(message)s", level=logging.INFO) 12 | 13 | 14 | def reverse_mapping(layer_to_nodes: Dict[int, List[Node]]) -> Dict[Node, int]: 15 | node_to_layer = {} 16 | for layer, nodes in layer_to_nodes.items(): 17 | for node in nodes: 18 | node_to_layer[node.index] = layer 19 | return node_to_layer 20 | 21 | 22 | def split_text( 23 | text: str, tokenizer: tiktoken.get_encoding("cl100k_base"), max_tokens: int, overlap: int = 0 24 | ): 25 | """ 26 | Splits the input text into smaller chunks based on the tokenizer and maximum allowed tokens. 27 | 28 | Args: 29 | text (str): The text to be split. 30 | tokenizer (CustomTokenizer): The tokenizer to be used for splitting the text. 31 | max_tokens (int): The maximum allowed tokens. 32 | overlap (int, optional): The number of overlapping tokens between chunks. Defaults to 0. 33 | 34 | Returns: 35 | List[str]: A list of text chunks. 36 | """ 37 | # Split the text into sentences using multiple delimiters 38 | delimiters = [".", "!", "?", "\n"] 39 | regex_pattern = "|".join(map(re.escape, delimiters)) 40 | sentences = re.split(regex_pattern, text) 41 | 42 | # Calculate the number of tokens for each sentence 43 | n_tokens = [len(tokenizer.encode(" " + sentence)) for sentence in sentences] 44 | 45 | chunks = [] 46 | current_chunk = [] 47 | current_length = 0 48 | 49 | for sentence, token_count in zip(sentences, n_tokens): 50 | # If the sentence is empty or consists only of whitespace, skip it 51 | if not sentence.strip(): 52 | continue 53 | 54 | # If the sentence is too long, split it into smaller parts 55 | if token_count > max_tokens: 56 | sub_sentences = re.split(r"[,;:]", sentence) 57 | sub_token_counts = [len(tokenizer.encode(" " + sub_sentence)) for sub_sentence in sub_sentences] 58 | 59 | sub_chunk = [] 60 | sub_length = 0 61 | 62 | for sub_sentence, sub_token_count in zip(sub_sentences, sub_token_counts): 63 | if sub_length + sub_token_count > max_tokens: 64 | chunks.append(" ".join(sub_chunk)) 65 | sub_chunk = sub_chunk[-overlap:] if overlap > 0 else [] 66 | sub_length = sum(sub_token_counts[max(0, len(sub_chunk) - overlap):len(sub_chunk)]) 67 | 68 | sub_chunk.append(sub_sentence) 69 | sub_length += sub_token_count 70 | 71 | if sub_chunk: 72 | chunks.append(" ".join(sub_chunk)) 73 | 74 | # If adding the sentence to the current chunk exceeds the max tokens, start a new chunk 75 | elif current_length + token_count > max_tokens: 76 | chunks.append(" ".join(current_chunk)) 77 | current_chunk = current_chunk[-overlap:] if overlap > 0 else [] 78 | current_length = sum(n_tokens[max(0, len(current_chunk) - overlap):len(current_chunk)]) 79 | current_chunk.append(sentence) 80 | current_length += token_count 81 | 82 | # Otherwise, add the sentence to the current chunk 83 | else: 84 | current_chunk.append(sentence) 85 | current_length += token_count 86 | 87 | # Add the last chunk if it's not empty 88 | if current_chunk: 89 | chunks.append(" ".join(current_chunk)) 90 | 91 | return chunks 92 | 93 | 94 | def distances_from_embeddings( 95 | query_embedding: List[float], 96 | embeddings: List[List[float]], 97 | distance_metric: str = "cosine", 98 | ) -> List[float]: 99 | """ 100 | Calculates the distances between a query embedding and a list of embeddings. 101 | 102 | Args: 103 | query_embedding (List[float]): The query embedding. 104 | embeddings (List[List[float]]): A list of embeddings to compare against the query embedding. 105 | distance_metric (str, optional): The distance metric to use for calculation. Defaults to 'cosine'. 106 | 107 | Returns: 108 | List[float]: The calculated distances between the query embedding and the list of embeddings. 109 | """ 110 | distance_metrics = { 111 | "cosine": spatial.distance.cosine, 112 | "L1": spatial.distance.cityblock, 113 | "L2": spatial.distance.euclidean, 114 | "Linf": spatial.distance.chebyshev, 115 | } 116 | 117 | if distance_metric not in distance_metrics: 118 | raise ValueError( 119 | f"Unsupported distance metric '{distance_metric}'. Supported metrics are: {list(distance_metrics.keys())}" 120 | ) 121 | 122 | distances = [ 123 | distance_metrics[distance_metric](query_embedding, embedding) 124 | for embedding in embeddings 125 | ] 126 | 127 | return distances 128 | 129 | 130 | def get_node_list(node_dict: Dict[int, Node]) -> List[Node]: 131 | """ 132 | Converts a dictionary of node indices to a sorted list of nodes. 133 | 134 | Args: 135 | node_dict (Dict[int, Node]): Dictionary of node indices to nodes. 136 | 137 | Returns: 138 | List[Node]: Sorted list of nodes. 139 | """ 140 | indices = sorted(node_dict.keys()) 141 | node_list = [node_dict[index] for index in indices] 142 | return node_list 143 | 144 | 145 | def get_embeddings(node_list: List[Node], embedding_model: str) -> List: 146 | """ 147 | Extracts the embeddings of nodes from a list of nodes. 148 | 149 | Args: 150 | node_list (List[Node]): List of nodes. 151 | embedding_model (str): The name of the embedding model to be used. 152 | 153 | Returns: 154 | List: List of node embeddings. 155 | """ 156 | return [node.embeddings[embedding_model] for node in node_list] 157 | 158 | 159 | def get_children(node_list: List[Node]) -> List[Set[int]]: 160 | """ 161 | Extracts the children of nodes from a list of nodes. 162 | 163 | Args: 164 | node_list (List[Node]): List of nodes. 165 | 166 | Returns: 167 | List[Set[int]]: List of sets of node children indices. 168 | """ 169 | return [node.children for node in node_list] 170 | 171 | 172 | def get_text(node_list: List[Node]) -> str: 173 | """ 174 | Generates a single text string by concatenating the text from a list of nodes. 175 | 176 | Args: 177 | node_list (List[Node]): List of nodes. 178 | 179 | Returns: 180 | str: Concatenated text. 181 | """ 182 | text = "" 183 | for node in node_list: 184 | text += f"{' '.join(node.text.splitlines())}" 185 | text += "\n\n" 186 | return text 187 | 188 | 189 | def indices_of_nearest_neighbors_from_distances(distances: List[float]) -> np.ndarray: 190 | """ 191 | Returns the indices of nearest neighbors sorted in ascending order of distance. 192 | 193 | Args: 194 | distances (List[float]): A list of distances between embeddings. 195 | 196 | Returns: 197 | np.ndarray: An array of indices sorted by ascending distance. 198 | """ 199 | return np.argsort(distances) 200 | -------------------------------------------------------------------------------- /raptor/QAModels.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | 4 | from openai import OpenAI 5 | 6 | 7 | import getpass 8 | from abc import ABC, abstractmethod 9 | 10 | import torch 11 | from tenacity import retry, stop_after_attempt, wait_random_exponential 12 | from transformers import T5ForConditionalGeneration, T5Tokenizer 13 | 14 | 15 | class BaseQAModel(ABC): 16 | @abstractmethod 17 | def answer_question(self, context, question): 18 | pass 19 | 20 | 21 | class GPT3QAModel(BaseQAModel): 22 | def __init__(self, model="text-davinci-003"): 23 | """ 24 | Initializes the GPT-3 model with the specified model version. 25 | 26 | Args: 27 | model (str, optional): The GPT-3 model version to use for generating summaries. Defaults to "text-davinci-003". 28 | """ 29 | self.model = model 30 | self.client = OpenAI(api_key=os.environ["OPENAI_API_KEY"]) 31 | 32 | @retry(wait=wait_random_exponential(min=1, max=20), stop=stop_after_attempt(6)) 33 | def answer_question(self, context, question, max_tokens=150, stop_sequence=None): 34 | """ 35 | Generates a summary of the given context using the GPT-3 model. 36 | 37 | Args: 38 | context (str): The text to summarize. 39 | max_tokens (int, optional): The maximum number of tokens in the generated summary. Defaults to 150. 40 | stop_sequence (str, optional): The sequence at which to stop summarization. Defaults to None. 41 | 42 | Returns: 43 | str: The generated summary. 44 | """ 45 | try: 46 | response = self.client.completions.create( 47 | prompt=f"using the folloing information {context}. Answer the following question in less than 5-7 words, if possible: {question}", 48 | temperature=0, 49 | max_tokens=max_tokens, 50 | top_p=1, 51 | frequency_penalty=0, 52 | presence_penalty=0, 53 | stop=stop_sequence, 54 | model=self.model, 55 | ) 56 | return response.choices[0].text.strip() 57 | 58 | except Exception as e: 59 | print(e) 60 | return "" 61 | 62 | 63 | class GPT3TurboQAModel(BaseQAModel): 64 | def __init__(self, model="gpt-3.5-turbo"): 65 | """ 66 | Initializes the GPT-3 model with the specified model version. 67 | 68 | Args: 69 | model (str, optional): The GPT-3 model version to use for generating summaries. Defaults to "text-davinci-003". 70 | """ 71 | self.model = model 72 | self.client = OpenAI(api_key=os.environ["OPENAI_API_KEY"]) 73 | 74 | @retry(wait=wait_random_exponential(min=1, max=20), stop=stop_after_attempt(6)) 75 | def _attempt_answer_question( 76 | self, context, question, max_tokens=150, stop_sequence=None 77 | ): 78 | """ 79 | Generates a summary of the given context using the GPT-3 model. 80 | 81 | Args: 82 | context (str): The text to summarize. 83 | max_tokens (int, optional): The maximum number of tokens in the generated summary. Defaults to 150. 84 | stop_sequence (str, optional): The sequence at which to stop summarization. Defaults to None. 85 | 86 | Returns: 87 | str: The generated summary. 88 | """ 89 | response = self.client.chat.completions.create( 90 | model=self.model, 91 | messages=[ 92 | {"role": "system", "content": "You are Question Answering Portal"}, 93 | { 94 | "role": "user", 95 | "content": f"Given Context: {context} Give the best full answer amongst the option to question {question}", 96 | }, 97 | ], 98 | temperature=0, 99 | ) 100 | 101 | return response.choices[0].message.content.strip() 102 | 103 | @retry(wait=wait_random_exponential(min=1, max=20), stop=stop_after_attempt(6)) 104 | def answer_question(self, context, question, max_tokens=150, stop_sequence=None): 105 | 106 | try: 107 | return self._attempt_answer_question( 108 | context, question, max_tokens=max_tokens, stop_sequence=stop_sequence 109 | ) 110 | except Exception as e: 111 | print(e) 112 | return e 113 | 114 | 115 | class GPT4QAModel(BaseQAModel): 116 | def __init__(self, model="gpt-4"): 117 | """ 118 | Initializes the GPT-3 model with the specified model version. 119 | 120 | Args: 121 | model (str, optional): The GPT-3 model version to use for generating summaries. Defaults to "text-davinci-003". 122 | """ 123 | self.model = model 124 | self.client = OpenAI(api_key=os.environ["OPENAI_API_KEY"]) 125 | 126 | @retry(wait=wait_random_exponential(min=1, max=20), stop=stop_after_attempt(6)) 127 | def _attempt_answer_question( 128 | self, context, question, max_tokens=150, stop_sequence=None 129 | ): 130 | """ 131 | Generates a summary of the given context using the GPT-3 model. 132 | 133 | Args: 134 | context (str): The text to summarize. 135 | max_tokens (int, optional): The maximum number of tokens in the generated summary. Defaults to 150. 136 | stop_sequence (str, optional): The sequence at which to stop summarization. Defaults to None. 137 | 138 | Returns: 139 | str: The generated summary. 140 | """ 141 | response = self.client.chat.completions.create( 142 | model=self.model, 143 | messages=[ 144 | {"role": "system", "content": "You are Question Answering Portal"}, 145 | { 146 | "role": "user", 147 | "content": f"Given Context: {context} Give the best full answer amongst the option to question {question}", 148 | }, 149 | ], 150 | temperature=0, 151 | ) 152 | 153 | return response.choices[0].message.content.strip() 154 | 155 | @retry(wait=wait_random_exponential(min=1, max=20), stop=stop_after_attempt(6)) 156 | def answer_question(self, context, question, max_tokens=150, stop_sequence=None): 157 | 158 | try: 159 | return self._attempt_answer_question( 160 | context, question, max_tokens=max_tokens, stop_sequence=stop_sequence 161 | ) 162 | except Exception as e: 163 | print(e) 164 | return e 165 | 166 | 167 | class UnifiedQAModel(BaseQAModel): 168 | def __init__(self, model_name="allenai/unifiedqa-v2-t5-3b-1363200"): 169 | self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 170 | self.model = T5ForConditionalGeneration.from_pretrained(model_name).to( 171 | self.device 172 | ) 173 | self.tokenizer = T5Tokenizer.from_pretrained(model_name) 174 | 175 | def run_model(self, input_string, **generator_args): 176 | input_ids = self.tokenizer.encode(input_string, return_tensors="pt").to( 177 | self.device 178 | ) 179 | res = self.model.generate(input_ids, **generator_args) 180 | return self.tokenizer.batch_decode(res, skip_special_tokens=True) 181 | 182 | def answer_question(self, context, question): 183 | input_string = question + " \\n " + context 184 | output = self.run_model(input_string) 185 | return output[0] 186 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 5 | 6 | 11 | 12 | 13 | 14 | Shows an illustrated sun in light color mode and a moon with stars in dark color mode. 15 | 16 | 17 | ## RAPTOR: Recursive Abstractive Processing for Tree-Organized Retrieval 18 | 19 | **RAPTOR** introduces a novel approach to retrieval-augmented language models by constructing a recursive tree structure from documents. This allows for more efficient and context-aware information retrieval across large texts, addressing common limitations in traditional language models. 20 | 21 | 22 | 23 | For detailed methodologies and implementations, refer to the original paper: 24 | 25 | - [RAPTOR: Recursive Abstractive Processing for Tree-Organized Retrieval](https://arxiv.org/abs/2401.18059) 26 | 27 | [![Paper page](https://huggingface.co/datasets/huggingface/badges/resolve/main/paper-page-sm.svg)](https://huggingface.co/papers/2401.18059) 28 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/raptor-recursive-abstractive-processing-for/question-answering-on-quality)](https://paperswithcode.com/sota/question-answering-on-quality?p=raptor-recursive-abstractive-processing-for) 29 | 30 | ## Installation 31 | 32 | Before using RAPTOR, ensure Python 3.8+ is installed. Clone the RAPTOR repository and install necessary dependencies: 33 | 34 | ```bash 35 | git clone https://github.com/parthsarthi03/raptor.git 36 | cd raptor 37 | pip install -r requirements.txt 38 | ``` 39 | 40 | ## Basic Usage 41 | 42 | To get started with RAPTOR, follow these steps: 43 | 44 | ### Setting Up RAPTOR 45 | 46 | First, set your OpenAI API key and initialize the RAPTOR configuration: 47 | 48 | ```python 49 | import os 50 | os.environ["OPENAI_API_KEY"] = "your-openai-api-key" 51 | 52 | from raptor import RetrievalAugmentation 53 | 54 | # Initialize with default configuration. For advanced configurations, check the documentation. [WIP] 55 | RA = RetrievalAugmentation() 56 | ``` 57 | 58 | ### Adding Documents to the Tree 59 | 60 | Add your text documents to RAPTOR for indexing: 61 | 62 | ```python 63 | with open('sample.txt', 'r') as file: 64 | text = file.read() 65 | RA.add_documents(text) 66 | ``` 67 | 68 | ### Answering Questions 69 | 70 | You can now use RAPTOR to answer questions based on the indexed documents: 71 | 72 | ```python 73 | question = "How did Cinderella reach her happy ending?" 74 | answer = RA.answer_question(question=question) 75 | print("Answer: ", answer) 76 | ``` 77 | 78 | ### Saving and Loading the Tree 79 | 80 | Save the constructed tree to a specified path: 81 | 82 | ```python 83 | SAVE_PATH = "demo/cinderella" 84 | RA.save(SAVE_PATH) 85 | ``` 86 | 87 | Load the saved tree back into RAPTOR: 88 | 89 | ```python 90 | RA = RetrievalAugmentation(tree=SAVE_PATH) 91 | answer = RA.answer_question(question=question) 92 | ``` 93 | 94 | 95 | ### Extending RAPTOR with other Models 96 | 97 | RAPTOR is designed to be flexible and allows you to integrate any models for summarization, question-answering (QA), and embedding generation. Here is how to extend RAPTOR with your own models: 98 | 99 | #### Custom Summarization Model 100 | 101 | If you wish to use a different language model for summarization, you can do so by extending the `BaseSummarizationModel` class. Implement the `summarize` method to integrate your custom summarization logic: 102 | 103 | ```python 104 | from raptor import BaseSummarizationModel 105 | 106 | class CustomSummarizationModel(BaseSummarizationModel): 107 | def __init__(self): 108 | # Initialize your model here 109 | pass 110 | 111 | def summarize(self, context, max_tokens=150): 112 | # Implement your summarization logic here 113 | # Return the summary as a string 114 | summary = "Your summary here" 115 | return summary 116 | ``` 117 | 118 | #### Custom QA Model 119 | 120 | For custom QA models, extend the `BaseQAModel` class and implement the `answer_question` method. This method should return the best answer found by your model given a context and a question: 121 | 122 | ```python 123 | from raptor import BaseQAModel 124 | 125 | class CustomQAModel(BaseQAModel): 126 | def __init__(self): 127 | # Initialize your model here 128 | pass 129 | 130 | def answer_question(self, context, question): 131 | # Implement your QA logic here 132 | # Return the answer as a string 133 | answer = "Your answer here" 134 | return answer 135 | ``` 136 | 137 | #### Custom Embedding Model 138 | 139 | To use a different embedding model, extend the `BaseEmbeddingModel` class. Implement the `create_embedding` method, which should return a vector representation of the input text: 140 | 141 | ```python 142 | from raptor import BaseEmbeddingModel 143 | 144 | class CustomEmbeddingModel(BaseEmbeddingModel): 145 | def __init__(self): 146 | # Initialize your model here 147 | pass 148 | 149 | def create_embedding(self, text): 150 | # Implement your embedding logic here 151 | # Return the embedding as a numpy array or a list of floats 152 | embedding = [0.0] * embedding_dim # Replace with actual embedding logic 153 | return embedding 154 | ``` 155 | 156 | #### Integrating Custom Models with RAPTOR 157 | 158 | After implementing your custom models, integrate them with RAPTOR as follows: 159 | 160 | ```python 161 | from raptor import RetrievalAugmentation, RetrievalAugmentationConfig 162 | 163 | # Initialize your custom models 164 | custom_summarizer = CustomSummarizationModel() 165 | custom_qa = CustomQAModel() 166 | custom_embedding = CustomEmbeddingModel() 167 | 168 | # Create a config with your custom models 169 | custom_config = RetrievalAugmentationConfig( 170 | summarization_model=custom_summarizer, 171 | qa_model=custom_qa, 172 | embedding_model=custom_embedding 173 | ) 174 | 175 | # Initialize RAPTOR with your custom config 176 | RA = RetrievalAugmentation(config=custom_config) 177 | ``` 178 | 179 | Check out `demo.ipynb` for examples on how to specify your own summarization/QA models, such as Llama/Mistral/Gemma, and Embedding Models such as SBERT, for use with RAPTOR. 180 | 181 | Note: More examples and ways to configure RAPTOR are forthcoming. Advanced usage and additional features will be provided in the documentation and repository updates. 182 | 183 | ## Contributing 184 | 185 | RAPTOR is an open-source project, and contributions are welcome. Whether you're fixing bugs, adding new features, or improving documentation, your help is appreciated. 186 | 187 | ## License 188 | 189 | RAPTOR is released under the MIT License. See the LICENSE file in the repository for full details. 190 | 191 | ## Citation 192 | 193 | If RAPTOR assists in your research, please cite it as follows: 194 | 195 | ```bibtex 196 | @inproceedings{sarthi2024raptor, 197 | title={RAPTOR: Recursive Abstractive Processing for Tree-Organized Retrieval}, 198 | author={Sarthi, Parth and Abdullah, Salman and Tuli, Aditi and Khanna, Shubh and Goldie, Anna and Manning, Christopher D.}, 199 | booktitle={International Conference on Learning Representations (ICLR)}, 200 | year={2024} 201 | } 202 | ``` 203 | 204 | Stay tuned for more examples, configuration guides, and updates. 205 | -------------------------------------------------------------------------------- /raptor/FaissRetriever.py: -------------------------------------------------------------------------------- 1 | import random 2 | from concurrent.futures import ProcessPoolExecutor 3 | 4 | import faiss 5 | import numpy as np 6 | import tiktoken 7 | from tqdm import tqdm 8 | 9 | from .EmbeddingModels import BaseEmbeddingModel, OpenAIEmbeddingModel 10 | from .Retrievers import BaseRetriever 11 | from .utils import split_text 12 | 13 | 14 | class FaissRetrieverConfig: 15 | def __init__( 16 | self, 17 | max_tokens=100, 18 | max_context_tokens=3500, 19 | use_top_k=False, 20 | embedding_model=None, 21 | question_embedding_model=None, 22 | top_k=5, 23 | tokenizer=tiktoken.get_encoding("cl100k_base"), 24 | embedding_model_string=None, 25 | ): 26 | if max_tokens < 1: 27 | raise ValueError("max_tokens must be at least 1") 28 | 29 | if top_k < 1: 30 | raise ValueError("top_k must be at least 1") 31 | 32 | if max_context_tokens is not None and max_context_tokens < 1: 33 | raise ValueError("max_context_tokens must be at least 1 or None") 34 | 35 | if embedding_model is not None and not isinstance( 36 | embedding_model, BaseEmbeddingModel 37 | ): 38 | raise ValueError( 39 | "embedding_model must be an instance of BaseEmbeddingModel or None" 40 | ) 41 | 42 | if question_embedding_model is not None and not isinstance( 43 | question_embedding_model, BaseEmbeddingModel 44 | ): 45 | raise ValueError( 46 | "question_embedding_model must be an instance of BaseEmbeddingModel or None" 47 | ) 48 | 49 | self.top_k = top_k 50 | self.max_tokens = max_tokens 51 | self.max_context_tokens = max_context_tokens 52 | self.use_top_k = use_top_k 53 | self.embedding_model = embedding_model or OpenAIEmbeddingModel() 54 | self.question_embedding_model = question_embedding_model or self.embedding_model 55 | self.tokenizer = tokenizer 56 | self.embedding_model_string = embedding_model_string or "OpenAI" 57 | 58 | def log_config(self): 59 | config_summary = """ 60 | FaissRetrieverConfig: 61 | Max Tokens: {max_tokens} 62 | Max Context Tokens: {max_context_tokens} 63 | Use Top K: {use_top_k} 64 | Embedding Model: {embedding_model} 65 | Question Embedding Model: {question_embedding_model} 66 | Top K: {top_k} 67 | Tokenizer: {tokenizer} 68 | Embedding Model String: {embedding_model_string} 69 | """.format( 70 | max_tokens=self.max_tokens, 71 | max_context_tokens=self.max_context_tokens, 72 | use_top_k=self.use_top_k, 73 | embedding_model=self.embedding_model, 74 | question_embedding_model=self.question_embedding_model, 75 | top_k=self.top_k, 76 | tokenizer=self.tokenizer, 77 | embedding_model_string=self.embedding_model_string, 78 | ) 79 | return config_summary 80 | 81 | 82 | class FaissRetriever(BaseRetriever): 83 | """ 84 | FaissRetriever is a class that retrieves similar context chunks for a given query using Faiss. 85 | encoders_type is 'same' if the question and context encoder is the same, 86 | otherwise, encoders_type is 'different'. 87 | """ 88 | 89 | def __init__(self, config): 90 | self.embedding_model = config.embedding_model 91 | self.question_embedding_model = config.question_embedding_model 92 | self.index = None 93 | self.context_chunks = None 94 | self.max_tokens = config.max_tokens 95 | self.max_context_tokens = config.max_context_tokens 96 | self.use_top_k = config.use_top_k 97 | self.tokenizer = config.tokenizer 98 | self.top_k = config.top_k 99 | self.embedding_model_string = config.embedding_model_string 100 | 101 | def build_from_text(self, doc_text): 102 | """ 103 | Builds the index from a given text. 104 | 105 | :param doc_text: A string containing the document text. 106 | :param tokenizer: A tokenizer used to split the text into chunks. 107 | :param max_tokens: An integer representing the maximum number of tokens per chunk. 108 | """ 109 | self.context_chunks = np.array( 110 | split_text(doc_text, self.tokenizer, self.max_tokens) 111 | ) 112 | 113 | with ProcessPoolExecutor() as executor: 114 | futures = [ 115 | executor.submit(self.embedding_model.create_embedding, context_chunk) 116 | for context_chunk in self.context_chunks 117 | ] 118 | 119 | self.embeddings = [] 120 | for future in tqdm(futures, total=len(futures), desc="Building embeddings"): 121 | self.embeddings.append(future.result()) 122 | 123 | self.embeddings = np.array(self.embeddings, dtype=np.float32) 124 | 125 | self.index = faiss.IndexFlatIP(self.embeddings.shape[1]) 126 | self.index.add(self.embeddings) 127 | 128 | def build_from_leaf_nodes(self, leaf_nodes): 129 | """ 130 | Builds the index from a given text. 131 | 132 | :param doc_text: A string containing the document text. 133 | :param tokenizer: A tokenizer used to split the text into chunks. 134 | :param max_tokens: An integer representing the maximum number of tokens per chunk. 135 | """ 136 | 137 | self.context_chunks = [node.text for node in leaf_nodes] 138 | 139 | self.embeddings = np.array( 140 | [node.embeddings[self.embedding_model_string] for node in leaf_nodes], 141 | dtype=np.float32, 142 | ) 143 | 144 | self.index = faiss.IndexFlatIP(self.embeddings.shape[1]) 145 | self.index.add(self.embeddings) 146 | 147 | def sanity_check(self, num_samples=4): 148 | """ 149 | Perform a sanity check by recomputing embeddings of a few randomly-selected chunks. 150 | 151 | :param num_samples: The number of samples to test. 152 | """ 153 | indices = random.sample(range(len(self.context_chunks)), num_samples) 154 | 155 | for i in indices: 156 | original_embedding = self.embeddings[i] 157 | recomputed_embedding = self.embedding_model.create_embedding( 158 | self.context_chunks[i] 159 | ) 160 | assert np.allclose( 161 | original_embedding, recomputed_embedding 162 | ), f"Embeddings do not match for index {i}!" 163 | 164 | print(f"Sanity check passed for {num_samples} random samples.") 165 | 166 | def retrieve(self, query: str) -> str: 167 | """ 168 | Retrieves the k most similar context chunks for a given query. 169 | 170 | :param query: A string containing the query. 171 | :param k: An integer representing the number of similar context chunks to retrieve. 172 | :return: A string containing the retrieved context chunks. 173 | """ 174 | query_embedding = np.array( 175 | [ 176 | np.array( 177 | self.question_embedding_model.create_embedding(query), 178 | dtype=np.float32, 179 | ).squeeze() 180 | ] 181 | ) 182 | 183 | context = "" 184 | 185 | if self.use_top_k: 186 | _, indices = self.index.search(query_embedding, self.top_k) 187 | for i in range(self.top_k): 188 | context += self.context_chunks[indices[0][i]] 189 | 190 | else: 191 | range_ = int(self.max_context_tokens / self.max_tokens) 192 | _, indices = self.index.search(query_embedding, range_) 193 | total_tokens = 0 194 | for i in range(range_): 195 | tokens = len(self.tokenizer.encode(self.context_chunks[indices[0][i]])) 196 | context += self.context_chunks[indices[0][i]] 197 | if total_tokens + tokens > self.max_context_tokens: 198 | break 199 | total_tokens += tokens 200 | 201 | return context 202 | -------------------------------------------------------------------------------- /demo.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "912cd8c6-d405-4dfe-8897-46108e6a6af7", 6 | "metadata": {}, 7 | "source": [ 8 | "# RAPTOR: Recursive Abstractive Processing for Tree-Organized Retrieval" 9 | ] 10 | }, 11 | { 12 | "cell_type": "code", 13 | "execution_count": null, 14 | "id": "631b09a3", 15 | "metadata": {}, 16 | "outputs": [], 17 | "source": [ 18 | "# NOTE: An OpenAI API key must be set here for application initialization, even if not in use.\n", 19 | "# If you're not utilizing OpenAI models, assign a placeholder string (e.g., \"not_used\").\n", 20 | "import os\n", 21 | "os.environ[\"OPENAI_API_KEY\"] = \"your-openai-key\"" 22 | ] 23 | }, 24 | { 25 | "cell_type": "code", 26 | "execution_count": null, 27 | "id": "e2d7d995-7beb-40b5-9a44-afd350b7d221", 28 | "metadata": {}, 29 | "outputs": [], 30 | "source": [ 31 | "# Cinderella story defined in sample.txt\n", 32 | "with open('demo/sample.txt', 'r') as file:\n", 33 | " text = file.read()\n", 34 | "\n", 35 | "print(text[:100])" 36 | ] 37 | }, 38 | { 39 | "cell_type": "markdown", 40 | "id": "c7d51ebd-5597-4fdd-8c37-32636395081b", 41 | "metadata": {}, 42 | "source": [ 43 | "1) **Building**: RAPTOR recursively embeds, clusters, and summarizes chunks of text to construct a tree with varying levels of summarization from the bottom up. You can create a tree from the text in 'sample.txt' using `RA.add_documents(text)`.\n", 44 | "\n", 45 | "2) **Querying**: At inference time, the RAPTOR model retrieves information from this tree, integrating data across lengthy documents at different abstraction levels. You can perform queries on the tree with `RA.answer_question`." 46 | ] 47 | }, 48 | { 49 | "cell_type": "markdown", 50 | "id": "f4f58830-9004-48a4-b50e-61a855511d24", 51 | "metadata": {}, 52 | "source": [ 53 | "### Building the tree" 54 | ] 55 | }, 56 | { 57 | "cell_type": "code", 58 | "execution_count": null, 59 | "id": "3753fcf9-0a8e-4ab3-bf3a-6be38ef6cd1e", 60 | "metadata": {}, 61 | "outputs": [], 62 | "source": [ 63 | "from raptor import RetrievalAugmentation " 64 | ] 65 | }, 66 | { 67 | "cell_type": "code", 68 | "execution_count": null, 69 | "id": "7e843edf", 70 | "metadata": {}, 71 | "outputs": [], 72 | "source": [ 73 | "RA = RetrievalAugmentation()\n", 74 | "\n", 75 | "# construct the tree\n", 76 | "RA.add_documents(text)" 77 | ] 78 | }, 79 | { 80 | "cell_type": "markdown", 81 | "id": "f219d60a-1f0b-4cee-89eb-2ae026f13e63", 82 | "metadata": {}, 83 | "source": [ 84 | "### Querying from the tree\n", 85 | "\n", 86 | "```python\n", 87 | "question = # any question\n", 88 | "RA.answer_question(question)\n", 89 | "```" 90 | ] 91 | }, 92 | { 93 | "cell_type": "code", 94 | "execution_count": null, 95 | "id": "1b4037c5-ad5a-424b-80e4-a67b8e00773b", 96 | "metadata": {}, 97 | "outputs": [], 98 | "source": [ 99 | "question = \"How did Cinderella reach her happy ending ?\"\n", 100 | "\n", 101 | "answer = RA.answer_question(question=question)\n", 102 | "\n", 103 | "print(\"Answer: \", answer)" 104 | ] 105 | }, 106 | { 107 | "cell_type": "code", 108 | "execution_count": null, 109 | "id": "f5be7e57", 110 | "metadata": {}, 111 | "outputs": [], 112 | "source": [ 113 | "# Save the tree by calling RA.save(\"path/to/save\")\n", 114 | "SAVE_PATH = \"demo/cinderella\"\n", 115 | "RA.save(SAVE_PATH)" 116 | ] 117 | }, 118 | { 119 | "cell_type": "code", 120 | "execution_count": null, 121 | "id": "2e845de9", 122 | "metadata": {}, 123 | "outputs": [], 124 | "source": [ 125 | "# load back the tree by passing it into RetrievalAugmentation\n", 126 | "\n", 127 | "RA = RetrievalAugmentation(tree=SAVE_PATH)\n", 128 | "\n", 129 | "answer = RA.answer_question(question=question)\n", 130 | "print(\"Answer: \", answer)" 131 | ] 132 | }, 133 | { 134 | "cell_type": "markdown", 135 | "id": "277ab6ea-1c79-4ed1-97de-1c2e39d6db2e", 136 | "metadata": {}, 137 | "source": [ 138 | "## Using other Open Source Models for Summarization/QA/Embeddings\n", 139 | "\n", 140 | "If you want to use other models such as Llama or Mistral, you can very easily define your own models and use them with RAPTOR. " 141 | ] 142 | }, 143 | { 144 | "cell_type": "code", 145 | "execution_count": null, 146 | "id": "f86cbe7e", 147 | "metadata": {}, 148 | "outputs": [], 149 | "source": [ 150 | "import torch\n", 151 | "from raptor import BaseSummarizationModel, BaseQAModel, BaseEmbeddingModel, RetrievalAugmentationConfig\n", 152 | "from transformers import AutoTokenizer, pipeline" 153 | ] 154 | }, 155 | { 156 | "cell_type": "code", 157 | "execution_count": null, 158 | "id": "fe5cef43", 159 | "metadata": {}, 160 | "outputs": [], 161 | "source": [ 162 | "# if you want to use the Gemma, you will need to authenticate with HuggingFace, Skip this step, if you have the model already downloaded\n", 163 | "from huggingface_hub import login\n", 164 | "login()" 165 | ] 166 | }, 167 | { 168 | "cell_type": "code", 169 | "execution_count": null, 170 | "id": "245b91a5", 171 | "metadata": {}, 172 | "outputs": [], 173 | "source": [ 174 | "from transformers import AutoTokenizer, pipeline\n", 175 | "import torch\n", 176 | "\n", 177 | "# You can define your own Summarization model by extending the base Summarization Class. \n", 178 | "class GEMMASummarizationModel(BaseSummarizationModel):\n", 179 | " def __init__(self, model_name=\"google/gemma-2b-it\"):\n", 180 | " # Initialize the tokenizer and the pipeline for the GEMMA model\n", 181 | " self.tokenizer = AutoTokenizer.from_pretrained(model_name)\n", 182 | " self.summarization_pipeline = pipeline(\n", 183 | " \"text-generation\",\n", 184 | " model=model_name,\n", 185 | " model_kwargs={\"torch_dtype\": torch.bfloat16},\n", 186 | " device=torch.device('cuda' if torch.cuda.is_available() else 'cpu'), # Use \"cpu\" if CUDA is not available\n", 187 | " )\n", 188 | "\n", 189 | " def summarize(self, context, max_tokens=150):\n", 190 | " # Format the prompt for summarization\n", 191 | " messages=[\n", 192 | " {\"role\": \"user\", \"content\": f\"Write a summary of the following, including as many key details as possible: {context}:\"}\n", 193 | " ]\n", 194 | " \n", 195 | " prompt = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)\n", 196 | " \n", 197 | " # Generate the summary using the pipeline\n", 198 | " outputs = self.summarization_pipeline(\n", 199 | " prompt,\n", 200 | " max_new_tokens=max_tokens,\n", 201 | " do_sample=True,\n", 202 | " temperature=0.7,\n", 203 | " top_k=50,\n", 204 | " top_p=0.95\n", 205 | " )\n", 206 | " \n", 207 | " # Extracting and returning the generated summary\n", 208 | " summary = outputs[0][\"generated_text\"].strip()\n", 209 | " return summary\n" 210 | ] 211 | }, 212 | { 213 | "cell_type": "code", 214 | "execution_count": null, 215 | "id": "a171496d", 216 | "metadata": {}, 217 | "outputs": [], 218 | "source": [ 219 | "class GEMMAQAModel(BaseQAModel):\n", 220 | " def __init__(self, model_name= \"google/gemma-2b-it\"):\n", 221 | " # Initialize the tokenizer and the pipeline for the model\n", 222 | " self.tokenizer = AutoTokenizer.from_pretrained(model_name)\n", 223 | " self.qa_pipeline = pipeline(\n", 224 | " \"text-generation\",\n", 225 | " model=model_name,\n", 226 | " model_kwargs={\"torch_dtype\": torch.bfloat16},\n", 227 | " device=torch.device('cuda' if torch.cuda.is_available() else 'cpu'),\n", 228 | " )\n", 229 | "\n", 230 | " def answer_question(self, context, question):\n", 231 | " # Apply the chat template for the context and question\n", 232 | " messages=[\n", 233 | " {\"role\": \"user\", \"content\": f\"Given Context: {context} Give the best full answer amongst the option to question {question}\"}\n", 234 | " ]\n", 235 | " prompt = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)\n", 236 | " \n", 237 | " # Generate the answer using the pipeline\n", 238 | " outputs = self.qa_pipeline(\n", 239 | " prompt,\n", 240 | " max_new_tokens=256,\n", 241 | " do_sample=True,\n", 242 | " temperature=0.7,\n", 243 | " top_k=50,\n", 244 | " top_p=0.95\n", 245 | " )\n", 246 | " \n", 247 | " # Extracting and returning the generated answer\n", 248 | " answer = outputs[0][\"generated_text\"][len(prompt):]\n", 249 | " return answer" 250 | ] 251 | }, 252 | { 253 | "cell_type": "code", 254 | "execution_count": null, 255 | "id": "878f7c7b", 256 | "metadata": {}, 257 | "outputs": [], 258 | "source": [ 259 | "from sentence_transformers import SentenceTransformer\n", 260 | "class SBertEmbeddingModel(BaseEmbeddingModel):\n", 261 | " def __init__(self, model_name=\"sentence-transformers/multi-qa-mpnet-base-cos-v1\"):\n", 262 | " self.model = SentenceTransformer(model_name)\n", 263 | "\n", 264 | " def create_embedding(self, text):\n", 265 | " return self.model.encode(text)\n" 266 | ] 267 | }, 268 | { 269 | "cell_type": "code", 270 | "execution_count": null, 271 | "id": "255791ce", 272 | "metadata": {}, 273 | "outputs": [], 274 | "source": [ 275 | "RAC = RetrievalAugmentationConfig(summarization_model=GEMMASummarizationModel(), qa_model=GEMMAQAModel(), embedding_model=SBertEmbeddingModel())" 276 | ] 277 | }, 278 | { 279 | "cell_type": "code", 280 | "execution_count": null, 281 | "id": "fee46f1d", 282 | "metadata": {}, 283 | "outputs": [], 284 | "source": [ 285 | "RA = RetrievalAugmentation(config=RAC)" 286 | ] 287 | }, 288 | { 289 | "cell_type": "code", 290 | "execution_count": null, 291 | "id": "afe05daf", 292 | "metadata": {}, 293 | "outputs": [], 294 | "source": [ 295 | "with open('demo/sample.txt', 'r') as file:\n", 296 | " text = file.read()\n", 297 | " \n", 298 | "RA.add_documents(text)" 299 | ] 300 | }, 301 | { 302 | "cell_type": "code", 303 | "execution_count": null, 304 | "id": "7eee5847", 305 | "metadata": {}, 306 | "outputs": [], 307 | "source": [ 308 | "question = \"How did Cinderella reach her happy ending?\"\n", 309 | "\n", 310 | "answer = RA.answer_question(question=question)\n", 311 | "\n", 312 | "print(\"Answer: \", answer)" 313 | ] 314 | } 315 | ], 316 | "metadata": { 317 | "kernelspec": { 318 | "display_name": "RAPTOR_env", 319 | "language": "python", 320 | "name": "raptor_env" 321 | }, 322 | "language_info": { 323 | "codemirror_mode": { 324 | "name": "ipython", 325 | "version": 3 326 | }, 327 | "file_extension": ".py", 328 | "mimetype": "text/x-python", 329 | "name": "python", 330 | "nbconvert_exporter": "python", 331 | "pygments_lexer": "ipython3", 332 | "version": "3.8.16" 333 | } 334 | }, 335 | "nbformat": 4, 336 | "nbformat_minor": 5 337 | } 338 | -------------------------------------------------------------------------------- /raptor/tree_retriever.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | from typing import Dict, List, Set 4 | 5 | import tiktoken 6 | from tenacity import retry, stop_after_attempt, wait_random_exponential 7 | 8 | from .EmbeddingModels import BaseEmbeddingModel, OpenAIEmbeddingModel 9 | from .Retrievers import BaseRetriever 10 | from .tree_structures import Node, Tree 11 | from .utils import (distances_from_embeddings, get_children, get_embeddings, 12 | get_node_list, get_text, 13 | indices_of_nearest_neighbors_from_distances, 14 | reverse_mapping) 15 | 16 | logging.basicConfig(format="%(asctime)s - %(message)s", level=logging.INFO) 17 | 18 | 19 | class TreeRetrieverConfig: 20 | def __init__( 21 | self, 22 | tokenizer=None, 23 | threshold=None, 24 | top_k=None, 25 | selection_mode=None, 26 | context_embedding_model=None, 27 | embedding_model=None, 28 | num_layers=None, 29 | start_layer=None, 30 | ): 31 | if tokenizer is None: 32 | tokenizer = tiktoken.get_encoding("cl100k_base") 33 | self.tokenizer = tokenizer 34 | 35 | if threshold is None: 36 | threshold = 0.5 37 | if not isinstance(threshold, float) or not (0 <= threshold <= 1): 38 | raise ValueError("threshold must be a float between 0 and 1") 39 | self.threshold = threshold 40 | 41 | if top_k is None: 42 | top_k = 5 43 | if not isinstance(top_k, int) or top_k < 1: 44 | raise ValueError("top_k must be an integer and at least 1") 45 | self.top_k = top_k 46 | 47 | if selection_mode is None: 48 | selection_mode = "top_k" 49 | if not isinstance(selection_mode, str) or selection_mode not in [ 50 | "top_k", 51 | "threshold", 52 | ]: 53 | raise ValueError( 54 | "selection_mode must be a string and either 'top_k' or 'threshold'" 55 | ) 56 | self.selection_mode = selection_mode 57 | 58 | if context_embedding_model is None: 59 | context_embedding_model = "OpenAI" 60 | if not isinstance(context_embedding_model, str): 61 | raise ValueError("context_embedding_model must be a string") 62 | self.context_embedding_model = context_embedding_model 63 | 64 | if embedding_model is None: 65 | embedding_model = OpenAIEmbeddingModel() 66 | if not isinstance(embedding_model, BaseEmbeddingModel): 67 | raise ValueError( 68 | "embedding_model must be an instance of BaseEmbeddingModel" 69 | ) 70 | self.embedding_model = embedding_model 71 | 72 | if num_layers is not None: 73 | if not isinstance(num_layers, int) or num_layers < 0: 74 | raise ValueError("num_layers must be an integer and at least 0") 75 | self.num_layers = num_layers 76 | 77 | if start_layer is not None: 78 | if not isinstance(start_layer, int) or start_layer < 0: 79 | raise ValueError("start_layer must be an integer and at least 0") 80 | self.start_layer = start_layer 81 | 82 | def log_config(self): 83 | config_log = """ 84 | TreeRetrieverConfig: 85 | Tokenizer: {tokenizer} 86 | Threshold: {threshold} 87 | Top K: {top_k} 88 | Selection Mode: {selection_mode} 89 | Context Embedding Model: {context_embedding_model} 90 | Embedding Model: {embedding_model} 91 | Num Layers: {num_layers} 92 | Start Layer: {start_layer} 93 | """.format( 94 | tokenizer=self.tokenizer, 95 | threshold=self.threshold, 96 | top_k=self.top_k, 97 | selection_mode=self.selection_mode, 98 | context_embedding_model=self.context_embedding_model, 99 | embedding_model=self.embedding_model, 100 | num_layers=self.num_layers, 101 | start_layer=self.start_layer, 102 | ) 103 | return config_log 104 | 105 | 106 | class TreeRetriever(BaseRetriever): 107 | 108 | def __init__(self, config, tree) -> None: 109 | if not isinstance(tree, Tree): 110 | raise ValueError("tree must be an instance of Tree") 111 | 112 | if config.num_layers is not None and config.num_layers > tree.num_layers + 1: 113 | raise ValueError( 114 | "num_layers in config must be less than or equal to tree.num_layers + 1" 115 | ) 116 | 117 | if config.start_layer is not None and config.start_layer > tree.num_layers: 118 | raise ValueError( 119 | "start_layer in config must be less than or equal to tree.num_layers" 120 | ) 121 | 122 | self.tree = tree 123 | self.num_layers = ( 124 | config.num_layers if config.num_layers is not None else tree.num_layers + 1 125 | ) 126 | self.start_layer = ( 127 | config.start_layer if config.start_layer is not None else tree.num_layers 128 | ) 129 | 130 | if self.num_layers > self.start_layer + 1: 131 | raise ValueError("num_layers must be less than or equal to start_layer + 1") 132 | 133 | self.tokenizer = config.tokenizer 134 | self.top_k = config.top_k 135 | self.threshold = config.threshold 136 | self.selection_mode = config.selection_mode 137 | self.embedding_model = config.embedding_model 138 | self.context_embedding_model = config.context_embedding_model 139 | 140 | self.tree_node_index_to_layer = reverse_mapping(self.tree.layer_to_nodes) 141 | 142 | logging.info( 143 | f"Successfully initialized TreeRetriever with Config {config.log_config()}" 144 | ) 145 | 146 | def create_embedding(self, text: str) -> List[float]: 147 | """ 148 | Generates embeddings for the given text using the specified embedding model. 149 | 150 | Args: 151 | text (str): The text for which to generate embeddings. 152 | 153 | Returns: 154 | List[float]: The generated embeddings. 155 | """ 156 | return self.embedding_model.create_embedding(text) 157 | 158 | def retrieve_information_collapse_tree(self, query: str, top_k: int, max_tokens: int) -> str: 159 | """ 160 | Retrieves the most relevant information from the tree based on the query. 161 | 162 | Args: 163 | query (str): The query text. 164 | max_tokens (int): The maximum number of tokens. 165 | 166 | Returns: 167 | str: The context created using the most relevant nodes. 168 | """ 169 | 170 | query_embedding = self.create_embedding(query) 171 | 172 | selected_nodes = [] 173 | 174 | node_list = get_node_list(self.tree.all_nodes) 175 | 176 | embeddings = get_embeddings(node_list, self.context_embedding_model) 177 | 178 | distances = distances_from_embeddings(query_embedding, embeddings) 179 | 180 | indices = indices_of_nearest_neighbors_from_distances(distances) 181 | 182 | total_tokens = 0 183 | for idx in indices[:top_k]: 184 | 185 | node = node_list[idx] 186 | node_tokens = len(self.tokenizer.encode(node.text)) 187 | 188 | if total_tokens + node_tokens > max_tokens: 189 | break 190 | 191 | selected_nodes.append(node) 192 | total_tokens += node_tokens 193 | 194 | context = get_text(selected_nodes) 195 | return selected_nodes, context 196 | 197 | def retrieve_information( 198 | self, current_nodes: List[Node], query: str, num_layers: int 199 | ) -> str: 200 | """ 201 | Retrieves the most relevant information from the tree based on the query. 202 | 203 | Args: 204 | current_nodes (List[Node]): A List of the current nodes. 205 | query (str): The query text. 206 | num_layers (int): The number of layers to traverse. 207 | 208 | Returns: 209 | str: The context created using the most relevant nodes. 210 | """ 211 | 212 | query_embedding = self.create_embedding(query) 213 | 214 | selected_nodes = [] 215 | 216 | node_list = current_nodes 217 | 218 | for layer in range(num_layers): 219 | 220 | embeddings = get_embeddings(node_list, self.context_embedding_model) 221 | 222 | distances = distances_from_embeddings(query_embedding, embeddings) 223 | 224 | indices = indices_of_nearest_neighbors_from_distances(distances) 225 | 226 | if self.selection_mode == "threshold": 227 | best_indices = [ 228 | index for index in indices if distances[index] > self.threshold 229 | ] 230 | 231 | elif self.selection_mode == "top_k": 232 | best_indices = indices[: self.top_k] 233 | 234 | nodes_to_add = [node_list[idx] for idx in best_indices] 235 | 236 | selected_nodes.extend(nodes_to_add) 237 | 238 | if layer != num_layers - 1: 239 | 240 | child_nodes = [] 241 | 242 | for index in best_indices: 243 | child_nodes.extend(node_list[index].children) 244 | 245 | # take the unique values 246 | child_nodes = list(dict.fromkeys(child_nodes)) 247 | node_list = [self.tree.all_nodes[i] for i in child_nodes] 248 | 249 | context = get_text(selected_nodes) 250 | return selected_nodes, context 251 | 252 | def retrieve( 253 | self, 254 | query: str, 255 | start_layer: int = None, 256 | num_layers: int = None, 257 | top_k: int = 10, 258 | max_tokens: int = 3500, 259 | collapse_tree: bool = True, 260 | return_layer_information: bool = False, 261 | ) -> str: 262 | """ 263 | Queries the tree and returns the most relevant information. 264 | 265 | Args: 266 | query (str): The query text. 267 | start_layer (int): The layer to start from. Defaults to self.start_layer. 268 | num_layers (int): The number of layers to traverse. Defaults to self.num_layers. 269 | max_tokens (int): The maximum number of tokens. Defaults to 3500. 270 | collapse_tree (bool): Whether to retrieve information from all nodes. Defaults to False. 271 | 272 | Returns: 273 | str: The result of the query. 274 | """ 275 | 276 | if not isinstance(query, str): 277 | raise ValueError("query must be a string") 278 | 279 | if not isinstance(max_tokens, int) or max_tokens < 1: 280 | raise ValueError("max_tokens must be an integer and at least 1") 281 | 282 | if not isinstance(collapse_tree, bool): 283 | raise ValueError("collapse_tree must be a boolean") 284 | 285 | # Set defaults 286 | start_layer = self.start_layer if start_layer is None else start_layer 287 | num_layers = self.num_layers if num_layers is None else num_layers 288 | 289 | if not isinstance(start_layer, int) or not ( 290 | 0 <= start_layer <= self.tree.num_layers 291 | ): 292 | raise ValueError( 293 | "start_layer must be an integer between 0 and tree.num_layers" 294 | ) 295 | 296 | if not isinstance(num_layers, int) or num_layers < 1: 297 | raise ValueError("num_layers must be an integer and at least 1") 298 | 299 | if num_layers > (start_layer + 1): 300 | raise ValueError("num_layers must be less than or equal to start_layer + 1") 301 | 302 | if collapse_tree: 303 | logging.info(f"Using collapsed_tree") 304 | selected_nodes, context = self.retrieve_information_collapse_tree( 305 | query, top_k, max_tokens 306 | ) 307 | else: 308 | layer_nodes = self.tree.layer_to_nodes[start_layer] 309 | selected_nodes, context = self.retrieve_information( 310 | layer_nodes, query, num_layers 311 | ) 312 | 313 | if return_layer_information: 314 | 315 | layer_information = [] 316 | 317 | for node in selected_nodes: 318 | layer_information.append( 319 | { 320 | "node_index": node.index, 321 | "layer_number": self.tree_node_index_to_layer[node.index], 322 | } 323 | ) 324 | 325 | return context, layer_information 326 | 327 | return context 328 | -------------------------------------------------------------------------------- /raptor/RetrievalAugmentation.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import pickle 3 | 4 | from .cluster_tree_builder import ClusterTreeBuilder, ClusterTreeConfig 5 | from .EmbeddingModels import BaseEmbeddingModel 6 | from .QAModels import BaseQAModel, GPT3TurboQAModel 7 | from .SummarizationModels import BaseSummarizationModel 8 | from .tree_builder import TreeBuilder, TreeBuilderConfig 9 | from .tree_retriever import TreeRetriever, TreeRetrieverConfig 10 | from .tree_structures import Node, Tree 11 | 12 | # Define a dictionary to map supported tree builders to their respective configs 13 | supported_tree_builders = {"cluster": (ClusterTreeBuilder, ClusterTreeConfig)} 14 | 15 | logging.basicConfig(format="%(asctime)s - %(message)s", level=logging.INFO) 16 | 17 | 18 | class RetrievalAugmentationConfig: 19 | def __init__( 20 | self, 21 | tree_builder_config=None, 22 | tree_retriever_config=None, # Change from default instantiation 23 | qa_model=None, 24 | embedding_model=None, 25 | summarization_model=None, 26 | tree_builder_type="cluster", 27 | # New parameters for TreeRetrieverConfig and TreeBuilderConfig 28 | # TreeRetrieverConfig arguments 29 | tr_tokenizer=None, 30 | tr_threshold=0.5, 31 | tr_top_k=5, 32 | tr_selection_mode="top_k", 33 | tr_context_embedding_model="OpenAI", 34 | tr_embedding_model=None, 35 | tr_num_layers=None, 36 | tr_start_layer=None, 37 | # TreeBuilderConfig arguments 38 | tb_tokenizer=None, 39 | tb_max_tokens=100, 40 | tb_num_layers=5, 41 | tb_threshold=0.5, 42 | tb_top_k=5, 43 | tb_selection_mode="top_k", 44 | tb_summarization_length=100, 45 | tb_summarization_model=None, 46 | tb_embedding_models=None, 47 | tb_cluster_embedding_model="OpenAI", 48 | ): 49 | # Validate tree_builder_type 50 | if tree_builder_type not in supported_tree_builders: 51 | raise ValueError( 52 | f"tree_builder_type must be one of {list(supported_tree_builders.keys())}" 53 | ) 54 | 55 | # Validate qa_model 56 | if qa_model is not None and not isinstance(qa_model, BaseQAModel): 57 | raise ValueError("qa_model must be an instance of BaseQAModel") 58 | 59 | if embedding_model is not None and not isinstance( 60 | embedding_model, BaseEmbeddingModel 61 | ): 62 | raise ValueError( 63 | "embedding_model must be an instance of BaseEmbeddingModel" 64 | ) 65 | elif embedding_model is not None: 66 | if tb_embedding_models is not None: 67 | raise ValueError( 68 | "Only one of 'tb_embedding_models' or 'embedding_model' should be provided, not both." 69 | ) 70 | tb_embedding_models = {"EMB": embedding_model} 71 | tr_embedding_model = embedding_model 72 | tb_cluster_embedding_model = "EMB" 73 | tr_context_embedding_model = "EMB" 74 | 75 | if summarization_model is not None and not isinstance( 76 | summarization_model, BaseSummarizationModel 77 | ): 78 | raise ValueError( 79 | "summarization_model must be an instance of BaseSummarizationModel" 80 | ) 81 | 82 | elif summarization_model is not None: 83 | if tb_summarization_model is not None: 84 | raise ValueError( 85 | "Only one of 'tb_summarization_model' or 'summarization_model' should be provided, not both." 86 | ) 87 | tb_summarization_model = summarization_model 88 | 89 | # Set TreeBuilderConfig 90 | tree_builder_class, tree_builder_config_class = supported_tree_builders[ 91 | tree_builder_type 92 | ] 93 | if tree_builder_config is None: 94 | tree_builder_config = tree_builder_config_class( 95 | tokenizer=tb_tokenizer, 96 | max_tokens=tb_max_tokens, 97 | num_layers=tb_num_layers, 98 | threshold=tb_threshold, 99 | top_k=tb_top_k, 100 | selection_mode=tb_selection_mode, 101 | summarization_length=tb_summarization_length, 102 | summarization_model=tb_summarization_model, 103 | embedding_models=tb_embedding_models, 104 | cluster_embedding_model=tb_cluster_embedding_model, 105 | ) 106 | 107 | elif not isinstance(tree_builder_config, tree_builder_config_class): 108 | raise ValueError( 109 | f"tree_builder_config must be a direct instance of {tree_builder_config_class} for tree_builder_type '{tree_builder_type}'" 110 | ) 111 | 112 | # Set TreeRetrieverConfig 113 | if tree_retriever_config is None: 114 | tree_retriever_config = TreeRetrieverConfig( 115 | tokenizer=tr_tokenizer, 116 | threshold=tr_threshold, 117 | top_k=tr_top_k, 118 | selection_mode=tr_selection_mode, 119 | context_embedding_model=tr_context_embedding_model, 120 | embedding_model=tr_embedding_model, 121 | num_layers=tr_num_layers, 122 | start_layer=tr_start_layer, 123 | ) 124 | elif not isinstance(tree_retriever_config, TreeRetrieverConfig): 125 | raise ValueError( 126 | "tree_retriever_config must be an instance of TreeRetrieverConfig" 127 | ) 128 | 129 | # Assign the created configurations to the instance 130 | self.tree_builder_config = tree_builder_config 131 | self.tree_retriever_config = tree_retriever_config 132 | self.qa_model = qa_model or GPT3TurboQAModel() 133 | self.tree_builder_type = tree_builder_type 134 | 135 | def log_config(self): 136 | config_summary = """ 137 | RetrievalAugmentationConfig: 138 | {tree_builder_config} 139 | 140 | {tree_retriever_config} 141 | 142 | QA Model: {qa_model} 143 | Tree Builder Type: {tree_builder_type} 144 | """.format( 145 | tree_builder_config=self.tree_builder_config.log_config(), 146 | tree_retriever_config=self.tree_retriever_config.log_config(), 147 | qa_model=self.qa_model, 148 | tree_builder_type=self.tree_builder_type, 149 | ) 150 | return config_summary 151 | 152 | 153 | class RetrievalAugmentation: 154 | """ 155 | A Retrieval Augmentation class that combines the TreeBuilder and TreeRetriever classes. 156 | Enables adding documents to the tree, retrieving information, and answering questions. 157 | """ 158 | 159 | def __init__(self, config=None, tree=None): 160 | """ 161 | Initializes a RetrievalAugmentation instance with the specified configuration. 162 | Args: 163 | config (RetrievalAugmentationConfig): The configuration for the RetrievalAugmentation instance. 164 | tree: The tree instance or the path to a pickled tree file. 165 | """ 166 | if config is None: 167 | config = RetrievalAugmentationConfig() 168 | if not isinstance(config, RetrievalAugmentationConfig): 169 | raise ValueError( 170 | "config must be an instance of RetrievalAugmentationConfig" 171 | ) 172 | 173 | # Check if tree is a string (indicating a path to a pickled tree) 174 | if isinstance(tree, str): 175 | try: 176 | with open(tree, "rb") as file: 177 | self.tree = pickle.load(file) 178 | if not isinstance(self.tree, Tree): 179 | raise ValueError("The loaded object is not an instance of Tree") 180 | except Exception as e: 181 | raise ValueError(f"Failed to load tree from {tree}: {e}") 182 | elif isinstance(tree, Tree) or tree is None: 183 | self.tree = tree 184 | else: 185 | raise ValueError( 186 | "tree must be an instance of Tree, a path to a pickled Tree, or None" 187 | ) 188 | 189 | tree_builder_class = supported_tree_builders[config.tree_builder_type][0] 190 | self.tree_builder = tree_builder_class(config.tree_builder_config) 191 | 192 | self.tree_retriever_config = config.tree_retriever_config 193 | self.qa_model = config.qa_model 194 | 195 | if self.tree is not None: 196 | self.retriever = TreeRetriever(self.tree_retriever_config, self.tree) 197 | else: 198 | self.retriever = None 199 | 200 | logging.info( 201 | f"Successfully initialized RetrievalAugmentation with Config {config.log_config()}" 202 | ) 203 | 204 | def add_documents(self, docs): 205 | """ 206 | Adds documents to the tree and creates a TreeRetriever instance. 207 | 208 | Args: 209 | docs (str): The input text to add to the tree. 210 | """ 211 | if self.tree is not None: 212 | user_input = input( 213 | "Warning: Overwriting existing tree. Did you mean to call 'add_to_existing' instead? (y/n): " 214 | ) 215 | if user_input.lower() == "y": 216 | # self.add_to_existing(docs) 217 | return 218 | 219 | self.tree = self.tree_builder.build_from_text(text=docs) 220 | self.retriever = TreeRetriever(self.tree_retriever_config, self.tree) 221 | 222 | def retrieve( 223 | self, 224 | question, 225 | start_layer: int = None, 226 | num_layers: int = None, 227 | top_k: int = 10, 228 | max_tokens: int = 3500, 229 | collapse_tree: bool = True, 230 | return_layer_information: bool = True, 231 | ): 232 | """ 233 | Retrieves information and answers a question using the TreeRetriever instance. 234 | 235 | Args: 236 | question (str): The question to answer. 237 | start_layer (int): The layer to start from. Defaults to self.start_layer. 238 | num_layers (int): The number of layers to traverse. Defaults to self.num_layers. 239 | max_tokens (int): The maximum number of tokens. Defaults to 3500. 240 | use_all_information (bool): Whether to retrieve information from all nodes. Defaults to False. 241 | 242 | Returns: 243 | str: The context from which the answer can be found. 244 | 245 | Raises: 246 | ValueError: If the TreeRetriever instance has not been initialized. 247 | """ 248 | if self.retriever is None: 249 | raise ValueError( 250 | "The TreeRetriever instance has not been initialized. Call 'add_documents' first." 251 | ) 252 | 253 | return self.retriever.retrieve( 254 | question, 255 | start_layer, 256 | num_layers, 257 | top_k, 258 | max_tokens, 259 | collapse_tree, 260 | return_layer_information, 261 | ) 262 | 263 | def answer_question( 264 | self, 265 | question, 266 | top_k: int = 10, 267 | start_layer: int = None, 268 | num_layers: int = None, 269 | max_tokens: int = 3500, 270 | collapse_tree: bool = True, 271 | return_layer_information: bool = False, 272 | ): 273 | """ 274 | Retrieves information and answers a question using the TreeRetriever instance. 275 | 276 | Args: 277 | question (str): The question to answer. 278 | start_layer (int): The layer to start from. Defaults to self.start_layer. 279 | num_layers (int): The number of layers to traverse. Defaults to self.num_layers. 280 | max_tokens (int): The maximum number of tokens. Defaults to 3500. 281 | use_all_information (bool): Whether to retrieve information from all nodes. Defaults to False. 282 | 283 | Returns: 284 | str: The answer to the question. 285 | 286 | Raises: 287 | ValueError: If the TreeRetriever instance has not been initialized. 288 | """ 289 | # if return_layer_information: 290 | context, layer_information = self.retrieve( 291 | question, start_layer, num_layers, top_k, max_tokens, collapse_tree, True 292 | ) 293 | 294 | answer = self.qa_model.answer_question(context, question) 295 | 296 | if return_layer_information: 297 | return answer, layer_information 298 | 299 | return answer 300 | 301 | def save(self, path): 302 | if self.tree is None: 303 | raise ValueError("There is no tree to save.") 304 | with open(path, "wb") as file: 305 | pickle.dump(self.tree, file) 306 | logging.info(f"Tree successfully saved to {path}") 307 | -------------------------------------------------------------------------------- /demo/sample.txt: -------------------------------------------------------------------------------- 1 | The wife of a rich man fell sick, and as she felt that her end 2 | was drawing near, she called her only daughter to her bedside and 3 | said, dear child, be good and pious, and then the 4 | good God will always protect you, and I will look down on you 5 | from heaven and be near you. Thereupon she closed her eyes and 6 | departed. Every day the maiden went out to her mother's grave, 7 | and wept, and she remained pious and good. When winter came 8 | the snow spread a white sheet over the grave, and by the time the 9 | spring sun had drawn it off again, the man had taken another wife. 10 | The woman had brought with her into the house two daughters, 11 | who were beautiful and fair of face, but vile and black of heart. 12 | Now began a bad time for the poor step-child. Is the stupid goose 13 | to sit in the parlor with us, they said. He who wants to eat bread 14 | must earn it. Out with the kitchen-wench. They took her pretty 15 | clothes away from her, put an old grey bedgown on her, and gave 16 | her wooden shoes. Just look at the proud princess, how decked 17 | out she is, they cried, and laughed, and led her into the kitchen. 18 | There she had to do hard work from morning till night, get up 19 | before daybreak, carry water, light fires, cook and wash. Besides 20 | this, the sisters did her every imaginable injury - they mocked her 21 | and emptied her peas and lentils into the ashes, so that she was 22 | forced to sit and pick them out again. In the evening when she had 23 | worked till she was weary she had no bed to go to, but had to sleep 24 | by the hearth in the cinders. And as on that account she always 25 | looked dusty and dirty, they called her cinderella. 26 | It happened that the father was once going to the fair, and he 27 | asked his two step-daughters what he should bring back for them. 28 | Beautiful dresses, said one, pearls and jewels, said the second. 29 | And you, cinderella, said he, what will you have. Father 30 | break off for me the first branch which knocks against your hat on 31 | your way home. So he bought beautiful dresses, pearls and jewels 32 | for his two step-daughters, and on his way home, as he was riding 33 | through a green thicket, a hazel twig brushed against him and 34 | knocked off his hat. Then he broke off the branch and took it with 35 | him. When he reached home he gave his step-daughters the things 36 | which they had wished for, and to cinderella he gave the branch 37 | from the hazel-bush. Cinderella thanked him, went to her mother's 38 | grave and planted the branch on it, and wept so much that the tears 39 | fell down on it and watered it. And it grew and became a handsome 40 | tree. Thrice a day cinderella went and sat beneath it, and wept and 41 | prayed, and a little white bird always came on the tree, and if 42 | cinderella expressed a wish, the bird threw down to her what she 43 | had wished for. 44 | It happened, however, that the king gave orders for a festival 45 | which was to last three days, and to which all the beautiful young 46 | girls in the country were invited, in order that his son might choose 47 | himself a bride. When the two step-sisters heard that they too were 48 | to appear among the number, they were delighted, called cinderella 49 | and said, comb our hair for us, brush our shoes and fasten our 50 | buckles, for we are going to the wedding at the king's palace. 51 | Cinderella obeyed, but wept, because she too would have liked to 52 | go with them to the dance, and begged her step-mother to allow 53 | her to do so. You go, cinderella, said she, covered in dust and 54 | dirt as you are, and would go to the festival. You have no clothes 55 | and shoes, and yet would dance. As, however, cinderella went on 56 | asking, the step-mother said at last, I have emptied a dish of 57 | lentils into the ashes for you, if you have picked them out again in 58 | two hours, you shall go with us. The maiden went through the 59 | back-door into the garden, and called, you tame pigeons, you 60 | turtle-doves, and all you birds beneath the sky, come and help me 61 | to pick 62 | the good into the pot, 63 | the bad into the crop. 64 | Then two white pigeons came in by the kitchen window, and 65 | afterwards the turtle-doves, and at last all the birds beneath the 66 | sky, came whirring and crowding in, and alighted amongst the ashes. 67 | And the pigeons nodded with their heads and began pick, pick, 68 | pick, pick, and the rest began also pick, pick, pick, pick, and 69 | gathered all the good grains into the dish. Hardly had one hour 70 | passed before they had finished, and all flew out again. Then the 71 | girl took the dish to her step-mother, and was glad, and believed 72 | that now she would be allowed to go with them to the festival. 73 | But the step-mother said, no, cinderella, you have no clothes and 74 | you can not dance. You would only be laughed at. And as 75 | cinderella wept at this, the step-mother said, if you can pick two 76 | dishes of lentils out of the ashes for me in one hour, you shall go 77 | with us. And she thought to herself, that she most certainly 78 | cannot do again. When the step-mother had emptied the two 79 | dishes of lentils amongst the ashes, the maiden went through the 80 | back-door into the garden and cried, you tame pigeons, you 81 | turtle-doves, and all you birds beneath the sky, come and help me 82 | to pick 83 | the good into the pot, 84 | the bad into the crop. 85 | Then two white pigeons came in by the kitchen-window, and 86 | afterwards the turtle-doves, and at length all the birds beneath the 87 | sky, came whirring and crowding in, and alighted amongst the 88 | ashes. And the doves nodded with their heads and began pick, 89 | pick, pick, pick, and the others began also pick, pick, pick, pick, 90 | and gathered all the good seeds into the dishes, and before half an 91 | hour was over they had already finished, and all flew out again. 92 | Then the maiden was delighted, and believed that she might now go 93 | with them to the wedding. But the step-mother said, all this will 94 | not help. You cannot go with us, for you have no clothes and can 95 | not dance. We should be ashamed of you. On this she turned her 96 | back on cinderella, and hurried away with her two proud daughters. 97 | As no one was now at home, cinderella went to her mother's 98 | grave beneath the hazel-tree, and cried - 99 | shiver and quiver, little tree, 100 | silver and gold throw down over me. 101 | Then the bird threw a gold and silver dress down to her, and 102 | slippers embroidered with silk and silver. She put on the dress 103 | with all speed, and went to the wedding. Her step-sisters and the 104 | step-mother however did not know her, and thought she must be a 105 | foreign princess, for she looked so beautiful in the golden dress. 106 | They never once thought of cinderella, and believed that she was 107 | sitting at home in the dirt, picking lentils out of the ashes. The 108 | prince approached her, took her by the hand and danced with her. 109 | He would dance with no other maiden, and never let loose of her 110 | hand, and if any one else came to invite her, he said, this is my 111 | partner. 112 | She danced till it was evening, and then she wanted to go home. 113 | But the king's son said, I will go with you and bear you company, 114 | for he wished to see to whom the beautiful maiden belonged. 115 | She escaped from him, however, and sprang into the 116 | pigeon-house. The king's son waited until her father came, and 117 | then he told him that the unknown maiden had leapt into the 118 | pigeon-house. The old man thought, can it be cinderella. And 119 | they had to bring him an axe and a pickaxe that he might hew 120 | the pigeon-house to pieces, but no one was inside it. And when they 121 | got home cinderella lay in her dirty clothes among the ashes, and 122 | a dim little oil-lamp was burning on the mantle-piece, for 123 | cinderella had jumped quickly down from the back of the pigeon-house 124 | and had run to the little hazel-tree, and there she had taken off 125 | her beautiful clothes and laid them on the grave, and the bird had 126 | taken them away again, and then she had seated herself in the 127 | kitchen amongst the ashes in her grey gown. 128 | Next day when the festival began afresh, and her parents and 129 | the step-sisters had gone once more, cinderella went to the 130 | hazel-tree and said - 131 | shiver and quiver, my little tree, 132 | silver and gold throw down over me. 133 | Then the bird threw down a much more beautiful dress than on 134 | the preceding day. And when cinderella appeared at the wedding 135 | in this dress, every one was astonished at her beauty. The king's 136 | son had waited until she came, and instantly took her by the hand 137 | and danced with no one but her. When others came and invited 138 | her, he said, this is my partner. When evening came she wished 139 | to leave, and the king's son followed her and wanted to see into 140 | which house she went. But she sprang away from him, and into 141 | the garden behind the house. Therein stood a beautiful tall tree on 142 | which hung the most magnificent pears. She clambered so nimbly 143 | between the branches like a squirrel that the king's son did not 144 | know where she was gone. He waited until her father came, and 145 | said to him, the unknown maiden has escaped from me, and I 146 | believe she has climbed up the pear-tree. The father thought, 147 | can it be cinderella. And had an axe brought and cut the 148 | tree down, but no one was on it. And when they got into the 149 | kitchen, cinderella lay there among the ashes, as usual, for she 150 | had jumped down on the other side of the tree, had taken the 151 | beautiful dress to the bird on the little hazel-tree, and put on her 152 | grey gown. 153 | On the third day, when the parents and sisters had gone away, 154 | cinderella went once more to her mother's grave and said to the 155 | little tree - 156 | shiver and quiver, my little tree, 157 | silver and gold throw down over me. 158 | And now the bird threw down to her a dress which was more 159 | splendid and magnificent than any she had yet had, and the 160 | slippers were golden. And when she went to the festival in the 161 | dress, no one knew how to speak for astonishment. The king's son 162 | danced with her only, and if any one invited her to dance, he said 163 | this is my partner. 164 | When evening came, cinderella wished to leave, and the king's 165 | son was anxious to go with her, but she escaped from him so quickly 166 | that he could not follow her. The king's son, however, had 167 | employed a ruse, and had caused the whole staircase to be smeared 168 | with pitch, and there, when she ran down, had the maiden's left 169 | slipper remained stuck. The king's son picked it up, and it was 170 | small and dainty, and all golden. Next morning, he went with it to 171 | the father, and said to him, no one shall be my wife but she whose 172 | foot this golden slipper fits. Then were the two sisters glad, 173 | for they had pretty feet. The eldest went with the shoe into her 174 | room and wanted to try it on, and her mother stood by. But she 175 | could not get her big toe into it, and the shoe was too small for 176 | her. Then her mother gave her a knife and said, cut the toe off, 177 | when you are queen you will have no more need to go on foot. The 178 | maiden cut the toe off, forced the foot into the shoe, swallowed 179 | the pain, and went out to the king's son. Then he took her on his 180 | his horse as his bride and rode away with her. They were 181 | obliged, however, to pass the grave, and there, on the hazel-tree, 182 | sat the two pigeons and cried - 183 | turn and peep, turn and peep, 184 | there's blood within the shoe, 185 | the shoe it is too small for her, 186 | the true bride waits for you. 187 | Then he looked at her foot and saw how the blood was trickling 188 | from it. He turned his horse round and took the false bride 189 | home again, and said she was not the true one, and that the 190 | other sister was to put the shoe on. Then this one went into her 191 | chamber and got her toes safely into the shoe, but her heel was 192 | too large. So her mother gave her a knife and said, cut a bit 193 | off your heel, when you are queen you will have no more need 194 | to go on foot. The maiden cut a bit off her heel, forced 195 | her foot into the shoe, swallowed the pain, and went out to the 196 | king's son. He took her on his horse as his bride, and rode away 197 | with her, but when they passed by the hazel-tree, the two pigeons 198 | sat on it and cried - 199 | turn and peep, turn and peep, 200 | there's blood within the shoe, 201 | the shoe it is too small for her, 202 | the true bride waits for you. 203 | He looked down at her foot and saw how the blood was running 204 | out of her shoe, and how it had stained her white stocking quite 205 | red. Then he turned his horse and took the false bride home 206 | again. This also is not the right one, said he, have you no 207 | other daughter. No, said the man, there is still a little 208 | stunted kitchen-wench which my late wife left behind her, but 209 | she cannot possibly be the bride. The king's son said he was 210 | to send her up to him, but the mother answered, oh, no, she is 211 | much too dirty, she cannot show herself. But he absolutely 212 | insisted on it, and cinderella had to be called. She first 213 | washed her hands and face clean, and then went and bowed down 214 | before the king's son, who gave her the golden shoe. Then she 215 | seated herself on a stool, drew her foot out of the heavy 216 | wooden shoe, and put it into the slipper, which fitted like a 217 | glove. And when she rose up and the king's son looked at her 218 | face he recognized the beautiful maiden who had danced with 219 | him and cried, that is the true bride. The step-mother and 220 | the two sisters were horrified and became pale with rage, he, 221 | however, took cinderella on his horse and rode away with her. As 222 | they passed by the hazel-tree, the two white doves cried - 223 | turn and peep, turn and peep, 224 | no blood is in the shoe, 225 | the shoe is not too small for her, 226 | the true bride rides with you, 227 | and when they had cried that, the two came flying down and 228 | placed themselves on cinderella's shoulders, one on the right, 229 | the other on the left, and remained sitting there. 230 | When the wedding with the king's son was to be celebrated, the 231 | two false sisters came and wanted to get into favor with 232 | cinderella and share her good fortune. When the betrothed 233 | couple went to church, the elder was at the right side and the 234 | younger at the left, and the pigeons pecked out one eye from 235 | each of them. Afterwards as they came back the elder was at 236 | the left, and the younger at the right, and then the pigeons 237 | pecked out the other eye from each. And thus, for their 238 | wickedness and falsehood, they were punished with blindness 239 | all their days. -------------------------------------------------------------------------------- /raptor/tree_builder.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import logging 3 | import os 4 | from abc import abstractclassmethod 5 | from concurrent.futures import ThreadPoolExecutor, as_completed 6 | from threading import Lock 7 | from typing import Dict, List, Optional, Set, Tuple 8 | 9 | import openai 10 | import tiktoken 11 | from tenacity import retry, stop_after_attempt, wait_random_exponential 12 | 13 | from .EmbeddingModels import BaseEmbeddingModel, OpenAIEmbeddingModel 14 | from .SummarizationModels import (BaseSummarizationModel, 15 | GPT3TurboSummarizationModel) 16 | from .tree_structures import Node, Tree 17 | from .utils import (distances_from_embeddings, get_children, get_embeddings, 18 | get_node_list, get_text, 19 | indices_of_nearest_neighbors_from_distances, split_text) 20 | 21 | logging.basicConfig(format="%(asctime)s - %(message)s", level=logging.INFO) 22 | 23 | 24 | class TreeBuilderConfig: 25 | def __init__( 26 | self, 27 | tokenizer=None, 28 | max_tokens=None, 29 | num_layers=None, 30 | threshold=None, 31 | top_k=None, 32 | selection_mode=None, 33 | summarization_length=None, 34 | summarization_model=None, 35 | embedding_models=None, 36 | cluster_embedding_model=None, 37 | ): 38 | if tokenizer is None: 39 | tokenizer = tiktoken.get_encoding("cl100k_base") 40 | self.tokenizer = tokenizer 41 | 42 | if max_tokens is None: 43 | max_tokens = 100 44 | if not isinstance(max_tokens, int) or max_tokens < 1: 45 | raise ValueError("max_tokens must be an integer and at least 1") 46 | self.max_tokens = max_tokens 47 | 48 | if num_layers is None: 49 | num_layers = 5 50 | if not isinstance(num_layers, int) or num_layers < 1: 51 | raise ValueError("num_layers must be an integer and at least 1") 52 | self.num_layers = num_layers 53 | 54 | if threshold is None: 55 | threshold = 0.5 56 | if not isinstance(threshold, (int, float)) or not (0 <= threshold <= 1): 57 | raise ValueError("threshold must be a number between 0 and 1") 58 | self.threshold = threshold 59 | 60 | if top_k is None: 61 | top_k = 5 62 | if not isinstance(top_k, int) or top_k < 1: 63 | raise ValueError("top_k must be an integer and at least 1") 64 | self.top_k = top_k 65 | 66 | if selection_mode is None: 67 | selection_mode = "top_k" 68 | if selection_mode not in ["top_k", "threshold"]: 69 | raise ValueError("selection_mode must be either 'top_k' or 'threshold'") 70 | self.selection_mode = selection_mode 71 | 72 | if summarization_length is None: 73 | summarization_length = 100 74 | self.summarization_length = summarization_length 75 | 76 | if summarization_model is None: 77 | summarization_model = GPT3TurboSummarizationModel() 78 | if not isinstance(summarization_model, BaseSummarizationModel): 79 | raise ValueError( 80 | "summarization_model must be an instance of BaseSummarizationModel" 81 | ) 82 | self.summarization_model = summarization_model 83 | 84 | if embedding_models is None: 85 | embedding_models = {"OpenAI": OpenAIEmbeddingModel()} 86 | if not isinstance(embedding_models, dict): 87 | raise ValueError( 88 | "embedding_models must be a dictionary of model_name: instance pairs" 89 | ) 90 | for model in embedding_models.values(): 91 | if not isinstance(model, BaseEmbeddingModel): 92 | raise ValueError( 93 | "All embedding models must be an instance of BaseEmbeddingModel" 94 | ) 95 | self.embedding_models = embedding_models 96 | 97 | if cluster_embedding_model is None: 98 | cluster_embedding_model = "OpenAI" 99 | if cluster_embedding_model not in self.embedding_models: 100 | raise ValueError( 101 | "cluster_embedding_model must be a key in the embedding_models dictionary" 102 | ) 103 | self.cluster_embedding_model = cluster_embedding_model 104 | 105 | def log_config(self): 106 | config_log = """ 107 | TreeBuilderConfig: 108 | Tokenizer: {tokenizer} 109 | Max Tokens: {max_tokens} 110 | Num Layers: {num_layers} 111 | Threshold: {threshold} 112 | Top K: {top_k} 113 | Selection Mode: {selection_mode} 114 | Summarization Length: {summarization_length} 115 | Summarization Model: {summarization_model} 116 | Embedding Models: {embedding_models} 117 | Cluster Embedding Model: {cluster_embedding_model} 118 | """.format( 119 | tokenizer=self.tokenizer, 120 | max_tokens=self.max_tokens, 121 | num_layers=self.num_layers, 122 | threshold=self.threshold, 123 | top_k=self.top_k, 124 | selection_mode=self.selection_mode, 125 | summarization_length=self.summarization_length, 126 | summarization_model=self.summarization_model, 127 | embedding_models=self.embedding_models, 128 | cluster_embedding_model=self.cluster_embedding_model, 129 | ) 130 | return config_log 131 | 132 | 133 | class TreeBuilder: 134 | """ 135 | The TreeBuilder class is responsible for building a hierarchical text abstraction 136 | structure, known as a "tree," using summarization models and 137 | embedding models. 138 | """ 139 | 140 | def __init__(self, config) -> None: 141 | """Initializes the tokenizer, maximum tokens, number of layers, top-k value, threshold, and selection mode.""" 142 | 143 | self.tokenizer = config.tokenizer 144 | self.max_tokens = config.max_tokens 145 | self.num_layers = config.num_layers 146 | self.top_k = config.top_k 147 | self.threshold = config.threshold 148 | self.selection_mode = config.selection_mode 149 | self.summarization_length = config.summarization_length 150 | self.summarization_model = config.summarization_model 151 | self.embedding_models = config.embedding_models 152 | self.cluster_embedding_model = config.cluster_embedding_model 153 | 154 | logging.info( 155 | f"Successfully initialized TreeBuilder with Config {config.log_config()}" 156 | ) 157 | 158 | def create_node( 159 | self, index: int, text: str, children_indices: Optional[Set[int]] = None 160 | ) -> Tuple[int, Node]: 161 | """Creates a new node with the given index, text, and (optionally) children indices. 162 | 163 | Args: 164 | index (int): The index of the new node. 165 | text (str): The text associated with the new node. 166 | children_indices (Optional[Set[int]]): A set of indices representing the children of the new node. 167 | If not provided, an empty set will be used. 168 | 169 | Returns: 170 | Tuple[int, Node]: A tuple containing the index and the newly created node. 171 | """ 172 | if children_indices is None: 173 | children_indices = set() 174 | 175 | embeddings = { 176 | model_name: model.create_embedding(text) 177 | for model_name, model in self.embedding_models.items() 178 | } 179 | return (index, Node(text, index, children_indices, embeddings)) 180 | 181 | def create_embedding(self, text) -> List[float]: 182 | """ 183 | Generates embeddings for the given text using the specified embedding model. 184 | 185 | Args: 186 | text (str): The text for which to generate embeddings. 187 | 188 | Returns: 189 | List[float]: The generated embeddings. 190 | """ 191 | return self.embedding_models[self.cluster_embedding_model].create_embedding( 192 | text 193 | ) 194 | 195 | def summarize(self, context, max_tokens=150) -> str: 196 | """ 197 | Generates a summary of the input context using the specified summarization model. 198 | 199 | Args: 200 | context (str, optional): The context to summarize. 201 | max_tokens (int, optional): The maximum number of tokens in the generated summary. Defaults to 150.o 202 | 203 | Returns: 204 | str: The generated summary. 205 | """ 206 | return self.summarization_model.summarize(context, max_tokens) 207 | 208 | def get_relevant_nodes(self, current_node, list_nodes) -> List[Node]: 209 | """ 210 | Retrieves the top-k most relevant nodes to the current node from the list of nodes 211 | based on cosine distance in the embedding space. 212 | 213 | Args: 214 | current_node (Node): The current node. 215 | list_nodes (List[Node]): The list of nodes. 216 | 217 | Returns: 218 | List[Node]: The top-k most relevant nodes. 219 | """ 220 | embeddings = get_embeddings(list_nodes, self.cluster_embedding_model) 221 | distances = distances_from_embeddings( 222 | current_node.embeddings[self.cluster_embedding_model], embeddings 223 | ) 224 | indices = indices_of_nearest_neighbors_from_distances(distances) 225 | 226 | if self.selection_mode == "threshold": 227 | best_indices = [ 228 | index for index in indices if distances[index] > self.threshold 229 | ] 230 | 231 | elif self.selection_mode == "top_k": 232 | best_indices = indices[: self.top_k] 233 | 234 | nodes_to_add = [list_nodes[idx] for idx in best_indices] 235 | 236 | return nodes_to_add 237 | 238 | def multithreaded_create_leaf_nodes(self, chunks: List[str]) -> Dict[int, Node]: 239 | """Creates leaf nodes using multithreading from the given list of text chunks. 240 | 241 | Args: 242 | chunks (List[str]): A list of text chunks to be turned into leaf nodes. 243 | 244 | Returns: 245 | Dict[int, Node]: A dictionary mapping node indices to the corresponding leaf nodes. 246 | """ 247 | with ThreadPoolExecutor() as executor: 248 | future_nodes = { 249 | executor.submit(self.create_node, index, text): (index, text) 250 | for index, text in enumerate(chunks) 251 | } 252 | 253 | leaf_nodes = {} 254 | for future in as_completed(future_nodes): 255 | index, node = future.result() 256 | leaf_nodes[index] = node 257 | 258 | return leaf_nodes 259 | 260 | def build_from_text(self, text: str, use_multithreading: bool = True) -> Tree: 261 | """Builds a golden tree from the input text, optionally using multithreading. 262 | 263 | Args: 264 | text (str): The input text. 265 | use_multithreading (bool, optional): Whether to use multithreading when creating leaf nodes. 266 | Default: True. 267 | 268 | Returns: 269 | Tree: The golden tree structure. 270 | """ 271 | chunks = split_text(text, self.tokenizer, self.max_tokens) 272 | 273 | logging.info("Creating Leaf Nodes") 274 | 275 | if use_multithreading: 276 | leaf_nodes = self.multithreaded_create_leaf_nodes(chunks) 277 | else: 278 | leaf_nodes = {} 279 | for index, text in enumerate(chunks): 280 | __, node = self.create_node(index, text) 281 | leaf_nodes[index] = node 282 | 283 | layer_to_nodes = {0: list(leaf_nodes.values())} 284 | 285 | logging.info(f"Created {len(leaf_nodes)} Leaf Embeddings") 286 | 287 | logging.info("Building All Nodes") 288 | 289 | all_nodes = copy.deepcopy(leaf_nodes) 290 | 291 | root_nodes = self.construct_tree(all_nodes, all_nodes, layer_to_nodes) 292 | 293 | tree = Tree(all_nodes, root_nodes, leaf_nodes, self.num_layers, layer_to_nodes) 294 | 295 | return tree 296 | 297 | @abstractclassmethod 298 | def construct_tree( 299 | self, 300 | current_level_nodes: Dict[int, Node], 301 | all_tree_nodes: Dict[int, Node], 302 | layer_to_nodes: Dict[int, List[Node]], 303 | use_multithreading: bool = True, 304 | ) -> Dict[int, Node]: 305 | """ 306 | Constructs the hierarchical tree structure layer by layer by iteratively summarizing groups 307 | of relevant nodes and updating the current_level_nodes and all_tree_nodes dictionaries at each step. 308 | 309 | Args: 310 | current_level_nodes (Dict[int, Node]): The current set of nodes. 311 | all_tree_nodes (Dict[int, Node]): The dictionary of all nodes. 312 | use_multithreading (bool): Whether to use multithreading to speed up the process. 313 | 314 | Returns: 315 | Dict[int, Node]: The final set of root nodes. 316 | """ 317 | pass 318 | 319 | # logging.info("Using Transformer-like TreeBuilder") 320 | 321 | # def process_node(idx, current_level_nodes, new_level_nodes, all_tree_nodes, next_node_index, lock): 322 | # relevant_nodes_chunk = self.get_relevant_nodes( 323 | # current_level_nodes[idx], current_level_nodes 324 | # ) 325 | 326 | # node_texts = get_text(relevant_nodes_chunk) 327 | 328 | # summarized_text = self.summarize( 329 | # context=node_texts, 330 | # max_tokens=self.summarization_length, 331 | # ) 332 | 333 | # logging.info( 334 | # f"Node Texts Length: {len(self.tokenizer.encode(node_texts))}, Summarized Text Length: {len(self.tokenizer.encode(summarized_text))}" 335 | # ) 336 | 337 | # next_node_index, new_parent_node = self.create_node( 338 | # next_node_index, 339 | # summarized_text, 340 | # {node.index for node in relevant_nodes_chunk} 341 | # ) 342 | 343 | # with lock: 344 | # new_level_nodes[next_node_index] = new_parent_node 345 | 346 | # for layer in range(self.num_layers): 347 | # logging.info(f"Constructing Layer {layer}: ") 348 | 349 | # node_list_current_layer = get_node_list(current_level_nodes) 350 | # next_node_index = len(all_tree_nodes) 351 | 352 | # new_level_nodes = {} 353 | # lock = Lock() 354 | 355 | # if use_multithreading: 356 | # with ThreadPoolExecutor() as executor: 357 | # for idx in range(0, len(node_list_current_layer)): 358 | # executor.submit(process_node, idx, node_list_current_layer, new_level_nodes, all_tree_nodes, next_node_index, lock) 359 | # next_node_index += 1 360 | # executor.shutdown(wait=True) 361 | # else: 362 | # for idx in range(0, len(node_list_current_layer)): 363 | # process_node(idx, node_list_current_layer, new_level_nodes, all_tree_nodes, next_node_index, lock) 364 | 365 | # layer_to_nodes[layer + 1] = list(new_level_nodes.values()) 366 | # current_level_nodes = new_level_nodes 367 | # all_tree_nodes.update(new_level_nodes) 368 | 369 | # return new_level_nodes 370 | --------------------------------------------------------------------------------