├── .gitignore
├── DocstoreTransform.py
├── DocumentContextExtractor.py
├── HybridSearchDemo.py
├── README.md
├── data
└── declaration.txt
├── perf_tests
├── asynciotest.py
├── asynciotest2.py
├── perf_tests.py
├── perftest3.py
├── prideandprejudice.txt
└── token_comparator.py
├── requirements.txt
└── test_document_context_extractor.py
/.gitignore:
--------------------------------------------------------------------------------
1 | .env
2 | hybriddemo
3 | *.pyc
4 | __pycache__
5 | .devcontainer
--------------------------------------------------------------------------------
/DocstoreTransform.py:
--------------------------------------------------------------------------------
1 | from llama_index.core.llms import ChatMessage, LLM
2 | from llama_index.core.async_utils import DEFAULT_NUM_WORKERS, run_jobs
3 | from llama_index.core.extractors import BaseExtractor
4 | from llama_index.core.schema import Document, Node
5 | from llama_index.core import Settings
6 | from llama_index.core.storage.docstore.simple_docstore import DocumentStore
7 | from llama_index.core.node_parser import TokenTextSplitter
8 | from typing import Optional, Dict, List, Set, Union, Literal
9 | from textwrap import dedent
10 | import importlib
11 | import logging
12 | import asyncio
13 | import random
14 | from functools import lru_cache
15 |
16 | OversizeStrategy = Literal["truncate_first", "truncate_last", "warn", "error", "ignore"]
17 | MetadataDict = Dict[str, str]
18 |
19 | DEFAULT_CONTEXT_PROMPT: str = dedent("""
20 | Generate keywords and brief phrases describing the main topics, entities, and actions in this text.
21 | Replace pronouns with their specific referents. Format as comma-separated phrases.
22 | Exclude meta-commentary about the text itself.
23 | """).strip()
24 |
25 | DEFAULT_KEY: str = "context"
26 |
27 | class DocumentContextExtractor(BaseExtractor):
28 | """
29 | Extracts contextual information from documents chunks using LLM-based analysis for enhanced RAG accuracy.
30 | """
31 |
32 | def __init__(
33 | self,
34 | docstore: DocumentStore,
35 | llm: Optional[LLM] = None,
36 | key: str = DEFAULT_KEY,
37 | prompt: str = DEFAULT_CONTEXT_PROMPT,
38 | num_workers: int = DEFAULT_NUM_WORKERS,
39 | max_context_length: int = 128000,
40 | max_contextual_tokens: int = 512,
41 | oversized_document_strategy: OversizeStrategy = "truncate_first",
42 | warn_on_oversize: bool = True,
43 | **kwargs
44 | ) -> None:
45 | if not importlib.util.find_spec("tiktoken"):
46 | raise ValueError("TikToken is required for DocumentContextExtractor. Please install tiktoken.")
47 |
48 | llm = llm or Settings.llm
49 | doc_ids: Set[str] = set()
50 |
51 | super().__init__(
52 | key=key,
53 | prompt=prompt,
54 | llm=llm,
55 | docstore=docstore,
56 | num_workers=num_workers,
57 | doc_ids=doc_ids,
58 | max_context_length=max_context_length,
59 | oversized_document_strategy=oversized_document_strategy,
60 | max_contextual_tokens=max_contextual_tokens,
61 | warn_on_oversize=warn_on_oversize,
62 | **kwargs
63 | )
64 |
65 | @staticmethod
66 | def _truncate_text(
67 | text: str,
68 | max_token_count: int,
69 | how: Literal['first', 'last'] = 'first'
70 | ) -> str:
71 | text_splitter = TokenTextSplitter(chunk_size=max_token_count, chunk_overlap=0)
72 | chunks = text_splitter.split_text(text)
73 |
74 | if not chunks:
75 | return ""
76 |
77 | if how == 'first':
78 | return chunks[0]
79 | elif how == 'last':
80 | return chunks[-1]
81 |
82 | raise ValueError("Invalid truncation method. Must be 'first' or 'last'.")
83 |
84 | @staticmethod
85 | def _count_tokens(text: str) -> int:
86 | text_splitter = TokenTextSplitter(chunk_size=1, chunk_overlap=0)
87 | tokens = text_splitter.split_text(text)
88 | return len(tokens)
89 |
90 | async def _agenerate_node_context(
91 | self,
92 | node: Node,
93 | metadata: MetadataDict,
94 | document: Document,
95 | prompt: str,
96 | key: str
97 | ) -> MetadataDict:
98 | cached_text = f"{document.text}"
99 | messages = [
100 | ChatMessage(
101 | role="user",
102 | content=[
103 | {
104 | "text": cached_text,
105 | "block_type": "text",
106 | "cache_control": {"type": "ephemeral"},
107 | },
108 | {
109 | "text": f"Here is the chunk we want to situate within the whole document:\n{node.text}\n{prompt}",
110 | "block_type": "text",
111 | },
112 | ],
113 | ),
114 | ]
115 |
116 | max_retries = 5
117 | base_delay = 60
118 |
119 | for attempt in range(max_retries):
120 | try:
121 | response = await self.llm.achat(
122 | messages,
123 | max_tokens=self.max_contextual_tokens,
124 | extra_headers={"anthropic-beta": "prompt-caching-2024-07-31"},
125 | )
126 | metadata[key] = response.message.blocks[0].text
127 | return metadata
128 |
129 | except Exception as e:
130 | is_rate_limit = any(
131 | message in str(e).lower()
132 | for message in ["rate limit", "too many requests", "429"]
133 | )
134 |
135 | if is_rate_limit and attempt < max_retries - 1:
136 | delay = (base_delay * (2 ** attempt)) + (random.random() * 0.5)
137 | logging.warning(
138 | f"Rate limit hit, retrying in {delay:.1f} seconds "
139 | f"(attempt {attempt + 1}/{max_retries})"
140 | )
141 | await asyncio.sleep(delay)
142 | continue
143 |
144 | if is_rate_limit:
145 | logging.error(f"Failed after {max_retries} retries due to rate limiting")
146 | else:
147 | logging.warning(f"Error generating context for node {node.node_id}: {str(e)}")
148 |
149 | return metadata
150 |
151 | async def aextract(self, nodes: List[Node]) -> List[MetadataDict]:
152 | """
153 | Extract context for multiple nodes asynchronously.
154 | This is the main entry point for the extractor.
155 | """
156 | metadata_list = [{} for _ in nodes]
157 | metadata_map = {node.node_id: metadata_dict for metadata_dict, node in zip(metadata_list, nodes)}
158 |
159 | @lru_cache(maxsize=10)
160 | async def _get_cached_document(doc_id: str) -> Optional[Document]:
161 | doc = await self.docstore.aget_document(doc_id)
162 |
163 | if self.max_context_length is not None:
164 | strategy = self.oversized_document_strategy
165 | token_count = self._count_tokens(doc.text)
166 | if token_count > self.max_context_length:
167 | message = (
168 | f"Document {doc.id} is too large ({token_count} tokens) "
169 | f"to be processed. Doc metadata: {doc.metadata}"
170 | )
171 |
172 | if self.warn_on_oversize:
173 | logging.warning(message)
174 |
175 | if strategy == "truncate_first":
176 | doc.text = self._truncate_text(doc.text, self.max_context_length, 'first')
177 | elif strategy == "truncate_last":
178 | doc.text = self._truncate_text(doc.text, self.max_context_length, 'last')
179 | elif strategy == "error":
180 | raise ValueError(message)
181 | elif strategy == "ignore":
182 | return None
183 | else:
184 | raise ValueError(f"Unknown oversized document strategy: {strategy}")
185 |
186 | return doc
187 |
188 | tasks = []
189 | for node in nodes:
190 | if not node.source_node:
191 | continue
192 |
193 | doc = await _get_cached_document(node.source_node.node_id)
194 | if not doc:
195 | continue
196 |
197 | metadata = metadata_map[node.node_id]
198 | tasks.append(self._agenerate_node_context(node, metadata, doc, self.prompt, self.key))
199 |
200 | if tasks:
201 | await run_jobs(tasks, show_progress=self.show_progress, workers=self.num_workers)
202 |
203 | return metadata_list
--------------------------------------------------------------------------------
/DocumentContextExtractor.py:
--------------------------------------------------------------------------------
1 | from llama_index.core.llms import ChatMessage, LLM
2 | from llama_index.core.async_utils import DEFAULT_NUM_WORKERS, run_jobs
3 | from llama_index.core.extractors import BaseExtractor
4 | from llama_index.core.schema import Document, Node
5 | from llama_index.core import Settings
6 | from llama_index.core.storage.docstore.simple_docstore import DocumentStore
7 | from typing import Optional, Dict, List, Tuple, Set, Union, Literal, Any
8 | from textwrap import dedent
9 | import importlib
10 | import logging
11 | import asyncio
12 | import random
13 | from functools import lru_cache
14 | import tiktoken
15 |
16 | OversizeStrategy = Literal["truncate_first", "truncate_last", "warn", "error", "ignore"]
17 | MetadataDict = Dict[str, str]
18 |
19 | DEFAULT_CONTEXT_PROMPT: str = dedent("""
20 | Generate keywords and brief phrases describing the main topics, entities, and actions in this text.
21 | Replace pronouns with their specific referents. Format as comma-separated phrases.
22 | Exclude meta-commentary about the text itself.
23 | """).strip()
24 |
25 | DEFAULT_KEY: str = "context"
26 |
27 | class DocumentContextExtractor(BaseExtractor):
28 | """
29 | An LLM-based context extractor for enhancing RAG accuracy through document analysis.
30 |
31 | This extractor processes documents and their nodes to generate contextual metadata,
32 | implementing the approach described in the Anthropic "Contextual Retrieval" blog post.
33 | It handles rate limits, document size constraints, and parallel processing of nodes.
34 |
35 | Attributes:
36 | llm (LLM): Language model instance for generating context
37 | docstore (DocumentStore): Storage for parent documents
38 | key (str): Metadata key for storing extracted context
39 | prompt (str): Prompt template for context generation
40 | doc_ids (Set[str]): Set of processed document IDs
41 | max_context_length (int): Maximum allowed document context length
42 | max_contextual_tokens (int): Maximum tokens in generated context
43 | oversized_document_strategy (OversizeStrategy): Strategy for handling large documents
44 | warn_on_oversize (bool): Whether to log warnings for oversized documents
45 | tiktoken_encoder (str): Name of the tiktoken encoding to use
46 |
47 | Example:
48 | ```python
49 | extractor = DocumentContextExtractor(
50 | docstore=my_docstore,
51 | llm=my_llm,
52 | max_context_length=64000,
53 | max_contextual_tokens=256
54 | )
55 | metadata_list = await extractor.aextract(nodes)
56 | ```
57 | """
58 |
59 | # Pydantic fields
60 | llm: LLM
61 | docstore: DocumentStore
62 | key: str
63 | prompt: str
64 | doc_ids: Set[str]
65 | max_context_length: int
66 | max_contextual_tokens: int
67 | oversized_document_strategy: OversizeStrategy
68 | warn_on_oversize: bool = True
69 | tiktoken_encoder: str
70 |
71 | def __init__(
72 | self,
73 | docstore: DocumentStore,
74 | llm: LLM,
75 | key: Optional[str] = DEFAULT_KEY,
76 | prompt: Optional[str] = DEFAULT_CONTEXT_PROMPT,
77 | num_workers: int = DEFAULT_NUM_WORKERS,
78 | max_context_length: int = 128000,
79 | max_contextual_tokens: int = 512,
80 | oversized_document_strategy: OversizeStrategy = "truncate_first",
81 | warn_on_oversize: bool = True,
82 | tiktoken_encoder: str = "cl100k_base",
83 | **kwargs
84 | ) -> None:
85 | if not importlib.util.find_spec("tiktoken"):
86 | raise ValueError("TikToken is required for DocumentContextExtractor. Please install tiktoken.")
87 |
88 | # Process input parameters
89 |
90 | llm = llm or Settings.llm
91 | doc_ids: Set[str] = set()
92 |
93 | super().__init__(
94 | key=key,
95 | prompt=prompt,
96 | llm=llm,
97 | docstore=docstore,
98 | num_workers=num_workers,
99 | doc_ids=doc_ids,
100 | max_context_length=max_context_length,
101 | oversized_document_strategy=oversized_document_strategy,
102 | max_contextual_tokens=max_contextual_tokens,
103 | warn_on_oversize=warn_on_oversize,
104 | tiktoken_encoder=tiktoken_encoder,
105 | **kwargs
106 | )
107 |
108 | # this can take a surprisingly long time on longer docs so we cache it. For oversized docs, we end up counting twice, the 2nd time withotu the cache.
109 | # but if you're repeateddly running way oversize docs, the time that takes wont be what matters anyways.
110 | @staticmethod
111 | @lru_cache(maxsize=1000)
112 | def _count_tokens(text: str, encoder:str="cl100k_base") -> int:
113 | """
114 | This can take a surprisingly long time on longer docs so we cache it, and we need to call it on every doc, regardless of size.
115 | """
116 | encoding = tiktoken.get_encoding(encoder)
117 | return len(encoding.encode(text))
118 |
119 | @staticmethod
120 | @lru_cache(maxsize=10)
121 | def _truncate_text(text: str, max_token_count: int, how: Literal['first', 'last'] = 'first', encoder="cl100k_base") -> str:
122 | """
123 | This can take a couple seconds. A small cache is nice here because the async calls will mostly happen in-order. If you DO hit an oversized document,
124 | you would otherwise be re-truncating 1000s of times as you procses through each chunk in your 200+ document.
125 | """
126 | encoding = tiktoken.get_encoding(encoder)
127 | tokens = encoding.encode(text)
128 |
129 | if how == 'first':
130 | truncated_tokens = tokens[:max_token_count]
131 | else: # 'last'
132 | truncated_tokens = tokens[-max_token_count:]
133 |
134 | return encoding.decode(truncated_tokens)
135 |
136 | async def _agenerate_node_context(
137 | self,
138 | node: Node,
139 | metadata: MetadataDict,
140 | document: Document,
141 | prompt: str,
142 | key: str
143 | ) -> MetadataDict:
144 | """
145 | Generate context for a node using LLM with retry logic.
146 |
147 | Implements exponential backoff for rate limit handling and uses prompt
148 | caching when available. The function retries on rate limits and handles
149 | various error conditions gracefully.
150 |
151 | Args:
152 | node: Node to generate context for
153 | metadata: Metadata dictionary to update
154 | document: Parent document containing the node
155 | prompt: Prompt template for context generation
156 | key: Metadata key for storing generated context
157 |
158 | Returns:
159 | Updated metadata dictionary with generated context
160 |
161 | Note:
162 | Uses exponential backoff starting at 60 seconds with up to 5 retries
163 | for rate limit handling.
164 | """
165 | cached_text = f"{document.text}"
166 | messages = [
167 | ChatMessage(
168 | role="user",
169 | content=[
170 | {
171 | "text": cached_text,
172 | "block_type": "text",
173 | "cache_control": {"type": "ephemeral"},
174 | },
175 | {
176 | "text": f"Here is the chunk we want to situate within the whole document:\n{node.text}\n{prompt}",
177 | "block_type": "text",
178 | },
179 | ],
180 | ),
181 | ]
182 |
183 | max_retries = 5
184 | base_delay = 60
185 |
186 | for attempt in range(max_retries):
187 | try:
188 | response = await self.llm.achat(
189 | messages,
190 | max_tokens=self.max_contextual_tokens,
191 | extra_headers={"anthropic-beta": "prompt-caching-2024-07-31"},
192 | )
193 | metadata[key] = response.message.blocks[0].text
194 | return metadata
195 |
196 | except Exception as e:
197 | is_rate_limit = any(
198 | message in str(e).lower()
199 | for message in ["rate limit", "too many requests", "429"]
200 | )
201 |
202 | if is_rate_limit and attempt < max_retries - 1:
203 | delay = (base_delay * (2 ** attempt)) + (random.random() * 0.5)
204 | logging.warning(
205 | f"Rate limit hit, retrying in {delay:.1f} seconds "
206 | f"(attempt {attempt + 1}/{max_retries})"
207 | )
208 | await asyncio.sleep(delay)
209 | continue
210 |
211 | if is_rate_limit:
212 | logging.error(f"Failed after {max_retries} retries due to rate limiting")
213 | else:
214 | logging.warning(f"Error generating context for node {node.node_id}: {str(e)}")
215 |
216 | return metadata
217 |
218 | async def _get_document(self, doc_id: str) -> Document:
219 | """counting tokens can be slow, as can awaiting the docstore (potentially), so we keep a small lru_cache"""
220 |
221 | # first we need to get the document
222 | doc = await self.docstore.aget_document(doc_id)
223 |
224 | # then truncate if necessary.
225 | if self.max_context_length is not None:
226 | strategy = self.oversized_document_strategy
227 | token_count = self._count_tokens(doc.text, self.tiktoken_encoder)
228 | if token_count > self.max_context_length:
229 | message = (
230 | f"Document {doc.id} is too large ({token_count} tokens) "
231 | f"to be processed. Doc metadata: {doc.metadata}"
232 | )
233 |
234 | if self.warn_on_oversize:
235 | logging.warning(message)
236 |
237 | if strategy == "truncate_first":
238 | doc.text = self._truncate_text(doc.text, self.max_context_length, 'first', self.tiktoken_encoder)
239 | elif strategy == "truncate_last":
240 | doc.text = self._truncate_text(doc.text, self.max_context_length, 'last', self.tiktoken_encoder)
241 | elif strategy == "error":
242 | raise ValueError(message)
243 | elif strategy == "ignore":
244 | return
245 | else:
246 | raise ValueError(f"Unknown oversized document strategy: {strategy}")
247 |
248 | return doc
249 |
250 | async def aextract(self, nodes: List[Node]) -> List[MetadataDict]:
251 | """
252 | Extract context for multiple nodes asynchronously, optimized for loosely ordered nodes.
253 | Processes each node independently without guaranteeing sequential document handling.
254 | Nodes will be *mostly* processed in document-order assuming nodes get passed in document-order.
255 |
256 | Args:
257 | nodes: List of nodes to process, ideally grouped by source document
258 |
259 | Returns:
260 | List of metadata dictionaries with generated context
261 | """
262 | metadata_list = [{} for _ in nodes]
263 | metadata_map = {node.node_id: metadata_dict for metadata_dict, node in zip(metadata_list, nodes)}
264 |
265 | # iterate over all the nodes and generate the jobs
266 | node_tasks = []
267 | for node in nodes:
268 | if not node.source_node:
269 | return
270 |
271 | doc = await self._get_document(node.source_node.node_id)
272 |
273 | if not doc:
274 | continue
275 |
276 | metadata = metadata_map[node.node_id]
277 | task = self._agenerate_node_context(node, metadata, doc, self.prompt, self.key)
278 | node_tasks.append(task)
279 |
280 | # then run the jobs
281 | await run_jobs(
282 | node_tasks,
283 | show_progress=self.show_progress,
284 | workers=self.num_workers,
285 | )
286 |
287 | return metadata_list
--------------------------------------------------------------------------------
/HybridSearchDemo.py:
--------------------------------------------------------------------------------
1 | from llama_index.core import VectorStoreIndex, StorageContext, Settings
2 | from llama_index.core.storage.index_store.simple_index_store import SimpleIndexStore
3 | from llama_index.vector_stores.qdrant import QdrantVectorStore
4 | from qdrant_client import QdrantClient, AsyncQdrantClient
5 |
6 | from llama_index.core import SimpleDirectoryReader
7 | from llama_index.core.node_parser import SentenceSplitter
8 | from llama_index.llms.openai import OpenAI
9 | from llama_index.core.postprocessor import LLMRerank
10 | from llama_index.core.storage.docstore.simple_docstore import SimpleDocumentStore
11 | from llama_index.embeddings.openai import OpenAIEmbedding
12 |
13 | import os
14 | from DocumentContextExtractor import DocumentContextExtractor
15 |
16 | # TODO: add 'query context' to this
17 |
18 | class HybridSearchWithContext:
19 | CHUNK_SIZE = 512
20 | CHUNK_OVERLAP = 50
21 | SIMILARITY_TOP_K = 10
22 | SPARSE_TOP_K = 20
23 | REREANKER_TOP_N = 3
24 | def __init__(self, name:str):
25 | """
26 | :param name: The name of the index, required for the underlying vector store
27 | """
28 | # Initialize clients
29 | client = QdrantClient(":memory:")
30 | aclient = AsyncQdrantClient(":memory:")
31 | self.index_store_path = f"{name}"
32 |
33 | if not os.path.exists(self.index_store_path):
34 | os.makedirs(self.index_store_path)
35 |
36 | # Load documents
37 | self.context_llm = OpenAI(model="gpt-4o-mini")
38 | self.answering_llm = OpenAI(model="gpt-4o-mini")
39 |
40 | self.embed_model = OpenAIEmbedding(model="text-embedding-3-small")
41 |
42 | sample_embedding = self.embed_model.get_query_embedding("sample text")
43 | self.embed_size = len(sample_embedding)
44 |
45 | self.reranker = LLMRerank(
46 | choice_batch_size=5,
47 | top_n=self.REREANKER_TOP_N,
48 | llm=self.context_llm
49 | )
50 |
51 | # Create vector store
52 | self.vector_store = QdrantVectorStore(
53 | name,
54 | client=client,
55 | aclient=aclient,
56 | enable_hybrid=True,
57 | batch_size=20,
58 | dim=self.embed_size
59 | )
60 |
61 | # Initialize storage context
62 | if os.path.exists(os.path.join(self.index_store_path, "index_store.json")):
63 | index_store=SimpleIndexStore.from_persist_dir(persist_dir=self.index_store_path)
64 | else:
65 | index_store=SimpleIndexStore()
66 |
67 | self.storage_context = StorageContext.from_defaults(vector_store=self.vector_store,
68 | index_store=index_store)
69 |
70 | # Create text splitter
71 | self.text_splitter = SentenceSplitter(
72 | chunk_size=self.CHUNK_SIZE,
73 | chunk_overlap=self.CHUNK_OVERLAP
74 | )
75 |
76 | # DocumentContextExtractor requires a document store
77 | # 1st 2 arguments are required.
78 | # max_contextual_tokens plus chunk_size should be a little less than the max input size of your embedding to give some headroom
79 | self.document_context_extractor = DocumentContextExtractor(docstore=self.storage_context.docstore,
80 | llm=self.context_llm, max_context_length=128000,
81 | max_contextual_tokens=512,
82 | oversized_document_strategy="truncate_first")
83 |
84 |
85 |
86 | self.index = VectorStoreIndex.from_vector_store(
87 | vector_store=self.vector_store,
88 | embed_model=self.embed_model,
89 | storage_context=self.storage_context,
90 | transformations=[self.text_splitter, self.document_context_extractor]
91 | )
92 |
93 | self.storage_context.persist(persist_dir=self.index_store_path)
94 |
95 | def add_directory(self, directory):
96 | reader = SimpleDirectoryReader(directory)
97 | documents = reader.load_data()
98 |
99 | self.storage_context.docstore.add_documents(documents)
100 | for doc in documents:
101 | self.index.insert(doc)
102 |
103 | self.query_engine = self.index.as_query_engine(
104 | similarity_top_k=self.SIMILARITY_TOP_K,
105 | sparse_top_k=self.SPARSE_TOP_K,
106 | vector_store_query_mode="hybrid",
107 | llm=self.answering_llm,
108 | node_postprocessors=[self.reranker]
109 | )
110 |
111 | self.retriever = self.index.as_retriever(
112 | similarity_top_k=self.SIMILARITY_TOP_K,
113 | sparse_top_k=self.SPARSE_TOP_K,
114 | vector_store_query_mode="hybrid"
115 | )
116 |
117 | self.storage_context.persist(persist_dir=self.index_store_path)
118 |
119 | def get_raw_search_results(self, question):
120 |
121 | retrieved_nodes = self.retriever.retrieve(question)
122 |
123 | retrieved_texts = [node.text for node in retrieved_nodes]
124 |
125 | return retrieved_nodes
126 |
127 | def query_engine(self, question):
128 | response = self.query_engine.query(
129 | question
130 | )
131 |
132 | return response
133 |
134 |
135 |
136 |
137 | if __name__=='__main__':
138 | from dotenv import load_dotenv
139 | load_dotenv()
140 | hybrid_search = HybridSearchWithContext(name="hybriddemo")
141 | hybrid_search.add_directory("./data")
142 |
143 | question = "Why was this document written?"
144 | print(hybrid_search.get_raw_search_results(question))
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 | # DEPRECATED. THIS HAS NOW BEEN MERGED INTO LLAMA_INDEX. PLEASE USE LLAMA_INDEX.CORE.EXTRACTORS.DocumentContextExtractor https://github.com/run-llama/llama_index/pull/17367
3 |
4 | ## Summary
5 |
6 | This repository contains a llama_index implementation of "contextual retrieval" (https://www.anthropic.com/news/contextual-retrieval)
7 |
8 | It implements a custom llama_index Extractor class, which can then be used in a llama index pipeline. It requires you to initialize it using a Document Store and an LLM to provide the context. It also requires you keep the documentstore up to date.
9 |
10 | ## motivation
11 |
12 | Anthropic made a 'cookbook' notebook to demo this. llama_index also made a demo of it here: https://docs.llamaindex.ai/en/stable/examples/cookbooks/contextual_retrieval
13 |
14 | The problem, there are tons of edge cases when trying to replicate what this at scale, over 100s of documents:
15 |
16 | - rate limits are a huge problem
17 |
18 | - cost
19 |
20 | - I want to put this into a pipeline
21 |
22 | - documents too large for context window
23 |
24 | - prompt caching doesn't work via llama_index interface
25 |
26 | - error handling
27 |
28 | - chunk + context can be too big for the embedding model
29 |
30 | - and much more!
31 |
32 | ## Demo
33 |
34 | See hybridsearchdemo.py for a demo of the extractor in action with Qdrant hybrid search, effectively re-implementing the blog post. All the OTHER parts of the blog post (reranking, hybrid search) are already well implemented in llama_index, in my opinion.
35 |
36 | ## Usage
37 |
38 | ```python
39 | docstore = SimpleDocumentStore()
40 |
41 | llm = OpenRouter(model="openai/gpt-4o-mini")
42 |
43 | # initialize the extractor
44 | extractor = DocumentContextExtractor(document_store, llm)
45 |
46 | storage_context = StorageContext.from_defaults(vector_store=self.vector_store,
47 | docstore=docstore,
48 | index_store=index_store)
49 | index = VectorStoreIndex.from_vector_store(
50 | vector_store=vector_store,
51 | embed_model=embed_model,
52 | storage_context=storage_context,
53 | transformations=[text_splitter, self.document_context_extractor]
54 | )
55 |
56 | reader = SimpleDirectoryReader(directory)
57 | documents = reader.load_data()
58 |
59 | # have to keep this updated for the DocumentContextExtractor to function.
60 | storagecontext.docstore.add_documents(documents)
61 | for doc in documents:
62 | self.index.insert(doc)
63 | ```
64 |
65 | ### custom keys and prompts
66 |
67 | by default, the extractor adds a key called "context" to each node, using a reasonable default prompt taken from the blog post cookbook, but you can pass in a list of keys and prompts like so:
68 |
69 | ```python
70 | extractor = DocumentContextExtractor(document_store, llm, keys=["context", "title"], prompts=["Give the document context", "Provide a chunk title"])
71 | ```
72 |
73 | ## model selection
74 |
75 | You need something fast, high rate limits, long context, low cost on input tokens.
76 |
77 | Recommended models:
78 |
79 | - gpt 4o-mini
80 |
81 | - Gemini flash models
82 |
83 | - long-context local models (would love recommendations)
84 |
85 | gpt 4o-mini is king. The 128k context, smart, automatic prompt caching make it absolutely perfect. Throw $50 at openai and wait 7 days, they'll give you 2mil tokens/minute at $0.075/mil toks.
86 | You're going to pay (doc_size * doc_size//chunk_size) tokens for each document in input costs, and then (num_chunks * 200) or so for output tokens.
87 | This means 10-50 million tokens to process Pride and Prejudice, if you dont split it into chapters first.
88 |
89 |
90 | ## TODO
91 | - TEST 'succinct' prompt performance vs 'full' prompt performance!
92 | - Support for Batch requests (supported by anthropic and openai) to handle truly massive amounts of documents?
93 | - fix this bug because it prevents Llama index from working with Python 3.10: https://github.com/run-llama/llama_index/discussions/14351
94 | - add a TransformComponent that splits documents into smaller documents and then adds them to the docstore
95 | - or better yet, a TransformComponent that simply adds the nodes to the docstore and does nothing else
96 | - then you can build a pipeline like this: ChapterSplitter -> DocstoreCatcher -> SentenceSplitter -> DocumentContextExtractor
97 | - make a pull request to llama_index
98 |
--------------------------------------------------------------------------------
/data/declaration.txt:
--------------------------------------------------------------------------------
1 | In Congress, July 4, 1776
2 |
3 | The unanimous Declaration of the thirteen united States of America, When in the Course of human events, it becomes necessary for one people to dissolve the political bands which have connected them with another, and to assume among the powers of the earth, the separate and equal station to which the Laws of Nature and of Nature's God entitle them, a decent respect to the opinions of mankind requires that they should declare the causes which impel them to the separation.
4 |
5 | We hold these truths to be self-evident, that all men are created equal, that they are endowed by their Creator with certain unalienable Rights, that among these are Life, Liberty and the pursuit of Happiness.--That to secure these rights, Governments are instituted among Men, deriving their just powers from the consent of the governed, --That whenever any Form of Government becomes destructive of these ends, it is the Right of the People to alter or to abolish it, and to institute new Government, laying its foundation on such principles and organizing its powers in such form, as to them shall seem most likely to effect their Safety and Happiness. Prudence, indeed, will dictate that Governments long established should not be changed for light and transient causes; and accordingly all experience hath shewn, that mankind are more disposed to suffer, while evils are sufferable, than to right themselves by abolishing the forms to which they are accustomed. But when a long train of abuses and usurpations, pursuing invariably the same Object evinces a design to reduce them under absolute Despotism, it is their right, it is their duty, to throw off such Government, and to provide new Guards for their future security.--Such has been the patient sufferance of these Colonies; and such is now the necessity which constrains them to alter their former Systems of Government. The history of the present King of Great Britain is a history of repeated injuries and usurpations, all having in direct object the establishment of an absolute Tyranny over these States. To prove this, let Facts be submitted to a candid world.
6 |
7 | He has refused his Assent to Laws, the most wholesome and necessary for the public good.
8 |
9 | He has forbidden his Governors to pass Laws of immediate and pressing importance, unless suspended in their operation till his Assent should be obtained; and when so suspended, he has utterly neglected to attend to them.
10 |
11 | He has refused to pass other Laws for the accommodation of large districts of people, unless those people would relinquish the right of Representation in the Legislature, a right inestimable to them and formidable to tyrants only.
12 |
13 | He has called together legislative bodies at places unusual, uncomfortable, and distant from the depository of their public Records, for the sole purpose of fatiguing them into compliance with his measures.
14 |
15 | He has dissolved Representative Houses repeatedly, for opposing with manly firmness his invasions on the rights of the people.
16 |
17 | He has refused for a long time, after such dissolutions, to cause others to be elected; whereby the Legislative powers, incapable of Annihilation, have returned to the People at large for their exercise; the State remaining in the mean time exposed to all the dangers of invasion from without, and convulsions within.
18 |
19 | He has endeavoured to prevent the population of these States; for that purpose obstructing the Laws for Naturalization of Foreigners; refusing to pass others to encourage their migrations hither, and raising the conditions of new Appropriations of Lands.
20 |
21 | He has obstructed the Administration of Justice, by refusing his Assent to Laws for establishing Judiciary powers.
22 |
23 | He has made Judges dependent on his Will alone, for the tenure of their offices, and the amount and payment of their salaries.
24 |
25 | He has erected a multitude of New Offices, and sent hither swarms of Officers to harrass our people, and eat out their substance.
26 |
27 | He has kept among us, in times of peace, Standing Armies without the Consent of our legislatures.
28 |
29 | He has affected to render the Military independent of and superior to the Civil power.
30 |
31 | He has combined with others to subject us to a jurisdiction foreign to our constitution, and unacknowledged by our laws; giving his Assent to their Acts of pretended Legislation:
32 |
33 | For Quartering large bodies of armed troops among us:
34 |
35 | For protecting them, by a mock Trial, from punishment for any Murders which they should commit on the Inhabitants of these States:
36 |
37 | For cutting off our Trade with all parts of the world:
38 |
39 | For imposing Taxes on us without our Consent:
40 |
41 | For depriving us in many cases, of the benefits of Trial by Jury:
42 |
43 | For transporting us beyond Seas to be tried for pretended offences:
44 |
45 | For abolishing the free System of English Laws in a neighbouring Province, establishing therein an Arbitrary government, and enlarging its Boundaries so as to render it at once an example and fit instrument for introducing the same absolute rule into these Colonies:
46 |
47 | For taking away our Charters, abolishing our most valuable Laws, and altering fundamentally the Forms of our Governments:
48 |
49 | For suspending our own Legislatures, and declaring themselves invested with power to legislate for us in all cases whatsoever.
50 |
51 | He has abdicated Government here, by declaring us out of his Protection and waging War against us.
52 |
53 | He has plundered our seas, ravaged our Coasts, burnt our towns, and destroyed the lives of our people.
54 |
55 | He is at this time transporting large Armies of foreign Mercenaries to compleat the works of death, desolation and tyranny, already begun with circumstances of Cruelty & perfidy scarcely paralleled in the most barbarous ages, and totally unworthy the Head of a civilized nation.
56 |
57 | He has constrained our fellow Citizens taken Captive on the high Seas to bear Arms against their Country, to become the executioners of their friends and Brethren, or to fall themselves by their Hands.
58 |
59 | He has excited domestic insurrections amongst us, and has endeavoured to bring on the inhabitants of our frontiers, the merciless Indian Savages, whose known rule of warfare, is an undistinguished destruction of all ages, sexes and conditions.
60 |
61 | In every stage of these Oppressions We have Petitioned for Redress in the most humble terms: Our repeated Petitions have been answered only by repeated injury. A Prince whose character is thus marked by every act which may define a Tyrant, is unfit to be the ruler of a free people.
62 |
63 | Nor have We been wanting in attentions to our Brittish brethren. We have warned them from time to time of attempts by their legislature to extend an unwarrantable jurisdiction over us. We have reminded them of the circumstances of our emigration and settlement here. We have appealed to their native justice and magnanimity, and we have conjured them by the ties of our common kindred to disavow these usurpations, which, would inevitably interrupt our connections and correspondence. They too have been deaf to the voice of justice and of consanguinity. We must, therefore, acquiesce in the necessity, which denounces our Separation, and hold them, as we hold the rest of mankind, Enemies in War, in Peace Friends.
64 |
65 | We, therefore, the Representatives of the united States of America, in General Congress, Assembled, appealing to the Supreme Judge of the world for the rectitude of our intentions, do, in the Name, and by Authority of the good People of these Colonies, solemnly publish and declare, That these United Colonies are, and of Right ought to be Free and Independent States; that they are Absolved from all Allegiance to the British Crown, and that all political connection between them and the State of Great Britain, is and ought to be totally dissolved; and that as Free and Independent States, they have full Power to levy War, conclude Peace, contract Alliances, establish Commerce, and to do all other Acts and Things which Independent States may of right do. And for the support of this Declaration, with a firm reliance on the protection of divine Providence, we mutually pledge to each other our Lives, our Fortunes and our sacred Honor.
--------------------------------------------------------------------------------
/perf_tests/asynciotest.py:
--------------------------------------------------------------------------------
1 | import asyncio
2 | import time
3 |
4 | async def task(num, sem):
5 | async with sem: # Acquire a slot from the semaphore
6 | await asyncio.sleep(1)
7 | return (num, time.perf_counter_ns())
8 |
9 | async def main():
10 | start_time = time.perf_counter_ns()
11 | print(f"Starting tasks at {start_time}")
12 |
13 | # Create semaphore with max 5 concurrent tasks
14 | sem = asyncio.Semaphore(5)
15 |
16 | tasks = [task(i, sem) for i in range(20)]
17 | results = await asyncio.gather(*tasks)
18 |
19 | # Sort by timestamp and print
20 | for num, completion_time in sorted(results, key=lambda x: x[1]):
21 | relative_time = (completion_time - start_time) / 1000 # ns to μs
22 | print(f"Task {num:3d} completed at +{relative_time:.3f}μs")
23 |
24 | print(f"\nAll tasks completed in {(time.perf_counter_ns() - start_time)/1e9:.3f} seconds")
25 |
26 | asyncio.run(main())
--------------------------------------------------------------------------------
/perf_tests/asynciotest2.py:
--------------------------------------------------------------------------------
1 | import asyncio
2 |
3 | async def main():
4 | data = [x for x in range(1000)]
5 |
6 | async def aiter(iterable):
7 | for v in iterable:
8 | yield v
9 |
10 | aiterable = aiter(data)
11 | aiterables = [aiterable] * 1000
12 |
13 | values = await asyncio.gather(
14 | *[it.__anext__() for it in aiterables])
15 |
16 | assert values == data, f'{values} != {data}'
17 |
18 | loop = asyncio.get_event_loop()
19 | loop.run_until_complete(main())
--------------------------------------------------------------------------------
/perf_tests/perf_tests.py:
--------------------------------------------------------------------------------
1 | import time
2 | import asyncio
3 | from llama_index.core.node_parser import TokenTextSplitter
4 | from llama_index.core.schema import Document
5 | from llama_index.core.storage.docstore.simple_docstore import SimpleDocumentStore
6 | import numpy as np
7 | from typing import Literal
8 | from functools import lru_cache
9 | import tiktoken
10 |
11 | @staticmethod
12 | @lru_cache(maxsize=1000)
13 | def _count_tokens(text: str, encoder:str="cl100k_base") -> int:
14 | encoding = tiktoken.get_encoding(encoder)
15 | return len(encoding.encode(text))
16 |
17 | @staticmethod
18 | def _truncate_text(text: str, max_token_count: int, how: Literal['first', 'last'] = 'first', encoder="cl100k_base") -> str:
19 | encoding = tiktoken.get_encoding(encoder)
20 | tokens = encoding.encode(text)
21 |
22 | if how == 'first':
23 | truncated_tokens = tokens[:max_token_count]
24 | else: # last
25 | truncated_tokens = tokens[-max_token_count:]
26 |
27 | return encoding.decode(truncated_tokens)
28 |
29 | async def run_performance_test():
30 | # Create test document and store
31 | docstore = SimpleDocumentStore()
32 |
33 | with open('prideandprejudice.txt', 'r') as f:
34 | large_text = f.read()
35 |
36 | doc = Document(text=large_text)
37 | docstore.add_documents([doc])
38 |
39 | # Test each operation separately
40 | n_runs = 5
41 | token_count_times = []
42 | truncate_times = []
43 | docstore_times = []
44 |
45 | for _ in range(n_runs):
46 | # Test token counting
47 | start = time.time()
48 | _ = count_tokens(large_text)
49 | token_count_times.append(time.time() - start)
50 |
51 | # Test truncation
52 | start = time.time()
53 | _ = truncate_text(large_text, 1000)
54 | truncate_times.append(time.time() - start)
55 |
56 | # Test document retrieval
57 | start = time.time()
58 | _ = await docstore.aget_document(doc.doc_id)
59 | docstore_times.append(time.time() - start)
60 |
61 | print("\nToken counting:")
62 | print(f"Average: {np.mean(token_count_times):.3f}s")
63 | print(f"Min: {min(token_count_times):.3f}s")
64 | print(f"Max: {max(token_count_times):.3f}s")
65 |
66 | print("\nTruncation:")
67 | print(f"Average: {np.mean(truncate_times):.3f}s")
68 | print(f"Min: {min(truncate_times):.3f}s")
69 | print(f"Max: {max(truncate_times):.3f}s")
70 |
71 | print("\nDocument retrieval:")
72 | print(f"Average: {np.mean(docstore_times):.3f}s")
73 | print(f"Min: {min(docstore_times):.3f}s")
74 | print(f"Max: {max(docstore_times):.3f}s")
75 |
76 | if __name__ == "__main__":
77 | asyncio.run(run_performance_test())
--------------------------------------------------------------------------------
/perf_tests/perftest3.py:
--------------------------------------------------------------------------------
1 | import time
2 | import tiktoken
3 | from functools import lru_cache
4 | import numpy as np
5 |
6 | @lru_cache(maxsize=1000)
7 | def count_tokens_cached(text: str) -> int:
8 | encoding = tiktoken.get_encoding("cl100k_base")
9 | return len(encoding.encode(text))
10 |
11 | def count_tokens_direct(text: str) -> int:
12 | encoding = tiktoken.get_encoding("cl100k_base")
13 | return len(encoding.encode(text))
14 |
15 | def run_test():
16 | # Create test documents of different sizes
17 | sizes = [1000, 10000, 100000]
18 |
19 | for size in sizes:
20 | text = f"This is a test document. " * size
21 | print(f"\nTesting with document size {size} words:")
22 |
23 | # Test uncached first run
24 | start = time.time()
25 | _ = count_tokens_cached(text)
26 | first_cache_time = time.time() - start
27 |
28 | # Test cached second run
29 | start = time.time()
30 | _ = count_tokens_cached(text)
31 | second_cache_time = time.time() - start
32 |
33 | # Test direct tokenization
34 | start = time.time()
35 | _ = count_tokens_direct(text)
36 | direct_time = time.time() - start
37 |
38 | print(f"First run with cache: {first_cache_time:.3f}s")
39 | print(f"Second run with cache: {second_cache_time:.3f}s")
40 | print(f"Direct tokenization: {direct_time:.3f}s")
41 |
42 | if __name__ == "__main__":
43 | run_test()
--------------------------------------------------------------------------------
/perf_tests/token_comparator.py:
--------------------------------------------------------------------------------
1 | import tiktoken
2 |
3 | def compare_encodings(text: str):
4 | encodings = ["cl100k_base", "p50k_base", "r50k_base"]
5 | results = {}
6 |
7 | for enc_name in encodings:
8 | encoding = tiktoken.get_encoding(enc_name)
9 | count = len(encoding.encode(text))
10 | results[enc_name] = count
11 |
12 | return results
13 |
14 | # Test with different types of text
15 | test_cases = {
16 | "Simple English": "This is a simple test of the encoding systems.",
17 | "Technical": "The DocumentContextExtractor class implements efficient token-based text processing with configurable chunking strategies.",
18 | "Mixed": "Here's some code: for i in range(10): print(f'Value: {i}')",
19 | "Special Chars": "Special characters like 你好, üñîçødé, and emojis 🌟 can affect encoding.",
20 | "Long Technical": """The transformer architecture employs multi-headed self-attention mechanisms to process sequential data.
21 | Each attention head can learn different aspects of the relationships between tokens in the sequence.
22 | The model uses positional encodings to maintain sequence order information.""" * 3
23 | }
24 |
25 | for name, text in test_cases.items():
26 | print(f"\n{name}:")
27 | counts = compare_encodings(text)
28 | base = counts['cl100k_base']
29 | for enc, count in counts.items():
30 | diff = ((count - base) / base) * 100
31 | print(f"{enc}: {count} tokens ({diff:+.1f}% vs cl100k)")
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | botocore
2 | python-dotenv
3 |
4 | llama_index
5 | llama_index_core
6 |
7 | llama-index-llms-openai
8 | llama-index-vector-stores-qdrant
9 | llama-index-embeddings-openai
--------------------------------------------------------------------------------
/test_document_context_extractor.py:
--------------------------------------------------------------------------------
1 | # this goes here: C:\Users\cklap\llama_index\llama-index-core\tests\node_parser\metadata_extractor.py
2 | import pytest
3 | from llama_index.core.schema import Document, Node
4 | from llama_index.core.storage.docstore.simple_docstore import SimpleDocumentStore
5 | from llama_index.core.llms import MockLLM
6 | from llama_index.core.llms import ChatMessage
7 | from DocumentContextExtractor import DocumentContextExtractor
8 |
9 | @pytest.fixture
10 | def mock_llm():
11 | class CustomMockLLM(MockLLM):
12 | def chat(self, messages, **kwargs):
13 | # Mock response that simulates context generation
14 | return ChatMessage(
15 | role="assistant",
16 | blocks=[{"text": f"Context for the provided chunk", "block_type": "text"}]
17 | )
18 |
19 | return CustomMockLLM()
20 |
21 | @pytest.fixture
22 | def sample_documents():
23 | # Create some test documents
24 | docs = [
25 | Document(
26 | text="This is chapter 1. It contains important information. This is a test document.",
27 | metadata={"title": "Doc 1"}
28 | ),
29 | Document(
30 | text="Chapter 2 builds on previous concepts. It introduces new ideas. More test content here.",
31 | metadata={"title": "Doc 2"}
32 | )
33 | ]
34 | return docs
35 |
36 | @pytest.fixture
37 | def docstore(sample_documents):
38 | # Initialize docstore with sample documents
39 | docstore = SimpleDocumentStore()
40 | for doc in sample_documents:
41 | docstore.add_documents([doc])
42 | return docstore
43 |
44 | @pytest.fixture
45 | def context_extractor(docstore, mock_llm):
46 | return DocumentContextExtractor(
47 | docstore=docstore,
48 | llm=mock_llm,
49 | max_context_length=1000,
50 | max_contextual_tokens=100,
51 | oversized_document_strategy="truncate_first"
52 | )
53 |
54 | @pytest.mark.asyncio
55 | async def test_context_extraction_basic(context_extractor, sample_documents):
56 | # Create nodes from the first document
57 | nodes = [
58 | Node(
59 | text="This is chapter 1.",
60 | metadata={},
61 | source_node=sample_documents[0]
62 | ),
63 | Node(
64 | text="It contains important information.",
65 | metadata={},
66 | source_node=sample_documents[0]
67 | )
68 | ]
69 |
70 | # Extract context
71 | metadata_list = await context_extractor.aextract(nodes)
72 |
73 | # Verify each node got context
74 | assert len(metadata_list) == len(nodes)
75 | for metadata in metadata_list:
76 | assert "context" in metadata
77 | assert metadata["context"] == "Context for the provided chunk"
78 |
79 | @pytest.mark.asyncio
80 | async def test_context_extraction_oversized_document():
81 | # Create a very large document
82 | large_doc = Document(
83 | text="This is a very long document. " * 1000,
84 | metadata={"title": "Large Doc"}
85 | )
86 |
87 | docstore = SimpleDocumentStore()
88 | docstore.add_documents([large_doc])
89 |
90 | extractor = DocumentContextExtractor(
91 | docstore=docstore,
92 | llm=MockLLM(),
93 | max_context_length=100, # Small limit to trigger truncation
94 | max_contextual_tokens=50,
95 | oversized_document_strategy="truncate_first"
96 | )
97 |
98 | node = Node(
99 | text="This is a test chunk.",
100 | metadata={},
101 | source_node=large_doc
102 | )
103 |
104 | # Should not raise an error due to truncation strategy
105 | metadata_list = await extractor.aextract([node])
106 | assert len(metadata_list) == 1
107 |
108 | @pytest.mark.asyncio
109 | async def test_context_extraction_custom_prompt(docstore, mock_llm):
110 | custom_prompt = "Generate a detailed context for this chunk:"
111 | extractor = DocumentContextExtractor(
112 | docstore=docstore,
113 | llm=mock_llm,
114 | prompts=[custom_prompt],
115 | max_context_length=1000,
116 | max_contextual_tokens=100
117 | )
118 |
119 | node = Node(
120 | text="Test chunk",
121 | metadata={},
122 | source_node=next(iter(docstore.docs.values()))
123 | )
124 |
125 | metadata_list = await extractor.aextract([node])
126 | assert len(metadata_list) == 1
127 | assert "context" in metadata_list[0]
128 |
129 | @pytest.mark.asyncio
130 | async def test_multiple_documents_context(context_extractor, sample_documents):
131 | # Create nodes from different documents
132 | nodes = [
133 | Node(
134 | text="This is chapter 1.",
135 | metadata={},
136 | source_node=sample_documents[0]
137 | ),
138 | Node(
139 | text="Chapter 2 builds on previous concepts.",
140 | metadata={},
141 | source_node=sample_documents[1]
142 | )
143 | ]
144 |
145 | metadata_list = await context_extractor.aextract(nodes)
146 | assert len(metadata_list) == 2
147 | for metadata in metadata_list:
148 | assert "context" in metadata
149 |
150 | def test_invalid_oversized_strategy():
151 | with pytest.raises(ValueError):
152 | DocumentContextExtractor(
153 | docstore=SimpleDocumentStore(),
154 | llm=MockLLM(),
155 | max_context_length=1000,
156 | max_contextual_tokens=100,
157 | oversized_document_strategy="invalid_strategy"
158 | )
--------------------------------------------------------------------------------