├── .DS_Store ├── .gitignore ├── README.md ├── colpal ├── schemas │ └── pdf_page.sd ├── search │ └── query-profiles │ │ ├── default.xml │ │ └── types │ │ └── root.xml └── services.xml ├── data ├── paper01-1-2.pdf ├── paper01.pdf ├── paper02-3.pdf └── paper02.pdf ├── logs └── failed_documents │ └── failed_documents_1735757189.yaml ├── main.py ├── poetry.lock ├── pyproject.toml ├── src ├── interfaces │ └── interfaces.py ├── ocr_benchmark │ ├── __init__.py │ ├── base.py │ ├── engines │ │ ├── config.py │ │ └── easy_ocr.py │ └── utils │ │ ├── image_processing.py │ │ └── pdf_processing.py ├── pdf_embedding_decider │ ├── analyze.py │ ├── components │ │ ├── layout_analyzer.py │ │ ├── table_detector.py │ │ ├── text_density_analyzer.py │ │ └── visual_element_analyzer.py │ ├── datatypes.py │ └── interfaces.py ├── prompts │ └── llama_vision.py └── vespa │ ├── datatypes.py │ ├── exceptions.py │ ├── indexing │ ├── pdf_processor.py │ ├── prepare_feed.py │ └── run.py │ ├── model.py │ ├── retrieval │ └── run.py │ ├── setup.py │ ├── utils.py │ └── vespa_config.yaml ├── test.py ├── test ├── schemas │ └── pdf_page.sd ├── search │ └── query-profiles │ │ ├── default.xml │ │ └── types │ │ └── root.xml └── services.xml ├── test2 ├── schemas │ └── pdf_page.sd ├── search │ └── query-profiles │ │ ├── default.xml │ │ └── types │ │ └── root.xml └── services.xml ├── test3 ├── schemas │ └── pdf_page.sd ├── search │ └── query-profiles │ │ ├── default.xml │ │ └── types │ │ └── root.xml └── services.xml ├── test4 ├── schemas │ └── pdf_page.sd ├── search │ └── query-profiles │ │ ├── default.xml │ │ └── types │ │ └── root.xml └── services.xml ├── test_vespa_indexing.py ├── test_vespa_inference.py ├── vespa_inference.py └── vespa_vision_rag.py /.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mlvanguards/evaluate-ocr-rag-systems/d33f3978d5fd2e18e29abc3a1a23c93619cad397/.DS_Store -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # UV 98 | # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | #uv.lock 102 | 103 | # poetry 104 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 105 | # This is especially recommended for binary packages to ensure reproducibility, and is more 106 | # commonly ignored for libraries. 107 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 108 | #poetry.lock 109 | 110 | # pdm 111 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 112 | #pdm.lock 113 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 114 | # in version control. 115 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 116 | .pdm.toml 117 | .pdm-python 118 | .pdm-build/ 119 | 120 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 121 | __pypackages__/ 122 | 123 | # Celery stuff 124 | celerybeat-schedule 125 | celerybeat.pid 126 | 127 | # SageMath parsed files 128 | *.sage.py 129 | 130 | # Environments 131 | .env 132 | .venv 133 | env/ 134 | venv/ 135 | ENV/ 136 | env.bak/ 137 | venv.bak/ 138 | 139 | /data 140 | 141 | # Spyder project settings 142 | .spyderproject 143 | .spyproject 144 | 145 | # Rope project settings 146 | .ropeproject 147 | 148 | # mkdocs documentation 149 | /site 150 | 151 | # mypy 152 | .mypy_cache/ 153 | .dmypy.json 154 | dmypy.json 155 | 156 | # Pyre type checker 157 | .pyre/ 158 | 159 | # pytype static type analyzer 160 | .pytype/ 161 | 162 | # Cython debug symbols 163 | cython_debug/ 164 | 165 | # PyCharm 166 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 167 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 168 | # and can be added to the global gitignore or merged into this file. For a more nuclear 169 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 170 | 171 | # IntelijIdea 172 | .idea/ 173 | 174 | # Visual Studio Code 175 | .vscode/ 176 | 177 | # PyPI configuration file 178 | .pypirc -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # evaluate-ocr-rag-systems -------------------------------------------------------------------------------- /colpal/schemas/pdf_page.sd: -------------------------------------------------------------------------------- 1 | schema pdf_page { 2 | document pdf_page { 3 | field id type string { 4 | indexing: summary | index 5 | match { 6 | word 7 | } 8 | } 9 | field url type string { 10 | indexing: summary | index 11 | } 12 | field title type string { 13 | indexing: summary | index 14 | index: enable-bm25 15 | match { 16 | text 17 | } 18 | } 19 | field page_number type int { 20 | indexing: summary | attribute 21 | } 22 | field image type raw { 23 | indexing: summary 24 | } 25 | field text type string { 26 | indexing: index 27 | index: enable-bm25 28 | match { 29 | text 30 | } 31 | } 32 | field embedding type tensor(patch{}, v[16]) { 33 | indexing: attribute | index 34 | attribute { 35 | distance-metric: hamming 36 | } 37 | index { 38 | hnsw { 39 | max-links-per-node: 32 40 | neighbors-to-explore-at-insert: 400 41 | } 42 | } 43 | } 44 | } 45 | fieldset default { 46 | fields: title, text 47 | } 48 | rank-profile default { 49 | inputs { 50 | query(qt) tensor(querytoken{}, v[128]) 51 | 52 | } 53 | function max_sim() { 54 | expression { 55 | 56 | sum( 57 | reduce( 58 | sum( 59 | query(qt) * unpack_bits(attribute(embedding)) , v 60 | ), 61 | max, patch 62 | ), 63 | querytoken 64 | ) 65 | 66 | } 67 | } 68 | function bm25_score() { 69 | expression { 70 | bm25(title) + bm25(text) 71 | } 72 | } 73 | first-phase { 74 | expression { 75 | bm25_score 76 | } 77 | } 78 | second-phase { 79 | rerank-count: 100 80 | expression { 81 | max_sim 82 | } 83 | } 84 | } 85 | } -------------------------------------------------------------------------------- /colpal/search/query-profiles/default.xml: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /colpal/search/query-profiles/types/root.xml: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /colpal/services.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 1 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | -------------------------------------------------------------------------------- /data/paper01-1-2.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mlvanguards/evaluate-ocr-rag-systems/d33f3978d5fd2e18e29abc3a1a23c93619cad397/data/paper01-1-2.pdf -------------------------------------------------------------------------------- /data/paper01.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mlvanguards/evaluate-ocr-rag-systems/d33f3978d5fd2e18e29abc3a1a23c93619cad397/data/paper01.pdf -------------------------------------------------------------------------------- /data/paper02-3.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mlvanguards/evaluate-ocr-rag-systems/d33f3978d5fd2e18e29abc3a1a23c93619cad397/data/paper02-3.pdf -------------------------------------------------------------------------------- /data/paper02.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mlvanguards/evaluate-ocr-rag-systems/d33f3978d5fd2e18e29abc3a1a23c93619cad397/data/paper02.pdf -------------------------------------------------------------------------------- /logs/failed_documents/failed_documents_1735757189.yaml: -------------------------------------------------------------------------------- 1 | - error: 2 | message: 'No field ''images'' in the structure of type ''pdf_page'', which has 3 | the fields: [field ''id'' of type string, field ''url'' of type string, field 4 | ''title'' of type string, field ''page_number'' of type int, field ''image'' 5 | of type raw, field ''text'' of type string, field ''embedding'' of type tensor(patch{},v[16])]' 6 | pathId: /document/v1/pdf_page/pdf_page/docid/https%3A/static.conocophillips.com/files/resources/24-0976-sustainability-highlights_nature.pdf 7 | id: https://static.conocophillips.com/files/resources/24-0976-sustainability-highlights_nature.pdf 8 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mlvanguards/evaluate-ocr-rag-systems/d33f3978d5fd2e18e29abc3a1a23c93619cad397/main.py -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.poetry] 2 | name = "ocr-rag-systems" 3 | version = "0.1.0" 4 | description = "" 5 | authors = ["Vesa Alex"] 6 | readme = "README.md" 7 | 8 | [tool.poetry.dependencies] 9 | python = "^3.12" 10 | torch = "^2.5.1" 11 | surya-ocr = "^0.8.1" 12 | numpy = "1.26.4" 13 | pytesseract = "^0.3.13" 14 | easyocr = "^1.7.2" 15 | paddleocr = "^2.9.1" 16 | python-doctr = "^0.10.0" 17 | pdf2image = "^1.17.0" 18 | pandas = "^2.2.3" 19 | ollama = "^0.4.4" 20 | together = "^1.3.10" 21 | colpali-engine = "0.3.1" 22 | vidore-benchmark = "4.0.0" 23 | google-generativeai = "^0.8.3" 24 | pypdf = "5.0.1" 25 | pyvespa = "^0.51.0" 26 | vespacli = "^8.453.24" 27 | requests = "^2.32.3" 28 | ipython = "^8.31.0" 29 | pymupdf = "^1.25.1" 30 | 31 | 32 | [build-system] 33 | requires = ["poetry-core"] 34 | build-backend = "poetry.core.masonry.api" 35 | -------------------------------------------------------------------------------- /src/interfaces/interfaces.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | 3 | from pydantic import BaseModel 4 | 5 | 6 | class BasePromptTemplate(ABC, BaseModel): 7 | @abstractmethod 8 | def create_template(self, *args) -> str: 9 | pass 10 | -------------------------------------------------------------------------------- /src/ocr_benchmark/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mlvanguards/evaluate-ocr-rag-systems/d33f3978d5fd2e18e29abc3a1a23c93619cad397/src/ocr_benchmark/__init__.py -------------------------------------------------------------------------------- /src/ocr_benchmark/base.py: -------------------------------------------------------------------------------- 1 | """Base class for OCR engines following the Interface Segregation Principle.""" 2 | 3 | from abc import ABC, abstractmethod 4 | from typing import Any, Dict, List 5 | 6 | from PIL import Image 7 | 8 | 9 | class OCREngine(ABC): 10 | """Abstract base class for OCR engines.""" 11 | 12 | @property 13 | @abstractmethod 14 | def name(self) -> str: 15 | """Return the name of the OCR engine.""" 16 | pass 17 | 18 | @abstractmethod 19 | def initialize(self) -> None: 20 | """Initialize the OCR engine with required models and configurations.""" 21 | pass 22 | 23 | @abstractmethod 24 | def process_image(self, image: Image.Image) -> Dict[str, Any]: 25 | """ 26 | Process a single image and return the extracted text and metadata. 27 | 28 | Args: 29 | image: PIL Image object to process 30 | 31 | Returns: 32 | Dictionary containing: 33 | - text: extracted text 34 | - confidence: confidence scores if available 35 | - boxes: bounding boxes if available 36 | - tables: detected tables if available 37 | """ 38 | pass 39 | 40 | @abstractmethod 41 | def process_images(self, images: List[Image.Image]) -> List[Dict[str, Any]]: 42 | """ 43 | Process multiple images and return results for each. 44 | 45 | Args: 46 | images: List of PIL Image objects 47 | 48 | Returns: 49 | List of dictionaries containing results for each image 50 | """ 51 | pass 52 | 53 | def cleanup(self) -> None: 54 | """Cleanup resources. Override if needed.""" 55 | pass 56 | 57 | def __enter__(self): 58 | """Context manager entry.""" 59 | self.initialize() 60 | return self 61 | 62 | def __exit__(self, exc_type, exc_val, exc_tb): 63 | """Context manager exit.""" 64 | self.cleanup() 65 | -------------------------------------------------------------------------------- /src/ocr_benchmark/engines/config.py: -------------------------------------------------------------------------------- 1 | """Configuration classes for OCR engines.""" 2 | 3 | from dataclasses import dataclass 4 | from typing import List, Optional 5 | 6 | 7 | @dataclass 8 | class TesseractConfig: 9 | """Configuration for Tesseract OCR.""" 10 | 11 | tessdata_path: Optional[str] = None 12 | language: str = "eng" 13 | psm: int = 3 # Page segmentation mode 14 | oem: int = 3 # OCR Engine mode 15 | 16 | 17 | @dataclass 18 | class EasyOCRConfig: 19 | """Configuration for EasyOCR.""" 20 | 21 | languages: List[str] = ("en",) 22 | gpu: bool = False 23 | model_storage_directory: Optional[str] = None 24 | download_enabled: bool = True 25 | 26 | 27 | @dataclass 28 | class PaddleOCRConfig: 29 | """Configuration for PaddleOCR.""" 30 | 31 | use_angle_cls: bool = True 32 | lang: str = "en" 33 | use_gpu: bool = False 34 | show_log: bool = False 35 | 36 | 37 | @dataclass 38 | class DocTRConfig: 39 | """Configuration for DocTR.""" 40 | 41 | pretrained: bool = True 42 | assume_straight_pages: bool = True 43 | straighten_pages: bool = True 44 | -------------------------------------------------------------------------------- /src/ocr_benchmark/engines/easy_ocr.py: -------------------------------------------------------------------------------- 1 | """EasyOCR engine implementation.""" 2 | 3 | from typing import Any, Dict, List, Optional 4 | 5 | import easyocr 6 | import numpy as np 7 | from PIL import Image 8 | 9 | from src.ocr_benchmark.base import OCREngine 10 | from src.ocr_benchmark.engines.config import EasyOCRConfig 11 | from src.ocr_benchmark.utils.image_processing import ensure_rgb 12 | 13 | 14 | class EasyOCREngine(OCREngine): 15 | """EasyOCR engine implementation.""" 16 | 17 | def __init__(self, config: Optional[EasyOCRConfig] = None): 18 | """ 19 | Initialize EasyOCREngine. 20 | 21 | Args: 22 | config: Optional configuration for EasyOCR 23 | """ 24 | self._config = config or EasyOCRConfig() 25 | self._initialized = False 26 | self._reader = None 27 | 28 | @property 29 | def name(self) -> str: 30 | return "EasyOCR" 31 | 32 | def initialize(self) -> None: 33 | """Initialize EasyOCR with configuration.""" 34 | if not self._initialized: 35 | self._reader = easyocr.Reader( 36 | lang_list=self._config.languages, 37 | gpu=self._config.gpu, 38 | model_storage_directory=self._config.model_storage_directory, 39 | download_enabled=self._config.download_enabled, 40 | ) 41 | self._initialized = True 42 | 43 | def process_image(self, image: Image.Image) -> Dict[str, Any]: 44 | """ 45 | Process a single image using EasyOCR. 46 | 47 | Args: 48 | image: PIL Image to process 49 | 50 | Returns: 51 | Dictionary containing: 52 | - text: extracted text 53 | - confidence: mean confidence score 54 | - boxes: detected text boxes 55 | """ 56 | if not self._initialized: 57 | self.initialize() 58 | 59 | # Ensure image is in RGB format 60 | image = ensure_rgb(image) 61 | 62 | # Convert PIL Image to numpy array 63 | image_np = np.array(image) 64 | 65 | # Process image with EasyOCR 66 | results = self._reader.readtext(image_np) 67 | 68 | # Extract text and confidence scores 69 | boxes = [] 70 | full_text = [] 71 | confidences = [] 72 | 73 | for bbox, text, conf in results: 74 | full_text.append(text) 75 | confidences.append(conf) 76 | 77 | # Convert bbox points to (x1, y1, x2, y2) format 78 | x1 = min(point[0] for point in bbox) 79 | y1 = min(point[1] for point in bbox) 80 | x2 = max(point[0] for point in bbox) 81 | y2 = max(point[1] for point in bbox) 82 | 83 | boxes.append( 84 | { 85 | "text": text, 86 | "conf": conf, 87 | "box": (int(x1), int(y1), int(x2), int(y2)), 88 | } 89 | ) 90 | 91 | # Calculate mean confidence 92 | mean_confidence = sum(confidences) / len(confidences) if confidences else 0 93 | 94 | return { 95 | "text": " ".join(full_text), 96 | "confidence": mean_confidence, 97 | "boxes": boxes, 98 | } 99 | 100 | def process_images(self, images: List[Image.Image]) -> List[Dict[str, Any]]: 101 | """Process multiple images and return results for each.""" 102 | return [self.process_image(image) for image in images] 103 | 104 | 105 | if __name__ == "__main__": 106 | import os 107 | 108 | from PIL import Image, ImageDraw, ImageFont 109 | 110 | # Create a test image 111 | img = Image.new("RGB", (800, 200), color="white") 112 | d = ImageDraw.Draw(img) 113 | try: 114 | font = ImageFont.truetype("Arial.ttf", 60) 115 | except: 116 | font = ImageFont.load_default() 117 | 118 | d.text((50, 50), "Hello, EasyOCR Test!", fill="black", font=font) 119 | 120 | # Save test image 121 | test_image_path = "test_easyocr.png" 122 | img.save(test_image_path) 123 | 124 | try: 125 | # Initialize engine 126 | engine = EasyOCREngine(EasyOCRConfig(languages=["en"])) 127 | 128 | # Process test image 129 | with Image.open(test_image_path) as test_img: 130 | result = engine.process_image(test_img) 131 | 132 | # Print results 133 | print("\nEasyOCR Test Results:") 134 | print(f"Detected Text: {result['text']}") 135 | print(f"Confidence: {result['confidence']:.2f}") 136 | print("Detected Boxes:", len(result["boxes"])) 137 | 138 | for box in result["boxes"]: 139 | print(f"- Text: {box['text']}, Confidence: {box['conf']:.2f}") 140 | print(f" Box coordinates: {box['box']}") 141 | 142 | except Exception as e: 143 | print(f"Error during test: {str(e)}") 144 | 145 | finally: 146 | # Cleanup 147 | if os.path.exists(test_image_path): 148 | os.remove(test_image_path) 149 | -------------------------------------------------------------------------------- /src/ocr_benchmark/utils/image_processing.py: -------------------------------------------------------------------------------- 1 | """Utility functions for image processing and manipulation.""" 2 | 3 | import logging 4 | from dataclasses import dataclass 5 | from typing import Optional 6 | 7 | import cv2 8 | import numpy as np 9 | from PIL import Image, ImageEnhance, ImageOps 10 | 11 | logger = logging.getLogger(__name__) 12 | 13 | 14 | @dataclass 15 | class ImagePreprocessingConfig: 16 | """Configuration for image preprocessing.""" 17 | 18 | resize_width: Optional[int] = None 19 | resize_height: Optional[int] = None 20 | contrast_factor: float = 1.0 21 | brightness_factor: float = 1.0 22 | sharpen_factor: float = 1.0 23 | denoise: bool = False 24 | deskew: bool = False 25 | binarize: bool = False 26 | 27 | 28 | def ensure_rgb(image: Image.Image) -> Image.Image: 29 | """ 30 | Ensure image is in RGB format. 31 | 32 | Args: 33 | image: Input PIL Image 34 | 35 | Returns: 36 | PIL Image in RGB format 37 | """ 38 | if image.mode != "RGB": 39 | return image.convert("RGB") 40 | return image 41 | 42 | 43 | def resize_image( 44 | image: Image.Image, 45 | width: Optional[int] = None, 46 | height: Optional[int] = None, 47 | maintain_aspect: bool = True, 48 | ) -> Image.Image: 49 | """ 50 | Resize image while optionally maintaining aspect ratio. 51 | 52 | Args: 53 | image: Input PIL Image 54 | width: Target width in pixels 55 | height: Target height in pixels 56 | maintain_aspect: Whether to maintain aspect ratio 57 | 58 | Returns: 59 | Resized PIL Image 60 | """ 61 | if not width and not height: 62 | return image 63 | 64 | orig_width, orig_height = image.size 65 | 66 | if maintain_aspect: 67 | if width and height: 68 | # Use the dimension that results in a smaller image 69 | width_ratio = width / orig_width 70 | height_ratio = height / orig_height 71 | ratio = min(width_ratio, height_ratio) 72 | new_width = int(orig_width * ratio) 73 | new_height = int(orig_height * ratio) 74 | elif width: 75 | ratio = width / orig_width 76 | new_width = width 77 | new_height = int(orig_height * ratio) 78 | else: 79 | ratio = height / orig_height 80 | new_width = int(orig_width * ratio) 81 | new_height = height 82 | else: 83 | new_width = width if width else orig_width 84 | new_height = height if height else orig_height 85 | 86 | return image.resize((new_width, new_height), Image.Resampling.LANCZOS) 87 | 88 | 89 | def enhance_image( 90 | image: Image.Image, 91 | contrast: float = 1.0, 92 | brightness: float = 1.0, 93 | sharpness: float = 1.0, 94 | ) -> Image.Image: 95 | """ 96 | Enhance image using contrast, brightness, and sharpness adjustments. 97 | 98 | Args: 99 | image: Input PIL Image 100 | contrast: Contrast enhancement factor 101 | brightness: Brightness enhancement factor 102 | sharpness: Sharpness enhancement factor 103 | 104 | Returns: 105 | Enhanced PIL Image 106 | """ 107 | if contrast != 1.0: 108 | image = ImageEnhance.Contrast(image).enhance(contrast) 109 | if brightness != 1.0: 110 | image = ImageEnhance.Brightness(image).enhance(brightness) 111 | if sharpness != 1.0: 112 | image = ImageEnhance.Sharpness(image).enhance(sharpness) 113 | return image 114 | 115 | 116 | def denoise_image(image: Image.Image) -> Image.Image: 117 | """ 118 | Apply denoising to image. 119 | 120 | Args: 121 | image: Input PIL Image 122 | 123 | Returns: 124 | Denoised PIL Image 125 | """ 126 | # Convert to numpy array for OpenCV processing 127 | img_array = np.array(image) 128 | denoised = cv2.fastNlMeansDenoisingColored(img_array, None, 10, 10, 7, 21) 129 | return Image.fromarray(denoised) 130 | 131 | 132 | def deskew_image(image: Image.Image) -> Image.Image: 133 | """ 134 | Deskew image by detecting and correcting rotation. 135 | 136 | Args: 137 | image: Input PIL Image 138 | 139 | Returns: 140 | Deskewed PIL Image 141 | """ 142 | # Convert to numpy array and grayscale 143 | img_array = np.array(image) 144 | gray = cv2.cvtColor(img_array, cv2.COLOR_RGB2GRAY) 145 | 146 | # Detect edges 147 | edges = cv2.Canny(gray, 50, 150, apertureSize=3) 148 | 149 | # Detect lines using Hough transform 150 | lines = cv2.HoughLines(edges, 1, np.pi / 180, 100) 151 | 152 | if lines is not None: 153 | # Calculate the most common angle 154 | angles = [] 155 | for rho, theta in lines[:, 0]: 156 | angle = np.degrees(theta) 157 | if angle < 45: 158 | angles.append(angle) 159 | elif angle > 135: 160 | angles.append(angle - 180) 161 | 162 | if angles: 163 | median_angle = np.median(angles) 164 | if abs(median_angle) > 0.5: # Only rotate if angle is significant 165 | return image.rotate( 166 | median_angle, expand=True, fillcolor=(255, 255, 255) 167 | ) 168 | 169 | return image 170 | 171 | 172 | def binarize_image(image: Image.Image) -> Image.Image: 173 | """ 174 | Convert image to binary (black and white) using adaptive thresholding. 175 | 176 | Args: 177 | image: Input PIL Image 178 | 179 | Returns: 180 | Binarized PIL Image 181 | """ 182 | # Convert to grayscale 183 | gray = ImageOps.grayscale(image) 184 | 185 | # Convert to numpy array for OpenCV processing 186 | img_array = np.array(gray) 187 | 188 | # Apply adaptive thresholding 189 | binary = cv2.adaptiveThreshold( 190 | img_array, 191 | 255, 192 | cv2.ADAPTIVE_THRESH_GAUSSIAN_C, 193 | cv2.THRESH_BINARY, 194 | 11, # Block size 195 | 2, # Constant subtracted from mean 196 | ) 197 | 198 | return Image.fromarray(binary) 199 | 200 | 201 | def preprocess_image( 202 | image: Image.Image, config: ImagePreprocessingConfig 203 | ) -> Image.Image: 204 | """ 205 | Apply a series of preprocessing steps to an image. 206 | 207 | Args: 208 | image: Input PIL Image 209 | config: Preprocessing configuration 210 | 211 | Returns: 212 | Preprocessed PIL Image 213 | """ 214 | try: 215 | # Ensure RGB format 216 | image = ensure_rgb(image) 217 | 218 | # Resize if needed 219 | if config.resize_width or config.resize_height: 220 | image = resize_image(image, config.resize_width, config.resize_height) 221 | 222 | # Apply enhancements 223 | if any( 224 | factor != 1.0 225 | for factor in [ 226 | config.contrast_factor, 227 | config.brightness_factor, 228 | config.sharpen_factor, 229 | ] 230 | ): 231 | image = enhance_image( 232 | image, 233 | config.contrast_factor, 234 | config.brightness_factor, 235 | config.sharpen_factor, 236 | ) 237 | 238 | # Apply denoising 239 | if config.denoise: 240 | image = denoise_image(image) 241 | 242 | # Apply deskewing 243 | if config.deskew: 244 | image = deskew_image(image) 245 | 246 | # Apply binarization 247 | if config.binarize: 248 | image = binarize_image(image) 249 | 250 | return image 251 | 252 | except Exception as e: 253 | logger.error(f"Error preprocessing image: {str(e)}") 254 | raise 255 | -------------------------------------------------------------------------------- /src/ocr_benchmark/utils/pdf_processing.py: -------------------------------------------------------------------------------- 1 | """Utility functions for PDF processing and conversion.""" 2 | 3 | import logging 4 | import tempfile 5 | from dataclasses import dataclass 6 | from pathlib import Path 7 | from typing import List, Optional, Tuple 8 | 9 | import pdf2image 10 | from pdf2image.exceptions import ( 11 | PDFPageCountError, 12 | PDFSyntaxError, 13 | ) 14 | from PIL import Image 15 | 16 | logger = logging.getLogger(__name__) 17 | 18 | 19 | @dataclass 20 | class PDFConversionSettings: 21 | """Settings for PDF to image conversion.""" 22 | 23 | dpi: int = 300 24 | grayscale: bool = False 25 | use_cropbox: bool = False 26 | strict: bool = False 27 | thread_count: int = 4 28 | raise_on_error: bool = True 29 | 30 | 31 | class PDFProcessingError(Exception): 32 | """Base exception for PDF processing errors.""" 33 | 34 | pass 35 | 36 | 37 | def validate_pdf(pdf_path: Path) -> Tuple[bool, str]: 38 | """ 39 | Validate PDF file existence and readability. 40 | 41 | Args: 42 | pdf_path: Path to PDF file 43 | 44 | Returns: 45 | Tuple of (is_valid, message) 46 | """ 47 | try: 48 | if not pdf_path.exists(): 49 | return False, f"File not found: {pdf_path}" 50 | 51 | if pdf_path.stat().st_size == 0: 52 | return False, f"File is empty: {pdf_path}" 53 | 54 | # Try converting first page to validate format 55 | pdf2image.convert_from_path(str(pdf_path), first_page=1, last_page=1) 56 | return True, "Valid PDF file" 57 | 58 | except PDFSyntaxError: 59 | return False, f"Invalid PDF format or corrupted file: {pdf_path}" 60 | except PDFPageCountError: 61 | return False, f"Error determining page count: {pdf_path}" 62 | except Exception as e: 63 | return False, f"Error validating PDF: {str(e)}" 64 | 65 | 66 | def pdf_to_images( 67 | pdf_path: Path, settings: Optional[PDFConversionSettings] = None 68 | ) -> List[Image.Image]: 69 | """ 70 | Convert PDF file to a list of PIL Images. 71 | 72 | Args: 73 | pdf_path: Path to PDF file 74 | settings: Optional conversion settings 75 | 76 | Returns: 77 | List of PIL Images, one per page 78 | 79 | Raises: 80 | PDFProcessingError: If conversion fails and raise_on_error is True 81 | """ 82 | if settings is None: 83 | settings = PDFConversionSettings() 84 | 85 | try: 86 | # Validate PDF first 87 | is_valid, message = validate_pdf(pdf_path) 88 | if not is_valid and settings.raise_on_error: 89 | raise PDFProcessingError(message) 90 | 91 | # Create temporary directory for conversion 92 | with tempfile.TemporaryDirectory() as temp_dir: 93 | try: 94 | # Convert PDF to images 95 | images = pdf2image.convert_from_path( 96 | str(pdf_path), 97 | dpi=settings.dpi, 98 | grayscale=settings.grayscale, 99 | use_cropbox=settings.use_cropbox, 100 | strict=settings.strict, 101 | thread_count=settings.thread_count, 102 | output_folder=temp_dir, 103 | ) 104 | 105 | logger.info(f"Successfully converted PDF with {len(images)} pages") 106 | return images 107 | 108 | except Exception as e: 109 | error_msg = f"Error converting PDF to images: {str(e)}" 110 | if settings.raise_on_error: 111 | raise PDFProcessingError(error_msg) 112 | logger.error(error_msg) 113 | return [] 114 | 115 | except Exception as e: 116 | error_msg = f"PDF processing error: {str(e)}" 117 | if settings.raise_on_error: 118 | raise PDFProcessingError(error_msg) 119 | logger.error(error_msg) 120 | return [] 121 | 122 | 123 | def estimate_pdf_size(pdf_path: Path) -> Tuple[int, str]: 124 | """ 125 | Estimate memory requirements for PDF conversion. 126 | 127 | Args: 128 | pdf_path: Path to PDF file 129 | 130 | Returns: 131 | Tuple of (size_in_bytes, human_readable_size) 132 | """ 133 | try: 134 | # Convert first page to estimate size per page 135 | sample = pdf2image.convert_from_path(str(pdf_path), first_page=1, last_page=1)[ 136 | 0 137 | ] 138 | width, height = sample.size 139 | channels = len(sample.getbands()) 140 | 141 | # Get total page count 142 | info = pdf2image.pdfinfo_from_path(str(pdf_path)) 143 | total_pages = info["Pages"] 144 | 145 | # Estimate total memory requirement 146 | bytes_per_page = width * height * channels 147 | total_bytes = bytes_per_page * total_pages 148 | 149 | # Convert to human readable 150 | for unit in ["B", "KB", "MB", "GB"]: 151 | if total_bytes < 1024: 152 | return total_bytes, f"{total_bytes:.1f}{unit}" 153 | total_bytes /= 1024 154 | 155 | return total_bytes, f"{total_bytes:.1f}GB" 156 | 157 | except Exception as e: 158 | logger.error(f"Error estimating PDF size: {str(e)}") 159 | return 0, "Unknown" 160 | -------------------------------------------------------------------------------- /src/pdf_embedding_decider/analyze.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from typing import Dict 3 | 4 | from src.pdf_embedding_decider.components.layout_analyzer import LayoutAnalyzer 5 | from src.pdf_embedding_decider.components.table_detector import TableDetector 6 | from src.pdf_embedding_decider.components.text_density_analyzer import ( 7 | TextDensityAnalyzer, 8 | ) 9 | from src.pdf_embedding_decider.components.visual_element_analyzer import ( 10 | VisualElementAnalyzer, 11 | ) 12 | from src.pdf_embedding_decider.datatypes import ( 13 | AnalysisConfig, 14 | AnalysisResult, 15 | EmbeddingType, 16 | ) 17 | 18 | # Configure logging 19 | logging.basicConfig(level=logging.INFO) 20 | logger = logging.getLogger(__name__) 21 | 22 | 23 | class PDFEmbeddingDecider: 24 | def __init__(self, config: AnalysisConfig): 25 | self.config = config 26 | self.analyzers = { 27 | "visual": VisualElementAnalyzer(config), 28 | "density": TextDensityAnalyzer(config), 29 | "layout": LayoutAnalyzer(), 30 | "table": TableDetector(config), 31 | } 32 | 33 | def analyze(self, pdf_path: str) -> Dict[str, AnalysisResult]: 34 | """Run all analyses and return detailed results""" 35 | return { 36 | name: analyzer.analyze(pdf_path) 37 | for name, analyzer in self.analyzers.items() 38 | } 39 | 40 | def decide(self, pdf_path: str) -> EmbeddingType: 41 | """Determine the appropriate embedding type based on PDF analysis""" 42 | try: 43 | results = self.analyze(pdf_path) 44 | 45 | # Log detailed analysis results 46 | logger.info("Analysis Results:") 47 | for analyzer_name, result in results.items(): 48 | logger.info(f"{analyzer_name}: {result}") 49 | 50 | # Enhanced decision logic incorporating table analysis 51 | table_score = results["table"].score 52 | 53 | if table_score > 0: 54 | table_influence = min(1.0, table_score * self.config.table_weight) 55 | adjusted_density_threshold = self.config.text_density_threshold * ( 56 | 1 - table_influence 57 | ) 58 | else: 59 | adjusted_density_threshold = self.config.text_density_threshold 60 | 61 | if ( 62 | results["visual"].score > self.config.visual_threshold 63 | or results["density"].score < adjusted_density_threshold 64 | or ( 65 | results["layout"].score > self.config.layout_threshold 66 | and table_score == 0 67 | ) 68 | ): # Only consider complex layout if not tabular 69 | return EmbeddingType.COLPALI 70 | 71 | return EmbeddingType.TRADITIONAL 72 | 73 | except Exception as e: 74 | logger.error(f"Error deciding embedding type: {str(e)}") 75 | raise 76 | 77 | 78 | if __name__ == "__main__": 79 | config = AnalysisConfig( 80 | visual_threshold=15, 81 | text_density_threshold=0.25, 82 | layout_threshold=100, 83 | min_image_size=1000, 84 | table_row_threshold=5, 85 | table_weight=0.3, 86 | ) 87 | 88 | # Initialize decider 89 | decider = PDFEmbeddingDecider(config) 90 | 91 | # Example usage 92 | pdf_path = "/Users/vesaalexandru/Workspaces/cube/cube-publication/evaluate-ocr-rag-systems/data/aiminded-extras-octomrbie-decembrie-2023.pdf" 93 | # pdf_path = "/Users/vesaalexandru/Workspaces/cube/cube-publication/evaluate-ocr-rag-systems/data/paper01-1-2.pdf" 94 | try: 95 | # Get detailed analysis 96 | analysis_results = decider.analyze(pdf_path) 97 | 98 | # Get final decision 99 | embedding_type = decider.decide(pdf_path) 100 | 101 | logger.info(f"Recommended embedding type: {embedding_type.value}") 102 | logger.info("Detailed analysis results:") 103 | for analyzer_name, result in analysis_results.items(): 104 | logger.info(f"{analyzer_name}: {result}") 105 | 106 | except Exception as e: 107 | logger.error(f"Error processing PDF: {str(e)}") 108 | -------------------------------------------------------------------------------- /src/pdf_embedding_decider/components/layout_analyzer.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from types import Any, Dict 3 | 4 | import fitz 5 | 6 | from pdf_embedding_decider.interfaces import PDFAnalyzer 7 | from src.pdf_embedding_decider.datatypes import AnalysisResult 8 | 9 | logging.basicConfig(level=logging.INFO) 10 | logger = logging.getLogger(__name__) 11 | 12 | 13 | class LayoutAnalyzer(PDFAnalyzer): 14 | def _analyze_document(self, doc: fitz.Document) -> AnalysisResult: 15 | total_complexity = 0 16 | page_layouts = [] 17 | 18 | for page_num, page in enumerate(doc): 19 | layout_info = self._analyze_page_layout(page) 20 | complexity = layout_info["block_count"] + len(layout_info["alignments"]) 21 | total_complexity += complexity 22 | 23 | page_layouts.append( 24 | {"page": page_num + 1, **layout_info, "complexity_score": complexity} 25 | ) 26 | 27 | return AnalysisResult( 28 | score=total_complexity, 29 | details={ 30 | "total_complexity": total_complexity, 31 | "page_layouts": page_layouts, 32 | }, 33 | ) 34 | 35 | def _analyze_page_layout(self, page: fitz.Page) -> Dict[str, Any]: 36 | text_blocks = page.get_text("blocks") 37 | alignments = set() 38 | 39 | for block in text_blocks: 40 | x0, y0, x1, y1, *_ = block 41 | if x0 < page.rect.width * 0.3: 42 | alignments.add("left") 43 | elif x1 > page.rect.width * 0.7: 44 | alignments.add("right") 45 | else: 46 | alignments.add("center") 47 | 48 | return {"block_count": len(text_blocks), "alignments": list(alignments)} 49 | -------------------------------------------------------------------------------- /src/pdf_embedding_decider/components/table_detector.py: -------------------------------------------------------------------------------- 1 | from types import Any, Dict, List 2 | 3 | import fitz 4 | 5 | from pdf_embedding_decider.interfaces import PDFAnalyzer 6 | from src.pdf_embedding_decider.datatypes import AnalysisConfig, AnalysisResult 7 | 8 | 9 | class TableDetector(PDFAnalyzer): 10 | """Analyzes the presence and structure of tables in the document""" 11 | 12 | def __init__(self, config: AnalysisConfig): 13 | self.config = config 14 | 15 | def _analyze_document(self, doc: fitz.Document) -> AnalysisResult: 16 | total_tables = 0 17 | table_details = [] 18 | 19 | for page_num, page in enumerate(doc): 20 | tables = self._detect_tables(page) 21 | total_tables += len(tables) 22 | table_details.extend( 23 | [ 24 | { 25 | "page": page_num + 1, 26 | "rows": table["rows"], 27 | "columns": table["columns"], 28 | "area": table["area"], 29 | } 30 | for table in tables 31 | ] 32 | ) 33 | 34 | return AnalysisResult( 35 | score=total_tables, 36 | details={"total_tables": total_tables, "table_details": table_details}, 37 | ) 38 | 39 | def _detect_tables(self, page: fitz.Page) -> List[Dict[str, Any]]: 40 | """Detect tables based on text block alignment and spacing""" 41 | blocks = page.get_text("blocks") 42 | tables = [] 43 | 44 | # Sort blocks by vertical position 45 | sorted_blocks = sorted(blocks, key=lambda b: (b[1], b[0])) # Sort by y, then x 46 | 47 | current_table = {"rows": [], "y_positions": set()} 48 | 49 | for block in sorted_blocks: 50 | x0, y0, x1, y1, text, *_ = block 51 | 52 | # Check if block is part of current table structure 53 | if current_table["rows"]: 54 | last_y = max(current_table["y_positions"]) 55 | y_gap = y0 - last_y 56 | 57 | if y_gap > 20: # New table or not part of table 58 | if len(current_table["rows"]) >= self.config.table_row_threshold: 59 | tables.append(self._finalize_table(current_table)) 60 | current_table = {"rows": [], "y_positions": set()} 61 | 62 | current_table["rows"].append({"text": text, "bbox": (x0, y0, x1, y1)}) 63 | current_table["y_positions"].add(y0) 64 | 65 | # Check last table 66 | if len(current_table["rows"]) >= self.config.table_row_threshold: 67 | tables.append(self._finalize_table(current_table)) 68 | 69 | return tables 70 | 71 | def _finalize_table(self, table_data: Dict) -> Dict[str, Any]: 72 | """Calculate table metrics""" 73 | rows = table_data["rows"] 74 | x_positions = [] 75 | for row in rows: 76 | x_positions.extend([row["bbox"][0], row["bbox"][2]]) 77 | 78 | # Estimate columns by analyzing x-position clusters 79 | x_clusters = self._cluster_positions(x_positions) 80 | 81 | return { 82 | "rows": len(rows), 83 | "columns": len(x_clusters) 84 | // 2, # Divide by 2 as we counted start/end positions 85 | "area": self._calculate_table_area(rows), 86 | } 87 | 88 | @staticmethod 89 | def _cluster_positions( 90 | positions: List[float], threshold: float = 10 91 | ) -> List[float]: 92 | """Cluster similar x-positions together""" 93 | positions = sorted(positions) 94 | clusters = [] 95 | current_cluster = [positions[0]] 96 | 97 | for pos in positions[1:]: 98 | if pos - current_cluster[-1] <= threshold: 99 | current_cluster.append(pos) 100 | else: 101 | clusters.append(sum(current_cluster) / len(current_cluster)) 102 | current_cluster = [pos] 103 | 104 | if current_cluster: 105 | clusters.append(sum(current_cluster) / len(current_cluster)) 106 | 107 | return clusters 108 | 109 | @staticmethod 110 | def _calculate_table_area(rows: List[Dict]) -> float: 111 | """Calculate the total area occupied by the table""" 112 | if not rows: 113 | return 0 114 | 115 | x_min = min(row["bbox"][0] for row in rows) 116 | x_max = max(row["bbox"][2] for row in rows) 117 | y_min = min(row["bbox"][1] for row in rows) 118 | y_max = max(row["bbox"][3] for row in rows) 119 | 120 | return (x_max - x_min) * (y_max - y_min) 121 | -------------------------------------------------------------------------------- /src/pdf_embedding_decider/components/text_density_analyzer.py: -------------------------------------------------------------------------------- 1 | import fitz 2 | 3 | from pdf_embedding_decider.interfaces import PDFAnalyzer 4 | from src.pdf_embedding_decider.datatypes import AnalysisConfig, AnalysisResult 5 | 6 | 7 | class TextDensityAnalyzer(PDFAnalyzer): 8 | def __init__(self, config: AnalysisConfig): 9 | self.config = config 10 | 11 | def _analyze_document(self, doc: fitz.Document) -> AnalysisResult: 12 | total_density = 0 13 | page_densities = [] 14 | 15 | for page_num, page in enumerate(doc): 16 | text_area = self._calculate_text_area(page) 17 | page_area = page.rect.width * page.rect.height 18 | density = text_area / page_area if page_area > 0 else 0 19 | 20 | page_densities.append( 21 | { 22 | "page": page_num + 1, 23 | "density": density, 24 | "text_area": text_area, 25 | "page_area": page_area, 26 | } 27 | ) 28 | total_density += density 29 | 30 | avg_density = total_density / len(doc) if len(doc) > 0 else 0 31 | return AnalysisResult( 32 | score=avg_density, 33 | details={"average_density": avg_density, "page_densities": page_densities}, 34 | ) 35 | 36 | @staticmethod 37 | def _calculate_text_area(page: fitz.Page) -> float: 38 | text_area = 0 39 | for block in page.get_text("blocks"): 40 | x0, y0, x1, y1, *_ = block 41 | text_area += (x1 - x0) * (y1 - y0) 42 | return text_area 43 | -------------------------------------------------------------------------------- /src/pdf_embedding_decider/components/visual_element_analyzer.py: -------------------------------------------------------------------------------- 1 | import fitz 2 | 3 | from pdf_embedding_decider.interfaces import PDFAnalyzer 4 | from src.pdf_embedding_decider.datatypes import AnalysisConfig, AnalysisResult 5 | 6 | 7 | class VisualElementAnalyzer(PDFAnalyzer): 8 | def __init__(self, config: AnalysisConfig): 9 | self.config = config 10 | 11 | def _analyze_document(self, doc: fitz.Document) -> AnalysisResult: 12 | total_visual_elements = 0 13 | image_details = [] 14 | drawing_details = [] 15 | 16 | for page_num, page in enumerate(doc): 17 | # Analyze images 18 | images = page.get_images(full=True) 19 | filtered_images = [ 20 | img for img in images if self._is_significant_image(page, img) 21 | ] 22 | image_details.extend( 23 | [ 24 | {"page": page_num + 1, "size": self._get_image_size(page, img)} 25 | for img in filtered_images 26 | ] 27 | ) 28 | 29 | # Analyze vector graphics 30 | drawings = page.get_drawings() 31 | significant_drawings = self._filter_significant_drawings(drawings) 32 | drawing_details.extend( 33 | [ 34 | {"page": page_num + 1, "complexity": len(draw["items"])} 35 | for draw in significant_drawings 36 | ] 37 | ) 38 | 39 | total_visual_elements += len(filtered_images) + len(significant_drawings) 40 | 41 | return AnalysisResult( 42 | score=total_visual_elements, 43 | details={ 44 | "total_elements": total_visual_elements, 45 | "images": image_details, 46 | "drawings": drawing_details, 47 | }, 48 | ) 49 | 50 | def _is_significant_image(self, page: fitz.Page, image: tuple) -> bool: 51 | """Filter out small or insignificant images""" 52 | xref = image[0] 53 | pix = fitz.Pixmap(page.parent, xref) 54 | area = pix.width * pix.height 55 | return area >= self.config.min_image_size 56 | 57 | def _filter_significant_drawings(self, drawings: List[dict]) -> List[dict]: 58 | """Filter out simple decorative elements""" 59 | return [d for d in drawings if len(d["items"]) > 2] 60 | 61 | @staticmethod 62 | def _get_image_size(page: fitz.Page, image: tuple) -> dict: 63 | xref = image[0] 64 | pix = fitz.Pixmap(page.parent, xref) 65 | return {"width": pix.width, "height": pix.height} 66 | -------------------------------------------------------------------------------- /src/pdf_embedding_decider/datatypes.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from enum import Enum 3 | from typing import Any, Dict 4 | 5 | 6 | class EmbeddingType(Enum): 7 | COLPALI = "ColPali" 8 | TRADITIONAL = "Traditional" 9 | 10 | 11 | @dataclass 12 | class AnalysisConfig: 13 | """Configuration parameters for PDF analysis""" 14 | 15 | visual_threshold: int = 15 16 | text_density_threshold: float = 0.25 17 | layout_threshold: int = 100 18 | min_image_size: int = 1000 19 | table_row_threshold: int = 5 20 | table_weight: float = 0.3 21 | 22 | 23 | @dataclass 24 | class AnalysisResult: 25 | """Structured container for analysis results""" 26 | 27 | score: float 28 | details: Dict[str, Any] 29 | confidence: float = 1.0 30 | -------------------------------------------------------------------------------- /src/pdf_embedding_decider/interfaces.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import fitz 4 | 5 | from src.pdf_embedding_decider.datatypes import AnalysisResult 6 | 7 | logging.basicConfig(level=logging.INFO) 8 | logger = logging.getLogger(__name__) 9 | 10 | 11 | class PDFAnalyzer: 12 | """Base class for PDF analysis with error handling and resource management""" 13 | 14 | def analyze(self, pdf_path: str) -> AnalysisResult: 15 | try: 16 | doc = fitz.open(pdf_path) 17 | result = self._analyze_document(doc) 18 | return result 19 | except FileNotFoundError: 20 | logger.error(f"PDF file not found: {pdf_path}") 21 | raise 22 | except Exception as e: 23 | logger.error(f"Error analyzing PDF: {str(e)}") 24 | raise 25 | finally: 26 | if "doc" in locals(): 27 | doc.close() 28 | 29 | def _analyze_document(self, doc: fitz.Document) -> AnalysisResult: 30 | raise NotImplementedError 31 | -------------------------------------------------------------------------------- /src/prompts/llama_vision.py: -------------------------------------------------------------------------------- 1 | from src.interfaces.interfaces import BasePromptTemplate 2 | 3 | 4 | class AnalyzePdfImage(BasePromptTemplate): 5 | prompt: str = """You are an OCR assistant specialized in analyzing PDF images. Your tasks are: 6 | 7 | 1. **Full Text Recognition:** Extract all visible text from the page, including equations, special characters, and inline formatting. Ensure accuracy in representing symbols, numbers, and technical terminology. 8 | 2. **Structure and Formatting Preservation:** Recreate the original structure and formatting of the text, including: 9 | - Headings and subheadings (use appropriate markdown syntax for headers). 10 | - Paragraphs (preserve line breaks and text alignment). 11 | - Lists (use bullet points or numbered lists as necessary). 12 | - Inline formatting (e.g., bold, italics, subscript, superscript). 13 | 3. **Table Extraction:** Identify and extract any tables from the page, preserving their structure and content. Include: 14 | - Table headers, rows, and columns. 15 | - Ensure proper alignment and representation using markdown table formatting. 16 | 4. **Figure and Graph Recognition:** Detect any figures, graphs, or charts on the page and provide the following: 17 | - A detailed description of each figure or graph. 18 | - Key elements such as titles, axes labels, data points, and trends. 19 | - Include a markdown-formatted list with clear annotations. 20 | 5. **Metadata Extraction:** Extract and include metadata such as: 21 | - Document title, authors, publication year, and source information (if available). 22 | - Licensing and attribution details. 23 | 6. **Additional Context:** Capture any footnotes, captions, or side notes present on the page. Include these in a separate "Notes" section to preserve supplementary information. 24 | 25 | **Output Format:** 26 | Return the extracted content in the following markdown format: 27 | 28 | - Use `#`, `##`, or `###` for headings to match the original hierarchy. 29 | - Represent lists with `-` or `1.` for bullet points and numbered lists. 30 | - Use markdown table formatting to represent tables, ensuring alignment and clarity. 31 | - Provide descriptions for figures and graphs as a markdown-formatted list. 32 | - Add a "Key Insights" section summarizing critical findings or observations. 33 | - Include a "Notes" section for supplementary information such as footnotes or captions. 34 | 35 | **Example Output:** 36 | ``` 37 | # Document Title 38 | 39 | ## Metadata 40 | - **Title:** A Survey on Image Data Augmentation for Deep Learning 41 | - **Authors:** Connor Shorten and Taghi M. Khoshgoftaar 42 | - **Year:** 2019 43 | - **License:** Creative Commons Attribution 4.0 International License 44 | 45 | ## Key Insights 46 | - Data Augmentation improves model robustness and generalization, addressing overfitting in limited datasets. 47 | - Techniques include geometric transformations, GAN-based augmentations, and meta-learning. 48 | 49 | ## Abstract 50 | Deep convolutional neural networks have performed remarkably well on many Computer Vision tasks. However, these networks rely heavily on big data to avoid overfitting. Data Augmentation encompasses techniques such as geometric transformations, color space augmentations, and adversarial training. This paper outlines promising developments in these areas and their impact on deep learning model performance. 51 | 52 | ## Introduction 53 | Deep learning models excel in discriminative tasks like image classification and segmentation, leveraging architectures like AlexNet and ResNet. This paper focuses on augmenting data to expand training datasets, improving performance and generalization. 54 | 55 | ### Methodology 56 | - **Geometric Transformations:** Rotation, scaling, translation. 57 | - **Color Space Augmentation:** RGB to HSV conversion, random jittering. 58 | - **Kernel Filters:** Gaussian blur, median filter. 59 | - **Mixing Images:** Random erasing, GAN-generated augmentations. 60 | 61 | ### Results 62 | The proposed framework was evaluated on datasets like CIFAR-10 and SVHN, outperforming state-of-the-art augmentation techniques in accuracy and robustness. 63 | 64 | ### Graph Descriptions 65 | 1. *Figure 1:* Validation vs. training error graph illustrating overfitting (left) and desired generalization (right). The x-axis represents training epochs, and the y-axis represents error rates. 66 | 67 | ## Notes 68 | - *Figure 1:* Caption reads "Validation vs. training error over epochs." 69 | - All data was sourced from publicly available datasets. 70 | ``` 71 | 72 | **Instructions:** 73 | - Ensure high fidelity in text and structural representation. 74 | - Capture all elements on the page, including supplementary information. 75 | - Use markdown for clarity and consistency in the output. 76 | 77 | """ 78 | 79 | def create_template( 80 | self, 81 | ) -> str: 82 | return self.prompt.format() 83 | -------------------------------------------------------------------------------- /src/vespa/datatypes.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Any, Dict, List, Optional 3 | 4 | from torch import Tensor 5 | 6 | 7 | @dataclass 8 | class PDFInput: 9 | """Data class for PDF input information.""" 10 | 11 | title: str 12 | url: str 13 | 14 | 15 | @dataclass 16 | class PDFData: 17 | """Data class to store processed PDF information.""" 18 | 19 | url: str 20 | title: str 21 | images: List[Any] # PIL.Image type 22 | texts: List[str] 23 | embeddings: List[Tensor] 24 | 25 | 26 | @dataclass 27 | class PDFPage: 28 | id: str 29 | url: str 30 | title: str 31 | page_number: int 32 | image: str 33 | text: str 34 | embedding: Dict[int, str] 35 | 36 | 37 | @dataclass 38 | class VespaSchemaConfig: 39 | """Configuration for Vespa schema settings.""" 40 | 41 | max_query_terms: int = 64 42 | hnsw_max_links: int = 32 43 | hnsw_neighbors: int = 400 44 | rerank_count: int = 10 45 | tensor_dimensions: int = 16 46 | 47 | @classmethod 48 | def from_dict(cls, config_dict: Dict[str, Any]) -> "VespaSchemaConfig": 49 | """Create schema config from dictionary with default values.""" 50 | return cls( 51 | max_query_terms=config_dict.get("max_query_terms", 64), 52 | hnsw_max_links=config_dict.get("hnsw_max_links", 32), 53 | hnsw_neighbors=config_dict.get("hnsw_neighbors", 400), 54 | rerank_count=config_dict.get("rerank_count", 10), 55 | tensor_dimensions=config_dict.get("tensor_dimensions", 16), 56 | ) 57 | 58 | 59 | @dataclass 60 | class VespaConfig: 61 | """Configuration for Vespa deployment.""" 62 | 63 | app_name: str 64 | tenant_name: str 65 | connections: int = 1 66 | timeout: int = 180 67 | schema_name: str = "pdf_page" 68 | schema_config: Optional[VespaSchemaConfig] = None 69 | 70 | @classmethod 71 | def from_dict(cls, config_dict: Dict[str, Any]) -> "VespaConfig": 72 | """Create VespaConfig from dictionary, handling schema config separately.""" 73 | # Extract and convert schema config if present 74 | schema_dict = config_dict.pop("schema", {}) 75 | schema_config = ( 76 | VespaSchemaConfig.from_dict(schema_dict) if schema_dict else None 77 | ) 78 | 79 | return cls(**config_dict, schema_config=schema_config) 80 | 81 | 82 | @dataclass 83 | class VespaQueryConfig: 84 | app_name: str 85 | tenant_name: str 86 | connections: int = 1 87 | timeout: int = 180 88 | hits_per_query: int = 5 89 | schema_name: str = "pdf_page" 90 | tensor_dimensions: int = 16 # Added this field 91 | schema_config: Optional[VespaSchemaConfig] = None 92 | 93 | @classmethod 94 | def from_dict(cls, config_dict: Dict[str, Any]) -> "VespaQueryConfig": 95 | schema_dict = config_dict.pop("schema", {}) 96 | schema_config = ( 97 | VespaSchemaConfig.from_dict(schema_dict) if schema_dict else None 98 | ) 99 | return cls( 100 | **{k: v for k, v in config_dict.items() if k != "schema"}, 101 | schema_config=schema_config, 102 | ) 103 | 104 | 105 | @dataclass 106 | class QueryResult: 107 | """Data class for query results.""" 108 | 109 | title: str 110 | url: str 111 | page_number: int 112 | relevance: float 113 | text: str 114 | source: Dict[str, Any] 115 | -------------------------------------------------------------------------------- /src/vespa/exceptions.py: -------------------------------------------------------------------------------- 1 | class VespaDeploymentError(Exception): 2 | """Custom exception for Vespa deployment errors.""" 3 | 4 | pass 5 | 6 | 7 | class VespaSetupError(Exception): 8 | """Custom exception for Vespa setup errors.""" 9 | 10 | pass 11 | 12 | 13 | class PDFProcessingError(Exception): 14 | """Custom exception for PDF processing errors.""" 15 | 16 | pass 17 | 18 | 19 | class VespaQueryError(Exception): 20 | """Custom exception for Vespa query errors.""" 21 | 22 | pass 23 | -------------------------------------------------------------------------------- /src/vespa/indexing/pdf_processor.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import platform 4 | import tempfile 5 | from io import BytesIO 6 | from typing import Any, Dict, List, Optional, Tuple 7 | 8 | import requests 9 | import torch 10 | from colpali_engine.models import ColQwen2, ColQwen2Processor 11 | from pdf2image import convert_from_path 12 | from pdf2image.exceptions import PDFPageCountError 13 | from pypdf import PdfReader 14 | from torch import Tensor 15 | from torch import device as torch_device 16 | from torch.utils.data import DataLoader 17 | from tqdm import tqdm 18 | 19 | from src.vespa.datatypes import PDFData 20 | from src.vespa.exceptions import PDFProcessingError 21 | 22 | # Configure logging 23 | logging.basicConfig( 24 | level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" 25 | ) 26 | logger = logging.getLogger(__name__) 27 | 28 | 29 | class PDFProcessor: 30 | """Class for processing PDFs and generating embeddings using ColQwen2 model.""" 31 | 32 | def __init__( 33 | self, 34 | model_name: str = "vidore/colqwen2-v0.1", 35 | batch_size: int = 2, 36 | device: Optional[str] = None, 37 | ) -> None: 38 | """ 39 | Initialize the PDF processor with the specified model. 40 | 41 | Args: 42 | model_name: Name of the pretrained model to use 43 | batch_size: Batch size for processing images 44 | device: Optional specific device to use ('cuda', 'mps', 'cpu'). 45 | If None, will automatically select the optimal device. 46 | 47 | Raises: 48 | RuntimeError: If model loading fails 49 | """ 50 | self.batch_size = batch_size 51 | self.device = self._get_optimal_device(device) 52 | 53 | try: 54 | logger.info(f"Loading model {model_name} on {self.device}") 55 | self.model = ColQwen2.from_pretrained( 56 | model_name, 57 | torch_dtype=torch.bfloat16, 58 | device_map="auto" if self.device.type == "cuda" else None, 59 | ) 60 | self.processor = ColQwen2Processor.from_pretrained(model_name) 61 | self.model.eval() 62 | 63 | # Move model to device if not using device_map="auto" 64 | if self.device.type != "cuda": 65 | self.model.to(self.device) 66 | 67 | except Exception as e: 68 | logger.error(f"Failed to load model: {str(e)}") 69 | raise RuntimeError(f"Model initialization failed: {str(e)}") 70 | 71 | def _get_optimal_device( 72 | self, requested_device: Optional[str] = None 73 | ) -> torch_device: 74 | """ 75 | Determine the optimal device for model execution. 76 | 77 | Args: 78 | requested_device: Optional specific device to use 79 | 80 | Returns: 81 | torch.device: The optimal device for the current system 82 | """ 83 | if requested_device: 84 | # If a specific device is requested, try to use it 85 | if requested_device == "cuda" and not torch.cuda.is_available(): 86 | logger.warning( 87 | "CUDA requested but not available, falling back to optimal device" 88 | ) 89 | elif requested_device == "mps" and not ( 90 | platform.system() == "Darwin" and torch.backends.mps.is_available() 91 | ): 92 | logger.warning( 93 | "MPS requested but not available, falling back to optimal device" 94 | ) 95 | else: 96 | return torch.device(requested_device) 97 | 98 | # Automatic device selection 99 | if torch.cuda.is_available(): 100 | logger.info("Using CUDA device") 101 | return torch.device("cuda") 102 | elif platform.system() == "Darwin" and torch.backends.mps.is_available(): 103 | try: 104 | # Test MPS allocation 105 | test_tensor = torch.zeros(1, device="mps") 106 | logger.info("Using MPS device") 107 | return torch.device("mps") 108 | except Exception as e: 109 | logger.warning(f"MPS device available but failed allocation test: {e}") 110 | logger.info("Falling back to CPU") 111 | return torch.device("cpu") 112 | else: 113 | logger.info("Using CPU device") 114 | return torch.device("cpu") 115 | 116 | def _handle_device_fallback( 117 | self, batch: Dict[str, Tensor] 118 | ) -> Tuple[Dict[str, Tensor], torch_device]: 119 | """ 120 | Handle device fallback if the current device fails. 121 | 122 | Args: 123 | batch: The current batch of data 124 | 125 | Returns: 126 | Tuple of (processed batch, new device) 127 | 128 | Raises: 129 | PDFProcessingError: If processing fails on all devices 130 | """ 131 | devices_to_try = [] 132 | 133 | # Add fallback devices in order of preference 134 | if self.device.type != "cpu": 135 | devices_to_try.append("cpu") 136 | if self.device.type != "cuda" and torch.cuda.is_available(): 137 | devices_to_try.append("cuda") 138 | 139 | for device_type in devices_to_try: 140 | try: 141 | new_device = torch.device(device_type) 142 | logger.warning(f"Attempting fallback to {device_type}") 143 | 144 | self.model.to(new_device) 145 | new_batch = {k: v.to(new_device) for k, v in batch.items()} 146 | 147 | # Test the new device with a forward pass 148 | with torch.no_grad(): 149 | _ = self.model(**new_batch) 150 | 151 | self.device = new_device 152 | return new_batch, new_device 153 | 154 | except Exception as e: 155 | logger.warning(f"Fallback to {device_type} failed: {e}") 156 | continue 157 | 158 | raise PDFProcessingError("Failed to process batch on all available devices") 159 | 160 | def process_pdf(self, pdf_metadata: List[Dict[str, str]]) -> List[PDFData]: 161 | """ 162 | Process multiple PDFs and generate their embeddings. 163 | 164 | Args: 165 | pdf_metadata: List of dictionaries containing PDF metadata (must have 'url' and 'title') 166 | 167 | Returns: 168 | List of PDFData objects containing processed information 169 | 170 | Raises: 171 | PDFProcessingError: If processing any PDF fails 172 | """ 173 | 174 | pdf_data: List[PDFData] = [] 175 | 176 | for pdf in pdf_metadata: 177 | try: 178 | logger.info(f"Processing PDF: {pdf['title']}") 179 | images, texts = self.get_pdf_content(pdf["url"]) 180 | embeddings = self.generate_embeddings(images) 181 | 182 | pdf_data.append( 183 | PDFData( 184 | url=pdf["url"], 185 | title=pdf["title"], 186 | images=images, 187 | texts=texts, 188 | embeddings=embeddings, 189 | ) 190 | ) 191 | 192 | except Exception as e: 193 | logger.error(f"Failed to process PDF {pdf['title']}: {str(e)}") 194 | raise PDFProcessingError(f"Failed to process PDF: {str(e)}") 195 | 196 | return pdf_data 197 | 198 | def get_pdf_content(self, pdf_url: str) -> Tuple[List[Any], List[str]]: 199 | """ 200 | Extract images and text content from PDF. 201 | 202 | Args: 203 | pdf_url: URL of the PDF to process 204 | 205 | Returns: 206 | Tuple containing lists of images and extracted text 207 | 208 | Raises: 209 | PDFProcessingError: If PDF processing fails 210 | """ 211 | try: 212 | logger.info(f"Downloading PDF from {pdf_url}") 213 | response = requests.get(pdf_url, timeout=30) 214 | response.raise_for_status() 215 | pdf_file = BytesIO(response.content) 216 | 217 | with tempfile.NamedTemporaryFile(suffix=".pdf", delete=False) as temp_file: 218 | temp_path = temp_file.name 219 | temp_file.write(pdf_file.read()) 220 | 221 | try: 222 | reader = PdfReader(temp_path) 223 | page_texts = [page.extract_text() for page in reader.pages] 224 | 225 | logger.info("Converting PDF pages to images") 226 | images = convert_from_path(temp_path) 227 | 228 | if len(images) != len(page_texts): 229 | raise PDFProcessingError( 230 | "Mismatch between number of images and texts" 231 | ) 232 | 233 | return images, page_texts 234 | 235 | except PDFPageCountError as e: 236 | raise PDFProcessingError(f"Failed to process PDF pages: {str(e)}") 237 | finally: 238 | # Clean up temporary file 239 | os.unlink(temp_path) 240 | 241 | except requests.exceptions.RequestException as e: 242 | raise PDFProcessingError(f"Failed to download PDF: {str(e)}") 243 | except Exception as e: 244 | raise PDFProcessingError(f"Failed to process PDF: {str(e)}") 245 | 246 | @torch.no_grad() 247 | def generate_embeddings(self, images: List[Any]) -> List[Tensor]: 248 | """ 249 | Generate embeddings for a list of images. 250 | 251 | Args: 252 | images: List of PIL images to process 253 | 254 | Returns: 255 | List of tensor embeddings 256 | 257 | Raises: 258 | PDFProcessingError: If embedding generation fails 259 | """ 260 | try: 261 | embeddings: List[Tensor] = [] 262 | dataloader = DataLoader( 263 | images, 264 | batch_size=self.batch_size, 265 | shuffle=False, 266 | collate_fn=lambda x: self.processor.process_images(x), 267 | ) 268 | 269 | logger.info(f"Generating embeddings using device: {self.device}") 270 | for batch in tqdm(dataloader, desc="Processing batches"): 271 | try: 272 | batch = {k: v.to(self.device) for k, v in batch.items()} 273 | batch_embeddings = self.model(**batch) 274 | embeddings.extend(list(torch.unbind(batch_embeddings.to("cpu")))) 275 | 276 | except RuntimeError as e: 277 | if any(err in str(e) for err in ["MPS", "CUDA", "device"]): 278 | # Try fallback to another device 279 | batch, new_device = self._handle_device_fallback(batch) 280 | batch_embeddings = self.model(**batch) 281 | embeddings.extend( 282 | list(torch.unbind(batch_embeddings.to("cpu"))) 283 | ) 284 | else: 285 | raise 286 | 287 | return embeddings 288 | 289 | except Exception as e: 290 | logger.error(f"Failed to generate embeddings: {str(e)}") 291 | raise PDFProcessingError(f"Embedding generation failed: {str(e)}") 292 | -------------------------------------------------------------------------------- /src/vespa/indexing/prepare_feed.py: -------------------------------------------------------------------------------- 1 | import base64 2 | import hashlib 3 | import logging 4 | from io import BytesIO 5 | from typing import Dict, List, Optional 6 | 7 | import numpy as np 8 | import numpy.typing as npt 9 | from PIL import Image 10 | from torch import Tensor 11 | 12 | from src.vespa.indexing.pdf_processor import PDFData 13 | 14 | logging.basicConfig(level=logging.DEBUG) 15 | logger = logging.getLogger(__name__) 16 | 17 | 18 | class VespaFeedPreparator: 19 | def __init__(self, max_image_height: int = 800) -> None: 20 | self.max_image_height = max_image_height 21 | 22 | def prepare_feed(self, pdf_data: List[PDFData]) -> List[Dict]: 23 | try: 24 | vespa_feed = [] 25 | for pdf in pdf_data: 26 | logger.info(f"Processing PDF: {pdf.title}") 27 | for page_number, (text, embedding, image) in enumerate( 28 | zip(pdf.texts, pdf.embeddings, pdf.images) 29 | ): 30 | page_id = self._generate_page_id(pdf.url, page_number) 31 | processed_image = self._process_image(image) 32 | embedding_dict = self._process_embeddings(embedding) 33 | 34 | doc = { 35 | "fields": { 36 | "id": page_id, 37 | "url": pdf.url, 38 | "title": pdf.title, 39 | "page_number": page_number, 40 | "image": processed_image, 41 | "text": text, 42 | "embedding": { 43 | "blocks": self._convert_to_patch_blocks(embedding_dict) 44 | }, 45 | } 46 | } 47 | vespa_feed.append(doc) 48 | return vespa_feed 49 | except Exception as e: 50 | logger.error(f"Failed to prepare feed: {str(e)}") 51 | raise 52 | 53 | def _convert_to_patch_blocks(self, embedding_dict: Dict[int, str]) -> List[Dict]: 54 | return [ 55 | {"address": {"patch": patch_idx}, "values": vector} 56 | for patch_idx, vector in embedding_dict.items() 57 | if vector != "0" * 32 58 | ] 59 | 60 | def _generate_page_id(self, url: str, page_number: int) -> str: 61 | content = f"{url}{page_number}".encode("utf-8") 62 | return hashlib.sha256(content).hexdigest() 63 | 64 | def _process_embeddings(self, embedding: Tensor) -> Dict[int, str]: 65 | embedding_dict = {} 66 | embedding_float = embedding.detach().cpu().float() 67 | embedding_np = embedding_float.numpy() 68 | for idx, patch_embedding in enumerate(embedding_np): 69 | binary_vector = self._convert_to_binary_vector(patch_embedding) 70 | embedding_dict[idx] = binary_vector 71 | return embedding_dict 72 | 73 | @staticmethod 74 | def _convert_to_binary_vector(patch_embedding: npt.NDArray[np.float32]) -> str: 75 | binary = np.packbits(np.where(patch_embedding > 0, 1, 0)) 76 | return binary.astype(np.int8).tobytes().hex() 77 | 78 | def _process_image(self, image: Image.Image) -> str: 79 | resized_image = self._resize_image(image) 80 | return self._encode_image_base64(resized_image) 81 | 82 | def _resize_image( 83 | self, image: Image.Image, target_width: Optional[int] = 640 84 | ) -> Image.Image: 85 | width, height = image.size 86 | if height > self.max_image_height: 87 | ratio = self.max_image_height / height 88 | new_width = int(width * ratio) 89 | return image.resize((new_width, self.max_image_height), Image.LANCZOS) 90 | if target_width and width > target_width: 91 | ratio = target_width / width 92 | new_height = int(height * ratio) 93 | return image.resize((target_width, new_height), Image.LANCZOS) 94 | return image 95 | 96 | @staticmethod 97 | def _encode_image_base64(image: Image.Image) -> str: 98 | buffered = BytesIO() 99 | image.save(buffered, format="JPEG", quality=85, optimize=True) 100 | return base64.b64encode(buffered.getvalue()).decode("utf-8") 101 | -------------------------------------------------------------------------------- /src/vespa/indexing/run.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from pathlib import Path 3 | from time import time 4 | from typing import Any, Dict, List, Optional 5 | 6 | import yaml 7 | from tqdm import tqdm 8 | 9 | from src.vespa.datatypes import ( 10 | PDFInput, 11 | VespaConfig, 12 | VespaSchemaConfig, 13 | ) 14 | from src.vespa.indexing.pdf_processor import PDFProcessor 15 | from src.vespa.indexing.prepare_feed import VespaFeedPreparator 16 | from src.vespa.setup import VespaSetup 17 | from vespa.application import Vespa 18 | from vespa.deployment import VespaCloud 19 | 20 | logging.basicConfig(level=logging.DEBUG) 21 | logger = logging.getLogger(__name__) 22 | 23 | 24 | class VespaDeployer: 25 | def __init__( 26 | self, 27 | config: VespaConfig, 28 | pdf_processor: Optional[PDFProcessor] = None, 29 | feed_preparator: Optional[VespaFeedPreparator] = None, 30 | ) -> None: 31 | self.config = config 32 | self.pdf_processor = pdf_processor or PDFProcessor() 33 | self.feed_preparator = feed_preparator or VespaFeedPreparator() 34 | 35 | async def _feed_data(self, app: Vespa, vespa_feed: List[Dict]) -> None: 36 | failed_documents = [] 37 | 38 | async with app.asyncio( 39 | connections=self.config.connections, timeout=self.config.timeout 40 | ) as session: 41 | logger.info("Starting data feed process") 42 | 43 | for doc in tqdm(vespa_feed, desc="Feeding documents"): 44 | try: 45 | logger.debug(f"Feeding document: {doc['fields']['url']}") 46 | 47 | response = await session.feed_data_point( 48 | schema=self.config.schema_name, 49 | data_id=doc["fields"]["url"], 50 | fields=doc["fields"], 51 | ) 52 | 53 | if not response.is_successful(): 54 | error_msg = ( 55 | response.get_json() 56 | if response.get_json() 57 | else str(response) 58 | ) 59 | logger.error(f"Feed failed: {error_msg}") 60 | failed_documents.append( 61 | {"id": doc["fields"]["url"], "error": error_msg} 62 | ) 63 | 64 | except Exception as e: 65 | logger.error(f"Feed error for {doc['fields']['url']}: {str(e)}") 66 | failed_documents.append( 67 | {"id": doc["fields"]["url"], "error": str(e)} 68 | ) 69 | 70 | if failed_documents: 71 | self._save_failed_documents(failed_documents) 72 | error_details = "\n".join( 73 | [f"Doc {doc['id']}: {doc['error']}" for doc in failed_documents] 74 | ) 75 | raise Exception(f"Documents failed to feed:\n{error_details}") 76 | 77 | @staticmethod 78 | def _save_failed_documents(failed_docs: List[Dict[str, Any]]) -> None: 79 | output_dir = Path("logs/failed_documents") 80 | output_dir.mkdir(parents=True, exist_ok=True) 81 | output_file = output_dir / f"failed_documents_{int(time())}.yaml" 82 | with open(output_file, "w") as f: 83 | yaml.dump(failed_docs, f) 84 | logger.info(f"Saved failed documents to {output_file}") 85 | 86 | async def deploy_and_feed(self, vespa_feed: List[Dict]) -> Vespa: 87 | try: 88 | vespa_setup = VespaSetup( 89 | app_name=self.config.app_name, 90 | schema_config=self.config.schema_config or VespaSchemaConfig(), 91 | ) 92 | 93 | vespa_cloud = VespaCloud( 94 | tenant=self.config.tenant_name, 95 | application=self.config.app_name, 96 | application_package=vespa_setup.app_package, 97 | ) 98 | 99 | logger.info("Deploying to Vespa Cloud") 100 | app = vespa_cloud.deploy() 101 | await self._feed_data(app, vespa_feed) 102 | return app 103 | 104 | except Exception as e: 105 | logger.error(f"Vespa deployment failed: {str(e)}") 106 | raise 107 | 108 | 109 | async def run_indexing( 110 | config_path: str, 111 | pdfs: List[PDFInput], 112 | pdf_processor: Optional[PDFProcessor] = None, 113 | feed_preparator: Optional[VespaFeedPreparator] = None, 114 | ) -> None: 115 | try: 116 | if not pdfs: 117 | raise ValueError("PDF list cannot be empty") 118 | 119 | config_path = Path(config_path) 120 | if not config_path.exists(): 121 | raise FileNotFoundError(f"Configuration file not found: {config_path}") 122 | 123 | with open(config_path) as f: 124 | config_data = yaml.safe_load(f) 125 | 126 | if not config_data.get("vespa"): 127 | raise ValueError("Invalid configuration: 'vespa' section missing") 128 | 129 | schema_config = config_data["vespa"].get("schema") 130 | vespa_config = VespaConfig( 131 | app_name=config_data["vespa"]["app_name"], 132 | tenant_name=config_data["vespa"]["tenant_name"], 133 | connections=config_data["vespa"].get("connections", 1), 134 | timeout=config_data["vespa"].get("timeout", 180), 135 | schema_name=config_data["vespa"].get("schema_name", "pdf_page"), 136 | schema_config=VespaSchemaConfig.from_dict(schema_config) 137 | if schema_config 138 | else None, 139 | ) 140 | 141 | deployer = VespaDeployer( 142 | config=vespa_config, 143 | pdf_processor=pdf_processor, 144 | feed_preparator=feed_preparator, 145 | ) 146 | 147 | processed_data = pdf_processor.process_pdf( 148 | [{"title": pdf.title, "url": pdf.url} for pdf in pdfs] 149 | ) 150 | vespa_feed = feed_preparator.prepare_feed(processed_data) 151 | 152 | await deployer.deploy_and_feed(vespa_feed) 153 | logger.info("Indexing process completed successfully") 154 | 155 | except Exception as e: 156 | logger.error(f"Indexing failed: {str(e)}") 157 | raise 158 | -------------------------------------------------------------------------------- /src/vespa/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from colpali_engine.models import ColQwen2, ColQwen2Processor 3 | 4 | 5 | def create_model(model_name: str = "vidore/colqwen2-v0.1"): 6 | """ 7 | Load a pre-trained ColQwen2 model and processor. 8 | 9 | Args: 10 | model_name: The name of the pre-trained model to load (default: "vidore/colqwen2-v0.1") 11 | 12 | Returns: 13 | A tuple (model, processor) containing the pre-trained model and processor 14 | """ 15 | model = ColQwen2.from_pretrained( 16 | model_name, torch_dtype=torch.bfloat16, device_map="auto" 17 | ) 18 | processor = ColQwen2Processor.from_pretrained(model_name) 19 | model.eval() 20 | return model, processor 21 | -------------------------------------------------------------------------------- /src/vespa/retrieval/run.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from pathlib import Path 3 | from typing import Dict, List, Optional 4 | 5 | import torch 6 | import yaml 7 | from colpali_engine.models import ColQwen2, ColQwen2Processor 8 | from torch.utils.data import DataLoader 9 | 10 | from src.vespa.datatypes import QueryResult, VespaQueryConfig 11 | from src.vespa.exceptions import VespaQueryError 12 | from src.vespa.setup import VespaSetup 13 | from vespa.deployment import VespaCloud 14 | from vespa.io import VespaQueryResponse 15 | 16 | logging.basicConfig(level=logging.DEBUG) 17 | logger = logging.getLogger(__name__) 18 | 19 | 20 | class VespaQuerier: 21 | def __init__( 22 | self, 23 | config: VespaQueryConfig, 24 | setup: Optional[VespaSetup] = None, 25 | model_name: str = "vidore/colqwen2-v0.1", 26 | ): 27 | self.config = config 28 | self.setup = setup or VespaSetup( 29 | app_name=config.app_name, schema_config=config.schema_config 30 | ) 31 | 32 | logger.info(f"Loading model {model_name}") 33 | self.model = ColQwen2.from_pretrained( 34 | model_name, torch_dtype=torch.bfloat16, device_map="auto" 35 | ) 36 | self.processor = ColQwen2Processor.from_pretrained(model_name) 37 | self.model.eval() 38 | 39 | self._init_vespa_cloud() 40 | 41 | def _init_vespa_cloud(self) -> None: 42 | try: 43 | self.vespa_cloud = VespaCloud( 44 | tenant=self.config.tenant_name, 45 | application=self.config.app_name, 46 | application_package=self.setup.app_package, 47 | ) 48 | except Exception as e: 49 | logger.error(f"Failed to initialize VespaCloud: {str(e)}") 50 | raise VespaQueryError(f"VespaCloud initialization failed: {str(e)}") 51 | 52 | def _prepare_query_tensors(self, query: str) -> Dict[str, List[float]]: 53 | dataloader = DataLoader( 54 | [query], 55 | batch_size=1, 56 | shuffle=False, 57 | collate_fn=lambda x: self.processor.process_queries(x), 58 | ) 59 | 60 | with torch.no_grad(): 61 | batch_query = next(iter(dataloader)) 62 | batch_query = {k: v.to(self.model.device) for k, v in batch_query.items()} 63 | embeddings_query = self.model(**batch_query) 64 | query_embedding = embeddings_query[0].cpu().float() 65 | 66 | # Convert embedding to list format 67 | embedding_list = query_embedding.tolist() 68 | if isinstance(embedding_list[0], list): 69 | # Handle case where embedding is 2D 70 | query_tensors = {i: emb for i, emb in enumerate(embedding_list)} 71 | else: 72 | # Handle case where embedding is 1D 73 | query_tensors = {0: embedding_list} 74 | 75 | return {"input.query(qt)": query_tensors} 76 | 77 | async def execute_queries(self, queries: List[str]) -> Dict[str, List[QueryResult]]: 78 | try: 79 | app = self.vespa_cloud.get_application() 80 | results: Dict[str, List[QueryResult]] = {} 81 | 82 | async with app.asyncio( 83 | connections=self.config.connections, timeout=self.config.timeout 84 | ) as session: 85 | for query in queries: 86 | try: 87 | logger.info(f"Executing query: {query}") 88 | query_tensors = self._prepare_query_tensors(query) 89 | 90 | logger.debug(f"Query tensors: {query_tensors}") 91 | 92 | response = await session.query( 93 | yql="select id, title, url, text, page_number, image from pdf_page where userInput(@userQuery)", 94 | userQuery=query, 95 | hits=self.config.hits_per_query, 96 | body={ 97 | **query_tensors, 98 | "presentation.timing": True, 99 | "timeout": str(self.config.timeout), 100 | }, 101 | ) 102 | 103 | if not response.is_successful(): 104 | error_msg = ( 105 | response.get_json() 106 | if hasattr(response, "get_json") 107 | else str(response) 108 | ) 109 | logger.error(f"Query response error: {error_msg}") 110 | raise VespaQueryError(f"Query failed: {error_msg}") 111 | 112 | results[query] = self._process_response(response) 113 | 114 | except Exception as e: 115 | logger.error(f"Query failed for '{query}': {str(e)}") 116 | results[query] = [] 117 | 118 | return results 119 | 120 | except Exception as e: 121 | logger.error(f"Failed to execute queries: {str(e)}") 122 | raise VespaQueryError(f"Query execution failed: {str(e)}") 123 | 124 | def _process_response(self, response: VespaQueryResponse) -> List[QueryResult]: 125 | try: 126 | results = [] 127 | for hit in response.hits: 128 | fields = hit["fields"] 129 | results.append( 130 | QueryResult( 131 | title=fields.get("title", "N/A"), 132 | url=fields.get("url", "N/A"), 133 | page_number=fields.get("page_number", -1), 134 | relevance=float(hit.get("relevance", 0.0)), 135 | text=fields.get("text", ""), 136 | source=hit, 137 | ) 138 | ) 139 | return results 140 | except Exception as e: 141 | logger.error(f"Error processing response: {str(e)}") 142 | return [] 143 | 144 | def display_results(self, query: str, results: List[QueryResult]) -> None: 145 | print(f"\nQuery: {query}") 146 | print(f"Total Results: {len(results)}") 147 | 148 | for idx, result in enumerate(results, 1): 149 | print(f"\nResult {idx}:") 150 | print(f"Title: {result.title}") 151 | print(f"URL: {result.url}") 152 | print(f"Page: {result.page_number}") 153 | print(f"Score: {result.relevance:.4f}") 154 | text_preview = ( 155 | result.text[:200] + "..." if len(result.text) > 200 else result.text 156 | ) 157 | print(f"Text Preview: {text_preview}") 158 | 159 | 160 | async def run_queries( 161 | config_path: str, queries: List[str], display_results: bool = True 162 | ) -> Dict[str, List[QueryResult]]: 163 | try: 164 | config_path = Path(config_path) 165 | if not config_path.exists(): 166 | raise FileNotFoundError(f"Configuration file not found: {config_path}") 167 | 168 | with open(config_path) as f: 169 | config_data = yaml.safe_load(f) 170 | 171 | if not config_data.get("vespa"): 172 | raise ValueError("Invalid configuration: 'vespa' section missing") 173 | 174 | vespa_config = VespaQueryConfig.from_dict(config_data["vespa"]) 175 | querier = VespaQuerier(vespa_config) 176 | 177 | logger.info(f"Executing {len(queries)} queries") 178 | results = await querier.execute_queries(queries) 179 | 180 | if display_results: 181 | for query, query_results in results.items(): 182 | querier.display_results(query, query_results) 183 | 184 | return results 185 | 186 | except Exception as e: 187 | logger.error(f"Failed to run queries: {str(e)}") 188 | raise 189 | 190 | 191 | async def main() -> Dict[str, List[QueryResult]]: 192 | # Path to your config file 193 | config_path = ( 194 | "/Users/vesaalexandru/Workspaces/cube/cube-publication/" 195 | "evaluate-ocr-rag-systems/src/vespa/vespa_config.yaml" 196 | ) 197 | 198 | # Test queries 199 | queries = [ 200 | "Percentage of non-fresh water as source?", 201 | ] 202 | 203 | try: 204 | logger.info("Starting query execution") 205 | results = await run_queries( 206 | config_path=config_path, queries=queries, display_results=True 207 | ) 208 | logger.info("Query execution completed successfully") 209 | return results 210 | 211 | except Exception as e: 212 | logger.error(f"Query execution failed: {str(e)}") 213 | raise 214 | 215 | 216 | if __name__ == "__main__": 217 | import asyncio 218 | 219 | try: 220 | results = asyncio.run(main()) 221 | 222 | # Print summary of results 223 | print("\nResults Summary:") 224 | for query, query_results in results.items(): 225 | print(f"\nQuery: {query}") 226 | print(f"Number of results: {len(query_results)}") 227 | if query_results: 228 | print(f"Top result score: {query_results[0].relevance:.4f}") 229 | print(f"Top result title: {query_results[0].title}") 230 | 231 | except KeyboardInterrupt: 232 | logger.info("Query execution interrupted by user") 233 | except Exception as e: 234 | logger.error(f"Error occurred: {str(e)}") 235 | raise 236 | -------------------------------------------------------------------------------- /src/vespa/setup.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from typing import List, Optional 3 | 4 | from src.vespa.datatypes import VespaSchemaConfig 5 | from src.vespa.exceptions import VespaSetupError 6 | from vespa.package import ( 7 | HNSW, 8 | ApplicationPackage, 9 | Document, 10 | Field, 11 | FieldSet, 12 | FirstPhaseRanking, 13 | Function, 14 | RankProfile, 15 | Schema, 16 | SecondPhaseRanking, 17 | ) 18 | 19 | # Configure logging 20 | logging.basicConfig( 21 | level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" 22 | ) 23 | logger = logging.getLogger(__name__) 24 | 25 | 26 | class VespaSetup: 27 | """Class for setting up Vespa application package and schema.""" 28 | 29 | def __init__( 30 | self, app_name: str, schema_config: Optional[VespaSchemaConfig] = None 31 | ) -> None: 32 | """ 33 | Initialize Vespa setup. 34 | 35 | Args: 36 | app_name: Name of the Vespa application 37 | schema_config: Optional configuration for schema settings 38 | """ 39 | self.app_name = app_name 40 | self.config = schema_config or VespaSchemaConfig() 41 | 42 | try: 43 | logger.info(f"Creating Vespa schema for application: {app_name}") 44 | self.schema = self._create_schema() 45 | self.app_package = ApplicationPackage(name=app_name, schema=[self.schema]) 46 | except Exception as e: 47 | logger.error(f"Failed to create Vespa setup: {str(e)}") 48 | raise VespaSetupError(f"Setup failed: {str(e)}") 49 | 50 | def _create_schema(self) -> Schema: 51 | """ 52 | Create the Vespa schema with fields and rank profiles. 53 | 54 | Returns: 55 | Configured Schema object 56 | 57 | Raises: 58 | VespaSetupError: If schema creation fails 59 | """ 60 | try: 61 | schema = Schema( 62 | name="pdf_page", 63 | document=Document(fields=self._create_fields()), 64 | fieldsets=[FieldSet(name="default", fields=["title", "text"])], 65 | ) 66 | 67 | self._add_rank_profiles(schema) 68 | return schema 69 | 70 | except Exception as e: 71 | logger.error(f"Failed to create schema: {str(e)}") 72 | raise VespaSetupError(f"Schema creation failed: {str(e)}") 73 | 74 | def _create_fields(self) -> List[Field]: 75 | """ 76 | Create the fields for the schema. 77 | 78 | Returns: 79 | List of Field objects 80 | """ 81 | return [ 82 | Field( 83 | name="id", type="string", indexing=["summary", "index"], match=["word"] 84 | ), 85 | Field(name="url", type="string", indexing=["summary", "index"]), 86 | Field( 87 | name="title", 88 | type="string", 89 | indexing=["summary", "index"], 90 | match=["text"], 91 | index="enable-bm25", 92 | ), 93 | Field(name="page_number", type="int", indexing=["summary", "attribute"]), 94 | Field(name="image", type="raw", indexing=["summary"]), 95 | Field( 96 | name="text", 97 | type="string", 98 | indexing=["index"], 99 | match=["text"], 100 | index="enable-bm25", 101 | ), 102 | Field( 103 | name="embedding", 104 | type=f"tensor(patch{{}}, v[{self.config.tensor_dimensions}])", 105 | indexing=["attribute", "index"], 106 | ann=HNSW( 107 | distance_metric="hamming", 108 | max_links_per_node=self.config.hnsw_max_links, 109 | neighbors_to_explore_at_insert=self.config.hnsw_neighbors, 110 | ), 111 | ), 112 | ] 113 | 114 | def _add_rank_profiles(self, schema: Schema) -> None: 115 | """ 116 | Add rank profiles to the schema. 117 | 118 | Args: 119 | schema: Schema object to add rank profiles to 120 | """ 121 | # Add default ranking profile 122 | schema.add_rank_profile(self._create_default_profile()) 123 | 124 | # Add retrieval and rerank profile 125 | schema.add_rank_profile(self._create_retrieval_rerank_profile()) 126 | 127 | def _create_default_profile(self) -> RankProfile: 128 | """ 129 | Create the default rank profile. 130 | 131 | Returns: 132 | Configured RankProfile object 133 | """ 134 | return RankProfile( 135 | name="default", 136 | inputs=[("query(qt)", "tensor(querytoken{}, v[128])")], 137 | functions=[ 138 | Function( 139 | name="max_sim", 140 | expression=""" 141 | sum( 142 | reduce( 143 | sum( 144 | query(qt) * unpack_bits(attribute(embedding)) , v 145 | ), 146 | max, patch 147 | ), 148 | querytoken 149 | ) 150 | """, 151 | ), 152 | Function(name="bm25_score", expression="bm25(title) + bm25(text)"), 153 | ], 154 | first_phase=FirstPhaseRanking(expression="bm25_score"), 155 | second_phase=SecondPhaseRanking(expression="max_sim", rerank_count=100), 156 | ) 157 | 158 | def _create_retrieval_rerank_profile(self) -> RankProfile: 159 | """ 160 | Create the retrieval and rerank profile. 161 | 162 | Returns: 163 | Configured RankProfile object 164 | """ 165 | input_query_tensors = [] 166 | 167 | # Add query tensors for each term 168 | for i in range(self.config.max_query_terms): 169 | input_query_tensors.append( 170 | (f"query(rq{i})", f"tensor(v[{self.config.tensor_dimensions}])") 171 | ) 172 | 173 | # Add additional query tensors 174 | input_query_tensors.extend( 175 | [ 176 | ("query(qt)", "tensor(querytoken{}, v[128])"), 177 | ( 178 | "query(qtb)", 179 | f"tensor(querytoken{{}}, v[{self.config.tensor_dimensions}])", 180 | ), 181 | ] 182 | ) 183 | 184 | return RankProfile( 185 | name="retrieval-and-rerank", 186 | inputs=input_query_tensors, 187 | functions=[ 188 | Function( 189 | name="max_sim", 190 | expression=""" 191 | sum( 192 | reduce( 193 | sum( 194 | query(qt) * unpack_bits(attribute(embedding)) , v 195 | ), 196 | max, patch 197 | ), 198 | querytoken 199 | ) 200 | """, 201 | ), 202 | Function( 203 | name="max_sim_binary", 204 | expression=""" 205 | sum( 206 | reduce( 207 | 1/(1 + sum( 208 | hamming(query(qtb), attribute(embedding)) ,v) 209 | ), 210 | max, 211 | patch 212 | ), 213 | querytoken 214 | ) 215 | """, 216 | ), 217 | ], 218 | first_phase=FirstPhaseRanking(expression="max_sim_binary"), 219 | second_phase=SecondPhaseRanking( 220 | expression="max_sim", rerank_count=self.config.rerank_count 221 | ), 222 | ) 223 | -------------------------------------------------------------------------------- /src/vespa/utils.py: -------------------------------------------------------------------------------- 1 | from IPython.display import HTML, display 2 | 3 | 4 | def display_query_results(query, response, hits=5): 5 | query_time = response.json.get("timing", {}).get("searchtime", -1) 6 | query_time = round(query_time, 2) 7 | count = response.json.get("root", {}).get("fields", {}).get("totalCount", 0) 8 | html_content = f"

Query text: '{query}', query time {query_time}s, count={count}, top results:

" 9 | 10 | for i, hit in enumerate(response.hits[:hits]): 11 | title = hit["fields"]["title"] 12 | url = hit["fields"]["url"] 13 | page = hit["fields"]["page_number"] 14 | image = hit["fields"]["image"] 15 | score = hit["relevance"] 16 | 17 | html_content += f"

PDF Result {i + 1}

" 18 | html_content += f'

Title: {title}, page {page+1} with score {score:.2f}

' 19 | html_content += ( 20 | f'' 21 | ) 22 | 23 | display(HTML(html_content)) 24 | -------------------------------------------------------------------------------- /src/vespa/vespa_config.yaml: -------------------------------------------------------------------------------- 1 | # vespa: 2 | # tenant_name: "cube-digital" 3 | # app_name: "test4" 4 | # schema_name: "pdf_page" 5 | # connections: 1 6 | # timeout: 180 7 | # hits_per_query: 5 8 | # schema: 9 | # max_query_terms: 64 10 | # hnsw_max_links: 32 11 | # hnsw_neighbors: 400 12 | # rerank_count: 10 13 | # tensor_dimensions: 16 14 | 15 | 16 | 17 | vespa: 18 | app_name: "test" 19 | tenant_name: "ml-vanguards" 20 | connections: 1 21 | timeout: 180 22 | hits_per_query: 5 23 | schema_name: "pdf_page" 24 | tensor_dimensions: 16 25 | schema: 26 | max_query_terms: 64 27 | hnsw_max_links: 32 28 | hnsw_neighbors: 400 29 | rerank_count: 10 -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import base64 2 | import json 3 | import os 4 | import tempfile 5 | from typing import Dict, List, Optional 6 | 7 | import requests 8 | from pdf2image import convert_from_path 9 | 10 | from src.interfaces.interfaces import BasePromptTemplate 11 | 12 | 13 | class PDFVisionProcessor: 14 | def __init__(self, temp_dir: Optional[str] = None): 15 | """ 16 | Initialize the PDF Vision Processor. 17 | Args: 18 | temp_dir: Optional path to temporary directory. If None, system temp dir is used. 19 | """ 20 | self.temp_dir = temp_dir or tempfile.mkdtemp() 21 | os.makedirs(self.temp_dir, exist_ok=True) 22 | 23 | def convert_pdf_to_images(self, pdf_path: str) -> List[str]: 24 | """ 25 | Convert PDF pages to images and save them in the temporary directory. 26 | Args: 27 | pdf_path: Path to the PDF file 28 | Returns: 29 | List of paths to generated images 30 | """ 31 | image_paths = [] 32 | try: 33 | # Convert PDF to images 34 | pages = convert_from_path(pdf_path) 35 | 36 | # Save each page as an image 37 | for i, page in enumerate(pages): 38 | image_path = os.path.join(self.temp_dir, f"page_{i}.png") 39 | page.save(image_path, "PNG") 40 | image_paths.append(image_path) 41 | 42 | return image_paths 43 | except Exception as e: 44 | print(f"Error converting PDF to images: {str(e)}") 45 | return [] 46 | 47 | @staticmethod 48 | def encode_image_to_base64(image_path: str) -> str: 49 | """Convert an image file to a base64 encoded string.""" 50 | with open(image_path, "rb") as image_file: 51 | return base64.b64encode(image_file.read()).decode("utf-8") 52 | 53 | def process_pdfs( 54 | self, pdf_paths: List[str], prompt: BasePromptTemplate 55 | ) -> Dict[str, List[Dict]]: 56 | """ 57 | Process multiple PDFs and perform vision OCR on each page. 58 | Args: 59 | pdf_paths: List of paths to PDF files 60 | prompt: Prompt template for vision processing 61 | Returns: 62 | Dictionary with PDF paths as keys and lists of OCR results as values 63 | """ 64 | results = {} 65 | 66 | for pdf_path in pdf_paths: 67 | pdf_results = [] 68 | image_paths = self.convert_pdf_to_images(pdf_path) 69 | 70 | for image_path in image_paths: 71 | ocr_result = self.perform_vision_ocr(image_path, prompt) 72 | if ocr_result: 73 | pdf_results.append(ocr_result) 74 | 75 | results[pdf_path] = pdf_results 76 | 77 | return results 78 | 79 | def perform_vision_ocr( 80 | self, image_path: str, prompt: BasePromptTemplate 81 | ) -> Optional[Dict]: 82 | """Perform OCR on the given image using Llama 3.2-Vision.""" 83 | base64_image = self.encode_image_to_base64(image_path) 84 | 85 | response = requests.post( 86 | "http://localhost:11434/api/chat", 87 | json={ 88 | "model": "llama3.2-vision", 89 | "messages": [ 90 | { 91 | "role": "user", 92 | "content": prompt().create_template(), 93 | "images": [base64_image], 94 | }, 95 | ], 96 | }, 97 | ) 98 | 99 | if response.status_code == 200: 100 | full_content = "" 101 | for line in response.iter_lines(): 102 | if line: 103 | json_obj = json.loads(line) 104 | full_content += json_obj["message"]["content"] 105 | 106 | try: 107 | return json.loads(full_content) 108 | except json.JSONDecodeError: 109 | return {"raw_content": full_content} 110 | else: 111 | print(f"Error: {response.status_code} {response.text}") 112 | return None 113 | 114 | def cleanup(self): 115 | """Remove temporary files and directory.""" 116 | try: 117 | for file in os.listdir(self.temp_dir): 118 | os.remove(os.path.join(self.temp_dir, file)) 119 | os.rmdir(self.temp_dir) 120 | except Exception as e: 121 | print(f"Error during cleanup: {str(e)}") 122 | 123 | 124 | if __name__ == "__main__": 125 | from src.prompts.llama_vision import AnalyzePdfImage 126 | 127 | # Example usage 128 | pdf_paths = [ 129 | "/Users/vesaalexandru/Workspaces/cube/cube-publication/evaluate-ocr-rag-systems/data/paper01-1-2.pdf" 130 | ] 131 | 132 | # Initialize processor with custom temp directory (optional) 133 | processor = PDFVisionProcessor() 134 | 135 | try: 136 | # Process PDFs and get results 137 | results = processor.process_pdfs(pdf_paths, AnalyzePdfImage) 138 | 139 | # Print results 140 | for pdf_path, ocr_results in results.items(): 141 | print(f"\nResults for {pdf_path}:") 142 | for i, result in enumerate(ocr_results): 143 | print(f"Page {i + 1}:", result) 144 | 145 | finally: 146 | # Clean up temporary files 147 | processor.cleanup() 148 | -------------------------------------------------------------------------------- /test/schemas/pdf_page.sd: -------------------------------------------------------------------------------- 1 | schema pdf_page { 2 | document pdf_page { 3 | field id type string { 4 | indexing: summary | index 5 | match { 6 | word 7 | } 8 | } 9 | field url type string { 10 | indexing: summary | index 11 | } 12 | field title type string { 13 | indexing: summary | index 14 | index: enable-bm25 15 | match { 16 | text 17 | } 18 | } 19 | field page_number type int { 20 | indexing: summary | attribute 21 | } 22 | field image type raw { 23 | indexing: summary 24 | } 25 | field text type string { 26 | indexing: index 27 | index: enable-bm25 28 | match { 29 | text 30 | } 31 | } 32 | field embedding type tensor(patch{}, v[16]) { 33 | indexing: attribute | index 34 | attribute { 35 | distance-metric: hamming 36 | } 37 | index { 38 | hnsw { 39 | max-links-per-node: 32 40 | neighbors-to-explore-at-insert: 400 41 | } 42 | } 43 | } 44 | } 45 | fieldset default { 46 | fields: title, text 47 | } 48 | rank-profile default { 49 | inputs { 50 | query(qt) tensor(querytoken{}, v[128]) 51 | 52 | } 53 | function max_sim() { 54 | expression { 55 | 56 | sum( 57 | reduce( 58 | sum( 59 | query(qt) * unpack_bits(attribute(embedding)) , v 60 | ), 61 | max, patch 62 | ), 63 | querytoken 64 | ) 65 | 66 | } 67 | } 68 | function bm25_score() { 69 | expression { 70 | bm25(title) + bm25(text) 71 | } 72 | } 73 | first-phase { 74 | expression { 75 | bm25_score 76 | } 77 | } 78 | second-phase { 79 | rerank-count: 100 80 | expression { 81 | max_sim 82 | } 83 | } 84 | } 85 | rank-profile retrieval-and-rerank { 86 | inputs { 87 | query(rq0) tensor(v[16]) 88 | query(rq1) tensor(v[16]) 89 | query(rq2) tensor(v[16]) 90 | query(rq3) tensor(v[16]) 91 | query(rq4) tensor(v[16]) 92 | query(rq5) tensor(v[16]) 93 | query(rq6) tensor(v[16]) 94 | query(rq7) tensor(v[16]) 95 | query(rq8) tensor(v[16]) 96 | query(rq9) tensor(v[16]) 97 | query(rq10) tensor(v[16]) 98 | query(rq11) tensor(v[16]) 99 | query(rq12) tensor(v[16]) 100 | query(rq13) tensor(v[16]) 101 | query(rq14) tensor(v[16]) 102 | query(rq15) tensor(v[16]) 103 | query(rq16) tensor(v[16]) 104 | query(rq17) tensor(v[16]) 105 | query(rq18) tensor(v[16]) 106 | query(rq19) tensor(v[16]) 107 | query(rq20) tensor(v[16]) 108 | query(rq21) tensor(v[16]) 109 | query(rq22) tensor(v[16]) 110 | query(rq23) tensor(v[16]) 111 | query(rq24) tensor(v[16]) 112 | query(rq25) tensor(v[16]) 113 | query(rq26) tensor(v[16]) 114 | query(rq27) tensor(v[16]) 115 | query(rq28) tensor(v[16]) 116 | query(rq29) tensor(v[16]) 117 | query(rq30) tensor(v[16]) 118 | query(rq31) tensor(v[16]) 119 | query(rq32) tensor(v[16]) 120 | query(rq33) tensor(v[16]) 121 | query(rq34) tensor(v[16]) 122 | query(rq35) tensor(v[16]) 123 | query(rq36) tensor(v[16]) 124 | query(rq37) tensor(v[16]) 125 | query(rq38) tensor(v[16]) 126 | query(rq39) tensor(v[16]) 127 | query(rq40) tensor(v[16]) 128 | query(rq41) tensor(v[16]) 129 | query(rq42) tensor(v[16]) 130 | query(rq43) tensor(v[16]) 131 | query(rq44) tensor(v[16]) 132 | query(rq45) tensor(v[16]) 133 | query(rq46) tensor(v[16]) 134 | query(rq47) tensor(v[16]) 135 | query(rq48) tensor(v[16]) 136 | query(rq49) tensor(v[16]) 137 | query(rq50) tensor(v[16]) 138 | query(rq51) tensor(v[16]) 139 | query(rq52) tensor(v[16]) 140 | query(rq53) tensor(v[16]) 141 | query(rq54) tensor(v[16]) 142 | query(rq55) tensor(v[16]) 143 | query(rq56) tensor(v[16]) 144 | query(rq57) tensor(v[16]) 145 | query(rq58) tensor(v[16]) 146 | query(rq59) tensor(v[16]) 147 | query(rq60) tensor(v[16]) 148 | query(rq61) tensor(v[16]) 149 | query(rq62) tensor(v[16]) 150 | query(rq63) tensor(v[16]) 151 | query(qt) tensor(querytoken{}, v[128]) 152 | query(qtb) tensor(querytoken{}, v[16]) 153 | 154 | } 155 | function max_sim() { 156 | expression { 157 | 158 | sum( 159 | reduce( 160 | sum( 161 | query(qt) * unpack_bits(attribute(embedding)) , v 162 | ), 163 | max, patch 164 | ), 165 | querytoken 166 | ) 167 | 168 | } 169 | } 170 | function max_sim_binary() { 171 | expression { 172 | 173 | sum( 174 | reduce( 175 | 1/(1 + sum( 176 | hamming(query(qtb), attribute(embedding)) ,v) 177 | ), 178 | max, 179 | patch 180 | ), 181 | querytoken 182 | ) 183 | 184 | } 185 | } 186 | first-phase { 187 | expression { 188 | max_sim_binary 189 | } 190 | } 191 | second-phase { 192 | rerank-count: 10 193 | expression { 194 | max_sim 195 | } 196 | } 197 | } 198 | } -------------------------------------------------------------------------------- /test/search/query-profiles/default.xml: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /test/search/query-profiles/types/root.xml: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /test/services.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 1 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | -------------------------------------------------------------------------------- /test2/schemas/pdf_page.sd: -------------------------------------------------------------------------------- 1 | schema pdf_page { 2 | document pdf_page { 3 | field id type string { 4 | indexing: summary | index 5 | match { 6 | word 7 | } 8 | } 9 | field url type string { 10 | indexing: summary | index 11 | } 12 | field title type string { 13 | indexing: summary | index 14 | index: enable-bm25 15 | match { 16 | text 17 | } 18 | } 19 | field page_number type int { 20 | indexing: summary | attribute 21 | } 22 | field image type raw { 23 | indexing: summary 24 | } 25 | field text type string { 26 | indexing: index 27 | index: enable-bm25 28 | match { 29 | text 30 | } 31 | } 32 | field embedding type tensor(patch{}, v[16]) { 33 | indexing: attribute | index 34 | attribute { 35 | distance-metric: hamming 36 | } 37 | index { 38 | hnsw { 39 | max-links-per-node: 32 40 | neighbors-to-explore-at-insert: 400 41 | } 42 | } 43 | } 44 | } 45 | fieldset default { 46 | fields: title, text 47 | } 48 | rank-profile default { 49 | inputs { 50 | query(qt) tensor(querytoken{}, v[128]) 51 | 52 | } 53 | function max_sim() { 54 | expression { 55 | 56 | sum( 57 | reduce( 58 | sum( 59 | query(qt) * unpack_bits(attribute(embedding)) , v 60 | ), 61 | max, patch 62 | ), 63 | querytoken 64 | ) 65 | 66 | } 67 | } 68 | function bm25_score() { 69 | expression { 70 | bm25(title) + bm25(text) 71 | } 72 | } 73 | first-phase { 74 | expression { 75 | bm25_score 76 | } 77 | } 78 | second-phase { 79 | rerank-count: 100 80 | expression { 81 | max_sim 82 | } 83 | } 84 | } 85 | } -------------------------------------------------------------------------------- /test2/search/query-profiles/default.xml: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /test2/search/query-profiles/types/root.xml: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /test2/services.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 1 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | -------------------------------------------------------------------------------- /test3/schemas/pdf_page.sd: -------------------------------------------------------------------------------- 1 | schema pdf_page { 2 | document pdf_page { 3 | field id type string { 4 | indexing: summary | index 5 | match { 6 | word 7 | } 8 | } 9 | field url type string { 10 | indexing: summary | index 11 | } 12 | field title type string { 13 | indexing: summary | index 14 | index: enable-bm25 15 | match { 16 | text 17 | } 18 | } 19 | field page_number type int { 20 | indexing: summary | attribute 21 | } 22 | field image type raw { 23 | indexing: summary 24 | } 25 | field text type string { 26 | indexing: index 27 | index: enable-bm25 28 | match { 29 | text 30 | } 31 | } 32 | field embedding type tensor(patch{}, v[16]) { 33 | indexing: attribute | index 34 | attribute { 35 | distance-metric: hamming 36 | } 37 | index { 38 | hnsw { 39 | max-links-per-node: 32 40 | neighbors-to-explore-at-insert: 400 41 | } 42 | } 43 | } 44 | } 45 | fieldset default { 46 | fields: title, text 47 | } 48 | rank-profile default { 49 | inputs { 50 | query(qt) tensor(querytoken{}, v[128]) 51 | 52 | } 53 | function max_sim() { 54 | expression { 55 | 56 | sum( 57 | reduce( 58 | sum( 59 | query(qt) * unpack_bits(attribute(embedding)) , v 60 | ), 61 | max, patch 62 | ), 63 | querytoken 64 | ) 65 | 66 | } 67 | } 68 | function bm25_score() { 69 | expression { 70 | bm25(title) + bm25(text) 71 | } 72 | } 73 | first-phase { 74 | expression { 75 | bm25_score 76 | } 77 | } 78 | second-phase { 79 | rerank-count: 100 80 | expression { 81 | max_sim 82 | } 83 | } 84 | } 85 | rank-profile retrieval-and-rerank { 86 | inputs { 87 | query(rq0) tensor(v[16]) 88 | query(rq1) tensor(v[16]) 89 | query(rq2) tensor(v[16]) 90 | query(rq3) tensor(v[16]) 91 | query(rq4) tensor(v[16]) 92 | query(rq5) tensor(v[16]) 93 | query(rq6) tensor(v[16]) 94 | query(rq7) tensor(v[16]) 95 | query(rq8) tensor(v[16]) 96 | query(rq9) tensor(v[16]) 97 | query(rq10) tensor(v[16]) 98 | query(rq11) tensor(v[16]) 99 | query(rq12) tensor(v[16]) 100 | query(rq13) tensor(v[16]) 101 | query(rq14) tensor(v[16]) 102 | query(rq15) tensor(v[16]) 103 | query(rq16) tensor(v[16]) 104 | query(rq17) tensor(v[16]) 105 | query(rq18) tensor(v[16]) 106 | query(rq19) tensor(v[16]) 107 | query(rq20) tensor(v[16]) 108 | query(rq21) tensor(v[16]) 109 | query(rq22) tensor(v[16]) 110 | query(rq23) tensor(v[16]) 111 | query(rq24) tensor(v[16]) 112 | query(rq25) tensor(v[16]) 113 | query(rq26) tensor(v[16]) 114 | query(rq27) tensor(v[16]) 115 | query(rq28) tensor(v[16]) 116 | query(rq29) tensor(v[16]) 117 | query(rq30) tensor(v[16]) 118 | query(rq31) tensor(v[16]) 119 | query(rq32) tensor(v[16]) 120 | query(rq33) tensor(v[16]) 121 | query(rq34) tensor(v[16]) 122 | query(rq35) tensor(v[16]) 123 | query(rq36) tensor(v[16]) 124 | query(rq37) tensor(v[16]) 125 | query(rq38) tensor(v[16]) 126 | query(rq39) tensor(v[16]) 127 | query(rq40) tensor(v[16]) 128 | query(rq41) tensor(v[16]) 129 | query(rq42) tensor(v[16]) 130 | query(rq43) tensor(v[16]) 131 | query(rq44) tensor(v[16]) 132 | query(rq45) tensor(v[16]) 133 | query(rq46) tensor(v[16]) 134 | query(rq47) tensor(v[16]) 135 | query(rq48) tensor(v[16]) 136 | query(rq49) tensor(v[16]) 137 | query(rq50) tensor(v[16]) 138 | query(rq51) tensor(v[16]) 139 | query(rq52) tensor(v[16]) 140 | query(rq53) tensor(v[16]) 141 | query(rq54) tensor(v[16]) 142 | query(rq55) tensor(v[16]) 143 | query(rq56) tensor(v[16]) 144 | query(rq57) tensor(v[16]) 145 | query(rq58) tensor(v[16]) 146 | query(rq59) tensor(v[16]) 147 | query(rq60) tensor(v[16]) 148 | query(rq61) tensor(v[16]) 149 | query(rq62) tensor(v[16]) 150 | query(rq63) tensor(v[16]) 151 | query(qt) tensor(querytoken{}, v[128]) 152 | query(qtb) tensor(querytoken{}, v[16]) 153 | 154 | } 155 | function max_sim() { 156 | expression { 157 | 158 | sum( 159 | reduce( 160 | sum( 161 | query(qt) * unpack_bits(attribute(embedding)) , v 162 | ), 163 | max, patch 164 | ), 165 | querytoken 166 | ) 167 | 168 | } 169 | } 170 | function max_sim_binary() { 171 | expression { 172 | 173 | sum( 174 | reduce( 175 | 1/(1 + sum( 176 | hamming(query(qtb), attribute(embedding)) ,v) 177 | ), 178 | max, 179 | patch 180 | ), 181 | querytoken 182 | ) 183 | 184 | } 185 | } 186 | first-phase { 187 | expression { 188 | max_sim_binary 189 | } 190 | } 191 | second-phase { 192 | rerank-count: 10 193 | expression { 194 | max_sim 195 | } 196 | } 197 | } 198 | } -------------------------------------------------------------------------------- /test3/search/query-profiles/default.xml: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /test3/search/query-profiles/types/root.xml: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /test3/services.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 1 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | -------------------------------------------------------------------------------- /test4/schemas/pdf_page.sd: -------------------------------------------------------------------------------- 1 | schema pdf_page { 2 | document pdf_page { 3 | field id type string { 4 | indexing: summary | index 5 | match { 6 | word 7 | } 8 | } 9 | field url type string { 10 | indexing: summary | index 11 | } 12 | field title type string { 13 | indexing: summary | index 14 | index: enable-bm25 15 | match { 16 | text 17 | } 18 | } 19 | field page_number type int { 20 | indexing: summary | attribute 21 | } 22 | field image type raw { 23 | indexing: summary 24 | } 25 | field text type string { 26 | indexing: index 27 | index: enable-bm25 28 | match { 29 | text 30 | } 31 | } 32 | field embedding type tensor(patch{}, v[16]) { 33 | indexing: attribute | index 34 | attribute { 35 | distance-metric: hamming 36 | } 37 | index { 38 | hnsw { 39 | max-links-per-node: 32 40 | neighbors-to-explore-at-insert: 400 41 | } 42 | } 43 | } 44 | } 45 | fieldset default { 46 | fields: title, text 47 | } 48 | rank-profile default { 49 | inputs { 50 | query(qt) tensor(querytoken{}, v[128]) 51 | 52 | } 53 | function max_sim() { 54 | expression { 55 | 56 | sum( 57 | reduce( 58 | sum( 59 | query(qt) * unpack_bits(attribute(embedding)) , v 60 | ), 61 | max, patch 62 | ), 63 | querytoken 64 | ) 65 | 66 | } 67 | } 68 | function bm25_score() { 69 | expression { 70 | bm25(title) + bm25(text) 71 | } 72 | } 73 | first-phase { 74 | expression { 75 | bm25_score 76 | } 77 | } 78 | second-phase { 79 | rerank-count: 100 80 | expression { 81 | max_sim 82 | } 83 | } 84 | } 85 | rank-profile retrieval-and-rerank { 86 | inputs { 87 | query(rq0) tensor(v[16]) 88 | query(rq1) tensor(v[16]) 89 | query(rq2) tensor(v[16]) 90 | query(rq3) tensor(v[16]) 91 | query(rq4) tensor(v[16]) 92 | query(rq5) tensor(v[16]) 93 | query(rq6) tensor(v[16]) 94 | query(rq7) tensor(v[16]) 95 | query(rq8) tensor(v[16]) 96 | query(rq9) tensor(v[16]) 97 | query(rq10) tensor(v[16]) 98 | query(rq11) tensor(v[16]) 99 | query(rq12) tensor(v[16]) 100 | query(rq13) tensor(v[16]) 101 | query(rq14) tensor(v[16]) 102 | query(rq15) tensor(v[16]) 103 | query(rq16) tensor(v[16]) 104 | query(rq17) tensor(v[16]) 105 | query(rq18) tensor(v[16]) 106 | query(rq19) tensor(v[16]) 107 | query(rq20) tensor(v[16]) 108 | query(rq21) tensor(v[16]) 109 | query(rq22) tensor(v[16]) 110 | query(rq23) tensor(v[16]) 111 | query(rq24) tensor(v[16]) 112 | query(rq25) tensor(v[16]) 113 | query(rq26) tensor(v[16]) 114 | query(rq27) tensor(v[16]) 115 | query(rq28) tensor(v[16]) 116 | query(rq29) tensor(v[16]) 117 | query(rq30) tensor(v[16]) 118 | query(rq31) tensor(v[16]) 119 | query(rq32) tensor(v[16]) 120 | query(rq33) tensor(v[16]) 121 | query(rq34) tensor(v[16]) 122 | query(rq35) tensor(v[16]) 123 | query(rq36) tensor(v[16]) 124 | query(rq37) tensor(v[16]) 125 | query(rq38) tensor(v[16]) 126 | query(rq39) tensor(v[16]) 127 | query(rq40) tensor(v[16]) 128 | query(rq41) tensor(v[16]) 129 | query(rq42) tensor(v[16]) 130 | query(rq43) tensor(v[16]) 131 | query(rq44) tensor(v[16]) 132 | query(rq45) tensor(v[16]) 133 | query(rq46) tensor(v[16]) 134 | query(rq47) tensor(v[16]) 135 | query(rq48) tensor(v[16]) 136 | query(rq49) tensor(v[16]) 137 | query(rq50) tensor(v[16]) 138 | query(rq51) tensor(v[16]) 139 | query(rq52) tensor(v[16]) 140 | query(rq53) tensor(v[16]) 141 | query(rq54) tensor(v[16]) 142 | query(rq55) tensor(v[16]) 143 | query(rq56) tensor(v[16]) 144 | query(rq57) tensor(v[16]) 145 | query(rq58) tensor(v[16]) 146 | query(rq59) tensor(v[16]) 147 | query(rq60) tensor(v[16]) 148 | query(rq61) tensor(v[16]) 149 | query(rq62) tensor(v[16]) 150 | query(rq63) tensor(v[16]) 151 | query(qt) tensor(querytoken{}, v[128]) 152 | query(qtb) tensor(querytoken{}, v[16]) 153 | 154 | } 155 | function max_sim() { 156 | expression { 157 | 158 | sum( 159 | reduce( 160 | sum( 161 | query(qt) * unpack_bits(attribute(embedding)) , v 162 | ), 163 | max, patch 164 | ), 165 | querytoken 166 | ) 167 | 168 | } 169 | } 170 | function max_sim_binary() { 171 | expression { 172 | 173 | sum( 174 | reduce( 175 | 1/(1 + sum( 176 | hamming(query(qtb), attribute(embedding)) ,v) 177 | ), 178 | max, 179 | patch 180 | ), 181 | querytoken 182 | ) 183 | 184 | } 185 | } 186 | first-phase { 187 | expression { 188 | max_sim_binary 189 | } 190 | } 191 | second-phase { 192 | rerank-count: 10 193 | expression { 194 | max_sim 195 | } 196 | } 197 | } 198 | } -------------------------------------------------------------------------------- /test4/search/query-profiles/default.xml: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /test4/search/query-profiles/types/root.xml: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /test4/services.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 1 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | -------------------------------------------------------------------------------- /test_vespa_indexing.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | 3 | from src.vespa.datatypes import PDFInput 4 | from src.vespa.indexing.pdf_processor import PDFProcessor 5 | from src.vespa.indexing.prepare_feed import VespaFeedPreparator 6 | from src.vespa.indexing.run import run_indexing 7 | 8 | 9 | async def index_documents(): 10 | """Example function showing how to use run_indexing.""" 11 | config_path = "/Users/vesaalexandru/Workspaces/cube/cube-publication/evaluate-ocr-rag-systems/src/vespa/vespa_config.yaml" 12 | 13 | pdfs = [ 14 | PDFInput( 15 | title="Building a Resilient Strategy", 16 | url="https://static.conocophillips.com/files/resources/24-0976-sustainability-highlights_nature.pdf", 17 | ) 18 | ] 19 | 20 | # pdfs = [ 21 | # PDFInput( 22 | # title="Building a Resilient Strategy", 23 | # url="https://static.conocophillips.com/files/resources/conocophillips-2023-managing-climate-related-risks.pdf", 24 | # ) 25 | # ] 26 | 27 | # Explicitly create the processors 28 | pdf_processor = PDFProcessor() 29 | feed_preparator = VespaFeedPreparator() 30 | 31 | try: 32 | await run_indexing( 33 | config_path=config_path, 34 | pdfs=pdfs, 35 | pdf_processor=pdf_processor, 36 | feed_preparator=feed_preparator, 37 | ) 38 | except Exception as e: 39 | print(f"Error occurred: {e}") 40 | raise 41 | 42 | 43 | def main(): 44 | """Entry point for the application.""" 45 | asyncio.run(index_documents()) 46 | 47 | 48 | if __name__ == "__main__": 49 | main() 50 | -------------------------------------------------------------------------------- /test_vespa_inference.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import logging 3 | from typing import Dict, List 4 | 5 | from src.vespa.datatypes import QueryResult 6 | from src.vespa.retrieval.run import run_queries 7 | 8 | logging.basicConfig(level=logging.INFO) 9 | logger = logging.getLogger(__name__) 10 | 11 | 12 | async def main() -> Dict[str, List[QueryResult]]: 13 | # Path to your config file 14 | config_path = ( 15 | "/Users/vesaalexandru/Workspaces/cube/cube-publication/" 16 | "evaluate-ocr-rag-systems/src/vespa/vespa_config.yaml" 17 | ) 18 | 19 | # Test queries 20 | queries = [ 21 | "Percentage of non-fresh water as source?", 22 | ] 23 | 24 | try: 25 | logger.info("Starting query execution") 26 | results = await run_queries( 27 | config_path=config_path, queries=queries, display_results=True 28 | ) 29 | logger.info("Query execution completed successfully") 30 | return results 31 | 32 | except Exception as e: 33 | logger.error(f"Query execution failed: {str(e)}") 34 | raise 35 | 36 | 37 | if __name__ == "__main__": 38 | import base64 39 | from io import BytesIO 40 | 41 | import google.generativeai as genai 42 | from PIL import Image 43 | from vidore_benchmark.utils.image_utils import scale_image 44 | 45 | results = asyncio.run(main()) 46 | 47 | genai.configure(api_key="AIzaSyCxMUFUaeApWRNr5HUS_xhWL26p0WLuG2w") 48 | 49 | queries = [ 50 | "Percentage of non-fresh water as source?", 51 | # "Policies related to nature risk?", 52 | # "How much of produced water is recycled?", 53 | ] 54 | 55 | best_hit = results["Percentage of non-fresh water as source?"][0] 56 | pdf_url = best_hit.url 57 | pdf_title = best_hit.title 58 | # match_scores = best_hit["fields"]["matchfeatures"]["max_sim_per_page"] 59 | images = best_hit["fields"]["images"] 60 | sorted_pages = sorted(match_scores.items(), key=lambda x: x[1], reverse=True) 61 | best_page, score = sorted_pages[0] 62 | best_page = int(best_page) 63 | image_data = base64.b64decode(best_hit.source["fields"]["image"]) 64 | image = Image.open(BytesIO(image_data)) 65 | scaled_image = scale_image(image, 720) 66 | # # display(scaled_image) 67 | 68 | model = genai.GenerativeModel(model_name="gemini-1.5-flash") 69 | response = model.generate_content([queries[0], image]) 70 | print(response) 71 | -------------------------------------------------------------------------------- /vespa_inference.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | 3 | from vespa.deployment import VespaCloud 4 | from vespa.io import VespaQueryResponse 5 | 6 | from vespa_vision_rag import VespaSetup 7 | 8 | cloud = VespaSetup("test") 9 | 10 | 11 | async def query_vespa(queries, vespa_cloud): 12 | """ 13 | Query Vespa for each user query. 14 | """ 15 | app = vespa_cloud.get_application() 16 | 17 | async with app.asyncio(connections=1, timeout=180) as session: 18 | for query in queries: 19 | response: VespaQueryResponse = await session.query( 20 | yql="select title,url,image,page_number from pdf_page where userInput(@userQuery)", 21 | ranking="default", 22 | userQuery=query, 23 | timeout=120, 24 | hits=5, # Adjust the number of hits returned 25 | body={"presentation.timing": True}, 26 | ) 27 | if response.is_successful(): 28 | display_query_results(query, response) 29 | else: 30 | print(f"Query failed for '{query}': {response.json()}") 31 | 32 | 33 | def display_query_results(query, response): 34 | """ 35 | Display query results in a readable format. 36 | """ 37 | query_time = response.json.get("timing", {}).get("searchtime", -1) 38 | total_count = response.json.get("root", {}).get("fields", {}).get("totalCount", 0) 39 | print(f"Query: {query}") 40 | print(f"Query Time: {query_time}s, Total Results: {total_count}") 41 | for idx, hit in enumerate(response.hits): 42 | title = hit["fields"].get("title", "N/A") 43 | url = hit["fields"].get("url", "N/A") 44 | page_number = hit["fields"].get("page_number", "N/A") 45 | print(f" Result {idx + 1}:") 46 | print(f" Title: {title}") 47 | print(f" URL: {url}") 48 | print(f" Page: {page_number}") 49 | 50 | 51 | async def main(): 52 | # Define the queries you want to execute 53 | queries = [ 54 | "Percentage of non-fresh water as source?", 55 | "Policies related to nature risk?", 56 | "How much of produced water is recycled?", 57 | ] 58 | 59 | # Initialize VespaCloud 60 | vespa_cloud = VespaCloud( 61 | tenant="cube-digital", 62 | application="test", 63 | application_package=cloud.app_package, # Use None since the app is already deployed 64 | ) 65 | 66 | # Run queries against the Vespa app 67 | await query_vespa(queries, vespa_cloud) 68 | 69 | 70 | if __name__ == "__main__": 71 | asyncio.run(main()) 72 | -------------------------------------------------------------------------------- /vespa_vision_rag.py: -------------------------------------------------------------------------------- 1 | import base64 2 | from io import BytesIO 3 | 4 | import numpy as np 5 | import requests 6 | import torch 7 | from colpali_engine.models import ColQwen2, ColQwen2Processor 8 | from pdf2image import convert_from_path 9 | from pypdf import PdfReader 10 | from torch.utils.data import DataLoader 11 | from tqdm import tqdm 12 | from vespa.deployment import VespaCloud 13 | from vespa.package import ( 14 | HNSW, 15 | ApplicationPackage, 16 | Document, 17 | Field, 18 | FieldSet, 19 | FirstPhaseRanking, 20 | Function, 21 | RankProfile, 22 | Schema, 23 | SecondPhaseRanking, 24 | ) 25 | 26 | 27 | class PDFProcessor: 28 | def __init__(self, model_name="vidore/colqwen2-v0.1"): 29 | self.model = ColQwen2.from_pretrained( 30 | model_name, torch_dtype=torch.bfloat16, device_map="auto" 31 | ) 32 | self.processor = ColQwen2Processor.from_pretrained(model_name) 33 | self.model.eval() 34 | 35 | def download_pdf(self, url): 36 | response = requests.get(url) 37 | if response.status_code == 200: 38 | return BytesIO(response.content) 39 | raise Exception(f"Failed to download PDF: Status code {response.status_code}") 40 | 41 | def get_pdf_content(self, pdf_url): 42 | pdf_file = self.download_pdf(pdf_url) 43 | temp_file = "temp.pdf" 44 | with open(temp_file, "wb") as f: 45 | f.write(pdf_file.read()) 46 | 47 | reader = PdfReader(temp_file) 48 | page_texts = [page.extract_text() for page in reader.pages] 49 | images = convert_from_path(temp_file) 50 | assert len(images) == len(page_texts) 51 | return images, page_texts 52 | 53 | def process_pdf(self, pdf_metadata): 54 | pdf_data = [] 55 | for pdf in pdf_metadata: 56 | images, texts = self.get_pdf_content(pdf["url"]) 57 | embeddings = self.generate_embeddings(images) 58 | pdf_data.append( 59 | { 60 | "url": pdf["url"], 61 | "title": pdf["title"], 62 | "images": images, 63 | "texts": texts, 64 | "embeddings": embeddings, 65 | } 66 | ) 67 | return pdf_data 68 | 69 | def generate_embeddings(self, images): 70 | embeddings = [] 71 | dataloader = DataLoader( 72 | images, 73 | batch_size=2, 74 | shuffle=False, 75 | collate_fn=lambda x: self.processor.process_images(x), 76 | ) 77 | 78 | for batch in tqdm(dataloader): 79 | with torch.no_grad(): 80 | batch = {k: v.to(self.model.device) for k, v in batch.items()} 81 | batch_embeddings = self.model(**batch) 82 | embeddings.extend(list(torch.unbind(batch_embeddings.to("cpu")))) 83 | return embeddings 84 | 85 | 86 | class VespaSetup: 87 | def __init__(self, app_name): 88 | self.app_name = app_name 89 | self.schema = self._create_schema() 90 | self.app_package = ApplicationPackage(name=app_name, schema=[self.schema]) 91 | 92 | def _create_schema(self): 93 | schema = Schema( 94 | name="pdf_page", 95 | document=Document( 96 | fields=[ 97 | Field( 98 | name="id", 99 | type="string", 100 | indexing=["summary", "index"], 101 | match=["word"], 102 | ), 103 | Field(name="url", type="string", indexing=["summary", "index"]), 104 | Field( 105 | name="title", 106 | type="string", 107 | indexing=["summary", "index"], 108 | match=["text"], 109 | index="enable-bm25", 110 | ), 111 | Field( 112 | name="page_number", 113 | type="int", 114 | indexing=["summary", "attribute"], 115 | ), 116 | Field(name="image", type="raw", indexing=["summary"]), 117 | Field( 118 | name="text", 119 | type="string", 120 | indexing=["index"], 121 | match=["text"], 122 | index="enable-bm25", 123 | ), 124 | Field( 125 | name="embedding", 126 | type="tensor(patch{}, v[16])", 127 | indexing=["attribute", "index"], 128 | ann=HNSW( 129 | distance_metric="hamming", 130 | max_links_per_node=32, 131 | neighbors_to_explore_at_insert=400, 132 | ), 133 | ), 134 | ] 135 | ), 136 | fieldsets=[FieldSet(name="default", fields=["title", "text"])], 137 | ) 138 | self._add_rank_profiles(schema) 139 | return schema 140 | 141 | def _add_rank_profiles(self, schema): 142 | default_profile = RankProfile( 143 | name="default", 144 | inputs=[("query(qt)", "tensor(querytoken{}, v[128])")], 145 | functions=[ 146 | Function( 147 | name="max_sim", 148 | expression=""" 149 | sum( 150 | reduce( 151 | sum( 152 | query(qt) * unpack_bits(attribute(embedding)) , v 153 | ), 154 | max, patch 155 | ), 156 | querytoken 157 | ) 158 | """, 159 | ), 160 | Function(name="bm25_score", expression="bm25(title) + bm25(text)"), 161 | ], 162 | first_phase=FirstPhaseRanking(expression="bm25_score"), 163 | second_phase=SecondPhaseRanking(expression="max_sim", rerank_count=100), 164 | ) 165 | schema.add_rank_profile(default_profile) 166 | 167 | 168 | def prepare_vespa_feed(pdf_data): 169 | vespa_feed = [] 170 | for pdf in pdf_data: 171 | for page_number, (text, embedding, image) in enumerate( 172 | zip(pdf["texts"], pdf["embeddings"], pdf["images"]) 173 | ): 174 | embedding_dict = {} 175 | for idx, patch_embedding in enumerate(embedding): 176 | binary_vector = ( 177 | np.packbits(np.where(patch_embedding > 0, 1, 0)) 178 | .astype(np.int8) 179 | .tobytes() 180 | .hex() 181 | ) 182 | embedding_dict[idx] = binary_vector 183 | 184 | page = { 185 | "id": hash(pdf["url"] + str(page_number)), 186 | "url": pdf["url"], 187 | "title": pdf["title"], 188 | "page_number": page_number, 189 | "image": get_base64_image(resize_image(image, 640)), 190 | "text": text, 191 | "embedding": embedding_dict, 192 | } 193 | vespa_feed.append(page) 194 | return vespa_feed 195 | 196 | 197 | def resize_image(image, max_height=800): 198 | width, height = image.size 199 | if height > max_height: 200 | ratio = max_height / height 201 | return image.resize((int(width * ratio), int(height * ratio))) 202 | return image 203 | 204 | 205 | def get_base64_image(image): 206 | buffered = BytesIO() 207 | image.save(buffered, format="JPEG") 208 | return str(base64.b64encode(buffered.getvalue()), "utf-8") 209 | 210 | 211 | async def deploy_and_feed(vespa_feed): 212 | vespa_setup = VespaSetup("test") 213 | 214 | vespa_cloud = VespaCloud( 215 | tenant="cube-digital", 216 | application="test", 217 | application_package=vespa_setup.app_package, 218 | ) 219 | 220 | app = vespa_cloud.deploy() 221 | 222 | async with app.asyncio(connections=1, timeout=180) as session: 223 | for page in tqdm(vespa_feed): 224 | response = await session.feed_data_point( 225 | data_id=page["id"], fields=page, schema="pdf_page" 226 | ) 227 | if not response.is_successful(): 228 | print(response.json()) 229 | return app 230 | 231 | 232 | async def main(): 233 | # Example usage 234 | sample_pdfs = [ 235 | { 236 | "title": "Building a Resilient Strategy for the Energy Transition", 237 | "url": "https://static.conocophillips.com/files/resources/conocophillips-2023-managing-climate-related-risks.pdf", 238 | } 239 | ] 240 | 241 | processor = PDFProcessor() 242 | pdf_data = processor.process_pdf(sample_pdfs) 243 | vespa_feed = prepare_vespa_feed(pdf_data) 244 | 245 | # Deploy to Vespa Cloud (requires configuration) 246 | await deploy_and_feed(vespa_feed=vespa_feed) 247 | 248 | 249 | if __name__ == "__main__": 250 | import asyncio 251 | 252 | asyncio.run(main()) 253 | () 254 | --------------------------------------------------------------------------------