(patch{},v[16])]'
6 | pathId: /document/v1/pdf_page/pdf_page/docid/https%3A/static.conocophillips.com/files/resources/24-0976-sustainability-highlights_nature.pdf
7 | id: https://static.conocophillips.com/files/resources/24-0976-sustainability-highlights_nature.pdf
8 |
--------------------------------------------------------------------------------
/src/vespa/vespa_config.yaml:
--------------------------------------------------------------------------------
1 | # vespa:
2 | # tenant_name: "cube-digital"
3 | # app_name: "test4"
4 | # schema_name: "pdf_page"
5 | # connections: 1
6 | # timeout: 180
7 | # hits_per_query: 5
8 | # schema:
9 | # max_query_terms: 64
10 | # hnsw_max_links: 32
11 | # hnsw_neighbors: 400
12 | # rerank_count: 10
13 | # tensor_dimensions: 16
14 |
15 |
16 |
17 | vespa:
18 | app_name: "test"
19 | tenant_name: "ml-vanguards"
20 | connections: 1
21 | timeout: 180
22 | hits_per_query: 5
23 | schema_name: "pdf_page"
24 | tensor_dimensions: 16
25 | schema:
26 | max_query_terms: 64
27 | hnsw_max_links: 32
28 | hnsw_neighbors: 400
29 | rerank_count: 10
--------------------------------------------------------------------------------
/src/vespa/model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from colpali_engine.models import ColQwen2, ColQwen2Processor
3 |
4 |
5 | def create_model(model_name: str = "vidore/colqwen2-v0.1"):
6 | """
7 | Load a pre-trained ColQwen2 model and processor.
8 |
9 | Args:
10 | model_name: The name of the pre-trained model to load (default: "vidore/colqwen2-v0.1")
11 |
12 | Returns:
13 | A tuple (model, processor) containing the pre-trained model and processor
14 | """
15 | model = ColQwen2.from_pretrained(
16 | model_name, torch_dtype=torch.bfloat16, device_map="auto"
17 | )
18 | processor = ColQwen2Processor.from_pretrained(model_name)
19 | model.eval()
20 | return model, processor
21 |
--------------------------------------------------------------------------------
/src/pdf_embedding_decider/datatypes.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass
2 | from enum import Enum
3 | from typing import Any, Dict
4 |
5 |
6 | class EmbeddingType(Enum):
7 | COLPALI = "ColPali"
8 | TRADITIONAL = "Traditional"
9 |
10 |
11 | @dataclass
12 | class AnalysisConfig:
13 | """Configuration parameters for PDF analysis"""
14 |
15 | visual_threshold: int = 15
16 | text_density_threshold: float = 0.25
17 | layout_threshold: int = 100
18 | min_image_size: int = 1000
19 | table_row_threshold: int = 5
20 | table_weight: float = 0.3
21 |
22 |
23 | @dataclass
24 | class AnalysisResult:
25 | """Structured container for analysis results"""
26 |
27 | score: float
28 | details: Dict[str, Any]
29 | confidence: float = 1.0
30 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [tool.poetry]
2 | name = "ocr-rag-systems"
3 | version = "0.1.0"
4 | description = ""
5 | authors = ["Vesa Alex"]
6 | readme = "README.md"
7 |
8 | [tool.poetry.dependencies]
9 | python = "^3.12"
10 | torch = "^2.5.1"
11 | surya-ocr = "^0.8.1"
12 | numpy = "1.26.4"
13 | pytesseract = "^0.3.13"
14 | easyocr = "^1.7.2"
15 | paddleocr = "^2.9.1"
16 | python-doctr = "^0.10.0"
17 | pdf2image = "^1.17.0"
18 | pandas = "^2.2.3"
19 | ollama = "^0.4.4"
20 | together = "^1.3.10"
21 | colpali-engine = "0.3.1"
22 | vidore-benchmark = "4.0.0"
23 | google-generativeai = "^0.8.3"
24 | pypdf = "5.0.1"
25 | pyvespa = "^0.51.0"
26 | vespacli = "^8.453.24"
27 | requests = "^2.32.3"
28 | ipython = "^8.31.0"
29 | pymupdf = "^1.25.1"
30 |
31 |
32 | [build-system]
33 | requires = ["poetry-core"]
34 | build-backend = "poetry.core.masonry.api"
35 |
--------------------------------------------------------------------------------
/src/pdf_embedding_decider/interfaces.py:
--------------------------------------------------------------------------------
1 | import logging
2 |
3 | import fitz
4 |
5 | from src.pdf_embedding_decider.datatypes import AnalysisResult
6 |
7 | logging.basicConfig(level=logging.INFO)
8 | logger = logging.getLogger(__name__)
9 |
10 |
11 | class PDFAnalyzer:
12 | """Base class for PDF analysis with error handling and resource management"""
13 |
14 | def analyze(self, pdf_path: str) -> AnalysisResult:
15 | try:
16 | doc = fitz.open(pdf_path)
17 | result = self._analyze_document(doc)
18 | return result
19 | except FileNotFoundError:
20 | logger.error(f"PDF file not found: {pdf_path}")
21 | raise
22 | except Exception as e:
23 | logger.error(f"Error analyzing PDF: {str(e)}")
24 | raise
25 | finally:
26 | if "doc" in locals():
27 | doc.close()
28 |
29 | def _analyze_document(self, doc: fitz.Document) -> AnalysisResult:
30 | raise NotImplementedError
31 |
--------------------------------------------------------------------------------
/src/vespa/utils.py:
--------------------------------------------------------------------------------
1 | from IPython.display import HTML, display
2 |
3 |
4 | def display_query_results(query, response, hits=5):
5 | query_time = response.json.get("timing", {}).get("searchtime", -1)
6 | query_time = round(query_time, 2)
7 | count = response.json.get("root", {}).get("fields", {}).get("totalCount", 0)
8 | html_content = f"Query text: '{query}', query time {query_time}s, count={count}, top results:
"
9 |
10 | for i, hit in enumerate(response.hits[:hits]):
11 | title = hit["fields"]["title"]
12 | url = hit["fields"]["url"]
13 | page = hit["fields"]["page_number"]
14 | image = hit["fields"]["image"]
15 | score = hit["relevance"]
16 |
17 | html_content += f"PDF Result {i + 1}
"
18 | html_content += f'Title: {title}, page {page+1} with score {score:.2f}
'
19 | html_content += (
20 | f'
'
21 | )
22 |
23 | display(HTML(html_content))
24 |
--------------------------------------------------------------------------------
/src/ocr_benchmark/engines/config.py:
--------------------------------------------------------------------------------
1 | """Configuration classes for OCR engines."""
2 |
3 | from dataclasses import dataclass
4 | from typing import List, Optional
5 |
6 |
7 | @dataclass
8 | class TesseractConfig:
9 | """Configuration for Tesseract OCR."""
10 |
11 | tessdata_path: Optional[str] = None
12 | language: str = "eng"
13 | psm: int = 3 # Page segmentation mode
14 | oem: int = 3 # OCR Engine mode
15 |
16 |
17 | @dataclass
18 | class EasyOCRConfig:
19 | """Configuration for EasyOCR."""
20 |
21 | languages: List[str] = ("en",)
22 | gpu: bool = False
23 | model_storage_directory: Optional[str] = None
24 | download_enabled: bool = True
25 |
26 |
27 | @dataclass
28 | class PaddleOCRConfig:
29 | """Configuration for PaddleOCR."""
30 |
31 | use_angle_cls: bool = True
32 | lang: str = "en"
33 | use_gpu: bool = False
34 | show_log: bool = False
35 |
36 |
37 | @dataclass
38 | class DocTRConfig:
39 | """Configuration for DocTR."""
40 |
41 | pretrained: bool = True
42 | assume_straight_pages: bool = True
43 | straighten_pages: bool = True
44 |
--------------------------------------------------------------------------------
/src/pdf_embedding_decider/components/text_density_analyzer.py:
--------------------------------------------------------------------------------
1 | import fitz
2 |
3 | from pdf_embedding_decider.interfaces import PDFAnalyzer
4 | from src.pdf_embedding_decider.datatypes import AnalysisConfig, AnalysisResult
5 |
6 |
7 | class TextDensityAnalyzer(PDFAnalyzer):
8 | def __init__(self, config: AnalysisConfig):
9 | self.config = config
10 |
11 | def _analyze_document(self, doc: fitz.Document) -> AnalysisResult:
12 | total_density = 0
13 | page_densities = []
14 |
15 | for page_num, page in enumerate(doc):
16 | text_area = self._calculate_text_area(page)
17 | page_area = page.rect.width * page.rect.height
18 | density = text_area / page_area if page_area > 0 else 0
19 |
20 | page_densities.append(
21 | {
22 | "page": page_num + 1,
23 | "density": density,
24 | "text_area": text_area,
25 | "page_area": page_area,
26 | }
27 | )
28 | total_density += density
29 |
30 | avg_density = total_density / len(doc) if len(doc) > 0 else 0
31 | return AnalysisResult(
32 | score=avg_density,
33 | details={"average_density": avg_density, "page_densities": page_densities},
34 | )
35 |
36 | @staticmethod
37 | def _calculate_text_area(page: fitz.Page) -> float:
38 | text_area = 0
39 | for block in page.get_text("blocks"):
40 | x0, y0, x1, y1, *_ = block
41 | text_area += (x1 - x0) * (y1 - y0)
42 | return text_area
43 |
--------------------------------------------------------------------------------
/test_vespa_indexing.py:
--------------------------------------------------------------------------------
1 | import asyncio
2 |
3 | from src.vespa.datatypes import PDFInput
4 | from src.vespa.indexing.pdf_processor import PDFProcessor
5 | from src.vespa.indexing.prepare_feed import VespaFeedPreparator
6 | from src.vespa.indexing.run import run_indexing
7 |
8 |
9 | async def index_documents():
10 | """Example function showing how to use run_indexing."""
11 | config_path = "/Users/vesaalexandru/Workspaces/cube/cube-publication/evaluate-ocr-rag-systems/src/vespa/vespa_config.yaml"
12 |
13 | pdfs = [
14 | PDFInput(
15 | title="Building a Resilient Strategy",
16 | url="https://static.conocophillips.com/files/resources/24-0976-sustainability-highlights_nature.pdf",
17 | )
18 | ]
19 |
20 | # pdfs = [
21 | # PDFInput(
22 | # title="Building a Resilient Strategy",
23 | # url="https://static.conocophillips.com/files/resources/conocophillips-2023-managing-climate-related-risks.pdf",
24 | # )
25 | # ]
26 |
27 | # Explicitly create the processors
28 | pdf_processor = PDFProcessor()
29 | feed_preparator = VespaFeedPreparator()
30 |
31 | try:
32 | await run_indexing(
33 | config_path=config_path,
34 | pdfs=pdfs,
35 | pdf_processor=pdf_processor,
36 | feed_preparator=feed_preparator,
37 | )
38 | except Exception as e:
39 | print(f"Error occurred: {e}")
40 | raise
41 |
42 |
43 | def main():
44 | """Entry point for the application."""
45 | asyncio.run(index_documents())
46 |
47 |
48 | if __name__ == "__main__":
49 | main()
50 |
--------------------------------------------------------------------------------
/src/pdf_embedding_decider/components/layout_analyzer.py:
--------------------------------------------------------------------------------
1 | import logging
2 | from types import Any, Dict
3 |
4 | import fitz
5 |
6 | from pdf_embedding_decider.interfaces import PDFAnalyzer
7 | from src.pdf_embedding_decider.datatypes import AnalysisResult
8 |
9 | logging.basicConfig(level=logging.INFO)
10 | logger = logging.getLogger(__name__)
11 |
12 |
13 | class LayoutAnalyzer(PDFAnalyzer):
14 | def _analyze_document(self, doc: fitz.Document) -> AnalysisResult:
15 | total_complexity = 0
16 | page_layouts = []
17 |
18 | for page_num, page in enumerate(doc):
19 | layout_info = self._analyze_page_layout(page)
20 | complexity = layout_info["block_count"] + len(layout_info["alignments"])
21 | total_complexity += complexity
22 |
23 | page_layouts.append(
24 | {"page": page_num + 1, **layout_info, "complexity_score": complexity}
25 | )
26 |
27 | return AnalysisResult(
28 | score=total_complexity,
29 | details={
30 | "total_complexity": total_complexity,
31 | "page_layouts": page_layouts,
32 | },
33 | )
34 |
35 | def _analyze_page_layout(self, page: fitz.Page) -> Dict[str, Any]:
36 | text_blocks = page.get_text("blocks")
37 | alignments = set()
38 |
39 | for block in text_blocks:
40 | x0, y0, x1, y1, *_ = block
41 | if x0 < page.rect.width * 0.3:
42 | alignments.add("left")
43 | elif x1 > page.rect.width * 0.7:
44 | alignments.add("right")
45 | else:
46 | alignments.add("center")
47 |
48 | return {"block_count": len(text_blocks), "alignments": list(alignments)}
49 |
--------------------------------------------------------------------------------
/src/ocr_benchmark/base.py:
--------------------------------------------------------------------------------
1 | """Base class for OCR engines following the Interface Segregation Principle."""
2 |
3 | from abc import ABC, abstractmethod
4 | from typing import Any, Dict, List
5 |
6 | from PIL import Image
7 |
8 |
9 | class OCREngine(ABC):
10 | """Abstract base class for OCR engines."""
11 |
12 | @property
13 | @abstractmethod
14 | def name(self) -> str:
15 | """Return the name of the OCR engine."""
16 | pass
17 |
18 | @abstractmethod
19 | def initialize(self) -> None:
20 | """Initialize the OCR engine with required models and configurations."""
21 | pass
22 |
23 | @abstractmethod
24 | def process_image(self, image: Image.Image) -> Dict[str, Any]:
25 | """
26 | Process a single image and return the extracted text and metadata.
27 |
28 | Args:
29 | image: PIL Image object to process
30 |
31 | Returns:
32 | Dictionary containing:
33 | - text: extracted text
34 | - confidence: confidence scores if available
35 | - boxes: bounding boxes if available
36 | - tables: detected tables if available
37 | """
38 | pass
39 |
40 | @abstractmethod
41 | def process_images(self, images: List[Image.Image]) -> List[Dict[str, Any]]:
42 | """
43 | Process multiple images and return results for each.
44 |
45 | Args:
46 | images: List of PIL Image objects
47 |
48 | Returns:
49 | List of dictionaries containing results for each image
50 | """
51 | pass
52 |
53 | def cleanup(self) -> None:
54 | """Cleanup resources. Override if needed."""
55 | pass
56 |
57 | def __enter__(self):
58 | """Context manager entry."""
59 | self.initialize()
60 | return self
61 |
62 | def __exit__(self, exc_type, exc_val, exc_tb):
63 | """Context manager exit."""
64 | self.cleanup()
65 |
--------------------------------------------------------------------------------
/test_vespa_inference.py:
--------------------------------------------------------------------------------
1 | import asyncio
2 | import logging
3 | from typing import Dict, List
4 |
5 | from src.vespa.datatypes import QueryResult
6 | from src.vespa.retrieval.run import run_queries
7 |
8 | logging.basicConfig(level=logging.INFO)
9 | logger = logging.getLogger(__name__)
10 |
11 |
12 | async def main() -> Dict[str, List[QueryResult]]:
13 | # Path to your config file
14 | config_path = (
15 | "/Users/vesaalexandru/Workspaces/cube/cube-publication/"
16 | "evaluate-ocr-rag-systems/src/vespa/vespa_config.yaml"
17 | )
18 |
19 | # Test queries
20 | queries = [
21 | "Percentage of non-fresh water as source?",
22 | ]
23 |
24 | try:
25 | logger.info("Starting query execution")
26 | results = await run_queries(
27 | config_path=config_path, queries=queries, display_results=True
28 | )
29 | logger.info("Query execution completed successfully")
30 | return results
31 |
32 | except Exception as e:
33 | logger.error(f"Query execution failed: {str(e)}")
34 | raise
35 |
36 |
37 | if __name__ == "__main__":
38 | import base64
39 | from io import BytesIO
40 |
41 | import google.generativeai as genai
42 | from PIL import Image
43 | from vidore_benchmark.utils.image_utils import scale_image
44 |
45 | results = asyncio.run(main())
46 |
47 | genai.configure(api_key="AIzaSyCxMUFUaeApWRNr5HUS_xhWL26p0WLuG2w")
48 |
49 | queries = [
50 | "Percentage of non-fresh water as source?",
51 | # "Policies related to nature risk?",
52 | # "How much of produced water is recycled?",
53 | ]
54 |
55 | best_hit = results["Percentage of non-fresh water as source?"][0]
56 | pdf_url = best_hit.url
57 | pdf_title = best_hit.title
58 | # match_scores = best_hit["fields"]["matchfeatures"]["max_sim_per_page"]
59 | images = best_hit["fields"]["images"]
60 | sorted_pages = sorted(match_scores.items(), key=lambda x: x[1], reverse=True)
61 | best_page, score = sorted_pages[0]
62 | best_page = int(best_page)
63 | image_data = base64.b64decode(best_hit.source["fields"]["image"])
64 | image = Image.open(BytesIO(image_data))
65 | scaled_image = scale_image(image, 720)
66 | # # display(scaled_image)
67 |
68 | model = genai.GenerativeModel(model_name="gemini-1.5-flash")
69 | response = model.generate_content([queries[0], image])
70 | print(response)
71 |
--------------------------------------------------------------------------------
/vespa_inference.py:
--------------------------------------------------------------------------------
1 | import asyncio
2 |
3 | from vespa.deployment import VespaCloud
4 | from vespa.io import VespaQueryResponse
5 |
6 | from vespa_vision_rag import VespaSetup
7 |
8 | cloud = VespaSetup("test")
9 |
10 |
11 | async def query_vespa(queries, vespa_cloud):
12 | """
13 | Query Vespa for each user query.
14 | """
15 | app = vespa_cloud.get_application()
16 |
17 | async with app.asyncio(connections=1, timeout=180) as session:
18 | for query in queries:
19 | response: VespaQueryResponse = await session.query(
20 | yql="select title,url,image,page_number from pdf_page where userInput(@userQuery)",
21 | ranking="default",
22 | userQuery=query,
23 | timeout=120,
24 | hits=5, # Adjust the number of hits returned
25 | body={"presentation.timing": True},
26 | )
27 | if response.is_successful():
28 | display_query_results(query, response)
29 | else:
30 | print(f"Query failed for '{query}': {response.json()}")
31 |
32 |
33 | def display_query_results(query, response):
34 | """
35 | Display query results in a readable format.
36 | """
37 | query_time = response.json.get("timing", {}).get("searchtime", -1)
38 | total_count = response.json.get("root", {}).get("fields", {}).get("totalCount", 0)
39 | print(f"Query: {query}")
40 | print(f"Query Time: {query_time}s, Total Results: {total_count}")
41 | for idx, hit in enumerate(response.hits):
42 | title = hit["fields"].get("title", "N/A")
43 | url = hit["fields"].get("url", "N/A")
44 | page_number = hit["fields"].get("page_number", "N/A")
45 | print(f" Result {idx + 1}:")
46 | print(f" Title: {title}")
47 | print(f" URL: {url}")
48 | print(f" Page: {page_number}")
49 |
50 |
51 | async def main():
52 | # Define the queries you want to execute
53 | queries = [
54 | "Percentage of non-fresh water as source?",
55 | "Policies related to nature risk?",
56 | "How much of produced water is recycled?",
57 | ]
58 |
59 | # Initialize VespaCloud
60 | vespa_cloud = VespaCloud(
61 | tenant="cube-digital",
62 | application="test",
63 | application_package=cloud.app_package, # Use None since the app is already deployed
64 | )
65 |
66 | # Run queries against the Vespa app
67 | await query_vespa(queries, vespa_cloud)
68 |
69 |
70 | if __name__ == "__main__":
71 | asyncio.run(main())
72 |
--------------------------------------------------------------------------------
/src/pdf_embedding_decider/components/visual_element_analyzer.py:
--------------------------------------------------------------------------------
1 | import fitz
2 |
3 | from pdf_embedding_decider.interfaces import PDFAnalyzer
4 | from src.pdf_embedding_decider.datatypes import AnalysisConfig, AnalysisResult
5 |
6 |
7 | class VisualElementAnalyzer(PDFAnalyzer):
8 | def __init__(self, config: AnalysisConfig):
9 | self.config = config
10 |
11 | def _analyze_document(self, doc: fitz.Document) -> AnalysisResult:
12 | total_visual_elements = 0
13 | image_details = []
14 | drawing_details = []
15 |
16 | for page_num, page in enumerate(doc):
17 | # Analyze images
18 | images = page.get_images(full=True)
19 | filtered_images = [
20 | img for img in images if self._is_significant_image(page, img)
21 | ]
22 | image_details.extend(
23 | [
24 | {"page": page_num + 1, "size": self._get_image_size(page, img)}
25 | for img in filtered_images
26 | ]
27 | )
28 |
29 | # Analyze vector graphics
30 | drawings = page.get_drawings()
31 | significant_drawings = self._filter_significant_drawings(drawings)
32 | drawing_details.extend(
33 | [
34 | {"page": page_num + 1, "complexity": len(draw["items"])}
35 | for draw in significant_drawings
36 | ]
37 | )
38 |
39 | total_visual_elements += len(filtered_images) + len(significant_drawings)
40 |
41 | return AnalysisResult(
42 | score=total_visual_elements,
43 | details={
44 | "total_elements": total_visual_elements,
45 | "images": image_details,
46 | "drawings": drawing_details,
47 | },
48 | )
49 |
50 | def _is_significant_image(self, page: fitz.Page, image: tuple) -> bool:
51 | """Filter out small or insignificant images"""
52 | xref = image[0]
53 | pix = fitz.Pixmap(page.parent, xref)
54 | area = pix.width * pix.height
55 | return area >= self.config.min_image_size
56 |
57 | def _filter_significant_drawings(self, drawings: List[dict]) -> List[dict]:
58 | """Filter out simple decorative elements"""
59 | return [d for d in drawings if len(d["items"]) > 2]
60 |
61 | @staticmethod
62 | def _get_image_size(page: fitz.Page, image: tuple) -> dict:
63 | xref = image[0]
64 | pix = fitz.Pixmap(page.parent, xref)
65 | return {"width": pix.width, "height": pix.height}
66 |
--------------------------------------------------------------------------------
/test2/schemas/pdf_page.sd:
--------------------------------------------------------------------------------
1 | schema pdf_page {
2 | document pdf_page {
3 | field id type string {
4 | indexing: summary | index
5 | match {
6 | word
7 | }
8 | }
9 | field url type string {
10 | indexing: summary | index
11 | }
12 | field title type string {
13 | indexing: summary | index
14 | index: enable-bm25
15 | match {
16 | text
17 | }
18 | }
19 | field page_number type int {
20 | indexing: summary | attribute
21 | }
22 | field image type raw {
23 | indexing: summary
24 | }
25 | field text type string {
26 | indexing: index
27 | index: enable-bm25
28 | match {
29 | text
30 | }
31 | }
32 | field embedding type tensor(patch{}, v[16]) {
33 | indexing: attribute | index
34 | attribute {
35 | distance-metric: hamming
36 | }
37 | index {
38 | hnsw {
39 | max-links-per-node: 32
40 | neighbors-to-explore-at-insert: 400
41 | }
42 | }
43 | }
44 | }
45 | fieldset default {
46 | fields: title, text
47 | }
48 | rank-profile default {
49 | inputs {
50 | query(qt) tensor(querytoken{}, v[128])
51 |
52 | }
53 | function max_sim() {
54 | expression {
55 |
56 | sum(
57 | reduce(
58 | sum(
59 | query(qt) * unpack_bits(attribute(embedding)) , v
60 | ),
61 | max, patch
62 | ),
63 | querytoken
64 | )
65 |
66 | }
67 | }
68 | function bm25_score() {
69 | expression {
70 | bm25(title) + bm25(text)
71 | }
72 | }
73 | first-phase {
74 | expression {
75 | bm25_score
76 | }
77 | }
78 | second-phase {
79 | rerank-count: 100
80 | expression {
81 | max_sim
82 | }
83 | }
84 | }
85 | }
--------------------------------------------------------------------------------
/colpal/schemas/pdf_page.sd:
--------------------------------------------------------------------------------
1 | schema pdf_page {
2 | document pdf_page {
3 | field id type string {
4 | indexing: summary | index
5 | match {
6 | word
7 | }
8 | }
9 | field url type string {
10 | indexing: summary | index
11 | }
12 | field title type string {
13 | indexing: summary | index
14 | index: enable-bm25
15 | match {
16 | text
17 | }
18 | }
19 | field page_number type int {
20 | indexing: summary | attribute
21 | }
22 | field image type raw {
23 | indexing: summary
24 | }
25 | field text type string {
26 | indexing: index
27 | index: enable-bm25
28 | match {
29 | text
30 | }
31 | }
32 | field embedding type tensor(patch{}, v[16]) {
33 | indexing: attribute | index
34 | attribute {
35 | distance-metric: hamming
36 | }
37 | index {
38 | hnsw {
39 | max-links-per-node: 32
40 | neighbors-to-explore-at-insert: 400
41 | }
42 | }
43 | }
44 | }
45 | fieldset default {
46 | fields: title, text
47 | }
48 | rank-profile default {
49 | inputs {
50 | query(qt) tensor(querytoken{}, v[128])
51 |
52 | }
53 | function max_sim() {
54 | expression {
55 |
56 | sum(
57 | reduce(
58 | sum(
59 | query(qt) * unpack_bits(attribute(embedding)) , v
60 | ),
61 | max, patch
62 | ),
63 | querytoken
64 | )
65 |
66 | }
67 | }
68 | function bm25_score() {
69 | expression {
70 | bm25(title) + bm25(text)
71 | }
72 | }
73 | first-phase {
74 | expression {
75 | bm25_score
76 | }
77 | }
78 | second-phase {
79 | rerank-count: 100
80 | expression {
81 | max_sim
82 | }
83 | }
84 | }
85 | }
--------------------------------------------------------------------------------
/src/vespa/datatypes.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass
2 | from typing import Any, Dict, List, Optional
3 |
4 | from torch import Tensor
5 |
6 |
7 | @dataclass
8 | class PDFInput:
9 | """Data class for PDF input information."""
10 |
11 | title: str
12 | url: str
13 |
14 |
15 | @dataclass
16 | class PDFData:
17 | """Data class to store processed PDF information."""
18 |
19 | url: str
20 | title: str
21 | images: List[Any] # PIL.Image type
22 | texts: List[str]
23 | embeddings: List[Tensor]
24 |
25 |
26 | @dataclass
27 | class PDFPage:
28 | id: str
29 | url: str
30 | title: str
31 | page_number: int
32 | image: str
33 | text: str
34 | embedding: Dict[int, str]
35 |
36 |
37 | @dataclass
38 | class VespaSchemaConfig:
39 | """Configuration for Vespa schema settings."""
40 |
41 | max_query_terms: int = 64
42 | hnsw_max_links: int = 32
43 | hnsw_neighbors: int = 400
44 | rerank_count: int = 10
45 | tensor_dimensions: int = 16
46 |
47 | @classmethod
48 | def from_dict(cls, config_dict: Dict[str, Any]) -> "VespaSchemaConfig":
49 | """Create schema config from dictionary with default values."""
50 | return cls(
51 | max_query_terms=config_dict.get("max_query_terms", 64),
52 | hnsw_max_links=config_dict.get("hnsw_max_links", 32),
53 | hnsw_neighbors=config_dict.get("hnsw_neighbors", 400),
54 | rerank_count=config_dict.get("rerank_count", 10),
55 | tensor_dimensions=config_dict.get("tensor_dimensions", 16),
56 | )
57 |
58 |
59 | @dataclass
60 | class VespaConfig:
61 | """Configuration for Vespa deployment."""
62 |
63 | app_name: str
64 | tenant_name: str
65 | connections: int = 1
66 | timeout: int = 180
67 | schema_name: str = "pdf_page"
68 | schema_config: Optional[VespaSchemaConfig] = None
69 |
70 | @classmethod
71 | def from_dict(cls, config_dict: Dict[str, Any]) -> "VespaConfig":
72 | """Create VespaConfig from dictionary, handling schema config separately."""
73 | # Extract and convert schema config if present
74 | schema_dict = config_dict.pop("schema", {})
75 | schema_config = (
76 | VespaSchemaConfig.from_dict(schema_dict) if schema_dict else None
77 | )
78 |
79 | return cls(**config_dict, schema_config=schema_config)
80 |
81 |
82 | @dataclass
83 | class VespaQueryConfig:
84 | app_name: str
85 | tenant_name: str
86 | connections: int = 1
87 | timeout: int = 180
88 | hits_per_query: int = 5
89 | schema_name: str = "pdf_page"
90 | tensor_dimensions: int = 16 # Added this field
91 | schema_config: Optional[VespaSchemaConfig] = None
92 |
93 | @classmethod
94 | def from_dict(cls, config_dict: Dict[str, Any]) -> "VespaQueryConfig":
95 | schema_dict = config_dict.pop("schema", {})
96 | schema_config = (
97 | VespaSchemaConfig.from_dict(schema_dict) if schema_dict else None
98 | )
99 | return cls(
100 | **{k: v for k, v in config_dict.items() if k != "schema"},
101 | schema_config=schema_config,
102 | )
103 |
104 |
105 | @dataclass
106 | class QueryResult:
107 | """Data class for query results."""
108 |
109 | title: str
110 | url: str
111 | page_number: int
112 | relevance: float
113 | text: str
114 | source: Dict[str, Any]
115 |
--------------------------------------------------------------------------------
/src/pdf_embedding_decider/analyze.py:
--------------------------------------------------------------------------------
1 | import logging
2 | from typing import Dict
3 |
4 | from src.pdf_embedding_decider.components.layout_analyzer import LayoutAnalyzer
5 | from src.pdf_embedding_decider.components.table_detector import TableDetector
6 | from src.pdf_embedding_decider.components.text_density_analyzer import (
7 | TextDensityAnalyzer,
8 | )
9 | from src.pdf_embedding_decider.components.visual_element_analyzer import (
10 | VisualElementAnalyzer,
11 | )
12 | from src.pdf_embedding_decider.datatypes import (
13 | AnalysisConfig,
14 | AnalysisResult,
15 | EmbeddingType,
16 | )
17 |
18 | # Configure logging
19 | logging.basicConfig(level=logging.INFO)
20 | logger = logging.getLogger(__name__)
21 |
22 |
23 | class PDFEmbeddingDecider:
24 | def __init__(self, config: AnalysisConfig):
25 | self.config = config
26 | self.analyzers = {
27 | "visual": VisualElementAnalyzer(config),
28 | "density": TextDensityAnalyzer(config),
29 | "layout": LayoutAnalyzer(),
30 | "table": TableDetector(config),
31 | }
32 |
33 | def analyze(self, pdf_path: str) -> Dict[str, AnalysisResult]:
34 | """Run all analyses and return detailed results"""
35 | return {
36 | name: analyzer.analyze(pdf_path)
37 | for name, analyzer in self.analyzers.items()
38 | }
39 |
40 | def decide(self, pdf_path: str) -> EmbeddingType:
41 | """Determine the appropriate embedding type based on PDF analysis"""
42 | try:
43 | results = self.analyze(pdf_path)
44 |
45 | # Log detailed analysis results
46 | logger.info("Analysis Results:")
47 | for analyzer_name, result in results.items():
48 | logger.info(f"{analyzer_name}: {result}")
49 |
50 | # Enhanced decision logic incorporating table analysis
51 | table_score = results["table"].score
52 |
53 | if table_score > 0:
54 | table_influence = min(1.0, table_score * self.config.table_weight)
55 | adjusted_density_threshold = self.config.text_density_threshold * (
56 | 1 - table_influence
57 | )
58 | else:
59 | adjusted_density_threshold = self.config.text_density_threshold
60 |
61 | if (
62 | results["visual"].score > self.config.visual_threshold
63 | or results["density"].score < adjusted_density_threshold
64 | or (
65 | results["layout"].score > self.config.layout_threshold
66 | and table_score == 0
67 | )
68 | ): # Only consider complex layout if not tabular
69 | return EmbeddingType.COLPALI
70 |
71 | return EmbeddingType.TRADITIONAL
72 |
73 | except Exception as e:
74 | logger.error(f"Error deciding embedding type: {str(e)}")
75 | raise
76 |
77 |
78 | if __name__ == "__main__":
79 | config = AnalysisConfig(
80 | visual_threshold=15,
81 | text_density_threshold=0.25,
82 | layout_threshold=100,
83 | min_image_size=1000,
84 | table_row_threshold=5,
85 | table_weight=0.3,
86 | )
87 |
88 | # Initialize decider
89 | decider = PDFEmbeddingDecider(config)
90 |
91 | # Example usage
92 | pdf_path = "/Users/vesaalexandru/Workspaces/cube/cube-publication/evaluate-ocr-rag-systems/data/aiminded-extras-octomrbie-decembrie-2023.pdf"
93 | # pdf_path = "/Users/vesaalexandru/Workspaces/cube/cube-publication/evaluate-ocr-rag-systems/data/paper01-1-2.pdf"
94 | try:
95 | # Get detailed analysis
96 | analysis_results = decider.analyze(pdf_path)
97 |
98 | # Get final decision
99 | embedding_type = decider.decide(pdf_path)
100 |
101 | logger.info(f"Recommended embedding type: {embedding_type.value}")
102 | logger.info("Detailed analysis results:")
103 | for analyzer_name, result in analysis_results.items():
104 | logger.info(f"{analyzer_name}: {result}")
105 |
106 | except Exception as e:
107 | logger.error(f"Error processing PDF: {str(e)}")
108 |
--------------------------------------------------------------------------------
/src/vespa/indexing/prepare_feed.py:
--------------------------------------------------------------------------------
1 | import base64
2 | import hashlib
3 | import logging
4 | from io import BytesIO
5 | from typing import Dict, List, Optional
6 |
7 | import numpy as np
8 | import numpy.typing as npt
9 | from PIL import Image
10 | from torch import Tensor
11 |
12 | from src.vespa.indexing.pdf_processor import PDFData
13 |
14 | logging.basicConfig(level=logging.DEBUG)
15 | logger = logging.getLogger(__name__)
16 |
17 |
18 | class VespaFeedPreparator:
19 | def __init__(self, max_image_height: int = 800) -> None:
20 | self.max_image_height = max_image_height
21 |
22 | def prepare_feed(self, pdf_data: List[PDFData]) -> List[Dict]:
23 | try:
24 | vespa_feed = []
25 | for pdf in pdf_data:
26 | logger.info(f"Processing PDF: {pdf.title}")
27 | for page_number, (text, embedding, image) in enumerate(
28 | zip(pdf.texts, pdf.embeddings, pdf.images)
29 | ):
30 | page_id = self._generate_page_id(pdf.url, page_number)
31 | processed_image = self._process_image(image)
32 | embedding_dict = self._process_embeddings(embedding)
33 |
34 | doc = {
35 | "fields": {
36 | "id": page_id,
37 | "url": pdf.url,
38 | "title": pdf.title,
39 | "page_number": page_number,
40 | "image": processed_image,
41 | "text": text,
42 | "embedding": {
43 | "blocks": self._convert_to_patch_blocks(embedding_dict)
44 | },
45 | }
46 | }
47 | vespa_feed.append(doc)
48 | return vespa_feed
49 | except Exception as e:
50 | logger.error(f"Failed to prepare feed: {str(e)}")
51 | raise
52 |
53 | def _convert_to_patch_blocks(self, embedding_dict: Dict[int, str]) -> List[Dict]:
54 | return [
55 | {"address": {"patch": patch_idx}, "values": vector}
56 | for patch_idx, vector in embedding_dict.items()
57 | if vector != "0" * 32
58 | ]
59 |
60 | def _generate_page_id(self, url: str, page_number: int) -> str:
61 | content = f"{url}{page_number}".encode("utf-8")
62 | return hashlib.sha256(content).hexdigest()
63 |
64 | def _process_embeddings(self, embedding: Tensor) -> Dict[int, str]:
65 | embedding_dict = {}
66 | embedding_float = embedding.detach().cpu().float()
67 | embedding_np = embedding_float.numpy()
68 | for idx, patch_embedding in enumerate(embedding_np):
69 | binary_vector = self._convert_to_binary_vector(patch_embedding)
70 | embedding_dict[idx] = binary_vector
71 | return embedding_dict
72 |
73 | @staticmethod
74 | def _convert_to_binary_vector(patch_embedding: npt.NDArray[np.float32]) -> str:
75 | binary = np.packbits(np.where(patch_embedding > 0, 1, 0))
76 | return binary.astype(np.int8).tobytes().hex()
77 |
78 | def _process_image(self, image: Image.Image) -> str:
79 | resized_image = self._resize_image(image)
80 | return self._encode_image_base64(resized_image)
81 |
82 | def _resize_image(
83 | self, image: Image.Image, target_width: Optional[int] = 640
84 | ) -> Image.Image:
85 | width, height = image.size
86 | if height > self.max_image_height:
87 | ratio = self.max_image_height / height
88 | new_width = int(width * ratio)
89 | return image.resize((new_width, self.max_image_height), Image.LANCZOS)
90 | if target_width and width > target_width:
91 | ratio = target_width / width
92 | new_height = int(height * ratio)
93 | return image.resize((target_width, new_height), Image.LANCZOS)
94 | return image
95 |
96 | @staticmethod
97 | def _encode_image_base64(image: Image.Image) -> str:
98 | buffered = BytesIO()
99 | image.save(buffered, format="JPEG", quality=85, optimize=True)
100 | return base64.b64encode(buffered.getvalue()).decode("utf-8")
101 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 | cover/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | .pybuilder/
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | # For a library or package, you might want to ignore these files since the code is
87 | # intended to run in multiple environments; otherwise, check them in:
88 | # .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # UV
98 | # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
99 | # This is especially recommended for binary packages to ensure reproducibility, and is more
100 | # commonly ignored for libraries.
101 | #uv.lock
102 |
103 | # poetry
104 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
105 | # This is especially recommended for binary packages to ensure reproducibility, and is more
106 | # commonly ignored for libraries.
107 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
108 | #poetry.lock
109 |
110 | # pdm
111 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
112 | #pdm.lock
113 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
114 | # in version control.
115 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
116 | .pdm.toml
117 | .pdm-python
118 | .pdm-build/
119 |
120 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
121 | __pypackages__/
122 |
123 | # Celery stuff
124 | celerybeat-schedule
125 | celerybeat.pid
126 |
127 | # SageMath parsed files
128 | *.sage.py
129 |
130 | # Environments
131 | .env
132 | .venv
133 | env/
134 | venv/
135 | ENV/
136 | env.bak/
137 | venv.bak/
138 |
139 | /data
140 |
141 | # Spyder project settings
142 | .spyderproject
143 | .spyproject
144 |
145 | # Rope project settings
146 | .ropeproject
147 |
148 | # mkdocs documentation
149 | /site
150 |
151 | # mypy
152 | .mypy_cache/
153 | .dmypy.json
154 | dmypy.json
155 |
156 | # Pyre type checker
157 | .pyre/
158 |
159 | # pytype static type analyzer
160 | .pytype/
161 |
162 | # Cython debug symbols
163 | cython_debug/
164 |
165 | # PyCharm
166 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
167 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
168 | # and can be added to the global gitignore or merged into this file. For a more nuclear
169 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
170 |
171 | # IntelijIdea
172 | .idea/
173 |
174 | # Visual Studio Code
175 | .vscode/
176 |
177 | # PyPI configuration file
178 | .pypirc
--------------------------------------------------------------------------------
/src/pdf_embedding_decider/components/table_detector.py:
--------------------------------------------------------------------------------
1 | from types import Any, Dict, List
2 |
3 | import fitz
4 |
5 | from pdf_embedding_decider.interfaces import PDFAnalyzer
6 | from src.pdf_embedding_decider.datatypes import AnalysisConfig, AnalysisResult
7 |
8 |
9 | class TableDetector(PDFAnalyzer):
10 | """Analyzes the presence and structure of tables in the document"""
11 |
12 | def __init__(self, config: AnalysisConfig):
13 | self.config = config
14 |
15 | def _analyze_document(self, doc: fitz.Document) -> AnalysisResult:
16 | total_tables = 0
17 | table_details = []
18 |
19 | for page_num, page in enumerate(doc):
20 | tables = self._detect_tables(page)
21 | total_tables += len(tables)
22 | table_details.extend(
23 | [
24 | {
25 | "page": page_num + 1,
26 | "rows": table["rows"],
27 | "columns": table["columns"],
28 | "area": table["area"],
29 | }
30 | for table in tables
31 | ]
32 | )
33 |
34 | return AnalysisResult(
35 | score=total_tables,
36 | details={"total_tables": total_tables, "table_details": table_details},
37 | )
38 |
39 | def _detect_tables(self, page: fitz.Page) -> List[Dict[str, Any]]:
40 | """Detect tables based on text block alignment and spacing"""
41 | blocks = page.get_text("blocks")
42 | tables = []
43 |
44 | # Sort blocks by vertical position
45 | sorted_blocks = sorted(blocks, key=lambda b: (b[1], b[0])) # Sort by y, then x
46 |
47 | current_table = {"rows": [], "y_positions": set()}
48 |
49 | for block in sorted_blocks:
50 | x0, y0, x1, y1, text, *_ = block
51 |
52 | # Check if block is part of current table structure
53 | if current_table["rows"]:
54 | last_y = max(current_table["y_positions"])
55 | y_gap = y0 - last_y
56 |
57 | if y_gap > 20: # New table or not part of table
58 | if len(current_table["rows"]) >= self.config.table_row_threshold:
59 | tables.append(self._finalize_table(current_table))
60 | current_table = {"rows": [], "y_positions": set()}
61 |
62 | current_table["rows"].append({"text": text, "bbox": (x0, y0, x1, y1)})
63 | current_table["y_positions"].add(y0)
64 |
65 | # Check last table
66 | if len(current_table["rows"]) >= self.config.table_row_threshold:
67 | tables.append(self._finalize_table(current_table))
68 |
69 | return tables
70 |
71 | def _finalize_table(self, table_data: Dict) -> Dict[str, Any]:
72 | """Calculate table metrics"""
73 | rows = table_data["rows"]
74 | x_positions = []
75 | for row in rows:
76 | x_positions.extend([row["bbox"][0], row["bbox"][2]])
77 |
78 | # Estimate columns by analyzing x-position clusters
79 | x_clusters = self._cluster_positions(x_positions)
80 |
81 | return {
82 | "rows": len(rows),
83 | "columns": len(x_clusters)
84 | // 2, # Divide by 2 as we counted start/end positions
85 | "area": self._calculate_table_area(rows),
86 | }
87 |
88 | @staticmethod
89 | def _cluster_positions(
90 | positions: List[float], threshold: float = 10
91 | ) -> List[float]:
92 | """Cluster similar x-positions together"""
93 | positions = sorted(positions)
94 | clusters = []
95 | current_cluster = [positions[0]]
96 |
97 | for pos in positions[1:]:
98 | if pos - current_cluster[-1] <= threshold:
99 | current_cluster.append(pos)
100 | else:
101 | clusters.append(sum(current_cluster) / len(current_cluster))
102 | current_cluster = [pos]
103 |
104 | if current_cluster:
105 | clusters.append(sum(current_cluster) / len(current_cluster))
106 |
107 | return clusters
108 |
109 | @staticmethod
110 | def _calculate_table_area(rows: List[Dict]) -> float:
111 | """Calculate the total area occupied by the table"""
112 | if not rows:
113 | return 0
114 |
115 | x_min = min(row["bbox"][0] for row in rows)
116 | x_max = max(row["bbox"][2] for row in rows)
117 | y_min = min(row["bbox"][1] for row in rows)
118 | y_max = max(row["bbox"][3] for row in rows)
119 |
120 | return (x_max - x_min) * (y_max - y_min)
121 |
--------------------------------------------------------------------------------
/src/prompts/llama_vision.py:
--------------------------------------------------------------------------------
1 | from src.interfaces.interfaces import BasePromptTemplate
2 |
3 |
4 | class AnalyzePdfImage(BasePromptTemplate):
5 | prompt: str = """You are an OCR assistant specialized in analyzing PDF images. Your tasks are:
6 |
7 | 1. **Full Text Recognition:** Extract all visible text from the page, including equations, special characters, and inline formatting. Ensure accuracy in representing symbols, numbers, and technical terminology.
8 | 2. **Structure and Formatting Preservation:** Recreate the original structure and formatting of the text, including:
9 | - Headings and subheadings (use appropriate markdown syntax for headers).
10 | - Paragraphs (preserve line breaks and text alignment).
11 | - Lists (use bullet points or numbered lists as necessary).
12 | - Inline formatting (e.g., bold, italics, subscript, superscript).
13 | 3. **Table Extraction:** Identify and extract any tables from the page, preserving their structure and content. Include:
14 | - Table headers, rows, and columns.
15 | - Ensure proper alignment and representation using markdown table formatting.
16 | 4. **Figure and Graph Recognition:** Detect any figures, graphs, or charts on the page and provide the following:
17 | - A detailed description of each figure or graph.
18 | - Key elements such as titles, axes labels, data points, and trends.
19 | - Include a markdown-formatted list with clear annotations.
20 | 5. **Metadata Extraction:** Extract and include metadata such as:
21 | - Document title, authors, publication year, and source information (if available).
22 | - Licensing and attribution details.
23 | 6. **Additional Context:** Capture any footnotes, captions, or side notes present on the page. Include these in a separate "Notes" section to preserve supplementary information.
24 |
25 | **Output Format:**
26 | Return the extracted content in the following markdown format:
27 |
28 | - Use `#`, `##`, or `###` for headings to match the original hierarchy.
29 | - Represent lists with `-` or `1.` for bullet points and numbered lists.
30 | - Use markdown table formatting to represent tables, ensuring alignment and clarity.
31 | - Provide descriptions for figures and graphs as a markdown-formatted list.
32 | - Add a "Key Insights" section summarizing critical findings or observations.
33 | - Include a "Notes" section for supplementary information such as footnotes or captions.
34 |
35 | **Example Output:**
36 | ```
37 | # Document Title
38 |
39 | ## Metadata
40 | - **Title:** A Survey on Image Data Augmentation for Deep Learning
41 | - **Authors:** Connor Shorten and Taghi M. Khoshgoftaar
42 | - **Year:** 2019
43 | - **License:** Creative Commons Attribution 4.0 International License
44 |
45 | ## Key Insights
46 | - Data Augmentation improves model robustness and generalization, addressing overfitting in limited datasets.
47 | - Techniques include geometric transformations, GAN-based augmentations, and meta-learning.
48 |
49 | ## Abstract
50 | Deep convolutional neural networks have performed remarkably well on many Computer Vision tasks. However, these networks rely heavily on big data to avoid overfitting. Data Augmentation encompasses techniques such as geometric transformations, color space augmentations, and adversarial training. This paper outlines promising developments in these areas and their impact on deep learning model performance.
51 |
52 | ## Introduction
53 | Deep learning models excel in discriminative tasks like image classification and segmentation, leveraging architectures like AlexNet and ResNet. This paper focuses on augmenting data to expand training datasets, improving performance and generalization.
54 |
55 | ### Methodology
56 | - **Geometric Transformations:** Rotation, scaling, translation.
57 | - **Color Space Augmentation:** RGB to HSV conversion, random jittering.
58 | - **Kernel Filters:** Gaussian blur, median filter.
59 | - **Mixing Images:** Random erasing, GAN-generated augmentations.
60 |
61 | ### Results
62 | The proposed framework was evaluated on datasets like CIFAR-10 and SVHN, outperforming state-of-the-art augmentation techniques in accuracy and robustness.
63 |
64 | ### Graph Descriptions
65 | 1. *Figure 1:* Validation vs. training error graph illustrating overfitting (left) and desired generalization (right). The x-axis represents training epochs, and the y-axis represents error rates.
66 |
67 | ## Notes
68 | - *Figure 1:* Caption reads "Validation vs. training error over epochs."
69 | - All data was sourced from publicly available datasets.
70 | ```
71 |
72 | **Instructions:**
73 | - Ensure high fidelity in text and structural representation.
74 | - Capture all elements on the page, including supplementary information.
75 | - Use markdown for clarity and consistency in the output.
76 |
77 | """
78 |
79 | def create_template(
80 | self,
81 | ) -> str:
82 | return self.prompt.format()
83 |
--------------------------------------------------------------------------------
/src/ocr_benchmark/engines/easy_ocr.py:
--------------------------------------------------------------------------------
1 | """EasyOCR engine implementation."""
2 |
3 | from typing import Any, Dict, List, Optional
4 |
5 | import easyocr
6 | import numpy as np
7 | from PIL import Image
8 |
9 | from src.ocr_benchmark.base import OCREngine
10 | from src.ocr_benchmark.engines.config import EasyOCRConfig
11 | from src.ocr_benchmark.utils.image_processing import ensure_rgb
12 |
13 |
14 | class EasyOCREngine(OCREngine):
15 | """EasyOCR engine implementation."""
16 |
17 | def __init__(self, config: Optional[EasyOCRConfig] = None):
18 | """
19 | Initialize EasyOCREngine.
20 |
21 | Args:
22 | config: Optional configuration for EasyOCR
23 | """
24 | self._config = config or EasyOCRConfig()
25 | self._initialized = False
26 | self._reader = None
27 |
28 | @property
29 | def name(self) -> str:
30 | return "EasyOCR"
31 |
32 | def initialize(self) -> None:
33 | """Initialize EasyOCR with configuration."""
34 | if not self._initialized:
35 | self._reader = easyocr.Reader(
36 | lang_list=self._config.languages,
37 | gpu=self._config.gpu,
38 | model_storage_directory=self._config.model_storage_directory,
39 | download_enabled=self._config.download_enabled,
40 | )
41 | self._initialized = True
42 |
43 | def process_image(self, image: Image.Image) -> Dict[str, Any]:
44 | """
45 | Process a single image using EasyOCR.
46 |
47 | Args:
48 | image: PIL Image to process
49 |
50 | Returns:
51 | Dictionary containing:
52 | - text: extracted text
53 | - confidence: mean confidence score
54 | - boxes: detected text boxes
55 | """
56 | if not self._initialized:
57 | self.initialize()
58 |
59 | # Ensure image is in RGB format
60 | image = ensure_rgb(image)
61 |
62 | # Convert PIL Image to numpy array
63 | image_np = np.array(image)
64 |
65 | # Process image with EasyOCR
66 | results = self._reader.readtext(image_np)
67 |
68 | # Extract text and confidence scores
69 | boxes = []
70 | full_text = []
71 | confidences = []
72 |
73 | for bbox, text, conf in results:
74 | full_text.append(text)
75 | confidences.append(conf)
76 |
77 | # Convert bbox points to (x1, y1, x2, y2) format
78 | x1 = min(point[0] for point in bbox)
79 | y1 = min(point[1] for point in bbox)
80 | x2 = max(point[0] for point in bbox)
81 | y2 = max(point[1] for point in bbox)
82 |
83 | boxes.append(
84 | {
85 | "text": text,
86 | "conf": conf,
87 | "box": (int(x1), int(y1), int(x2), int(y2)),
88 | }
89 | )
90 |
91 | # Calculate mean confidence
92 | mean_confidence = sum(confidences) / len(confidences) if confidences else 0
93 |
94 | return {
95 | "text": " ".join(full_text),
96 | "confidence": mean_confidence,
97 | "boxes": boxes,
98 | }
99 |
100 | def process_images(self, images: List[Image.Image]) -> List[Dict[str, Any]]:
101 | """Process multiple images and return results for each."""
102 | return [self.process_image(image) for image in images]
103 |
104 |
105 | if __name__ == "__main__":
106 | import os
107 |
108 | from PIL import Image, ImageDraw, ImageFont
109 |
110 | # Create a test image
111 | img = Image.new("RGB", (800, 200), color="white")
112 | d = ImageDraw.Draw(img)
113 | try:
114 | font = ImageFont.truetype("Arial.ttf", 60)
115 | except:
116 | font = ImageFont.load_default()
117 |
118 | d.text((50, 50), "Hello, EasyOCR Test!", fill="black", font=font)
119 |
120 | # Save test image
121 | test_image_path = "test_easyocr.png"
122 | img.save(test_image_path)
123 |
124 | try:
125 | # Initialize engine
126 | engine = EasyOCREngine(EasyOCRConfig(languages=["en"]))
127 |
128 | # Process test image
129 | with Image.open(test_image_path) as test_img:
130 | result = engine.process_image(test_img)
131 |
132 | # Print results
133 | print("\nEasyOCR Test Results:")
134 | print(f"Detected Text: {result['text']}")
135 | print(f"Confidence: {result['confidence']:.2f}")
136 | print("Detected Boxes:", len(result["boxes"]))
137 |
138 | for box in result["boxes"]:
139 | print(f"- Text: {box['text']}, Confidence: {box['conf']:.2f}")
140 | print(f" Box coordinates: {box['box']}")
141 |
142 | except Exception as e:
143 | print(f"Error during test: {str(e)}")
144 |
145 | finally:
146 | # Cleanup
147 | if os.path.exists(test_image_path):
148 | os.remove(test_image_path)
149 |
--------------------------------------------------------------------------------
/src/ocr_benchmark/utils/pdf_processing.py:
--------------------------------------------------------------------------------
1 | """Utility functions for PDF processing and conversion."""
2 |
3 | import logging
4 | import tempfile
5 | from dataclasses import dataclass
6 | from pathlib import Path
7 | from typing import List, Optional, Tuple
8 |
9 | import pdf2image
10 | from pdf2image.exceptions import (
11 | PDFPageCountError,
12 | PDFSyntaxError,
13 | )
14 | from PIL import Image
15 |
16 | logger = logging.getLogger(__name__)
17 |
18 |
19 | @dataclass
20 | class PDFConversionSettings:
21 | """Settings for PDF to image conversion."""
22 |
23 | dpi: int = 300
24 | grayscale: bool = False
25 | use_cropbox: bool = False
26 | strict: bool = False
27 | thread_count: int = 4
28 | raise_on_error: bool = True
29 |
30 |
31 | class PDFProcessingError(Exception):
32 | """Base exception for PDF processing errors."""
33 |
34 | pass
35 |
36 |
37 | def validate_pdf(pdf_path: Path) -> Tuple[bool, str]:
38 | """
39 | Validate PDF file existence and readability.
40 |
41 | Args:
42 | pdf_path: Path to PDF file
43 |
44 | Returns:
45 | Tuple of (is_valid, message)
46 | """
47 | try:
48 | if not pdf_path.exists():
49 | return False, f"File not found: {pdf_path}"
50 |
51 | if pdf_path.stat().st_size == 0:
52 | return False, f"File is empty: {pdf_path}"
53 |
54 | # Try converting first page to validate format
55 | pdf2image.convert_from_path(str(pdf_path), first_page=1, last_page=1)
56 | return True, "Valid PDF file"
57 |
58 | except PDFSyntaxError:
59 | return False, f"Invalid PDF format or corrupted file: {pdf_path}"
60 | except PDFPageCountError:
61 | return False, f"Error determining page count: {pdf_path}"
62 | except Exception as e:
63 | return False, f"Error validating PDF: {str(e)}"
64 |
65 |
66 | def pdf_to_images(
67 | pdf_path: Path, settings: Optional[PDFConversionSettings] = None
68 | ) -> List[Image.Image]:
69 | """
70 | Convert PDF file to a list of PIL Images.
71 |
72 | Args:
73 | pdf_path: Path to PDF file
74 | settings: Optional conversion settings
75 |
76 | Returns:
77 | List of PIL Images, one per page
78 |
79 | Raises:
80 | PDFProcessingError: If conversion fails and raise_on_error is True
81 | """
82 | if settings is None:
83 | settings = PDFConversionSettings()
84 |
85 | try:
86 | # Validate PDF first
87 | is_valid, message = validate_pdf(pdf_path)
88 | if not is_valid and settings.raise_on_error:
89 | raise PDFProcessingError(message)
90 |
91 | # Create temporary directory for conversion
92 | with tempfile.TemporaryDirectory() as temp_dir:
93 | try:
94 | # Convert PDF to images
95 | images = pdf2image.convert_from_path(
96 | str(pdf_path),
97 | dpi=settings.dpi,
98 | grayscale=settings.grayscale,
99 | use_cropbox=settings.use_cropbox,
100 | strict=settings.strict,
101 | thread_count=settings.thread_count,
102 | output_folder=temp_dir,
103 | )
104 |
105 | logger.info(f"Successfully converted PDF with {len(images)} pages")
106 | return images
107 |
108 | except Exception as e:
109 | error_msg = f"Error converting PDF to images: {str(e)}"
110 | if settings.raise_on_error:
111 | raise PDFProcessingError(error_msg)
112 | logger.error(error_msg)
113 | return []
114 |
115 | except Exception as e:
116 | error_msg = f"PDF processing error: {str(e)}"
117 | if settings.raise_on_error:
118 | raise PDFProcessingError(error_msg)
119 | logger.error(error_msg)
120 | return []
121 |
122 |
123 | def estimate_pdf_size(pdf_path: Path) -> Tuple[int, str]:
124 | """
125 | Estimate memory requirements for PDF conversion.
126 |
127 | Args:
128 | pdf_path: Path to PDF file
129 |
130 | Returns:
131 | Tuple of (size_in_bytes, human_readable_size)
132 | """
133 | try:
134 | # Convert first page to estimate size per page
135 | sample = pdf2image.convert_from_path(str(pdf_path), first_page=1, last_page=1)[
136 | 0
137 | ]
138 | width, height = sample.size
139 | channels = len(sample.getbands())
140 |
141 | # Get total page count
142 | info = pdf2image.pdfinfo_from_path(str(pdf_path))
143 | total_pages = info["Pages"]
144 |
145 | # Estimate total memory requirement
146 | bytes_per_page = width * height * channels
147 | total_bytes = bytes_per_page * total_pages
148 |
149 | # Convert to human readable
150 | for unit in ["B", "KB", "MB", "GB"]:
151 | if total_bytes < 1024:
152 | return total_bytes, f"{total_bytes:.1f}{unit}"
153 | total_bytes /= 1024
154 |
155 | return total_bytes, f"{total_bytes:.1f}GB"
156 |
157 | except Exception as e:
158 | logger.error(f"Error estimating PDF size: {str(e)}")
159 | return 0, "Unknown"
160 |
--------------------------------------------------------------------------------
/test.py:
--------------------------------------------------------------------------------
1 | import base64
2 | import json
3 | import os
4 | import tempfile
5 | from typing import Dict, List, Optional
6 |
7 | import requests
8 | from pdf2image import convert_from_path
9 |
10 | from src.interfaces.interfaces import BasePromptTemplate
11 |
12 |
13 | class PDFVisionProcessor:
14 | def __init__(self, temp_dir: Optional[str] = None):
15 | """
16 | Initialize the PDF Vision Processor.
17 | Args:
18 | temp_dir: Optional path to temporary directory. If None, system temp dir is used.
19 | """
20 | self.temp_dir = temp_dir or tempfile.mkdtemp()
21 | os.makedirs(self.temp_dir, exist_ok=True)
22 |
23 | def convert_pdf_to_images(self, pdf_path: str) -> List[str]:
24 | """
25 | Convert PDF pages to images and save them in the temporary directory.
26 | Args:
27 | pdf_path: Path to the PDF file
28 | Returns:
29 | List of paths to generated images
30 | """
31 | image_paths = []
32 | try:
33 | # Convert PDF to images
34 | pages = convert_from_path(pdf_path)
35 |
36 | # Save each page as an image
37 | for i, page in enumerate(pages):
38 | image_path = os.path.join(self.temp_dir, f"page_{i}.png")
39 | page.save(image_path, "PNG")
40 | image_paths.append(image_path)
41 |
42 | return image_paths
43 | except Exception as e:
44 | print(f"Error converting PDF to images: {str(e)}")
45 | return []
46 |
47 | @staticmethod
48 | def encode_image_to_base64(image_path: str) -> str:
49 | """Convert an image file to a base64 encoded string."""
50 | with open(image_path, "rb") as image_file:
51 | return base64.b64encode(image_file.read()).decode("utf-8")
52 |
53 | def process_pdfs(
54 | self, pdf_paths: List[str], prompt: BasePromptTemplate
55 | ) -> Dict[str, List[Dict]]:
56 | """
57 | Process multiple PDFs and perform vision OCR on each page.
58 | Args:
59 | pdf_paths: List of paths to PDF files
60 | prompt: Prompt template for vision processing
61 | Returns:
62 | Dictionary with PDF paths as keys and lists of OCR results as values
63 | """
64 | results = {}
65 |
66 | for pdf_path in pdf_paths:
67 | pdf_results = []
68 | image_paths = self.convert_pdf_to_images(pdf_path)
69 |
70 | for image_path in image_paths:
71 | ocr_result = self.perform_vision_ocr(image_path, prompt)
72 | if ocr_result:
73 | pdf_results.append(ocr_result)
74 |
75 | results[pdf_path] = pdf_results
76 |
77 | return results
78 |
79 | def perform_vision_ocr(
80 | self, image_path: str, prompt: BasePromptTemplate
81 | ) -> Optional[Dict]:
82 | """Perform OCR on the given image using Llama 3.2-Vision."""
83 | base64_image = self.encode_image_to_base64(image_path)
84 |
85 | response = requests.post(
86 | "http://localhost:11434/api/chat",
87 | json={
88 | "model": "llama3.2-vision",
89 | "messages": [
90 | {
91 | "role": "user",
92 | "content": prompt().create_template(),
93 | "images": [base64_image],
94 | },
95 | ],
96 | },
97 | )
98 |
99 | if response.status_code == 200:
100 | full_content = ""
101 | for line in response.iter_lines():
102 | if line:
103 | json_obj = json.loads(line)
104 | full_content += json_obj["message"]["content"]
105 |
106 | try:
107 | return json.loads(full_content)
108 | except json.JSONDecodeError:
109 | return {"raw_content": full_content}
110 | else:
111 | print(f"Error: {response.status_code} {response.text}")
112 | return None
113 |
114 | def cleanup(self):
115 | """Remove temporary files and directory."""
116 | try:
117 | for file in os.listdir(self.temp_dir):
118 | os.remove(os.path.join(self.temp_dir, file))
119 | os.rmdir(self.temp_dir)
120 | except Exception as e:
121 | print(f"Error during cleanup: {str(e)}")
122 |
123 |
124 | if __name__ == "__main__":
125 | from src.prompts.llama_vision import AnalyzePdfImage
126 |
127 | # Example usage
128 | pdf_paths = [
129 | "/Users/vesaalexandru/Workspaces/cube/cube-publication/evaluate-ocr-rag-systems/data/paper01-1-2.pdf"
130 | ]
131 |
132 | # Initialize processor with custom temp directory (optional)
133 | processor = PDFVisionProcessor()
134 |
135 | try:
136 | # Process PDFs and get results
137 | results = processor.process_pdfs(pdf_paths, AnalyzePdfImage)
138 |
139 | # Print results
140 | for pdf_path, ocr_results in results.items():
141 | print(f"\nResults for {pdf_path}:")
142 | for i, result in enumerate(ocr_results):
143 | print(f"Page {i + 1}:", result)
144 |
145 | finally:
146 | # Clean up temporary files
147 | processor.cleanup()
148 |
--------------------------------------------------------------------------------
/src/vespa/indexing/run.py:
--------------------------------------------------------------------------------
1 | import logging
2 | from pathlib import Path
3 | from time import time
4 | from typing import Any, Dict, List, Optional
5 |
6 | import yaml
7 | from tqdm import tqdm
8 |
9 | from src.vespa.datatypes import (
10 | PDFInput,
11 | VespaConfig,
12 | VespaSchemaConfig,
13 | )
14 | from src.vespa.indexing.pdf_processor import PDFProcessor
15 | from src.vespa.indexing.prepare_feed import VespaFeedPreparator
16 | from src.vespa.setup import VespaSetup
17 | from vespa.application import Vespa
18 | from vespa.deployment import VespaCloud
19 |
20 | logging.basicConfig(level=logging.DEBUG)
21 | logger = logging.getLogger(__name__)
22 |
23 |
24 | class VespaDeployer:
25 | def __init__(
26 | self,
27 | config: VespaConfig,
28 | pdf_processor: Optional[PDFProcessor] = None,
29 | feed_preparator: Optional[VespaFeedPreparator] = None,
30 | ) -> None:
31 | self.config = config
32 | self.pdf_processor = pdf_processor or PDFProcessor()
33 | self.feed_preparator = feed_preparator or VespaFeedPreparator()
34 |
35 | async def _feed_data(self, app: Vespa, vespa_feed: List[Dict]) -> None:
36 | failed_documents = []
37 |
38 | async with app.asyncio(
39 | connections=self.config.connections, timeout=self.config.timeout
40 | ) as session:
41 | logger.info("Starting data feed process")
42 |
43 | for doc in tqdm(vespa_feed, desc="Feeding documents"):
44 | try:
45 | logger.debug(f"Feeding document: {doc['fields']['url']}")
46 |
47 | response = await session.feed_data_point(
48 | schema=self.config.schema_name,
49 | data_id=doc["fields"]["url"],
50 | fields=doc["fields"],
51 | )
52 |
53 | if not response.is_successful():
54 | error_msg = (
55 | response.get_json()
56 | if response.get_json()
57 | else str(response)
58 | )
59 | logger.error(f"Feed failed: {error_msg}")
60 | failed_documents.append(
61 | {"id": doc["fields"]["url"], "error": error_msg}
62 | )
63 |
64 | except Exception as e:
65 | logger.error(f"Feed error for {doc['fields']['url']}: {str(e)}")
66 | failed_documents.append(
67 | {"id": doc["fields"]["url"], "error": str(e)}
68 | )
69 |
70 | if failed_documents:
71 | self._save_failed_documents(failed_documents)
72 | error_details = "\n".join(
73 | [f"Doc {doc['id']}: {doc['error']}" for doc in failed_documents]
74 | )
75 | raise Exception(f"Documents failed to feed:\n{error_details}")
76 |
77 | @staticmethod
78 | def _save_failed_documents(failed_docs: List[Dict[str, Any]]) -> None:
79 | output_dir = Path("logs/failed_documents")
80 | output_dir.mkdir(parents=True, exist_ok=True)
81 | output_file = output_dir / f"failed_documents_{int(time())}.yaml"
82 | with open(output_file, "w") as f:
83 | yaml.dump(failed_docs, f)
84 | logger.info(f"Saved failed documents to {output_file}")
85 |
86 | async def deploy_and_feed(self, vespa_feed: List[Dict]) -> Vespa:
87 | try:
88 | vespa_setup = VespaSetup(
89 | app_name=self.config.app_name,
90 | schema_config=self.config.schema_config or VespaSchemaConfig(),
91 | )
92 |
93 | vespa_cloud = VespaCloud(
94 | tenant=self.config.tenant_name,
95 | application=self.config.app_name,
96 | application_package=vespa_setup.app_package,
97 | )
98 |
99 | logger.info("Deploying to Vespa Cloud")
100 | app = vespa_cloud.deploy()
101 | await self._feed_data(app, vespa_feed)
102 | return app
103 |
104 | except Exception as e:
105 | logger.error(f"Vespa deployment failed: {str(e)}")
106 | raise
107 |
108 |
109 | async def run_indexing(
110 | config_path: str,
111 | pdfs: List[PDFInput],
112 | pdf_processor: Optional[PDFProcessor] = None,
113 | feed_preparator: Optional[VespaFeedPreparator] = None,
114 | ) -> None:
115 | try:
116 | if not pdfs:
117 | raise ValueError("PDF list cannot be empty")
118 |
119 | config_path = Path(config_path)
120 | if not config_path.exists():
121 | raise FileNotFoundError(f"Configuration file not found: {config_path}")
122 |
123 | with open(config_path) as f:
124 | config_data = yaml.safe_load(f)
125 |
126 | if not config_data.get("vespa"):
127 | raise ValueError("Invalid configuration: 'vespa' section missing")
128 |
129 | schema_config = config_data["vespa"].get("schema")
130 | vespa_config = VespaConfig(
131 | app_name=config_data["vespa"]["app_name"],
132 | tenant_name=config_data["vespa"]["tenant_name"],
133 | connections=config_data["vespa"].get("connections", 1),
134 | timeout=config_data["vespa"].get("timeout", 180),
135 | schema_name=config_data["vespa"].get("schema_name", "pdf_page"),
136 | schema_config=VespaSchemaConfig.from_dict(schema_config)
137 | if schema_config
138 | else None,
139 | )
140 |
141 | deployer = VespaDeployer(
142 | config=vespa_config,
143 | pdf_processor=pdf_processor,
144 | feed_preparator=feed_preparator,
145 | )
146 |
147 | processed_data = pdf_processor.process_pdf(
148 | [{"title": pdf.title, "url": pdf.url} for pdf in pdfs]
149 | )
150 | vespa_feed = feed_preparator.prepare_feed(processed_data)
151 |
152 | await deployer.deploy_and_feed(vespa_feed)
153 | logger.info("Indexing process completed successfully")
154 |
155 | except Exception as e:
156 | logger.error(f"Indexing failed: {str(e)}")
157 | raise
158 |
--------------------------------------------------------------------------------
/src/ocr_benchmark/utils/image_processing.py:
--------------------------------------------------------------------------------
1 | """Utility functions for image processing and manipulation."""
2 |
3 | import logging
4 | from dataclasses import dataclass
5 | from typing import Optional
6 |
7 | import cv2
8 | import numpy as np
9 | from PIL import Image, ImageEnhance, ImageOps
10 |
11 | logger = logging.getLogger(__name__)
12 |
13 |
14 | @dataclass
15 | class ImagePreprocessingConfig:
16 | """Configuration for image preprocessing."""
17 |
18 | resize_width: Optional[int] = None
19 | resize_height: Optional[int] = None
20 | contrast_factor: float = 1.0
21 | brightness_factor: float = 1.0
22 | sharpen_factor: float = 1.0
23 | denoise: bool = False
24 | deskew: bool = False
25 | binarize: bool = False
26 |
27 |
28 | def ensure_rgb(image: Image.Image) -> Image.Image:
29 | """
30 | Ensure image is in RGB format.
31 |
32 | Args:
33 | image: Input PIL Image
34 |
35 | Returns:
36 | PIL Image in RGB format
37 | """
38 | if image.mode != "RGB":
39 | return image.convert("RGB")
40 | return image
41 |
42 |
43 | def resize_image(
44 | image: Image.Image,
45 | width: Optional[int] = None,
46 | height: Optional[int] = None,
47 | maintain_aspect: bool = True,
48 | ) -> Image.Image:
49 | """
50 | Resize image while optionally maintaining aspect ratio.
51 |
52 | Args:
53 | image: Input PIL Image
54 | width: Target width in pixels
55 | height: Target height in pixels
56 | maintain_aspect: Whether to maintain aspect ratio
57 |
58 | Returns:
59 | Resized PIL Image
60 | """
61 | if not width and not height:
62 | return image
63 |
64 | orig_width, orig_height = image.size
65 |
66 | if maintain_aspect:
67 | if width and height:
68 | # Use the dimension that results in a smaller image
69 | width_ratio = width / orig_width
70 | height_ratio = height / orig_height
71 | ratio = min(width_ratio, height_ratio)
72 | new_width = int(orig_width * ratio)
73 | new_height = int(orig_height * ratio)
74 | elif width:
75 | ratio = width / orig_width
76 | new_width = width
77 | new_height = int(orig_height * ratio)
78 | else:
79 | ratio = height / orig_height
80 | new_width = int(orig_width * ratio)
81 | new_height = height
82 | else:
83 | new_width = width if width else orig_width
84 | new_height = height if height else orig_height
85 |
86 | return image.resize((new_width, new_height), Image.Resampling.LANCZOS)
87 |
88 |
89 | def enhance_image(
90 | image: Image.Image,
91 | contrast: float = 1.0,
92 | brightness: float = 1.0,
93 | sharpness: float = 1.0,
94 | ) -> Image.Image:
95 | """
96 | Enhance image using contrast, brightness, and sharpness adjustments.
97 |
98 | Args:
99 | image: Input PIL Image
100 | contrast: Contrast enhancement factor
101 | brightness: Brightness enhancement factor
102 | sharpness: Sharpness enhancement factor
103 |
104 | Returns:
105 | Enhanced PIL Image
106 | """
107 | if contrast != 1.0:
108 | image = ImageEnhance.Contrast(image).enhance(contrast)
109 | if brightness != 1.0:
110 | image = ImageEnhance.Brightness(image).enhance(brightness)
111 | if sharpness != 1.0:
112 | image = ImageEnhance.Sharpness(image).enhance(sharpness)
113 | return image
114 |
115 |
116 | def denoise_image(image: Image.Image) -> Image.Image:
117 | """
118 | Apply denoising to image.
119 |
120 | Args:
121 | image: Input PIL Image
122 |
123 | Returns:
124 | Denoised PIL Image
125 | """
126 | # Convert to numpy array for OpenCV processing
127 | img_array = np.array(image)
128 | denoised = cv2.fastNlMeansDenoisingColored(img_array, None, 10, 10, 7, 21)
129 | return Image.fromarray(denoised)
130 |
131 |
132 | def deskew_image(image: Image.Image) -> Image.Image:
133 | """
134 | Deskew image by detecting and correcting rotation.
135 |
136 | Args:
137 | image: Input PIL Image
138 |
139 | Returns:
140 | Deskewed PIL Image
141 | """
142 | # Convert to numpy array and grayscale
143 | img_array = np.array(image)
144 | gray = cv2.cvtColor(img_array, cv2.COLOR_RGB2GRAY)
145 |
146 | # Detect edges
147 | edges = cv2.Canny(gray, 50, 150, apertureSize=3)
148 |
149 | # Detect lines using Hough transform
150 | lines = cv2.HoughLines(edges, 1, np.pi / 180, 100)
151 |
152 | if lines is not None:
153 | # Calculate the most common angle
154 | angles = []
155 | for rho, theta in lines[:, 0]:
156 | angle = np.degrees(theta)
157 | if angle < 45:
158 | angles.append(angle)
159 | elif angle > 135:
160 | angles.append(angle - 180)
161 |
162 | if angles:
163 | median_angle = np.median(angles)
164 | if abs(median_angle) > 0.5: # Only rotate if angle is significant
165 | return image.rotate(
166 | median_angle, expand=True, fillcolor=(255, 255, 255)
167 | )
168 |
169 | return image
170 |
171 |
172 | def binarize_image(image: Image.Image) -> Image.Image:
173 | """
174 | Convert image to binary (black and white) using adaptive thresholding.
175 |
176 | Args:
177 | image: Input PIL Image
178 |
179 | Returns:
180 | Binarized PIL Image
181 | """
182 | # Convert to grayscale
183 | gray = ImageOps.grayscale(image)
184 |
185 | # Convert to numpy array for OpenCV processing
186 | img_array = np.array(gray)
187 |
188 | # Apply adaptive thresholding
189 | binary = cv2.adaptiveThreshold(
190 | img_array,
191 | 255,
192 | cv2.ADAPTIVE_THRESH_GAUSSIAN_C,
193 | cv2.THRESH_BINARY,
194 | 11, # Block size
195 | 2, # Constant subtracted from mean
196 | )
197 |
198 | return Image.fromarray(binary)
199 |
200 |
201 | def preprocess_image(
202 | image: Image.Image, config: ImagePreprocessingConfig
203 | ) -> Image.Image:
204 | """
205 | Apply a series of preprocessing steps to an image.
206 |
207 | Args:
208 | image: Input PIL Image
209 | config: Preprocessing configuration
210 |
211 | Returns:
212 | Preprocessed PIL Image
213 | """
214 | try:
215 | # Ensure RGB format
216 | image = ensure_rgb(image)
217 |
218 | # Resize if needed
219 | if config.resize_width or config.resize_height:
220 | image = resize_image(image, config.resize_width, config.resize_height)
221 |
222 | # Apply enhancements
223 | if any(
224 | factor != 1.0
225 | for factor in [
226 | config.contrast_factor,
227 | config.brightness_factor,
228 | config.sharpen_factor,
229 | ]
230 | ):
231 | image = enhance_image(
232 | image,
233 | config.contrast_factor,
234 | config.brightness_factor,
235 | config.sharpen_factor,
236 | )
237 |
238 | # Apply denoising
239 | if config.denoise:
240 | image = denoise_image(image)
241 |
242 | # Apply deskewing
243 | if config.deskew:
244 | image = deskew_image(image)
245 |
246 | # Apply binarization
247 | if config.binarize:
248 | image = binarize_image(image)
249 |
250 | return image
251 |
252 | except Exception as e:
253 | logger.error(f"Error preprocessing image: {str(e)}")
254 | raise
255 |
--------------------------------------------------------------------------------
/src/vespa/setup.py:
--------------------------------------------------------------------------------
1 | import logging
2 | from typing import List, Optional
3 |
4 | from src.vespa.datatypes import VespaSchemaConfig
5 | from src.vespa.exceptions import VespaSetupError
6 | from vespa.package import (
7 | HNSW,
8 | ApplicationPackage,
9 | Document,
10 | Field,
11 | FieldSet,
12 | FirstPhaseRanking,
13 | Function,
14 | RankProfile,
15 | Schema,
16 | SecondPhaseRanking,
17 | )
18 |
19 | # Configure logging
20 | logging.basicConfig(
21 | level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
22 | )
23 | logger = logging.getLogger(__name__)
24 |
25 |
26 | class VespaSetup:
27 | """Class for setting up Vespa application package and schema."""
28 |
29 | def __init__(
30 | self, app_name: str, schema_config: Optional[VespaSchemaConfig] = None
31 | ) -> None:
32 | """
33 | Initialize Vespa setup.
34 |
35 | Args:
36 | app_name: Name of the Vespa application
37 | schema_config: Optional configuration for schema settings
38 | """
39 | self.app_name = app_name
40 | self.config = schema_config or VespaSchemaConfig()
41 |
42 | try:
43 | logger.info(f"Creating Vespa schema for application: {app_name}")
44 | self.schema = self._create_schema()
45 | self.app_package = ApplicationPackage(name=app_name, schema=[self.schema])
46 | except Exception as e:
47 | logger.error(f"Failed to create Vespa setup: {str(e)}")
48 | raise VespaSetupError(f"Setup failed: {str(e)}")
49 |
50 | def _create_schema(self) -> Schema:
51 | """
52 | Create the Vespa schema with fields and rank profiles.
53 |
54 | Returns:
55 | Configured Schema object
56 |
57 | Raises:
58 | VespaSetupError: If schema creation fails
59 | """
60 | try:
61 | schema = Schema(
62 | name="pdf_page",
63 | document=Document(fields=self._create_fields()),
64 | fieldsets=[FieldSet(name="default", fields=["title", "text"])],
65 | )
66 |
67 | self._add_rank_profiles(schema)
68 | return schema
69 |
70 | except Exception as e:
71 | logger.error(f"Failed to create schema: {str(e)}")
72 | raise VespaSetupError(f"Schema creation failed: {str(e)}")
73 |
74 | def _create_fields(self) -> List[Field]:
75 | """
76 | Create the fields for the schema.
77 |
78 | Returns:
79 | List of Field objects
80 | """
81 | return [
82 | Field(
83 | name="id", type="string", indexing=["summary", "index"], match=["word"]
84 | ),
85 | Field(name="url", type="string", indexing=["summary", "index"]),
86 | Field(
87 | name="title",
88 | type="string",
89 | indexing=["summary", "index"],
90 | match=["text"],
91 | index="enable-bm25",
92 | ),
93 | Field(name="page_number", type="int", indexing=["summary", "attribute"]),
94 | Field(name="image", type="raw", indexing=["summary"]),
95 | Field(
96 | name="text",
97 | type="string",
98 | indexing=["index"],
99 | match=["text"],
100 | index="enable-bm25",
101 | ),
102 | Field(
103 | name="embedding",
104 | type=f"tensor(patch{{}}, v[{self.config.tensor_dimensions}])",
105 | indexing=["attribute", "index"],
106 | ann=HNSW(
107 | distance_metric="hamming",
108 | max_links_per_node=self.config.hnsw_max_links,
109 | neighbors_to_explore_at_insert=self.config.hnsw_neighbors,
110 | ),
111 | ),
112 | ]
113 |
114 | def _add_rank_profiles(self, schema: Schema) -> None:
115 | """
116 | Add rank profiles to the schema.
117 |
118 | Args:
119 | schema: Schema object to add rank profiles to
120 | """
121 | # Add default ranking profile
122 | schema.add_rank_profile(self._create_default_profile())
123 |
124 | # Add retrieval and rerank profile
125 | schema.add_rank_profile(self._create_retrieval_rerank_profile())
126 |
127 | def _create_default_profile(self) -> RankProfile:
128 | """
129 | Create the default rank profile.
130 |
131 | Returns:
132 | Configured RankProfile object
133 | """
134 | return RankProfile(
135 | name="default",
136 | inputs=[("query(qt)", "tensor(querytoken{}, v[128])")],
137 | functions=[
138 | Function(
139 | name="max_sim",
140 | expression="""
141 | sum(
142 | reduce(
143 | sum(
144 | query(qt) * unpack_bits(attribute(embedding)) , v
145 | ),
146 | max, patch
147 | ),
148 | querytoken
149 | )
150 | """,
151 | ),
152 | Function(name="bm25_score", expression="bm25(title) + bm25(text)"),
153 | ],
154 | first_phase=FirstPhaseRanking(expression="bm25_score"),
155 | second_phase=SecondPhaseRanking(expression="max_sim", rerank_count=100),
156 | )
157 |
158 | def _create_retrieval_rerank_profile(self) -> RankProfile:
159 | """
160 | Create the retrieval and rerank profile.
161 |
162 | Returns:
163 | Configured RankProfile object
164 | """
165 | input_query_tensors = []
166 |
167 | # Add query tensors for each term
168 | for i in range(self.config.max_query_terms):
169 | input_query_tensors.append(
170 | (f"query(rq{i})", f"tensor(v[{self.config.tensor_dimensions}])")
171 | )
172 |
173 | # Add additional query tensors
174 | input_query_tensors.extend(
175 | [
176 | ("query(qt)", "tensor(querytoken{}, v[128])"),
177 | (
178 | "query(qtb)",
179 | f"tensor(querytoken{{}}, v[{self.config.tensor_dimensions}])",
180 | ),
181 | ]
182 | )
183 |
184 | return RankProfile(
185 | name="retrieval-and-rerank",
186 | inputs=input_query_tensors,
187 | functions=[
188 | Function(
189 | name="max_sim",
190 | expression="""
191 | sum(
192 | reduce(
193 | sum(
194 | query(qt) * unpack_bits(attribute(embedding)) , v
195 | ),
196 | max, patch
197 | ),
198 | querytoken
199 | )
200 | """,
201 | ),
202 | Function(
203 | name="max_sim_binary",
204 | expression="""
205 | sum(
206 | reduce(
207 | 1/(1 + sum(
208 | hamming(query(qtb), attribute(embedding)) ,v)
209 | ),
210 | max,
211 | patch
212 | ),
213 | querytoken
214 | )
215 | """,
216 | ),
217 | ],
218 | first_phase=FirstPhaseRanking(expression="max_sim_binary"),
219 | second_phase=SecondPhaseRanking(
220 | expression="max_sim", rerank_count=self.config.rerank_count
221 | ),
222 | )
223 |
--------------------------------------------------------------------------------
/test/schemas/pdf_page.sd:
--------------------------------------------------------------------------------
1 | schema pdf_page {
2 | document pdf_page {
3 | field id type string {
4 | indexing: summary | index
5 | match {
6 | word
7 | }
8 | }
9 | field url type string {
10 | indexing: summary | index
11 | }
12 | field title type string {
13 | indexing: summary | index
14 | index: enable-bm25
15 | match {
16 | text
17 | }
18 | }
19 | field page_number type int {
20 | indexing: summary | attribute
21 | }
22 | field image type raw {
23 | indexing: summary
24 | }
25 | field text type string {
26 | indexing: index
27 | index: enable-bm25
28 | match {
29 | text
30 | }
31 | }
32 | field embedding type tensor(patch{}, v[16]) {
33 | indexing: attribute | index
34 | attribute {
35 | distance-metric: hamming
36 | }
37 | index {
38 | hnsw {
39 | max-links-per-node: 32
40 | neighbors-to-explore-at-insert: 400
41 | }
42 | }
43 | }
44 | }
45 | fieldset default {
46 | fields: title, text
47 | }
48 | rank-profile default {
49 | inputs {
50 | query(qt) tensor(querytoken{}, v[128])
51 |
52 | }
53 | function max_sim() {
54 | expression {
55 |
56 | sum(
57 | reduce(
58 | sum(
59 | query(qt) * unpack_bits(attribute(embedding)) , v
60 | ),
61 | max, patch
62 | ),
63 | querytoken
64 | )
65 |
66 | }
67 | }
68 | function bm25_score() {
69 | expression {
70 | bm25(title) + bm25(text)
71 | }
72 | }
73 | first-phase {
74 | expression {
75 | bm25_score
76 | }
77 | }
78 | second-phase {
79 | rerank-count: 100
80 | expression {
81 | max_sim
82 | }
83 | }
84 | }
85 | rank-profile retrieval-and-rerank {
86 | inputs {
87 | query(rq0) tensor(v[16])
88 | query(rq1) tensor(v[16])
89 | query(rq2) tensor(v[16])
90 | query(rq3) tensor(v[16])
91 | query(rq4) tensor(v[16])
92 | query(rq5) tensor(v[16])
93 | query(rq6) tensor(v[16])
94 | query(rq7) tensor(v[16])
95 | query(rq8) tensor(v[16])
96 | query(rq9) tensor(v[16])
97 | query(rq10) tensor(v[16])
98 | query(rq11) tensor(v[16])
99 | query(rq12) tensor(v[16])
100 | query(rq13) tensor(v[16])
101 | query(rq14) tensor(v[16])
102 | query(rq15) tensor(v[16])
103 | query(rq16) tensor(v[16])
104 | query(rq17) tensor(v[16])
105 | query(rq18) tensor(v[16])
106 | query(rq19) tensor(v[16])
107 | query(rq20) tensor(v[16])
108 | query(rq21) tensor(v[16])
109 | query(rq22) tensor(v[16])
110 | query(rq23) tensor(v[16])
111 | query(rq24) tensor(v[16])
112 | query(rq25) tensor(v[16])
113 | query(rq26) tensor(v[16])
114 | query(rq27) tensor(v[16])
115 | query(rq28) tensor(v[16])
116 | query(rq29) tensor(v[16])
117 | query(rq30) tensor(v[16])
118 | query(rq31) tensor(v[16])
119 | query(rq32) tensor(v[16])
120 | query(rq33) tensor(v[16])
121 | query(rq34) tensor(v[16])
122 | query(rq35) tensor(v[16])
123 | query(rq36) tensor(v[16])
124 | query(rq37) tensor(v[16])
125 | query(rq38) tensor(v[16])
126 | query(rq39) tensor(v[16])
127 | query(rq40) tensor(v[16])
128 | query(rq41) tensor(v[16])
129 | query(rq42) tensor(v[16])
130 | query(rq43) tensor(v[16])
131 | query(rq44) tensor(v[16])
132 | query(rq45) tensor(v[16])
133 | query(rq46) tensor(v[16])
134 | query(rq47) tensor(v[16])
135 | query(rq48) tensor(v[16])
136 | query(rq49) tensor(v[16])
137 | query(rq50) tensor(v[16])
138 | query(rq51) tensor(v[16])
139 | query(rq52) tensor(v[16])
140 | query(rq53) tensor(v[16])
141 | query(rq54) tensor(v[16])
142 | query(rq55) tensor(v[16])
143 | query(rq56) tensor(v[16])
144 | query(rq57) tensor(v[16])
145 | query(rq58) tensor(v[16])
146 | query(rq59) tensor(v[16])
147 | query(rq60) tensor(v[16])
148 | query(rq61) tensor(v[16])
149 | query(rq62) tensor(v[16])
150 | query(rq63) tensor(v[16])
151 | query(qt) tensor(querytoken{}, v[128])
152 | query(qtb) tensor(querytoken{}, v[16])
153 |
154 | }
155 | function max_sim() {
156 | expression {
157 |
158 | sum(
159 | reduce(
160 | sum(
161 | query(qt) * unpack_bits(attribute(embedding)) , v
162 | ),
163 | max, patch
164 | ),
165 | querytoken
166 | )
167 |
168 | }
169 | }
170 | function max_sim_binary() {
171 | expression {
172 |
173 | sum(
174 | reduce(
175 | 1/(1 + sum(
176 | hamming(query(qtb), attribute(embedding)) ,v)
177 | ),
178 | max,
179 | patch
180 | ),
181 | querytoken
182 | )
183 |
184 | }
185 | }
186 | first-phase {
187 | expression {
188 | max_sim_binary
189 | }
190 | }
191 | second-phase {
192 | rerank-count: 10
193 | expression {
194 | max_sim
195 | }
196 | }
197 | }
198 | }
--------------------------------------------------------------------------------
/test3/schemas/pdf_page.sd:
--------------------------------------------------------------------------------
1 | schema pdf_page {
2 | document pdf_page {
3 | field id type string {
4 | indexing: summary | index
5 | match {
6 | word
7 | }
8 | }
9 | field url type string {
10 | indexing: summary | index
11 | }
12 | field title type string {
13 | indexing: summary | index
14 | index: enable-bm25
15 | match {
16 | text
17 | }
18 | }
19 | field page_number type int {
20 | indexing: summary | attribute
21 | }
22 | field image type raw {
23 | indexing: summary
24 | }
25 | field text type string {
26 | indexing: index
27 | index: enable-bm25
28 | match {
29 | text
30 | }
31 | }
32 | field embedding type tensor(patch{}, v[16]) {
33 | indexing: attribute | index
34 | attribute {
35 | distance-metric: hamming
36 | }
37 | index {
38 | hnsw {
39 | max-links-per-node: 32
40 | neighbors-to-explore-at-insert: 400
41 | }
42 | }
43 | }
44 | }
45 | fieldset default {
46 | fields: title, text
47 | }
48 | rank-profile default {
49 | inputs {
50 | query(qt) tensor(querytoken{}, v[128])
51 |
52 | }
53 | function max_sim() {
54 | expression {
55 |
56 | sum(
57 | reduce(
58 | sum(
59 | query(qt) * unpack_bits(attribute(embedding)) , v
60 | ),
61 | max, patch
62 | ),
63 | querytoken
64 | )
65 |
66 | }
67 | }
68 | function bm25_score() {
69 | expression {
70 | bm25(title) + bm25(text)
71 | }
72 | }
73 | first-phase {
74 | expression {
75 | bm25_score
76 | }
77 | }
78 | second-phase {
79 | rerank-count: 100
80 | expression {
81 | max_sim
82 | }
83 | }
84 | }
85 | rank-profile retrieval-and-rerank {
86 | inputs {
87 | query(rq0) tensor(v[16])
88 | query(rq1) tensor(v[16])
89 | query(rq2) tensor(v[16])
90 | query(rq3) tensor(v[16])
91 | query(rq4) tensor(v[16])
92 | query(rq5) tensor(v[16])
93 | query(rq6) tensor(v[16])
94 | query(rq7) tensor(v[16])
95 | query(rq8) tensor(v[16])
96 | query(rq9) tensor(v[16])
97 | query(rq10) tensor(v[16])
98 | query(rq11) tensor(v[16])
99 | query(rq12) tensor(v[16])
100 | query(rq13) tensor(v[16])
101 | query(rq14) tensor(v[16])
102 | query(rq15) tensor(v[16])
103 | query(rq16) tensor(v[16])
104 | query(rq17) tensor(v[16])
105 | query(rq18) tensor(v[16])
106 | query(rq19) tensor(v[16])
107 | query(rq20) tensor(v[16])
108 | query(rq21) tensor(v[16])
109 | query(rq22) tensor(v[16])
110 | query(rq23) tensor(v[16])
111 | query(rq24) tensor(v[16])
112 | query(rq25) tensor(v[16])
113 | query(rq26) tensor(v[16])
114 | query(rq27) tensor(v[16])
115 | query(rq28) tensor(v[16])
116 | query(rq29) tensor(v[16])
117 | query(rq30) tensor(v[16])
118 | query(rq31) tensor(v[16])
119 | query(rq32) tensor(v[16])
120 | query(rq33) tensor(v[16])
121 | query(rq34) tensor(v[16])
122 | query(rq35) tensor(v[16])
123 | query(rq36) tensor(v[16])
124 | query(rq37) tensor(v[16])
125 | query(rq38) tensor(v[16])
126 | query(rq39) tensor(v[16])
127 | query(rq40) tensor(v[16])
128 | query(rq41) tensor(v[16])
129 | query(rq42) tensor(v[16])
130 | query(rq43) tensor(v[16])
131 | query(rq44) tensor(v[16])
132 | query(rq45) tensor(v[16])
133 | query(rq46) tensor(v[16])
134 | query(rq47) tensor(v[16])
135 | query(rq48) tensor(v[16])
136 | query(rq49) tensor(v[16])
137 | query(rq50) tensor(v[16])
138 | query(rq51) tensor(v[16])
139 | query(rq52) tensor(v[16])
140 | query(rq53) tensor(v[16])
141 | query(rq54) tensor(v[16])
142 | query(rq55) tensor(v[16])
143 | query(rq56) tensor(v[16])
144 | query(rq57) tensor(v[16])
145 | query(rq58) tensor(v[16])
146 | query(rq59) tensor(v[16])
147 | query(rq60) tensor(v[16])
148 | query(rq61) tensor(v[16])
149 | query(rq62) tensor(v[16])
150 | query(rq63) tensor(v[16])
151 | query(qt) tensor(querytoken{}, v[128])
152 | query(qtb) tensor(querytoken{}, v[16])
153 |
154 | }
155 | function max_sim() {
156 | expression {
157 |
158 | sum(
159 | reduce(
160 | sum(
161 | query(qt) * unpack_bits(attribute(embedding)) , v
162 | ),
163 | max, patch
164 | ),
165 | querytoken
166 | )
167 |
168 | }
169 | }
170 | function max_sim_binary() {
171 | expression {
172 |
173 | sum(
174 | reduce(
175 | 1/(1 + sum(
176 | hamming(query(qtb), attribute(embedding)) ,v)
177 | ),
178 | max,
179 | patch
180 | ),
181 | querytoken
182 | )
183 |
184 | }
185 | }
186 | first-phase {
187 | expression {
188 | max_sim_binary
189 | }
190 | }
191 | second-phase {
192 | rerank-count: 10
193 | expression {
194 | max_sim
195 | }
196 | }
197 | }
198 | }
--------------------------------------------------------------------------------
/test4/schemas/pdf_page.sd:
--------------------------------------------------------------------------------
1 | schema pdf_page {
2 | document pdf_page {
3 | field id type string {
4 | indexing: summary | index
5 | match {
6 | word
7 | }
8 | }
9 | field url type string {
10 | indexing: summary | index
11 | }
12 | field title type string {
13 | indexing: summary | index
14 | index: enable-bm25
15 | match {
16 | text
17 | }
18 | }
19 | field page_number type int {
20 | indexing: summary | attribute
21 | }
22 | field image type raw {
23 | indexing: summary
24 | }
25 | field text type string {
26 | indexing: index
27 | index: enable-bm25
28 | match {
29 | text
30 | }
31 | }
32 | field embedding type tensor(patch{}, v[16]) {
33 | indexing: attribute | index
34 | attribute {
35 | distance-metric: hamming
36 | }
37 | index {
38 | hnsw {
39 | max-links-per-node: 32
40 | neighbors-to-explore-at-insert: 400
41 | }
42 | }
43 | }
44 | }
45 | fieldset default {
46 | fields: title, text
47 | }
48 | rank-profile default {
49 | inputs {
50 | query(qt) tensor(querytoken{}, v[128])
51 |
52 | }
53 | function max_sim() {
54 | expression {
55 |
56 | sum(
57 | reduce(
58 | sum(
59 | query(qt) * unpack_bits(attribute(embedding)) , v
60 | ),
61 | max, patch
62 | ),
63 | querytoken
64 | )
65 |
66 | }
67 | }
68 | function bm25_score() {
69 | expression {
70 | bm25(title) + bm25(text)
71 | }
72 | }
73 | first-phase {
74 | expression {
75 | bm25_score
76 | }
77 | }
78 | second-phase {
79 | rerank-count: 100
80 | expression {
81 | max_sim
82 | }
83 | }
84 | }
85 | rank-profile retrieval-and-rerank {
86 | inputs {
87 | query(rq0) tensor(v[16])
88 | query(rq1) tensor(v[16])
89 | query(rq2) tensor(v[16])
90 | query(rq3) tensor(v[16])
91 | query(rq4) tensor(v[16])
92 | query(rq5) tensor(v[16])
93 | query(rq6) tensor(v[16])
94 | query(rq7) tensor(v[16])
95 | query(rq8) tensor(v[16])
96 | query(rq9) tensor(v[16])
97 | query(rq10) tensor(v[16])
98 | query(rq11) tensor(v[16])
99 | query(rq12) tensor(v[16])
100 | query(rq13) tensor(v[16])
101 | query(rq14) tensor(v[16])
102 | query(rq15) tensor(v[16])
103 | query(rq16) tensor(v[16])
104 | query(rq17) tensor(v[16])
105 | query(rq18) tensor(v[16])
106 | query(rq19) tensor(v[16])
107 | query(rq20) tensor(v[16])
108 | query(rq21) tensor(v[16])
109 | query(rq22) tensor(v[16])
110 | query(rq23) tensor(v[16])
111 | query(rq24) tensor(v[16])
112 | query(rq25) tensor(v[16])
113 | query(rq26) tensor(v[16])
114 | query(rq27) tensor(v[16])
115 | query(rq28) tensor(v[16])
116 | query(rq29) tensor(v[16])
117 | query(rq30) tensor(v[16])
118 | query(rq31) tensor(v[16])
119 | query(rq32) tensor(v[16])
120 | query(rq33) tensor(v[16])
121 | query(rq34) tensor(v[16])
122 | query(rq35) tensor(v[16])
123 | query(rq36) tensor(v[16])
124 | query(rq37) tensor(v[16])
125 | query(rq38) tensor(v[16])
126 | query(rq39) tensor(v[16])
127 | query(rq40) tensor(v[16])
128 | query(rq41) tensor(v[16])
129 | query(rq42) tensor(v[16])
130 | query(rq43) tensor(v[16])
131 | query(rq44) tensor(v[16])
132 | query(rq45) tensor(v[16])
133 | query(rq46) tensor(v[16])
134 | query(rq47) tensor(v[16])
135 | query(rq48) tensor(v[16])
136 | query(rq49) tensor(v[16])
137 | query(rq50) tensor(v[16])
138 | query(rq51) tensor(v[16])
139 | query(rq52) tensor(v[16])
140 | query(rq53) tensor(v[16])
141 | query(rq54) tensor(v[16])
142 | query(rq55) tensor(v[16])
143 | query(rq56) tensor(v[16])
144 | query(rq57) tensor(v[16])
145 | query(rq58) tensor(v[16])
146 | query(rq59) tensor(v[16])
147 | query(rq60) tensor(v[16])
148 | query(rq61) tensor(v[16])
149 | query(rq62) tensor(v[16])
150 | query(rq63) tensor(v[16])
151 | query(qt) tensor(querytoken{}, v[128])
152 | query(qtb) tensor(querytoken{}, v[16])
153 |
154 | }
155 | function max_sim() {
156 | expression {
157 |
158 | sum(
159 | reduce(
160 | sum(
161 | query(qt) * unpack_bits(attribute(embedding)) , v
162 | ),
163 | max, patch
164 | ),
165 | querytoken
166 | )
167 |
168 | }
169 | }
170 | function max_sim_binary() {
171 | expression {
172 |
173 | sum(
174 | reduce(
175 | 1/(1 + sum(
176 | hamming(query(qtb), attribute(embedding)) ,v)
177 | ),
178 | max,
179 | patch
180 | ),
181 | querytoken
182 | )
183 |
184 | }
185 | }
186 | first-phase {
187 | expression {
188 | max_sim_binary
189 | }
190 | }
191 | second-phase {
192 | rerank-count: 10
193 | expression {
194 | max_sim
195 | }
196 | }
197 | }
198 | }
--------------------------------------------------------------------------------
/vespa_vision_rag.py:
--------------------------------------------------------------------------------
1 | import base64
2 | from io import BytesIO
3 |
4 | import numpy as np
5 | import requests
6 | import torch
7 | from colpali_engine.models import ColQwen2, ColQwen2Processor
8 | from pdf2image import convert_from_path
9 | from pypdf import PdfReader
10 | from torch.utils.data import DataLoader
11 | from tqdm import tqdm
12 | from vespa.deployment import VespaCloud
13 | from vespa.package import (
14 | HNSW,
15 | ApplicationPackage,
16 | Document,
17 | Field,
18 | FieldSet,
19 | FirstPhaseRanking,
20 | Function,
21 | RankProfile,
22 | Schema,
23 | SecondPhaseRanking,
24 | )
25 |
26 |
27 | class PDFProcessor:
28 | def __init__(self, model_name="vidore/colqwen2-v0.1"):
29 | self.model = ColQwen2.from_pretrained(
30 | model_name, torch_dtype=torch.bfloat16, device_map="auto"
31 | )
32 | self.processor = ColQwen2Processor.from_pretrained(model_name)
33 | self.model.eval()
34 |
35 | def download_pdf(self, url):
36 | response = requests.get(url)
37 | if response.status_code == 200:
38 | return BytesIO(response.content)
39 | raise Exception(f"Failed to download PDF: Status code {response.status_code}")
40 |
41 | def get_pdf_content(self, pdf_url):
42 | pdf_file = self.download_pdf(pdf_url)
43 | temp_file = "temp.pdf"
44 | with open(temp_file, "wb") as f:
45 | f.write(pdf_file.read())
46 |
47 | reader = PdfReader(temp_file)
48 | page_texts = [page.extract_text() for page in reader.pages]
49 | images = convert_from_path(temp_file)
50 | assert len(images) == len(page_texts)
51 | return images, page_texts
52 |
53 | def process_pdf(self, pdf_metadata):
54 | pdf_data = []
55 | for pdf in pdf_metadata:
56 | images, texts = self.get_pdf_content(pdf["url"])
57 | embeddings = self.generate_embeddings(images)
58 | pdf_data.append(
59 | {
60 | "url": pdf["url"],
61 | "title": pdf["title"],
62 | "images": images,
63 | "texts": texts,
64 | "embeddings": embeddings,
65 | }
66 | )
67 | return pdf_data
68 |
69 | def generate_embeddings(self, images):
70 | embeddings = []
71 | dataloader = DataLoader(
72 | images,
73 | batch_size=2,
74 | shuffle=False,
75 | collate_fn=lambda x: self.processor.process_images(x),
76 | )
77 |
78 | for batch in tqdm(dataloader):
79 | with torch.no_grad():
80 | batch = {k: v.to(self.model.device) for k, v in batch.items()}
81 | batch_embeddings = self.model(**batch)
82 | embeddings.extend(list(torch.unbind(batch_embeddings.to("cpu"))))
83 | return embeddings
84 |
85 |
86 | class VespaSetup:
87 | def __init__(self, app_name):
88 | self.app_name = app_name
89 | self.schema = self._create_schema()
90 | self.app_package = ApplicationPackage(name=app_name, schema=[self.schema])
91 |
92 | def _create_schema(self):
93 | schema = Schema(
94 | name="pdf_page",
95 | document=Document(
96 | fields=[
97 | Field(
98 | name="id",
99 | type="string",
100 | indexing=["summary", "index"],
101 | match=["word"],
102 | ),
103 | Field(name="url", type="string", indexing=["summary", "index"]),
104 | Field(
105 | name="title",
106 | type="string",
107 | indexing=["summary", "index"],
108 | match=["text"],
109 | index="enable-bm25",
110 | ),
111 | Field(
112 | name="page_number",
113 | type="int",
114 | indexing=["summary", "attribute"],
115 | ),
116 | Field(name="image", type="raw", indexing=["summary"]),
117 | Field(
118 | name="text",
119 | type="string",
120 | indexing=["index"],
121 | match=["text"],
122 | index="enable-bm25",
123 | ),
124 | Field(
125 | name="embedding",
126 | type="tensor(patch{}, v[16])",
127 | indexing=["attribute", "index"],
128 | ann=HNSW(
129 | distance_metric="hamming",
130 | max_links_per_node=32,
131 | neighbors_to_explore_at_insert=400,
132 | ),
133 | ),
134 | ]
135 | ),
136 | fieldsets=[FieldSet(name="default", fields=["title", "text"])],
137 | )
138 | self._add_rank_profiles(schema)
139 | return schema
140 |
141 | def _add_rank_profiles(self, schema):
142 | default_profile = RankProfile(
143 | name="default",
144 | inputs=[("query(qt)", "tensor(querytoken{}, v[128])")],
145 | functions=[
146 | Function(
147 | name="max_sim",
148 | expression="""
149 | sum(
150 | reduce(
151 | sum(
152 | query(qt) * unpack_bits(attribute(embedding)) , v
153 | ),
154 | max, patch
155 | ),
156 | querytoken
157 | )
158 | """,
159 | ),
160 | Function(name="bm25_score", expression="bm25(title) + bm25(text)"),
161 | ],
162 | first_phase=FirstPhaseRanking(expression="bm25_score"),
163 | second_phase=SecondPhaseRanking(expression="max_sim", rerank_count=100),
164 | )
165 | schema.add_rank_profile(default_profile)
166 |
167 |
168 | def prepare_vespa_feed(pdf_data):
169 | vespa_feed = []
170 | for pdf in pdf_data:
171 | for page_number, (text, embedding, image) in enumerate(
172 | zip(pdf["texts"], pdf["embeddings"], pdf["images"])
173 | ):
174 | embedding_dict = {}
175 | for idx, patch_embedding in enumerate(embedding):
176 | binary_vector = (
177 | np.packbits(np.where(patch_embedding > 0, 1, 0))
178 | .astype(np.int8)
179 | .tobytes()
180 | .hex()
181 | )
182 | embedding_dict[idx] = binary_vector
183 |
184 | page = {
185 | "id": hash(pdf["url"] + str(page_number)),
186 | "url": pdf["url"],
187 | "title": pdf["title"],
188 | "page_number": page_number,
189 | "image": get_base64_image(resize_image(image, 640)),
190 | "text": text,
191 | "embedding": embedding_dict,
192 | }
193 | vespa_feed.append(page)
194 | return vespa_feed
195 |
196 |
197 | def resize_image(image, max_height=800):
198 | width, height = image.size
199 | if height > max_height:
200 | ratio = max_height / height
201 | return image.resize((int(width * ratio), int(height * ratio)))
202 | return image
203 |
204 |
205 | def get_base64_image(image):
206 | buffered = BytesIO()
207 | image.save(buffered, format="JPEG")
208 | return str(base64.b64encode(buffered.getvalue()), "utf-8")
209 |
210 |
211 | async def deploy_and_feed(vespa_feed):
212 | vespa_setup = VespaSetup("test")
213 |
214 | vespa_cloud = VespaCloud(
215 | tenant="cube-digital",
216 | application="test",
217 | application_package=vespa_setup.app_package,
218 | )
219 |
220 | app = vespa_cloud.deploy()
221 |
222 | async with app.asyncio(connections=1, timeout=180) as session:
223 | for page in tqdm(vespa_feed):
224 | response = await session.feed_data_point(
225 | data_id=page["id"], fields=page, schema="pdf_page"
226 | )
227 | if not response.is_successful():
228 | print(response.json())
229 | return app
230 |
231 |
232 | async def main():
233 | # Example usage
234 | sample_pdfs = [
235 | {
236 | "title": "Building a Resilient Strategy for the Energy Transition",
237 | "url": "https://static.conocophillips.com/files/resources/conocophillips-2023-managing-climate-related-risks.pdf",
238 | }
239 | ]
240 |
241 | processor = PDFProcessor()
242 | pdf_data = processor.process_pdf(sample_pdfs)
243 | vespa_feed = prepare_vespa_feed(pdf_data)
244 |
245 | # Deploy to Vespa Cloud (requires configuration)
246 | await deploy_and_feed(vespa_feed=vespa_feed)
247 |
248 |
249 | if __name__ == "__main__":
250 | import asyncio
251 |
252 | asyncio.run(main())
253 | ()
254 |
--------------------------------------------------------------------------------
/src/vespa/retrieval/run.py:
--------------------------------------------------------------------------------
1 | import logging
2 | from pathlib import Path
3 | from typing import Dict, List, Optional
4 |
5 | import torch
6 | import yaml
7 | from colpali_engine.models import ColQwen2, ColQwen2Processor
8 | from torch.utils.data import DataLoader
9 |
10 | from src.vespa.datatypes import QueryResult, VespaQueryConfig
11 | from src.vespa.exceptions import VespaQueryError
12 | from src.vespa.setup import VespaSetup
13 | from vespa.deployment import VespaCloud
14 | from vespa.io import VespaQueryResponse
15 |
16 | logging.basicConfig(level=logging.DEBUG)
17 | logger = logging.getLogger(__name__)
18 |
19 |
20 | class VespaQuerier:
21 | def __init__(
22 | self,
23 | config: VespaQueryConfig,
24 | setup: Optional[VespaSetup] = None,
25 | model_name: str = "vidore/colqwen2-v0.1",
26 | ):
27 | self.config = config
28 | self.setup = setup or VespaSetup(
29 | app_name=config.app_name, schema_config=config.schema_config
30 | )
31 |
32 | logger.info(f"Loading model {model_name}")
33 | self.model = ColQwen2.from_pretrained(
34 | model_name, torch_dtype=torch.bfloat16, device_map="auto"
35 | )
36 | self.processor = ColQwen2Processor.from_pretrained(model_name)
37 | self.model.eval()
38 |
39 | self._init_vespa_cloud()
40 |
41 | def _init_vespa_cloud(self) -> None:
42 | try:
43 | self.vespa_cloud = VespaCloud(
44 | tenant=self.config.tenant_name,
45 | application=self.config.app_name,
46 | application_package=self.setup.app_package,
47 | )
48 | except Exception as e:
49 | logger.error(f"Failed to initialize VespaCloud: {str(e)}")
50 | raise VespaQueryError(f"VespaCloud initialization failed: {str(e)}")
51 |
52 | def _prepare_query_tensors(self, query: str) -> Dict[str, List[float]]:
53 | dataloader = DataLoader(
54 | [query],
55 | batch_size=1,
56 | shuffle=False,
57 | collate_fn=lambda x: self.processor.process_queries(x),
58 | )
59 |
60 | with torch.no_grad():
61 | batch_query = next(iter(dataloader))
62 | batch_query = {k: v.to(self.model.device) for k, v in batch_query.items()}
63 | embeddings_query = self.model(**batch_query)
64 | query_embedding = embeddings_query[0].cpu().float()
65 |
66 | # Convert embedding to list format
67 | embedding_list = query_embedding.tolist()
68 | if isinstance(embedding_list[0], list):
69 | # Handle case where embedding is 2D
70 | query_tensors = {i: emb for i, emb in enumerate(embedding_list)}
71 | else:
72 | # Handle case where embedding is 1D
73 | query_tensors = {0: embedding_list}
74 |
75 | return {"input.query(qt)": query_tensors}
76 |
77 | async def execute_queries(self, queries: List[str]) -> Dict[str, List[QueryResult]]:
78 | try:
79 | app = self.vespa_cloud.get_application()
80 | results: Dict[str, List[QueryResult]] = {}
81 |
82 | async with app.asyncio(
83 | connections=self.config.connections, timeout=self.config.timeout
84 | ) as session:
85 | for query in queries:
86 | try:
87 | logger.info(f"Executing query: {query}")
88 | query_tensors = self._prepare_query_tensors(query)
89 |
90 | logger.debug(f"Query tensors: {query_tensors}")
91 |
92 | response = await session.query(
93 | yql="select id, title, url, text, page_number, image from pdf_page where userInput(@userQuery)",
94 | userQuery=query,
95 | hits=self.config.hits_per_query,
96 | body={
97 | **query_tensors,
98 | "presentation.timing": True,
99 | "timeout": str(self.config.timeout),
100 | },
101 | )
102 |
103 | if not response.is_successful():
104 | error_msg = (
105 | response.get_json()
106 | if hasattr(response, "get_json")
107 | else str(response)
108 | )
109 | logger.error(f"Query response error: {error_msg}")
110 | raise VespaQueryError(f"Query failed: {error_msg}")
111 |
112 | results[query] = self._process_response(response)
113 |
114 | except Exception as e:
115 | logger.error(f"Query failed for '{query}': {str(e)}")
116 | results[query] = []
117 |
118 | return results
119 |
120 | except Exception as e:
121 | logger.error(f"Failed to execute queries: {str(e)}")
122 | raise VespaQueryError(f"Query execution failed: {str(e)}")
123 |
124 | def _process_response(self, response: VespaQueryResponse) -> List[QueryResult]:
125 | try:
126 | results = []
127 | for hit in response.hits:
128 | fields = hit["fields"]
129 | results.append(
130 | QueryResult(
131 | title=fields.get("title", "N/A"),
132 | url=fields.get("url", "N/A"),
133 | page_number=fields.get("page_number", -1),
134 | relevance=float(hit.get("relevance", 0.0)),
135 | text=fields.get("text", ""),
136 | source=hit,
137 | )
138 | )
139 | return results
140 | except Exception as e:
141 | logger.error(f"Error processing response: {str(e)}")
142 | return []
143 |
144 | def display_results(self, query: str, results: List[QueryResult]) -> None:
145 | print(f"\nQuery: {query}")
146 | print(f"Total Results: {len(results)}")
147 |
148 | for idx, result in enumerate(results, 1):
149 | print(f"\nResult {idx}:")
150 | print(f"Title: {result.title}")
151 | print(f"URL: {result.url}")
152 | print(f"Page: {result.page_number}")
153 | print(f"Score: {result.relevance:.4f}")
154 | text_preview = (
155 | result.text[:200] + "..." if len(result.text) > 200 else result.text
156 | )
157 | print(f"Text Preview: {text_preview}")
158 |
159 |
160 | async def run_queries(
161 | config_path: str, queries: List[str], display_results: bool = True
162 | ) -> Dict[str, List[QueryResult]]:
163 | try:
164 | config_path = Path(config_path)
165 | if not config_path.exists():
166 | raise FileNotFoundError(f"Configuration file not found: {config_path}")
167 |
168 | with open(config_path) as f:
169 | config_data = yaml.safe_load(f)
170 |
171 | if not config_data.get("vespa"):
172 | raise ValueError("Invalid configuration: 'vespa' section missing")
173 |
174 | vespa_config = VespaQueryConfig.from_dict(config_data["vespa"])
175 | querier = VespaQuerier(vespa_config)
176 |
177 | logger.info(f"Executing {len(queries)} queries")
178 | results = await querier.execute_queries(queries)
179 |
180 | if display_results:
181 | for query, query_results in results.items():
182 | querier.display_results(query, query_results)
183 |
184 | return results
185 |
186 | except Exception as e:
187 | logger.error(f"Failed to run queries: {str(e)}")
188 | raise
189 |
190 |
191 | async def main() -> Dict[str, List[QueryResult]]:
192 | # Path to your config file
193 | config_path = (
194 | "/Users/vesaalexandru/Workspaces/cube/cube-publication/"
195 | "evaluate-ocr-rag-systems/src/vespa/vespa_config.yaml"
196 | )
197 |
198 | # Test queries
199 | queries = [
200 | "Percentage of non-fresh water as source?",
201 | ]
202 |
203 | try:
204 | logger.info("Starting query execution")
205 | results = await run_queries(
206 | config_path=config_path, queries=queries, display_results=True
207 | )
208 | logger.info("Query execution completed successfully")
209 | return results
210 |
211 | except Exception as e:
212 | logger.error(f"Query execution failed: {str(e)}")
213 | raise
214 |
215 |
216 | if __name__ == "__main__":
217 | import asyncio
218 |
219 | try:
220 | results = asyncio.run(main())
221 |
222 | # Print summary of results
223 | print("\nResults Summary:")
224 | for query, query_results in results.items():
225 | print(f"\nQuery: {query}")
226 | print(f"Number of results: {len(query_results)}")
227 | if query_results:
228 | print(f"Top result score: {query_results[0].relevance:.4f}")
229 | print(f"Top result title: {query_results[0].title}")
230 |
231 | except KeyboardInterrupt:
232 | logger.info("Query execution interrupted by user")
233 | except Exception as e:
234 | logger.error(f"Error occurred: {str(e)}")
235 | raise
236 |
--------------------------------------------------------------------------------
/src/vespa/indexing/pdf_processor.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import os
3 | import platform
4 | import tempfile
5 | from io import BytesIO
6 | from typing import Any, Dict, List, Optional, Tuple
7 |
8 | import requests
9 | import torch
10 | from colpali_engine.models import ColQwen2, ColQwen2Processor
11 | from pdf2image import convert_from_path
12 | from pdf2image.exceptions import PDFPageCountError
13 | from pypdf import PdfReader
14 | from torch import Tensor
15 | from torch import device as torch_device
16 | from torch.utils.data import DataLoader
17 | from tqdm import tqdm
18 |
19 | from src.vespa.datatypes import PDFData
20 | from src.vespa.exceptions import PDFProcessingError
21 |
22 | # Configure logging
23 | logging.basicConfig(
24 | level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
25 | )
26 | logger = logging.getLogger(__name__)
27 |
28 |
29 | class PDFProcessor:
30 | """Class for processing PDFs and generating embeddings using ColQwen2 model."""
31 |
32 | def __init__(
33 | self,
34 | model_name: str = "vidore/colqwen2-v0.1",
35 | batch_size: int = 2,
36 | device: Optional[str] = None,
37 | ) -> None:
38 | """
39 | Initialize the PDF processor with the specified model.
40 |
41 | Args:
42 | model_name: Name of the pretrained model to use
43 | batch_size: Batch size for processing images
44 | device: Optional specific device to use ('cuda', 'mps', 'cpu').
45 | If None, will automatically select the optimal device.
46 |
47 | Raises:
48 | RuntimeError: If model loading fails
49 | """
50 | self.batch_size = batch_size
51 | self.device = self._get_optimal_device(device)
52 |
53 | try:
54 | logger.info(f"Loading model {model_name} on {self.device}")
55 | self.model = ColQwen2.from_pretrained(
56 | model_name,
57 | torch_dtype=torch.bfloat16,
58 | device_map="auto" if self.device.type == "cuda" else None,
59 | )
60 | self.processor = ColQwen2Processor.from_pretrained(model_name)
61 | self.model.eval()
62 |
63 | # Move model to device if not using device_map="auto"
64 | if self.device.type != "cuda":
65 | self.model.to(self.device)
66 |
67 | except Exception as e:
68 | logger.error(f"Failed to load model: {str(e)}")
69 | raise RuntimeError(f"Model initialization failed: {str(e)}")
70 |
71 | def _get_optimal_device(
72 | self, requested_device: Optional[str] = None
73 | ) -> torch_device:
74 | """
75 | Determine the optimal device for model execution.
76 |
77 | Args:
78 | requested_device: Optional specific device to use
79 |
80 | Returns:
81 | torch.device: The optimal device for the current system
82 | """
83 | if requested_device:
84 | # If a specific device is requested, try to use it
85 | if requested_device == "cuda" and not torch.cuda.is_available():
86 | logger.warning(
87 | "CUDA requested but not available, falling back to optimal device"
88 | )
89 | elif requested_device == "mps" and not (
90 | platform.system() == "Darwin" and torch.backends.mps.is_available()
91 | ):
92 | logger.warning(
93 | "MPS requested but not available, falling back to optimal device"
94 | )
95 | else:
96 | return torch.device(requested_device)
97 |
98 | # Automatic device selection
99 | if torch.cuda.is_available():
100 | logger.info("Using CUDA device")
101 | return torch.device("cuda")
102 | elif platform.system() == "Darwin" and torch.backends.mps.is_available():
103 | try:
104 | # Test MPS allocation
105 | test_tensor = torch.zeros(1, device="mps")
106 | logger.info("Using MPS device")
107 | return torch.device("mps")
108 | except Exception as e:
109 | logger.warning(f"MPS device available but failed allocation test: {e}")
110 | logger.info("Falling back to CPU")
111 | return torch.device("cpu")
112 | else:
113 | logger.info("Using CPU device")
114 | return torch.device("cpu")
115 |
116 | def _handle_device_fallback(
117 | self, batch: Dict[str, Tensor]
118 | ) -> Tuple[Dict[str, Tensor], torch_device]:
119 | """
120 | Handle device fallback if the current device fails.
121 |
122 | Args:
123 | batch: The current batch of data
124 |
125 | Returns:
126 | Tuple of (processed batch, new device)
127 |
128 | Raises:
129 | PDFProcessingError: If processing fails on all devices
130 | """
131 | devices_to_try = []
132 |
133 | # Add fallback devices in order of preference
134 | if self.device.type != "cpu":
135 | devices_to_try.append("cpu")
136 | if self.device.type != "cuda" and torch.cuda.is_available():
137 | devices_to_try.append("cuda")
138 |
139 | for device_type in devices_to_try:
140 | try:
141 | new_device = torch.device(device_type)
142 | logger.warning(f"Attempting fallback to {device_type}")
143 |
144 | self.model.to(new_device)
145 | new_batch = {k: v.to(new_device) for k, v in batch.items()}
146 |
147 | # Test the new device with a forward pass
148 | with torch.no_grad():
149 | _ = self.model(**new_batch)
150 |
151 | self.device = new_device
152 | return new_batch, new_device
153 |
154 | except Exception as e:
155 | logger.warning(f"Fallback to {device_type} failed: {e}")
156 | continue
157 |
158 | raise PDFProcessingError("Failed to process batch on all available devices")
159 |
160 | def process_pdf(self, pdf_metadata: List[Dict[str, str]]) -> List[PDFData]:
161 | """
162 | Process multiple PDFs and generate their embeddings.
163 |
164 | Args:
165 | pdf_metadata: List of dictionaries containing PDF metadata (must have 'url' and 'title')
166 |
167 | Returns:
168 | List of PDFData objects containing processed information
169 |
170 | Raises:
171 | PDFProcessingError: If processing any PDF fails
172 | """
173 |
174 | pdf_data: List[PDFData] = []
175 |
176 | for pdf in pdf_metadata:
177 | try:
178 | logger.info(f"Processing PDF: {pdf['title']}")
179 | images, texts = self.get_pdf_content(pdf["url"])
180 | embeddings = self.generate_embeddings(images)
181 |
182 | pdf_data.append(
183 | PDFData(
184 | url=pdf["url"],
185 | title=pdf["title"],
186 | images=images,
187 | texts=texts,
188 | embeddings=embeddings,
189 | )
190 | )
191 |
192 | except Exception as e:
193 | logger.error(f"Failed to process PDF {pdf['title']}: {str(e)}")
194 | raise PDFProcessingError(f"Failed to process PDF: {str(e)}")
195 |
196 | return pdf_data
197 |
198 | def get_pdf_content(self, pdf_url: str) -> Tuple[List[Any], List[str]]:
199 | """
200 | Extract images and text content from PDF.
201 |
202 | Args:
203 | pdf_url: URL of the PDF to process
204 |
205 | Returns:
206 | Tuple containing lists of images and extracted text
207 |
208 | Raises:
209 | PDFProcessingError: If PDF processing fails
210 | """
211 | try:
212 | logger.info(f"Downloading PDF from {pdf_url}")
213 | response = requests.get(pdf_url, timeout=30)
214 | response.raise_for_status()
215 | pdf_file = BytesIO(response.content)
216 |
217 | with tempfile.NamedTemporaryFile(suffix=".pdf", delete=False) as temp_file:
218 | temp_path = temp_file.name
219 | temp_file.write(pdf_file.read())
220 |
221 | try:
222 | reader = PdfReader(temp_path)
223 | page_texts = [page.extract_text() for page in reader.pages]
224 |
225 | logger.info("Converting PDF pages to images")
226 | images = convert_from_path(temp_path)
227 |
228 | if len(images) != len(page_texts):
229 | raise PDFProcessingError(
230 | "Mismatch between number of images and texts"
231 | )
232 |
233 | return images, page_texts
234 |
235 | except PDFPageCountError as e:
236 | raise PDFProcessingError(f"Failed to process PDF pages: {str(e)}")
237 | finally:
238 | # Clean up temporary file
239 | os.unlink(temp_path)
240 |
241 | except requests.exceptions.RequestException as e:
242 | raise PDFProcessingError(f"Failed to download PDF: {str(e)}")
243 | except Exception as e:
244 | raise PDFProcessingError(f"Failed to process PDF: {str(e)}")
245 |
246 | @torch.no_grad()
247 | def generate_embeddings(self, images: List[Any]) -> List[Tensor]:
248 | """
249 | Generate embeddings for a list of images.
250 |
251 | Args:
252 | images: List of PIL images to process
253 |
254 | Returns:
255 | List of tensor embeddings
256 |
257 | Raises:
258 | PDFProcessingError: If embedding generation fails
259 | """
260 | try:
261 | embeddings: List[Tensor] = []
262 | dataloader = DataLoader(
263 | images,
264 | batch_size=self.batch_size,
265 | shuffle=False,
266 | collate_fn=lambda x: self.processor.process_images(x),
267 | )
268 |
269 | logger.info(f"Generating embeddings using device: {self.device}")
270 | for batch in tqdm(dataloader, desc="Processing batches"):
271 | try:
272 | batch = {k: v.to(self.device) for k, v in batch.items()}
273 | batch_embeddings = self.model(**batch)
274 | embeddings.extend(list(torch.unbind(batch_embeddings.to("cpu"))))
275 |
276 | except RuntimeError as e:
277 | if any(err in str(e) for err in ["MPS", "CUDA", "device"]):
278 | # Try fallback to another device
279 | batch, new_device = self._handle_device_fallback(batch)
280 | batch_embeddings = self.model(**batch)
281 | embeddings.extend(
282 | list(torch.unbind(batch_embeddings.to("cpu")))
283 | )
284 | else:
285 | raise
286 |
287 | return embeddings
288 |
289 | except Exception as e:
290 | logger.error(f"Failed to generate embeddings: {str(e)}")
291 | raise PDFProcessingError(f"Embedding generation failed: {str(e)}")
292 |
--------------------------------------------------------------------------------