├── .gitignore ├── LICENSE.md ├── README.md ├── application ├── extraction │ ├── Dockerfile │ ├── models │ │ └── models.py │ ├── requirements.txt │ ├── service │ │ ├── extraction_handler.py │ │ ├── extraction_worker.py │ │ └── processing_handler.py │ └── start_extraction.py ├── pipeline │ ├── Dockerfile │ ├── models │ │ └── models.py │ ├── requirements.txt │ ├── routes │ │ └── pipeline_routes.py │ ├── service │ │ └── pipeline_service.py │ └── start_pipeline.py └── transformation │ ├── Dockerfile │ ├── models │ └── models.py │ ├── requirements.txt │ ├── service │ ├── transformation_handler.py │ └── transformation_worker.py │ └── start_transformation.py ├── common ├── agents │ ├── agent_prompt_enums.py │ └── prs_agent.py ├── api │ └── 3.1.0-marly-spec.yml ├── destinations │ ├── base │ │ └── base_destination.py │ ├── destination_factory.py │ ├── enums │ │ └── destination_enums.py │ └── sqlite_destination.py ├── models │ ├── azure_model.py │ ├── base │ │ └── base_model.py │ ├── cerebras_model.py │ ├── enums │ │ └── model_enums.py │ ├── groq_model.py │ ├── mistral_model.py │ ├── model_factory.py │ └── openai_model.py ├── prompts │ └── prompt_enums.py ├── redis │ └── redis_config.py ├── sources │ ├── base │ │ └── base_source.py │ ├── enums │ │ └── source_enums.py │ ├── local_fs_source.py │ ├── s3_source.py │ └── source_factory.py └── text_extraction │ └── text_extractor.py ├── docker-compose.yml ├── examples ├── ai-workers │ └── ai-sdr │ │ ├── auth │ │ └── anon_helper.py │ │ ├── contacts.db │ │ ├── main.py │ │ ├── output_source │ │ └── sql_helper.py │ │ └── transformation │ │ └── marly_helper.py ├── example_files │ ├── lacers.pdf │ └── lacers_reduced.pdf ├── notebooks │ ├── autogen_example │ │ ├── OAI_CONFIG_LIST.json │ │ ├── autogen.ipynb │ │ ├── lacers_reduced.pdf │ │ └── plot.png │ └── langgraph_example │ │ ├── data_loading_workflow.ipynb │ │ ├── diagram.jpeg │ │ ├── lacers_reduced.pdf │ │ ├── notebooklm_workflow.ipynb │ │ ├── notebooklm_workflow.jpeg │ │ ├── pdf_table_to_chart_workflow.ipynb │ │ └── wkflow.jpeg └── scripts │ ├── api_example.py │ ├── azure_example.py │ ├── cerebras_example.py │ ├── data_loading_example.py │ ├── groq_example.py │ ├── markdown_example.py │ ├── mistral_example.py │ ├── non_marly_examples │ ├── lacers_reduced.pdf │ └── llamaindex_pinecone.py │ ├── requirements.txt │ ├── s3_example.py │ └── web_and_document_example.py ├── requirements.txt └── start-oe.sh /.gitignore: -------------------------------------------------------------------------------- 1 | .env 2 | .chroma 3 | .idea 4 | .ipynb_checkpoints 5 | .mypy_cache 6 | .vscode 7 | __pycache__ 8 | .pytest_cache 9 | htmlcov 10 | dist 11 | site 12 | .coverage 13 | coverage.xml 14 | .netlify 15 | *_venv/ 16 | test.db 17 | log.txt 18 | Pipfile.lock 19 | env3.* 20 | env 21 | docs_build 22 | venv 23 | docs.zip 24 | archive.zip 25 | __init__.py 26 | 27 | # vim temporary files 28 | *~ 29 | .*.sw? 30 | 31 | */myenv 32 | */venv 33 | .DS_Store 34 | 35 | 36 | example_notebooks/*/.cache/ 37 | -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2025 Murtaza Meerza 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 | 3 | # open-extract 4 | 5 | [Features](#-features) • [What is a Schema?](#-what-is-a-schema) • [Use Cases](#-use-cases) • [Getting Started](#-getting-started) • [Documentation](#-documentation) 6 | 7 |
8 | 9 | --- 10 | 11 | open-extract simplifies the ingestion and processing of unstructured data for those building AI Agents/Agentic Workflows using frameworks such as LangGraph, AG2, and CrewAI. 12 | 13 | --- 14 | 15 | ## 🚀 Features 16 | 17 | 📄 Extract Relevant Information Seamlessly: Give your applications the ability to identify and extract relevant data from one or many large documents and websites with just a single API call. Get the content back in JSON or Markdown formats, making it easy to integrate into your workflows. 18 | 19 | 🔍 Multi-Schema/Multi-Document Support: Extract data based one or many predefined schemas from a variety of document types, without needing a vector database or specifying page numbers. 20 | 21 | 🔄 Built-in Caching: With built-in caching, previously extracted schemas can be instantly retrieved, enabling rapid repeat extractions without having to reprocess the original documents. 22 | 23 | 🚫 No Vendor Lock-In: Enjoy complete flexibility with your choice of model provider. Whether using open-source or closed-source models, you're never tied to a specific vendor, ensuring full control. 24 | 25 | --- 26 | 27 | ## 🧰 What is a Schema? 28 | 29 | A schema is a set of key-value pairs describing what needs to be extracted from a particular document. 30 | 31 |
32 | 📋 Example Schema 33 | 34 | ``` 35 | { 36 | "Firm": "The name of the firm", 37 | "Number of Funds": "The number of funds managed by the firm", 38 | "Commitment": "The commitment amount in millions of dollars", 39 | "% of Total Comm": "The percentage of total commitment", 40 | "Exposure (FMV + Unfunded)": "The exposure including fair market value and unfunded commitments in millions of dollars", 41 | "% of Total Exposure": "The percentage of total exposure", 42 | "TVPI": "Total Value to Paid-In multiple", 43 | "Net IRR": "Net Internal Rate of Return as a percentage" 44 | } 45 | ``` 46 | 47 |
48 | 49 | 50 | 51 | --- 52 | 53 | ## 🎯 Use Cases 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 |
💼 Financial Report Analysis📊 Customer Feedback Processing🔬 Research Assistant🧠 Legal Contract Parsing
Extract key financial metrics from quarterly PDF reportsCategorize feedback from various document typesProcess research papers, extracting methodologies and findingsExtract key legal terms and conditions from contracts
69 | 70 | --- 71 | 72 | ## 🛠️ Getting Started 73 | 74 | ### Build the Platform 75 | 76 | --- 77 | 78 | To build the platform from source, run the following command: 79 | 80 | ```bash 81 | ./start-oe.sh 82 | ``` 83 | 84 | --- 85 | 86 | ### Run an example script or notebook 87 | 88 | Once the platform is running you can test it out by trying one of our examples 89 | 90 | 1. Navigate to the examples folder: 91 | 92 | ```bash 93 | cd examples 94 | ``` 95 | 2. Navigate to the scripts or notebooks folder: 96 | 97 | ```bash 98 | cd scripts 99 | ``` 100 | or 101 | ```bash 102 | cd notebooks/autogen_example 103 | ``` 104 | 3. Run one of our example scripts: 105 | ```bash 106 | python azure_example.py 107 | ``` 108 | 109 | --- 110 | 111 | -------------------------------------------------------------------------------- /application/extraction/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM python:3.10-slim 2 | 3 | WORKDIR /app 4 | 5 | # Copy the common directory 6 | COPY ./common /app/common 7 | 8 | # Copy the extraction application directory 9 | COPY ./application/extraction /app/application/extraction 10 | 11 | # Copy the requirements file 12 | COPY ./application/extraction/requirements.txt /app/requirements.txt 13 | 14 | RUN apt-get update && apt-get install -y \ 15 | build-essential \ 16 | make \ 17 | && rm -rf /var/lib/apt/lists/* 18 | 19 | # Install dependencies 20 | RUN pip install --no-cache-dir -r /app/requirements.txt 21 | 22 | # Set the Python path 23 | ENV PYTHONPATH /app 24 | 25 | # Run the extraction service 26 | CMD ["python", "application/extraction/start_extraction.py"] 27 | -------------------------------------------------------------------------------- /application/extraction/models/models.py: -------------------------------------------------------------------------------- 1 | from pydantic import BaseModel 2 | from typing import List, Dict 3 | from enum import Enum 4 | 5 | class ExtractionRequestModel(BaseModel): 6 | task_id: str 7 | pdf_key: str 8 | schemas: List[Dict] 9 | source_type: str = "pdf" 10 | destination: str = None 11 | 12 | class SchemaResult(BaseModel): 13 | schema_id: str 14 | metrics: Dict[str, str] 15 | schema_data: Dict[str, str] 16 | 17 | class ExtractionResponseModel(BaseModel): 18 | task_id: str 19 | pdf_key: str 20 | results: List[SchemaResult] 21 | source_type: str = "pdf" 22 | 23 | class JobStatus(str, Enum): 24 | PENDING = "PENDING" 25 | IN_PROGRESS = "IN_PROGRESS" 26 | COMPLETED = "COMPLETED" 27 | FAILED = "FAILED" 28 | 29 | class ModelDetails(BaseModel): 30 | provider_type: str 31 | provider_model_name: str 32 | api_key: str 33 | markdown_mode: bool 34 | additional_params: Dict[str, str] 35 | -------------------------------------------------------------------------------- /application/extraction/requirements.txt: -------------------------------------------------------------------------------- 1 | python-dotenv 2 | pydantic 3 | typing 4 | langchain 5 | langchainhub 6 | langsmith 7 | langgraph 8 | openai 9 | groq 10 | cerebras_cloud_sdk 11 | PyPDF2 12 | pdfminer.six 13 | pymupdf 14 | aiohttp 15 | requests==2.32.3 16 | redis[hiredis] 17 | boto3 18 | beautifulsoup4 19 | markitdown 20 | asyncio 21 | celery 22 | tiktoken -------------------------------------------------------------------------------- /application/extraction/service/extraction_handler.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import logging 3 | import time 4 | import gc 5 | from typing import List, Dict, NamedTuple, Optional 6 | from io import BytesIO 7 | from redis.asyncio import Redis 8 | from common.redis.redis_config import get_redis_connection 9 | from common.text_extraction.text_extractor import extract_page_as_markdown, find_common_pages 10 | import base64 11 | from dotenv import load_dotenv 12 | from common.models.model_factory import ModelFactory 13 | from langsmith import Client as LangSmithClient 14 | from common.prompts.prompt_enums import PromptType 15 | import tiktoken 16 | from dataclasses import dataclass 17 | from celery import Celery 18 | from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor 19 | from application.extraction.service.processing_handler import ( 20 | get_latest_model_details, 21 | preprocess_messages, 22 | process_web_content 23 | ) 24 | from common.agents.prs_agent import process_extraction, AgentMode 25 | 26 | celery_app = Celery('extraction_tasks') 27 | celery_app.conf.update( 28 | broker_url='redis://redis:6379/0', 29 | result_backend='redis://redis:6379/0', 30 | task_serializer='json', 31 | result_serializer='json', 32 | accept_content=['json'], 33 | task_routes={ 34 | 'process_pdf_chunk': {'queue': 'pdf_processing'}, 35 | 'process_batch': {'queue': 'batch_processing'} 36 | }, 37 | task_time_limit=3600, 38 | worker_prefetch_multiplier=1, 39 | worker_max_tasks_per_child=100 40 | ) 41 | 42 | MAX_WORKERS = 10 43 | thread_pool = ThreadPoolExecutor(max_workers=MAX_WORKERS) 44 | process_pool = ProcessPoolExecutor(max_workers=MAX_WORKERS) 45 | 46 | logging.basicConfig(level=logging.INFO) 47 | logger = logging.getLogger(__name__) 48 | load_dotenv() 49 | 50 | langsmith_client = LangSmithClient() 51 | 52 | class PageNumbers(NamedTuple): 53 | pages: List[int] 54 | 55 | @dataclass 56 | class ProcessingProgress: 57 | current: int 58 | total: int 59 | stage: str 60 | timestamp: float 61 | status: str 62 | 63 | class ProcessingQueue: 64 | def __init__(self): 65 | self.queue = asyncio.Queue() 66 | self.results = {} 67 | self.progress = {} 68 | 69 | async def add_job(self, pdf_key: str, job_id: str): 70 | await self.queue.put((pdf_key, job_id)) 71 | self.progress[job_id] = ProcessingProgress(0, 0, "queued", time.time(), "pending") 72 | 73 | async def get_progress(self, job_id: str) -> Optional[ProcessingProgress]: 74 | return self.progress.get(job_id) 75 | 76 | processing_queue = ProcessingQueue() 77 | 78 | async def get_model_client(): 79 | """Get model client for batch processing.""" 80 | redis = await get_redis_connection() 81 | model_details = await get_latest_model_details(redis) 82 | if not model_details: 83 | raise Exception("Could not get model details") 84 | 85 | return ModelFactory.create_model( 86 | model_type=model_details.provider_type, 87 | model_name=model_details.provider_model_name, 88 | api_key=model_details.api_key, 89 | additional_params=model_details.additional_params 90 | ) 91 | 92 | @celery_app.task(name='process_pdf_chunk') 93 | def process_pdf_chunk(pdf_chunk: bytes, start_page: int, end_page: int, job_id: str): 94 | """Celery task for processing a chunk of PDF pages.""" 95 | file_stream = BytesIO(pdf_chunk) 96 | return asyncio.run(_process_pdf_chunk(file_stream, start_page, end_page, job_id)) 97 | 98 | @celery_app.task(name='process_batch') 99 | def process_batch(batch_content: str, keywords: str, examples: str, job_id: str): 100 | """Celery task for processing a batch of content.""" 101 | return asyncio.run(_process_batch(batch_content, keywords, examples, job_id)) 102 | 103 | async def _process_pdf_chunk(file_stream: BytesIO, start_page: int, end_page: int, job_id: str) -> Dict: 104 | """Process a chunk of PDF pages asynchronously.""" 105 | try: 106 | extracted_contents = [] 107 | for page in range(start_page, end_page + 1): 108 | content = await process_page(file_stream, page) 109 | if content: 110 | token_count = estimate_tokens(content) 111 | extracted_contents.append((page, content, token_count)) 112 | return {'success': True, 'contents': extracted_contents} 113 | except Exception as e: 114 | logger.error(f"Error processing PDF chunk {start_page}-{end_page}: {e}") 115 | return {'success': False, 'error': str(e)} 116 | 117 | async def _process_batch(batch_content: str, keywords: str, examples: str, job_id: str) -> str: 118 | """Process a batch of content asynchronously.""" 119 | try: 120 | client = await get_model_client() 121 | result = await call_llm_with_file_content(batch_content, keywords, examples, client) 122 | return result if result else "" 123 | except Exception as e: 124 | logger.error(f"Error processing batch in job {job_id}: {e}") 125 | return "" 126 | 127 | async def run_extraction(pdf_key: str, schemas: List[Dict[str, str]], job_id: str = None) -> List[str]: 128 | """Enhanced extraction with smart processing selection.""" 129 | if not job_id: 130 | job_id = f"job_{int(time.time())}" 131 | 132 | logger.info(f"Starting extraction for pdf_key: {pdf_key}, job_id: {job_id}") 133 | await track_progress(job_id, 0, len(schemas), "initialization") 134 | 135 | redis = await get_redis_connection() 136 | try: 137 | file_stream = await get_file_stream(redis, pdf_key) 138 | file_size = file_stream.getbuffer().nbytes 139 | 140 | client = await get_model_client() 141 | examples = await get_examples(client, str(schemas)) 142 | formatted_keywords = "\n".join([f"{k}: {v}" for k, v in schemas[0].items()]) 143 | 144 | page_numbers = await get_relevant_page_numbers(client, file_stream, formatted_keywords) 145 | total_relevant_pages = len(page_numbers.pages) 146 | 147 | DISTRIBUTED_THRESHOLD = 10 148 | 149 | if total_relevant_pages <= DISTRIBUTED_THRESHOLD: 150 | logger.info(f"Using direct processing for {total_relevant_pages} relevant pages") 151 | metrics = await retrieve_multi_page_metrics( 152 | page_numbers.pages, formatted_keywords, file_stream, 153 | examples, client, job_id 154 | ) 155 | await track_progress(job_id, len(schemas), len(schemas), "completed", "success") 156 | return [metrics] if metrics else [] 157 | else: 158 | logger.info(f"Using distributed processing for {total_relevant_pages} relevant pages") 159 | 160 | chunk_size = max(1, total_relevant_pages // MAX_WORKERS) 161 | chunks = [page_numbers.pages[i:i + chunk_size] for i in range(0, total_relevant_pages, chunk_size)] 162 | 163 | chunk_tasks = [] 164 | for chunk_pages in chunks: 165 | task = process_pdf_chunk.delay( 166 | base64.b64encode(file_stream.getvalue()).decode(), 167 | min(chunk_pages), 168 | max(chunk_pages), 169 | job_id 170 | ) 171 | chunk_tasks.append(task) 172 | 173 | results = [] 174 | for task in chunk_tasks: 175 | result = task.get() 176 | if result['success']: 177 | results.extend(result['contents']) 178 | 179 | if not results: 180 | logger.error("No results from distributed processing") 181 | return [] 182 | 183 | combined_content = "\n=== BATCH BREAK ===\n".join([content for _, content, _ in sorted(results, key=lambda x: x[0])]) 184 | final_result = await validate_metrics(combined_content, examples, client) 185 | 186 | await track_progress(job_id, len(schemas), len(schemas), "completed", "success") 187 | await cleanup_processed_files(redis, pdf_key) 188 | 189 | return [final_result] if final_result else [] 190 | 191 | except Exception as e: 192 | logger.error(f"Error in extraction process: {e}") 193 | await track_progress(job_id, 0, len(schemas), "failed", "error") 194 | return [] 195 | 196 | async def process_small_pdf(file_stream: BytesIO, schemas: List[Dict[str, str]], examples: str, client, job_id: str) -> List[str]: 197 | """Direct processing for small PDFs.""" 198 | try: 199 | formatted_keywords = "\n".join([f"{k}: {v}" for k, v in schemas[0].items()]) 200 | page_numbers = await get_relevant_page_numbers(client, file_stream, formatted_keywords) 201 | 202 | metrics = await retrieve_multi_page_metrics( 203 | page_numbers.pages, formatted_keywords, file_stream, 204 | examples, client, job_id 205 | ) 206 | 207 | await track_progress(job_id, len(schemas), len(schemas), "completed", "success") 208 | return [metrics] if metrics else [] 209 | 210 | except Exception as e: 211 | logger.error(f"Error in direct processing: {e}") 212 | return [] 213 | 214 | async def track_progress(job_id: str, current: int, total: int, stage: str, status: str = "running"): 215 | processing_queue.progress[job_id] = ProcessingProgress( 216 | current=current, 217 | total=total, 218 | stage=stage, 219 | timestamp=time.time(), 220 | status=status 221 | ) 222 | 223 | async def cleanup_processed_files(redis: Redis, pdf_key: str): 224 | """Cleanup temporary files and memory after processing.""" 225 | try: 226 | await redis.delete(f"processed:{pdf_key}") 227 | gc.collect() 228 | except Exception as e: 229 | logger.error(f"Error during cleanup: {e}") 230 | 231 | def estimate_tokens(text: str) -> int: 232 | """More accurate token estimation using tiktoken.""" 233 | try: 234 | encoding = tiktoken.get_encoding("cl100k_base") 235 | return len(encoding.encode(text)) 236 | except Exception: 237 | return len(text) // 4 238 | 239 | async def calculate_optimal_batch_size(content_length: int, max_tokens: int = 3000) -> int: 240 | """Calculate optimal batch size based on content length.""" 241 | return min(max_tokens, max(1000, content_length // 2)) 242 | 243 | async def process_schema(client, file_stream: BytesIO, schema: Dict[str, str], job_id: str, schema_idx: int, total_schemas: int) -> str: 244 | """Enhanced schema processing with progress tracking.""" 245 | try: 246 | await track_progress(job_id, schema_idx, total_schemas, f"processing_schema_{schema_idx}") 247 | formatted_keywords = "\n".join([f"{k}: {v}" for k, v in schema.items()]) 248 | 249 | page_numbers = await get_relevant_page_numbers(client, file_stream, formatted_keywords) 250 | examples = await get_examples(client, formatted_keywords) 251 | 252 | await track_progress(job_id, schema_idx + 0.5, total_schemas, f"extracting_metrics_{schema_idx}") 253 | metrics = await retrieve_multi_page_metrics( 254 | page_numbers.pages, formatted_keywords, file_stream, 255 | examples if examples else "", client, job_id 256 | ) 257 | 258 | return metrics if metrics else "" 259 | except Exception as e: 260 | logger.error(f"Error processing schema {schema_idx}: {e}") 261 | return "" 262 | 263 | async def get_file_stream(redis: Redis, pdf_key: str) -> BytesIO: 264 | try: 265 | base64_string: str = await redis.get(pdf_key) 266 | decoded_bytes = base64.b64decode(base64_string) 267 | return BytesIO(decoded_bytes) 268 | except Exception as e: 269 | logger.error(f"Failed to process file stream: {e}") 270 | raise e 271 | 272 | async def get_examples(client, formatted_keywords: str) -> str: 273 | try: 274 | prompt = langsmith_client.pull_prompt(PromptType.EXAMPLE_GENERATION.value) 275 | messages = prompt.invoke({"first_value": formatted_keywords}) 276 | processed_messages = preprocess_messages(messages) 277 | if processed_messages: 278 | return client.do_completion(processed_messages) 279 | except Exception as e: 280 | logger.error(f"Error generating examples: {e}") 281 | return "" 282 | 283 | async def get_relevant_page_numbers(client, file_stream: BytesIO, formatted_keywords: str) -> PageNumbers: 284 | try: 285 | common_pages = await find_common_pages(client, file_stream, formatted_keywords) 286 | return PageNumbers(pages=common_pages) 287 | except Exception as e: 288 | logger.error(f"Error with find_common_pages: {e}") 289 | raise Exception("Cannot process: No common pages found or unreadable file.") 290 | 291 | async def retrieve_multi_page_metrics( 292 | pages: List[int], keywords: str, file_stream: BytesIO, examples: str, client, job_id: str 293 | ) -> str: 294 | try: 295 | MAX_TOKENS = await calculate_optimal_batch_size(3000) 296 | logger.info(f"Using dynamic token limit: {MAX_TOKENS}") 297 | 298 | extracted_contents = [] 299 | total_pages = len(pages) 300 | 301 | for idx, page in enumerate(pages): 302 | await track_progress(job_id, idx, total_pages, f"extracting_page_{page}") 303 | content = await process_page(file_stream, page) 304 | if content: 305 | token_count = estimate_tokens(content) 306 | extracted_contents.append((page, content, token_count)) 307 | logger.info(f"Page {page}: {token_count} tokens") 308 | 309 | if not extracted_contents: 310 | logger.error("No content extracted from pages") 311 | return "" 312 | 313 | all_results = [] 314 | current_batch = [] 315 | current_token_count = 0 316 | 317 | for page_num, content, token_count in extracted_contents: 318 | if current_token_count + token_count > MAX_TOKENS: 319 | if current_batch: 320 | batch_content = "\n=== PAGE BREAK ===\n".join([c for _, c, _ in current_batch]) 321 | batch_result = await call_llm_with_file_content(batch_content, keywords, examples, client) 322 | if batch_result: 323 | all_results.append(batch_result) 324 | 325 | del batch_content 326 | gc.collect() 327 | 328 | current_batch = [(page_num, content, token_count)] 329 | current_token_count = token_count 330 | else: 331 | current_batch.append((page_num, content, token_count)) 332 | current_token_count += token_count 333 | 334 | if current_batch: 335 | batch_content = "\n=== PAGE BREAK ===\n".join([c for _, c, _ in current_batch]) 336 | batch_result = await call_llm_with_file_content(batch_content, keywords, examples, client) 337 | if batch_result: 338 | all_results.append(batch_result) 339 | 340 | del batch_content 341 | gc.collect() 342 | 343 | combined_results = "\n=== BATCH BREAK ===\n".join(all_results) 344 | await track_progress(job_id, total_pages, total_pages, "validating_results") 345 | 346 | validation_result = await validate_metrics(combined_results, examples, client) 347 | 348 | del all_results 349 | del combined_results 350 | gc.collect() 351 | 352 | return validation_result 353 | 354 | except Exception as e: 355 | logger.error(f"Error in retrieve_multi_page_metrics: {e}") 356 | await track_progress(job_id, 0, len(pages), "failed", "error") 357 | return "" 358 | 359 | async def process_page(file_stream: BytesIO, page: int) -> str: 360 | try: 361 | content = extract_page_as_markdown(file_stream, page) 362 | return content.decode('utf-8') if isinstance(content, bytes) else content 363 | except Exception as e: 364 | logger.error(f"Error extracting page {page} as markdown: {e}") 365 | return "" 366 | 367 | async def call_llm_with_file_content(formatted_content: str, keywords: str, examples: str, client) -> str: 368 | try: 369 | text = f"""SOURCE DOCUMENT: 370 | {formatted_content} 371 | 372 | EXAMPLE FORMAT: 373 | {examples} 374 | 375 | METRICS TO EXTRACT: 376 | {keywords} 377 | 378 | Note: The source document may contain multiple pages separated by '=== PAGE BREAK ==='. 379 | Extract all relevant metrics from each page section while maintaining accuracy.""" 380 | return process_extraction(text, client, AgentMode.EXTRACTION) 381 | except Exception as e: 382 | logger.error(f"Error calling LLM with file content: {e}") 383 | return "" 384 | 385 | async def validate_metrics( 386 | llm_results: str, 387 | examples: str, 388 | client 389 | ) -> str: 390 | try: 391 | batched_results = llm_results.split("=== BATCH BREAK ===") 392 | MAX_VALIDATION_TOKENS = await calculate_optimal_batch_size(3000) 393 | 394 | validated_chunks = [] 395 | current_chunk = [] 396 | current_token_count = 0 397 | 398 | for batch in batched_results: 399 | if not batch.strip(): 400 | continue 401 | 402 | batch_token_count = estimate_tokens(batch.strip()) 403 | 404 | if current_token_count + batch_token_count > MAX_VALIDATION_TOKENS: 405 | if current_chunk: 406 | try: 407 | chunk_text = "\n".join(current_chunk) 408 | prompt = langsmith_client.pull_prompt(PromptType.VALIDATION.value) 409 | messages = prompt.invoke({ 410 | "first_value": chunk_text, 411 | "second_value": examples 412 | }) 413 | processed_messages = preprocess_messages(messages) 414 | if processed_messages: 415 | chunk_validation = client.do_completion(processed_messages) 416 | validated_chunks.append(chunk_validation) 417 | 418 | del chunk_text 419 | del messages 420 | gc.collect() 421 | except Exception as e: 422 | logger.error(f"Error validating chunk: {e}") 423 | 424 | current_chunk = [batch.strip()] 425 | current_token_count = batch_token_count 426 | else: 427 | current_chunk.append(batch.strip()) 428 | current_token_count += batch_token_count 429 | 430 | if current_chunk: 431 | try: 432 | chunk_text = "\n".join(current_chunk) 433 | prompt = langsmith_client.pull_prompt(PromptType.VALIDATION.value) 434 | messages = prompt.invoke({ 435 | "first_value": chunk_text, 436 | "second_value": examples 437 | }) 438 | processed_messages = preprocess_messages(messages) 439 | if processed_messages: 440 | chunk_validation = client.do_completion(processed_messages) 441 | validated_chunks.append(chunk_validation) 442 | 443 | del chunk_text 444 | del messages 445 | gc.collect() 446 | except Exception as e: 447 | logger.error(f"Error validating final chunk: {e}") 448 | 449 | if len(validated_chunks) > 1: 450 | try: 451 | final_consolidation = "\n".join(validated_chunks) 452 | if estimate_tokens(final_consolidation) > MAX_VALIDATION_TOKENS: 453 | logger.warning("Final consolidation exceeds token limit, will process in chunks") 454 | return await validate_metrics(final_consolidation, examples, client) 455 | 456 | prompt = langsmith_client.pull_prompt(PromptType.VALIDATION.value) 457 | messages = prompt.invoke({ 458 | "first_value": final_consolidation, 459 | "second_value": examples 460 | }) 461 | processed_messages = preprocess_messages(messages) 462 | if processed_messages: 463 | final_result = client.do_completion(processed_messages) 464 | return final_result 465 | except Exception as e: 466 | logger.error(f"Error in final consolidation: {e}") 467 | return validated_chunks[0] if validated_chunks else "" 468 | elif validated_chunks: 469 | return validated_chunks[0] 470 | 471 | return "" 472 | 473 | except Exception as e: 474 | logger.error(f"Error validating metrics: {e}") 475 | return "" 476 | 477 | async def run_web_extraction(url: str, schemas: List[Dict[str, str]], job_id: str = None) -> List[str]: 478 | """Handle extraction from web content using processing_handler.""" 479 | if not job_id: 480 | job_id = f"job_{int(time.time())}" 481 | 482 | logger.info(f"Starting web extraction for url: {url}, job_id: {job_id}") 483 | await track_progress(job_id, 0, len(schemas), "initialization") 484 | 485 | try: 486 | redis = await get_redis_connection() 487 | results = await process_web_content(redis, url, schemas) 488 | 489 | await track_progress(job_id, len(schemas), len(schemas), "completed", "success") 490 | return results 491 | 492 | except Exception as e: 493 | logger.error(f"Error in web extraction: {e}") 494 | await track_progress(job_id, 0, len(schemas), "failed", "error") 495 | return [] -------------------------------------------------------------------------------- /application/extraction/service/extraction_worker.py: -------------------------------------------------------------------------------- 1 | import os 2 | import asyncio 3 | import json 4 | import logging 5 | from typing import Optional, Dict, Any 6 | from redis.asyncio import Redis 7 | from redis.exceptions import RedisError 8 | from datetime import datetime 9 | from application.extraction.models.models import ( ExtractionRequestModel, 10 | ExtractionResponseModel, 11 | SchemaResult, 12 | JobStatus 13 | ) 14 | from application.extraction.service.extraction_handler import run_extraction, run_web_extraction 15 | from common.redis.redis_config import get_redis_connection 16 | 17 | logging.basicConfig(level=logging.INFO) 18 | logger = logging.getLogger(__name__) 19 | 20 | async def run_extractions() -> None: 21 | logger.info("Starting extraction worker") 22 | redis = await get_redis_connection() 23 | last_id = "0-0" 24 | 25 | while True: 26 | try: 27 | result = await redis.xread( 28 | streams={"extraction-stream": last_id}, 29 | count=1, 30 | block=0 31 | ) 32 | logger.info(f"Result: {result}") 33 | 34 | for stream_name, messages in result: 35 | for message_id, message in messages: 36 | logger.info(f"Received message from stream {stream_name}: ID {message_id}") 37 | payload = message.get(b"payload") 38 | if payload: 39 | try: 40 | logger.info(f"Payload value: {payload}") 41 | extraction_request = ExtractionRequestModel(**json.loads(payload.decode('utf-8'))) 42 | extraction_result = await process_extraction(extraction_request) 43 | serialized_result = json.dumps(extraction_result.model_dump()) 44 | logger.info(f"Pushing result to transformation-stream: {serialized_result}") 45 | await redis.xadd("transformation-stream", {"payload": serialized_result}) 46 | logger.info("Successfully pushed result to transformation-stream") 47 | except json.JSONDecodeError as e: 48 | logger.error(f"Failed to parse payload JSON: {e}") 49 | except Exception as e: 50 | logger.error(f"Error processing extraction task: {e}") 51 | else: 52 | logger.error("Message does not contain 'payload' field") 53 | last_id = message_id 54 | 55 | except RedisError as e: 56 | logger.error(f"Error reading from Redis stream: {e}") 57 | await asyncio.sleep(1) 58 | 59 | async def process_extraction(extraction_request: ExtractionRequestModel) -> ExtractionResponseModel: 60 | try: 61 | if extraction_request.source_type == "web": 62 | results = await run_web_extraction(extraction_request.pdf_key, extraction_request.schemas) 63 | else: 64 | results = await run_extraction(extraction_request.pdf_key, extraction_request.schemas) 65 | schema_results = [ 66 | SchemaResult( 67 | schema_id=f"schema_{index}", 68 | metrics={f"schema_{index}": result}, 69 | schema_data=schema 70 | ) 71 | for index, (schema, result) in enumerate(zip(extraction_request.schemas, results)) 72 | ] 73 | 74 | response = ExtractionResponseModel( 75 | task_id=extraction_request.task_id, 76 | pdf_key=extraction_request.pdf_key, 77 | results=schema_results, 78 | source_type=extraction_request.source_type 79 | ) 80 | 81 | redis = await get_redis_connection() 82 | await update_job_status(redis, extraction_request.task_id, JobStatus.PENDING, None) 83 | 84 | return response 85 | 86 | except Exception as e: 87 | logger.error(f"Error processing extraction task: {e}") 88 | redis = await get_redis_connection() 89 | # Log the type of the key causing the error 90 | key_type = await redis.type(extraction_request.pdf_key) 91 | logger.error(f"Key type for {extraction_request.pdf_key}: {key_type}") 92 | 93 | # Handle unexpected key types 94 | if key_type != b'string': 95 | logger.error(f"Unexpected key type for {extraction_request.pdf_key}. Deleting the key.") 96 | await redis.delete(extraction_request.pdf_key) 97 | 98 | await update_job_status(redis, extraction_request.task_id, JobStatus.FAILED, str(e)) 99 | raise e 100 | 101 | async def update_job_status( 102 | redis: Redis, 103 | task_id: str, 104 | status: JobStatus, 105 | error_message: Optional[str] 106 | ) -> None: 107 | fields: Dict[str, Any] = {"status": json.dumps(status.value)} 108 | if error_message: 109 | fields["error_message"] = error_message 110 | 111 | start_time_str = await redis.get(f"job-start-time:{task_id}") 112 | start_time = int(start_time_str) if start_time_str else 0 113 | current_time = int(datetime.utcnow().timestamp()) 114 | total_run_time = current_time - start_time 115 | run_time_str = f"{total_run_time // 60} minutes" if total_run_time >= 60 else f"{total_run_time} seconds" 116 | fields["total_run_time"] = run_time_str 117 | 118 | await redis.xadd(f"job-status:{task_id}", fields) 119 | 120 | async def clear_extraction_stream() -> None: 121 | retries = 0 122 | max_retries = 5 123 | delay = 1 124 | 125 | while retries < max_retries: 126 | try: 127 | redis = await get_redis_connection() 128 | await redis.xtrim("extraction-stream", approximate=True, maxlen=0) 129 | logger.info("Cleared extraction-stream") 130 | return 131 | except RedisError as e: 132 | if "LOADING" in str(e): 133 | logger.info(f"Redis is still loading. Retrying in {delay} seconds...") 134 | await asyncio.sleep(delay) 135 | retries += 1 136 | delay *= 2 137 | else: 138 | raise e 139 | 140 | raise Exception("Failed to clear extraction-stream after maximum retries") -------------------------------------------------------------------------------- /application/extraction/service/processing_handler.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from typing import List, Dict, Optional 3 | from redis.asyncio import Redis 4 | from bs4 import BeautifulSoup 5 | import re 6 | from common.models.model_factory import ModelFactory 7 | from application.extraction.models.models import ModelDetails 8 | from langsmith import Client as LangSmithClient 9 | from common.prompts.prompt_enums import PromptType 10 | from langchain_core.messages import SystemMessage, HumanMessage 11 | import json 12 | from common.agents.prs_agent import process_extraction 13 | from common.agents.agent_prompt_enums import AgentMode 14 | 15 | logger = logging.getLogger(__name__) 16 | langsmith_client = LangSmithClient() 17 | 18 | def web_preprocessing(html_content: str) -> str: 19 | # Parse the HTML 20 | soup = BeautifulSoup(html_content, 'html.parser') 21 | 22 | # Remove script and style elements 23 | for script_or_style in soup(["script", "style", "meta", "link"]): 24 | script_or_style.decompose() 25 | 26 | # Remove navigation, footer, ads, and sidebars 27 | for element in soup(["nav", "footer", "aside"]): 28 | element.decompose() 29 | 30 | # Remove elements with common ad-related class names 31 | ad_classes = ["ad", "advertisement", "banner", "sidebar"] 32 | for element in soup.find_all(class_=lambda x: x and any(cls in x for cls in ad_classes)): 33 | element.decompose() 34 | 35 | # Try to find the main content 36 | main_content = soup.find("main") or soup.find("article") 37 | if not main_content: 38 | main_content = max( 39 | soup.find_all("div", text=True), 40 | key=lambda div: len(div.get_text()), 41 | default=soup 42 | ) 43 | 44 | # Process links in the main content 45 | for a in main_content.find_all('a', href=True): 46 | href = a['href'] 47 | if not a.string: 48 | a.string = href 49 | else: 50 | a.string = f"{a.string} ({href})" 51 | 52 | # Extract text from the main content 53 | text = main_content.get_text(separator=' ', strip=True) 54 | 55 | # Clean the text 56 | text = re.sub(r'\s+', ' ', text) # Normalize whitespace 57 | text = re.sub(r'&[a-zA-Z]+;', '', text) # Remove HTML entities 58 | logger.info(f"Length of text: {len(text)}") 59 | return text 60 | 61 | async def get_latest_model_details(redis: Redis) -> Optional[ModelDetails]: 62 | try: 63 | model_details_json = await redis.get("model-details") 64 | if not model_details_json: 65 | logger.error("No data found for model-details") 66 | return None 67 | 68 | model_details = ModelDetails(**json.loads(model_details_json)) 69 | return model_details 70 | 71 | except Exception as e: 72 | logger.error(f"Failed to get or parse model details: {e}") 73 | return None 74 | 75 | async def process_web_content(redis: Redis, pdf_key: str, schemas: List[Dict[str, str]]) -> List[str]: 76 | logger.info(f"Starting web extraction process for pdf_key: {pdf_key}") 77 | 78 | # Retrieve the HTML content from Redis 79 | html_content = await redis.get(pdf_key) 80 | 81 | if not html_content: 82 | logger.error(f"No HTML content found for key: {pdf_key}") 83 | return [] 84 | 85 | # Preprocess the HTML content 86 | preprocessed_text = web_preprocessing(html_content.decode('utf-8')) 87 | 88 | # Get the latest model details 89 | model_details = await get_latest_model_details(redis) 90 | if not model_details: 91 | return [] 92 | 93 | # Create model instance 94 | try: 95 | model_instance = ModelFactory.create_model( 96 | model_type=model_details.provider_type, 97 | model_name=model_details.provider_model_name, 98 | api_key=model_details.api_key, 99 | additional_params=model_details.additional_params 100 | ) 101 | except ValueError as e: 102 | logger.error(f"Model creation error: {e}") 103 | return [] 104 | 105 | # Process each schema using PRS agent 106 | results = [] 107 | for schema in schemas: 108 | try: 109 | formatted_keywords = "\n".join([f"{k}: {v}" for k, v in schema.items()]) 110 | # Get example format from the schema 111 | example_format = await get_example_format(model_instance, formatted_keywords) 112 | # Use PRS agent for extraction 113 | text = f"""SOURCE DOCUMENT: 114 | {preprocessed_text} 115 | 116 | EXAMPLE FORMAT: 117 | {example_format} 118 | 119 | METRICS TO EXTRACT: 120 | {formatted_keywords}""" 121 | 122 | result = process_extraction(text, model_instance, AgentMode.EXTRACTION) 123 | results.append(result) 124 | except Exception as e: 125 | logger.error(f"Error processing schema: {e}") 126 | 127 | return results 128 | 129 | async def get_example_format(client, formatted_keywords: str) -> str: 130 | try: 131 | prompt = langsmith_client.pull_prompt(PromptType.EXAMPLE_GENERATION.value) 132 | messages = prompt.invoke({"first_value": formatted_keywords}) 133 | processed_messages = preprocess_messages(messages) 134 | if processed_messages: 135 | return client.do_completion(processed_messages) 136 | except Exception as e: 137 | logger.error(f"Error generating example format: {e}") 138 | return "" 139 | 140 | def preprocess_messages(raw_payload) -> List[Dict[str, str]]: 141 | messages = [] 142 | if hasattr(raw_payload, 'to_messages'): 143 | for message in raw_payload.to_messages(): 144 | if isinstance(message, SystemMessage): 145 | messages.append({"role": "system", "content": message.content}) 146 | elif isinstance(message, HumanMessage): 147 | messages.append({"role": "user", "content": message.content}) 148 | else: 149 | logger.warning(f"Unexpected message type: {type(message)}") 150 | else: 151 | logger.warning(f"Unexpected raw_payload format: {type(raw_payload)}") 152 | return messages -------------------------------------------------------------------------------- /application/extraction/start_extraction.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | from dotenv import load_dotenv 4 | import asyncio 5 | from application.extraction.service.extraction_worker import clear_extraction_stream, run_extractions 6 | 7 | load_dotenv() 8 | 9 | logging.basicConfig(level=logging.INFO, 10 | format='%(asctime)s - %(levelname)s - %(message)s') 11 | 12 | async def main(): 13 | try: 14 | await clear_extraction_stream() 15 | except Exception as e: 16 | logging.error("Failed to clear extraction-stream: %s", e) 17 | 18 | try: 19 | logging.info("Started extraction service...") 20 | await run_extractions() 21 | except Exception as e: 22 | logging.error("Application error: %s", e) 23 | exit(1) 24 | 25 | if __name__ == "__main__": 26 | asyncio.run(main()) 27 | -------------------------------------------------------------------------------- /application/pipeline/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM python:3.10-slim 2 | 3 | WORKDIR /app 4 | 5 | # Copy the common directory 6 | COPY ./common /app/common 7 | 8 | # Copy the pipeline application directory 9 | COPY ./application/pipeline /app/application/pipeline 10 | 11 | # Copy the requirements file 12 | COPY ./application/pipeline/requirements.txt /app/requirements.txt 13 | 14 | # Copy example files 15 | COPY ./examples/example_files /app/example_files 16 | 17 | RUN apt-get update && apt-get install -y \ 18 | build-essential \ 19 | make \ 20 | && rm -rf /var/lib/apt/lists/* 21 | 22 | # Install dependencies 23 | RUN pip install --no-cache-dir -r /app/requirements.txt 24 | 25 | # Set the Python path 26 | ENV PYTHONPATH /app 27 | 28 | # Expose the port 29 | EXPOSE 8002 30 | 31 | # Run the application 32 | CMD ["uvicorn", "application.pipeline.start_pipeline:app", "--host", "0.0.0.0", "--port", "8100"] -------------------------------------------------------------------------------- /application/pipeline/models/models.py: -------------------------------------------------------------------------------- 1 | from pydantic import BaseModel, Field 2 | from typing import List, Dict, Any 3 | from enum import Enum 4 | 5 | class JobStatus(str, Enum): 6 | PENDING = "PENDING" 7 | IN_PROGRESS = "IN_PROGRESS" 8 | COMPLETED = "COMPLETED" 9 | FAILED = "FAILED" 10 | 11 | class WorkloadItem(BaseModel): 12 | raw_data: str = Field(default=None) 13 | schemas: List[str] 14 | data_source: str = Field(default=None) 15 | documents_location: str = Field(default=None) 16 | file_name: str = Field(default=None) 17 | additional_params: Dict[str, Any] = Field(default_factory=dict) 18 | destination: str = Field(default=None) 19 | 20 | class PipelineRequestModel(BaseModel): 21 | workloads: List[WorkloadItem] 22 | provider_type: str 23 | provider_model_name: str 24 | api_key: str 25 | markdown_mode: bool = False 26 | additional_params: Dict[str, Any] = Field(default_factory=dict) 27 | 28 | class PipelineResponseModel(BaseModel): 29 | message: str 30 | task_id: str 31 | 32 | class PipelineResult(BaseModel): 33 | task_id: str 34 | status: JobStatus 35 | results: List[Dict] 36 | total_run_time: str 37 | 38 | class ExtractionRequestModel(BaseModel): 39 | task_id: str 40 | pdf_key: str 41 | schemas: List[Dict] 42 | source_type: str = "pdf" 43 | -------------------------------------------------------------------------------- /application/pipeline/requirements.txt: -------------------------------------------------------------------------------- 1 | uvicorn 2 | fastapi 3 | redis[hiredis] 4 | pydantic 5 | PyPDF2 6 | groq 7 | openai 8 | langchain 9 | langchainhub 10 | python-dotenv 11 | langsmith 12 | pdfminer.six 13 | boto3 14 | pymupdf 15 | cerebras_cloud_sdk 16 | aiohttp 17 | requests==2.32.3 18 | markitdown 19 | -------------------------------------------------------------------------------- /application/pipeline/routes/pipeline_routes.py: -------------------------------------------------------------------------------- 1 | from fastapi import APIRouter, HTTPException, status 2 | from application.pipeline.models.models import PipelineRequestModel, PipelineResponseModel, PipelineResult 3 | from application.pipeline.service.pipeline_service import run_pipeline, get_pipeline_results 4 | from fastapi.responses import JSONResponse 5 | import logging 6 | 7 | api_router = APIRouter() 8 | 9 | logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') 10 | logger = logging.getLogger(__name__) 11 | 12 | @api_router.post("/pipelines", 13 | response_model=PipelineResponseModel, 14 | status_code=status.HTTP_202_ACCEPTED, 15 | summary="Run pipeline", 16 | description="Initiates a pipeline processing job for the given PDF and schemas.", 17 | responses={ 18 | 202: {"description": "Pipeline processing started"}, 19 | 500: {"description": "Internal server error"} 20 | }) 21 | async def run_pipeline_route(pipeline_request: PipelineRequestModel): 22 | try: 23 | response = await run_pipeline(pipeline_request) 24 | if isinstance(response, dict) and "error" in response: 25 | raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=response["error"]) 26 | return JSONResponse(content={"task_id": response["task_id"], "message": "Pipeline processing started"}, status_code=status.HTTP_202_ACCEPTED) 27 | except Exception as e: 28 | raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e)) 29 | 30 | @api_router.get("/pipelines/{task_id}", 31 | response_model=PipelineResult, 32 | summary="Get pipeline results", 33 | description="Retrieves the results of a pipeline processing job.", 34 | responses={ 35 | 200: {"description": "Pipeline results retrieved successfully"}, 36 | 404: {"description": "Task not found"}, 37 | 500: {"description": "Internal server error"} 38 | }) 39 | async def get_pipeline_results_route(task_id: str): 40 | try: 41 | results = await get_pipeline_results(task_id) 42 | if isinstance(results, tuple) and results[1] == 500: 43 | raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=results[0]['error']) 44 | if not results: 45 | raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Task not found") 46 | return JSONResponse(content=results, status_code=status.HTTP_200_OK) 47 | except HTTPException as he: 48 | raise he 49 | except Exception as e: 50 | logger.error(f"Unexpected error in get_pipeline_results_route: {str(e)}") 51 | raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="An unexpected error occurred") 52 | -------------------------------------------------------------------------------- /application/pipeline/service/pipeline_service.py: -------------------------------------------------------------------------------- 1 | import json 2 | import uuid 3 | import base64 4 | import zlib 5 | import hashlib 6 | from io import BytesIO 7 | from typing import List, Dict, Optional 8 | import redis.asyncio as redis 9 | import logging 10 | import time 11 | import asyncio 12 | from urllib.parse import urlparse 13 | import aiohttp 14 | 15 | from application.pipeline.models.models import ( 16 | PipelineRequestModel, 17 | PipelineResponseModel, 18 | JobStatus, 19 | PipelineResult, 20 | ExtractionRequestModel, 21 | WorkloadItem 22 | ) 23 | from common.text_extraction.text_extractor import get_pdf_page_count 24 | from common.models.model_factory import ModelFactory 25 | from common.sources.source_factory import SourceFactory 26 | from langsmith import Client as LangSmithClient 27 | from langchain.schema import SystemMessage, HumanMessage 28 | 29 | logging.basicConfig(level=logging.INFO) 30 | logger = logging.getLogger(__name__) 31 | 32 | langsmith_client = LangSmithClient() 33 | 34 | async def run_pipeline(customer_input: PipelineRequestModel): 35 | logger.info("Starting Pipeline Run...") 36 | 37 | try: 38 | con = await redis.from_url("redis://redis:6379/0", encoding="utf-8", decode_responses=True) 39 | except Exception as e: 40 | logger.error(f"Redis connection error: {e}") 41 | return { 42 | "error": "Service temporarily unavailable", 43 | "details": "Unable to connect to the database. Please try again later." 44 | } 45 | 46 | task_id = str(uuid.uuid4()) 47 | start_time = int(time.time()) 48 | await con.set(f"job-start-time:{task_id}", start_time) 49 | await con.xadd( 50 | f"job-status:{task_id}", 51 | {"status": json.dumps(JobStatus.PENDING.value), "start_time": str(start_time)} 52 | ) 53 | 54 | try: 55 | ModelFactory.create_model( 56 | model_type=customer_input.provider_type, 57 | model_name=customer_input.provider_model_name, 58 | api_key=customer_input.api_key, 59 | additional_params=customer_input.additional_params 60 | ) 61 | except ValueError as e: 62 | logger.error(f"Model creation error: {e}") 63 | return { 64 | "error": "Invalid model configuration", 65 | "details": str(e) 66 | } 67 | 68 | # Publish details to be used by other workers 69 | model_details_json = json.dumps({ 70 | "provider_type": customer_input.provider_type, 71 | "provider_model_name": customer_input.provider_model_name, 72 | "api_key": customer_input.api_key, 73 | "markdown_mode": customer_input.markdown_mode, 74 | "additional_params": customer_input.additional_params 75 | }) 76 | await con.set("model-details", model_details_json) 77 | 78 | if is_transformation_only_job(customer_input.workloads): 79 | logger.info("Transformation-only job detected") 80 | return await handle_transformation(customer_input, con, task_id) 81 | else: 82 | logger.info("Full pipeline job detected") 83 | return await handle_full_pipeline(customer_input, con, task_id) 84 | 85 | async def handle_transformation(customer_input: PipelineRequestModel, con: redis.Redis, task_id: str): 86 | workload = customer_input.workloads[0] 87 | transformation_payload = { 88 | "task_id": task_id, 89 | "data_location_key": workload.documents_location, 90 | "schemas": workload.schemas, 91 | "destination": workload.destination, 92 | "raw_data": workload.raw_data 93 | } 94 | await con.xadd("transformation-only-stream", {"payload": json.dumps(transformation_payload)}) 95 | 96 | response = PipelineResponseModel( 97 | message="Transformation task submitted successfully", 98 | task_id=task_id 99 | ) 100 | return {"task_id": response.task_id, "message": response.message} 101 | 102 | async def handle_full_pipeline(customer_input: PipelineRequestModel, con: redis.Redis, task_id: str): 103 | await con.set(f"workload-count:{task_id}", len(customer_input.workloads)) 104 | 105 | pdf_hash = hashlib.sha256(json.dumps([w.dict() for w in customer_input.workloads]).encode()).hexdigest() 106 | cache_key = f"cache:{pdf_hash}" 107 | cached_hash = await con.get(cache_key) 108 | if cached_hash: 109 | logger.info(f"Cache hit for key: {cache_key}") 110 | cached_response = await con.get(cached_hash) 111 | if cached_response: 112 | return json.loads(cached_response) 113 | 114 | async def process_workload(index: int, workload_combo: WorkloadItem) -> int: 115 | try: 116 | if workload_combo.raw_data and workload_combo.data_source: 117 | logger.error(f"Workload {index} cannot have both raw_data and data_source.") 118 | raise ValueError("Workload cannot have both raw_data and data_source.") 119 | if workload_combo.raw_data: 120 | return await handle_raw_data(index, workload_combo, con, task_id) 121 | elif workload_combo.data_source == "web": 122 | return await handle_web_source(index, workload_combo, con, task_id) 123 | elif workload_combo.data_source: 124 | return await handle_data_source(index, workload_combo, con, task_id) 125 | else: 126 | logger.error(f"Workload {index} must have either raw_data or data_source.") 127 | return 0 128 | except Exception as e: 129 | logger.error(f"Error processing workload {index}: {e}") 130 | return 0 131 | 132 | workload_results = await asyncio.gather( 133 | *[process_workload(index, workload_combo) for index, workload_combo in enumerate(customer_input.workloads)], 134 | return_exceptions=False 135 | ) 136 | 137 | total_pages = sum(workload_results) 138 | logger.info(f"Total pages processed: {total_pages}") 139 | 140 | await con.xadd( 141 | f"job-status:{task_id}", 142 | {"status": json.dumps(JobStatus.IN_PROGRESS.value)} 143 | ) 144 | 145 | response = PipelineResponseModel( 146 | message="Tasks submitted successfully", 147 | task_id=task_id 148 | ) 149 | 150 | response_hash = hashlib.sha256(json.dumps(response.dict()).encode()).hexdigest() 151 | await con.setex(cache_key, 3600, response_hash) 152 | await con.setex(response_hash, 3600, json.dumps(response.dict())) 153 | 154 | return {"task_id": response.task_id, "message": response.message} 155 | 156 | def is_transformation_only_job(workloads: List[WorkloadItem]) -> bool: 157 | if len(workloads) != 1: 158 | return False 159 | workload = workloads[0] 160 | return ( 161 | workload.destination is not None and 162 | workload.documents_location is not None and 163 | workload.schemas is not None and 164 | workload.raw_data is not None and 165 | workload.data_source is None 166 | ) 167 | 168 | async def handle_raw_data(index: int, workload_combo: WorkloadItem, con: redis.Redis, task_id: str) -> int: 169 | logger.info(f"Processing workload {index} with raw_data.") 170 | # Decode the base64 encoded data stream 171 | try: 172 | decompressed_data = zlib.decompress(base64.b64decode(workload_combo.raw_data)) 173 | except zlib.error as e: 174 | logger.error(f"Decompression failed for workload {index}: {e}") 175 | return 0 176 | 177 | data_cursor = BytesIO(decompressed_data) 178 | logger.info(f"Decompressed data for workload {index}") 179 | 180 | try: 181 | page_count = get_pdf_page_count(data_cursor) 182 | logger.debug(f"Page count for workload {index}: {page_count}") 183 | except Exception as e: 184 | logger.error(f"Error getting page count for workload {index}: {e}") 185 | return 0 186 | 187 | data_key = f"data:{task_id}:{index}" 188 | await con.set(data_key, base64.b64encode(decompressed_data).decode()) 189 | logger.info(f"Data stored in Redis with key: {data_key} as base64 string") 190 | 191 | schemas = [json.loads(schema) for schema in workload_combo.schemas] 192 | 193 | task_payload = ExtractionRequestModel( 194 | task_id=task_id, 195 | pdf_key=data_key, 196 | schemas=schemas 197 | ) 198 | 199 | await con.xadd("extraction-stream", {"payload": json.dumps(task_payload.dict())}) 200 | 201 | return page_count 202 | 203 | async def fetch_document(session, workload_combo): 204 | try: 205 | async with session.get(workload_combo.documents_location) as response: 206 | response.raise_for_status() # This will raise an HTTPError for bad responses 207 | content = await response.read() 208 | return content 209 | except aiohttp.ClientError as e: 210 | logger.error(f"Error fetching document: {str(e)}") 211 | raise 212 | 213 | 214 | async def handle_web_source(index: int, workload_combo: WorkloadItem, con: redis.Redis, task_id: str) -> int: 215 | logger.info(f"Processing workload {index} with web source: {workload_combo.documents_location}") 216 | 217 | # Fetch the web content 218 | # You might want to use a library like aiohttp for asynchronous HTTP requests 219 | # Here I need to make sure that the url is valid and that it returns a 200 220 | # Validate URL 221 | try: 222 | result = urlparse(workload_combo.documents_location) 223 | if not all([result.scheme, result.netloc]): 224 | logger.error(f"Invalid URL: {workload_combo.documents_location}") 225 | return 0 226 | except ValueError: 227 | logger.error(f"Invalid URL format: {workload_combo.documents_location}") 228 | return 0 229 | logger.info("URL is valid") 230 | logger.info("Fetching document content") 231 | async with aiohttp.ClientSession() as session: 232 | try: 233 | document_content = await fetch_document(session, workload_combo) 234 | except aiohttp.ClientError as e: 235 | logger.error(f"Failed to connect to URL: {str(e)}") 236 | return 0 237 | 238 | # Process the HTML content as needed 239 | # For example, you might want to extract text or specific elements 240 | # This is where the preprocessing could come into play, or I can add it 241 | # Store the processed content in Redis 242 | content_key = f"web:{task_id}:{index}" 243 | await con.set(content_key, document_content) 244 | 245 | # Create and add the extraction task 246 | schemas = [json.loads(schema) for schema in workload_combo.schemas] 247 | task_payload = ExtractionRequestModel( 248 | task_id=task_id, 249 | pdf_key=content_key, 250 | schemas=schemas, 251 | source_type="web" 252 | ) 253 | await con.xadd("extraction-stream", {"payload": json.dumps(task_payload.dict())}) 254 | 255 | return 1 256 | 257 | 258 | async def handle_data_source(index: int, workload_combo: WorkloadItem, con: redis.Redis, task_id: str) -> int: 259 | logger.info(f"Processing workload {index} with data_source: {workload_combo.data_source}") 260 | logger.info(f"Documents location: {workload_combo.documents_location}") 261 | source = SourceFactory.create_source ( 262 | source_type=workload_combo.data_source, 263 | documents_location=workload_combo.documents_location, 264 | additional_params=workload_combo.additional_params 265 | ) 266 | logger.info(f"Created source: {source}") 267 | 268 | all_files = source.read_all() 269 | logger.info(f"List of all files: {all_files}") 270 | if not all_files: 271 | logger.warning(f"No files found in data source: {workload_combo.data_source}") 272 | return 0 273 | 274 | relevant_file = await get_relevant_file_via_llm(all_files, workload_combo.file_name) 275 | if not relevant_file: 276 | logger.warning(f"LLM did not return a valid file for workload {index}") 277 | return 0 278 | 279 | logger.info(f"Selected relevant file: {relevant_file}") 280 | 281 | file_stream = source.read({"file_key": relevant_file}) 282 | if not file_stream: 283 | logger.warning(f"Failed to read the selected file: {relevant_file}") 284 | return 0 285 | 286 | compressed_data = base64.b64encode(zlib.compress(file_stream.getvalue())).decode('utf-8') 287 | 288 | try: 289 | decompressed_pdf = zlib.decompress(base64.b64decode(compressed_data)) 290 | except zlib.error as e: 291 | logger.error(f"Decompression failed for workload {index}: {e}") 292 | return 0 293 | 294 | pdf_cursor = BytesIO(decompressed_pdf) 295 | logger.info(f"Decompressed PDF for workload {index}") 296 | 297 | try: 298 | page_count = get_pdf_page_count(pdf_cursor) 299 | logger.debug(f"Page count for workload {index}: {page_count}") 300 | except Exception as e: 301 | logger.error(f"Error getting page count for workload {index}: {e}") 302 | return 0 303 | 304 | pdf_key = f"pdf:{task_id}:{index}" 305 | await con.set(pdf_key, base64.b64encode(decompressed_pdf).decode()) 306 | logger.info(f"PDF stored in Redis with key: {pdf_key} as base64 string") 307 | 308 | schemas = [json.loads(schema) for schema in workload_combo.schemas] 309 | 310 | task_payload = ExtractionRequestModel( 311 | task_id=task_id, 312 | pdf_key=pdf_key, 313 | schemas=schemas 314 | ) 315 | 316 | await con.xadd("extraction-stream", {"payload": json.dumps(task_payload.dict())}) 317 | 318 | return page_count 319 | 320 | def preprocess_messages(raw_payload): 321 | messages = [] 322 | if hasattr(raw_payload, 'to_messages'): 323 | for message in raw_payload.to_messages(): 324 | if isinstance(message, SystemMessage): 325 | messages.append({"role": "system", "content": message.content}) 326 | elif isinstance(message, HumanMessage): 327 | messages.append({"role": "user", "content": message.content}) 328 | else: 329 | logger.warning(f"Unexpected message type: {type(message)}") 330 | else: 331 | logger.warning(f"Unexpected raw_payload format: {type(raw_payload)}") 332 | return messages 333 | 334 | async def get_relevant_file_via_llm(filenames: List[str], file_name: str) -> Optional[str]: 335 | if not filenames: 336 | logger.error("No filenames provided to determine relevance.") 337 | return None 338 | 339 | try: 340 | con = await redis.from_url("redis://redis:6379/0", encoding="utf-8", decode_responses=True) 341 | customer_input_str = await con.get("model-details") 342 | logger.info(f"Model details JSON: {customer_input_str}") 343 | if not customer_input_str: 344 | logger.error("Model details not found in Redis.") 345 | return None 346 | 347 | customer_input = json.loads(customer_input_str) 348 | 349 | model_instance = ModelFactory.create_model( 350 | model_type=customer_input["provider_type"], 351 | model_name=customer_input["provider_model_name"], 352 | api_key=customer_input["api_key"], 353 | additional_params=customer_input["additional_params"] 354 | ) 355 | 356 | prompt = langsmith_client.pull_prompt("marly/get-relevant-file") 357 | 358 | messages = prompt.invoke({ 359 | "first_value": file_name, 360 | "second_value": filenames 361 | }) 362 | 363 | processed_messages = preprocess_messages(messages) 364 | 365 | if not processed_messages: 366 | logger.error("No messages to process for determining relevant file.") 367 | return None 368 | 369 | relevant_file = model_instance.do_completion(processed_messages) 370 | 371 | if relevant_file in filenames: 372 | return relevant_file 373 | else: 374 | logger.error(f"LLM returned an invalid filename: {relevant_file}") 375 | return None 376 | 377 | except Exception as e: 378 | logger.exception("An error occurred while determining the relevant file via LLM.") 379 | return None 380 | 381 | async def get_results_from_stream(con: redis.Redis, task_id: str) -> List[Dict]: 382 | entries = await con.xrange(f"job-status:{task_id}") 383 | results = [] 384 | for _, entry in entries: 385 | if b'result' in entry: 386 | result = json.loads(entry[b'result']) 387 | if 'results' in result: 388 | results.extend(result['results']) 389 | return results 390 | 391 | async def get_pipeline_results(task_id: str): 392 | try: 393 | con = await redis.from_url("redis://redis:6379/0", encoding="utf-8", decode_responses=True) 394 | except Exception as e: 395 | logger.error(f"Failed to connect to Redis: {e}") 396 | return {"error": "Failed to connect to Redis"}, 500 397 | 398 | try: 399 | status_stream = await con.xrange(f"job-status:{task_id}") 400 | except Exception as e: 401 | logger.error(f"Failed to fetch job status: {e}") 402 | return {"error": "Failed to fetch job status"}, 500 403 | 404 | if not status_stream: 405 | return {"error": "Task not found"}, 404 406 | 407 | latest_status = JobStatus.PENDING 408 | total_run_time = 'N/A' 409 | all_results = [] 410 | 411 | for _, entry in status_stream: 412 | try: 413 | entry_status = JobStatus(json.loads(entry['status'])) 414 | latest_status = entry_status 415 | except (json.JSONDecodeError, ValueError): 416 | logger.error(f"Invalid status value: {entry['status']}") 417 | 418 | if 'total_run_time' in entry: 419 | total_run_time = entry['total_run_time'] 420 | 421 | if 'result' in entry: 422 | try: 423 | result_data = json.loads(entry['result']) 424 | all_results.append(result_data) 425 | except json.JSONDecodeError: 426 | logger.error(f"Failed to parse result JSON: {entry['result']}") 427 | 428 | response = PipelineResult( 429 | task_id=task_id, 430 | status=latest_status, 431 | results=all_results, 432 | total_run_time=total_run_time 433 | ) 434 | 435 | return response.dict() 436 | 437 | 438 | 439 | 440 | 441 | -------------------------------------------------------------------------------- /application/pipeline/start_pipeline.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | 4 | # Add the parent directory of 'application' to sys.path 5 | sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))) 6 | 7 | from fastapi import FastAPI 8 | from fastapi.middleware.cors import CORSMiddleware 9 | from application.pipeline.routes.pipeline_routes import api_router as pipeline_api_router 10 | import uvicorn 11 | 12 | app = FastAPI( 13 | title="Marly API", 14 | description="The Data Processor for Agents", 15 | docs_url='/docs', 16 | openapi_url='/openapi.json' 17 | ) 18 | 19 | app.add_middleware( 20 | CORSMiddleware, 21 | allow_origins=["*"], 22 | allow_methods=["*"], 23 | allow_headers=["*"], 24 | allow_credentials=True, 25 | ) 26 | 27 | 28 | 29 | app.include_router(pipeline_api_router, prefix="", tags=["Pipeline Execution"]) 30 | 31 | if __name__ == "__main__": 32 | uvicorn.run(app, host="127.0.0.1", port=8100) 33 | -------------------------------------------------------------------------------- /application/transformation/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM python:3.10-slim 2 | 3 | WORKDIR /app 4 | 5 | # Copy the common directory 6 | COPY ./common /app/common 7 | 8 | # Copy the transformation application directory 9 | COPY ./application/transformation /app/application/transformation 10 | 11 | # Copy the requirements file 12 | COPY ./application/transformation/requirements.txt /app/requirements.txt 13 | 14 | # Install dependencies 15 | RUN pip install --no-cache-dir -r /app/requirements.txt 16 | 17 | # Set the Python path 18 | ENV PYTHONPATH /app 19 | 20 | # Run the transformation service 21 | CMD ["python", "application/transformation/start_transformation.py"] -------------------------------------------------------------------------------- /application/transformation/models/models.py: -------------------------------------------------------------------------------- 1 | from pydantic import BaseModel 2 | from typing import List, Dict 3 | from enum import Enum 4 | 5 | class TransformationRequestModel(BaseModel): 6 | task_id: str 7 | pdf_key: str 8 | results: List['SchemaResult'] 9 | source_type: str = "pdf" 10 | destination: str = None 11 | 12 | class TransformationOnlyRequestModel(BaseModel): 13 | task_id: str 14 | data_location_key: str 15 | schemas: List[str] 16 | destination: str 17 | raw_data: str 18 | 19 | class TransformationResponseModel(BaseModel): 20 | task_id: str 21 | pdf_key: str 22 | results: List['SchemaResult'] 23 | 24 | class SchemaResult(BaseModel): 25 | schema_id: str 26 | metrics: Dict[str, str] 27 | schema_data: Dict[str, str] 28 | 29 | class JobStatus(str, Enum): 30 | PENDING = "PENDING" 31 | IN_PROGRESS = "IN_PROGRESS" 32 | COMPLETED = "COMPLETED" 33 | FAILED = "FAILED" 34 | 35 | class ExtractionResponseModel(BaseModel): 36 | task_id: str 37 | pdf_key: str 38 | results: List[SchemaResult] 39 | 40 | class ModelDetails(BaseModel): 41 | provider_type: str 42 | provider_model_name: str 43 | api_key: str 44 | markdown_mode: bool 45 | additional_params: Dict[str, str] 46 | -------------------------------------------------------------------------------- /application/transformation/requirements.txt: -------------------------------------------------------------------------------- 1 | asyncio 2 | typing 3 | redis[hiredis] 4 | python-dotenv 5 | pydantic 6 | PyPDF2 7 | aiohttp 8 | openai 9 | groq 10 | langchain 11 | langchainhub 12 | langsmith 13 | pdfminer.six 14 | cerebras_cloud_sdk 15 | boto3 16 | requests==2.32.3 17 | markitdown -------------------------------------------------------------------------------- /application/transformation/service/transformation_handler.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import logging 3 | from typing import Dict, List 4 | from redis.asyncio import Redis 5 | from common.redis.redis_config import get_redis_connection 6 | from common.models.model_factory import ModelFactory 7 | from application.transformation.models.models import ModelDetails 8 | from langsmith import Client as LangSmithClient 9 | from common.prompts.prompt_enums import PromptType 10 | from langchain.schema import SystemMessage, HumanMessage 11 | from common.destinations.destination_factory import DestinationFactory 12 | from common.destinations.enums.destination_enums import DestinationType 13 | import json 14 | 15 | logging.basicConfig(level=logging.INFO) 16 | logger = logging.getLogger(__name__) 17 | 18 | langsmith_client = LangSmithClient() 19 | 20 | def preprocess_messages(raw_payload): 21 | messages = [] 22 | if hasattr(raw_payload, 'to_messages'): 23 | for message in raw_payload.to_messages(): 24 | if isinstance(message, SystemMessage): 25 | messages.append({"role": "system", "content": message.content}) 26 | elif isinstance(message, HumanMessage): 27 | messages.append({"role": "user", "content": message.content}) 28 | else: 29 | logger.warning(f"Unexpected message type: {type(message)}") 30 | else: 31 | logger.warning(f"Unexpected raw_payload format: {type(raw_payload)}") 32 | return messages 33 | 34 | async def get_latest_model_details(redis: Redis) -> ModelDetails: 35 | try: 36 | model_details_json = await redis.get("model-details") 37 | if not model_details_json: 38 | logger.error("No data found for model-details") 39 | return None 40 | 41 | model_details = ModelDetails(**json.loads(model_details_json)) 42 | return model_details 43 | 44 | except Exception as e: 45 | logger.error(f"Failed to get or parse model details: {e}") 46 | return None 47 | 48 | async def run_transformation(metrics: Dict[str, str], schema: Dict[str, str], source_type: str) -> Dict[str, str]: 49 | logger.info("Starting transformation process") 50 | 51 | redis: Redis = await get_redis_connection() 52 | model_details = await get_latest_model_details(redis) 53 | if not model_details: 54 | return {} 55 | 56 | try: 57 | model_instance = ModelFactory.create_model( 58 | model_type=model_details.provider_type, 59 | model_name=model_details.provider_model_name, 60 | api_key=model_details.api_key, 61 | additional_params=model_details.additional_params 62 | ) 63 | markdown_mode = model_details.markdown_mode 64 | logger.info(f"Model instance created with type: {model_details.provider_type}, name: {model_details.provider_model_name}") 65 | except ValueError as e: 66 | logger.error(f"Model creation error: {e}") 67 | return {} 68 | 69 | schema_keys = ",".join(schema.keys()) 70 | 71 | tasks = [ 72 | asyncio.create_task(process_schema(model_instance, schema_id, metric_value, schema_keys, markdown_mode, source_type)) 73 | for schema_id, metric_value in metrics.items() 74 | ] 75 | 76 | results = await asyncio.gather(*tasks, return_exceptions=True) 77 | transformed_metrics = {schema_id: result for schema_id, result in zip(metrics.keys(), results) if isinstance(result, str)} 78 | logger.info(f"Transformed metrics: {transformed_metrics}") 79 | 80 | return transformed_metrics 81 | 82 | async def process_schema(client, schema_id: str, metric_value: str, schema_keys: str, markdown_mode: bool, source_type: str) -> str: 83 | try: 84 | if source_type == 'web': 85 | prompt = langsmith_client.pull_prompt(PromptType.TRANSFORMATION_WEB.value) 86 | elif markdown_mode: 87 | prompt = langsmith_client.pull_prompt(PromptType.TRANSFORMATION_MARKDOWN.value) 88 | else: 89 | prompt = langsmith_client.pull_prompt(PromptType.TRANSFORMATION.value) 90 | 91 | messages = prompt.invoke({ 92 | "first_value": metric_value, 93 | "second_value": schema_keys 94 | }) 95 | processed_messages = preprocess_messages(messages) 96 | if not processed_messages: 97 | logger.error("No messages to process for transformation") 98 | return "" 99 | if markdown_mode: 100 | transformed_metric = client.do_completion(processed_messages) 101 | else: 102 | transformed_metric = client.do_completion(processed_messages, response_format={"type": "json_object"}) 103 | return transformed_metric 104 | except Exception as e: 105 | logger.error(f"Error transforming metric for schema {schema_id}: {e}") 106 | return "" 107 | 108 | async def run_transformation_only(task_id: str, data_location_key: str, schemas: List[str], destination: str, raw_data: str) -> Dict[str, str]: 109 | # eventually we will want to do destination writes here 110 | logger.info(f"Starting transformation-only process for task_id: {task_id}") 111 | 112 | redis: Redis = await get_redis_connection() 113 | model_details = await get_latest_model_details(redis) 114 | if not model_details: 115 | return {} 116 | 117 | try: 118 | model_instance = ModelFactory.create_model( 119 | model_type=model_details.provider_type, 120 | model_name=model_details.provider_model_name, 121 | api_key=model_details.api_key, 122 | additional_params=model_details.additional_params 123 | ) 124 | markdown_mode = model_details.markdown_mode 125 | logger.info(f"Model instance created with type: {model_details.provider_type}, name: {model_details.provider_model_name}") 126 | except ValueError as e: 127 | logger.error(f"Model creation error: {e}") 128 | return {} 129 | 130 | destination_config = { 131 | "db_path": destination, 132 | "additional_params": {} 133 | } 134 | destination_instance = DestinationFactory.create_destination(DestinationType.SQLITE.value, destination_config) 135 | 136 | try: 137 | database_schema = destination_instance.get_table_structure(data_location_key) 138 | except AttributeError: 139 | logger.error("The destination does not support get_table_structure method") 140 | database_schema = {} 141 | 142 | transformed_metrics = {} 143 | for schema in schemas: 144 | try: 145 | logger.info(f"schema: {schema}, raw_data: {raw_data}, database_schema: {database_schema}") 146 | prompt = langsmith_client.pull_prompt(PromptType.TRANSFORMATION_ONLY.value) 147 | messages = prompt.invoke({ 148 | "first_value": schema, 149 | "second_value": raw_data, 150 | "third_value": database_schema 151 | }) 152 | processed_messages = preprocess_messages(messages) 153 | if not processed_messages: 154 | logger.error(f"No messages to process for transformation of schema: {schema}") 155 | continue 156 | if markdown_mode: 157 | result = model_instance.do_completion(processed_messages) 158 | else: 159 | logger.info(f"Transforming schema: {schema} with JSON response format") 160 | result = model_instance.do_completion(processed_messages, response_format={"type": "json_object"}) 161 | transformed_metrics[schema] = result 162 | except Exception as e: 163 | logger.error(f"Error transforming metric for schema {schema}: {e}") 164 | 165 | logger.info(f"Transformed metrics for task_id {task_id}: {transformed_metrics}") 166 | return transformed_metrics 167 | -------------------------------------------------------------------------------- /application/transformation/service/transformation_worker.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import json 3 | import logging 4 | from typing import Optional, Dict, Any, Union 5 | from redis.asyncio import Redis 6 | from redis.exceptions import RedisError 7 | from datetime import datetime 8 | from application.transformation.models.models import ( 9 | TransformationRequestModel, 10 | TransformationResponseModel, 11 | SchemaResult, 12 | JobStatus, 13 | TransformationOnlyRequestModel 14 | ) 15 | from application.transformation.service.transformation_handler import run_transformation, run_transformation_only 16 | from common.redis.redis_config import get_redis_connection 17 | 18 | logging.basicConfig(level=logging.INFO) 19 | logger = logging.getLogger(__name__) 20 | 21 | async def run_transformations() -> None: 22 | logger.info("Starting transformation worker") 23 | redis = await get_redis_connection() 24 | last_id_transformation = "0-0" 25 | last_id_transformation_only = "0-0" 26 | 27 | task_workloads = {} 28 | 29 | while True: 30 | try: 31 | result = await redis.xread( 32 | streams={ 33 | "transformation-stream": last_id_transformation, 34 | "transformation-only-stream": last_id_transformation_only 35 | }, 36 | count=1, 37 | block=0 38 | ) 39 | 40 | for stream_name, messages in result: 41 | for message_id, message in messages: 42 | logger.info(f"Received message from stream {stream_name}: ID {message_id}") 43 | payload = message.get(b"payload") 44 | if payload: 45 | try: 46 | payload_dict = json.loads(payload.decode('utf-8')) 47 | 48 | if stream_name == b"transformation-stream": 49 | request = TransformationRequestModel(**payload_dict) 50 | else: 51 | request = TransformationOnlyRequestModel(**payload_dict) 52 | 53 | if request.task_id not in task_workloads: 54 | total_workloads = await get_total_workloads(redis, request.task_id) 55 | task_workloads[request.task_id] = { 56 | 'total': total_workloads, 57 | 'completed': 0 58 | } 59 | 60 | existing_results = await get_existing_results(redis, request.task_id) 61 | transformation_result = await process_transformation(request) 62 | merged_results = merge_results(existing_results, transformation_result) 63 | 64 | serialized_result = json.dumps(merged_results.dict()) 65 | await redis.xadd(f"results-stream:{request.task_id}", {"payload": serialized_result}) 66 | 67 | task_workloads[request.task_id]['completed'] += 1 68 | 69 | if task_workloads[request.task_id]['completed'] == task_workloads[request.task_id]['total']: 70 | logger.info(f"All workloads completed for task {request.task_id}") 71 | await update_job_status(redis, request.task_id, JobStatus.COMPLETED, serialized_result) 72 | del task_workloads[request.task_id] # Cleanup 73 | else: 74 | await update_job_status(redis, request.task_id, JobStatus.IN_PROGRESS, None) 75 | 76 | except Exception as e: 77 | logger.error(f"Error processing transformation task: {e}") 78 | await update_job_status(redis, request.task_id, JobStatus.FAILED, str(e)) 79 | if request.task_id in task_workloads: 80 | del task_workloads[request.task_id] 81 | 82 | if stream_name == b"transformation-stream": 83 | last_id_transformation = message_id 84 | elif stream_name == b"transformation-only-stream": 85 | last_id_transformation_only = message_id 86 | 87 | except RedisError as e: 88 | logger.error(f"Error reading from Redis stream: {e}") 89 | await asyncio.sleep(1) 90 | 91 | async def process_transformation( 92 | transformation_request: Union[TransformationRequestModel, TransformationOnlyRequestModel] 93 | ) -> TransformationResponseModel: 94 | redis = await get_redis_connection() 95 | try: 96 | transformed_results = [] 97 | 98 | if isinstance(transformation_request, TransformationOnlyRequestModel): 99 | transformed_metrics = await run_transformation_only( 100 | task_id=transformation_request.task_id, 101 | data_location_key=transformation_request.data_location_key, 102 | schemas=transformation_request.schemas, 103 | destination=transformation_request.destination, 104 | raw_data=transformation_request.raw_data 105 | ) 106 | if isinstance(transformed_metrics, str): 107 | transformed_metrics = json.loads(transformed_metrics) 108 | 109 | transformed_results.append(SchemaResult( 110 | schema_id=f"{transformation_request.task_id}-transformation_only", 111 | schema_data={}, 112 | metrics=transformed_metrics 113 | )) 114 | else: 115 | for schema_result in transformation_request.results: 116 | transformed_metrics = await run_transformation( 117 | metrics=schema_result.metrics, 118 | schema=schema_result.schema_data, 119 | source_type=transformation_request.source_type 120 | ) 121 | if isinstance(transformed_metrics, str): 122 | transformed_metrics = json.loads(transformed_metrics) 123 | 124 | transformed_results.append(SchemaResult( 125 | schema_id=schema_result.schema_id, 126 | schema_data=schema_result.schema_data, 127 | metrics=transformed_metrics 128 | )) 129 | 130 | response = TransformationResponseModel( 131 | task_id=transformation_request.task_id, 132 | pdf_key=transformation_request.data_location_key if isinstance(transformation_request, TransformationOnlyRequestModel) else transformation_request.pdf_key, 133 | results=transformed_results 134 | ) 135 | 136 | return response 137 | except Exception as e: 138 | logger.error(f"Error processing transformation task {transformation_request.task_id}: {e}") 139 | await update_job_status(redis, transformation_request.task_id, JobStatus.FAILED, str(e)) 140 | raise 141 | 142 | async def update_job_status( 143 | redis: Redis, 144 | task_id: str, 145 | status: JobStatus, 146 | result: Optional[str] 147 | ) -> None: 148 | fields: Dict[str, Any] = {"status": json.dumps(status.value)} 149 | if result: 150 | fields["result"] = result 151 | 152 | start_time_str = await redis.get(f"job-start-time:{task_id}") 153 | start_time = int(start_time_str) if start_time_str else 0 154 | current_time = int(datetime.utcnow().timestamp()) 155 | total_run_time = current_time - start_time 156 | run_time_str = f"{total_run_time // 60} minutes" if total_run_time >= 60 else f"{total_run_time} seconds" 157 | fields["total_run_time"] = run_time_str 158 | 159 | await redis.xadd(f"job-status:{task_id}", fields) 160 | 161 | async def get_existing_results(redis: Redis, task_id: str) -> Optional[TransformationResponseModel]: 162 | try: 163 | entries = await redis.xrevrange(f"results-stream:{task_id}", count=1) 164 | if entries: 165 | _, message = entries[0] 166 | payload = message.get(b"payload") 167 | if payload: 168 | return TransformationResponseModel(**json.loads(payload.decode('utf-8'))) 169 | except Exception as e: 170 | logger.error(f"Error getting existing results: {e}") 171 | return None 172 | 173 | def merge_results(existing: Optional[TransformationResponseModel], new: TransformationResponseModel) -> TransformationResponseModel: 174 | if not existing: 175 | return new 176 | 177 | return TransformationResponseModel( 178 | task_id=new.task_id, 179 | pdf_key=new.pdf_key, 180 | results=existing.results + new.results 181 | ) 182 | 183 | async def get_total_workloads(redis: Redis, task_id: str) -> int: 184 | """Get the total number of workloads for a task from Redis.""" 185 | try: 186 | workload_count = await redis.get(f"workload-count:{task_id}") 187 | return int(workload_count) if workload_count else 1 188 | except Exception as e: 189 | logger.error(f"Error getting workload count: {e}") 190 | return 1 191 | 192 | async def clear_transformation_streams() -> None: 193 | retries = 0 194 | max_retries = 5 195 | delay = 1 196 | 197 | while retries < max_retries: 198 | try: 199 | redis = await get_redis_connection() 200 | await redis.xtrim("transformation-stream", approximate=True, maxlen=0) 201 | await redis.xtrim("transformation-only-stream", approximate=True, maxlen=0) 202 | logger.info("Cleared transformation-stream and transformation-only-stream") 203 | return 204 | except RedisError as e: 205 | if "LOADING" in str(e): 206 | logger.info(f"Redis is still loading. Retrying in {delay} seconds...") 207 | await asyncio.sleep(delay) 208 | retries += 1 209 | delay *= 2 210 | else: 211 | raise e 212 | 213 | raise Exception("Failed to clear transformation streams after maximum retries") 214 | -------------------------------------------------------------------------------- /application/transformation/start_transformation.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | from dotenv import load_dotenv 4 | import asyncio 5 | from application.transformation.service.transformation_worker import clear_transformation_streams, run_transformations 6 | 7 | load_dotenv() 8 | 9 | logging.basicConfig(level=logging.INFO, 10 | format='%(asctime)s - %(levelname)s - %(message)s') 11 | 12 | async def main(): 13 | try: 14 | await clear_transformation_streams() 15 | except Exception as e: 16 | logging.error("Failed to clear transformation streams: %s", e) 17 | 18 | try: 19 | logging.info("Started transformation service...") 20 | await run_transformations() 21 | except Exception as e: 22 | logging.error("Application error: %s", e) 23 | exit(1) 24 | 25 | if __name__ == "__main__": 26 | asyncio.run(main()) 27 | -------------------------------------------------------------------------------- /common/agents/agent_prompt_enums.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | 3 | class AgentMode(Enum): 4 | EXTRACTION = "extraction" 5 | PAGE_FINDER = "page_finder" 6 | 7 | class ExtractionPrompts(Enum): 8 | SYSTEM = """You are a precise data extraction assistant focused on consolidating and deduplicating information. Your task is to: 9 | 1. Extract and consolidate metrics from all sources 10 | 2. Remove duplicate entries and redundant information 11 | 3. Preserve EXACT units and symbols as they appear 12 | 13 | Guidelines: 14 | - Consolidate repeated information into single, definitive entries 15 | - When conflicting values exist, prefer confirmed/definitive information over rumors 16 | - Remove duplicate entries that refer to the same event/metric 17 | - Keep ALL units and symbols EXACTLY as they appear (e.g., keep "$15 million" as is) 18 | - Never perform unit conversions or reformatting 19 | - Mark as "Not Available" only when no reliable value exists 20 | - For multiple genuine distinct instances (e.g., different funding rounds), maintain separate entries 21 | 22 | Remember: Focus on providing clean, deduplicated data while maintaining accuracy.""" 23 | 24 | ANALYSIS = """You are a precise analysis expert focused on deduplication and consolidation. Your task is to: 25 | 1. Find duplicate or redundant information 26 | 2. Identify consolidation opportunities 27 | 3. List improvements already made 28 | 29 | Format your response using ONLY these markers: 30 | - For duplicates/conflicts: Start lines with "⚠ Duplicate:" or "⚠ Conflict:" 31 | - For improvements/consolidations: Start lines with "✓ Consolidated:" 32 | 33 | Example format: 34 | ⚠ Duplicate: Found duplicate funding amount "$50M" in both Series A and B sections 35 | ⚠ Conflict: Different dates for same event - "March 2021" vs "03/2021" 36 | ✓ Consolidated: Combined Series A details into single entry 37 | 38 | Guidelines: 39 | - Each issue must start with "⚠" 40 | - Each improvement must start with "✓" 41 | - Be specific about what needs to be consolidated 42 | - Identify exact duplicates or conflicts 43 | - Keep each line focused on one issue/improvement 44 | 45 | Remember: Focus on finding information that can be combined or deduplicated.""" 46 | 47 | FIX = """You are a precise consolidation and deduplication expert. Your task is to: 48 | 1. Review the current extraction result 49 | 2. Apply consolidation and deduplication fixes 50 | 3. Verify the accuracy of merged information 51 | 52 | Guidelines: 53 | - Merge duplicate entries into single, authoritative records 54 | - When consolidating conflicting values: 55 | * Prefer confirmed facts over rumors 56 | * Use the most specific/detailed version 57 | * Maintain original formatting and units 58 | - Keep entries separate only if they are genuinely distinct events 59 | - Document your consolidation decisions 60 | - Verify that no information is lost during consolidation 61 | 62 | Remember: Each consolidation must be verified and justified.""" 63 | 64 | CONFIDENCE = """Rate the extraction confidence on a scale of 0.0 to 1.0. 65 | Score based on: 66 | - Consolidation: Information is properly deduplicated 67 | - Accuracy: Confirmed facts are prioritized 68 | - Unit Preservation: Original units and symbols are maintained 69 | - Format Alignment: Output matches expected format 70 | 71 | Lower score if: 72 | - Duplicate entries remain 73 | - Rumors are mixed with confirmed facts 74 | - Units or symbols were modified 75 | - Information is not properly consolidated 76 | 77 | Respond with ONLY a number between 0.0 and 1.0.""" 78 | 79 | SYNTHESIS = """You are a synthesis expert focused on consolidation and deduplication. Your task is to: 80 | 1. Review all previous extraction attempts 81 | 2. Create a final answer that: 82 | - Consolidates duplicate information into single entries 83 | - Removes redundant entries referring to same events 84 | - Prioritizes confirmed facts over speculative information 85 | - Preserves ALL original units and symbols exactly as they appear 86 | - Never converts or reformats units 87 | - Maintains separate entries only for genuinely distinct instances 88 | - Marks values as "Not Available" only when no reliable data exists 89 | 90 | Start your response with 'FINAL ANSWER:' and ensure each entry represents unique, consolidated information.""" 91 | 92 | VERIFICATION = """You are a verification expert focused on consolidation quality. Your task is to: 93 | 1. Compare before and after states of fixes 94 | 2. Verify that consolidation was successful by checking: 95 | - All duplicate entries were properly merged 96 | - No information was lost during consolidation 97 | - Confirmed facts were prioritized over rumors 98 | - Original formatting and units were preserved 99 | - Only genuinely distinct entries remain separate 100 | 101 | Output as JSON: 102 | { 103 | "consolidation_checks": [ 104 | { 105 | "aspect": "description", 106 | "success": boolean, 107 | "details": "explanation" 108 | } 109 | ], 110 | "remaining_duplicates": [ 111 | { 112 | "location": "where", 113 | "description": "what duplicates remain" 114 | } 115 | ], 116 | "verification_score": float 117 | }""" 118 | 119 | class PageFinderPrompts(Enum): 120 | SYSTEM = """You are a precise page relevance analyzer. Your task is to: 121 | 1. Analyze the given page content 122 | 2. Find pages containing relevant metric information 123 | 3. Provide clear evidence for your decisions 124 | 125 | Guidelines: 126 | - Look for clear mentions of metric names or values 127 | - Consider both direct mentions and strong contextual evidence 128 | - Evaluate relevance based on metric content quality 129 | - Provide specific text evidence for your decisions 130 | - When confident, start with 'FINAL ANSWER:' 131 | 132 | Focus on finding meaningful metric information.""" 133 | 134 | REFLECTION = """Analyze the page relevance assessment with these points: 135 | - Metric Presence: Is the metric information clear and meaningful? 136 | - Supporting Evidence: What context shows metric relevance? 137 | - Contextual Clarity: Is the information reliable and well-supported? 138 | - Validation: Can the metric information be verified? 139 | 140 | Keep reflection focused on information quality (1-2 sentences).""" 141 | 142 | CONFIDENCE = """Rate the page relevance confidence on a scale of 0.0 to 1.0. 143 | Score based on: 144 | - Information Quality: Clear and reliable metric data 145 | - Context Support: Strong evidence in surrounding text 146 | - Relevance: Direct connection to requested metrics 147 | 148 | Lower score if: 149 | - Information is unclear or ambiguous 150 | - Context is weak or misleading 151 | - Metric connection is tenuous 152 | 153 | Respond with ONLY a number between 0.0 and 1.0.""" 154 | 155 | SYNTHESIS = """You are a synthesis expert balancing precision with completeness. Your task is to: 156 | 1. Review all previous page assessments 157 | 2. Identify pages with valuable metric information 158 | 3. Create a final answer that: 159 | - Lists pages with strong metric evidence 160 | - Provides supporting context for relevance 161 | - Notes quality of metric information 162 | - Considers contextual relationships 163 | - Maintains high information standards 164 | 165 | Start your response with 'FINAL ANSWER:' and provide clear evidence for included pages.""" -------------------------------------------------------------------------------- /common/agents/prs_agent.py: -------------------------------------------------------------------------------- 1 | from typing import Annotated, Sequence, TypedDict, Literal 2 | import operator 3 | from dotenv import load_dotenv 4 | from langchain_core.messages import BaseMessage, AIMessage, HumanMessage 5 | from langgraph.graph import StateGraph, END 6 | import functools 7 | import logging 8 | import redis 9 | import json 10 | import uuid 11 | from .agent_prompt_enums import AgentMode, ExtractionPrompts, PageFinderPrompts 12 | 13 | load_dotenv() 14 | 15 | redis_client = redis.Redis(host='redis', port=6379, db=0) 16 | REDIS_EXPIRE = 60 * 60 17 | 18 | class AgentState(TypedDict): 19 | messages: Annotated[Sequence[BaseMessage], operator.add] 20 | sender: str 21 | confidence_score: float 22 | session_id: str 23 | iterations: int 24 | 25 | def get_redis_key(session_id: str, key_type: str) -> str: 26 | """Generate Redis key for different state types.""" 27 | return f"prs:{session_id}:{key_type}" 28 | 29 | def store_list(session_id: str, key_type: str, items: list) -> None: 30 | """Store a list in Redis with expiry.""" 31 | key = get_redis_key(session_id, key_type) 32 | if items: 33 | redis_client.delete(key) 34 | redis_client.rpush(key, *[json.dumps(item) for item in items]) 35 | redis_client.expire(key, REDIS_EXPIRE) 36 | 37 | def get_list(session_id: str, key_type: str, start: int = 0, end: int = -1) -> list: 38 | """Get a list from Redis with optional range.""" 39 | key = get_redis_key(session_id, key_type) 40 | items = redis_client.lrange(key, start, end) 41 | return [json.loads(item) for item in items] if items else [] 42 | 43 | def get_prompts(mode: AgentMode): 44 | """Get the appropriate prompts for the specified mode.""" 45 | if mode == AgentMode.EXTRACTION: 46 | return ExtractionPrompts 47 | return PageFinderPrompts 48 | 49 | def create_agent(client, system_message: str): 50 | """Create an agent with a specific system message.""" 51 | def agent_fn(inputs): 52 | improvements_made = "\n".join(inputs.get("improvements", [])) 53 | pending_fixes = "\n".join(inputs.get("pending_fixes", [])) 54 | 55 | messages = [ 56 | {"role": "system", "content": system_message}, 57 | {"role": "system", "content": f"""Previous reflections identified these issues to fix: 58 | {pending_fixes} 59 | 60 | Improvements already made: 61 | {improvements_made} 62 | 63 | Focus on addressing the pending issues while maintaining previous improvements."""}, 64 | {"role": "user", "content": inputs["messages"][-1].content} 65 | ] 66 | return client.do_completion(messages) 67 | return agent_fn 68 | 69 | def store_message(session_id: str, content: str, role: str = "ai") -> None: 70 | """Store a message in Redis.""" 71 | key = get_redis_key(session_id, "messages") 72 | redis_client.rpush(key, json.dumps({"role": role, "content": content})) 73 | redis_client.expire(key, REDIS_EXPIRE) 74 | 75 | def get_last_message(session_id: str) -> str: 76 | """Get the last message content from Redis.""" 77 | key = get_redis_key(session_id, "messages") 78 | last_msg = redis_client.lindex(key, -1) 79 | if last_msg: 80 | return json.loads(last_msg)["content"] 81 | return "" 82 | 83 | def get_all_messages(session_id: str) -> list: 84 | """Get all messages from Redis.""" 85 | key = get_redis_key(session_id, "messages") 86 | messages = redis_client.lrange(key, 0, -1) 87 | return [json.loads(msg) for msg in messages] if messages else [] 88 | 89 | def agent_node(state, agent, name): 90 | """Process the agent's response and update the state.""" 91 | session_id = state["session_id"] 92 | messages = state["messages"] 93 | 94 | improvements = get_list(session_id, "improvements", -5) 95 | pending_fixes = get_list(session_id, "pending_fixes") 96 | 97 | result = agent({ 98 | "messages": messages, 99 | "improvements": improvements, 100 | "pending_fixes": pending_fixes 101 | }) 102 | 103 | store_message(session_id, result) 104 | 105 | return { 106 | "messages": messages + [AIMessage(content=result)], 107 | "sender": name, 108 | "confidence_score": state["confidence_score"], 109 | "session_id": session_id, 110 | "iterations": state["iterations"] + 1 111 | } 112 | 113 | def analyze_node(state, agent, name, prompts): 114 | """Analyze output and identify consolidation opportunities.""" 115 | logger = logging.getLogger(__name__) 116 | session_id = state["session_id"] 117 | 118 | logger.info(f"\nStarting analysis iteration {state['iterations'] + 1}") 119 | 120 | last_message = get_last_message(session_id) 121 | 122 | prev_improvements = get_list(session_id, "improvements") 123 | 124 | analysis_messages = [ 125 | {"role": "system", "content": prompts.ANALYSIS.value}, 126 | {"role": "user", "content": f"""Current extraction to analyze: 127 | {last_message} 128 | 129 | Previous improvements made: 130 | {chr(10).join(prev_improvements)} 131 | 132 | Analyze this extraction focusing on duplicate entries and consolidation opportunities. 133 | Be thorough and identify ALL potential duplicates and conflicts."""} 134 | ] 135 | 136 | analysis = agent.do_completion(analysis_messages, temperature=0.2) 137 | 138 | verification_messages = [ 139 | {"role": "system", "content": """Verify the previous analysis for: 140 | 1. Missed duplicates or conflicts 141 | 2. False positives in identified issues 142 | 3. Completeness of consolidation opportunities 143 | 4. Accuracy of proposed improvements"""}, 144 | {"role": "user", "content": f"""Previous analysis: 145 | {analysis} 146 | 147 | Original content: 148 | {last_message} 149 | 150 | Verify and identify any missed issues or inaccuracies."""} 151 | ] 152 | 153 | verification = agent.do_completion(verification_messages, temperature=0.1) 154 | 155 | current_improvements = get_list(session_id, "improvements") 156 | current_pending_fixes = get_list(session_id, "pending_fixes") 157 | 158 | new_improvements = [] 159 | new_fixes = [] 160 | 161 | for content in [analysis, verification]: 162 | for line in content.split('\n'): 163 | line = line.strip() 164 | if not line: 165 | continue 166 | if line.startswith('✓'): 167 | if line not in current_improvements and line not in new_improvements: 168 | new_improvements.append(line) 169 | elif line.startswith('⚠'): 170 | if line not in current_pending_fixes and line not in new_fixes: 171 | new_fixes.append(line) 172 | 173 | if new_improvements: 174 | store_list(session_id, "improvements", current_improvements + new_improvements) 175 | if new_fixes: 176 | store_list(session_id, "pending_fixes", new_fixes) 177 | 178 | reflection = f"""Analysis {redis_client.llen(get_redis_key(session_id, 'reflections')) + 1}: 179 | 180 | Initial Analysis: 181 | {analysis} 182 | 183 | Verification: 184 | {verification} 185 | 186 | New Issues: {len(new_fixes)} 187 | New Improvements: {len(new_improvements)}""" 188 | 189 | redis_client.rpush(get_redis_key(session_id, "reflections"), json.dumps(reflection)) 190 | 191 | logger.info(f"Analysis complete: {len(new_fixes)} new issues, {len(new_improvements)} improvements") 192 | 193 | return { 194 | "messages": state["messages"], 195 | "sender": name, 196 | "confidence_score": state["confidence_score"], 197 | "session_id": session_id, 198 | "iterations": state["iterations"] 199 | } 200 | 201 | def confidence_node(state, agent, name, prompts): 202 | """Score the confidence of the current analysis.""" 203 | session_id = state["session_id"] 204 | messages = state["messages"] 205 | last_message = messages[-1].content if messages else "" 206 | 207 | confidence_messages = [ 208 | {"role": "system", "content": prompts.CONFIDENCE.value}, 209 | {"role": "user", "content": f"Analysis to score:\n{last_message}"} 210 | ] 211 | 212 | score = agent.do_completion(confidence_messages, temperature=0.0) 213 | 214 | try: 215 | confidence = float(score.strip()) 216 | except: 217 | confidence = 0.5 218 | 219 | return { 220 | "messages": messages, 221 | "sender": name, 222 | "confidence_score": confidence, 223 | "session_id": session_id, 224 | "iterations": state["iterations"] 225 | } 226 | 227 | def fix_node(state, agent, name, prompts): 228 | """Fix identified issues focusing on deduplication.""" 229 | logger = logging.getLogger(__name__) 230 | session_id = state["session_id"] 231 | last_message = get_last_message(session_id) 232 | 233 | pending_fixes = get_list(session_id, "pending_fixes") 234 | if pending_fixes: 235 | logger.info(f"\nAttempting to fix {len(pending_fixes)} issues:") 236 | for fix in pending_fixes: 237 | logger.info(f"⚠ {fix}") 238 | 239 | fix_messages = [ 240 | {"role": "system", "content": prompts.FIX.value}, 241 | {"role": "user", "content": f"""Content to fix: 242 | {last_message} 243 | 244 | Issues to address: 245 | {chr(10).join(pending_fixes)} 246 | 247 | Apply fixes systematically and verify each change."""} 248 | ] 249 | 250 | fixed_result = agent.do_completion(fix_messages, temperature=0.2) 251 | 252 | verify_messages = [ 253 | {"role": "system", "content": """Verify that all fixes were properly applied: 254 | 1. Check each issue was addressed 255 | 2. Verify no information was lost 256 | 3. Confirm all consolidations are accurate 257 | 4. Ensure no new duplicates were created"""}, 258 | {"role": "user", "content": f"""Original content: 259 | {last_message} 260 | 261 | Applied fixes: 262 | {fixed_result} 263 | 264 | Original issues: 265 | {chr(10).join(pending_fixes)} 266 | 267 | Verify all fixes were properly applied."""} 268 | ] 269 | 270 | verification = agent.do_completion(verify_messages, temperature=0.1) 271 | 272 | store_message(session_id, fixed_result) 273 | redis_client.rpush(get_redis_key(session_id, "fix_verifications"), 274 | json.dumps({"fixes": pending_fixes, "verification": verification})) 275 | 276 | redis_client.delete(get_redis_key(session_id, "pending_fixes")) 277 | 278 | logger.info("Fix attempt complete") 279 | logger.info(f"Verification result length: {len(verification.split())}") 280 | 281 | return { 282 | "messages": state["messages"] + [AIMessage(content=fixed_result)], 283 | "sender": name, 284 | "confidence_score": state["confidence_score"], 285 | "session_id": session_id, 286 | "iterations": state["iterations"] 287 | } 288 | 289 | def create_router(mode: AgentMode): 290 | """Create a router function based on the agent mode.""" 291 | def router(state) -> Literal["process", "analyze", "fix", "score", "synthesize", "__end__"]: 292 | iterations = state["iterations"] 293 | confidence = state["confidence_score"] 294 | 295 | max_iterations = 2 296 | min_confidence = 0.8 297 | 298 | if iterations >= max_iterations or confidence >= min_confidence: 299 | if state["sender"] != "synthesizer": 300 | return "synthesize" 301 | return "__end__" 302 | 303 | if state["sender"] == "user": 304 | return "process" 305 | elif state["sender"] == "processor": 306 | return "analyze" 307 | elif state["sender"] == "analyzer": 308 | return "fix" 309 | elif state["sender"] == "fixer": 310 | return "score" 311 | elif state["sender"] == "scorer": 312 | if confidence < min_confidence and iterations < max_iterations: 313 | return "process" 314 | return "synthesize" 315 | else: 316 | return "process" 317 | 318 | return router 319 | 320 | def synthesize_node(state, agent, name, prompts): 321 | """Create final answer by synthesizing all iterations and improvements.""" 322 | session_id = state["session_id"] 323 | 324 | messages = get_all_messages(session_id) 325 | improvements = get_list(session_id, "improvements", -5) 326 | 327 | final_result = agent.do_completion([ 328 | {"role": "system", "content": prompts.SYNTHESIS.value}, 329 | {"role": "user", "content": f"""Best response so far: 330 | {messages[-1]['content'] if messages else ''} 331 | 332 | Key improvements: 333 | {chr(10).join(improvements)}"""} 334 | ], temperature=0.0) 335 | 336 | store_message(session_id, final_result) 337 | 338 | return { 339 | "sender": "synthesizer", 340 | "confidence_score": state["confidence_score"], 341 | "session_id": session_id, 342 | "iterations": state["iterations"] 343 | } 344 | 345 | def create_graph(client, mode: AgentMode): 346 | """Create the workflow graph with the specified mode.""" 347 | prompts = get_prompts(mode) 348 | 349 | text_processor = create_agent(client, prompts.SYSTEM.value) 350 | 351 | processor_node = functools.partial(agent_node, agent=text_processor, name="processor") 352 | analyzer = functools.partial(analyze_node, agent=client, name="analyzer", prompts=prompts) 353 | issue_fixer = functools.partial(fix_node, agent=client, name="fixer", prompts=prompts) 354 | confidence_scorer = functools.partial(confidence_node, agent=client, name="scorer", prompts=prompts) 355 | synthesizer = functools.partial(synthesize_node, agent=client, name="synthesizer", prompts=prompts) 356 | 357 | workflow = StateGraph(AgentState) 358 | 359 | workflow.add_node("process", processor_node) 360 | workflow.add_node("analyze", analyzer) 361 | workflow.add_node("fix", issue_fixer) 362 | workflow.add_node("score", confidence_scorer) 363 | workflow.add_node("synthesize", synthesizer) 364 | 365 | router = create_router(mode) 366 | 367 | for node in ["process", "analyze", "fix", "score", "synthesize"]: 368 | workflow.add_conditional_edges( 369 | node, 370 | router, 371 | { 372 | "process": "process", 373 | "analyze": "analyze", 374 | "fix": "fix", 375 | "score": "score", 376 | "synthesize": "synthesize", 377 | "__end__": END, 378 | }, 379 | ) 380 | 381 | workflow.set_entry_point("process") 382 | 383 | return workflow.compile() 384 | 385 | def process_extraction(text: str, client, mode: AgentMode) -> str: 386 | """Process text through the agent workflow with extraction handler format.""" 387 | logger = logging.getLogger(__name__) 388 | session_id = str(uuid.uuid4()) 389 | 390 | try: 391 | store_message(session_id, text, "human") 392 | 393 | graph = create_graph(client, mode) 394 | 395 | result = graph.invoke({ 396 | "messages": [HumanMessage(content=text)], 397 | "sender": "user", 398 | "confidence_score": 0.0, 399 | "session_id": session_id, 400 | "iterations": 0 401 | }) 402 | 403 | logger.info("\n" + "="*50) 404 | logger.info(f"AGENT MODE: {mode.value}") 405 | logger.info("="*50) 406 | 407 | reflections = get_list(session_id, "reflections") 408 | logger.info("\nPROCESS REFLECTIONS:") 409 | for i, reflection in enumerate(reflections, 1): 410 | logger.info(f"\nITERATION {i}:") 411 | logger.info(reflection) 412 | 413 | improvements = get_list(session_id, "improvements") 414 | logger.info("\nIMPROVEMENTS MADE:") 415 | if improvements: 416 | for imp in improvements: 417 | logger.info(f"✓ {imp}") 418 | else: 419 | logger.info("(No improvements needed)") 420 | 421 | fix_verifications = get_list(session_id, "fix_verifications") 422 | logger.info("\nFIX VERIFICATIONS:") 423 | for i, verification in enumerate(fix_verifications, 1): 424 | logger.info(f"\nFix Attempt {i}:") 425 | logger.info(f"Issues Addressed: {len(verification.get('fixes', []))}") 426 | logger.info(f"Verification Result: {verification.get('verification', '')}") 427 | 428 | pending_fixes = get_list(session_id, "pending_fixes") 429 | logger.info("\nFINAL STATE:") 430 | logger.info(f"Total Iterations: {result['iterations']}") 431 | logger.info(f"Final Confidence Score: {result['confidence_score']:.2f}") 432 | logger.info("\nRemaining Issues:") 433 | if pending_fixes: 434 | for issue in pending_fixes: 435 | logger.info(f"⚠ {issue}") 436 | else: 437 | logger.info("(No remaining issues)") 438 | 439 | logger.info("\nFINAL RESULT:") 440 | final_message = result["messages"][-1].content if result["messages"] else "" 441 | logger.info(final_message) 442 | logger.info("="*50 + "\n") 443 | 444 | for key_type in ["messages", "improvements", "pending_fixes", "reflections", "fix_verifications"]: 445 | redis_client.delete(get_redis_key(session_id, key_type)) 446 | 447 | return final_message 448 | 449 | except Exception as e: 450 | logger.error(f"Extraction failed: {str(e)}") 451 | try: 452 | reflections = get_list(session_id, "reflections") 453 | if reflections: 454 | logger.error("\nLast known state before error:") 455 | logger.error(reflections[-1]) 456 | except: 457 | pass 458 | 459 | for key_type in ["messages", "improvements", "pending_fixes", "reflections", "fix_verifications"]: 460 | redis_client.delete(get_redis_key(session_id, key_type)) 461 | return "Extraction failed. Please try again." -------------------------------------------------------------------------------- /common/api/3.1.0-marly-spec.yml: -------------------------------------------------------------------------------- 1 | openapi: 3.1.0 2 | info: 3 | title: Marly API 4 | description: The Data Processor for Agents 5 | version: 1.0.0 6 | 7 | paths: 8 | /pipelines: 9 | post: 10 | summary: Run pipeline 11 | description: Initiates a pipeline processing job for the given PDF and schemas. 12 | operationId: runPipeline 13 | requestBody: 14 | required: true 15 | content: 16 | application/json: 17 | schema: 18 | $ref: "#/components/schemas/PipelineRequestModel" 19 | responses: 20 | "202": 21 | description: Pipeline processing started 22 | content: 23 | application/json: 24 | schema: 25 | $ref: "#/components/schemas/PipelineResponseModel" 26 | "500": 27 | $ref: "#/components/responses/InternalServerError" 28 | 29 | /pipelines/{task_id}: 30 | get: 31 | summary: Get pipeline results 32 | description: Retrieves the results of a pipeline processing job. 33 | operationId: getPipelineResults 34 | parameters: 35 | - name: task_id 36 | in: path 37 | required: true 38 | schema: 39 | type: string 40 | description: Unique identifier for the pipeline task 41 | responses: 42 | "200": 43 | description: Pipeline results retrieved successfully 44 | content: 45 | application/json: 46 | schema: 47 | $ref: "#/components/schemas/PipelineResult" 48 | "404": 49 | description: Task not found 50 | "500": 51 | $ref: "#/components/responses/InternalServerError" 52 | 53 | components: 54 | schemas: 55 | WorkloadItem: 56 | type: object 57 | properties: 58 | raw_data: 59 | type: string 60 | description: string version of raw data (can be a pdf, html, text, etc.) 61 | schemas: 62 | type: array 63 | items: 64 | type: string 65 | description: List of schema strings 66 | data_source: 67 | type: string 68 | description: Type of data source 69 | documents_location: 70 | type: string 71 | description: Location of documents 72 | file_name: 73 | type: string 74 | description: Name of the file 75 | additional_params: 76 | type: object 77 | additionalProperties: true 78 | description: Additional parameters for the workload 79 | destination: 80 | type: string 81 | description: Destination for the processed data 82 | required: 83 | - schemas 84 | 85 | PipelineRequestModel: 86 | type: object 87 | properties: 88 | workloads: 89 | type: array 90 | items: 91 | $ref: "#/components/schemas/WorkloadItem" 92 | provider_type: 93 | type: string 94 | provider_model_name: 95 | type: string 96 | api_key: 97 | type: string 98 | markdown_mode: 99 | type: boolean 100 | default: false 101 | additional_params: 102 | type: object 103 | additionalProperties: true 104 | required: 105 | - workloads 106 | - provider_type 107 | - provider_model_name 108 | - api_key 109 | 110 | PipelineResponseModel: 111 | type: object 112 | properties: 113 | task_id: 114 | type: string 115 | description: Unique identifier for the pipeline task 116 | message: 117 | type: string 118 | description: Status message 119 | required: 120 | - task_id 121 | - message 122 | 123 | PipelineResult: 124 | type: object 125 | properties: 126 | task_id: 127 | type: string 128 | description: Unique identifier for the pipeline task 129 | status: 130 | $ref: "#/components/schemas/JobStatus" 131 | results: 132 | type: array 133 | items: 134 | $ref: "#/components/schemas/SchemaResult" 135 | description: Array of schema results 136 | total_run_time: 137 | type: string 138 | description: Total execution time of the pipeline 139 | required: 140 | - task_id 141 | - status 142 | - results 143 | - total_run_time 144 | 145 | SchemaResult: 146 | type: object 147 | properties: 148 | schema_id: 149 | type: string 150 | description: Identifier for the schema used 151 | metrics: 152 | type: object 153 | additionalProperties: 154 | type: string 155 | description: Metrics related to the schema extraction 156 | schema_data: 157 | type: object 158 | additionalProperties: 159 | type: string 160 | description: Extracted data based on the schema 161 | required: 162 | - schema_id 163 | - metrics 164 | - schema_data 165 | 166 | JobStatus: 167 | type: string 168 | enum: 169 | - PENDING 170 | - IN_PROGRESS 171 | - COMPLETED 172 | - FAILED 173 | description: Current status of the pipeline job 174 | 175 | responses: 176 | InternalServerError: 177 | description: Internal Server Error 178 | content: 179 | application/json: 180 | schema: 181 | type: object 182 | properties: 183 | detail: 184 | type: string 185 | description: Error details 186 | -------------------------------------------------------------------------------- /common/destinations/base/base_destination.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from typing import Dict, Any, List, Optional 3 | 4 | class BaseDestination(ABC): 5 | @abstractmethod 6 | def connect(self) -> None: 7 | """Establish a connection to the destination.""" 8 | pass 9 | 10 | @abstractmethod 11 | def insert(self, data: List[Dict[str, Any]]) -> None: 12 | """Insert data into the destination.""" 13 | pass 14 | 15 | @abstractmethod 16 | def close(self) -> None: 17 | """Close the connection to the destination.""" 18 | pass 19 | 20 | @abstractmethod 21 | def get_table_structure(self, table_name: str, schema: Optional[str] = None) -> List[Dict[str, Any]]: 22 | """Get the structure of a table.""" 23 | pass 24 | -------------------------------------------------------------------------------- /common/destinations/destination_factory.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Any 2 | from common.destinations.enums.destination_enums import DestinationType 3 | from common.destinations.sqlite_destination import SQLiteDestination 4 | 5 | class DestinationFactory: 6 | @staticmethod 7 | def create_destination(destination_type: str, config: Dict[str, Any] = None): 8 | if not destination_type: 9 | raise ValueError("destination_type must be provided.") 10 | 11 | try: 12 | destination_type_enum = DestinationType(destination_type.lower()) 13 | except ValueError: 14 | raise ValueError(f"Invalid destination type. Allowed values are: {', '.join([d.value for d in DestinationType])}") 15 | 16 | if destination_type_enum == DestinationType.SQLITE: 17 | return SQLiteDestination( 18 | db_path=config.get("db_path", ":memory:"), 19 | additional_params=config.get("additional_params", {}) 20 | ) 21 | else: 22 | raise ValueError(f"Unsupported destination type: {destination_type_enum}") 23 | 24 | -------------------------------------------------------------------------------- /common/destinations/enums/destination_enums.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | 3 | class DestinationType(Enum): 4 | SQLITE = "sqlite" 5 | 6 | -------------------------------------------------------------------------------- /common/destinations/sqlite_destination.py: -------------------------------------------------------------------------------- 1 | import sqlite3 2 | import os 3 | from typing import Dict, Any, List, Optional 4 | from common.destinations.base.base_destination import BaseDestination 5 | 6 | class SQLiteDestination(BaseDestination): 7 | def __init__(self, db_path: str, additional_params: Dict[str, Any] = None): 8 | self.db_path = db_path 9 | self.additional_params = additional_params or {} 10 | self.schema = self.additional_params.get("schema", "main") 11 | self.conn = None 12 | self.cursor = None 13 | 14 | def connect(self) -> None: 15 | if not self.conn: 16 | try: 17 | self.conn = sqlite3.connect(self.db_path) 18 | self.cursor = self.conn.cursor() 19 | except sqlite3.OperationalError: 20 | # If the directory doesn't exist, create it 21 | os.makedirs(os.path.dirname(self.db_path), exist_ok=True) 22 | self.conn = sqlite3.connect(self.db_path) 23 | self.cursor = self.conn.cursor() 24 | print(f"Created new SQLite database at {self.db_path}") 25 | 26 | def insert(self, table_name: str, data: List[Dict[str, Any]]) -> None: 27 | self.connect() 28 | if not data: 29 | raise ValueError("No data provided for insertion") 30 | 31 | if not self.table_exists(table_name): 32 | self.create_table(table_name, data[0].keys()) 33 | 34 | columns = list(data[0].keys()) 35 | placeholders = ', '.join(['?' for _ in columns]) 36 | insert_query = f"INSERT INTO {table_name} ({', '.join(columns)}) VALUES ({placeholders})" 37 | 38 | try: 39 | for row in data: 40 | values = [row.get(col, '') for col in columns] 41 | self.cursor.execute(insert_query, values) 42 | self.conn.commit() 43 | except sqlite3.Error as e: 44 | self.conn.rollback() 45 | raise ValueError(f"Error inserting data: {str(e)}") 46 | 47 | def table_exists(self, table_name: str) -> bool: 48 | self.connect() 49 | self.cursor.execute(f"SELECT name FROM sqlite_master WHERE type='table' AND name=?", (table_name,)) 50 | return self.cursor.fetchone() is not None 51 | 52 | def create_table(self, table_name: str, columns: List[str]) -> None: 53 | column_defs = ', '.join([f"{col} TEXT" for col in columns]) 54 | create_query = f"CREATE TABLE IF NOT EXISTS {table_name} ({column_defs})" 55 | self.cursor.execute(create_query) 56 | self.conn.commit() 57 | 58 | def close(self) -> None: 59 | if self.cursor: 60 | self.cursor.close() 61 | if self.conn: 62 | self.conn.close() 63 | self.conn = None 64 | self.cursor = None 65 | 66 | def get_table_structure(self, table_name: str, schema: Optional[str] = None) -> List[Dict[str, Any]]: 67 | self.connect() 68 | schema = schema or self.schema 69 | query = f"PRAGMA {schema}.table_info({table_name})" 70 | self.cursor.execute(query) 71 | columns = self.cursor.fetchall() 72 | 73 | structure = [] 74 | for column in columns: 75 | structure.append({ 76 | "name": column[1], 77 | "type": column[2], 78 | "notnull": bool(column[3]), 79 | "default_value": column[4], 80 | "pk": bool(column[5]) 81 | }) 82 | 83 | return structure 84 | -------------------------------------------------------------------------------- /common/models/azure_model.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List, Optional, Union, Any 2 | from common.models.base.base_model import BaseModel 3 | from openai import AzureOpenAI 4 | from common.models.enums.model_enums import AzureModelName 5 | from langsmith import traceable 6 | 7 | class AzureModel(BaseModel): 8 | def __init__(self, api_key: str, model_name: str, additional_params: Dict[str, Any] = None): 9 | if not api_key: 10 | raise ValueError("API key must be provided.") 11 | self.api_key = api_key 12 | 13 | if not additional_params: 14 | additional_params = {} 15 | 16 | self.api_version = additional_params.get("api_version") 17 | if not self.api_version: 18 | raise ValueError("API version must be provided in additional_params.") 19 | 20 | self.azure_endpoint = additional_params.get("azure_endpoint") 21 | if not self.azure_endpoint: 22 | raise ValueError("Azure endpoint must be provided in additional_params.") 23 | 24 | self.azure_deployment = additional_params.get("azure_deployment") 25 | if not self.azure_deployment: 26 | raise ValueError("Azure deployment must be provided in additional_params.") 27 | 28 | self.model_name = self.validate_model_name(model_name) 29 | 30 | self.client = AzureOpenAI( 31 | api_key=self.api_key, 32 | api_version=self.api_version, 33 | azure_endpoint=self.azure_endpoint, 34 | azure_deployment=self.azure_deployment 35 | ) 36 | 37 | @staticmethod 38 | def validate_model_name(model_name: str) -> str: 39 | try: 40 | return AzureModelName(model_name).value 41 | except ValueError: 42 | raise ValueError(f"Invalid model name. Allowed values are: {', '.join([m.value for m in AzureModelName])}") 43 | 44 | @traceable(run_type="llm") 45 | def do_completion(self, 46 | messages: List[Dict[str, str]], 47 | model_name: Optional[str] = None, 48 | max_tokens: Optional[int] = None, 49 | temperature: Optional[float] = None, 50 | top_p: Optional[float] = None, 51 | n: Optional[int] = None, 52 | stop: Optional[Union[str, List[str]]] = None, 53 | response_format: Optional[Dict[str, str]] = None) -> str: 54 | if not messages: 55 | raise ValueError("'messages' must be provided.") 56 | 57 | params = { 58 | "model": self.validate_model_name(model_name) if model_name else self.model_name, 59 | "messages": messages, 60 | } 61 | 62 | if max_tokens is not None: 63 | params["max_tokens"] = max_tokens 64 | if temperature is not None: 65 | params["temperature"] = temperature 66 | if top_p is not None: 67 | params["top_p"] = top_p 68 | if n is not None: 69 | params["n"] = n 70 | if stop is not None: 71 | params["stop"] = stop 72 | if response_format is not None: 73 | params["response_format"] = response_format 74 | 75 | response = self.client.chat.completions.create(**params) 76 | 77 | return response.choices[0].message.content 78 | -------------------------------------------------------------------------------- /common/models/base/base_model.py: -------------------------------------------------------------------------------- 1 | from typing import Dict 2 | 3 | class BaseModel: 4 | def do_completion(self, data: Dict): 5 | raise NotImplementedError -------------------------------------------------------------------------------- /common/models/cerebras_model.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List, Optional, Union, Any 2 | from common.models.base.base_model import BaseModel 3 | from common.models.enums.model_enums import CerebrasModelName 4 | from cerebras.cloud.sdk import Cerebras 5 | from langsmith import traceable 6 | 7 | class CerebrasModel(BaseModel): 8 | def __init__(self, api_key: str, model_name: str, additional_params: Dict[str, Any] = None): 9 | if not api_key: 10 | raise ValueError("API key must be provided.") 11 | self.api_key = api_key 12 | 13 | if not additional_params: 14 | additional_params = {} 15 | 16 | self.model_name = self.validate_model_name(model_name) 17 | 18 | self.client = Cerebras(api_key=self.api_key) 19 | 20 | @staticmethod 21 | def validate_model_name(model_name: str) -> str: 22 | try: 23 | return CerebrasModelName(model_name).value 24 | except ValueError: 25 | raise ValueError(f"Invalid model name. Allowed values are: {', '.join([m.value for m in CerebrasModelName])}") 26 | 27 | @traceable(run_type="llm") 28 | def do_completion(self, 29 | messages: List[Dict[str, str]], 30 | model_name: Optional[str] = None, 31 | max_tokens: Optional[int] = None, 32 | temperature: Optional[float] = None, 33 | top_p: Optional[float] = None, 34 | n: Optional[int] = None, 35 | stop: Optional[Union[str, List[str]]] = None, 36 | response_format: Optional[Dict[str, str]] = None) -> str: 37 | if not messages: 38 | raise ValueError("'messages' must be provided.") 39 | 40 | params = { 41 | "model": self.validate_model_name(model_name) if model_name else self.model_name, 42 | "messages": messages, 43 | } 44 | 45 | if max_tokens is not None: 46 | params["max_tokens"] = max_tokens 47 | if temperature is not None: 48 | params["temperature"] = temperature 49 | if top_p is not None: 50 | params["top_p"] = top_p 51 | if n is not None: 52 | params["n"] = n 53 | if stop is not None: 54 | params["stop"] = stop 55 | if response_format is not None: 56 | params["response_format"] = response_format 57 | 58 | response = self.client.chat.completions.create(**params) 59 | 60 | return response.choices[0].message.content 61 | -------------------------------------------------------------------------------- /common/models/enums/model_enums.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | 3 | class ModelType(Enum): 4 | OPENAI = "openai" 5 | AZURE = "azure" 6 | GROQ = "groq" 7 | CEREBRAS = "cerebras" 8 | MISTRAL = "mistral" 9 | 10 | class OpenAIModelName(Enum): 11 | GPT_3_5_TURBO = "gpt-3.5-turbo" 12 | GPT_4 = "gpt-4" 13 | GPT_4O = "gpt-4o" 14 | CLAUDE_3_HAIKU = "claude-3-haiku" 15 | 16 | class AzureModelName(Enum): 17 | GPT_35_TURBO = "gpt-35-turbo" 18 | GPT_4 = "gpt-4" 19 | GPT_4_32K = "gpt-4-32k" 20 | GPT_4O = "gpt-4o" 21 | 22 | class GroqModelName(Enum): 23 | DISTIL_WHISPER_ENGLISH = "distil-whisper-large-v3-en" 24 | GEMMA2_9B = "gemma2-9b-it" 25 | GEMMA_7B = "gemma-7b-it" 26 | LLAMA3_GROQ_70B_TOOL_USE = "llama3-groq-70b-8192-tool-use-preview" 27 | LLAMA3_GROQ_8B_TOOL_USE = "llama3-groq-8b-8192-tool-use-preview" 28 | LLAMA_3_1_70B = "llama-3.1-70b-versatile" 29 | LLAMA_3_1_8B = "llama-3.1-8b-instant" 30 | LLAMA_GUARD_3_8B = "llama-guard-3-8b" 31 | LLAVA_1_5_7B = "llava-v1.5-7b-4096-preview" 32 | META_LLAMA3_70B = "llama3-70b-8192" 33 | META_LLAMA3_8B = "llama3-8b-8192" 34 | MIXTRAL_8X7B = "mixtral-8x7b-32768" 35 | WHISPER = "whisper-large-v3" 36 | 37 | class CerebrasModelName(Enum): 38 | LLAMA_3_1_70B = "llama3.1-70b" 39 | LLAMA_3_3_70B = "llama3.3-70b" 40 | LLAMA_3_1_8B = "llama3.1-8b" 41 | 42 | class MistralModelName(Enum): 43 | MISTRAL_LARGE_LATEST = "mistral-large-latest" 44 | 45 | class MistralAPIURL(Enum): 46 | CHAT_COMPLETIONS = "https://api.mistral.ai/v1/chat/completions" 47 | -------------------------------------------------------------------------------- /common/models/groq_model.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List, Optional, Union, Any 2 | from common.models.base.base_model import BaseModel 3 | from common.models.enums.model_enums import GroqModelName 4 | from groq import Groq 5 | from langsmith import traceable 6 | 7 | class GroqModel(BaseModel): 8 | def __init__(self, api_key: str, model_name: str, additional_params: Dict[str, Any] = None): 9 | if not api_key: 10 | raise ValueError("API key must be provided.") 11 | self.api_key = api_key 12 | 13 | if not additional_params: 14 | additional_params = {} 15 | 16 | self.model_name = self.validate_model_name(model_name) 17 | self.client = Groq(api_key=self.api_key) 18 | 19 | @staticmethod 20 | def validate_model_name(model_name: str) -> str: 21 | try: 22 | return GroqModelName(model_name).value 23 | except ValueError: 24 | raise ValueError(f"Invalid model name. Allowed values are: {', '.join([m.value for m in GroqModelName])}") 25 | 26 | @traceable(run_type="llm") 27 | def do_completion(self, 28 | messages: List[Dict[str, str]], 29 | model_name: Optional[str] = None, 30 | max_tokens: Optional[int] = None, 31 | temperature: Optional[float] = None, 32 | top_p: Optional[float] = None, 33 | stop: Optional[Union[str, List[str]]] = None, 34 | response_format: Optional[Dict[str, str]] = None) -> str: 35 | if not messages: 36 | raise ValueError("'messages' must be provided.") 37 | 38 | params = { 39 | "model": self.validate_model_name(model_name) if model_name else self.model_name, 40 | "messages": messages, 41 | } 42 | 43 | if max_tokens is not None: 44 | params["max_tokens"] = max_tokens 45 | if temperature is not None: 46 | params["temperature"] = temperature 47 | if top_p is not None: 48 | params["top_p"] = top_p 49 | if stop is not None: 50 | params["stop"] = stop 51 | if response_format is not None: 52 | params["response_format"] = response_format 53 | 54 | response = self.client.chat.completions.create(**params) 55 | 56 | return response.choices[0].message.content 57 | -------------------------------------------------------------------------------- /common/models/mistral_model.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List, Optional, Union, Any 2 | from common.models.base.base_model import BaseModel 3 | from common.models.enums.model_enums import MistralModelName, MistralAPIURL 4 | from langsmith import traceable 5 | import requests 6 | 7 | class MistralModel(BaseModel): 8 | def __init__(self, api_key: str, model_name: str, additional_params: Dict[str, Any] = None): 9 | if not api_key: 10 | raise ValueError("API key must be provided.") 11 | self.api_key = api_key 12 | 13 | if not additional_params: 14 | additional_params = {} 15 | 16 | self.model_name = self.validate_model_name(model_name) 17 | self.api_url = MistralAPIURL.CHAT_COMPLETIONS.value 18 | 19 | @staticmethod 20 | def validate_model_name(model_name: str) -> str: 21 | try: 22 | return MistralModelName(model_name).value 23 | except ValueError: 24 | raise ValueError(f"Invalid model name. Allowed values are: {', '.join([m.value for m in MistralModelName])}") 25 | 26 | @traceable(run_type="llm") 27 | def do_completion(self, 28 | messages: List[Dict[str, str]], 29 | model_name: Optional[str] = None, 30 | max_tokens: Optional[int] = None, 31 | temperature: Optional[float] = None, 32 | top_p: Optional[float] = None, 33 | stop: Optional[Union[str, List[str]]] = None, 34 | response_format: Optional[Dict[str, str]] = None) -> str: 35 | if not messages: 36 | raise ValueError("'messages' must be provided.") 37 | 38 | headers = { 39 | "Content-Type": "application/json", 40 | "Accept": "application/json", 41 | "Authorization": f"Bearer {self.api_key}" 42 | } 43 | 44 | data = { 45 | "model": self.validate_model_name(model_name) if model_name else self.model_name, 46 | "messages": messages, 47 | } 48 | 49 | if max_tokens is not None: 50 | data["max_tokens"] = max_tokens 51 | if temperature is not None: 52 | data["temperature"] = temperature 53 | if top_p is not None: 54 | data["top_p"] = top_p 55 | if stop is not None: 56 | data["stop"] = stop 57 | if response_format is not None: 58 | data["response_format"] = response_format 59 | 60 | response = requests.post(self.api_url, headers=headers, json=data) 61 | 62 | if response.status_code == 200: 63 | result = response.json() 64 | return result['choices'][0]['message']['content'] 65 | else: 66 | raise Exception(f"Error: {response.status_code}\n{response.text}") 67 | -------------------------------------------------------------------------------- /common/models/model_factory.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from typing import Dict, Any 3 | from common.models.enums.model_enums import ModelType, OpenAIModelName, AzureModelName, GroqModelName, CerebrasModelName, MistralModelName 4 | from common.models.openai_model import OpenaiModel 5 | from common.models.azure_model import AzureModel 6 | from common.models.groq_model import GroqModel 7 | from common.models.cerebras_model import CerebrasModel 8 | from common.models.mistral_model import MistralModel 9 | 10 | logger = logging.getLogger(__name__) 11 | 12 | class ModelFactory: 13 | @staticmethod 14 | def create_model(model_type: str, model_name: str, api_key: str, additional_params: Dict[str, Any] = None): 15 | if not model_type: 16 | raise ValueError("model_type must be provided.") 17 | if not model_name: 18 | raise ValueError("model_name must be provided.") 19 | if not api_key: 20 | raise ValueError("API key must be provided.") 21 | 22 | try: 23 | model_type_enum = ModelType(model_type.lower()) 24 | except ValueError: 25 | raise ValueError(f"Invalid model type. Allowed values are: {', '.join([m.value for m in ModelType])}") 26 | 27 | model_config = { 28 | "api_key": api_key, 29 | "model_name": model_name, 30 | "additional_params": additional_params or {} 31 | } 32 | 33 | required_params = { 34 | ModelType.AZURE: ["api_version", "azure_endpoint", "azure_deployment"], 35 | ModelType.OPENAI: [], 36 | ModelType.GROQ: [], 37 | ModelType.CEREBRAS: [], 38 | ModelType.MISTRAL: [] 39 | } 40 | 41 | missing_params = [] 42 | for param in required_params.get(model_type_enum, []): 43 | if param not in model_config["additional_params"]: 44 | missing_params.append(param) 45 | 46 | if missing_params: 47 | raise ValueError(f"Missing required additional_params for {model_type}: {', '.join(missing_params)}") 48 | if model_type_enum == ModelType.OPENAI: 49 | OpenAIModelName(model_name) 50 | model_instance = OpenaiModel(**model_config) 51 | elif model_type_enum == ModelType.AZURE: 52 | AzureModelName(model_name) 53 | model_instance = AzureModel(**model_config) 54 | elif model_type_enum == ModelType.GROQ: 55 | GroqModelName(model_name) 56 | model_instance = GroqModel(**model_config) 57 | elif model_type_enum == ModelType.CEREBRAS: 58 | CerebrasModelName(model_name) 59 | model_instance = CerebrasModel(**model_config) 60 | elif model_type_enum == ModelType.MISTRAL: 61 | MistralModelName(model_name) 62 | model_instance = MistralModel(**model_config) 63 | else: 64 | raise ValueError(f"Unsupported model type: {model_type_enum}") 65 | 66 | logger.info(f"Returning model instance of type: {model_type_enum.value}") 67 | return model_instance -------------------------------------------------------------------------------- /common/models/openai_model.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List, Optional, Union, Any 2 | from common.models.base.base_model import BaseModel 3 | from common.models.enums.model_enums import OpenAIModelName 4 | from openai import OpenAI 5 | from langsmith import traceable 6 | 7 | class OpenaiModel(BaseModel): 8 | def __init__(self, api_key: str, model_name: str, additional_params: Dict[str, Any] = None): 9 | if not api_key: 10 | raise ValueError("API key must be provided.") 11 | self.api_key = api_key 12 | 13 | if not additional_params: 14 | additional_params = {} 15 | 16 | self.model_name = self.validate_model_name(model_name) 17 | self.base_url = additional_params.get('base_url') 18 | 19 | client_params = {'api_key': self.api_key} 20 | if self.base_url: 21 | client_params['base_url'] = self.base_url 22 | 23 | self.client = OpenAI(**client_params) 24 | 25 | @staticmethod 26 | def validate_model_name(model_name: str) -> str: 27 | try: 28 | return OpenAIModelName(model_name).value 29 | except ValueError: 30 | raise ValueError(f"Invalid model name. Allowed values are: {', '.join([m.value for m in OpenAIModelName])}") 31 | 32 | @traceable(run_type="llm") 33 | def do_completion(self, 34 | messages: List[Dict[str, str]], 35 | model_name: Optional[str] = None, 36 | max_tokens: Optional[int] = None, 37 | temperature: Optional[float] = None, 38 | top_p: Optional[float] = None, 39 | n: Optional[int] = None, 40 | stop: Optional[Union[str, List[str]]] = None, 41 | response_format: Optional[Dict[str, str]] = None) -> Dict: 42 | if not messages: 43 | raise ValueError("'messages' must be provided.") 44 | 45 | params = { 46 | "model": self.validate_model_name(model_name) if model_name else self.model_name, 47 | "messages": messages, 48 | } 49 | 50 | if max_tokens is not None: 51 | params["max_tokens"] = max_tokens 52 | if temperature is not None: 53 | params["temperature"] = temperature 54 | if top_p is not None: 55 | params["top_p"] = top_p 56 | if n is not None: 57 | params["n"] = n 58 | if stop is not None: 59 | params["stop"] = stop 60 | if response_format is not None: 61 | params["response_format"] = response_format 62 | 63 | response = self.client.chat.completions.create(**params) 64 | 65 | return response.choices[0].message.content 66 | -------------------------------------------------------------------------------- /common/prompts/prompt_enums.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | 3 | class PromptType(Enum): 4 | EXAMPLE_GENERATION = "marly/example-generation" 5 | EXTRACTION = "marly/extraction" 6 | TRANSFORMATION = "marly/transformation" 7 | TRANSFORMATION_MARKDOWN = "noahegg/tranformation-markdown" 8 | TRANSFORMATION_WEB = "noahegg/web_scraper" 9 | TRANSFORMATION_ONLY = "marly/transformation-only" 10 | VALIDATION = "marly/validation" 11 | RELEVANT_PAGE_FINDER = "marly/relevant-page-finder" 12 | PLAN = "marly/plan" 13 | RELEVANT_PAGE_FINDER_V2 = "marly/relevant-page-finder-with-plan" 14 | -------------------------------------------------------------------------------- /common/redis/redis_config.py: -------------------------------------------------------------------------------- 1 | import redis.asyncio as redis 2 | from redis.asyncio import ConnectionPool 3 | import os 4 | 5 | class RedisClient: 6 | def __init__(self): 7 | self.host = os.getenv('REDIS_HOST', '127.0.0.1') 8 | self.port = os.getenv('REDIS_PORT', 6379) 9 | self.db = os.getenv('REDIS_DB', 0) 10 | self.pool = ConnectionPool(host=self.host, port=self.port, db=self.db) 11 | self.client = redis.Redis(connection_pool=self.pool) 12 | 13 | def pipeline(self): 14 | return self.client.pipeline() 15 | 16 | redis_client = RedisClient().client 17 | 18 | async def get_redis_connection(): 19 | return redis_client 20 | -------------------------------------------------------------------------------- /common/sources/base/base_source.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from typing import Dict, Any, Optional, List 3 | from io import BytesIO 4 | 5 | class BaseSource(ABC): 6 | @abstractmethod 7 | def connect(self) -> None: 8 | """Establish a connection to the data source.""" 9 | pass 10 | 11 | @abstractmethod 12 | def read(self, data: Dict[str, Any]) -> Optional[BytesIO]: 13 | """Read a specific file from the data source.""" 14 | pass 15 | 16 | @abstractmethod 17 | def read_all(self) -> List[str]: 18 | """Retrieve a list of all valid files in the data source.""" 19 | pass -------------------------------------------------------------------------------- /common/sources/enums/source_enums.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | 3 | class SourceType(Enum): 4 | LOCAL_FS = "local_fs" 5 | S3 = "s3" 6 | -------------------------------------------------------------------------------- /common/sources/local_fs_source.py: -------------------------------------------------------------------------------- 1 | from common.sources.base.base_source import BaseSource 2 | from typing import Dict, Optional, Any, List, Union 3 | import os 4 | from io import BytesIO 5 | import logging 6 | import mimetypes 7 | 8 | logger = logging.getLogger(__name__) 9 | 10 | class LocalFSIntegration(BaseSource): 11 | ALLOWED_EXTENSIONS = {'.pdf', '.pptx', '.docx'} 12 | ALLOWED_MIMETYPES = { 13 | 'application/pdf', 14 | 'application/vnd.openxmlformats-officedocument.presentationml.presentation', 15 | 'application/vnd.openxmlformats-officedocument.wordprocessingml.document' 16 | } 17 | 18 | def __init__(self, base_path: str, additional_params: Optional[Dict[str, Any]] = None) -> None: 19 | self.base_path: str = base_path 20 | self.additional_params = additional_params or {} 21 | self.connect() 22 | 23 | def connect(self) -> None: 24 | if not os.path.exists(self.base_path): 25 | raise ValueError(f"The specified base path does not exist: {self.base_path}") 26 | 27 | @staticmethod 28 | def is_valid_file(file_path: str) -> bool: 29 | _, ext = os.path.splitext(file_path) 30 | if ext.lower() not in LocalFSIntegration.ALLOWED_EXTENSIONS: 31 | return False 32 | 33 | mime_type, _ = mimetypes.guess_type(file_path) 34 | return mime_type in LocalFSIntegration.ALLOWED_MIMETYPES 35 | 36 | def read(self, data: Dict[str, Any]) -> Optional[BytesIO]: 37 | file_key: Optional[str] = data.get('file_key') 38 | if not file_key: 39 | raise ValueError("The 'file_key' must be provided in the data dictionary.") 40 | 41 | file_path = os.path.join(self.base_path, file_key) 42 | if not os.path.exists(file_path): 43 | logger.warning(f"File {file_path} not found.") 44 | return None 45 | 46 | if not self.is_valid_file(file_path): 47 | logger.warning(f"File {file_path} is not a valid PDF, PowerPoint, or Word document.") 48 | return None 49 | 50 | try: 51 | with open(file_path, 'rb') as file: 52 | return BytesIO(file.read()) 53 | except Exception as e: 54 | logger.error(f"Error reading file {file_path}: {str(e)}") 55 | return None 56 | 57 | def read_all(self) -> List[str]: 58 | """Retrieve a list of all valid files in the base directory.""" 59 | try: 60 | files = os.listdir(self.base_path) 61 | valid_files = [ 62 | f for f in files 63 | if self.is_valid_file(os.path.join(self.base_path, f)) 64 | ] 65 | logger.info(f"Retrieved {len(valid_files)} valid files from {self.base_path}") 66 | return valid_files 67 | except Exception as e: 68 | logger.error(f"Error listing files in {self.base_path}: {e}") 69 | return [] -------------------------------------------------------------------------------- /common/sources/s3_source.py: -------------------------------------------------------------------------------- 1 | from common.sources.base.base_source import BaseSource 2 | from typing import Dict, Optional, Any, List 3 | from io import BytesIO 4 | import boto3 5 | import os 6 | import logging 7 | 8 | logger = logging.getLogger(__name__) 9 | 10 | class S3Integration(BaseSource): 11 | def __init__(self, bucket_name: str, additional_params: Optional[Dict[str, Any]] = None) -> None: 12 | self.bucket_name = bucket_name 13 | self.additional_params = additional_params or {} 14 | self.s3_client = boto3.client( 15 | 's3', 16 | region_name=self.additional_params.get('region_name'), 17 | aws_access_key_id=self.additional_params.get('aws_access_key_id'), 18 | aws_secret_access_key=self.additional_params.get('aws_secret_access_key'), 19 | aws_session_token=self.additional_params.get('aws_session_token') 20 | ) 21 | self.connect() 22 | 23 | def connect(self) -> None: 24 | try: 25 | self.s3_client.head_bucket(Bucket=self.bucket_name) 26 | logger.info(f"Successfully connected to S3 bucket: {self.bucket_name}") 27 | except Exception as e: 28 | logger.error(f"Failed to connect to S3 bucket {self.bucket_name}: {e}") 29 | raise 30 | 31 | def read(self, data: Dict[str, Any]) -> Optional[BytesIO]: 32 | file_key: Optional[str] = data.get('file_key') 33 | if not file_key: 34 | raise ValueError("The 'file_key' must be provided in the data dictionary.") 35 | 36 | try: 37 | response = self.s3_client.get_object(Bucket=self.bucket_name, Key=file_key) 38 | return BytesIO(response['Body'].read()) 39 | except Exception as e: 40 | logger.error(f"Error reading file from S3: {e}") 41 | return None 42 | 43 | def read_all(self) -> List[str]: 44 | """Retrieve a list of all valid files in the S3 bucket.""" 45 | try: 46 | paginator = self.s3_client.get_paginator('list_objects_v2') 47 | page_iterator = paginator.paginate(Bucket=self.bucket_name) 48 | valid_files = [] 49 | for page in page_iterator: 50 | contents = page.get('Contents', []) 51 | for obj in contents: 52 | key = obj['Key'] 53 | if self.is_valid_file(key): 54 | valid_files.append(key) 55 | logger.info(f"Retrieved {len(valid_files)} valid files from S3 bucket {self.bucket_name}") 56 | return valid_files 57 | except Exception as e: 58 | logger.error(f"Error listing files in S3 bucket {self.bucket_name}: {e}") 59 | return [] 60 | 61 | @staticmethod 62 | def is_valid_file(file_key: str) -> bool: 63 | allowed_extensions = {'.pdf', '.pptx', '.docx'} 64 | _, ext = os.path.splitext(file_key) 65 | return ext.lower() in allowed_extensions -------------------------------------------------------------------------------- /common/sources/source_factory.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Any 2 | from common.sources.enums.source_enums import SourceType 3 | from common.sources.local_fs_source import LocalFSIntegration 4 | from common.sources.s3_source import S3Integration 5 | 6 | class SourceFactory: 7 | @staticmethod 8 | def create_source(source_type: str, documents_location: str = None, additional_params: Dict[str, Any] = None): 9 | if not source_type: 10 | raise ValueError("source_type must be provided.") 11 | 12 | try: 13 | source_type_enum = SourceType(source_type.lower()) 14 | except ValueError: 15 | raise ValueError(f"Invalid source type. Allowed values are: {', '.join([s.value for s in SourceType])}") 16 | 17 | if source_type_enum == SourceType.LOCAL_FS: 18 | return LocalFSIntegration( 19 | base_path=documents_location, 20 | additional_params=additional_params 21 | ) 22 | elif source_type_enum == SourceType.S3: 23 | return S3Integration( 24 | bucket_name=documents_location, 25 | additional_params=additional_params 26 | ) 27 | else: 28 | raise ValueError(f"Unsupported source type: {source_type_enum}") -------------------------------------------------------------------------------- /common/text_extraction/text_extractor.py: -------------------------------------------------------------------------------- 1 | import PyPDF2 2 | from typing import List 3 | from io import BytesIO 4 | import logging 5 | from dotenv import load_dotenv 6 | from langchain.schema import SystemMessage, HumanMessage 7 | from langsmith import Client 8 | import asyncio 9 | import time 10 | from concurrent.futures import ThreadPoolExecutor 11 | from common.prompts.prompt_enums import PromptType 12 | import json 13 | from datetime import datetime 14 | from common.redis.redis_config import get_redis_connection 15 | 16 | load_dotenv() 17 | 18 | logger = logging.getLogger(__name__) 19 | 20 | def get_pdf_page_count(pdf_stream): 21 | try: 22 | pdf_reader = PyPDF2.PdfReader(pdf_stream) 23 | return len(pdf_reader.pages) 24 | except PyPDF2.errors.PdfReadError as e: 25 | raise PyPDF2.errors.PdfReadError(f"Error reading PDF file: {str(e)}") 26 | 27 | def extract_page_as_markdown(file_stream: BytesIO, page_number: int) -> str: 28 | import os 29 | from PyPDF2 import PdfReader, PdfWriter 30 | from markitdown import MarkItDown 31 | 32 | temp_file = "temp.pdf" 33 | 34 | try: 35 | if not isinstance(page_number, int) or page_number < 0: 36 | raise ValueError(f"Invalid page number: {page_number}") 37 | 38 | reader = PdfReader(file_stream) 39 | 40 | if page_number >= len(reader.pages): 41 | raise ValueError(f"Page number {page_number} exceeds document length of {len(reader.pages)} pages") 42 | 43 | writer = PdfWriter() 44 | writer.add_page(reader.pages[page_number]) 45 | 46 | try: 47 | with open(temp_file, 'wb') as output_file: 48 | writer.write(output_file) 49 | except IOError as e: 50 | raise IOError(f"Failed to write temporary PDF file: {str(e)}") 51 | 52 | try: 53 | markitdown = MarkItDown() 54 | result = markitdown.convert(temp_file) 55 | 56 | if not result or not result.text_content: 57 | logger.warning(f"No text content extracted from page {page_number}") 58 | return "" 59 | 60 | return result.text_content 61 | 62 | except Exception as e: 63 | raise Exception(f"Error converting PDF to markdown: {str(e)}") 64 | 65 | except Exception as e: 66 | logger.error(f"Error in extract_page_markdown: {str(e)}") 67 | raise 68 | 69 | finally: 70 | try: 71 | if os.path.exists(temp_file): 72 | os.remove(temp_file) 73 | except Exception as e: 74 | logger.error(f"Failed to remove temporary file {temp_file}: {str(e)}") 75 | 76 | def preprocess_messages(raw_payload): 77 | messages = [] 78 | if hasattr(raw_payload, 'to_messages'): 79 | for message in raw_payload.to_messages(): 80 | if isinstance(message, SystemMessage): 81 | messages.append({"role": "system", "content": message.content}) 82 | elif isinstance(message, HumanMessage): 83 | messages.append({"role": "user", "content": message.content}) 84 | else: 85 | logger.warning(f"Unexpected message type: {type(message)}") 86 | else: 87 | logger.warning(f"Unexpected raw_payload format: {type(raw_payload)}") 88 | return messages 89 | 90 | def process_page(client, prompt, page_number: int, page_text: str, formatted_keywords: str) -> tuple[int, dict]: 91 | try: 92 | raw_payload = prompt.invoke({ 93 | "first_value": page_text, 94 | "second_value": formatted_keywords 95 | }) 96 | 97 | processed_messages = preprocess_messages(raw_payload) 98 | if processed_messages: 99 | response = client.do_completion(processed_messages) 100 | logger.info(f"Response for page {page_number}: {response}") 101 | 102 | response_data = { 103 | 'response': response, 104 | 'keywords': formatted_keywords, 105 | 'timestamp': datetime.now().isoformat(), 106 | 'is_relevant': "yes" in response[:20].lower() 107 | } 108 | 109 | return (page_number if "yes" in response[:20].lower() else -1), response_data 110 | 111 | except Exception as e: 112 | logger.error(f"Error processing page {page_number}: {e}") 113 | return -1, { 114 | 'response': f"Error: {str(e)}", 115 | 'keywords': formatted_keywords, 116 | 'page_text': page_text, 117 | 'timestamp': datetime.now().isoformat(), 118 | 'is_relevant': False, 119 | 'error': True 120 | } 121 | 122 | async def find_common_pages(client, file_stream: BytesIO, formatted_keywords: str) -> List[int]: 123 | try: 124 | start_time = time.time() 125 | pdf_reader = PyPDF2.PdfReader(file_stream) 126 | langsmith_client = Client() 127 | prompt = langsmith_client.pull_prompt(PromptType.RELEVANT_PAGE_FINDER_V2.value) 128 | logger.info(f"MODEL TYPE: {type(client)}") 129 | 130 | run_id = datetime.now().strftime('%Y%m%d_%H%M%S') 131 | 132 | loop = asyncio.get_event_loop() 133 | with ThreadPoolExecutor(max_workers=10) as executor: 134 | tasks = [] 135 | for page_number in range(len(pdf_reader.pages)): 136 | page_text = extract_page_as_markdown(file_stream, page_number) 137 | task = loop.run_in_executor( 138 | executor, 139 | process_page, 140 | client, 141 | prompt, 142 | page_number, 143 | page_text, 144 | formatted_keywords 145 | ) 146 | tasks.append(task) 147 | 148 | results = await asyncio.gather(*tasks) 149 | 150 | # Process results and store in Redis 151 | page_responses = {} 152 | relevant_pages = [] 153 | 154 | logger.info(f"Processing {len(results)} results") 155 | 156 | for page_number, response_data in results: 157 | page_responses[str(page_number if page_number != -1 else results.index((page_number, response_data)))] = response_data 158 | 159 | if page_number != -1: 160 | relevant_pages.append(page_number) 161 | 162 | logger.info(f"Collected responses for {len(page_responses)} pages") 163 | 164 | if page_responses: 165 | try: 166 | redis_client = await get_redis_connection() 167 | redis_key = f"page_responses:{run_id}" 168 | data_to_store = json.dumps(page_responses) 169 | 170 | await redis_client.set(redis_key, data_to_store) 171 | await redis_client.expire(redis_key, 86400) # 1 day TTL 172 | 173 | stored_data = await redis_client.get(redis_key) 174 | if stored_data: 175 | logger.info(f"Successfully stored and verified data in Redis for run_id: {run_id}") 176 | else: 177 | logger.error("Failed to verify data storage in Redis") 178 | 179 | except Exception as e: 180 | logger.error(f"Redis storage error: {str(e)}") 181 | else: 182 | logger.warning("No page responses to store in Redis") 183 | 184 | end_time = time.time() 185 | total_time = end_time - start_time 186 | logger.info(f"Processed {len(pdf_reader.pages)} pages in {total_time:.2f} seconds") 187 | logger.info(f"Relevant Pages: {relevant_pages}") 188 | 189 | return relevant_pages 190 | 191 | except Exception as e: 192 | logger.error(f"Error finding common pages: {e}") 193 | return [] -------------------------------------------------------------------------------- /docker-compose.yml: -------------------------------------------------------------------------------- 1 | services: 2 | pipeline: 3 | build: 4 | context: . 5 | dockerfile: application/pipeline/Dockerfile 6 | ports: 7 | - "8100:8100" 8 | volumes: 9 | - ${PWD}/common:/app/common 10 | - ${PWD}/application/pipeline:/app/application/pipeline 11 | env_file: .env 12 | environment: 13 | - REDIS_HOST=redis 14 | - REDIS_PORT=6379 15 | - PYTHONPATH=/app 16 | depends_on: 17 | - redis 18 | networks: 19 | - marly_default 20 | 21 | extraction: 22 | build: 23 | context: . 24 | dockerfile: application/extraction/Dockerfile 25 | volumes: 26 | - ${PWD}/common:/app/common 27 | - ${PWD}/application/extraction:/app/application/extraction 28 | env_file: .env 29 | environment: 30 | - REDIS_HOST=redis 31 | - REDIS_PORT=6379 32 | - PYTHONPATH=/app 33 | depends_on: 34 | - redis 35 | networks: 36 | - marly_default 37 | 38 | transformation: 39 | build: 40 | context: . 41 | dockerfile: application/transformation/Dockerfile 42 | volumes: 43 | - ${PWD}/common:/app/common 44 | - ${PWD}/application/transformation:/app/application/transformation 45 | env_file: .env 46 | environment: 47 | - REDIS_HOST=redis 48 | - REDIS_PORT=6379 49 | - PYTHONPATH=/app 50 | depends_on: 51 | - redis 52 | networks: 53 | - marly_default 54 | 55 | redis: 56 | image: "redis:alpine" 57 | ports: 58 | - "6379:6379" 59 | command: redis-server --save "" --appendonly no 60 | networks: 61 | - marly_default 62 | 63 | networks: 64 | marly_default: 65 | name: marly_default 66 | -------------------------------------------------------------------------------- /examples/ai-workers/ai-sdr/auth/anon_helper.py: -------------------------------------------------------------------------------- 1 | import requests 2 | from datetime import datetime, timedelta 3 | import json 4 | from dotenv import load_dotenv 5 | import os 6 | import logging 7 | 8 | load_dotenv() 9 | 10 | BASE_URL = "https://svc.sandbox.anon.com/actions/linkedin/" 11 | ANON_USER_ID = os.environ.get("ANON_USER_ID") 12 | 13 | def get_headers(): 14 | return { 15 | "Authorization": f"Bearer {os.getenv('ANON_TOKEN')}" 16 | } 17 | 18 | def make_api_request(endpoint, method="GET", params=None): 19 | url = f"{BASE_URL}/{endpoint}" 20 | response = requests.request(method, url, headers=get_headers(), params=params) 21 | return json.loads(response.text) 22 | 23 | def get_recent_conversations(days=7): 24 | params = {"appUserId": ANON_USER_ID} 25 | data = make_api_request("listConversations", params=params) 26 | current_date = datetime.now() 27 | cutoff_date = current_date - timedelta(days=days) 28 | 29 | recent_profile_ids = [] 30 | for conversation in data['conversations']: 31 | conversation_date = datetime.strptime(conversation['timestamp'], "%Y-%m-%dT%H:%M:%S.%fZ") 32 | if conversation_date > cutoff_date: 33 | for profile in conversation['profiles']: 34 | if not profile['isSelf']: 35 | profile_id = profile.get('id', '') 36 | if profile_id.startswith('profile-'): 37 | profile_id = profile_id[8:] # Remove 'profile-' prefix 38 | if profile_id: 39 | recent_profile_ids.append(profile_id) 40 | 41 | return recent_profile_ids 42 | 43 | def get_profile(profile_id): 44 | params = {"id": profile_id, "appUserId": ANON_USER_ID} 45 | return make_api_request("getProfile", params=params) 46 | 47 | def raw_profile_data(): 48 | recent_profile_ids = get_recent_conversations() 49 | profiles = [] 50 | 51 | for profile_id in recent_profile_ids: 52 | try: 53 | if profile_id.isdigit(): 54 | profile_id = f"ACoAA{profile_id}" 55 | 56 | profile_data = get_profile(profile_id) 57 | 58 | if profile_data: 59 | profiles.append(profile_data) 60 | else: 61 | logging.warning(f"No data returned for profile ID: {profile_id}") 62 | except Exception as e: 63 | logging.error(f"Error fetching profile {profile_id}: {str(e)}") 64 | 65 | return profiles 66 | -------------------------------------------------------------------------------- /examples/ai-workers/ai-sdr/contacts.db: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/velocitybolt/open-extract/64d06b8f79e18d299767ee5fcf9746e9c9425bb4/examples/ai-workers/ai-sdr/contacts.db -------------------------------------------------------------------------------- /examples/ai-workers/ai-sdr/main.py: -------------------------------------------------------------------------------- 1 | from transformation.marly_helper import process_data 2 | from auth.anon_helper import raw_profile_data 3 | from output_source.sql_helper import SQLiteHelper 4 | 5 | import json 6 | 7 | def extract_metrics_data(results): 8 | if not results or not results[0].results: 9 | return None 10 | 11 | metrics = results[0].results[0]['metrics'] 12 | if not metrics: 13 | return None 14 | 15 | json_string = metrics[next(iter(metrics.keys()))] 16 | 17 | data = json.loads(json_string) 18 | 19 | if 'data_entries' in data: 20 | records = data['data_entries'] 21 | else: 22 | records = [data] 23 | 24 | cleaned_data = [] 25 | for record in records: 26 | cleaned_record = { 27 | 'id': record.get('id'), 28 | 'first_name': record.get('firstName'), 29 | 'last_name': record.get('lastName'), 30 | 'headline': record.get('headline'), 31 | 'location': record.get('location'), 32 | 'summary': record.get('summary'), 33 | 'connections_count': record.get('connectionsCount') 34 | } 35 | cleaned_data.append(cleaned_record) 36 | 37 | return cleaned_data 38 | 39 | if __name__ == "__main__": 40 | raw_data = raw_profile_data() 41 | processed_data = process_data(raw_data) 42 | cleaned_data = extract_metrics_data(processed_data) 43 | db = SQLiteHelper() 44 | success = db.insert_contact(cleaned_data) 45 | db.get_all_contacts() 46 | -------------------------------------------------------------------------------- /examples/ai-workers/ai-sdr/output_source/sql_helper.py: -------------------------------------------------------------------------------- 1 | import sqlite3 2 | from typing import Dict, List, Union 3 | from tabulate import tabulate 4 | import json 5 | 6 | class SQLiteHelper: 7 | def __init__(self, db_path: str = 'contacts.db'): 8 | self.db_path = db_path 9 | self.create_tables() 10 | 11 | def create_tables(self): 12 | """Create the necessary tables if they don't exist""" 13 | conn = sqlite3.connect(self.db_path) 14 | cursor = conn.cursor() 15 | 16 | cursor.execute(""" 17 | CREATE TABLE IF NOT EXISTS contacts ( 18 | id TEXT PRIMARY KEY, 19 | first_name TEXT, 20 | last_name TEXT, 21 | headline TEXT, 22 | location TEXT, 23 | summary TEXT, 24 | connections_count INTEGER 25 | ) 26 | """) 27 | 28 | conn.commit() 29 | conn.close() 30 | 31 | def insert_contact(self, contact_data: Union[Dict, List[Dict]]) -> bool: 32 | """Insert one or more contacts into the database""" 33 | conn = sqlite3.connect(self.db_path) 34 | cursor = conn.cursor() 35 | 36 | insert_sql = """ 37 | INSERT OR REPLACE INTO contacts ( 38 | id, first_name, last_name, headline, location, summary, 39 | connections_count 40 | ) VALUES (?, ?, ?, ?, ?, ?, ?) 41 | """ 42 | 43 | try: 44 | # Handle both single dict and list of dicts 45 | contacts_to_insert = contact_data if isinstance(contact_data, list) else [contact_data] 46 | 47 | for contact in contacts_to_insert: 48 | cursor.execute(insert_sql, ( 49 | contact['id'], 50 | contact['first_name'], 51 | contact['last_name'], 52 | contact['headline'], 53 | contact['location'], 54 | contact['summary'], 55 | contact['connections_count'], 56 | )) 57 | conn.commit() 58 | return True 59 | except sqlite3.Error as e: 60 | print(f"Error inserting contact(s): {e}") 61 | return False 62 | finally: 63 | conn.close() 64 | 65 | def get_all_contacts(self) -> None: 66 | """Retrieve and display all contacts from the database in a formatted table""" 67 | conn = sqlite3.connect(self.db_path) 68 | cursor = conn.cursor() 69 | 70 | cursor.execute("SELECT id, first_name, last_name, headline, location, summary, connections_count FROM contacts") 71 | results = cursor.fetchall() 72 | 73 | conn.close() 74 | 75 | if results: 76 | keys = ['id', 'first_name', 'last_name', 'headline', 'location', 77 | 'summary', 'connections_count'] 78 | 79 | print(tabulate(results, 80 | headers=keys, 81 | tablefmt='grid', 82 | maxcolwidths=[None, None, None, 30, 20, 40, None])) 83 | 84 | return [dict(zip(keys, result)) for result in results] 85 | 86 | print("No contacts found in database") 87 | return [] 88 | -------------------------------------------------------------------------------- /examples/ai-workers/ai-sdr/transformation/marly_helper.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | import time 4 | from dotenv import load_dotenv 5 | import os 6 | from marly import Marly 7 | 8 | load_dotenv() 9 | 10 | logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') 11 | 12 | BASE_URL = "http://localhost:8100" 13 | 14 | def process_data(raw_data): 15 | schema_1 = { 16 | "id": "Unique identifier for the profile for all contacts", 17 | "firstName": "First name of the person for all contacts", 18 | "lastName": "Last name of the person for all contacts", 19 | "headline": "Professional headline or tagline for all contacts", 20 | "location": "Geographic location of the person for all contacts", 21 | "summary": "Brief professional summary or bio for all contacts", 22 | "connectionsCount": "Number of LinkedIn connections for all contacts" 23 | } 24 | 25 | client = Marly(base_url=BASE_URL) 26 | 27 | try: 28 | pipeline_response_model = client.pipelines.create( 29 | api_key=os.getenv("AZURE_OPENAI_API_KEY"), 30 | provider_model_name=os.getenv("AZURE_MODEL_NAME"), 31 | provider_type="azure", 32 | workloads=[ 33 | { 34 | "destination": "sqlite", 35 | "documents_location": "contacts_table", 36 | "schemas": [json.dumps(schema_1)], 37 | "raw_data": json.dumps(raw_data), 38 | } 39 | ], 40 | additional_params={ 41 | "azure_endpoint": os.getenv("AZURE_ENDPOINT"), 42 | "azure_deployment": os.getenv("AZURE_DEPLOYMENT_ID"), 43 | "api_version": os.getenv("AZURE_API_VERSION") 44 | } 45 | ) 46 | 47 | logging.debug(f"Task ID: {pipeline_response_model.task_id}") 48 | 49 | max_attempts = 5 50 | attempt = 0 51 | while attempt < max_attempts: 52 | time.sleep(30) 53 | 54 | results = client.pipelines.retrieve(pipeline_response_model.task_id) 55 | logging.debug(f"Poll attempt {attempt + 1}: Status - {results.status}") 56 | 57 | if results.status == 'COMPLETED': 58 | logging.debug(f"Pipeline completed with results: {results.results}") 59 | return results.results 60 | elif results.status == 'FAILED': 61 | logging.error(f"Error: {results.error_message}") 62 | return None 63 | 64 | attempt += 1 65 | 66 | logging.warning("Timeout: Pipeline execution took too long.") 67 | return None 68 | 69 | except Exception as e: 70 | logging.error(f"Error in pipeline process: {e}") 71 | return None 72 | -------------------------------------------------------------------------------- /examples/example_files/lacers.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/velocitybolt/open-extract/64d06b8f79e18d299767ee5fcf9746e9c9425bb4/examples/example_files/lacers.pdf -------------------------------------------------------------------------------- /examples/example_files/lacers_reduced.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/velocitybolt/open-extract/64d06b8f79e18d299767ee5fcf9746e9c9425bb4/examples/example_files/lacers_reduced.pdf -------------------------------------------------------------------------------- /examples/notebooks/autogen_example/OAI_CONFIG_LIST.json: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "model": "llama-3.1-70b-versatile", 4 | "api_key": "", 5 | "api_type": "groq", 6 | "tags": ["tool", "llama-3.1-70b-versatile"] 7 | } 8 | ] -------------------------------------------------------------------------------- /examples/notebooks/autogen_example/lacers_reduced.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/velocitybolt/open-extract/64d06b8f79e18d299767ee5fcf9746e9c9425bb4/examples/notebooks/autogen_example/lacers_reduced.pdf -------------------------------------------------------------------------------- /examples/notebooks/autogen_example/plot.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/velocitybolt/open-extract/64d06b8f79e18d299767ee5fcf9746e9c9425bb4/examples/notebooks/autogen_example/plot.png -------------------------------------------------------------------------------- /examples/notebooks/langgraph_example/diagram.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/velocitybolt/open-extract/64d06b8f79e18d299767ee5fcf9746e9c9425bb4/examples/notebooks/langgraph_example/diagram.jpeg -------------------------------------------------------------------------------- /examples/notebooks/langgraph_example/lacers_reduced.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/velocitybolt/open-extract/64d06b8f79e18d299767ee5fcf9746e9c9425bb4/examples/notebooks/langgraph_example/lacers_reduced.pdf -------------------------------------------------------------------------------- /examples/notebooks/langgraph_example/notebooklm_workflow.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/velocitybolt/open-extract/64d06b8f79e18d299767ee5fcf9746e9c9425bb4/examples/notebooks/langgraph_example/notebooklm_workflow.jpeg -------------------------------------------------------------------------------- /examples/notebooks/langgraph_example/wkflow.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/velocitybolt/open-extract/64d06b8f79e18d299767ee5fcf9746e9c9425bb4/examples/notebooks/langgraph_example/wkflow.jpeg -------------------------------------------------------------------------------- /examples/scripts/api_example.py: -------------------------------------------------------------------------------- 1 | import base64 2 | import json 3 | import zlib 4 | import logging 5 | import time 6 | import requests 7 | 8 | logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s') 9 | 10 | PDF_FILE = "" 11 | API_KEY = "" 12 | BASE_URL = "https://api.marly.ai"; 13 | 14 | def read_and_encode_pdf(file_path): 15 | with open(file_path, "rb") as file: 16 | pdf_content = base64.b64encode(zlib.compress(file.read())).decode('utf-8') 17 | logging.debug(f"{file_path} read and encoded") 18 | return pdf_content 19 | 20 | def get_pipeline_results(task_id): 21 | logging.debug(f"Fetching results for task ID: {task_id}") 22 | response = requests.get(f"{BASE_URL}/pipelines/{task_id}", headers={"marly-api-key": API_KEY}) 23 | return response.json() 24 | 25 | def process_pdf(pdf_file): 26 | pdf_content = read_and_encode_pdf(pdf_file) 27 | 28 | schema_1 = { 29 | "Type of change": "Description of the type of change", 30 | "Location": "Location of the change", 31 | "Item": "Description of the item" 32 | } 33 | 34 | pipeline_request = { 35 | "license_key": "1234567890", 36 | "workloads": [ 37 | { 38 | "pdf_stream": pdf_content, 39 | "schemas": [json.dumps(schema_1)] 40 | } 41 | ] 42 | } 43 | 44 | logging.debug("Sending POST request to pipeline endpoint") 45 | response = requests.post(f"{BASE_URL}/pipelines", json=pipeline_request, headers={"marly-api-key": API_KEY}) 46 | 47 | logging.debug(f"Response status code: {response.status_code}") 48 | logging.debug(f"Response headers: {response.headers}") 49 | logging.debug(f"Response content: {response.text}") 50 | 51 | result = response.json() 52 | task_id = result.get("task_id") 53 | if not task_id: 54 | raise ValueError("Invalid task_id: task_id is None or empty") 55 | logging.debug(f"Task ID: {task_id}") 56 | 57 | max_attempts = 5 58 | attempt = 0 59 | while attempt < max_attempts: 60 | time.sleep(35) 61 | 62 | results = get_pipeline_results(task_id) 63 | logging.debug(f"Poll attempt {attempt + 1}: Status - {results['status']}") 64 | 65 | if results['status'] == 'Completed': 66 | return results 67 | elif results['status'] == 'Failed': 68 | logging.error(f"Error: {results.get('error_message', 'Unknown error')}") 69 | return None 70 | 71 | attempt += 1 72 | 73 | logging.warning("Timeout: Pipeline execution took too long.") 74 | return None 75 | 76 | def main(): 77 | results = process_pdf(PDF_FILE) 78 | 79 | if results: 80 | print("Raw API Response:") 81 | print(json.dumps(results, indent=2)) 82 | 83 | if 'results' in results and results['results']: 84 | for i, result in enumerate(results['results']): 85 | print(f"\nResult {i + 1}:") 86 | metrics = result.get('metrics', {}) 87 | print("Metrics:") 88 | print(json.dumps(metrics, indent=2)) 89 | 90 | schema_3 = metrics.get('schema_0', '{}') 91 | print("\nSchema 3:") 92 | print(schema_3) 93 | 94 | try: 95 | schema_3_json = json.loads(schema_3) 96 | print("\nSchema 3 (parsed):") 97 | print(json.dumps(schema_3_json, indent=2)) 98 | 99 | for key, value in schema_3_json.items(): 100 | print(f"\n{key}:") 101 | print(value) 102 | 103 | except json.JSONDecodeError: 104 | print("\nFailed to parse schema_3 as JSON") 105 | else: 106 | print("No results found in the API response") 107 | else: 108 | print("Failed to process PDF. Please check the logs for more information.") 109 | 110 | if __name__ == "__main__": 111 | main() 112 | -------------------------------------------------------------------------------- /examples/scripts/azure_example.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | import time 4 | from dotenv import load_dotenv 5 | import os 6 | from marly import Marly 7 | 8 | load_dotenv() 9 | 10 | logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') 11 | 12 | BASE_URL = "http://localhost:8100" 13 | 14 | def process_pdf(): 15 | schema_1 = { 16 | "Firm": "The name of the firm", 17 | "Number of Funds": "The number of funds managed by the firm", 18 | "Commitment": "The commitment amount in millions of dollars", 19 | "Percent of Total Comm": "The percentage of total commitment", 20 | "Exposure (FMV + Unfunded)": "The exposure including fair market value and unfunded commitments in millions of dollars", 21 | "Percent of Total Exposure": "The percentage of total exposure", 22 | "TVPI": "Total Value to Paid-In multiple", 23 | "Net IRR": "Net Internal Rate of Return as a percentage" 24 | } 25 | 26 | client = Marly(base_url=BASE_URL) 27 | 28 | try: 29 | pipeline_response_model = client.pipelines.create( 30 | api_key=os.getenv("AZURE_OPENAI_API_KEY"), 31 | provider_model_name=os.getenv("AZURE_MODEL_NAME"), 32 | provider_type="azure", 33 | workloads=[ 34 | { 35 | "file_name": "lacers reduced", 36 | "data_source": "local_fs", 37 | "documents_location": "/app/example_files", 38 | "schemas": [json.dumps(schema_1)], 39 | } 40 | ], 41 | additional_params={ 42 | "azure_endpoint": os.getenv("AZURE_ENDPOINT"), 43 | "azure_deployment": os.getenv("AZURE_DEPLOYMENT_ID"), 44 | "api_version": os.getenv("AZURE_API_VERSION") 45 | } 46 | ) 47 | 48 | logging.debug(f"Task ID: {pipeline_response_model.task_id}") 49 | 50 | max_attempts = 5 51 | attempt = 0 52 | while attempt < max_attempts: 53 | time.sleep(30) 54 | 55 | results = client.pipelines.retrieve(pipeline_response_model.task_id) 56 | logging.debug(f"Poll attempt {attempt + 1}: Status - {results.status}") 57 | 58 | if results.status == 'COMPLETED': 59 | logging.debug(f"Pipeline completed with results: {results.results}") 60 | return results.results 61 | elif results.status == 'FAILED': 62 | logging.error(f"Error: {results.error_message}") 63 | return None 64 | 65 | attempt += 1 66 | 67 | logging.warning("Timeout: Pipeline execution took too long.") 68 | return None 69 | 70 | except Exception as e: 71 | logging.error(f"Error in pipeline process: {e}") 72 | return None 73 | 74 | if __name__ == "__main__": 75 | results = process_pdf() 76 | print(results) -------------------------------------------------------------------------------- /examples/scripts/cerebras_example.py: -------------------------------------------------------------------------------- 1 | import base64 2 | import json 3 | import zlib 4 | import logging 5 | import os 6 | from dotenv import load_dotenv 7 | from marly import Marly 8 | import time 9 | 10 | load_dotenv() 11 | 12 | logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') 13 | 14 | PDF_FILE_PATH = "../example_files/lacers_reduced.pdf" 15 | BASE_URL = "http://localhost:8100" 16 | 17 | SCHEMA_1 = { 18 | "Firm": "The name of the firm", 19 | "Number of Funds": "The number of funds managed by the firm", 20 | "Commitment": "The commitment amount in millions of dollars", 21 | "Percent of Total Comm": "The percentage of total commitment", 22 | "Exposure (FMV + Unfunded)": "The exposure including fair market value and unfunded commitments in millions of dollars", 23 | "Percent of Total Exposure": "The percentage of total exposure", 24 | "TVPI": "Total Value to Paid-In multiple", 25 | "Net IRR": "Net Internal Rate of Return as a percentage" 26 | } 27 | 28 | def read_and_encode_pdf(file_path): 29 | with open(file_path, "rb") as file: 30 | pdf_content = base64.b64encode(zlib.compress(file.read())).decode('utf-8') 31 | logging.debug(f"{file_path} read and encoded") 32 | return pdf_content 33 | 34 | def process_pdf(pdf_file): 35 | pdf_content = read_and_encode_pdf(pdf_file) 36 | 37 | client = Marly(base_url="http://localhost:8100") 38 | 39 | try: 40 | pipeline_response = client.pipelines.create( 41 | api_key=os.getenv("CEREBRAS_API_KEY"), 42 | provider_model_name="llama3.3-70b", 43 | provider_type="cerebras", 44 | workloads=[{"raw_data": pdf_content, "schemas": [json.dumps(SCHEMA_1)]}], 45 | ) 46 | 47 | while True: 48 | results = client.pipelines.retrieve(pipeline_response.task_id) 49 | if results.status == 'COMPLETED': 50 | return results.results 51 | elif results.status == 'FAILED': 52 | return None 53 | time.sleep(15) 54 | 55 | except Exception as e: 56 | logging.error(f"Error in pipeline process: {e}") 57 | return None 58 | 59 | # Usage 60 | if __name__ == "__main__": 61 | result = process_pdf(PDF_FILE_PATH) 62 | print(result) 63 | -------------------------------------------------------------------------------- /examples/scripts/data_loading_example.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | import time 4 | from dotenv import load_dotenv 5 | import os 6 | from marly import Marly 7 | 8 | load_dotenv() 9 | 10 | logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') 11 | 12 | BASE_URL = "http://localhost:8100" 13 | 14 | def process_pdf(): 15 | schema_1 = { 16 | "Firm": "The name of the firm", 17 | "Number of Funds": "The number of funds managed by the firm", 18 | "Commitment": "The commitment amount in millions of dollars", 19 | "Percent of Total Comm": "The percentage of total commitment", 20 | "Exposure (FMV + Unfunded)": "The exposure including fair market value and unfunded commitments in millions of dollars", 21 | "Percent of Total Exposure": "The percentage of total exposure", 22 | "TVPI": "Total Value to Paid-In multiple", 23 | "Net IRR": "Net Internal Rate of Return as a percentage" 24 | } 25 | raw_data = json.dumps({ 26 | "Firm": "Evergreen Capital Partners", 27 | "Number of Funds": 7, 28 | "Commitment": 150.5, 29 | "Percent of Total Comm": 12.3, 30 | "Exposure (FMV + Unfunded)": 175.2, 31 | "Percent of Total Exposure": 14.8, 32 | "TVPI": 1.45, 33 | "Net IRR": 18.7, 34 | "Founded Year": 2005, 35 | "Headquarters": "New York, NY", 36 | "AUM": 3500.0, 37 | "Investment Strategy": "Growth Equity", 38 | "Target Industries": ["Technology", "Healthcare", "Consumer"], 39 | "Managing Partners": ["John Smith", "Sarah Johnson"], 40 | "Last Fund Closing Date": "2022-09-15", 41 | "ESG Focus": True 42 | }) 43 | 44 | client = Marly(base_url=BASE_URL) 45 | 46 | try: 47 | pipeline_response_model = client.pipelines.create( 48 | api_key=os.getenv("AZURE_OPENAI_API_KEY"), 49 | provider_model_name=os.getenv("AZURE_MODEL_NAME"), 50 | provider_type="azure", 51 | workloads=[ 52 | { 53 | "destination": "sqlite", 54 | "documents_location": "contacts_table", 55 | "schemas": [json.dumps(schema_1)], 56 | "raw_data": raw_data, 57 | } 58 | ], 59 | additional_params={ 60 | "azure_endpoint": os.getenv("AZURE_ENDPOINT"), 61 | "azure_deployment": os.getenv("AZURE_DEPLOYMENT_ID"), 62 | "api_version": os.getenv("AZURE_API_VERSION") 63 | } 64 | ) 65 | 66 | logging.debug(f"Task ID: {pipeline_response_model.task_id}") 67 | 68 | max_attempts = 5 69 | attempt = 0 70 | while attempt < max_attempts: 71 | time.sleep(5) 72 | 73 | results = client.pipelines.retrieve(pipeline_response_model.task_id) 74 | logging.debug(f"Poll attempt {attempt + 1}: Status - {results.status}") 75 | 76 | if results.status == 'COMPLETED': 77 | logging.debug(f"Pipeline completed with results: {results.results}") 78 | return results.results 79 | elif results.status == 'FAILED': 80 | logging.error(f"Error: {results.error_message}") 81 | return None 82 | 83 | attempt += 1 84 | 85 | logging.warning("Timeout: Pipeline execution took too long.") 86 | return None 87 | 88 | except Exception as e: 89 | logging.error(f"Error in pipeline process: {e}") 90 | return None 91 | 92 | if __name__ == "__main__": 93 | results = process_pdf() 94 | print(results) 95 | -------------------------------------------------------------------------------- /examples/scripts/groq_example.py: -------------------------------------------------------------------------------- 1 | import base64 2 | import json 3 | import zlib 4 | import logging 5 | import requests 6 | from dotenv import load_dotenv 7 | import os 8 | import time 9 | load_dotenv() 10 | 11 | logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s') 12 | 13 | BASE_URL = "http://localhost:8100" 14 | PDF_FILE_PATH = "../example_files/lacers_reduced.pdf" 15 | 16 | def read_and_encode_pdf(file_path): 17 | with open(file_path, "rb") as file: 18 | pdf_content = base64.b64encode(zlib.compress(file.read())).decode('utf-8') 19 | logging.debug(f"{file_path} read and encoded") 20 | return pdf_content 21 | 22 | def get_pipeline_results(task_id): 23 | logging.debug(f"Fetching results for task ID: {task_id}") 24 | response = requests.get(f"{BASE_URL}/pipelines/{task_id}") 25 | return response.json() 26 | 27 | def process_pdf(pdf_file): 28 | pdf_content = read_and_encode_pdf(pdf_file) 29 | 30 | schema_1 = { 31 | "Firm": "The name of the firm", 32 | "Number of Funds": "The number of funds managed by the firm", 33 | "Commitment": "The commitment amount in millions of dollars", 34 | "Percent of Total Comm": "The percentage of total commitment", 35 | "Exposure (FMV + Unfunded)": "The exposure including fair market value and unfunded commitments in millions of dollars", 36 | "Percent of Total Exposure": "The percentage of total exposure", 37 | "TVPI": "Total Value to Paid-In multiple", 38 | "Net IRR": "Net Internal Rate of Return as a percentage" 39 | } 40 | 41 | pipeline_request = { 42 | "workloads": [ 43 | { 44 | "raw_data": pdf_content, 45 | "schemas": [json.dumps(schema_1)] 46 | } 47 | ], 48 | "provider_type": "groq", 49 | "provider_model_name": "llama-3.1-70b-versatile", 50 | "api_key": os.getenv("GROQ_API_KEY"), 51 | "additional_params": {} 52 | } 53 | 54 | logging.debug("Sending POST request to pipeline endpoint") 55 | try: 56 | response = requests.post(f"{BASE_URL}/pipelines", json=pipeline_request) 57 | response.raise_for_status() 58 | except requests.exceptions.RequestException as e: 59 | logging.error(f"Error sending request: {e}") 60 | return 61 | 62 | logging.debug(f"Response status code: {response.status_code}") 63 | logging.debug(f"Response headers: {response.headers}") 64 | logging.debug(f"Response content: {response.text}") 65 | 66 | try: 67 | result = response.json() 68 | except json.JSONDecodeError: 69 | logging.error("Failed to decode JSON response") 70 | return 71 | 72 | task_id = result.get("task_id") 73 | if not task_id: 74 | logging.error("Invalid task_id: task_id is None or empty") 75 | return 76 | logging.debug(f"Task ID: {task_id}") 77 | 78 | max_attempts = 5 79 | attempt = 0 80 | while attempt < max_attempts: 81 | logging.debug(f"Waiting for pipeline to complete. Attempt {attempt + 1} of {max_attempts}") 82 | time.sleep(30) 83 | 84 | results = get_pipeline_results(task_id) 85 | logging.debug(f"Poll attempt {attempt + 1}: Status - {results['status']}") 86 | 87 | if results['status'] == 'COMPLETED': 88 | logging.debug(f"Pipeline completed with results: {results['results']}") 89 | return results['results'] 90 | elif results['status'] == 'FAILED': 91 | logging.error(f"Error: {results.get('error_message', 'Unknown error')}") 92 | return None 93 | 94 | attempt += 1 95 | 96 | logging.warning("Timeout: Pipeline execution took too long.") 97 | return None 98 | 99 | if __name__ == "__main__": 100 | process_pdf(PDF_FILE_PATH) 101 | -------------------------------------------------------------------------------- /examples/scripts/markdown_example.py: -------------------------------------------------------------------------------- 1 | import base64 2 | import json 3 | import zlib 4 | import logging 5 | import os 6 | import time 7 | from dotenv import load_dotenv 8 | from marly import Marly 9 | # This script is the same as local_example_cerebras but will return data in markdown instead of JSON 10 | # In order to extract data in markdown, you need to add markdown_mode in the POST pipelines request as seen 11 | # on line 46 12 | load_dotenv() 13 | 14 | logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') 15 | 16 | PDF_FILE_PATH = "../example_files/lacers_reduced.pdf" 17 | 18 | def read_and_encode_pdf(file_path): 19 | with open(file_path, "rb") as file: 20 | pdf_content = base64.b64encode(zlib.compress(file.read())).decode('utf-8') 21 | logging.debug(f"{file_path} read and encoded") 22 | return pdf_content 23 | 24 | def process_pdf(pdf_file): 25 | pdf_content = read_and_encode_pdf(pdf_file) 26 | 27 | schema_1 = { 28 | "Firm": "The name of the firm", 29 | "Number of Funds": "The number of funds managed by the firm", 30 | "Commitment": "The commitment amount in millions of dollars", 31 | "Percent of Total Comm": "The percentage of total commitment", 32 | "Exposure (FMV + Unfunded)": "The exposure including fair market value and unfunded commitments in millions of dollars", 33 | "Percent of Total Exposure": "The percentage of total exposure", 34 | "TVPI": "Total Value to Paid-In multiple", 35 | "Net IRR": "Net Internal Rate of Return as a percentage" 36 | } 37 | 38 | #running locally 39 | # TODO: Change back to 8100 40 | client = Marly(base_url="http://localhost:8100") 41 | 42 | try: 43 | pipeline_response_model = client.pipelines.create( 44 | api_key=os.getenv("CEREBRAS_API_KEY"), 45 | provider_model_name="llama3.1-70b", 46 | provider_type="cerebras", 47 | markdown_mode = True, 48 | workloads=[ 49 | { 50 | "raw_data": pdf_content, 51 | "schemas": [json.dumps(schema_1)], 52 | } 53 | ], 54 | ) 55 | logging.debug(f"Task ID: {pipeline_response_model.task_id}") 56 | except Exception as e: 57 | logging.error(f"Error creating pipeline: {e}") 58 | return None 59 | 60 | # Quick check for cached results 61 | time.sleep(5) 62 | try: 63 | results = client.pipelines.retrieve(pipeline_response_model.task_id) 64 | if results.status == 'COMPLETED': 65 | parsed_results = [json.loads(results.results[0].metrics[f'schema_{i}']) for i in range(len(results.results[0].metrics))] 66 | logging.info(f"Cached results found and returned {parsed_results}") 67 | return json.dumps(parsed_results, indent=2) 68 | except Exception as e: 69 | logging.debug(f"No cached results available: {e}") 70 | 71 | max_attempts = 5 72 | attempt = 0 73 | while attempt < max_attempts: 74 | logging.debug(f"Waiting for pipeline to complete. Attempt {attempt + 1} of {max_attempts}") 75 | time.sleep(30) 76 | 77 | try: 78 | results = client.pipelines.retrieve(pipeline_response_model.task_id) 79 | logging.debug(f"Poll attempt {attempt + 1}: Status - {results.status}") 80 | 81 | if results.status == 'COMPLETED': 82 | parsed_results = [results.results[0].metrics[f'schema_{i}'] for i in range(len(results.results[0].metrics))] 83 | logging.info(f"Results: {parsed_results}") 84 | return parsed_results # No need to json.dumps() here if we're returning Markdown 85 | elif results.status == 'FAILED': 86 | logging.error(f"Error: {results.error_message or 'Unknown error'}") 87 | return None 88 | 89 | except Exception as e: 90 | logging.error(f"Error fetching pipeline results: {e}") 91 | return None 92 | 93 | attempt += 1 94 | 95 | logging.warning("Timeout: Pipeline execution took too long.") 96 | return None 97 | 98 | if __name__ == "__main__": 99 | process_pdf(PDF_FILE_PATH) 100 | -------------------------------------------------------------------------------- /examples/scripts/mistral_example.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | import time 4 | from dotenv import load_dotenv 5 | import os 6 | from marly import Marly 7 | 8 | load_dotenv() 9 | 10 | logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') 11 | 12 | BASE_URL = "http://localhost:8100" 13 | 14 | def process_pdf(): 15 | schema_1 = { 16 | "Firm": "The name of the firm", 17 | "Number of Funds": "The number of funds managed by the firm", 18 | "Commitment": "The commitment amount in millions of dollars", 19 | "Percent of Total Comm": "The percentage of total commitment", 20 | "Exposure (FMV + Unfunded)": "The exposure including fair market value and unfunded commitments in millions of dollars", 21 | "Percent of Total Exposure": "The percentage of total exposure", 22 | "TVPI": "Total Value to Paid-In multiple", 23 | "Net IRR": "Net Internal Rate of Return as a percentage" 24 | } 25 | 26 | client = Marly(base_url=BASE_URL) 27 | 28 | try: 29 | pipeline_response_model = client.pipelines.create( 30 | api_key=os.getenv("MISTRAL_API_KEY"), 31 | provider_model_name="mistral-large-latest", 32 | provider_type="mistral", 33 | workloads=[ 34 | { 35 | "file_name": "lacers reduced", 36 | "data_source": "local_fs", 37 | "documents_location": "/app/example_files", 38 | "schemas": [json.dumps(schema_1)], 39 | } 40 | ] 41 | ) 42 | 43 | logging.debug(f"Task ID: {pipeline_response_model.task_id}") 44 | 45 | max_attempts = 5 46 | attempt = 0 47 | while attempt < max_attempts: 48 | time.sleep(30) 49 | 50 | results = client.pipelines.retrieve(pipeline_response_model.task_id) 51 | logging.debug(f"Poll attempt {attempt + 1}: Status - {results.status}") 52 | 53 | if results.status == 'COMPLETED': 54 | logging.debug(f"Pipeline completed with results: {results.results}") 55 | return results.results 56 | elif results.status == 'FAILED': 57 | logging.error(f"Error: {results.error_message}") 58 | return None 59 | 60 | attempt += 1 61 | 62 | logging.warning("Timeout: Pipeline execution took too long.") 63 | return None 64 | 65 | except Exception as e: 66 | logging.error(f"Error in pipeline process: {e}") 67 | return None 68 | 69 | if __name__ == "__main__": 70 | results = process_pdf() 71 | print(results) -------------------------------------------------------------------------------- /examples/scripts/non_marly_examples/lacers_reduced.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/velocitybolt/open-extract/64d06b8f79e18d299767ee5fcf9746e9c9425bb4/examples/scripts/non_marly_examples/lacers_reduced.pdf -------------------------------------------------------------------------------- /examples/scripts/non_marly_examples/llamaindex_pinecone.py: -------------------------------------------------------------------------------- 1 | from llama_index.readers.web import SimpleWebPageReader 2 | from llama_index.core import VectorStoreIndex 3 | from llama_index.vector_stores.pinecone import PineconeVectorStore 4 | from llama_index.core.storage.storage_context import StorageContext 5 | from llama_index.core.indices.postprocessor import MetadataReplacementPostProcessor 6 | from llama_index.core.settings import Settings 7 | from llama_index.core import SimpleDirectoryReader 8 | from dotenv import load_dotenv 9 | from pinecone import Pinecone, ServerlessSpec 10 | from llama_index.llms.azure_openai import AzureOpenAI 11 | from llama_index.embeddings.azure_openai import AzureOpenAIEmbedding 12 | from llama_index.core import Settings 13 | import os 14 | 15 | load_dotenv() 16 | 17 | SCHEMA = { 18 | "Names": "The first name and last name of the company founders", 19 | "Company Name": "Name of the Company", 20 | "Round": "The round of funding", 21 | "Round Size": "How much money has the company raised", 22 | "Investors": "The names of the investors in the companies (names of investors and firms)", 23 | "Company Valuation": "The current valuation of the company", 24 | "Summary": "Three sentence summary of the company" 25 | } 26 | 27 | SCHEMA_2 = { 28 | "Firm": "The name of the firm", 29 | "Number of Funds": "The number of funds managed by the firm", 30 | "Commitment": "The commitment amount in millions of dollars", 31 | "Percent of Total Comm": "The percentage of total commitment", 32 | "Exposure (FMV + Unfunded)": "The exposure including fair market value and unfunded commitments in millions of dollars", 33 | "Percent of Total Exposure": "The percentage of total exposure", 34 | "TVPI": "Total Value to Paid-In multiple", 35 | "Net IRR": "Net Internal Rate of Return as a percentage" 36 | } 37 | 38 | class LlamaIndexAzure: 39 | def __init__(self): 40 | self.llm = AzureOpenAI ( 41 | model=os.environ["AZURE_MODEL_NAME"], 42 | deployment_name=os.environ["AZURE_DEPLOYMENT_ID"], 43 | api_key=os.environ["AZURE_OPENAI_API_KEY"], 44 | azure_endpoint=os.environ["AZURE_ENDPOINT"], 45 | api_version=os.environ["AZURE_API_VERSION"], 46 | ) 47 | 48 | self.embed_model = AzureOpenAIEmbedding ( 49 | model=os.environ["EMBED_MODEL_NAME"], 50 | deployment_name=os.environ["EMBED_DEPLOYMENT_NAME"], 51 | api_key=os.environ["AZURE_OPENAI_API_KEY"], 52 | azure_endpoint=os.environ["EMBED_AZURE_ENDPOINT"], 53 | api_version=os.environ["EMBED_API_VERSION"], 54 | ) 55 | Settings.llm = self.llm 56 | Settings.embed_model = self.embed_model 57 | 58 | llamaindex_azure_llm = LlamaIndexAzure().llm 59 | llamaindex_azure_embed = LlamaIndexAzure().embed_model 60 | Settings.llm = llamaindex_azure_llm 61 | Settings.embed_model = llamaindex_azure_embed 62 | 63 | def ask_question(question): 64 | index = pinecone_setup() 65 | query_engine = index.as_query_engine( 66 | similarity_top_k=9, 67 | node_postprocessors=[ 68 | MetadataReplacementPostProcessor(target_metadata_key="window") 69 | ], 70 | ) 71 | 72 | return query_engine.query(question) 73 | 74 | def pinecone_setup(): 75 | print("Setting up Pinecone...") 76 | pc = Pinecone( 77 | api_key=os.environ.get("PINECONE_API_KEY"), 78 | environment="gcp-starter" 79 | ) 80 | index_name = "example" 81 | if index_name not in pc.list_indexes().names(): 82 | print("Creating new Pinecone index") 83 | load_docs_into_pinecone(pc,index_name=index_name) 84 | print("Pinecone setup completed and index created") 85 | 86 | print("Index already in the environment") 87 | pinecone_index = pc.Index(index_name) 88 | vector_store = PineconeVectorStore(pinecone_index=pinecone_index) 89 | index = VectorStoreIndex.from_vector_store(vector_store=vector_store) 90 | 91 | return index 92 | 93 | def load_docs_into_pinecone(pc, index_name): 94 | WEBSITE_URL = "https://techcrunch.com/2013/02/08/snapchat-raises-13-5m-series-a-led-by-benchmark-now-sees-60m-snaps-sent-per-day/" 95 | WEBSITE_URL2 = "https://techcrunch.com/2024/08/09/anysphere-a-github-copilot-rival-has-raised-60m-series-a-at-400m-valuation-from-a16z-thrive-sources-say/" 96 | 97 | 98 | documents = SimpleWebPageReader().load_data( 99 | [WEBSITE_URL, WEBSITE_URL2] 100 | ) 101 | 102 | print(f"Loaded {len(documents)} document(s)") 103 | 104 | reader = SimpleDirectoryReader(input_files=["./lacers_reduced.pdf"]) 105 | documents2 = reader.load_data() 106 | 107 | print(f"Loaded {len(documents2)} document(s)") 108 | 109 | pc.create_index( 110 | name=index_name, 111 | dimension=1536, 112 | metric='euclidean', 113 | spec=ServerlessSpec( 114 | cloud="aws", 115 | region="us-east-1", 116 | ) 117 | ) 118 | 119 | 120 | pc = Pinecone( 121 | api_key=os.environ.get("PINECONE_API_KEY"), 122 | environment="gcp-starter" 123 | ) 124 | 125 | pinecone_index = pc.Index(index_name) 126 | 127 | vector_store = PineconeVectorStore(pinecone_index=pinecone_index) 128 | storage_context = StorageContext.from_defaults(vector_store=vector_store) 129 | VectorStoreIndex.from_documents(documents, storage_context=storage_context) 130 | VectorStoreIndex.from_documents(documents2, storage_context=storage_context) 131 | 132 | if __name__ == "__main__": 133 | question1 = f"Extract values for {SCHEMA} for every company and extract {SCHEMA_2} for the 10 Largest Sponsor Relationships then return the result as JSON. Please include all the data!" 134 | answer = ask_question(question1) 135 | print("Answer:", answer) -------------------------------------------------------------------------------- /examples/scripts/requirements.txt: -------------------------------------------------------------------------------- 1 | certifi==2024.8.30 2 | charset-normalizer==3.3.2 3 | idna==3.10 4 | python-dotenv==1.0.1 5 | requests==2.32.3 6 | urllib3==2.2.3 7 | -------------------------------------------------------------------------------- /examples/scripts/s3_example.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | import os 4 | from dotenv import load_dotenv 5 | from marly import Marly 6 | import time 7 | import os 8 | 9 | load_dotenv() 10 | 11 | logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') 12 | 13 | S3_BUCKET_NAME = os.getenv("S3_BUCKET_NAME") 14 | S3_FILE_KEY = os.getenv("S3_FILE_KEY") 15 | BASE_URL = "http://localhost:8100" 16 | 17 | SCHEMA_1 = { 18 | "Firm": "The name of the firm", 19 | "Number of Funds": "The number of funds managed by the firm", 20 | "Commitment": "The commitment amount in millions of dollars", 21 | "Percent of Total Comm": "The percentage of total commitment", 22 | "Exposure (FMV + Unfunded)": "The exposure including fair market value and unfunded commitments in millions of dollars", 23 | "Percent of Total Exposure": "The percentage of total exposure", 24 | "TVPI": "Total Value to Paid-In multiple", 25 | "Net IRR": "Net Internal Rate of Return as a percentage" 26 | } 27 | 28 | def process_pdf(): 29 | client = Marly(base_url=BASE_URL) 30 | 31 | try: 32 | pipeline_response = client.pipelines.create( 33 | api_key=os.getenv("GROQ_API_KEY"), 34 | provider_model_name="llama-3.1-70b-versatile", 35 | provider_type="groq", 36 | workloads=[ 37 | { 38 | "file_name": S3_FILE_KEY, 39 | "data_source": "s3", 40 | "documents_location": S3_BUCKET_NAME, 41 | "schemas": [json.dumps(SCHEMA_1)] 42 | } 43 | ], 44 | additional_params={ 45 | "bucket_name": S3_BUCKET_NAME, 46 | "region_name": "us-east-1", 47 | "aws_access_key_id": os.getenv("AWS_ACCESS_KEY_ID"), 48 | "aws_secret_access_key": os.getenv("AWS_SECRET_ACCESS_KEY") 49 | } 50 | ) 51 | logging.debug(f"Task ID: {pipeline_response.task_id}") 52 | 53 | max_attempts = 5 54 | attempt = 0 55 | while attempt < max_attempts: 56 | time.sleep(10) 57 | 58 | results = client.pipelines.retrieve(pipeline_response.task_id) 59 | logging.debug(f"Poll attempt {attempt + 1}: Status - {results.status}") 60 | 61 | if results.status == 'COMPLETED': 62 | logging.debug(f"Pipeline completed with results: {results.results}") 63 | return results.results 64 | elif results.status == 'FAILED': 65 | logging.error(f"Error: {results.error_message}") 66 | return None 67 | 68 | attempt += 1 69 | 70 | logging.warning("Timeout: Pipeline execution took too long.") 71 | return None 72 | 73 | except Exception as e: 74 | logging.error(f"Error in pipeline process: {e}") 75 | return None 76 | 77 | if __name__ == "__main__": 78 | results = process_pdf() 79 | print(results) 80 | -------------------------------------------------------------------------------- /examples/scripts/web_and_document_example.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | import time 4 | from dotenv import load_dotenv 5 | import os 6 | from marly import Marly 7 | 8 | load_dotenv() 9 | 10 | logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') 11 | 12 | BASE_URL = "http://localhost:8100" 13 | WEBSITE_URL = "https://techcrunch.com/2013/02/08/snapchat-raises-13-5m-series-a-led-by-benchmark-now-sees-60m-snaps-sent-per-day/" 14 | WEBSITE_URL2 = "https://techcrunch.com/2024/08/09/anysphere-a-github-copilot-rival-has-raised-60m-series-a-at-400m-valuation-from-a16z-thrive-sources-say/" 15 | 16 | SCHEMA = { 17 | "Names": "The first name and last name of the company founders", 18 | "Company Name": "Name of the Company", 19 | "Round": "The round of funding", 20 | "Round Size": "How much money has the company raised", 21 | "Investors": "The names of the investors in the companies (names of investors and firms)", 22 | "Company Valuation": "The current valuation of the company", 23 | "Summary": "Three sentence summary of the company" 24 | } 25 | 26 | SCHEMA_2 = { 27 | "Firm": "The name of the firm", 28 | "Number of Funds": "The number of funds managed by the firm", 29 | "Commitment": "The commitment amount in millions of dollars", 30 | "Percent of Total Comm": "The percentage of total commitment", 31 | "Exposure (FMV + Unfunded)": "The exposure including fair market value and unfunded commitments in millions of dollars", 32 | "Percent of Total Exposure": "The percentage of total exposure", 33 | "TVPI": "Total Value to Paid-In multiple", 34 | "Net IRR": "Net Internal Rate of Return as a percentage" 35 | } 36 | 37 | def process_website(): 38 | client = Marly(base_url=BASE_URL) 39 | 40 | try: 41 | pipeline_response = client.pipelines.create( 42 | api_key=os.getenv("AZURE_OPENAI_API_KEY"), 43 | provider_model_name=os.getenv("AZURE_MODEL_NAME"), 44 | provider_type="azure", 45 | workloads=[{ 46 | "data_source": "web", 47 | "documents_location": WEBSITE_URL, 48 | "schemas": [json.dumps(SCHEMA)], 49 | }, 50 | { 51 | "data_source": "web", 52 | "documents_location": WEBSITE_URL2, 53 | "schemas": [json.dumps(SCHEMA)], 54 | }, 55 | { 56 | "file_name": "lacers reduced", 57 | "data_source": "local_fs", 58 | "documents_location": "/app/example_files", 59 | "schemas": [json.dumps(SCHEMA_2)], 60 | } 61 | ], 62 | additional_params={ 63 | "azure_endpoint": os.getenv("AZURE_ENDPOINT"), 64 | "azure_deployment": os.getenv("AZURE_DEPLOYMENT_ID"), 65 | "api_version": os.getenv("AZURE_API_VERSION") 66 | } 67 | ) 68 | 69 | while True: 70 | results = client.pipelines.retrieve(pipeline_response.task_id) 71 | if results.status == 'COMPLETED': 72 | processed_results = [] 73 | for result in results.results: 74 | for schema_result in result.results: 75 | processed_result = json.loads(schema_result['metrics']['schema_0']) 76 | processed_results.append(processed_result) 77 | return json.dumps(processed_results, indent=2) 78 | elif results.status == 'FAILED': 79 | return None 80 | time.sleep(15) 81 | 82 | except Exception as e: 83 | logging.error(f"Error in pipeline process: {e}") 84 | return None 85 | 86 | if __name__ == "__main__": 87 | print(process_website()) 88 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | redis==5.0.1 2 | beautifulsoup4==4.12.2 3 | python-dotenv==1.0.0 4 | langsmith==0.0.83 5 | PyPDF2==3.0.1 6 | markitdown==0.1.0 7 | langchain-core==0.1.27 8 | langgraph==0.0.27 -------------------------------------------------------------------------------- /start-oe.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | if [ ! -x "$0" ]; then 4 | echo -e "${RED}Error: Script doesn't have execute permissions${NC}" 5 | echo "Please run: chmod +x $0" 6 | exit 1 7 | fi 8 | 9 | RED='\033[0;31m' 10 | GREEN='\033[0;32m' 11 | BLUE='\033[0;34m' 12 | YELLOW='\033[1;33m' 13 | NC='\033[0m' 14 | 15 | echo -e "${BLUE}=== Checking System Requirements ===${NC}" 16 | 17 | if [ -f "examples/scripts/requirements.txt" ]; then 18 | 19 | if command -v pip3 &> /dev/null; then 20 | PIP_CMD="pip3" 21 | elif command -v pip &> /dev/null; then 22 | PIP_CMD="pip" 23 | else 24 | echo -e "${RED}Error: Neither pip nor pip3 is installed${NC}" 25 | exit 1 26 | fi 27 | 28 | missing_packages=0 29 | while IFS= read -r package || [ -n "$package" ]; do 30 | [[ $package =~ ^[[:space:]]*$ || $package =~ ^#.*$ ]] && continue 31 | 32 | package_name=$(echo "$package" | cut -d'=' -f1 | cut -d'>' -f1 | cut -d'<' -f1 | tr -d ' ') 33 | 34 | if ! $PIP_CMD show "$package_name" >/dev/null 2>&1; then 35 | echo -e "${RED}❌ Missing package: ${package_name}${NC}" 36 | missing_packages=1 37 | else 38 | echo -e "${GREEN}✓ ${package_name} is installed${NC}" 39 | fi 40 | done < "examples/scripts/requirements.txt" 41 | 42 | if [ $missing_packages -eq 1 ]; then 43 | echo -e "\n${RED}Error: Missing Python requirements${NC}" 44 | echo -e "Please install required packages with:" 45 | echo -e "$PIP_CMD install -r examples/scripts/requirements.txt" 46 | exit 1 47 | fi 48 | else 49 | echo -e "${RED}Error: examples/scripts/requirements.txt not found${NC}" 50 | exit 1 51 | fi 52 | 53 | if ! command -v docker &> /dev/null; then 54 | echo -e "${RED}Error: Docker is not installed${NC}" 55 | echo "Please install Docker first: https://docs.docker.com/get-docker/" 56 | exit 1 57 | else 58 | echo -e "${GREEN}✓ Docker is installed${NC}" 59 | fi 60 | 61 | if ! command -v docker-compose &> /dev/null; then 62 | echo -e "${RED}Error: Docker Compose is not installed${NC}" 63 | echo "Please install Docker Compose: https://docs.docker.com/compose/install/" 64 | exit 1 65 | else 66 | echo -e "${GREEN}✓ Docker Compose is installed${NC}" 67 | fi 68 | 69 | if ! docker info &> /dev/null; then 70 | echo -e "\n${RED}Error: Docker daemon is not running${NC}" 71 | echo "Please start Docker daemon first" 72 | exit 1 73 | else 74 | echo -e "${GREEN}✓ Docker daemon is running${NC}" 75 | fi 76 | 77 | echo -e "\n${BLUE}=== Attempting to load .env file... ===${NC}" 78 | 79 | if [ ! -f .env ]; then 80 | echo -e "${RED}Error: .env file not found${NC}" 81 | echo -e "Please create a .env file with your provider credentials" 82 | exit 1 83 | fi 84 | 85 | set -a 86 | source .env >/dev/null 2>&1 87 | set +a 88 | 89 | provider_configured=false 90 | 91 | if [ ! -z "${AZURE_OPENAI_API_KEY}" ]; then 92 | provider_configured=true 93 | missing_vars=0 94 | 95 | azure_vars=( 96 | "AZURE_RESOURCE_NAME" 97 | "AZURE_DEPLOYMENT_ID" 98 | "AZURE_MODEL_NAME" 99 | "AZURE_API_VERSION" 100 | "AZURE_OPENAI_API_KEY" 101 | "AZURE_ENDPOINT" 102 | ) 103 | 104 | for var in "${azure_vars[@]}"; do 105 | if [ -z "${!var}" ]; then 106 | echo -e "${RED}❌ Missing ${var}${NC}" 107 | missing_vars=1 108 | else 109 | echo -e "${GREEN}✓ ${var} is set${NC}" 110 | fi 111 | done 112 | 113 | if [ $missing_vars -eq 1 ]; then 114 | echo -e "${RED}Error: Azure OpenAI requires all related environment variables${NC}" 115 | exit 1 116 | fi 117 | fi 118 | 119 | if [ ! -z "${OPENAI_API_KEY}" ]; then 120 | provider_configured=true 121 | echo -e "${GREEN}✓ OpenAI configured${NC}" 122 | fi 123 | 124 | if [ ! -z "${CEREBRAS_API_KEY}" ]; then 125 | provider_configured=true 126 | echo -e "${GREEN}✓ Cerebras configured${NC}" 127 | fi 128 | 129 | if [ ! -z "${GROQ_API_KEY}" ]; then 130 | provider_configured=true 131 | echo -e "${GREEN}✓ Groq configured${NC}" 132 | fi 133 | 134 | if [ ! -z "${MISTRAL_API_KEY}" ]; then 135 | provider_configured=true 136 | echo -e "${GREEN}✓ Mistral configured${NC}" 137 | fi 138 | 139 | if [ "$provider_configured" = false ]; then 140 | echo -e "${YELLOW}Warning: No AI Model provider credentials found in .env file${NC}" 141 | echo -e "These are some of the model providers that are supported:" 142 | echo -e "- Azure OpenAI (AZURE_OPENAI_API_KEY + related vars)" 143 | echo -e "- OpenAI (OPENAI_API_KEY)" 144 | echo -e "- Cerebras (CEREBRAS_API_KEY)" 145 | echo -e "- Groq (GROQ_API_KEY)" 146 | echo -e "- Mistral (MISTRAL_API_KEY)" 147 | echo -e "\n${YELLOW}Are you sure you want to continue without any of these providers configured? (yes/no)${NC}" 148 | read -r answer 149 | 150 | if [ "$(echo $answer | tr '[:upper:]' '[:lower:]')" = "yes" ]; then 151 | echo -e "${YELLOW}Continuing without provider configuration. Please ensure your desired AI model provider is supported by Marly.${NC}" 152 | else 153 | echo -e "${RED}Please configure at least one provider in your .env file${NC}" 154 | exit 1 155 | fi 156 | fi 157 | 158 | 159 | if ! grep -q "^LANGCHAIN_TRACING_V2=true" .env; then 160 | echo -e "${YELLOW}Warning: LANGCHAIN_TRACING_V2 is not enabled${NC}" 161 | echo -e "\n${YELLOW}Are you sure you want to continue without LangSmith tracing? (yes/no)${NC}" 162 | read -r answer 163 | 164 | if [ "$(echo $answer | tr '[:upper:]' '[:lower:]')" = "yes" ]; then 165 | echo -e "${YELLOW}Continuing without LangSmith tracing...${NC}" 166 | else 167 | echo -e "${RED}Please configure LangSmith tracing in your .env file${NC}" 168 | echo -e "\nRequired variables:" 169 | echo -e "LANGCHAIN_TRACING_V2=true" 170 | echo -e "LANGCHAIN_ENDPOINT=https://api.smith.langchain.com" 171 | echo -e "LANGCHAIN_API_KEY=" 172 | echo -e "LANGCHAIN_PROJECT=" 173 | exit 1 174 | fi 175 | else 176 | echo -e "\n${BLUE}Checking LangSmith configuration:${NC}" 177 | missing_vars=0 178 | 179 | langsmith_vars=( 180 | "LANGCHAIN_ENDPOINT" 181 | "LANGCHAIN_API_KEY" 182 | "LANGCHAIN_PROJECT" 183 | ) 184 | 185 | for var in "${langsmith_vars[@]}"; do 186 | if ! grep -q "^${var}=" .env; then 187 | echo -e "${RED}❌ ${var} must be explicitly set in .env file${NC}" 188 | missing_vars=1 189 | else 190 | echo -e "${GREEN}✓ ${var} is set${NC}" 191 | fi 192 | done 193 | 194 | if [ $missing_vars -eq 1 ]; then 195 | echo -e "${RED}Error: LangSmith tracing requires all related environment variables${NC}" 196 | exit 1 197 | else 198 | echo -e "${GREEN}✓ LangSmith tracing has been correctly configured${NC}" 199 | fi 200 | fi 201 | 202 | echo -e "\n${BLUE}=== Checking for Running Containers ===${NC}" 203 | 204 | if docker ps --format '{{.Names}}' | grep -qE '(pipeline-1|extraction-1|transformation-1)'; then 205 | echo -e "${RED}Error: Marly containers are already running${NC}" 206 | echo -e "\nRunning containers:" 207 | docker ps --format 'table {{.Names}}\t{{.Status}}' | grep -E '(pipeline-1|extraction-1|transformation-1)' 208 | exit 1 209 | else 210 | echo -e "\n${GREEN}✓ All checks passed! Starting services...${NC}" 211 | docker-compose up --build 212 | fi 213 | --------------------------------------------------------------------------------