├── .gitignore
├── LICENSE.md
├── README.md
├── application
├── extraction
│ ├── Dockerfile
│ ├── models
│ │ └── models.py
│ ├── requirements.txt
│ ├── service
│ │ ├── extraction_handler.py
│ │ ├── extraction_worker.py
│ │ └── processing_handler.py
│ └── start_extraction.py
├── pipeline
│ ├── Dockerfile
│ ├── models
│ │ └── models.py
│ ├── requirements.txt
│ ├── routes
│ │ └── pipeline_routes.py
│ ├── service
│ │ └── pipeline_service.py
│ └── start_pipeline.py
└── transformation
│ ├── Dockerfile
│ ├── models
│ └── models.py
│ ├── requirements.txt
│ ├── service
│ ├── transformation_handler.py
│ └── transformation_worker.py
│ └── start_transformation.py
├── common
├── agents
│ ├── agent_prompt_enums.py
│ └── prs_agent.py
├── api
│ └── 3.1.0-marly-spec.yml
├── destinations
│ ├── base
│ │ └── base_destination.py
│ ├── destination_factory.py
│ ├── enums
│ │ └── destination_enums.py
│ └── sqlite_destination.py
├── models
│ ├── azure_model.py
│ ├── base
│ │ └── base_model.py
│ ├── cerebras_model.py
│ ├── enums
│ │ └── model_enums.py
│ ├── groq_model.py
│ ├── mistral_model.py
│ ├── model_factory.py
│ └── openai_model.py
├── prompts
│ └── prompt_enums.py
├── redis
│ └── redis_config.py
├── sources
│ ├── base
│ │ └── base_source.py
│ ├── enums
│ │ └── source_enums.py
│ ├── local_fs_source.py
│ ├── s3_source.py
│ └── source_factory.py
└── text_extraction
│ └── text_extractor.py
├── docker-compose.yml
├── examples
├── ai-workers
│ └── ai-sdr
│ │ ├── auth
│ │ └── anon_helper.py
│ │ ├── contacts.db
│ │ ├── main.py
│ │ ├── output_source
│ │ └── sql_helper.py
│ │ └── transformation
│ │ └── marly_helper.py
├── example_files
│ ├── lacers.pdf
│ └── lacers_reduced.pdf
├── notebooks
│ ├── autogen_example
│ │ ├── OAI_CONFIG_LIST.json
│ │ ├── autogen.ipynb
│ │ ├── lacers_reduced.pdf
│ │ └── plot.png
│ └── langgraph_example
│ │ ├── data_loading_workflow.ipynb
│ │ ├── diagram.jpeg
│ │ ├── lacers_reduced.pdf
│ │ ├── notebooklm_workflow.ipynb
│ │ ├── notebooklm_workflow.jpeg
│ │ ├── pdf_table_to_chart_workflow.ipynb
│ │ └── wkflow.jpeg
└── scripts
│ ├── api_example.py
│ ├── azure_example.py
│ ├── cerebras_example.py
│ ├── data_loading_example.py
│ ├── groq_example.py
│ ├── markdown_example.py
│ ├── mistral_example.py
│ ├── non_marly_examples
│ ├── lacers_reduced.pdf
│ └── llamaindex_pinecone.py
│ ├── requirements.txt
│ ├── s3_example.py
│ └── web_and_document_example.py
├── requirements.txt
└── start-oe.sh
/.gitignore:
--------------------------------------------------------------------------------
1 | .env
2 | .chroma
3 | .idea
4 | .ipynb_checkpoints
5 | .mypy_cache
6 | .vscode
7 | __pycache__
8 | .pytest_cache
9 | htmlcov
10 | dist
11 | site
12 | .coverage
13 | coverage.xml
14 | .netlify
15 | *_venv/
16 | test.db
17 | log.txt
18 | Pipfile.lock
19 | env3.*
20 | env
21 | docs_build
22 | venv
23 | docs.zip
24 | archive.zip
25 | __init__.py
26 |
27 | # vim temporary files
28 | *~
29 | .*.sw?
30 |
31 | */myenv
32 | */venv
33 | .DS_Store
34 |
35 |
36 | example_notebooks/*/.cache/
37 |
--------------------------------------------------------------------------------
/LICENSE.md:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2025 Murtaza Meerza
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | # open-extract
4 |
5 | [Features](#-features) • [What is a Schema?](#-what-is-a-schema) • [Use Cases](#-use-cases) • [Getting Started](#-getting-started) • [Documentation](#-documentation)
6 |
7 |
8 |
9 | ---
10 |
11 | open-extract simplifies the ingestion and processing of unstructured data for those building AI Agents/Agentic Workflows using frameworks such as LangGraph, AG2, and CrewAI.
12 |
13 | ---
14 |
15 | ## 🚀 Features
16 |
17 | 📄 Extract Relevant Information Seamlessly: Give your applications the ability to identify and extract relevant data from one or many large documents and websites with just a single API call. Get the content back in JSON or Markdown formats, making it easy to integrate into your workflows.
18 |
19 | 🔍 Multi-Schema/Multi-Document Support: Extract data based one or many predefined schemas from a variety of document types, without needing a vector database or specifying page numbers.
20 |
21 | 🔄 Built-in Caching: With built-in caching, previously extracted schemas can be instantly retrieved, enabling rapid repeat extractions without having to reprocess the original documents.
22 |
23 | 🚫 No Vendor Lock-In: Enjoy complete flexibility with your choice of model provider. Whether using open-source or closed-source models, you're never tied to a specific vendor, ensuring full control.
24 |
25 | ---
26 |
27 | ## 🧰 What is a Schema?
28 |
29 | A schema is a set of key-value pairs describing what needs to be extracted from a particular document.
30 |
31 |
32 | 📋 Example Schema
33 |
34 | ```
35 | {
36 | "Firm": "The name of the firm",
37 | "Number of Funds": "The number of funds managed by the firm",
38 | "Commitment": "The commitment amount in millions of dollars",
39 | "% of Total Comm": "The percentage of total commitment",
40 | "Exposure (FMV + Unfunded)": "The exposure including fair market value and unfunded commitments in millions of dollars",
41 | "% of Total Exposure": "The percentage of total exposure",
42 | "TVPI": "Total Value to Paid-In multiple",
43 | "Net IRR": "Net Internal Rate of Return as a percentage"
44 | }
45 | ```
46 |
47 |
48 |
49 |
50 |
51 | ---
52 |
53 | ## 🎯 Use Cases
54 |
55 |
56 |
57 | 💼 Financial Report Analysis |
58 | 📊 Customer Feedback Processing |
59 | 🔬 Research Assistant |
60 | 🧠 Legal Contract Parsing |
61 |
62 |
63 | Extract key financial metrics from quarterly PDF reports |
64 | Categorize feedback from various document types |
65 | Process research papers, extracting methodologies and findings |
66 | Extract key legal terms and conditions from contracts |
67 |
68 |
69 |
70 | ---
71 |
72 | ## 🛠️ Getting Started
73 |
74 | ### Build the Platform
75 |
76 | ---
77 |
78 | To build the platform from source, run the following command:
79 |
80 | ```bash
81 | ./start-oe.sh
82 | ```
83 |
84 | ---
85 |
86 | ### Run an example script or notebook
87 |
88 | Once the platform is running you can test it out by trying one of our examples
89 |
90 | 1. Navigate to the examples folder:
91 |
92 | ```bash
93 | cd examples
94 | ```
95 | 2. Navigate to the scripts or notebooks folder:
96 |
97 | ```bash
98 | cd scripts
99 | ```
100 | or
101 | ```bash
102 | cd notebooks/autogen_example
103 | ```
104 | 3. Run one of our example scripts:
105 | ```bash
106 | python azure_example.py
107 | ```
108 |
109 | ---
110 |
111 |
--------------------------------------------------------------------------------
/application/extraction/Dockerfile:
--------------------------------------------------------------------------------
1 | FROM python:3.10-slim
2 |
3 | WORKDIR /app
4 |
5 | # Copy the common directory
6 | COPY ./common /app/common
7 |
8 | # Copy the extraction application directory
9 | COPY ./application/extraction /app/application/extraction
10 |
11 | # Copy the requirements file
12 | COPY ./application/extraction/requirements.txt /app/requirements.txt
13 |
14 | RUN apt-get update && apt-get install -y \
15 | build-essential \
16 | make \
17 | && rm -rf /var/lib/apt/lists/*
18 |
19 | # Install dependencies
20 | RUN pip install --no-cache-dir -r /app/requirements.txt
21 |
22 | # Set the Python path
23 | ENV PYTHONPATH /app
24 |
25 | # Run the extraction service
26 | CMD ["python", "application/extraction/start_extraction.py"]
27 |
--------------------------------------------------------------------------------
/application/extraction/models/models.py:
--------------------------------------------------------------------------------
1 | from pydantic import BaseModel
2 | from typing import List, Dict
3 | from enum import Enum
4 |
5 | class ExtractionRequestModel(BaseModel):
6 | task_id: str
7 | pdf_key: str
8 | schemas: List[Dict]
9 | source_type: str = "pdf"
10 | destination: str = None
11 |
12 | class SchemaResult(BaseModel):
13 | schema_id: str
14 | metrics: Dict[str, str]
15 | schema_data: Dict[str, str]
16 |
17 | class ExtractionResponseModel(BaseModel):
18 | task_id: str
19 | pdf_key: str
20 | results: List[SchemaResult]
21 | source_type: str = "pdf"
22 |
23 | class JobStatus(str, Enum):
24 | PENDING = "PENDING"
25 | IN_PROGRESS = "IN_PROGRESS"
26 | COMPLETED = "COMPLETED"
27 | FAILED = "FAILED"
28 |
29 | class ModelDetails(BaseModel):
30 | provider_type: str
31 | provider_model_name: str
32 | api_key: str
33 | markdown_mode: bool
34 | additional_params: Dict[str, str]
35 |
--------------------------------------------------------------------------------
/application/extraction/requirements.txt:
--------------------------------------------------------------------------------
1 | python-dotenv
2 | pydantic
3 | typing
4 | langchain
5 | langchainhub
6 | langsmith
7 | langgraph
8 | openai
9 | groq
10 | cerebras_cloud_sdk
11 | PyPDF2
12 | pdfminer.six
13 | pymupdf
14 | aiohttp
15 | requests==2.32.3
16 | redis[hiredis]
17 | boto3
18 | beautifulsoup4
19 | markitdown
20 | asyncio
21 | celery
22 | tiktoken
--------------------------------------------------------------------------------
/application/extraction/service/extraction_handler.py:
--------------------------------------------------------------------------------
1 | import asyncio
2 | import logging
3 | import time
4 | import gc
5 | from typing import List, Dict, NamedTuple, Optional
6 | from io import BytesIO
7 | from redis.asyncio import Redis
8 | from common.redis.redis_config import get_redis_connection
9 | from common.text_extraction.text_extractor import extract_page_as_markdown, find_common_pages
10 | import base64
11 | from dotenv import load_dotenv
12 | from common.models.model_factory import ModelFactory
13 | from langsmith import Client as LangSmithClient
14 | from common.prompts.prompt_enums import PromptType
15 | import tiktoken
16 | from dataclasses import dataclass
17 | from celery import Celery
18 | from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor
19 | from application.extraction.service.processing_handler import (
20 | get_latest_model_details,
21 | preprocess_messages,
22 | process_web_content
23 | )
24 | from common.agents.prs_agent import process_extraction, AgentMode
25 |
26 | celery_app = Celery('extraction_tasks')
27 | celery_app.conf.update(
28 | broker_url='redis://redis:6379/0',
29 | result_backend='redis://redis:6379/0',
30 | task_serializer='json',
31 | result_serializer='json',
32 | accept_content=['json'],
33 | task_routes={
34 | 'process_pdf_chunk': {'queue': 'pdf_processing'},
35 | 'process_batch': {'queue': 'batch_processing'}
36 | },
37 | task_time_limit=3600,
38 | worker_prefetch_multiplier=1,
39 | worker_max_tasks_per_child=100
40 | )
41 |
42 | MAX_WORKERS = 10
43 | thread_pool = ThreadPoolExecutor(max_workers=MAX_WORKERS)
44 | process_pool = ProcessPoolExecutor(max_workers=MAX_WORKERS)
45 |
46 | logging.basicConfig(level=logging.INFO)
47 | logger = logging.getLogger(__name__)
48 | load_dotenv()
49 |
50 | langsmith_client = LangSmithClient()
51 |
52 | class PageNumbers(NamedTuple):
53 | pages: List[int]
54 |
55 | @dataclass
56 | class ProcessingProgress:
57 | current: int
58 | total: int
59 | stage: str
60 | timestamp: float
61 | status: str
62 |
63 | class ProcessingQueue:
64 | def __init__(self):
65 | self.queue = asyncio.Queue()
66 | self.results = {}
67 | self.progress = {}
68 |
69 | async def add_job(self, pdf_key: str, job_id: str):
70 | await self.queue.put((pdf_key, job_id))
71 | self.progress[job_id] = ProcessingProgress(0, 0, "queued", time.time(), "pending")
72 |
73 | async def get_progress(self, job_id: str) -> Optional[ProcessingProgress]:
74 | return self.progress.get(job_id)
75 |
76 | processing_queue = ProcessingQueue()
77 |
78 | async def get_model_client():
79 | """Get model client for batch processing."""
80 | redis = await get_redis_connection()
81 | model_details = await get_latest_model_details(redis)
82 | if not model_details:
83 | raise Exception("Could not get model details")
84 |
85 | return ModelFactory.create_model(
86 | model_type=model_details.provider_type,
87 | model_name=model_details.provider_model_name,
88 | api_key=model_details.api_key,
89 | additional_params=model_details.additional_params
90 | )
91 |
92 | @celery_app.task(name='process_pdf_chunk')
93 | def process_pdf_chunk(pdf_chunk: bytes, start_page: int, end_page: int, job_id: str):
94 | """Celery task for processing a chunk of PDF pages."""
95 | file_stream = BytesIO(pdf_chunk)
96 | return asyncio.run(_process_pdf_chunk(file_stream, start_page, end_page, job_id))
97 |
98 | @celery_app.task(name='process_batch')
99 | def process_batch(batch_content: str, keywords: str, examples: str, job_id: str):
100 | """Celery task for processing a batch of content."""
101 | return asyncio.run(_process_batch(batch_content, keywords, examples, job_id))
102 |
103 | async def _process_pdf_chunk(file_stream: BytesIO, start_page: int, end_page: int, job_id: str) -> Dict:
104 | """Process a chunk of PDF pages asynchronously."""
105 | try:
106 | extracted_contents = []
107 | for page in range(start_page, end_page + 1):
108 | content = await process_page(file_stream, page)
109 | if content:
110 | token_count = estimate_tokens(content)
111 | extracted_contents.append((page, content, token_count))
112 | return {'success': True, 'contents': extracted_contents}
113 | except Exception as e:
114 | logger.error(f"Error processing PDF chunk {start_page}-{end_page}: {e}")
115 | return {'success': False, 'error': str(e)}
116 |
117 | async def _process_batch(batch_content: str, keywords: str, examples: str, job_id: str) -> str:
118 | """Process a batch of content asynchronously."""
119 | try:
120 | client = await get_model_client()
121 | result = await call_llm_with_file_content(batch_content, keywords, examples, client)
122 | return result if result else ""
123 | except Exception as e:
124 | logger.error(f"Error processing batch in job {job_id}: {e}")
125 | return ""
126 |
127 | async def run_extraction(pdf_key: str, schemas: List[Dict[str, str]], job_id: str = None) -> List[str]:
128 | """Enhanced extraction with smart processing selection."""
129 | if not job_id:
130 | job_id = f"job_{int(time.time())}"
131 |
132 | logger.info(f"Starting extraction for pdf_key: {pdf_key}, job_id: {job_id}")
133 | await track_progress(job_id, 0, len(schemas), "initialization")
134 |
135 | redis = await get_redis_connection()
136 | try:
137 | file_stream = await get_file_stream(redis, pdf_key)
138 | file_size = file_stream.getbuffer().nbytes
139 |
140 | client = await get_model_client()
141 | examples = await get_examples(client, str(schemas))
142 | formatted_keywords = "\n".join([f"{k}: {v}" for k, v in schemas[0].items()])
143 |
144 | page_numbers = await get_relevant_page_numbers(client, file_stream, formatted_keywords)
145 | total_relevant_pages = len(page_numbers.pages)
146 |
147 | DISTRIBUTED_THRESHOLD = 10
148 |
149 | if total_relevant_pages <= DISTRIBUTED_THRESHOLD:
150 | logger.info(f"Using direct processing for {total_relevant_pages} relevant pages")
151 | metrics = await retrieve_multi_page_metrics(
152 | page_numbers.pages, formatted_keywords, file_stream,
153 | examples, client, job_id
154 | )
155 | await track_progress(job_id, len(schemas), len(schemas), "completed", "success")
156 | return [metrics] if metrics else []
157 | else:
158 | logger.info(f"Using distributed processing for {total_relevant_pages} relevant pages")
159 |
160 | chunk_size = max(1, total_relevant_pages // MAX_WORKERS)
161 | chunks = [page_numbers.pages[i:i + chunk_size] for i in range(0, total_relevant_pages, chunk_size)]
162 |
163 | chunk_tasks = []
164 | for chunk_pages in chunks:
165 | task = process_pdf_chunk.delay(
166 | base64.b64encode(file_stream.getvalue()).decode(),
167 | min(chunk_pages),
168 | max(chunk_pages),
169 | job_id
170 | )
171 | chunk_tasks.append(task)
172 |
173 | results = []
174 | for task in chunk_tasks:
175 | result = task.get()
176 | if result['success']:
177 | results.extend(result['contents'])
178 |
179 | if not results:
180 | logger.error("No results from distributed processing")
181 | return []
182 |
183 | combined_content = "\n=== BATCH BREAK ===\n".join([content for _, content, _ in sorted(results, key=lambda x: x[0])])
184 | final_result = await validate_metrics(combined_content, examples, client)
185 |
186 | await track_progress(job_id, len(schemas), len(schemas), "completed", "success")
187 | await cleanup_processed_files(redis, pdf_key)
188 |
189 | return [final_result] if final_result else []
190 |
191 | except Exception as e:
192 | logger.error(f"Error in extraction process: {e}")
193 | await track_progress(job_id, 0, len(schemas), "failed", "error")
194 | return []
195 |
196 | async def process_small_pdf(file_stream: BytesIO, schemas: List[Dict[str, str]], examples: str, client, job_id: str) -> List[str]:
197 | """Direct processing for small PDFs."""
198 | try:
199 | formatted_keywords = "\n".join([f"{k}: {v}" for k, v in schemas[0].items()])
200 | page_numbers = await get_relevant_page_numbers(client, file_stream, formatted_keywords)
201 |
202 | metrics = await retrieve_multi_page_metrics(
203 | page_numbers.pages, formatted_keywords, file_stream,
204 | examples, client, job_id
205 | )
206 |
207 | await track_progress(job_id, len(schemas), len(schemas), "completed", "success")
208 | return [metrics] if metrics else []
209 |
210 | except Exception as e:
211 | logger.error(f"Error in direct processing: {e}")
212 | return []
213 |
214 | async def track_progress(job_id: str, current: int, total: int, stage: str, status: str = "running"):
215 | processing_queue.progress[job_id] = ProcessingProgress(
216 | current=current,
217 | total=total,
218 | stage=stage,
219 | timestamp=time.time(),
220 | status=status
221 | )
222 |
223 | async def cleanup_processed_files(redis: Redis, pdf_key: str):
224 | """Cleanup temporary files and memory after processing."""
225 | try:
226 | await redis.delete(f"processed:{pdf_key}")
227 | gc.collect()
228 | except Exception as e:
229 | logger.error(f"Error during cleanup: {e}")
230 |
231 | def estimate_tokens(text: str) -> int:
232 | """More accurate token estimation using tiktoken."""
233 | try:
234 | encoding = tiktoken.get_encoding("cl100k_base")
235 | return len(encoding.encode(text))
236 | except Exception:
237 | return len(text) // 4
238 |
239 | async def calculate_optimal_batch_size(content_length: int, max_tokens: int = 3000) -> int:
240 | """Calculate optimal batch size based on content length."""
241 | return min(max_tokens, max(1000, content_length // 2))
242 |
243 | async def process_schema(client, file_stream: BytesIO, schema: Dict[str, str], job_id: str, schema_idx: int, total_schemas: int) -> str:
244 | """Enhanced schema processing with progress tracking."""
245 | try:
246 | await track_progress(job_id, schema_idx, total_schemas, f"processing_schema_{schema_idx}")
247 | formatted_keywords = "\n".join([f"{k}: {v}" for k, v in schema.items()])
248 |
249 | page_numbers = await get_relevant_page_numbers(client, file_stream, formatted_keywords)
250 | examples = await get_examples(client, formatted_keywords)
251 |
252 | await track_progress(job_id, schema_idx + 0.5, total_schemas, f"extracting_metrics_{schema_idx}")
253 | metrics = await retrieve_multi_page_metrics(
254 | page_numbers.pages, formatted_keywords, file_stream,
255 | examples if examples else "", client, job_id
256 | )
257 |
258 | return metrics if metrics else ""
259 | except Exception as e:
260 | logger.error(f"Error processing schema {schema_idx}: {e}")
261 | return ""
262 |
263 | async def get_file_stream(redis: Redis, pdf_key: str) -> BytesIO:
264 | try:
265 | base64_string: str = await redis.get(pdf_key)
266 | decoded_bytes = base64.b64decode(base64_string)
267 | return BytesIO(decoded_bytes)
268 | except Exception as e:
269 | logger.error(f"Failed to process file stream: {e}")
270 | raise e
271 |
272 | async def get_examples(client, formatted_keywords: str) -> str:
273 | try:
274 | prompt = langsmith_client.pull_prompt(PromptType.EXAMPLE_GENERATION.value)
275 | messages = prompt.invoke({"first_value": formatted_keywords})
276 | processed_messages = preprocess_messages(messages)
277 | if processed_messages:
278 | return client.do_completion(processed_messages)
279 | except Exception as e:
280 | logger.error(f"Error generating examples: {e}")
281 | return ""
282 |
283 | async def get_relevant_page_numbers(client, file_stream: BytesIO, formatted_keywords: str) -> PageNumbers:
284 | try:
285 | common_pages = await find_common_pages(client, file_stream, formatted_keywords)
286 | return PageNumbers(pages=common_pages)
287 | except Exception as e:
288 | logger.error(f"Error with find_common_pages: {e}")
289 | raise Exception("Cannot process: No common pages found or unreadable file.")
290 |
291 | async def retrieve_multi_page_metrics(
292 | pages: List[int], keywords: str, file_stream: BytesIO, examples: str, client, job_id: str
293 | ) -> str:
294 | try:
295 | MAX_TOKENS = await calculate_optimal_batch_size(3000)
296 | logger.info(f"Using dynamic token limit: {MAX_TOKENS}")
297 |
298 | extracted_contents = []
299 | total_pages = len(pages)
300 |
301 | for idx, page in enumerate(pages):
302 | await track_progress(job_id, idx, total_pages, f"extracting_page_{page}")
303 | content = await process_page(file_stream, page)
304 | if content:
305 | token_count = estimate_tokens(content)
306 | extracted_contents.append((page, content, token_count))
307 | logger.info(f"Page {page}: {token_count} tokens")
308 |
309 | if not extracted_contents:
310 | logger.error("No content extracted from pages")
311 | return ""
312 |
313 | all_results = []
314 | current_batch = []
315 | current_token_count = 0
316 |
317 | for page_num, content, token_count in extracted_contents:
318 | if current_token_count + token_count > MAX_TOKENS:
319 | if current_batch:
320 | batch_content = "\n=== PAGE BREAK ===\n".join([c for _, c, _ in current_batch])
321 | batch_result = await call_llm_with_file_content(batch_content, keywords, examples, client)
322 | if batch_result:
323 | all_results.append(batch_result)
324 |
325 | del batch_content
326 | gc.collect()
327 |
328 | current_batch = [(page_num, content, token_count)]
329 | current_token_count = token_count
330 | else:
331 | current_batch.append((page_num, content, token_count))
332 | current_token_count += token_count
333 |
334 | if current_batch:
335 | batch_content = "\n=== PAGE BREAK ===\n".join([c for _, c, _ in current_batch])
336 | batch_result = await call_llm_with_file_content(batch_content, keywords, examples, client)
337 | if batch_result:
338 | all_results.append(batch_result)
339 |
340 | del batch_content
341 | gc.collect()
342 |
343 | combined_results = "\n=== BATCH BREAK ===\n".join(all_results)
344 | await track_progress(job_id, total_pages, total_pages, "validating_results")
345 |
346 | validation_result = await validate_metrics(combined_results, examples, client)
347 |
348 | del all_results
349 | del combined_results
350 | gc.collect()
351 |
352 | return validation_result
353 |
354 | except Exception as e:
355 | logger.error(f"Error in retrieve_multi_page_metrics: {e}")
356 | await track_progress(job_id, 0, len(pages), "failed", "error")
357 | return ""
358 |
359 | async def process_page(file_stream: BytesIO, page: int) -> str:
360 | try:
361 | content = extract_page_as_markdown(file_stream, page)
362 | return content.decode('utf-8') if isinstance(content, bytes) else content
363 | except Exception as e:
364 | logger.error(f"Error extracting page {page} as markdown: {e}")
365 | return ""
366 |
367 | async def call_llm_with_file_content(formatted_content: str, keywords: str, examples: str, client) -> str:
368 | try:
369 | text = f"""SOURCE DOCUMENT:
370 | {formatted_content}
371 |
372 | EXAMPLE FORMAT:
373 | {examples}
374 |
375 | METRICS TO EXTRACT:
376 | {keywords}
377 |
378 | Note: The source document may contain multiple pages separated by '=== PAGE BREAK ==='.
379 | Extract all relevant metrics from each page section while maintaining accuracy."""
380 | return process_extraction(text, client, AgentMode.EXTRACTION)
381 | except Exception as e:
382 | logger.error(f"Error calling LLM with file content: {e}")
383 | return ""
384 |
385 | async def validate_metrics(
386 | llm_results: str,
387 | examples: str,
388 | client
389 | ) -> str:
390 | try:
391 | batched_results = llm_results.split("=== BATCH BREAK ===")
392 | MAX_VALIDATION_TOKENS = await calculate_optimal_batch_size(3000)
393 |
394 | validated_chunks = []
395 | current_chunk = []
396 | current_token_count = 0
397 |
398 | for batch in batched_results:
399 | if not batch.strip():
400 | continue
401 |
402 | batch_token_count = estimate_tokens(batch.strip())
403 |
404 | if current_token_count + batch_token_count > MAX_VALIDATION_TOKENS:
405 | if current_chunk:
406 | try:
407 | chunk_text = "\n".join(current_chunk)
408 | prompt = langsmith_client.pull_prompt(PromptType.VALIDATION.value)
409 | messages = prompt.invoke({
410 | "first_value": chunk_text,
411 | "second_value": examples
412 | })
413 | processed_messages = preprocess_messages(messages)
414 | if processed_messages:
415 | chunk_validation = client.do_completion(processed_messages)
416 | validated_chunks.append(chunk_validation)
417 |
418 | del chunk_text
419 | del messages
420 | gc.collect()
421 | except Exception as e:
422 | logger.error(f"Error validating chunk: {e}")
423 |
424 | current_chunk = [batch.strip()]
425 | current_token_count = batch_token_count
426 | else:
427 | current_chunk.append(batch.strip())
428 | current_token_count += batch_token_count
429 |
430 | if current_chunk:
431 | try:
432 | chunk_text = "\n".join(current_chunk)
433 | prompt = langsmith_client.pull_prompt(PromptType.VALIDATION.value)
434 | messages = prompt.invoke({
435 | "first_value": chunk_text,
436 | "second_value": examples
437 | })
438 | processed_messages = preprocess_messages(messages)
439 | if processed_messages:
440 | chunk_validation = client.do_completion(processed_messages)
441 | validated_chunks.append(chunk_validation)
442 |
443 | del chunk_text
444 | del messages
445 | gc.collect()
446 | except Exception as e:
447 | logger.error(f"Error validating final chunk: {e}")
448 |
449 | if len(validated_chunks) > 1:
450 | try:
451 | final_consolidation = "\n".join(validated_chunks)
452 | if estimate_tokens(final_consolidation) > MAX_VALIDATION_TOKENS:
453 | logger.warning("Final consolidation exceeds token limit, will process in chunks")
454 | return await validate_metrics(final_consolidation, examples, client)
455 |
456 | prompt = langsmith_client.pull_prompt(PromptType.VALIDATION.value)
457 | messages = prompt.invoke({
458 | "first_value": final_consolidation,
459 | "second_value": examples
460 | })
461 | processed_messages = preprocess_messages(messages)
462 | if processed_messages:
463 | final_result = client.do_completion(processed_messages)
464 | return final_result
465 | except Exception as e:
466 | logger.error(f"Error in final consolidation: {e}")
467 | return validated_chunks[0] if validated_chunks else ""
468 | elif validated_chunks:
469 | return validated_chunks[0]
470 |
471 | return ""
472 |
473 | except Exception as e:
474 | logger.error(f"Error validating metrics: {e}")
475 | return ""
476 |
477 | async def run_web_extraction(url: str, schemas: List[Dict[str, str]], job_id: str = None) -> List[str]:
478 | """Handle extraction from web content using processing_handler."""
479 | if not job_id:
480 | job_id = f"job_{int(time.time())}"
481 |
482 | logger.info(f"Starting web extraction for url: {url}, job_id: {job_id}")
483 | await track_progress(job_id, 0, len(schemas), "initialization")
484 |
485 | try:
486 | redis = await get_redis_connection()
487 | results = await process_web_content(redis, url, schemas)
488 |
489 | await track_progress(job_id, len(schemas), len(schemas), "completed", "success")
490 | return results
491 |
492 | except Exception as e:
493 | logger.error(f"Error in web extraction: {e}")
494 | await track_progress(job_id, 0, len(schemas), "failed", "error")
495 | return []
--------------------------------------------------------------------------------
/application/extraction/service/extraction_worker.py:
--------------------------------------------------------------------------------
1 | import os
2 | import asyncio
3 | import json
4 | import logging
5 | from typing import Optional, Dict, Any
6 | from redis.asyncio import Redis
7 | from redis.exceptions import RedisError
8 | from datetime import datetime
9 | from application.extraction.models.models import ( ExtractionRequestModel,
10 | ExtractionResponseModel,
11 | SchemaResult,
12 | JobStatus
13 | )
14 | from application.extraction.service.extraction_handler import run_extraction, run_web_extraction
15 | from common.redis.redis_config import get_redis_connection
16 |
17 | logging.basicConfig(level=logging.INFO)
18 | logger = logging.getLogger(__name__)
19 |
20 | async def run_extractions() -> None:
21 | logger.info("Starting extraction worker")
22 | redis = await get_redis_connection()
23 | last_id = "0-0"
24 |
25 | while True:
26 | try:
27 | result = await redis.xread(
28 | streams={"extraction-stream": last_id},
29 | count=1,
30 | block=0
31 | )
32 | logger.info(f"Result: {result}")
33 |
34 | for stream_name, messages in result:
35 | for message_id, message in messages:
36 | logger.info(f"Received message from stream {stream_name}: ID {message_id}")
37 | payload = message.get(b"payload")
38 | if payload:
39 | try:
40 | logger.info(f"Payload value: {payload}")
41 | extraction_request = ExtractionRequestModel(**json.loads(payload.decode('utf-8')))
42 | extraction_result = await process_extraction(extraction_request)
43 | serialized_result = json.dumps(extraction_result.model_dump())
44 | logger.info(f"Pushing result to transformation-stream: {serialized_result}")
45 | await redis.xadd("transformation-stream", {"payload": serialized_result})
46 | logger.info("Successfully pushed result to transformation-stream")
47 | except json.JSONDecodeError as e:
48 | logger.error(f"Failed to parse payload JSON: {e}")
49 | except Exception as e:
50 | logger.error(f"Error processing extraction task: {e}")
51 | else:
52 | logger.error("Message does not contain 'payload' field")
53 | last_id = message_id
54 |
55 | except RedisError as e:
56 | logger.error(f"Error reading from Redis stream: {e}")
57 | await asyncio.sleep(1)
58 |
59 | async def process_extraction(extraction_request: ExtractionRequestModel) -> ExtractionResponseModel:
60 | try:
61 | if extraction_request.source_type == "web":
62 | results = await run_web_extraction(extraction_request.pdf_key, extraction_request.schemas)
63 | else:
64 | results = await run_extraction(extraction_request.pdf_key, extraction_request.schemas)
65 | schema_results = [
66 | SchemaResult(
67 | schema_id=f"schema_{index}",
68 | metrics={f"schema_{index}": result},
69 | schema_data=schema
70 | )
71 | for index, (schema, result) in enumerate(zip(extraction_request.schemas, results))
72 | ]
73 |
74 | response = ExtractionResponseModel(
75 | task_id=extraction_request.task_id,
76 | pdf_key=extraction_request.pdf_key,
77 | results=schema_results,
78 | source_type=extraction_request.source_type
79 | )
80 |
81 | redis = await get_redis_connection()
82 | await update_job_status(redis, extraction_request.task_id, JobStatus.PENDING, None)
83 |
84 | return response
85 |
86 | except Exception as e:
87 | logger.error(f"Error processing extraction task: {e}")
88 | redis = await get_redis_connection()
89 | # Log the type of the key causing the error
90 | key_type = await redis.type(extraction_request.pdf_key)
91 | logger.error(f"Key type for {extraction_request.pdf_key}: {key_type}")
92 |
93 | # Handle unexpected key types
94 | if key_type != b'string':
95 | logger.error(f"Unexpected key type for {extraction_request.pdf_key}. Deleting the key.")
96 | await redis.delete(extraction_request.pdf_key)
97 |
98 | await update_job_status(redis, extraction_request.task_id, JobStatus.FAILED, str(e))
99 | raise e
100 |
101 | async def update_job_status(
102 | redis: Redis,
103 | task_id: str,
104 | status: JobStatus,
105 | error_message: Optional[str]
106 | ) -> None:
107 | fields: Dict[str, Any] = {"status": json.dumps(status.value)}
108 | if error_message:
109 | fields["error_message"] = error_message
110 |
111 | start_time_str = await redis.get(f"job-start-time:{task_id}")
112 | start_time = int(start_time_str) if start_time_str else 0
113 | current_time = int(datetime.utcnow().timestamp())
114 | total_run_time = current_time - start_time
115 | run_time_str = f"{total_run_time // 60} minutes" if total_run_time >= 60 else f"{total_run_time} seconds"
116 | fields["total_run_time"] = run_time_str
117 |
118 | await redis.xadd(f"job-status:{task_id}", fields)
119 |
120 | async def clear_extraction_stream() -> None:
121 | retries = 0
122 | max_retries = 5
123 | delay = 1
124 |
125 | while retries < max_retries:
126 | try:
127 | redis = await get_redis_connection()
128 | await redis.xtrim("extraction-stream", approximate=True, maxlen=0)
129 | logger.info("Cleared extraction-stream")
130 | return
131 | except RedisError as e:
132 | if "LOADING" in str(e):
133 | logger.info(f"Redis is still loading. Retrying in {delay} seconds...")
134 | await asyncio.sleep(delay)
135 | retries += 1
136 | delay *= 2
137 | else:
138 | raise e
139 |
140 | raise Exception("Failed to clear extraction-stream after maximum retries")
--------------------------------------------------------------------------------
/application/extraction/service/processing_handler.py:
--------------------------------------------------------------------------------
1 | import logging
2 | from typing import List, Dict, Optional
3 | from redis.asyncio import Redis
4 | from bs4 import BeautifulSoup
5 | import re
6 | from common.models.model_factory import ModelFactory
7 | from application.extraction.models.models import ModelDetails
8 | from langsmith import Client as LangSmithClient
9 | from common.prompts.prompt_enums import PromptType
10 | from langchain_core.messages import SystemMessage, HumanMessage
11 | import json
12 | from common.agents.prs_agent import process_extraction
13 | from common.agents.agent_prompt_enums import AgentMode
14 |
15 | logger = logging.getLogger(__name__)
16 | langsmith_client = LangSmithClient()
17 |
18 | def web_preprocessing(html_content: str) -> str:
19 | # Parse the HTML
20 | soup = BeautifulSoup(html_content, 'html.parser')
21 |
22 | # Remove script and style elements
23 | for script_or_style in soup(["script", "style", "meta", "link"]):
24 | script_or_style.decompose()
25 |
26 | # Remove navigation, footer, ads, and sidebars
27 | for element in soup(["nav", "footer", "aside"]):
28 | element.decompose()
29 |
30 | # Remove elements with common ad-related class names
31 | ad_classes = ["ad", "advertisement", "banner", "sidebar"]
32 | for element in soup.find_all(class_=lambda x: x and any(cls in x for cls in ad_classes)):
33 | element.decompose()
34 |
35 | # Try to find the main content
36 | main_content = soup.find("main") or soup.find("article")
37 | if not main_content:
38 | main_content = max(
39 | soup.find_all("div", text=True),
40 | key=lambda div: len(div.get_text()),
41 | default=soup
42 | )
43 |
44 | # Process links in the main content
45 | for a in main_content.find_all('a', href=True):
46 | href = a['href']
47 | if not a.string:
48 | a.string = href
49 | else:
50 | a.string = f"{a.string} ({href})"
51 |
52 | # Extract text from the main content
53 | text = main_content.get_text(separator=' ', strip=True)
54 |
55 | # Clean the text
56 | text = re.sub(r'\s+', ' ', text) # Normalize whitespace
57 | text = re.sub(r'&[a-zA-Z]+;', '', text) # Remove HTML entities
58 | logger.info(f"Length of text: {len(text)}")
59 | return text
60 |
61 | async def get_latest_model_details(redis: Redis) -> Optional[ModelDetails]:
62 | try:
63 | model_details_json = await redis.get("model-details")
64 | if not model_details_json:
65 | logger.error("No data found for model-details")
66 | return None
67 |
68 | model_details = ModelDetails(**json.loads(model_details_json))
69 | return model_details
70 |
71 | except Exception as e:
72 | logger.error(f"Failed to get or parse model details: {e}")
73 | return None
74 |
75 | async def process_web_content(redis: Redis, pdf_key: str, schemas: List[Dict[str, str]]) -> List[str]:
76 | logger.info(f"Starting web extraction process for pdf_key: {pdf_key}")
77 |
78 | # Retrieve the HTML content from Redis
79 | html_content = await redis.get(pdf_key)
80 |
81 | if not html_content:
82 | logger.error(f"No HTML content found for key: {pdf_key}")
83 | return []
84 |
85 | # Preprocess the HTML content
86 | preprocessed_text = web_preprocessing(html_content.decode('utf-8'))
87 |
88 | # Get the latest model details
89 | model_details = await get_latest_model_details(redis)
90 | if not model_details:
91 | return []
92 |
93 | # Create model instance
94 | try:
95 | model_instance = ModelFactory.create_model(
96 | model_type=model_details.provider_type,
97 | model_name=model_details.provider_model_name,
98 | api_key=model_details.api_key,
99 | additional_params=model_details.additional_params
100 | )
101 | except ValueError as e:
102 | logger.error(f"Model creation error: {e}")
103 | return []
104 |
105 | # Process each schema using PRS agent
106 | results = []
107 | for schema in schemas:
108 | try:
109 | formatted_keywords = "\n".join([f"{k}: {v}" for k, v in schema.items()])
110 | # Get example format from the schema
111 | example_format = await get_example_format(model_instance, formatted_keywords)
112 | # Use PRS agent for extraction
113 | text = f"""SOURCE DOCUMENT:
114 | {preprocessed_text}
115 |
116 | EXAMPLE FORMAT:
117 | {example_format}
118 |
119 | METRICS TO EXTRACT:
120 | {formatted_keywords}"""
121 |
122 | result = process_extraction(text, model_instance, AgentMode.EXTRACTION)
123 | results.append(result)
124 | except Exception as e:
125 | logger.error(f"Error processing schema: {e}")
126 |
127 | return results
128 |
129 | async def get_example_format(client, formatted_keywords: str) -> str:
130 | try:
131 | prompt = langsmith_client.pull_prompt(PromptType.EXAMPLE_GENERATION.value)
132 | messages = prompt.invoke({"first_value": formatted_keywords})
133 | processed_messages = preprocess_messages(messages)
134 | if processed_messages:
135 | return client.do_completion(processed_messages)
136 | except Exception as e:
137 | logger.error(f"Error generating example format: {e}")
138 | return ""
139 |
140 | def preprocess_messages(raw_payload) -> List[Dict[str, str]]:
141 | messages = []
142 | if hasattr(raw_payload, 'to_messages'):
143 | for message in raw_payload.to_messages():
144 | if isinstance(message, SystemMessage):
145 | messages.append({"role": "system", "content": message.content})
146 | elif isinstance(message, HumanMessage):
147 | messages.append({"role": "user", "content": message.content})
148 | else:
149 | logger.warning(f"Unexpected message type: {type(message)}")
150 | else:
151 | logger.warning(f"Unexpected raw_payload format: {type(raw_payload)}")
152 | return messages
--------------------------------------------------------------------------------
/application/extraction/start_extraction.py:
--------------------------------------------------------------------------------
1 | import os
2 | import logging
3 | from dotenv import load_dotenv
4 | import asyncio
5 | from application.extraction.service.extraction_worker import clear_extraction_stream, run_extractions
6 |
7 | load_dotenv()
8 |
9 | logging.basicConfig(level=logging.INFO,
10 | format='%(asctime)s - %(levelname)s - %(message)s')
11 |
12 | async def main():
13 | try:
14 | await clear_extraction_stream()
15 | except Exception as e:
16 | logging.error("Failed to clear extraction-stream: %s", e)
17 |
18 | try:
19 | logging.info("Started extraction service...")
20 | await run_extractions()
21 | except Exception as e:
22 | logging.error("Application error: %s", e)
23 | exit(1)
24 |
25 | if __name__ == "__main__":
26 | asyncio.run(main())
27 |
--------------------------------------------------------------------------------
/application/pipeline/Dockerfile:
--------------------------------------------------------------------------------
1 | FROM python:3.10-slim
2 |
3 | WORKDIR /app
4 |
5 | # Copy the common directory
6 | COPY ./common /app/common
7 |
8 | # Copy the pipeline application directory
9 | COPY ./application/pipeline /app/application/pipeline
10 |
11 | # Copy the requirements file
12 | COPY ./application/pipeline/requirements.txt /app/requirements.txt
13 |
14 | # Copy example files
15 | COPY ./examples/example_files /app/example_files
16 |
17 | RUN apt-get update && apt-get install -y \
18 | build-essential \
19 | make \
20 | && rm -rf /var/lib/apt/lists/*
21 |
22 | # Install dependencies
23 | RUN pip install --no-cache-dir -r /app/requirements.txt
24 |
25 | # Set the Python path
26 | ENV PYTHONPATH /app
27 |
28 | # Expose the port
29 | EXPOSE 8002
30 |
31 | # Run the application
32 | CMD ["uvicorn", "application.pipeline.start_pipeline:app", "--host", "0.0.0.0", "--port", "8100"]
--------------------------------------------------------------------------------
/application/pipeline/models/models.py:
--------------------------------------------------------------------------------
1 | from pydantic import BaseModel, Field
2 | from typing import List, Dict, Any
3 | from enum import Enum
4 |
5 | class JobStatus(str, Enum):
6 | PENDING = "PENDING"
7 | IN_PROGRESS = "IN_PROGRESS"
8 | COMPLETED = "COMPLETED"
9 | FAILED = "FAILED"
10 |
11 | class WorkloadItem(BaseModel):
12 | raw_data: str = Field(default=None)
13 | schemas: List[str]
14 | data_source: str = Field(default=None)
15 | documents_location: str = Field(default=None)
16 | file_name: str = Field(default=None)
17 | additional_params: Dict[str, Any] = Field(default_factory=dict)
18 | destination: str = Field(default=None)
19 |
20 | class PipelineRequestModel(BaseModel):
21 | workloads: List[WorkloadItem]
22 | provider_type: str
23 | provider_model_name: str
24 | api_key: str
25 | markdown_mode: bool = False
26 | additional_params: Dict[str, Any] = Field(default_factory=dict)
27 |
28 | class PipelineResponseModel(BaseModel):
29 | message: str
30 | task_id: str
31 |
32 | class PipelineResult(BaseModel):
33 | task_id: str
34 | status: JobStatus
35 | results: List[Dict]
36 | total_run_time: str
37 |
38 | class ExtractionRequestModel(BaseModel):
39 | task_id: str
40 | pdf_key: str
41 | schemas: List[Dict]
42 | source_type: str = "pdf"
43 |
--------------------------------------------------------------------------------
/application/pipeline/requirements.txt:
--------------------------------------------------------------------------------
1 | uvicorn
2 | fastapi
3 | redis[hiredis]
4 | pydantic
5 | PyPDF2
6 | groq
7 | openai
8 | langchain
9 | langchainhub
10 | python-dotenv
11 | langsmith
12 | pdfminer.six
13 | boto3
14 | pymupdf
15 | cerebras_cloud_sdk
16 | aiohttp
17 | requests==2.32.3
18 | markitdown
19 |
--------------------------------------------------------------------------------
/application/pipeline/routes/pipeline_routes.py:
--------------------------------------------------------------------------------
1 | from fastapi import APIRouter, HTTPException, status
2 | from application.pipeline.models.models import PipelineRequestModel, PipelineResponseModel, PipelineResult
3 | from application.pipeline.service.pipeline_service import run_pipeline, get_pipeline_results
4 | from fastapi.responses import JSONResponse
5 | import logging
6 |
7 | api_router = APIRouter()
8 |
9 | logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
10 | logger = logging.getLogger(__name__)
11 |
12 | @api_router.post("/pipelines",
13 | response_model=PipelineResponseModel,
14 | status_code=status.HTTP_202_ACCEPTED,
15 | summary="Run pipeline",
16 | description="Initiates a pipeline processing job for the given PDF and schemas.",
17 | responses={
18 | 202: {"description": "Pipeline processing started"},
19 | 500: {"description": "Internal server error"}
20 | })
21 | async def run_pipeline_route(pipeline_request: PipelineRequestModel):
22 | try:
23 | response = await run_pipeline(pipeline_request)
24 | if isinstance(response, dict) and "error" in response:
25 | raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=response["error"])
26 | return JSONResponse(content={"task_id": response["task_id"], "message": "Pipeline processing started"}, status_code=status.HTTP_202_ACCEPTED)
27 | except Exception as e:
28 | raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e))
29 |
30 | @api_router.get("/pipelines/{task_id}",
31 | response_model=PipelineResult,
32 | summary="Get pipeline results",
33 | description="Retrieves the results of a pipeline processing job.",
34 | responses={
35 | 200: {"description": "Pipeline results retrieved successfully"},
36 | 404: {"description": "Task not found"},
37 | 500: {"description": "Internal server error"}
38 | })
39 | async def get_pipeline_results_route(task_id: str):
40 | try:
41 | results = await get_pipeline_results(task_id)
42 | if isinstance(results, tuple) and results[1] == 500:
43 | raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=results[0]['error'])
44 | if not results:
45 | raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Task not found")
46 | return JSONResponse(content=results, status_code=status.HTTP_200_OK)
47 | except HTTPException as he:
48 | raise he
49 | except Exception as e:
50 | logger.error(f"Unexpected error in get_pipeline_results_route: {str(e)}")
51 | raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="An unexpected error occurred")
52 |
--------------------------------------------------------------------------------
/application/pipeline/service/pipeline_service.py:
--------------------------------------------------------------------------------
1 | import json
2 | import uuid
3 | import base64
4 | import zlib
5 | import hashlib
6 | from io import BytesIO
7 | from typing import List, Dict, Optional
8 | import redis.asyncio as redis
9 | import logging
10 | import time
11 | import asyncio
12 | from urllib.parse import urlparse
13 | import aiohttp
14 |
15 | from application.pipeline.models.models import (
16 | PipelineRequestModel,
17 | PipelineResponseModel,
18 | JobStatus,
19 | PipelineResult,
20 | ExtractionRequestModel,
21 | WorkloadItem
22 | )
23 | from common.text_extraction.text_extractor import get_pdf_page_count
24 | from common.models.model_factory import ModelFactory
25 | from common.sources.source_factory import SourceFactory
26 | from langsmith import Client as LangSmithClient
27 | from langchain.schema import SystemMessage, HumanMessage
28 |
29 | logging.basicConfig(level=logging.INFO)
30 | logger = logging.getLogger(__name__)
31 |
32 | langsmith_client = LangSmithClient()
33 |
34 | async def run_pipeline(customer_input: PipelineRequestModel):
35 | logger.info("Starting Pipeline Run...")
36 |
37 | try:
38 | con = await redis.from_url("redis://redis:6379/0", encoding="utf-8", decode_responses=True)
39 | except Exception as e:
40 | logger.error(f"Redis connection error: {e}")
41 | return {
42 | "error": "Service temporarily unavailable",
43 | "details": "Unable to connect to the database. Please try again later."
44 | }
45 |
46 | task_id = str(uuid.uuid4())
47 | start_time = int(time.time())
48 | await con.set(f"job-start-time:{task_id}", start_time)
49 | await con.xadd(
50 | f"job-status:{task_id}",
51 | {"status": json.dumps(JobStatus.PENDING.value), "start_time": str(start_time)}
52 | )
53 |
54 | try:
55 | ModelFactory.create_model(
56 | model_type=customer_input.provider_type,
57 | model_name=customer_input.provider_model_name,
58 | api_key=customer_input.api_key,
59 | additional_params=customer_input.additional_params
60 | )
61 | except ValueError as e:
62 | logger.error(f"Model creation error: {e}")
63 | return {
64 | "error": "Invalid model configuration",
65 | "details": str(e)
66 | }
67 |
68 | # Publish details to be used by other workers
69 | model_details_json = json.dumps({
70 | "provider_type": customer_input.provider_type,
71 | "provider_model_name": customer_input.provider_model_name,
72 | "api_key": customer_input.api_key,
73 | "markdown_mode": customer_input.markdown_mode,
74 | "additional_params": customer_input.additional_params
75 | })
76 | await con.set("model-details", model_details_json)
77 |
78 | if is_transformation_only_job(customer_input.workloads):
79 | logger.info("Transformation-only job detected")
80 | return await handle_transformation(customer_input, con, task_id)
81 | else:
82 | logger.info("Full pipeline job detected")
83 | return await handle_full_pipeline(customer_input, con, task_id)
84 |
85 | async def handle_transformation(customer_input: PipelineRequestModel, con: redis.Redis, task_id: str):
86 | workload = customer_input.workloads[0]
87 | transformation_payload = {
88 | "task_id": task_id,
89 | "data_location_key": workload.documents_location,
90 | "schemas": workload.schemas,
91 | "destination": workload.destination,
92 | "raw_data": workload.raw_data
93 | }
94 | await con.xadd("transformation-only-stream", {"payload": json.dumps(transformation_payload)})
95 |
96 | response = PipelineResponseModel(
97 | message="Transformation task submitted successfully",
98 | task_id=task_id
99 | )
100 | return {"task_id": response.task_id, "message": response.message}
101 |
102 | async def handle_full_pipeline(customer_input: PipelineRequestModel, con: redis.Redis, task_id: str):
103 | await con.set(f"workload-count:{task_id}", len(customer_input.workloads))
104 |
105 | pdf_hash = hashlib.sha256(json.dumps([w.dict() for w in customer_input.workloads]).encode()).hexdigest()
106 | cache_key = f"cache:{pdf_hash}"
107 | cached_hash = await con.get(cache_key)
108 | if cached_hash:
109 | logger.info(f"Cache hit for key: {cache_key}")
110 | cached_response = await con.get(cached_hash)
111 | if cached_response:
112 | return json.loads(cached_response)
113 |
114 | async def process_workload(index: int, workload_combo: WorkloadItem) -> int:
115 | try:
116 | if workload_combo.raw_data and workload_combo.data_source:
117 | logger.error(f"Workload {index} cannot have both raw_data and data_source.")
118 | raise ValueError("Workload cannot have both raw_data and data_source.")
119 | if workload_combo.raw_data:
120 | return await handle_raw_data(index, workload_combo, con, task_id)
121 | elif workload_combo.data_source == "web":
122 | return await handle_web_source(index, workload_combo, con, task_id)
123 | elif workload_combo.data_source:
124 | return await handle_data_source(index, workload_combo, con, task_id)
125 | else:
126 | logger.error(f"Workload {index} must have either raw_data or data_source.")
127 | return 0
128 | except Exception as e:
129 | logger.error(f"Error processing workload {index}: {e}")
130 | return 0
131 |
132 | workload_results = await asyncio.gather(
133 | *[process_workload(index, workload_combo) for index, workload_combo in enumerate(customer_input.workloads)],
134 | return_exceptions=False
135 | )
136 |
137 | total_pages = sum(workload_results)
138 | logger.info(f"Total pages processed: {total_pages}")
139 |
140 | await con.xadd(
141 | f"job-status:{task_id}",
142 | {"status": json.dumps(JobStatus.IN_PROGRESS.value)}
143 | )
144 |
145 | response = PipelineResponseModel(
146 | message="Tasks submitted successfully",
147 | task_id=task_id
148 | )
149 |
150 | response_hash = hashlib.sha256(json.dumps(response.dict()).encode()).hexdigest()
151 | await con.setex(cache_key, 3600, response_hash)
152 | await con.setex(response_hash, 3600, json.dumps(response.dict()))
153 |
154 | return {"task_id": response.task_id, "message": response.message}
155 |
156 | def is_transformation_only_job(workloads: List[WorkloadItem]) -> bool:
157 | if len(workloads) != 1:
158 | return False
159 | workload = workloads[0]
160 | return (
161 | workload.destination is not None and
162 | workload.documents_location is not None and
163 | workload.schemas is not None and
164 | workload.raw_data is not None and
165 | workload.data_source is None
166 | )
167 |
168 | async def handle_raw_data(index: int, workload_combo: WorkloadItem, con: redis.Redis, task_id: str) -> int:
169 | logger.info(f"Processing workload {index} with raw_data.")
170 | # Decode the base64 encoded data stream
171 | try:
172 | decompressed_data = zlib.decompress(base64.b64decode(workload_combo.raw_data))
173 | except zlib.error as e:
174 | logger.error(f"Decompression failed for workload {index}: {e}")
175 | return 0
176 |
177 | data_cursor = BytesIO(decompressed_data)
178 | logger.info(f"Decompressed data for workload {index}")
179 |
180 | try:
181 | page_count = get_pdf_page_count(data_cursor)
182 | logger.debug(f"Page count for workload {index}: {page_count}")
183 | except Exception as e:
184 | logger.error(f"Error getting page count for workload {index}: {e}")
185 | return 0
186 |
187 | data_key = f"data:{task_id}:{index}"
188 | await con.set(data_key, base64.b64encode(decompressed_data).decode())
189 | logger.info(f"Data stored in Redis with key: {data_key} as base64 string")
190 |
191 | schemas = [json.loads(schema) for schema in workload_combo.schemas]
192 |
193 | task_payload = ExtractionRequestModel(
194 | task_id=task_id,
195 | pdf_key=data_key,
196 | schemas=schemas
197 | )
198 |
199 | await con.xadd("extraction-stream", {"payload": json.dumps(task_payload.dict())})
200 |
201 | return page_count
202 |
203 | async def fetch_document(session, workload_combo):
204 | try:
205 | async with session.get(workload_combo.documents_location) as response:
206 | response.raise_for_status() # This will raise an HTTPError for bad responses
207 | content = await response.read()
208 | return content
209 | except aiohttp.ClientError as e:
210 | logger.error(f"Error fetching document: {str(e)}")
211 | raise
212 |
213 |
214 | async def handle_web_source(index: int, workload_combo: WorkloadItem, con: redis.Redis, task_id: str) -> int:
215 | logger.info(f"Processing workload {index} with web source: {workload_combo.documents_location}")
216 |
217 | # Fetch the web content
218 | # You might want to use a library like aiohttp for asynchronous HTTP requests
219 | # Here I need to make sure that the url is valid and that it returns a 200
220 | # Validate URL
221 | try:
222 | result = urlparse(workload_combo.documents_location)
223 | if not all([result.scheme, result.netloc]):
224 | logger.error(f"Invalid URL: {workload_combo.documents_location}")
225 | return 0
226 | except ValueError:
227 | logger.error(f"Invalid URL format: {workload_combo.documents_location}")
228 | return 0
229 | logger.info("URL is valid")
230 | logger.info("Fetching document content")
231 | async with aiohttp.ClientSession() as session:
232 | try:
233 | document_content = await fetch_document(session, workload_combo)
234 | except aiohttp.ClientError as e:
235 | logger.error(f"Failed to connect to URL: {str(e)}")
236 | return 0
237 |
238 | # Process the HTML content as needed
239 | # For example, you might want to extract text or specific elements
240 | # This is where the preprocessing could come into play, or I can add it
241 | # Store the processed content in Redis
242 | content_key = f"web:{task_id}:{index}"
243 | await con.set(content_key, document_content)
244 |
245 | # Create and add the extraction task
246 | schemas = [json.loads(schema) for schema in workload_combo.schemas]
247 | task_payload = ExtractionRequestModel(
248 | task_id=task_id,
249 | pdf_key=content_key,
250 | schemas=schemas,
251 | source_type="web"
252 | )
253 | await con.xadd("extraction-stream", {"payload": json.dumps(task_payload.dict())})
254 |
255 | return 1
256 |
257 |
258 | async def handle_data_source(index: int, workload_combo: WorkloadItem, con: redis.Redis, task_id: str) -> int:
259 | logger.info(f"Processing workload {index} with data_source: {workload_combo.data_source}")
260 | logger.info(f"Documents location: {workload_combo.documents_location}")
261 | source = SourceFactory.create_source (
262 | source_type=workload_combo.data_source,
263 | documents_location=workload_combo.documents_location,
264 | additional_params=workload_combo.additional_params
265 | )
266 | logger.info(f"Created source: {source}")
267 |
268 | all_files = source.read_all()
269 | logger.info(f"List of all files: {all_files}")
270 | if not all_files:
271 | logger.warning(f"No files found in data source: {workload_combo.data_source}")
272 | return 0
273 |
274 | relevant_file = await get_relevant_file_via_llm(all_files, workload_combo.file_name)
275 | if not relevant_file:
276 | logger.warning(f"LLM did not return a valid file for workload {index}")
277 | return 0
278 |
279 | logger.info(f"Selected relevant file: {relevant_file}")
280 |
281 | file_stream = source.read({"file_key": relevant_file})
282 | if not file_stream:
283 | logger.warning(f"Failed to read the selected file: {relevant_file}")
284 | return 0
285 |
286 | compressed_data = base64.b64encode(zlib.compress(file_stream.getvalue())).decode('utf-8')
287 |
288 | try:
289 | decompressed_pdf = zlib.decompress(base64.b64decode(compressed_data))
290 | except zlib.error as e:
291 | logger.error(f"Decompression failed for workload {index}: {e}")
292 | return 0
293 |
294 | pdf_cursor = BytesIO(decompressed_pdf)
295 | logger.info(f"Decompressed PDF for workload {index}")
296 |
297 | try:
298 | page_count = get_pdf_page_count(pdf_cursor)
299 | logger.debug(f"Page count for workload {index}: {page_count}")
300 | except Exception as e:
301 | logger.error(f"Error getting page count for workload {index}: {e}")
302 | return 0
303 |
304 | pdf_key = f"pdf:{task_id}:{index}"
305 | await con.set(pdf_key, base64.b64encode(decompressed_pdf).decode())
306 | logger.info(f"PDF stored in Redis with key: {pdf_key} as base64 string")
307 |
308 | schemas = [json.loads(schema) for schema in workload_combo.schemas]
309 |
310 | task_payload = ExtractionRequestModel(
311 | task_id=task_id,
312 | pdf_key=pdf_key,
313 | schemas=schemas
314 | )
315 |
316 | await con.xadd("extraction-stream", {"payload": json.dumps(task_payload.dict())})
317 |
318 | return page_count
319 |
320 | def preprocess_messages(raw_payload):
321 | messages = []
322 | if hasattr(raw_payload, 'to_messages'):
323 | for message in raw_payload.to_messages():
324 | if isinstance(message, SystemMessage):
325 | messages.append({"role": "system", "content": message.content})
326 | elif isinstance(message, HumanMessage):
327 | messages.append({"role": "user", "content": message.content})
328 | else:
329 | logger.warning(f"Unexpected message type: {type(message)}")
330 | else:
331 | logger.warning(f"Unexpected raw_payload format: {type(raw_payload)}")
332 | return messages
333 |
334 | async def get_relevant_file_via_llm(filenames: List[str], file_name: str) -> Optional[str]:
335 | if not filenames:
336 | logger.error("No filenames provided to determine relevance.")
337 | return None
338 |
339 | try:
340 | con = await redis.from_url("redis://redis:6379/0", encoding="utf-8", decode_responses=True)
341 | customer_input_str = await con.get("model-details")
342 | logger.info(f"Model details JSON: {customer_input_str}")
343 | if not customer_input_str:
344 | logger.error("Model details not found in Redis.")
345 | return None
346 |
347 | customer_input = json.loads(customer_input_str)
348 |
349 | model_instance = ModelFactory.create_model(
350 | model_type=customer_input["provider_type"],
351 | model_name=customer_input["provider_model_name"],
352 | api_key=customer_input["api_key"],
353 | additional_params=customer_input["additional_params"]
354 | )
355 |
356 | prompt = langsmith_client.pull_prompt("marly/get-relevant-file")
357 |
358 | messages = prompt.invoke({
359 | "first_value": file_name,
360 | "second_value": filenames
361 | })
362 |
363 | processed_messages = preprocess_messages(messages)
364 |
365 | if not processed_messages:
366 | logger.error("No messages to process for determining relevant file.")
367 | return None
368 |
369 | relevant_file = model_instance.do_completion(processed_messages)
370 |
371 | if relevant_file in filenames:
372 | return relevant_file
373 | else:
374 | logger.error(f"LLM returned an invalid filename: {relevant_file}")
375 | return None
376 |
377 | except Exception as e:
378 | logger.exception("An error occurred while determining the relevant file via LLM.")
379 | return None
380 |
381 | async def get_results_from_stream(con: redis.Redis, task_id: str) -> List[Dict]:
382 | entries = await con.xrange(f"job-status:{task_id}")
383 | results = []
384 | for _, entry in entries:
385 | if b'result' in entry:
386 | result = json.loads(entry[b'result'])
387 | if 'results' in result:
388 | results.extend(result['results'])
389 | return results
390 |
391 | async def get_pipeline_results(task_id: str):
392 | try:
393 | con = await redis.from_url("redis://redis:6379/0", encoding="utf-8", decode_responses=True)
394 | except Exception as e:
395 | logger.error(f"Failed to connect to Redis: {e}")
396 | return {"error": "Failed to connect to Redis"}, 500
397 |
398 | try:
399 | status_stream = await con.xrange(f"job-status:{task_id}")
400 | except Exception as e:
401 | logger.error(f"Failed to fetch job status: {e}")
402 | return {"error": "Failed to fetch job status"}, 500
403 |
404 | if not status_stream:
405 | return {"error": "Task not found"}, 404
406 |
407 | latest_status = JobStatus.PENDING
408 | total_run_time = 'N/A'
409 | all_results = []
410 |
411 | for _, entry in status_stream:
412 | try:
413 | entry_status = JobStatus(json.loads(entry['status']))
414 | latest_status = entry_status
415 | except (json.JSONDecodeError, ValueError):
416 | logger.error(f"Invalid status value: {entry['status']}")
417 |
418 | if 'total_run_time' in entry:
419 | total_run_time = entry['total_run_time']
420 |
421 | if 'result' in entry:
422 | try:
423 | result_data = json.loads(entry['result'])
424 | all_results.append(result_data)
425 | except json.JSONDecodeError:
426 | logger.error(f"Failed to parse result JSON: {entry['result']}")
427 |
428 | response = PipelineResult(
429 | task_id=task_id,
430 | status=latest_status,
431 | results=all_results,
432 | total_run_time=total_run_time
433 | )
434 |
435 | return response.dict()
436 |
437 |
438 |
439 |
440 |
441 |
--------------------------------------------------------------------------------
/application/pipeline/start_pipeline.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import os
3 |
4 | # Add the parent directory of 'application' to sys.path
5 | sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
6 |
7 | from fastapi import FastAPI
8 | from fastapi.middleware.cors import CORSMiddleware
9 | from application.pipeline.routes.pipeline_routes import api_router as pipeline_api_router
10 | import uvicorn
11 |
12 | app = FastAPI(
13 | title="Marly API",
14 | description="The Data Processor for Agents",
15 | docs_url='/docs',
16 | openapi_url='/openapi.json'
17 | )
18 |
19 | app.add_middleware(
20 | CORSMiddleware,
21 | allow_origins=["*"],
22 | allow_methods=["*"],
23 | allow_headers=["*"],
24 | allow_credentials=True,
25 | )
26 |
27 |
28 |
29 | app.include_router(pipeline_api_router, prefix="", tags=["Pipeline Execution"])
30 |
31 | if __name__ == "__main__":
32 | uvicorn.run(app, host="127.0.0.1", port=8100)
33 |
--------------------------------------------------------------------------------
/application/transformation/Dockerfile:
--------------------------------------------------------------------------------
1 | FROM python:3.10-slim
2 |
3 | WORKDIR /app
4 |
5 | # Copy the common directory
6 | COPY ./common /app/common
7 |
8 | # Copy the transformation application directory
9 | COPY ./application/transformation /app/application/transformation
10 |
11 | # Copy the requirements file
12 | COPY ./application/transformation/requirements.txt /app/requirements.txt
13 |
14 | # Install dependencies
15 | RUN pip install --no-cache-dir -r /app/requirements.txt
16 |
17 | # Set the Python path
18 | ENV PYTHONPATH /app
19 |
20 | # Run the transformation service
21 | CMD ["python", "application/transformation/start_transformation.py"]
--------------------------------------------------------------------------------
/application/transformation/models/models.py:
--------------------------------------------------------------------------------
1 | from pydantic import BaseModel
2 | from typing import List, Dict
3 | from enum import Enum
4 |
5 | class TransformationRequestModel(BaseModel):
6 | task_id: str
7 | pdf_key: str
8 | results: List['SchemaResult']
9 | source_type: str = "pdf"
10 | destination: str = None
11 |
12 | class TransformationOnlyRequestModel(BaseModel):
13 | task_id: str
14 | data_location_key: str
15 | schemas: List[str]
16 | destination: str
17 | raw_data: str
18 |
19 | class TransformationResponseModel(BaseModel):
20 | task_id: str
21 | pdf_key: str
22 | results: List['SchemaResult']
23 |
24 | class SchemaResult(BaseModel):
25 | schema_id: str
26 | metrics: Dict[str, str]
27 | schema_data: Dict[str, str]
28 |
29 | class JobStatus(str, Enum):
30 | PENDING = "PENDING"
31 | IN_PROGRESS = "IN_PROGRESS"
32 | COMPLETED = "COMPLETED"
33 | FAILED = "FAILED"
34 |
35 | class ExtractionResponseModel(BaseModel):
36 | task_id: str
37 | pdf_key: str
38 | results: List[SchemaResult]
39 |
40 | class ModelDetails(BaseModel):
41 | provider_type: str
42 | provider_model_name: str
43 | api_key: str
44 | markdown_mode: bool
45 | additional_params: Dict[str, str]
46 |
--------------------------------------------------------------------------------
/application/transformation/requirements.txt:
--------------------------------------------------------------------------------
1 | asyncio
2 | typing
3 | redis[hiredis]
4 | python-dotenv
5 | pydantic
6 | PyPDF2
7 | aiohttp
8 | openai
9 | groq
10 | langchain
11 | langchainhub
12 | langsmith
13 | pdfminer.six
14 | cerebras_cloud_sdk
15 | boto3
16 | requests==2.32.3
17 | markitdown
--------------------------------------------------------------------------------
/application/transformation/service/transformation_handler.py:
--------------------------------------------------------------------------------
1 | import asyncio
2 | import logging
3 | from typing import Dict, List
4 | from redis.asyncio import Redis
5 | from common.redis.redis_config import get_redis_connection
6 | from common.models.model_factory import ModelFactory
7 | from application.transformation.models.models import ModelDetails
8 | from langsmith import Client as LangSmithClient
9 | from common.prompts.prompt_enums import PromptType
10 | from langchain.schema import SystemMessage, HumanMessage
11 | from common.destinations.destination_factory import DestinationFactory
12 | from common.destinations.enums.destination_enums import DestinationType
13 | import json
14 |
15 | logging.basicConfig(level=logging.INFO)
16 | logger = logging.getLogger(__name__)
17 |
18 | langsmith_client = LangSmithClient()
19 |
20 | def preprocess_messages(raw_payload):
21 | messages = []
22 | if hasattr(raw_payload, 'to_messages'):
23 | for message in raw_payload.to_messages():
24 | if isinstance(message, SystemMessage):
25 | messages.append({"role": "system", "content": message.content})
26 | elif isinstance(message, HumanMessage):
27 | messages.append({"role": "user", "content": message.content})
28 | else:
29 | logger.warning(f"Unexpected message type: {type(message)}")
30 | else:
31 | logger.warning(f"Unexpected raw_payload format: {type(raw_payload)}")
32 | return messages
33 |
34 | async def get_latest_model_details(redis: Redis) -> ModelDetails:
35 | try:
36 | model_details_json = await redis.get("model-details")
37 | if not model_details_json:
38 | logger.error("No data found for model-details")
39 | return None
40 |
41 | model_details = ModelDetails(**json.loads(model_details_json))
42 | return model_details
43 |
44 | except Exception as e:
45 | logger.error(f"Failed to get or parse model details: {e}")
46 | return None
47 |
48 | async def run_transformation(metrics: Dict[str, str], schema: Dict[str, str], source_type: str) -> Dict[str, str]:
49 | logger.info("Starting transformation process")
50 |
51 | redis: Redis = await get_redis_connection()
52 | model_details = await get_latest_model_details(redis)
53 | if not model_details:
54 | return {}
55 |
56 | try:
57 | model_instance = ModelFactory.create_model(
58 | model_type=model_details.provider_type,
59 | model_name=model_details.provider_model_name,
60 | api_key=model_details.api_key,
61 | additional_params=model_details.additional_params
62 | )
63 | markdown_mode = model_details.markdown_mode
64 | logger.info(f"Model instance created with type: {model_details.provider_type}, name: {model_details.provider_model_name}")
65 | except ValueError as e:
66 | logger.error(f"Model creation error: {e}")
67 | return {}
68 |
69 | schema_keys = ",".join(schema.keys())
70 |
71 | tasks = [
72 | asyncio.create_task(process_schema(model_instance, schema_id, metric_value, schema_keys, markdown_mode, source_type))
73 | for schema_id, metric_value in metrics.items()
74 | ]
75 |
76 | results = await asyncio.gather(*tasks, return_exceptions=True)
77 | transformed_metrics = {schema_id: result for schema_id, result in zip(metrics.keys(), results) if isinstance(result, str)}
78 | logger.info(f"Transformed metrics: {transformed_metrics}")
79 |
80 | return transformed_metrics
81 |
82 | async def process_schema(client, schema_id: str, metric_value: str, schema_keys: str, markdown_mode: bool, source_type: str) -> str:
83 | try:
84 | if source_type == 'web':
85 | prompt = langsmith_client.pull_prompt(PromptType.TRANSFORMATION_WEB.value)
86 | elif markdown_mode:
87 | prompt = langsmith_client.pull_prompt(PromptType.TRANSFORMATION_MARKDOWN.value)
88 | else:
89 | prompt = langsmith_client.pull_prompt(PromptType.TRANSFORMATION.value)
90 |
91 | messages = prompt.invoke({
92 | "first_value": metric_value,
93 | "second_value": schema_keys
94 | })
95 | processed_messages = preprocess_messages(messages)
96 | if not processed_messages:
97 | logger.error("No messages to process for transformation")
98 | return ""
99 | if markdown_mode:
100 | transformed_metric = client.do_completion(processed_messages)
101 | else:
102 | transformed_metric = client.do_completion(processed_messages, response_format={"type": "json_object"})
103 | return transformed_metric
104 | except Exception as e:
105 | logger.error(f"Error transforming metric for schema {schema_id}: {e}")
106 | return ""
107 |
108 | async def run_transformation_only(task_id: str, data_location_key: str, schemas: List[str], destination: str, raw_data: str) -> Dict[str, str]:
109 | # eventually we will want to do destination writes here
110 | logger.info(f"Starting transformation-only process for task_id: {task_id}")
111 |
112 | redis: Redis = await get_redis_connection()
113 | model_details = await get_latest_model_details(redis)
114 | if not model_details:
115 | return {}
116 |
117 | try:
118 | model_instance = ModelFactory.create_model(
119 | model_type=model_details.provider_type,
120 | model_name=model_details.provider_model_name,
121 | api_key=model_details.api_key,
122 | additional_params=model_details.additional_params
123 | )
124 | markdown_mode = model_details.markdown_mode
125 | logger.info(f"Model instance created with type: {model_details.provider_type}, name: {model_details.provider_model_name}")
126 | except ValueError as e:
127 | logger.error(f"Model creation error: {e}")
128 | return {}
129 |
130 | destination_config = {
131 | "db_path": destination,
132 | "additional_params": {}
133 | }
134 | destination_instance = DestinationFactory.create_destination(DestinationType.SQLITE.value, destination_config)
135 |
136 | try:
137 | database_schema = destination_instance.get_table_structure(data_location_key)
138 | except AttributeError:
139 | logger.error("The destination does not support get_table_structure method")
140 | database_schema = {}
141 |
142 | transformed_metrics = {}
143 | for schema in schemas:
144 | try:
145 | logger.info(f"schema: {schema}, raw_data: {raw_data}, database_schema: {database_schema}")
146 | prompt = langsmith_client.pull_prompt(PromptType.TRANSFORMATION_ONLY.value)
147 | messages = prompt.invoke({
148 | "first_value": schema,
149 | "second_value": raw_data,
150 | "third_value": database_schema
151 | })
152 | processed_messages = preprocess_messages(messages)
153 | if not processed_messages:
154 | logger.error(f"No messages to process for transformation of schema: {schema}")
155 | continue
156 | if markdown_mode:
157 | result = model_instance.do_completion(processed_messages)
158 | else:
159 | logger.info(f"Transforming schema: {schema} with JSON response format")
160 | result = model_instance.do_completion(processed_messages, response_format={"type": "json_object"})
161 | transformed_metrics[schema] = result
162 | except Exception as e:
163 | logger.error(f"Error transforming metric for schema {schema}: {e}")
164 |
165 | logger.info(f"Transformed metrics for task_id {task_id}: {transformed_metrics}")
166 | return transformed_metrics
167 |
--------------------------------------------------------------------------------
/application/transformation/service/transformation_worker.py:
--------------------------------------------------------------------------------
1 | import asyncio
2 | import json
3 | import logging
4 | from typing import Optional, Dict, Any, Union
5 | from redis.asyncio import Redis
6 | from redis.exceptions import RedisError
7 | from datetime import datetime
8 | from application.transformation.models.models import (
9 | TransformationRequestModel,
10 | TransformationResponseModel,
11 | SchemaResult,
12 | JobStatus,
13 | TransformationOnlyRequestModel
14 | )
15 | from application.transformation.service.transformation_handler import run_transformation, run_transformation_only
16 | from common.redis.redis_config import get_redis_connection
17 |
18 | logging.basicConfig(level=logging.INFO)
19 | logger = logging.getLogger(__name__)
20 |
21 | async def run_transformations() -> None:
22 | logger.info("Starting transformation worker")
23 | redis = await get_redis_connection()
24 | last_id_transformation = "0-0"
25 | last_id_transformation_only = "0-0"
26 |
27 | task_workloads = {}
28 |
29 | while True:
30 | try:
31 | result = await redis.xread(
32 | streams={
33 | "transformation-stream": last_id_transformation,
34 | "transformation-only-stream": last_id_transformation_only
35 | },
36 | count=1,
37 | block=0
38 | )
39 |
40 | for stream_name, messages in result:
41 | for message_id, message in messages:
42 | logger.info(f"Received message from stream {stream_name}: ID {message_id}")
43 | payload = message.get(b"payload")
44 | if payload:
45 | try:
46 | payload_dict = json.loads(payload.decode('utf-8'))
47 |
48 | if stream_name == b"transformation-stream":
49 | request = TransformationRequestModel(**payload_dict)
50 | else:
51 | request = TransformationOnlyRequestModel(**payload_dict)
52 |
53 | if request.task_id not in task_workloads:
54 | total_workloads = await get_total_workloads(redis, request.task_id)
55 | task_workloads[request.task_id] = {
56 | 'total': total_workloads,
57 | 'completed': 0
58 | }
59 |
60 | existing_results = await get_existing_results(redis, request.task_id)
61 | transformation_result = await process_transformation(request)
62 | merged_results = merge_results(existing_results, transformation_result)
63 |
64 | serialized_result = json.dumps(merged_results.dict())
65 | await redis.xadd(f"results-stream:{request.task_id}", {"payload": serialized_result})
66 |
67 | task_workloads[request.task_id]['completed'] += 1
68 |
69 | if task_workloads[request.task_id]['completed'] == task_workloads[request.task_id]['total']:
70 | logger.info(f"All workloads completed for task {request.task_id}")
71 | await update_job_status(redis, request.task_id, JobStatus.COMPLETED, serialized_result)
72 | del task_workloads[request.task_id] # Cleanup
73 | else:
74 | await update_job_status(redis, request.task_id, JobStatus.IN_PROGRESS, None)
75 |
76 | except Exception as e:
77 | logger.error(f"Error processing transformation task: {e}")
78 | await update_job_status(redis, request.task_id, JobStatus.FAILED, str(e))
79 | if request.task_id in task_workloads:
80 | del task_workloads[request.task_id]
81 |
82 | if stream_name == b"transformation-stream":
83 | last_id_transformation = message_id
84 | elif stream_name == b"transformation-only-stream":
85 | last_id_transformation_only = message_id
86 |
87 | except RedisError as e:
88 | logger.error(f"Error reading from Redis stream: {e}")
89 | await asyncio.sleep(1)
90 |
91 | async def process_transformation(
92 | transformation_request: Union[TransformationRequestModel, TransformationOnlyRequestModel]
93 | ) -> TransformationResponseModel:
94 | redis = await get_redis_connection()
95 | try:
96 | transformed_results = []
97 |
98 | if isinstance(transformation_request, TransformationOnlyRequestModel):
99 | transformed_metrics = await run_transformation_only(
100 | task_id=transformation_request.task_id,
101 | data_location_key=transformation_request.data_location_key,
102 | schemas=transformation_request.schemas,
103 | destination=transformation_request.destination,
104 | raw_data=transformation_request.raw_data
105 | )
106 | if isinstance(transformed_metrics, str):
107 | transformed_metrics = json.loads(transformed_metrics)
108 |
109 | transformed_results.append(SchemaResult(
110 | schema_id=f"{transformation_request.task_id}-transformation_only",
111 | schema_data={},
112 | metrics=transformed_metrics
113 | ))
114 | else:
115 | for schema_result in transformation_request.results:
116 | transformed_metrics = await run_transformation(
117 | metrics=schema_result.metrics,
118 | schema=schema_result.schema_data,
119 | source_type=transformation_request.source_type
120 | )
121 | if isinstance(transformed_metrics, str):
122 | transformed_metrics = json.loads(transformed_metrics)
123 |
124 | transformed_results.append(SchemaResult(
125 | schema_id=schema_result.schema_id,
126 | schema_data=schema_result.schema_data,
127 | metrics=transformed_metrics
128 | ))
129 |
130 | response = TransformationResponseModel(
131 | task_id=transformation_request.task_id,
132 | pdf_key=transformation_request.data_location_key if isinstance(transformation_request, TransformationOnlyRequestModel) else transformation_request.pdf_key,
133 | results=transformed_results
134 | )
135 |
136 | return response
137 | except Exception as e:
138 | logger.error(f"Error processing transformation task {transformation_request.task_id}: {e}")
139 | await update_job_status(redis, transformation_request.task_id, JobStatus.FAILED, str(e))
140 | raise
141 |
142 | async def update_job_status(
143 | redis: Redis,
144 | task_id: str,
145 | status: JobStatus,
146 | result: Optional[str]
147 | ) -> None:
148 | fields: Dict[str, Any] = {"status": json.dumps(status.value)}
149 | if result:
150 | fields["result"] = result
151 |
152 | start_time_str = await redis.get(f"job-start-time:{task_id}")
153 | start_time = int(start_time_str) if start_time_str else 0
154 | current_time = int(datetime.utcnow().timestamp())
155 | total_run_time = current_time - start_time
156 | run_time_str = f"{total_run_time // 60} minutes" if total_run_time >= 60 else f"{total_run_time} seconds"
157 | fields["total_run_time"] = run_time_str
158 |
159 | await redis.xadd(f"job-status:{task_id}", fields)
160 |
161 | async def get_existing_results(redis: Redis, task_id: str) -> Optional[TransformationResponseModel]:
162 | try:
163 | entries = await redis.xrevrange(f"results-stream:{task_id}", count=1)
164 | if entries:
165 | _, message = entries[0]
166 | payload = message.get(b"payload")
167 | if payload:
168 | return TransformationResponseModel(**json.loads(payload.decode('utf-8')))
169 | except Exception as e:
170 | logger.error(f"Error getting existing results: {e}")
171 | return None
172 |
173 | def merge_results(existing: Optional[TransformationResponseModel], new: TransformationResponseModel) -> TransformationResponseModel:
174 | if not existing:
175 | return new
176 |
177 | return TransformationResponseModel(
178 | task_id=new.task_id,
179 | pdf_key=new.pdf_key,
180 | results=existing.results + new.results
181 | )
182 |
183 | async def get_total_workloads(redis: Redis, task_id: str) -> int:
184 | """Get the total number of workloads for a task from Redis."""
185 | try:
186 | workload_count = await redis.get(f"workload-count:{task_id}")
187 | return int(workload_count) if workload_count else 1
188 | except Exception as e:
189 | logger.error(f"Error getting workload count: {e}")
190 | return 1
191 |
192 | async def clear_transformation_streams() -> None:
193 | retries = 0
194 | max_retries = 5
195 | delay = 1
196 |
197 | while retries < max_retries:
198 | try:
199 | redis = await get_redis_connection()
200 | await redis.xtrim("transformation-stream", approximate=True, maxlen=0)
201 | await redis.xtrim("transformation-only-stream", approximate=True, maxlen=0)
202 | logger.info("Cleared transformation-stream and transformation-only-stream")
203 | return
204 | except RedisError as e:
205 | if "LOADING" in str(e):
206 | logger.info(f"Redis is still loading. Retrying in {delay} seconds...")
207 | await asyncio.sleep(delay)
208 | retries += 1
209 | delay *= 2
210 | else:
211 | raise e
212 |
213 | raise Exception("Failed to clear transformation streams after maximum retries")
214 |
--------------------------------------------------------------------------------
/application/transformation/start_transformation.py:
--------------------------------------------------------------------------------
1 | import os
2 | import logging
3 | from dotenv import load_dotenv
4 | import asyncio
5 | from application.transformation.service.transformation_worker import clear_transformation_streams, run_transformations
6 |
7 | load_dotenv()
8 |
9 | logging.basicConfig(level=logging.INFO,
10 | format='%(asctime)s - %(levelname)s - %(message)s')
11 |
12 | async def main():
13 | try:
14 | await clear_transformation_streams()
15 | except Exception as e:
16 | logging.error("Failed to clear transformation streams: %s", e)
17 |
18 | try:
19 | logging.info("Started transformation service...")
20 | await run_transformations()
21 | except Exception as e:
22 | logging.error("Application error: %s", e)
23 | exit(1)
24 |
25 | if __name__ == "__main__":
26 | asyncio.run(main())
27 |
--------------------------------------------------------------------------------
/common/agents/agent_prompt_enums.py:
--------------------------------------------------------------------------------
1 | from enum import Enum
2 |
3 | class AgentMode(Enum):
4 | EXTRACTION = "extraction"
5 | PAGE_FINDER = "page_finder"
6 |
7 | class ExtractionPrompts(Enum):
8 | SYSTEM = """You are a precise data extraction assistant focused on consolidating and deduplicating information. Your task is to:
9 | 1. Extract and consolidate metrics from all sources
10 | 2. Remove duplicate entries and redundant information
11 | 3. Preserve EXACT units and symbols as they appear
12 |
13 | Guidelines:
14 | - Consolidate repeated information into single, definitive entries
15 | - When conflicting values exist, prefer confirmed/definitive information over rumors
16 | - Remove duplicate entries that refer to the same event/metric
17 | - Keep ALL units and symbols EXACTLY as they appear (e.g., keep "$15 million" as is)
18 | - Never perform unit conversions or reformatting
19 | - Mark as "Not Available" only when no reliable value exists
20 | - For multiple genuine distinct instances (e.g., different funding rounds), maintain separate entries
21 |
22 | Remember: Focus on providing clean, deduplicated data while maintaining accuracy."""
23 |
24 | ANALYSIS = """You are a precise analysis expert focused on deduplication and consolidation. Your task is to:
25 | 1. Find duplicate or redundant information
26 | 2. Identify consolidation opportunities
27 | 3. List improvements already made
28 |
29 | Format your response using ONLY these markers:
30 | - For duplicates/conflicts: Start lines with "⚠ Duplicate:" or "⚠ Conflict:"
31 | - For improvements/consolidations: Start lines with "✓ Consolidated:"
32 |
33 | Example format:
34 | ⚠ Duplicate: Found duplicate funding amount "$50M" in both Series A and B sections
35 | ⚠ Conflict: Different dates for same event - "March 2021" vs "03/2021"
36 | ✓ Consolidated: Combined Series A details into single entry
37 |
38 | Guidelines:
39 | - Each issue must start with "⚠"
40 | - Each improvement must start with "✓"
41 | - Be specific about what needs to be consolidated
42 | - Identify exact duplicates or conflicts
43 | - Keep each line focused on one issue/improvement
44 |
45 | Remember: Focus on finding information that can be combined or deduplicated."""
46 |
47 | FIX = """You are a precise consolidation and deduplication expert. Your task is to:
48 | 1. Review the current extraction result
49 | 2. Apply consolidation and deduplication fixes
50 | 3. Verify the accuracy of merged information
51 |
52 | Guidelines:
53 | - Merge duplicate entries into single, authoritative records
54 | - When consolidating conflicting values:
55 | * Prefer confirmed facts over rumors
56 | * Use the most specific/detailed version
57 | * Maintain original formatting and units
58 | - Keep entries separate only if they are genuinely distinct events
59 | - Document your consolidation decisions
60 | - Verify that no information is lost during consolidation
61 |
62 | Remember: Each consolidation must be verified and justified."""
63 |
64 | CONFIDENCE = """Rate the extraction confidence on a scale of 0.0 to 1.0.
65 | Score based on:
66 | - Consolidation: Information is properly deduplicated
67 | - Accuracy: Confirmed facts are prioritized
68 | - Unit Preservation: Original units and symbols are maintained
69 | - Format Alignment: Output matches expected format
70 |
71 | Lower score if:
72 | - Duplicate entries remain
73 | - Rumors are mixed with confirmed facts
74 | - Units or symbols were modified
75 | - Information is not properly consolidated
76 |
77 | Respond with ONLY a number between 0.0 and 1.0."""
78 |
79 | SYNTHESIS = """You are a synthesis expert focused on consolidation and deduplication. Your task is to:
80 | 1. Review all previous extraction attempts
81 | 2. Create a final answer that:
82 | - Consolidates duplicate information into single entries
83 | - Removes redundant entries referring to same events
84 | - Prioritizes confirmed facts over speculative information
85 | - Preserves ALL original units and symbols exactly as they appear
86 | - Never converts or reformats units
87 | - Maintains separate entries only for genuinely distinct instances
88 | - Marks values as "Not Available" only when no reliable data exists
89 |
90 | Start your response with 'FINAL ANSWER:' and ensure each entry represents unique, consolidated information."""
91 |
92 | VERIFICATION = """You are a verification expert focused on consolidation quality. Your task is to:
93 | 1. Compare before and after states of fixes
94 | 2. Verify that consolidation was successful by checking:
95 | - All duplicate entries were properly merged
96 | - No information was lost during consolidation
97 | - Confirmed facts were prioritized over rumors
98 | - Original formatting and units were preserved
99 | - Only genuinely distinct entries remain separate
100 |
101 | Output as JSON:
102 | {
103 | "consolidation_checks": [
104 | {
105 | "aspect": "description",
106 | "success": boolean,
107 | "details": "explanation"
108 | }
109 | ],
110 | "remaining_duplicates": [
111 | {
112 | "location": "where",
113 | "description": "what duplicates remain"
114 | }
115 | ],
116 | "verification_score": float
117 | }"""
118 |
119 | class PageFinderPrompts(Enum):
120 | SYSTEM = """You are a precise page relevance analyzer. Your task is to:
121 | 1. Analyze the given page content
122 | 2. Find pages containing relevant metric information
123 | 3. Provide clear evidence for your decisions
124 |
125 | Guidelines:
126 | - Look for clear mentions of metric names or values
127 | - Consider both direct mentions and strong contextual evidence
128 | - Evaluate relevance based on metric content quality
129 | - Provide specific text evidence for your decisions
130 | - When confident, start with 'FINAL ANSWER:'
131 |
132 | Focus on finding meaningful metric information."""
133 |
134 | REFLECTION = """Analyze the page relevance assessment with these points:
135 | - Metric Presence: Is the metric information clear and meaningful?
136 | - Supporting Evidence: What context shows metric relevance?
137 | - Contextual Clarity: Is the information reliable and well-supported?
138 | - Validation: Can the metric information be verified?
139 |
140 | Keep reflection focused on information quality (1-2 sentences)."""
141 |
142 | CONFIDENCE = """Rate the page relevance confidence on a scale of 0.0 to 1.0.
143 | Score based on:
144 | - Information Quality: Clear and reliable metric data
145 | - Context Support: Strong evidence in surrounding text
146 | - Relevance: Direct connection to requested metrics
147 |
148 | Lower score if:
149 | - Information is unclear or ambiguous
150 | - Context is weak or misleading
151 | - Metric connection is tenuous
152 |
153 | Respond with ONLY a number between 0.0 and 1.0."""
154 |
155 | SYNTHESIS = """You are a synthesis expert balancing precision with completeness. Your task is to:
156 | 1. Review all previous page assessments
157 | 2. Identify pages with valuable metric information
158 | 3. Create a final answer that:
159 | - Lists pages with strong metric evidence
160 | - Provides supporting context for relevance
161 | - Notes quality of metric information
162 | - Considers contextual relationships
163 | - Maintains high information standards
164 |
165 | Start your response with 'FINAL ANSWER:' and provide clear evidence for included pages."""
--------------------------------------------------------------------------------
/common/agents/prs_agent.py:
--------------------------------------------------------------------------------
1 | from typing import Annotated, Sequence, TypedDict, Literal
2 | import operator
3 | from dotenv import load_dotenv
4 | from langchain_core.messages import BaseMessage, AIMessage, HumanMessage
5 | from langgraph.graph import StateGraph, END
6 | import functools
7 | import logging
8 | import redis
9 | import json
10 | import uuid
11 | from .agent_prompt_enums import AgentMode, ExtractionPrompts, PageFinderPrompts
12 |
13 | load_dotenv()
14 |
15 | redis_client = redis.Redis(host='redis', port=6379, db=0)
16 | REDIS_EXPIRE = 60 * 60
17 |
18 | class AgentState(TypedDict):
19 | messages: Annotated[Sequence[BaseMessage], operator.add]
20 | sender: str
21 | confidence_score: float
22 | session_id: str
23 | iterations: int
24 |
25 | def get_redis_key(session_id: str, key_type: str) -> str:
26 | """Generate Redis key for different state types."""
27 | return f"prs:{session_id}:{key_type}"
28 |
29 | def store_list(session_id: str, key_type: str, items: list) -> None:
30 | """Store a list in Redis with expiry."""
31 | key = get_redis_key(session_id, key_type)
32 | if items:
33 | redis_client.delete(key)
34 | redis_client.rpush(key, *[json.dumps(item) for item in items])
35 | redis_client.expire(key, REDIS_EXPIRE)
36 |
37 | def get_list(session_id: str, key_type: str, start: int = 0, end: int = -1) -> list:
38 | """Get a list from Redis with optional range."""
39 | key = get_redis_key(session_id, key_type)
40 | items = redis_client.lrange(key, start, end)
41 | return [json.loads(item) for item in items] if items else []
42 |
43 | def get_prompts(mode: AgentMode):
44 | """Get the appropriate prompts for the specified mode."""
45 | if mode == AgentMode.EXTRACTION:
46 | return ExtractionPrompts
47 | return PageFinderPrompts
48 |
49 | def create_agent(client, system_message: str):
50 | """Create an agent with a specific system message."""
51 | def agent_fn(inputs):
52 | improvements_made = "\n".join(inputs.get("improvements", []))
53 | pending_fixes = "\n".join(inputs.get("pending_fixes", []))
54 |
55 | messages = [
56 | {"role": "system", "content": system_message},
57 | {"role": "system", "content": f"""Previous reflections identified these issues to fix:
58 | {pending_fixes}
59 |
60 | Improvements already made:
61 | {improvements_made}
62 |
63 | Focus on addressing the pending issues while maintaining previous improvements."""},
64 | {"role": "user", "content": inputs["messages"][-1].content}
65 | ]
66 | return client.do_completion(messages)
67 | return agent_fn
68 |
69 | def store_message(session_id: str, content: str, role: str = "ai") -> None:
70 | """Store a message in Redis."""
71 | key = get_redis_key(session_id, "messages")
72 | redis_client.rpush(key, json.dumps({"role": role, "content": content}))
73 | redis_client.expire(key, REDIS_EXPIRE)
74 |
75 | def get_last_message(session_id: str) -> str:
76 | """Get the last message content from Redis."""
77 | key = get_redis_key(session_id, "messages")
78 | last_msg = redis_client.lindex(key, -1)
79 | if last_msg:
80 | return json.loads(last_msg)["content"]
81 | return ""
82 |
83 | def get_all_messages(session_id: str) -> list:
84 | """Get all messages from Redis."""
85 | key = get_redis_key(session_id, "messages")
86 | messages = redis_client.lrange(key, 0, -1)
87 | return [json.loads(msg) for msg in messages] if messages else []
88 |
89 | def agent_node(state, agent, name):
90 | """Process the agent's response and update the state."""
91 | session_id = state["session_id"]
92 | messages = state["messages"]
93 |
94 | improvements = get_list(session_id, "improvements", -5)
95 | pending_fixes = get_list(session_id, "pending_fixes")
96 |
97 | result = agent({
98 | "messages": messages,
99 | "improvements": improvements,
100 | "pending_fixes": pending_fixes
101 | })
102 |
103 | store_message(session_id, result)
104 |
105 | return {
106 | "messages": messages + [AIMessage(content=result)],
107 | "sender": name,
108 | "confidence_score": state["confidence_score"],
109 | "session_id": session_id,
110 | "iterations": state["iterations"] + 1
111 | }
112 |
113 | def analyze_node(state, agent, name, prompts):
114 | """Analyze output and identify consolidation opportunities."""
115 | logger = logging.getLogger(__name__)
116 | session_id = state["session_id"]
117 |
118 | logger.info(f"\nStarting analysis iteration {state['iterations'] + 1}")
119 |
120 | last_message = get_last_message(session_id)
121 |
122 | prev_improvements = get_list(session_id, "improvements")
123 |
124 | analysis_messages = [
125 | {"role": "system", "content": prompts.ANALYSIS.value},
126 | {"role": "user", "content": f"""Current extraction to analyze:
127 | {last_message}
128 |
129 | Previous improvements made:
130 | {chr(10).join(prev_improvements)}
131 |
132 | Analyze this extraction focusing on duplicate entries and consolidation opportunities.
133 | Be thorough and identify ALL potential duplicates and conflicts."""}
134 | ]
135 |
136 | analysis = agent.do_completion(analysis_messages, temperature=0.2)
137 |
138 | verification_messages = [
139 | {"role": "system", "content": """Verify the previous analysis for:
140 | 1. Missed duplicates or conflicts
141 | 2. False positives in identified issues
142 | 3. Completeness of consolidation opportunities
143 | 4. Accuracy of proposed improvements"""},
144 | {"role": "user", "content": f"""Previous analysis:
145 | {analysis}
146 |
147 | Original content:
148 | {last_message}
149 |
150 | Verify and identify any missed issues or inaccuracies."""}
151 | ]
152 |
153 | verification = agent.do_completion(verification_messages, temperature=0.1)
154 |
155 | current_improvements = get_list(session_id, "improvements")
156 | current_pending_fixes = get_list(session_id, "pending_fixes")
157 |
158 | new_improvements = []
159 | new_fixes = []
160 |
161 | for content in [analysis, verification]:
162 | for line in content.split('\n'):
163 | line = line.strip()
164 | if not line:
165 | continue
166 | if line.startswith('✓'):
167 | if line not in current_improvements and line not in new_improvements:
168 | new_improvements.append(line)
169 | elif line.startswith('⚠'):
170 | if line not in current_pending_fixes and line not in new_fixes:
171 | new_fixes.append(line)
172 |
173 | if new_improvements:
174 | store_list(session_id, "improvements", current_improvements + new_improvements)
175 | if new_fixes:
176 | store_list(session_id, "pending_fixes", new_fixes)
177 |
178 | reflection = f"""Analysis {redis_client.llen(get_redis_key(session_id, 'reflections')) + 1}:
179 |
180 | Initial Analysis:
181 | {analysis}
182 |
183 | Verification:
184 | {verification}
185 |
186 | New Issues: {len(new_fixes)}
187 | New Improvements: {len(new_improvements)}"""
188 |
189 | redis_client.rpush(get_redis_key(session_id, "reflections"), json.dumps(reflection))
190 |
191 | logger.info(f"Analysis complete: {len(new_fixes)} new issues, {len(new_improvements)} improvements")
192 |
193 | return {
194 | "messages": state["messages"],
195 | "sender": name,
196 | "confidence_score": state["confidence_score"],
197 | "session_id": session_id,
198 | "iterations": state["iterations"]
199 | }
200 |
201 | def confidence_node(state, agent, name, prompts):
202 | """Score the confidence of the current analysis."""
203 | session_id = state["session_id"]
204 | messages = state["messages"]
205 | last_message = messages[-1].content if messages else ""
206 |
207 | confidence_messages = [
208 | {"role": "system", "content": prompts.CONFIDENCE.value},
209 | {"role": "user", "content": f"Analysis to score:\n{last_message}"}
210 | ]
211 |
212 | score = agent.do_completion(confidence_messages, temperature=0.0)
213 |
214 | try:
215 | confidence = float(score.strip())
216 | except:
217 | confidence = 0.5
218 |
219 | return {
220 | "messages": messages,
221 | "sender": name,
222 | "confidence_score": confidence,
223 | "session_id": session_id,
224 | "iterations": state["iterations"]
225 | }
226 |
227 | def fix_node(state, agent, name, prompts):
228 | """Fix identified issues focusing on deduplication."""
229 | logger = logging.getLogger(__name__)
230 | session_id = state["session_id"]
231 | last_message = get_last_message(session_id)
232 |
233 | pending_fixes = get_list(session_id, "pending_fixes")
234 | if pending_fixes:
235 | logger.info(f"\nAttempting to fix {len(pending_fixes)} issues:")
236 | for fix in pending_fixes:
237 | logger.info(f"⚠ {fix}")
238 |
239 | fix_messages = [
240 | {"role": "system", "content": prompts.FIX.value},
241 | {"role": "user", "content": f"""Content to fix:
242 | {last_message}
243 |
244 | Issues to address:
245 | {chr(10).join(pending_fixes)}
246 |
247 | Apply fixes systematically and verify each change."""}
248 | ]
249 |
250 | fixed_result = agent.do_completion(fix_messages, temperature=0.2)
251 |
252 | verify_messages = [
253 | {"role": "system", "content": """Verify that all fixes were properly applied:
254 | 1. Check each issue was addressed
255 | 2. Verify no information was lost
256 | 3. Confirm all consolidations are accurate
257 | 4. Ensure no new duplicates were created"""},
258 | {"role": "user", "content": f"""Original content:
259 | {last_message}
260 |
261 | Applied fixes:
262 | {fixed_result}
263 |
264 | Original issues:
265 | {chr(10).join(pending_fixes)}
266 |
267 | Verify all fixes were properly applied."""}
268 | ]
269 |
270 | verification = agent.do_completion(verify_messages, temperature=0.1)
271 |
272 | store_message(session_id, fixed_result)
273 | redis_client.rpush(get_redis_key(session_id, "fix_verifications"),
274 | json.dumps({"fixes": pending_fixes, "verification": verification}))
275 |
276 | redis_client.delete(get_redis_key(session_id, "pending_fixes"))
277 |
278 | logger.info("Fix attempt complete")
279 | logger.info(f"Verification result length: {len(verification.split())}")
280 |
281 | return {
282 | "messages": state["messages"] + [AIMessage(content=fixed_result)],
283 | "sender": name,
284 | "confidence_score": state["confidence_score"],
285 | "session_id": session_id,
286 | "iterations": state["iterations"]
287 | }
288 |
289 | def create_router(mode: AgentMode):
290 | """Create a router function based on the agent mode."""
291 | def router(state) -> Literal["process", "analyze", "fix", "score", "synthesize", "__end__"]:
292 | iterations = state["iterations"]
293 | confidence = state["confidence_score"]
294 |
295 | max_iterations = 2
296 | min_confidence = 0.8
297 |
298 | if iterations >= max_iterations or confidence >= min_confidence:
299 | if state["sender"] != "synthesizer":
300 | return "synthesize"
301 | return "__end__"
302 |
303 | if state["sender"] == "user":
304 | return "process"
305 | elif state["sender"] == "processor":
306 | return "analyze"
307 | elif state["sender"] == "analyzer":
308 | return "fix"
309 | elif state["sender"] == "fixer":
310 | return "score"
311 | elif state["sender"] == "scorer":
312 | if confidence < min_confidence and iterations < max_iterations:
313 | return "process"
314 | return "synthesize"
315 | else:
316 | return "process"
317 |
318 | return router
319 |
320 | def synthesize_node(state, agent, name, prompts):
321 | """Create final answer by synthesizing all iterations and improvements."""
322 | session_id = state["session_id"]
323 |
324 | messages = get_all_messages(session_id)
325 | improvements = get_list(session_id, "improvements", -5)
326 |
327 | final_result = agent.do_completion([
328 | {"role": "system", "content": prompts.SYNTHESIS.value},
329 | {"role": "user", "content": f"""Best response so far:
330 | {messages[-1]['content'] if messages else ''}
331 |
332 | Key improvements:
333 | {chr(10).join(improvements)}"""}
334 | ], temperature=0.0)
335 |
336 | store_message(session_id, final_result)
337 |
338 | return {
339 | "sender": "synthesizer",
340 | "confidence_score": state["confidence_score"],
341 | "session_id": session_id,
342 | "iterations": state["iterations"]
343 | }
344 |
345 | def create_graph(client, mode: AgentMode):
346 | """Create the workflow graph with the specified mode."""
347 | prompts = get_prompts(mode)
348 |
349 | text_processor = create_agent(client, prompts.SYSTEM.value)
350 |
351 | processor_node = functools.partial(agent_node, agent=text_processor, name="processor")
352 | analyzer = functools.partial(analyze_node, agent=client, name="analyzer", prompts=prompts)
353 | issue_fixer = functools.partial(fix_node, agent=client, name="fixer", prompts=prompts)
354 | confidence_scorer = functools.partial(confidence_node, agent=client, name="scorer", prompts=prompts)
355 | synthesizer = functools.partial(synthesize_node, agent=client, name="synthesizer", prompts=prompts)
356 |
357 | workflow = StateGraph(AgentState)
358 |
359 | workflow.add_node("process", processor_node)
360 | workflow.add_node("analyze", analyzer)
361 | workflow.add_node("fix", issue_fixer)
362 | workflow.add_node("score", confidence_scorer)
363 | workflow.add_node("synthesize", synthesizer)
364 |
365 | router = create_router(mode)
366 |
367 | for node in ["process", "analyze", "fix", "score", "synthesize"]:
368 | workflow.add_conditional_edges(
369 | node,
370 | router,
371 | {
372 | "process": "process",
373 | "analyze": "analyze",
374 | "fix": "fix",
375 | "score": "score",
376 | "synthesize": "synthesize",
377 | "__end__": END,
378 | },
379 | )
380 |
381 | workflow.set_entry_point("process")
382 |
383 | return workflow.compile()
384 |
385 | def process_extraction(text: str, client, mode: AgentMode) -> str:
386 | """Process text through the agent workflow with extraction handler format."""
387 | logger = logging.getLogger(__name__)
388 | session_id = str(uuid.uuid4())
389 |
390 | try:
391 | store_message(session_id, text, "human")
392 |
393 | graph = create_graph(client, mode)
394 |
395 | result = graph.invoke({
396 | "messages": [HumanMessage(content=text)],
397 | "sender": "user",
398 | "confidence_score": 0.0,
399 | "session_id": session_id,
400 | "iterations": 0
401 | })
402 |
403 | logger.info("\n" + "="*50)
404 | logger.info(f"AGENT MODE: {mode.value}")
405 | logger.info("="*50)
406 |
407 | reflections = get_list(session_id, "reflections")
408 | logger.info("\nPROCESS REFLECTIONS:")
409 | for i, reflection in enumerate(reflections, 1):
410 | logger.info(f"\nITERATION {i}:")
411 | logger.info(reflection)
412 |
413 | improvements = get_list(session_id, "improvements")
414 | logger.info("\nIMPROVEMENTS MADE:")
415 | if improvements:
416 | for imp in improvements:
417 | logger.info(f"✓ {imp}")
418 | else:
419 | logger.info("(No improvements needed)")
420 |
421 | fix_verifications = get_list(session_id, "fix_verifications")
422 | logger.info("\nFIX VERIFICATIONS:")
423 | for i, verification in enumerate(fix_verifications, 1):
424 | logger.info(f"\nFix Attempt {i}:")
425 | logger.info(f"Issues Addressed: {len(verification.get('fixes', []))}")
426 | logger.info(f"Verification Result: {verification.get('verification', '')}")
427 |
428 | pending_fixes = get_list(session_id, "pending_fixes")
429 | logger.info("\nFINAL STATE:")
430 | logger.info(f"Total Iterations: {result['iterations']}")
431 | logger.info(f"Final Confidence Score: {result['confidence_score']:.2f}")
432 | logger.info("\nRemaining Issues:")
433 | if pending_fixes:
434 | for issue in pending_fixes:
435 | logger.info(f"⚠ {issue}")
436 | else:
437 | logger.info("(No remaining issues)")
438 |
439 | logger.info("\nFINAL RESULT:")
440 | final_message = result["messages"][-1].content if result["messages"] else ""
441 | logger.info(final_message)
442 | logger.info("="*50 + "\n")
443 |
444 | for key_type in ["messages", "improvements", "pending_fixes", "reflections", "fix_verifications"]:
445 | redis_client.delete(get_redis_key(session_id, key_type))
446 |
447 | return final_message
448 |
449 | except Exception as e:
450 | logger.error(f"Extraction failed: {str(e)}")
451 | try:
452 | reflections = get_list(session_id, "reflections")
453 | if reflections:
454 | logger.error("\nLast known state before error:")
455 | logger.error(reflections[-1])
456 | except:
457 | pass
458 |
459 | for key_type in ["messages", "improvements", "pending_fixes", "reflections", "fix_verifications"]:
460 | redis_client.delete(get_redis_key(session_id, key_type))
461 | return "Extraction failed. Please try again."
--------------------------------------------------------------------------------
/common/api/3.1.0-marly-spec.yml:
--------------------------------------------------------------------------------
1 | openapi: 3.1.0
2 | info:
3 | title: Marly API
4 | description: The Data Processor for Agents
5 | version: 1.0.0
6 |
7 | paths:
8 | /pipelines:
9 | post:
10 | summary: Run pipeline
11 | description: Initiates a pipeline processing job for the given PDF and schemas.
12 | operationId: runPipeline
13 | requestBody:
14 | required: true
15 | content:
16 | application/json:
17 | schema:
18 | $ref: "#/components/schemas/PipelineRequestModel"
19 | responses:
20 | "202":
21 | description: Pipeline processing started
22 | content:
23 | application/json:
24 | schema:
25 | $ref: "#/components/schemas/PipelineResponseModel"
26 | "500":
27 | $ref: "#/components/responses/InternalServerError"
28 |
29 | /pipelines/{task_id}:
30 | get:
31 | summary: Get pipeline results
32 | description: Retrieves the results of a pipeline processing job.
33 | operationId: getPipelineResults
34 | parameters:
35 | - name: task_id
36 | in: path
37 | required: true
38 | schema:
39 | type: string
40 | description: Unique identifier for the pipeline task
41 | responses:
42 | "200":
43 | description: Pipeline results retrieved successfully
44 | content:
45 | application/json:
46 | schema:
47 | $ref: "#/components/schemas/PipelineResult"
48 | "404":
49 | description: Task not found
50 | "500":
51 | $ref: "#/components/responses/InternalServerError"
52 |
53 | components:
54 | schemas:
55 | WorkloadItem:
56 | type: object
57 | properties:
58 | raw_data:
59 | type: string
60 | description: string version of raw data (can be a pdf, html, text, etc.)
61 | schemas:
62 | type: array
63 | items:
64 | type: string
65 | description: List of schema strings
66 | data_source:
67 | type: string
68 | description: Type of data source
69 | documents_location:
70 | type: string
71 | description: Location of documents
72 | file_name:
73 | type: string
74 | description: Name of the file
75 | additional_params:
76 | type: object
77 | additionalProperties: true
78 | description: Additional parameters for the workload
79 | destination:
80 | type: string
81 | description: Destination for the processed data
82 | required:
83 | - schemas
84 |
85 | PipelineRequestModel:
86 | type: object
87 | properties:
88 | workloads:
89 | type: array
90 | items:
91 | $ref: "#/components/schemas/WorkloadItem"
92 | provider_type:
93 | type: string
94 | provider_model_name:
95 | type: string
96 | api_key:
97 | type: string
98 | markdown_mode:
99 | type: boolean
100 | default: false
101 | additional_params:
102 | type: object
103 | additionalProperties: true
104 | required:
105 | - workloads
106 | - provider_type
107 | - provider_model_name
108 | - api_key
109 |
110 | PipelineResponseModel:
111 | type: object
112 | properties:
113 | task_id:
114 | type: string
115 | description: Unique identifier for the pipeline task
116 | message:
117 | type: string
118 | description: Status message
119 | required:
120 | - task_id
121 | - message
122 |
123 | PipelineResult:
124 | type: object
125 | properties:
126 | task_id:
127 | type: string
128 | description: Unique identifier for the pipeline task
129 | status:
130 | $ref: "#/components/schemas/JobStatus"
131 | results:
132 | type: array
133 | items:
134 | $ref: "#/components/schemas/SchemaResult"
135 | description: Array of schema results
136 | total_run_time:
137 | type: string
138 | description: Total execution time of the pipeline
139 | required:
140 | - task_id
141 | - status
142 | - results
143 | - total_run_time
144 |
145 | SchemaResult:
146 | type: object
147 | properties:
148 | schema_id:
149 | type: string
150 | description: Identifier for the schema used
151 | metrics:
152 | type: object
153 | additionalProperties:
154 | type: string
155 | description: Metrics related to the schema extraction
156 | schema_data:
157 | type: object
158 | additionalProperties:
159 | type: string
160 | description: Extracted data based on the schema
161 | required:
162 | - schema_id
163 | - metrics
164 | - schema_data
165 |
166 | JobStatus:
167 | type: string
168 | enum:
169 | - PENDING
170 | - IN_PROGRESS
171 | - COMPLETED
172 | - FAILED
173 | description: Current status of the pipeline job
174 |
175 | responses:
176 | InternalServerError:
177 | description: Internal Server Error
178 | content:
179 | application/json:
180 | schema:
181 | type: object
182 | properties:
183 | detail:
184 | type: string
185 | description: Error details
186 |
--------------------------------------------------------------------------------
/common/destinations/base/base_destination.py:
--------------------------------------------------------------------------------
1 | from abc import ABC, abstractmethod
2 | from typing import Dict, Any, List, Optional
3 |
4 | class BaseDestination(ABC):
5 | @abstractmethod
6 | def connect(self) -> None:
7 | """Establish a connection to the destination."""
8 | pass
9 |
10 | @abstractmethod
11 | def insert(self, data: List[Dict[str, Any]]) -> None:
12 | """Insert data into the destination."""
13 | pass
14 |
15 | @abstractmethod
16 | def close(self) -> None:
17 | """Close the connection to the destination."""
18 | pass
19 |
20 | @abstractmethod
21 | def get_table_structure(self, table_name: str, schema: Optional[str] = None) -> List[Dict[str, Any]]:
22 | """Get the structure of a table."""
23 | pass
24 |
--------------------------------------------------------------------------------
/common/destinations/destination_factory.py:
--------------------------------------------------------------------------------
1 | from typing import Dict, Any
2 | from common.destinations.enums.destination_enums import DestinationType
3 | from common.destinations.sqlite_destination import SQLiteDestination
4 |
5 | class DestinationFactory:
6 | @staticmethod
7 | def create_destination(destination_type: str, config: Dict[str, Any] = None):
8 | if not destination_type:
9 | raise ValueError("destination_type must be provided.")
10 |
11 | try:
12 | destination_type_enum = DestinationType(destination_type.lower())
13 | except ValueError:
14 | raise ValueError(f"Invalid destination type. Allowed values are: {', '.join([d.value for d in DestinationType])}")
15 |
16 | if destination_type_enum == DestinationType.SQLITE:
17 | return SQLiteDestination(
18 | db_path=config.get("db_path", ":memory:"),
19 | additional_params=config.get("additional_params", {})
20 | )
21 | else:
22 | raise ValueError(f"Unsupported destination type: {destination_type_enum}")
23 |
24 |
--------------------------------------------------------------------------------
/common/destinations/enums/destination_enums.py:
--------------------------------------------------------------------------------
1 | from enum import Enum
2 |
3 | class DestinationType(Enum):
4 | SQLITE = "sqlite"
5 |
6 |
--------------------------------------------------------------------------------
/common/destinations/sqlite_destination.py:
--------------------------------------------------------------------------------
1 | import sqlite3
2 | import os
3 | from typing import Dict, Any, List, Optional
4 | from common.destinations.base.base_destination import BaseDestination
5 |
6 | class SQLiteDestination(BaseDestination):
7 | def __init__(self, db_path: str, additional_params: Dict[str, Any] = None):
8 | self.db_path = db_path
9 | self.additional_params = additional_params or {}
10 | self.schema = self.additional_params.get("schema", "main")
11 | self.conn = None
12 | self.cursor = None
13 |
14 | def connect(self) -> None:
15 | if not self.conn:
16 | try:
17 | self.conn = sqlite3.connect(self.db_path)
18 | self.cursor = self.conn.cursor()
19 | except sqlite3.OperationalError:
20 | # If the directory doesn't exist, create it
21 | os.makedirs(os.path.dirname(self.db_path), exist_ok=True)
22 | self.conn = sqlite3.connect(self.db_path)
23 | self.cursor = self.conn.cursor()
24 | print(f"Created new SQLite database at {self.db_path}")
25 |
26 | def insert(self, table_name: str, data: List[Dict[str, Any]]) -> None:
27 | self.connect()
28 | if not data:
29 | raise ValueError("No data provided for insertion")
30 |
31 | if not self.table_exists(table_name):
32 | self.create_table(table_name, data[0].keys())
33 |
34 | columns = list(data[0].keys())
35 | placeholders = ', '.join(['?' for _ in columns])
36 | insert_query = f"INSERT INTO {table_name} ({', '.join(columns)}) VALUES ({placeholders})"
37 |
38 | try:
39 | for row in data:
40 | values = [row.get(col, '') for col in columns]
41 | self.cursor.execute(insert_query, values)
42 | self.conn.commit()
43 | except sqlite3.Error as e:
44 | self.conn.rollback()
45 | raise ValueError(f"Error inserting data: {str(e)}")
46 |
47 | def table_exists(self, table_name: str) -> bool:
48 | self.connect()
49 | self.cursor.execute(f"SELECT name FROM sqlite_master WHERE type='table' AND name=?", (table_name,))
50 | return self.cursor.fetchone() is not None
51 |
52 | def create_table(self, table_name: str, columns: List[str]) -> None:
53 | column_defs = ', '.join([f"{col} TEXT" for col in columns])
54 | create_query = f"CREATE TABLE IF NOT EXISTS {table_name} ({column_defs})"
55 | self.cursor.execute(create_query)
56 | self.conn.commit()
57 |
58 | def close(self) -> None:
59 | if self.cursor:
60 | self.cursor.close()
61 | if self.conn:
62 | self.conn.close()
63 | self.conn = None
64 | self.cursor = None
65 |
66 | def get_table_structure(self, table_name: str, schema: Optional[str] = None) -> List[Dict[str, Any]]:
67 | self.connect()
68 | schema = schema or self.schema
69 | query = f"PRAGMA {schema}.table_info({table_name})"
70 | self.cursor.execute(query)
71 | columns = self.cursor.fetchall()
72 |
73 | structure = []
74 | for column in columns:
75 | structure.append({
76 | "name": column[1],
77 | "type": column[2],
78 | "notnull": bool(column[3]),
79 | "default_value": column[4],
80 | "pk": bool(column[5])
81 | })
82 |
83 | return structure
84 |
--------------------------------------------------------------------------------
/common/models/azure_model.py:
--------------------------------------------------------------------------------
1 | from typing import Dict, List, Optional, Union, Any
2 | from common.models.base.base_model import BaseModel
3 | from openai import AzureOpenAI
4 | from common.models.enums.model_enums import AzureModelName
5 | from langsmith import traceable
6 |
7 | class AzureModel(BaseModel):
8 | def __init__(self, api_key: str, model_name: str, additional_params: Dict[str, Any] = None):
9 | if not api_key:
10 | raise ValueError("API key must be provided.")
11 | self.api_key = api_key
12 |
13 | if not additional_params:
14 | additional_params = {}
15 |
16 | self.api_version = additional_params.get("api_version")
17 | if not self.api_version:
18 | raise ValueError("API version must be provided in additional_params.")
19 |
20 | self.azure_endpoint = additional_params.get("azure_endpoint")
21 | if not self.azure_endpoint:
22 | raise ValueError("Azure endpoint must be provided in additional_params.")
23 |
24 | self.azure_deployment = additional_params.get("azure_deployment")
25 | if not self.azure_deployment:
26 | raise ValueError("Azure deployment must be provided in additional_params.")
27 |
28 | self.model_name = self.validate_model_name(model_name)
29 |
30 | self.client = AzureOpenAI(
31 | api_key=self.api_key,
32 | api_version=self.api_version,
33 | azure_endpoint=self.azure_endpoint,
34 | azure_deployment=self.azure_deployment
35 | )
36 |
37 | @staticmethod
38 | def validate_model_name(model_name: str) -> str:
39 | try:
40 | return AzureModelName(model_name).value
41 | except ValueError:
42 | raise ValueError(f"Invalid model name. Allowed values are: {', '.join([m.value for m in AzureModelName])}")
43 |
44 | @traceable(run_type="llm")
45 | def do_completion(self,
46 | messages: List[Dict[str, str]],
47 | model_name: Optional[str] = None,
48 | max_tokens: Optional[int] = None,
49 | temperature: Optional[float] = None,
50 | top_p: Optional[float] = None,
51 | n: Optional[int] = None,
52 | stop: Optional[Union[str, List[str]]] = None,
53 | response_format: Optional[Dict[str, str]] = None) -> str:
54 | if not messages:
55 | raise ValueError("'messages' must be provided.")
56 |
57 | params = {
58 | "model": self.validate_model_name(model_name) if model_name else self.model_name,
59 | "messages": messages,
60 | }
61 |
62 | if max_tokens is not None:
63 | params["max_tokens"] = max_tokens
64 | if temperature is not None:
65 | params["temperature"] = temperature
66 | if top_p is not None:
67 | params["top_p"] = top_p
68 | if n is not None:
69 | params["n"] = n
70 | if stop is not None:
71 | params["stop"] = stop
72 | if response_format is not None:
73 | params["response_format"] = response_format
74 |
75 | response = self.client.chat.completions.create(**params)
76 |
77 | return response.choices[0].message.content
78 |
--------------------------------------------------------------------------------
/common/models/base/base_model.py:
--------------------------------------------------------------------------------
1 | from typing import Dict
2 |
3 | class BaseModel:
4 | def do_completion(self, data: Dict):
5 | raise NotImplementedError
--------------------------------------------------------------------------------
/common/models/cerebras_model.py:
--------------------------------------------------------------------------------
1 | from typing import Dict, List, Optional, Union, Any
2 | from common.models.base.base_model import BaseModel
3 | from common.models.enums.model_enums import CerebrasModelName
4 | from cerebras.cloud.sdk import Cerebras
5 | from langsmith import traceable
6 |
7 | class CerebrasModel(BaseModel):
8 | def __init__(self, api_key: str, model_name: str, additional_params: Dict[str, Any] = None):
9 | if not api_key:
10 | raise ValueError("API key must be provided.")
11 | self.api_key = api_key
12 |
13 | if not additional_params:
14 | additional_params = {}
15 |
16 | self.model_name = self.validate_model_name(model_name)
17 |
18 | self.client = Cerebras(api_key=self.api_key)
19 |
20 | @staticmethod
21 | def validate_model_name(model_name: str) -> str:
22 | try:
23 | return CerebrasModelName(model_name).value
24 | except ValueError:
25 | raise ValueError(f"Invalid model name. Allowed values are: {', '.join([m.value for m in CerebrasModelName])}")
26 |
27 | @traceable(run_type="llm")
28 | def do_completion(self,
29 | messages: List[Dict[str, str]],
30 | model_name: Optional[str] = None,
31 | max_tokens: Optional[int] = None,
32 | temperature: Optional[float] = None,
33 | top_p: Optional[float] = None,
34 | n: Optional[int] = None,
35 | stop: Optional[Union[str, List[str]]] = None,
36 | response_format: Optional[Dict[str, str]] = None) -> str:
37 | if not messages:
38 | raise ValueError("'messages' must be provided.")
39 |
40 | params = {
41 | "model": self.validate_model_name(model_name) if model_name else self.model_name,
42 | "messages": messages,
43 | }
44 |
45 | if max_tokens is not None:
46 | params["max_tokens"] = max_tokens
47 | if temperature is not None:
48 | params["temperature"] = temperature
49 | if top_p is not None:
50 | params["top_p"] = top_p
51 | if n is not None:
52 | params["n"] = n
53 | if stop is not None:
54 | params["stop"] = stop
55 | if response_format is not None:
56 | params["response_format"] = response_format
57 |
58 | response = self.client.chat.completions.create(**params)
59 |
60 | return response.choices[0].message.content
61 |
--------------------------------------------------------------------------------
/common/models/enums/model_enums.py:
--------------------------------------------------------------------------------
1 | from enum import Enum
2 |
3 | class ModelType(Enum):
4 | OPENAI = "openai"
5 | AZURE = "azure"
6 | GROQ = "groq"
7 | CEREBRAS = "cerebras"
8 | MISTRAL = "mistral"
9 |
10 | class OpenAIModelName(Enum):
11 | GPT_3_5_TURBO = "gpt-3.5-turbo"
12 | GPT_4 = "gpt-4"
13 | GPT_4O = "gpt-4o"
14 | CLAUDE_3_HAIKU = "claude-3-haiku"
15 |
16 | class AzureModelName(Enum):
17 | GPT_35_TURBO = "gpt-35-turbo"
18 | GPT_4 = "gpt-4"
19 | GPT_4_32K = "gpt-4-32k"
20 | GPT_4O = "gpt-4o"
21 |
22 | class GroqModelName(Enum):
23 | DISTIL_WHISPER_ENGLISH = "distil-whisper-large-v3-en"
24 | GEMMA2_9B = "gemma2-9b-it"
25 | GEMMA_7B = "gemma-7b-it"
26 | LLAMA3_GROQ_70B_TOOL_USE = "llama3-groq-70b-8192-tool-use-preview"
27 | LLAMA3_GROQ_8B_TOOL_USE = "llama3-groq-8b-8192-tool-use-preview"
28 | LLAMA_3_1_70B = "llama-3.1-70b-versatile"
29 | LLAMA_3_1_8B = "llama-3.1-8b-instant"
30 | LLAMA_GUARD_3_8B = "llama-guard-3-8b"
31 | LLAVA_1_5_7B = "llava-v1.5-7b-4096-preview"
32 | META_LLAMA3_70B = "llama3-70b-8192"
33 | META_LLAMA3_8B = "llama3-8b-8192"
34 | MIXTRAL_8X7B = "mixtral-8x7b-32768"
35 | WHISPER = "whisper-large-v3"
36 |
37 | class CerebrasModelName(Enum):
38 | LLAMA_3_1_70B = "llama3.1-70b"
39 | LLAMA_3_3_70B = "llama3.3-70b"
40 | LLAMA_3_1_8B = "llama3.1-8b"
41 |
42 | class MistralModelName(Enum):
43 | MISTRAL_LARGE_LATEST = "mistral-large-latest"
44 |
45 | class MistralAPIURL(Enum):
46 | CHAT_COMPLETIONS = "https://api.mistral.ai/v1/chat/completions"
47 |
--------------------------------------------------------------------------------
/common/models/groq_model.py:
--------------------------------------------------------------------------------
1 | from typing import Dict, List, Optional, Union, Any
2 | from common.models.base.base_model import BaseModel
3 | from common.models.enums.model_enums import GroqModelName
4 | from groq import Groq
5 | from langsmith import traceable
6 |
7 | class GroqModel(BaseModel):
8 | def __init__(self, api_key: str, model_name: str, additional_params: Dict[str, Any] = None):
9 | if not api_key:
10 | raise ValueError("API key must be provided.")
11 | self.api_key = api_key
12 |
13 | if not additional_params:
14 | additional_params = {}
15 |
16 | self.model_name = self.validate_model_name(model_name)
17 | self.client = Groq(api_key=self.api_key)
18 |
19 | @staticmethod
20 | def validate_model_name(model_name: str) -> str:
21 | try:
22 | return GroqModelName(model_name).value
23 | except ValueError:
24 | raise ValueError(f"Invalid model name. Allowed values are: {', '.join([m.value for m in GroqModelName])}")
25 |
26 | @traceable(run_type="llm")
27 | def do_completion(self,
28 | messages: List[Dict[str, str]],
29 | model_name: Optional[str] = None,
30 | max_tokens: Optional[int] = None,
31 | temperature: Optional[float] = None,
32 | top_p: Optional[float] = None,
33 | stop: Optional[Union[str, List[str]]] = None,
34 | response_format: Optional[Dict[str, str]] = None) -> str:
35 | if not messages:
36 | raise ValueError("'messages' must be provided.")
37 |
38 | params = {
39 | "model": self.validate_model_name(model_name) if model_name else self.model_name,
40 | "messages": messages,
41 | }
42 |
43 | if max_tokens is not None:
44 | params["max_tokens"] = max_tokens
45 | if temperature is not None:
46 | params["temperature"] = temperature
47 | if top_p is not None:
48 | params["top_p"] = top_p
49 | if stop is not None:
50 | params["stop"] = stop
51 | if response_format is not None:
52 | params["response_format"] = response_format
53 |
54 | response = self.client.chat.completions.create(**params)
55 |
56 | return response.choices[0].message.content
57 |
--------------------------------------------------------------------------------
/common/models/mistral_model.py:
--------------------------------------------------------------------------------
1 | from typing import Dict, List, Optional, Union, Any
2 | from common.models.base.base_model import BaseModel
3 | from common.models.enums.model_enums import MistralModelName, MistralAPIURL
4 | from langsmith import traceable
5 | import requests
6 |
7 | class MistralModel(BaseModel):
8 | def __init__(self, api_key: str, model_name: str, additional_params: Dict[str, Any] = None):
9 | if not api_key:
10 | raise ValueError("API key must be provided.")
11 | self.api_key = api_key
12 |
13 | if not additional_params:
14 | additional_params = {}
15 |
16 | self.model_name = self.validate_model_name(model_name)
17 | self.api_url = MistralAPIURL.CHAT_COMPLETIONS.value
18 |
19 | @staticmethod
20 | def validate_model_name(model_name: str) -> str:
21 | try:
22 | return MistralModelName(model_name).value
23 | except ValueError:
24 | raise ValueError(f"Invalid model name. Allowed values are: {', '.join([m.value for m in MistralModelName])}")
25 |
26 | @traceable(run_type="llm")
27 | def do_completion(self,
28 | messages: List[Dict[str, str]],
29 | model_name: Optional[str] = None,
30 | max_tokens: Optional[int] = None,
31 | temperature: Optional[float] = None,
32 | top_p: Optional[float] = None,
33 | stop: Optional[Union[str, List[str]]] = None,
34 | response_format: Optional[Dict[str, str]] = None) -> str:
35 | if not messages:
36 | raise ValueError("'messages' must be provided.")
37 |
38 | headers = {
39 | "Content-Type": "application/json",
40 | "Accept": "application/json",
41 | "Authorization": f"Bearer {self.api_key}"
42 | }
43 |
44 | data = {
45 | "model": self.validate_model_name(model_name) if model_name else self.model_name,
46 | "messages": messages,
47 | }
48 |
49 | if max_tokens is not None:
50 | data["max_tokens"] = max_tokens
51 | if temperature is not None:
52 | data["temperature"] = temperature
53 | if top_p is not None:
54 | data["top_p"] = top_p
55 | if stop is not None:
56 | data["stop"] = stop
57 | if response_format is not None:
58 | data["response_format"] = response_format
59 |
60 | response = requests.post(self.api_url, headers=headers, json=data)
61 |
62 | if response.status_code == 200:
63 | result = response.json()
64 | return result['choices'][0]['message']['content']
65 | else:
66 | raise Exception(f"Error: {response.status_code}\n{response.text}")
67 |
--------------------------------------------------------------------------------
/common/models/model_factory.py:
--------------------------------------------------------------------------------
1 | import logging
2 | from typing import Dict, Any
3 | from common.models.enums.model_enums import ModelType, OpenAIModelName, AzureModelName, GroqModelName, CerebrasModelName, MistralModelName
4 | from common.models.openai_model import OpenaiModel
5 | from common.models.azure_model import AzureModel
6 | from common.models.groq_model import GroqModel
7 | from common.models.cerebras_model import CerebrasModel
8 | from common.models.mistral_model import MistralModel
9 |
10 | logger = logging.getLogger(__name__)
11 |
12 | class ModelFactory:
13 | @staticmethod
14 | def create_model(model_type: str, model_name: str, api_key: str, additional_params: Dict[str, Any] = None):
15 | if not model_type:
16 | raise ValueError("model_type must be provided.")
17 | if not model_name:
18 | raise ValueError("model_name must be provided.")
19 | if not api_key:
20 | raise ValueError("API key must be provided.")
21 |
22 | try:
23 | model_type_enum = ModelType(model_type.lower())
24 | except ValueError:
25 | raise ValueError(f"Invalid model type. Allowed values are: {', '.join([m.value for m in ModelType])}")
26 |
27 | model_config = {
28 | "api_key": api_key,
29 | "model_name": model_name,
30 | "additional_params": additional_params or {}
31 | }
32 |
33 | required_params = {
34 | ModelType.AZURE: ["api_version", "azure_endpoint", "azure_deployment"],
35 | ModelType.OPENAI: [],
36 | ModelType.GROQ: [],
37 | ModelType.CEREBRAS: [],
38 | ModelType.MISTRAL: []
39 | }
40 |
41 | missing_params = []
42 | for param in required_params.get(model_type_enum, []):
43 | if param not in model_config["additional_params"]:
44 | missing_params.append(param)
45 |
46 | if missing_params:
47 | raise ValueError(f"Missing required additional_params for {model_type}: {', '.join(missing_params)}")
48 | if model_type_enum == ModelType.OPENAI:
49 | OpenAIModelName(model_name)
50 | model_instance = OpenaiModel(**model_config)
51 | elif model_type_enum == ModelType.AZURE:
52 | AzureModelName(model_name)
53 | model_instance = AzureModel(**model_config)
54 | elif model_type_enum == ModelType.GROQ:
55 | GroqModelName(model_name)
56 | model_instance = GroqModel(**model_config)
57 | elif model_type_enum == ModelType.CEREBRAS:
58 | CerebrasModelName(model_name)
59 | model_instance = CerebrasModel(**model_config)
60 | elif model_type_enum == ModelType.MISTRAL:
61 | MistralModelName(model_name)
62 | model_instance = MistralModel(**model_config)
63 | else:
64 | raise ValueError(f"Unsupported model type: {model_type_enum}")
65 |
66 | logger.info(f"Returning model instance of type: {model_type_enum.value}")
67 | return model_instance
--------------------------------------------------------------------------------
/common/models/openai_model.py:
--------------------------------------------------------------------------------
1 | from typing import Dict, List, Optional, Union, Any
2 | from common.models.base.base_model import BaseModel
3 | from common.models.enums.model_enums import OpenAIModelName
4 | from openai import OpenAI
5 | from langsmith import traceable
6 |
7 | class OpenaiModel(BaseModel):
8 | def __init__(self, api_key: str, model_name: str, additional_params: Dict[str, Any] = None):
9 | if not api_key:
10 | raise ValueError("API key must be provided.")
11 | self.api_key = api_key
12 |
13 | if not additional_params:
14 | additional_params = {}
15 |
16 | self.model_name = self.validate_model_name(model_name)
17 | self.base_url = additional_params.get('base_url')
18 |
19 | client_params = {'api_key': self.api_key}
20 | if self.base_url:
21 | client_params['base_url'] = self.base_url
22 |
23 | self.client = OpenAI(**client_params)
24 |
25 | @staticmethod
26 | def validate_model_name(model_name: str) -> str:
27 | try:
28 | return OpenAIModelName(model_name).value
29 | except ValueError:
30 | raise ValueError(f"Invalid model name. Allowed values are: {', '.join([m.value for m in OpenAIModelName])}")
31 |
32 | @traceable(run_type="llm")
33 | def do_completion(self,
34 | messages: List[Dict[str, str]],
35 | model_name: Optional[str] = None,
36 | max_tokens: Optional[int] = None,
37 | temperature: Optional[float] = None,
38 | top_p: Optional[float] = None,
39 | n: Optional[int] = None,
40 | stop: Optional[Union[str, List[str]]] = None,
41 | response_format: Optional[Dict[str, str]] = None) -> Dict:
42 | if not messages:
43 | raise ValueError("'messages' must be provided.")
44 |
45 | params = {
46 | "model": self.validate_model_name(model_name) if model_name else self.model_name,
47 | "messages": messages,
48 | }
49 |
50 | if max_tokens is not None:
51 | params["max_tokens"] = max_tokens
52 | if temperature is not None:
53 | params["temperature"] = temperature
54 | if top_p is not None:
55 | params["top_p"] = top_p
56 | if n is not None:
57 | params["n"] = n
58 | if stop is not None:
59 | params["stop"] = stop
60 | if response_format is not None:
61 | params["response_format"] = response_format
62 |
63 | response = self.client.chat.completions.create(**params)
64 |
65 | return response.choices[0].message.content
66 |
--------------------------------------------------------------------------------
/common/prompts/prompt_enums.py:
--------------------------------------------------------------------------------
1 | from enum import Enum
2 |
3 | class PromptType(Enum):
4 | EXAMPLE_GENERATION = "marly/example-generation"
5 | EXTRACTION = "marly/extraction"
6 | TRANSFORMATION = "marly/transformation"
7 | TRANSFORMATION_MARKDOWN = "noahegg/tranformation-markdown"
8 | TRANSFORMATION_WEB = "noahegg/web_scraper"
9 | TRANSFORMATION_ONLY = "marly/transformation-only"
10 | VALIDATION = "marly/validation"
11 | RELEVANT_PAGE_FINDER = "marly/relevant-page-finder"
12 | PLAN = "marly/plan"
13 | RELEVANT_PAGE_FINDER_V2 = "marly/relevant-page-finder-with-plan"
14 |
--------------------------------------------------------------------------------
/common/redis/redis_config.py:
--------------------------------------------------------------------------------
1 | import redis.asyncio as redis
2 | from redis.asyncio import ConnectionPool
3 | import os
4 |
5 | class RedisClient:
6 | def __init__(self):
7 | self.host = os.getenv('REDIS_HOST', '127.0.0.1')
8 | self.port = os.getenv('REDIS_PORT', 6379)
9 | self.db = os.getenv('REDIS_DB', 0)
10 | self.pool = ConnectionPool(host=self.host, port=self.port, db=self.db)
11 | self.client = redis.Redis(connection_pool=self.pool)
12 |
13 | def pipeline(self):
14 | return self.client.pipeline()
15 |
16 | redis_client = RedisClient().client
17 |
18 | async def get_redis_connection():
19 | return redis_client
20 |
--------------------------------------------------------------------------------
/common/sources/base/base_source.py:
--------------------------------------------------------------------------------
1 | from abc import ABC, abstractmethod
2 | from typing import Dict, Any, Optional, List
3 | from io import BytesIO
4 |
5 | class BaseSource(ABC):
6 | @abstractmethod
7 | def connect(self) -> None:
8 | """Establish a connection to the data source."""
9 | pass
10 |
11 | @abstractmethod
12 | def read(self, data: Dict[str, Any]) -> Optional[BytesIO]:
13 | """Read a specific file from the data source."""
14 | pass
15 |
16 | @abstractmethod
17 | def read_all(self) -> List[str]:
18 | """Retrieve a list of all valid files in the data source."""
19 | pass
--------------------------------------------------------------------------------
/common/sources/enums/source_enums.py:
--------------------------------------------------------------------------------
1 | from enum import Enum
2 |
3 | class SourceType(Enum):
4 | LOCAL_FS = "local_fs"
5 | S3 = "s3"
6 |
--------------------------------------------------------------------------------
/common/sources/local_fs_source.py:
--------------------------------------------------------------------------------
1 | from common.sources.base.base_source import BaseSource
2 | from typing import Dict, Optional, Any, List, Union
3 | import os
4 | from io import BytesIO
5 | import logging
6 | import mimetypes
7 |
8 | logger = logging.getLogger(__name__)
9 |
10 | class LocalFSIntegration(BaseSource):
11 | ALLOWED_EXTENSIONS = {'.pdf', '.pptx', '.docx'}
12 | ALLOWED_MIMETYPES = {
13 | 'application/pdf',
14 | 'application/vnd.openxmlformats-officedocument.presentationml.presentation',
15 | 'application/vnd.openxmlformats-officedocument.wordprocessingml.document'
16 | }
17 |
18 | def __init__(self, base_path: str, additional_params: Optional[Dict[str, Any]] = None) -> None:
19 | self.base_path: str = base_path
20 | self.additional_params = additional_params or {}
21 | self.connect()
22 |
23 | def connect(self) -> None:
24 | if not os.path.exists(self.base_path):
25 | raise ValueError(f"The specified base path does not exist: {self.base_path}")
26 |
27 | @staticmethod
28 | def is_valid_file(file_path: str) -> bool:
29 | _, ext = os.path.splitext(file_path)
30 | if ext.lower() not in LocalFSIntegration.ALLOWED_EXTENSIONS:
31 | return False
32 |
33 | mime_type, _ = mimetypes.guess_type(file_path)
34 | return mime_type in LocalFSIntegration.ALLOWED_MIMETYPES
35 |
36 | def read(self, data: Dict[str, Any]) -> Optional[BytesIO]:
37 | file_key: Optional[str] = data.get('file_key')
38 | if not file_key:
39 | raise ValueError("The 'file_key' must be provided in the data dictionary.")
40 |
41 | file_path = os.path.join(self.base_path, file_key)
42 | if not os.path.exists(file_path):
43 | logger.warning(f"File {file_path} not found.")
44 | return None
45 |
46 | if not self.is_valid_file(file_path):
47 | logger.warning(f"File {file_path} is not a valid PDF, PowerPoint, or Word document.")
48 | return None
49 |
50 | try:
51 | with open(file_path, 'rb') as file:
52 | return BytesIO(file.read())
53 | except Exception as e:
54 | logger.error(f"Error reading file {file_path}: {str(e)}")
55 | return None
56 |
57 | def read_all(self) -> List[str]:
58 | """Retrieve a list of all valid files in the base directory."""
59 | try:
60 | files = os.listdir(self.base_path)
61 | valid_files = [
62 | f for f in files
63 | if self.is_valid_file(os.path.join(self.base_path, f))
64 | ]
65 | logger.info(f"Retrieved {len(valid_files)} valid files from {self.base_path}")
66 | return valid_files
67 | except Exception as e:
68 | logger.error(f"Error listing files in {self.base_path}: {e}")
69 | return []
--------------------------------------------------------------------------------
/common/sources/s3_source.py:
--------------------------------------------------------------------------------
1 | from common.sources.base.base_source import BaseSource
2 | from typing import Dict, Optional, Any, List
3 | from io import BytesIO
4 | import boto3
5 | import os
6 | import logging
7 |
8 | logger = logging.getLogger(__name__)
9 |
10 | class S3Integration(BaseSource):
11 | def __init__(self, bucket_name: str, additional_params: Optional[Dict[str, Any]] = None) -> None:
12 | self.bucket_name = bucket_name
13 | self.additional_params = additional_params or {}
14 | self.s3_client = boto3.client(
15 | 's3',
16 | region_name=self.additional_params.get('region_name'),
17 | aws_access_key_id=self.additional_params.get('aws_access_key_id'),
18 | aws_secret_access_key=self.additional_params.get('aws_secret_access_key'),
19 | aws_session_token=self.additional_params.get('aws_session_token')
20 | )
21 | self.connect()
22 |
23 | def connect(self) -> None:
24 | try:
25 | self.s3_client.head_bucket(Bucket=self.bucket_name)
26 | logger.info(f"Successfully connected to S3 bucket: {self.bucket_name}")
27 | except Exception as e:
28 | logger.error(f"Failed to connect to S3 bucket {self.bucket_name}: {e}")
29 | raise
30 |
31 | def read(self, data: Dict[str, Any]) -> Optional[BytesIO]:
32 | file_key: Optional[str] = data.get('file_key')
33 | if not file_key:
34 | raise ValueError("The 'file_key' must be provided in the data dictionary.")
35 |
36 | try:
37 | response = self.s3_client.get_object(Bucket=self.bucket_name, Key=file_key)
38 | return BytesIO(response['Body'].read())
39 | except Exception as e:
40 | logger.error(f"Error reading file from S3: {e}")
41 | return None
42 |
43 | def read_all(self) -> List[str]:
44 | """Retrieve a list of all valid files in the S3 bucket."""
45 | try:
46 | paginator = self.s3_client.get_paginator('list_objects_v2')
47 | page_iterator = paginator.paginate(Bucket=self.bucket_name)
48 | valid_files = []
49 | for page in page_iterator:
50 | contents = page.get('Contents', [])
51 | for obj in contents:
52 | key = obj['Key']
53 | if self.is_valid_file(key):
54 | valid_files.append(key)
55 | logger.info(f"Retrieved {len(valid_files)} valid files from S3 bucket {self.bucket_name}")
56 | return valid_files
57 | except Exception as e:
58 | logger.error(f"Error listing files in S3 bucket {self.bucket_name}: {e}")
59 | return []
60 |
61 | @staticmethod
62 | def is_valid_file(file_key: str) -> bool:
63 | allowed_extensions = {'.pdf', '.pptx', '.docx'}
64 | _, ext = os.path.splitext(file_key)
65 | return ext.lower() in allowed_extensions
--------------------------------------------------------------------------------
/common/sources/source_factory.py:
--------------------------------------------------------------------------------
1 | from typing import Dict, Any
2 | from common.sources.enums.source_enums import SourceType
3 | from common.sources.local_fs_source import LocalFSIntegration
4 | from common.sources.s3_source import S3Integration
5 |
6 | class SourceFactory:
7 | @staticmethod
8 | def create_source(source_type: str, documents_location: str = None, additional_params: Dict[str, Any] = None):
9 | if not source_type:
10 | raise ValueError("source_type must be provided.")
11 |
12 | try:
13 | source_type_enum = SourceType(source_type.lower())
14 | except ValueError:
15 | raise ValueError(f"Invalid source type. Allowed values are: {', '.join([s.value for s in SourceType])}")
16 |
17 | if source_type_enum == SourceType.LOCAL_FS:
18 | return LocalFSIntegration(
19 | base_path=documents_location,
20 | additional_params=additional_params
21 | )
22 | elif source_type_enum == SourceType.S3:
23 | return S3Integration(
24 | bucket_name=documents_location,
25 | additional_params=additional_params
26 | )
27 | else:
28 | raise ValueError(f"Unsupported source type: {source_type_enum}")
--------------------------------------------------------------------------------
/common/text_extraction/text_extractor.py:
--------------------------------------------------------------------------------
1 | import PyPDF2
2 | from typing import List
3 | from io import BytesIO
4 | import logging
5 | from dotenv import load_dotenv
6 | from langchain.schema import SystemMessage, HumanMessage
7 | from langsmith import Client
8 | import asyncio
9 | import time
10 | from concurrent.futures import ThreadPoolExecutor
11 | from common.prompts.prompt_enums import PromptType
12 | import json
13 | from datetime import datetime
14 | from common.redis.redis_config import get_redis_connection
15 |
16 | load_dotenv()
17 |
18 | logger = logging.getLogger(__name__)
19 |
20 | def get_pdf_page_count(pdf_stream):
21 | try:
22 | pdf_reader = PyPDF2.PdfReader(pdf_stream)
23 | return len(pdf_reader.pages)
24 | except PyPDF2.errors.PdfReadError as e:
25 | raise PyPDF2.errors.PdfReadError(f"Error reading PDF file: {str(e)}")
26 |
27 | def extract_page_as_markdown(file_stream: BytesIO, page_number: int) -> str:
28 | import os
29 | from PyPDF2 import PdfReader, PdfWriter
30 | from markitdown import MarkItDown
31 |
32 | temp_file = "temp.pdf"
33 |
34 | try:
35 | if not isinstance(page_number, int) or page_number < 0:
36 | raise ValueError(f"Invalid page number: {page_number}")
37 |
38 | reader = PdfReader(file_stream)
39 |
40 | if page_number >= len(reader.pages):
41 | raise ValueError(f"Page number {page_number} exceeds document length of {len(reader.pages)} pages")
42 |
43 | writer = PdfWriter()
44 | writer.add_page(reader.pages[page_number])
45 |
46 | try:
47 | with open(temp_file, 'wb') as output_file:
48 | writer.write(output_file)
49 | except IOError as e:
50 | raise IOError(f"Failed to write temporary PDF file: {str(e)}")
51 |
52 | try:
53 | markitdown = MarkItDown()
54 | result = markitdown.convert(temp_file)
55 |
56 | if not result or not result.text_content:
57 | logger.warning(f"No text content extracted from page {page_number}")
58 | return ""
59 |
60 | return result.text_content
61 |
62 | except Exception as e:
63 | raise Exception(f"Error converting PDF to markdown: {str(e)}")
64 |
65 | except Exception as e:
66 | logger.error(f"Error in extract_page_markdown: {str(e)}")
67 | raise
68 |
69 | finally:
70 | try:
71 | if os.path.exists(temp_file):
72 | os.remove(temp_file)
73 | except Exception as e:
74 | logger.error(f"Failed to remove temporary file {temp_file}: {str(e)}")
75 |
76 | def preprocess_messages(raw_payload):
77 | messages = []
78 | if hasattr(raw_payload, 'to_messages'):
79 | for message in raw_payload.to_messages():
80 | if isinstance(message, SystemMessage):
81 | messages.append({"role": "system", "content": message.content})
82 | elif isinstance(message, HumanMessage):
83 | messages.append({"role": "user", "content": message.content})
84 | else:
85 | logger.warning(f"Unexpected message type: {type(message)}")
86 | else:
87 | logger.warning(f"Unexpected raw_payload format: {type(raw_payload)}")
88 | return messages
89 |
90 | def process_page(client, prompt, page_number: int, page_text: str, formatted_keywords: str) -> tuple[int, dict]:
91 | try:
92 | raw_payload = prompt.invoke({
93 | "first_value": page_text,
94 | "second_value": formatted_keywords
95 | })
96 |
97 | processed_messages = preprocess_messages(raw_payload)
98 | if processed_messages:
99 | response = client.do_completion(processed_messages)
100 | logger.info(f"Response for page {page_number}: {response}")
101 |
102 | response_data = {
103 | 'response': response,
104 | 'keywords': formatted_keywords,
105 | 'timestamp': datetime.now().isoformat(),
106 | 'is_relevant': "yes" in response[:20].lower()
107 | }
108 |
109 | return (page_number if "yes" in response[:20].lower() else -1), response_data
110 |
111 | except Exception as e:
112 | logger.error(f"Error processing page {page_number}: {e}")
113 | return -1, {
114 | 'response': f"Error: {str(e)}",
115 | 'keywords': formatted_keywords,
116 | 'page_text': page_text,
117 | 'timestamp': datetime.now().isoformat(),
118 | 'is_relevant': False,
119 | 'error': True
120 | }
121 |
122 | async def find_common_pages(client, file_stream: BytesIO, formatted_keywords: str) -> List[int]:
123 | try:
124 | start_time = time.time()
125 | pdf_reader = PyPDF2.PdfReader(file_stream)
126 | langsmith_client = Client()
127 | prompt = langsmith_client.pull_prompt(PromptType.RELEVANT_PAGE_FINDER_V2.value)
128 | logger.info(f"MODEL TYPE: {type(client)}")
129 |
130 | run_id = datetime.now().strftime('%Y%m%d_%H%M%S')
131 |
132 | loop = asyncio.get_event_loop()
133 | with ThreadPoolExecutor(max_workers=10) as executor:
134 | tasks = []
135 | for page_number in range(len(pdf_reader.pages)):
136 | page_text = extract_page_as_markdown(file_stream, page_number)
137 | task = loop.run_in_executor(
138 | executor,
139 | process_page,
140 | client,
141 | prompt,
142 | page_number,
143 | page_text,
144 | formatted_keywords
145 | )
146 | tasks.append(task)
147 |
148 | results = await asyncio.gather(*tasks)
149 |
150 | # Process results and store in Redis
151 | page_responses = {}
152 | relevant_pages = []
153 |
154 | logger.info(f"Processing {len(results)} results")
155 |
156 | for page_number, response_data in results:
157 | page_responses[str(page_number if page_number != -1 else results.index((page_number, response_data)))] = response_data
158 |
159 | if page_number != -1:
160 | relevant_pages.append(page_number)
161 |
162 | logger.info(f"Collected responses for {len(page_responses)} pages")
163 |
164 | if page_responses:
165 | try:
166 | redis_client = await get_redis_connection()
167 | redis_key = f"page_responses:{run_id}"
168 | data_to_store = json.dumps(page_responses)
169 |
170 | await redis_client.set(redis_key, data_to_store)
171 | await redis_client.expire(redis_key, 86400) # 1 day TTL
172 |
173 | stored_data = await redis_client.get(redis_key)
174 | if stored_data:
175 | logger.info(f"Successfully stored and verified data in Redis for run_id: {run_id}")
176 | else:
177 | logger.error("Failed to verify data storage in Redis")
178 |
179 | except Exception as e:
180 | logger.error(f"Redis storage error: {str(e)}")
181 | else:
182 | logger.warning("No page responses to store in Redis")
183 |
184 | end_time = time.time()
185 | total_time = end_time - start_time
186 | logger.info(f"Processed {len(pdf_reader.pages)} pages in {total_time:.2f} seconds")
187 | logger.info(f"Relevant Pages: {relevant_pages}")
188 |
189 | return relevant_pages
190 |
191 | except Exception as e:
192 | logger.error(f"Error finding common pages: {e}")
193 | return []
--------------------------------------------------------------------------------
/docker-compose.yml:
--------------------------------------------------------------------------------
1 | services:
2 | pipeline:
3 | build:
4 | context: .
5 | dockerfile: application/pipeline/Dockerfile
6 | ports:
7 | - "8100:8100"
8 | volumes:
9 | - ${PWD}/common:/app/common
10 | - ${PWD}/application/pipeline:/app/application/pipeline
11 | env_file: .env
12 | environment:
13 | - REDIS_HOST=redis
14 | - REDIS_PORT=6379
15 | - PYTHONPATH=/app
16 | depends_on:
17 | - redis
18 | networks:
19 | - marly_default
20 |
21 | extraction:
22 | build:
23 | context: .
24 | dockerfile: application/extraction/Dockerfile
25 | volumes:
26 | - ${PWD}/common:/app/common
27 | - ${PWD}/application/extraction:/app/application/extraction
28 | env_file: .env
29 | environment:
30 | - REDIS_HOST=redis
31 | - REDIS_PORT=6379
32 | - PYTHONPATH=/app
33 | depends_on:
34 | - redis
35 | networks:
36 | - marly_default
37 |
38 | transformation:
39 | build:
40 | context: .
41 | dockerfile: application/transformation/Dockerfile
42 | volumes:
43 | - ${PWD}/common:/app/common
44 | - ${PWD}/application/transformation:/app/application/transformation
45 | env_file: .env
46 | environment:
47 | - REDIS_HOST=redis
48 | - REDIS_PORT=6379
49 | - PYTHONPATH=/app
50 | depends_on:
51 | - redis
52 | networks:
53 | - marly_default
54 |
55 | redis:
56 | image: "redis:alpine"
57 | ports:
58 | - "6379:6379"
59 | command: redis-server --save "" --appendonly no
60 | networks:
61 | - marly_default
62 |
63 | networks:
64 | marly_default:
65 | name: marly_default
66 |
--------------------------------------------------------------------------------
/examples/ai-workers/ai-sdr/auth/anon_helper.py:
--------------------------------------------------------------------------------
1 | import requests
2 | from datetime import datetime, timedelta
3 | import json
4 | from dotenv import load_dotenv
5 | import os
6 | import logging
7 |
8 | load_dotenv()
9 |
10 | BASE_URL = "https://svc.sandbox.anon.com/actions/linkedin/"
11 | ANON_USER_ID = os.environ.get("ANON_USER_ID")
12 |
13 | def get_headers():
14 | return {
15 | "Authorization": f"Bearer {os.getenv('ANON_TOKEN')}"
16 | }
17 |
18 | def make_api_request(endpoint, method="GET", params=None):
19 | url = f"{BASE_URL}/{endpoint}"
20 | response = requests.request(method, url, headers=get_headers(), params=params)
21 | return json.loads(response.text)
22 |
23 | def get_recent_conversations(days=7):
24 | params = {"appUserId": ANON_USER_ID}
25 | data = make_api_request("listConversations", params=params)
26 | current_date = datetime.now()
27 | cutoff_date = current_date - timedelta(days=days)
28 |
29 | recent_profile_ids = []
30 | for conversation in data['conversations']:
31 | conversation_date = datetime.strptime(conversation['timestamp'], "%Y-%m-%dT%H:%M:%S.%fZ")
32 | if conversation_date > cutoff_date:
33 | for profile in conversation['profiles']:
34 | if not profile['isSelf']:
35 | profile_id = profile.get('id', '')
36 | if profile_id.startswith('profile-'):
37 | profile_id = profile_id[8:] # Remove 'profile-' prefix
38 | if profile_id:
39 | recent_profile_ids.append(profile_id)
40 |
41 | return recent_profile_ids
42 |
43 | def get_profile(profile_id):
44 | params = {"id": profile_id, "appUserId": ANON_USER_ID}
45 | return make_api_request("getProfile", params=params)
46 |
47 | def raw_profile_data():
48 | recent_profile_ids = get_recent_conversations()
49 | profiles = []
50 |
51 | for profile_id in recent_profile_ids:
52 | try:
53 | if profile_id.isdigit():
54 | profile_id = f"ACoAA{profile_id}"
55 |
56 | profile_data = get_profile(profile_id)
57 |
58 | if profile_data:
59 | profiles.append(profile_data)
60 | else:
61 | logging.warning(f"No data returned for profile ID: {profile_id}")
62 | except Exception as e:
63 | logging.error(f"Error fetching profile {profile_id}: {str(e)}")
64 |
65 | return profiles
66 |
--------------------------------------------------------------------------------
/examples/ai-workers/ai-sdr/contacts.db:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/velocitybolt/open-extract/64d06b8f79e18d299767ee5fcf9746e9c9425bb4/examples/ai-workers/ai-sdr/contacts.db
--------------------------------------------------------------------------------
/examples/ai-workers/ai-sdr/main.py:
--------------------------------------------------------------------------------
1 | from transformation.marly_helper import process_data
2 | from auth.anon_helper import raw_profile_data
3 | from output_source.sql_helper import SQLiteHelper
4 |
5 | import json
6 |
7 | def extract_metrics_data(results):
8 | if not results or not results[0].results:
9 | return None
10 |
11 | metrics = results[0].results[0]['metrics']
12 | if not metrics:
13 | return None
14 |
15 | json_string = metrics[next(iter(metrics.keys()))]
16 |
17 | data = json.loads(json_string)
18 |
19 | if 'data_entries' in data:
20 | records = data['data_entries']
21 | else:
22 | records = [data]
23 |
24 | cleaned_data = []
25 | for record in records:
26 | cleaned_record = {
27 | 'id': record.get('id'),
28 | 'first_name': record.get('firstName'),
29 | 'last_name': record.get('lastName'),
30 | 'headline': record.get('headline'),
31 | 'location': record.get('location'),
32 | 'summary': record.get('summary'),
33 | 'connections_count': record.get('connectionsCount')
34 | }
35 | cleaned_data.append(cleaned_record)
36 |
37 | return cleaned_data
38 |
39 | if __name__ == "__main__":
40 | raw_data = raw_profile_data()
41 | processed_data = process_data(raw_data)
42 | cleaned_data = extract_metrics_data(processed_data)
43 | db = SQLiteHelper()
44 | success = db.insert_contact(cleaned_data)
45 | db.get_all_contacts()
46 |
--------------------------------------------------------------------------------
/examples/ai-workers/ai-sdr/output_source/sql_helper.py:
--------------------------------------------------------------------------------
1 | import sqlite3
2 | from typing import Dict, List, Union
3 | from tabulate import tabulate
4 | import json
5 |
6 | class SQLiteHelper:
7 | def __init__(self, db_path: str = 'contacts.db'):
8 | self.db_path = db_path
9 | self.create_tables()
10 |
11 | def create_tables(self):
12 | """Create the necessary tables if they don't exist"""
13 | conn = sqlite3.connect(self.db_path)
14 | cursor = conn.cursor()
15 |
16 | cursor.execute("""
17 | CREATE TABLE IF NOT EXISTS contacts (
18 | id TEXT PRIMARY KEY,
19 | first_name TEXT,
20 | last_name TEXT,
21 | headline TEXT,
22 | location TEXT,
23 | summary TEXT,
24 | connections_count INTEGER
25 | )
26 | """)
27 |
28 | conn.commit()
29 | conn.close()
30 |
31 | def insert_contact(self, contact_data: Union[Dict, List[Dict]]) -> bool:
32 | """Insert one or more contacts into the database"""
33 | conn = sqlite3.connect(self.db_path)
34 | cursor = conn.cursor()
35 |
36 | insert_sql = """
37 | INSERT OR REPLACE INTO contacts (
38 | id, first_name, last_name, headline, location, summary,
39 | connections_count
40 | ) VALUES (?, ?, ?, ?, ?, ?, ?)
41 | """
42 |
43 | try:
44 | # Handle both single dict and list of dicts
45 | contacts_to_insert = contact_data if isinstance(contact_data, list) else [contact_data]
46 |
47 | for contact in contacts_to_insert:
48 | cursor.execute(insert_sql, (
49 | contact['id'],
50 | contact['first_name'],
51 | contact['last_name'],
52 | contact['headline'],
53 | contact['location'],
54 | contact['summary'],
55 | contact['connections_count'],
56 | ))
57 | conn.commit()
58 | return True
59 | except sqlite3.Error as e:
60 | print(f"Error inserting contact(s): {e}")
61 | return False
62 | finally:
63 | conn.close()
64 |
65 | def get_all_contacts(self) -> None:
66 | """Retrieve and display all contacts from the database in a formatted table"""
67 | conn = sqlite3.connect(self.db_path)
68 | cursor = conn.cursor()
69 |
70 | cursor.execute("SELECT id, first_name, last_name, headline, location, summary, connections_count FROM contacts")
71 | results = cursor.fetchall()
72 |
73 | conn.close()
74 |
75 | if results:
76 | keys = ['id', 'first_name', 'last_name', 'headline', 'location',
77 | 'summary', 'connections_count']
78 |
79 | print(tabulate(results,
80 | headers=keys,
81 | tablefmt='grid',
82 | maxcolwidths=[None, None, None, 30, 20, 40, None]))
83 |
84 | return [dict(zip(keys, result)) for result in results]
85 |
86 | print("No contacts found in database")
87 | return []
88 |
--------------------------------------------------------------------------------
/examples/ai-workers/ai-sdr/transformation/marly_helper.py:
--------------------------------------------------------------------------------
1 | import json
2 | import logging
3 | import time
4 | from dotenv import load_dotenv
5 | import os
6 | from marly import Marly
7 |
8 | load_dotenv()
9 |
10 | logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
11 |
12 | BASE_URL = "http://localhost:8100"
13 |
14 | def process_data(raw_data):
15 | schema_1 = {
16 | "id": "Unique identifier for the profile for all contacts",
17 | "firstName": "First name of the person for all contacts",
18 | "lastName": "Last name of the person for all contacts",
19 | "headline": "Professional headline or tagline for all contacts",
20 | "location": "Geographic location of the person for all contacts",
21 | "summary": "Brief professional summary or bio for all contacts",
22 | "connectionsCount": "Number of LinkedIn connections for all contacts"
23 | }
24 |
25 | client = Marly(base_url=BASE_URL)
26 |
27 | try:
28 | pipeline_response_model = client.pipelines.create(
29 | api_key=os.getenv("AZURE_OPENAI_API_KEY"),
30 | provider_model_name=os.getenv("AZURE_MODEL_NAME"),
31 | provider_type="azure",
32 | workloads=[
33 | {
34 | "destination": "sqlite",
35 | "documents_location": "contacts_table",
36 | "schemas": [json.dumps(schema_1)],
37 | "raw_data": json.dumps(raw_data),
38 | }
39 | ],
40 | additional_params={
41 | "azure_endpoint": os.getenv("AZURE_ENDPOINT"),
42 | "azure_deployment": os.getenv("AZURE_DEPLOYMENT_ID"),
43 | "api_version": os.getenv("AZURE_API_VERSION")
44 | }
45 | )
46 |
47 | logging.debug(f"Task ID: {pipeline_response_model.task_id}")
48 |
49 | max_attempts = 5
50 | attempt = 0
51 | while attempt < max_attempts:
52 | time.sleep(30)
53 |
54 | results = client.pipelines.retrieve(pipeline_response_model.task_id)
55 | logging.debug(f"Poll attempt {attempt + 1}: Status - {results.status}")
56 |
57 | if results.status == 'COMPLETED':
58 | logging.debug(f"Pipeline completed with results: {results.results}")
59 | return results.results
60 | elif results.status == 'FAILED':
61 | logging.error(f"Error: {results.error_message}")
62 | return None
63 |
64 | attempt += 1
65 |
66 | logging.warning("Timeout: Pipeline execution took too long.")
67 | return None
68 |
69 | except Exception as e:
70 | logging.error(f"Error in pipeline process: {e}")
71 | return None
72 |
--------------------------------------------------------------------------------
/examples/example_files/lacers.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/velocitybolt/open-extract/64d06b8f79e18d299767ee5fcf9746e9c9425bb4/examples/example_files/lacers.pdf
--------------------------------------------------------------------------------
/examples/example_files/lacers_reduced.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/velocitybolt/open-extract/64d06b8f79e18d299767ee5fcf9746e9c9425bb4/examples/example_files/lacers_reduced.pdf
--------------------------------------------------------------------------------
/examples/notebooks/autogen_example/OAI_CONFIG_LIST.json:
--------------------------------------------------------------------------------
1 | [
2 | {
3 | "model": "llama-3.1-70b-versatile",
4 | "api_key": "",
5 | "api_type": "groq",
6 | "tags": ["tool", "llama-3.1-70b-versatile"]
7 | }
8 | ]
--------------------------------------------------------------------------------
/examples/notebooks/autogen_example/lacers_reduced.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/velocitybolt/open-extract/64d06b8f79e18d299767ee5fcf9746e9c9425bb4/examples/notebooks/autogen_example/lacers_reduced.pdf
--------------------------------------------------------------------------------
/examples/notebooks/autogen_example/plot.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/velocitybolt/open-extract/64d06b8f79e18d299767ee5fcf9746e9c9425bb4/examples/notebooks/autogen_example/plot.png
--------------------------------------------------------------------------------
/examples/notebooks/langgraph_example/diagram.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/velocitybolt/open-extract/64d06b8f79e18d299767ee5fcf9746e9c9425bb4/examples/notebooks/langgraph_example/diagram.jpeg
--------------------------------------------------------------------------------
/examples/notebooks/langgraph_example/lacers_reduced.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/velocitybolt/open-extract/64d06b8f79e18d299767ee5fcf9746e9c9425bb4/examples/notebooks/langgraph_example/lacers_reduced.pdf
--------------------------------------------------------------------------------
/examples/notebooks/langgraph_example/notebooklm_workflow.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/velocitybolt/open-extract/64d06b8f79e18d299767ee5fcf9746e9c9425bb4/examples/notebooks/langgraph_example/notebooklm_workflow.jpeg
--------------------------------------------------------------------------------
/examples/notebooks/langgraph_example/wkflow.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/velocitybolt/open-extract/64d06b8f79e18d299767ee5fcf9746e9c9425bb4/examples/notebooks/langgraph_example/wkflow.jpeg
--------------------------------------------------------------------------------
/examples/scripts/api_example.py:
--------------------------------------------------------------------------------
1 | import base64
2 | import json
3 | import zlib
4 | import logging
5 | import time
6 | import requests
7 |
8 | logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s')
9 |
10 | PDF_FILE = ""
11 | API_KEY = ""
12 | BASE_URL = "https://api.marly.ai";
13 |
14 | def read_and_encode_pdf(file_path):
15 | with open(file_path, "rb") as file:
16 | pdf_content = base64.b64encode(zlib.compress(file.read())).decode('utf-8')
17 | logging.debug(f"{file_path} read and encoded")
18 | return pdf_content
19 |
20 | def get_pipeline_results(task_id):
21 | logging.debug(f"Fetching results for task ID: {task_id}")
22 | response = requests.get(f"{BASE_URL}/pipelines/{task_id}", headers={"marly-api-key": API_KEY})
23 | return response.json()
24 |
25 | def process_pdf(pdf_file):
26 | pdf_content = read_and_encode_pdf(pdf_file)
27 |
28 | schema_1 = {
29 | "Type of change": "Description of the type of change",
30 | "Location": "Location of the change",
31 | "Item": "Description of the item"
32 | }
33 |
34 | pipeline_request = {
35 | "license_key": "1234567890",
36 | "workloads": [
37 | {
38 | "pdf_stream": pdf_content,
39 | "schemas": [json.dumps(schema_1)]
40 | }
41 | ]
42 | }
43 |
44 | logging.debug("Sending POST request to pipeline endpoint")
45 | response = requests.post(f"{BASE_URL}/pipelines", json=pipeline_request, headers={"marly-api-key": API_KEY})
46 |
47 | logging.debug(f"Response status code: {response.status_code}")
48 | logging.debug(f"Response headers: {response.headers}")
49 | logging.debug(f"Response content: {response.text}")
50 |
51 | result = response.json()
52 | task_id = result.get("task_id")
53 | if not task_id:
54 | raise ValueError("Invalid task_id: task_id is None or empty")
55 | logging.debug(f"Task ID: {task_id}")
56 |
57 | max_attempts = 5
58 | attempt = 0
59 | while attempt < max_attempts:
60 | time.sleep(35)
61 |
62 | results = get_pipeline_results(task_id)
63 | logging.debug(f"Poll attempt {attempt + 1}: Status - {results['status']}")
64 |
65 | if results['status'] == 'Completed':
66 | return results
67 | elif results['status'] == 'Failed':
68 | logging.error(f"Error: {results.get('error_message', 'Unknown error')}")
69 | return None
70 |
71 | attempt += 1
72 |
73 | logging.warning("Timeout: Pipeline execution took too long.")
74 | return None
75 |
76 | def main():
77 | results = process_pdf(PDF_FILE)
78 |
79 | if results:
80 | print("Raw API Response:")
81 | print(json.dumps(results, indent=2))
82 |
83 | if 'results' in results and results['results']:
84 | for i, result in enumerate(results['results']):
85 | print(f"\nResult {i + 1}:")
86 | metrics = result.get('metrics', {})
87 | print("Metrics:")
88 | print(json.dumps(metrics, indent=2))
89 |
90 | schema_3 = metrics.get('schema_0', '{}')
91 | print("\nSchema 3:")
92 | print(schema_3)
93 |
94 | try:
95 | schema_3_json = json.loads(schema_3)
96 | print("\nSchema 3 (parsed):")
97 | print(json.dumps(schema_3_json, indent=2))
98 |
99 | for key, value in schema_3_json.items():
100 | print(f"\n{key}:")
101 | print(value)
102 |
103 | except json.JSONDecodeError:
104 | print("\nFailed to parse schema_3 as JSON")
105 | else:
106 | print("No results found in the API response")
107 | else:
108 | print("Failed to process PDF. Please check the logs for more information.")
109 |
110 | if __name__ == "__main__":
111 | main()
112 |
--------------------------------------------------------------------------------
/examples/scripts/azure_example.py:
--------------------------------------------------------------------------------
1 | import json
2 | import logging
3 | import time
4 | from dotenv import load_dotenv
5 | import os
6 | from marly import Marly
7 |
8 | load_dotenv()
9 |
10 | logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
11 |
12 | BASE_URL = "http://localhost:8100"
13 |
14 | def process_pdf():
15 | schema_1 = {
16 | "Firm": "The name of the firm",
17 | "Number of Funds": "The number of funds managed by the firm",
18 | "Commitment": "The commitment amount in millions of dollars",
19 | "Percent of Total Comm": "The percentage of total commitment",
20 | "Exposure (FMV + Unfunded)": "The exposure including fair market value and unfunded commitments in millions of dollars",
21 | "Percent of Total Exposure": "The percentage of total exposure",
22 | "TVPI": "Total Value to Paid-In multiple",
23 | "Net IRR": "Net Internal Rate of Return as a percentage"
24 | }
25 |
26 | client = Marly(base_url=BASE_URL)
27 |
28 | try:
29 | pipeline_response_model = client.pipelines.create(
30 | api_key=os.getenv("AZURE_OPENAI_API_KEY"),
31 | provider_model_name=os.getenv("AZURE_MODEL_NAME"),
32 | provider_type="azure",
33 | workloads=[
34 | {
35 | "file_name": "lacers reduced",
36 | "data_source": "local_fs",
37 | "documents_location": "/app/example_files",
38 | "schemas": [json.dumps(schema_1)],
39 | }
40 | ],
41 | additional_params={
42 | "azure_endpoint": os.getenv("AZURE_ENDPOINT"),
43 | "azure_deployment": os.getenv("AZURE_DEPLOYMENT_ID"),
44 | "api_version": os.getenv("AZURE_API_VERSION")
45 | }
46 | )
47 |
48 | logging.debug(f"Task ID: {pipeline_response_model.task_id}")
49 |
50 | max_attempts = 5
51 | attempt = 0
52 | while attempt < max_attempts:
53 | time.sleep(30)
54 |
55 | results = client.pipelines.retrieve(pipeline_response_model.task_id)
56 | logging.debug(f"Poll attempt {attempt + 1}: Status - {results.status}")
57 |
58 | if results.status == 'COMPLETED':
59 | logging.debug(f"Pipeline completed with results: {results.results}")
60 | return results.results
61 | elif results.status == 'FAILED':
62 | logging.error(f"Error: {results.error_message}")
63 | return None
64 |
65 | attempt += 1
66 |
67 | logging.warning("Timeout: Pipeline execution took too long.")
68 | return None
69 |
70 | except Exception as e:
71 | logging.error(f"Error in pipeline process: {e}")
72 | return None
73 |
74 | if __name__ == "__main__":
75 | results = process_pdf()
76 | print(results)
--------------------------------------------------------------------------------
/examples/scripts/cerebras_example.py:
--------------------------------------------------------------------------------
1 | import base64
2 | import json
3 | import zlib
4 | import logging
5 | import os
6 | from dotenv import load_dotenv
7 | from marly import Marly
8 | import time
9 |
10 | load_dotenv()
11 |
12 | logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
13 |
14 | PDF_FILE_PATH = "../example_files/lacers_reduced.pdf"
15 | BASE_URL = "http://localhost:8100"
16 |
17 | SCHEMA_1 = {
18 | "Firm": "The name of the firm",
19 | "Number of Funds": "The number of funds managed by the firm",
20 | "Commitment": "The commitment amount in millions of dollars",
21 | "Percent of Total Comm": "The percentage of total commitment",
22 | "Exposure (FMV + Unfunded)": "The exposure including fair market value and unfunded commitments in millions of dollars",
23 | "Percent of Total Exposure": "The percentage of total exposure",
24 | "TVPI": "Total Value to Paid-In multiple",
25 | "Net IRR": "Net Internal Rate of Return as a percentage"
26 | }
27 |
28 | def read_and_encode_pdf(file_path):
29 | with open(file_path, "rb") as file:
30 | pdf_content = base64.b64encode(zlib.compress(file.read())).decode('utf-8')
31 | logging.debug(f"{file_path} read and encoded")
32 | return pdf_content
33 |
34 | def process_pdf(pdf_file):
35 | pdf_content = read_and_encode_pdf(pdf_file)
36 |
37 | client = Marly(base_url="http://localhost:8100")
38 |
39 | try:
40 | pipeline_response = client.pipelines.create(
41 | api_key=os.getenv("CEREBRAS_API_KEY"),
42 | provider_model_name="llama3.3-70b",
43 | provider_type="cerebras",
44 | workloads=[{"raw_data": pdf_content, "schemas": [json.dumps(SCHEMA_1)]}],
45 | )
46 |
47 | while True:
48 | results = client.pipelines.retrieve(pipeline_response.task_id)
49 | if results.status == 'COMPLETED':
50 | return results.results
51 | elif results.status == 'FAILED':
52 | return None
53 | time.sleep(15)
54 |
55 | except Exception as e:
56 | logging.error(f"Error in pipeline process: {e}")
57 | return None
58 |
59 | # Usage
60 | if __name__ == "__main__":
61 | result = process_pdf(PDF_FILE_PATH)
62 | print(result)
63 |
--------------------------------------------------------------------------------
/examples/scripts/data_loading_example.py:
--------------------------------------------------------------------------------
1 | import json
2 | import logging
3 | import time
4 | from dotenv import load_dotenv
5 | import os
6 | from marly import Marly
7 |
8 | load_dotenv()
9 |
10 | logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
11 |
12 | BASE_URL = "http://localhost:8100"
13 |
14 | def process_pdf():
15 | schema_1 = {
16 | "Firm": "The name of the firm",
17 | "Number of Funds": "The number of funds managed by the firm",
18 | "Commitment": "The commitment amount in millions of dollars",
19 | "Percent of Total Comm": "The percentage of total commitment",
20 | "Exposure (FMV + Unfunded)": "The exposure including fair market value and unfunded commitments in millions of dollars",
21 | "Percent of Total Exposure": "The percentage of total exposure",
22 | "TVPI": "Total Value to Paid-In multiple",
23 | "Net IRR": "Net Internal Rate of Return as a percentage"
24 | }
25 | raw_data = json.dumps({
26 | "Firm": "Evergreen Capital Partners",
27 | "Number of Funds": 7,
28 | "Commitment": 150.5,
29 | "Percent of Total Comm": 12.3,
30 | "Exposure (FMV + Unfunded)": 175.2,
31 | "Percent of Total Exposure": 14.8,
32 | "TVPI": 1.45,
33 | "Net IRR": 18.7,
34 | "Founded Year": 2005,
35 | "Headquarters": "New York, NY",
36 | "AUM": 3500.0,
37 | "Investment Strategy": "Growth Equity",
38 | "Target Industries": ["Technology", "Healthcare", "Consumer"],
39 | "Managing Partners": ["John Smith", "Sarah Johnson"],
40 | "Last Fund Closing Date": "2022-09-15",
41 | "ESG Focus": True
42 | })
43 |
44 | client = Marly(base_url=BASE_URL)
45 |
46 | try:
47 | pipeline_response_model = client.pipelines.create(
48 | api_key=os.getenv("AZURE_OPENAI_API_KEY"),
49 | provider_model_name=os.getenv("AZURE_MODEL_NAME"),
50 | provider_type="azure",
51 | workloads=[
52 | {
53 | "destination": "sqlite",
54 | "documents_location": "contacts_table",
55 | "schemas": [json.dumps(schema_1)],
56 | "raw_data": raw_data,
57 | }
58 | ],
59 | additional_params={
60 | "azure_endpoint": os.getenv("AZURE_ENDPOINT"),
61 | "azure_deployment": os.getenv("AZURE_DEPLOYMENT_ID"),
62 | "api_version": os.getenv("AZURE_API_VERSION")
63 | }
64 | )
65 |
66 | logging.debug(f"Task ID: {pipeline_response_model.task_id}")
67 |
68 | max_attempts = 5
69 | attempt = 0
70 | while attempt < max_attempts:
71 | time.sleep(5)
72 |
73 | results = client.pipelines.retrieve(pipeline_response_model.task_id)
74 | logging.debug(f"Poll attempt {attempt + 1}: Status - {results.status}")
75 |
76 | if results.status == 'COMPLETED':
77 | logging.debug(f"Pipeline completed with results: {results.results}")
78 | return results.results
79 | elif results.status == 'FAILED':
80 | logging.error(f"Error: {results.error_message}")
81 | return None
82 |
83 | attempt += 1
84 |
85 | logging.warning("Timeout: Pipeline execution took too long.")
86 | return None
87 |
88 | except Exception as e:
89 | logging.error(f"Error in pipeline process: {e}")
90 | return None
91 |
92 | if __name__ == "__main__":
93 | results = process_pdf()
94 | print(results)
95 |
--------------------------------------------------------------------------------
/examples/scripts/groq_example.py:
--------------------------------------------------------------------------------
1 | import base64
2 | import json
3 | import zlib
4 | import logging
5 | import requests
6 | from dotenv import load_dotenv
7 | import os
8 | import time
9 | load_dotenv()
10 |
11 | logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s')
12 |
13 | BASE_URL = "http://localhost:8100"
14 | PDF_FILE_PATH = "../example_files/lacers_reduced.pdf"
15 |
16 | def read_and_encode_pdf(file_path):
17 | with open(file_path, "rb") as file:
18 | pdf_content = base64.b64encode(zlib.compress(file.read())).decode('utf-8')
19 | logging.debug(f"{file_path} read and encoded")
20 | return pdf_content
21 |
22 | def get_pipeline_results(task_id):
23 | logging.debug(f"Fetching results for task ID: {task_id}")
24 | response = requests.get(f"{BASE_URL}/pipelines/{task_id}")
25 | return response.json()
26 |
27 | def process_pdf(pdf_file):
28 | pdf_content = read_and_encode_pdf(pdf_file)
29 |
30 | schema_1 = {
31 | "Firm": "The name of the firm",
32 | "Number of Funds": "The number of funds managed by the firm",
33 | "Commitment": "The commitment amount in millions of dollars",
34 | "Percent of Total Comm": "The percentage of total commitment",
35 | "Exposure (FMV + Unfunded)": "The exposure including fair market value and unfunded commitments in millions of dollars",
36 | "Percent of Total Exposure": "The percentage of total exposure",
37 | "TVPI": "Total Value to Paid-In multiple",
38 | "Net IRR": "Net Internal Rate of Return as a percentage"
39 | }
40 |
41 | pipeline_request = {
42 | "workloads": [
43 | {
44 | "raw_data": pdf_content,
45 | "schemas": [json.dumps(schema_1)]
46 | }
47 | ],
48 | "provider_type": "groq",
49 | "provider_model_name": "llama-3.1-70b-versatile",
50 | "api_key": os.getenv("GROQ_API_KEY"),
51 | "additional_params": {}
52 | }
53 |
54 | logging.debug("Sending POST request to pipeline endpoint")
55 | try:
56 | response = requests.post(f"{BASE_URL}/pipelines", json=pipeline_request)
57 | response.raise_for_status()
58 | except requests.exceptions.RequestException as e:
59 | logging.error(f"Error sending request: {e}")
60 | return
61 |
62 | logging.debug(f"Response status code: {response.status_code}")
63 | logging.debug(f"Response headers: {response.headers}")
64 | logging.debug(f"Response content: {response.text}")
65 |
66 | try:
67 | result = response.json()
68 | except json.JSONDecodeError:
69 | logging.error("Failed to decode JSON response")
70 | return
71 |
72 | task_id = result.get("task_id")
73 | if not task_id:
74 | logging.error("Invalid task_id: task_id is None or empty")
75 | return
76 | logging.debug(f"Task ID: {task_id}")
77 |
78 | max_attempts = 5
79 | attempt = 0
80 | while attempt < max_attempts:
81 | logging.debug(f"Waiting for pipeline to complete. Attempt {attempt + 1} of {max_attempts}")
82 | time.sleep(30)
83 |
84 | results = get_pipeline_results(task_id)
85 | logging.debug(f"Poll attempt {attempt + 1}: Status - {results['status']}")
86 |
87 | if results['status'] == 'COMPLETED':
88 | logging.debug(f"Pipeline completed with results: {results['results']}")
89 | return results['results']
90 | elif results['status'] == 'FAILED':
91 | logging.error(f"Error: {results.get('error_message', 'Unknown error')}")
92 | return None
93 |
94 | attempt += 1
95 |
96 | logging.warning("Timeout: Pipeline execution took too long.")
97 | return None
98 |
99 | if __name__ == "__main__":
100 | process_pdf(PDF_FILE_PATH)
101 |
--------------------------------------------------------------------------------
/examples/scripts/markdown_example.py:
--------------------------------------------------------------------------------
1 | import base64
2 | import json
3 | import zlib
4 | import logging
5 | import os
6 | import time
7 | from dotenv import load_dotenv
8 | from marly import Marly
9 | # This script is the same as local_example_cerebras but will return data in markdown instead of JSON
10 | # In order to extract data in markdown, you need to add markdown_mode in the POST pipelines request as seen
11 | # on line 46
12 | load_dotenv()
13 |
14 | logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
15 |
16 | PDF_FILE_PATH = "../example_files/lacers_reduced.pdf"
17 |
18 | def read_and_encode_pdf(file_path):
19 | with open(file_path, "rb") as file:
20 | pdf_content = base64.b64encode(zlib.compress(file.read())).decode('utf-8')
21 | logging.debug(f"{file_path} read and encoded")
22 | return pdf_content
23 |
24 | def process_pdf(pdf_file):
25 | pdf_content = read_and_encode_pdf(pdf_file)
26 |
27 | schema_1 = {
28 | "Firm": "The name of the firm",
29 | "Number of Funds": "The number of funds managed by the firm",
30 | "Commitment": "The commitment amount in millions of dollars",
31 | "Percent of Total Comm": "The percentage of total commitment",
32 | "Exposure (FMV + Unfunded)": "The exposure including fair market value and unfunded commitments in millions of dollars",
33 | "Percent of Total Exposure": "The percentage of total exposure",
34 | "TVPI": "Total Value to Paid-In multiple",
35 | "Net IRR": "Net Internal Rate of Return as a percentage"
36 | }
37 |
38 | #running locally
39 | # TODO: Change back to 8100
40 | client = Marly(base_url="http://localhost:8100")
41 |
42 | try:
43 | pipeline_response_model = client.pipelines.create(
44 | api_key=os.getenv("CEREBRAS_API_KEY"),
45 | provider_model_name="llama3.1-70b",
46 | provider_type="cerebras",
47 | markdown_mode = True,
48 | workloads=[
49 | {
50 | "raw_data": pdf_content,
51 | "schemas": [json.dumps(schema_1)],
52 | }
53 | ],
54 | )
55 | logging.debug(f"Task ID: {pipeline_response_model.task_id}")
56 | except Exception as e:
57 | logging.error(f"Error creating pipeline: {e}")
58 | return None
59 |
60 | # Quick check for cached results
61 | time.sleep(5)
62 | try:
63 | results = client.pipelines.retrieve(pipeline_response_model.task_id)
64 | if results.status == 'COMPLETED':
65 | parsed_results = [json.loads(results.results[0].metrics[f'schema_{i}']) for i in range(len(results.results[0].metrics))]
66 | logging.info(f"Cached results found and returned {parsed_results}")
67 | return json.dumps(parsed_results, indent=2)
68 | except Exception as e:
69 | logging.debug(f"No cached results available: {e}")
70 |
71 | max_attempts = 5
72 | attempt = 0
73 | while attempt < max_attempts:
74 | logging.debug(f"Waiting for pipeline to complete. Attempt {attempt + 1} of {max_attempts}")
75 | time.sleep(30)
76 |
77 | try:
78 | results = client.pipelines.retrieve(pipeline_response_model.task_id)
79 | logging.debug(f"Poll attempt {attempt + 1}: Status - {results.status}")
80 |
81 | if results.status == 'COMPLETED':
82 | parsed_results = [results.results[0].metrics[f'schema_{i}'] for i in range(len(results.results[0].metrics))]
83 | logging.info(f"Results: {parsed_results}")
84 | return parsed_results # No need to json.dumps() here if we're returning Markdown
85 | elif results.status == 'FAILED':
86 | logging.error(f"Error: {results.error_message or 'Unknown error'}")
87 | return None
88 |
89 | except Exception as e:
90 | logging.error(f"Error fetching pipeline results: {e}")
91 | return None
92 |
93 | attempt += 1
94 |
95 | logging.warning("Timeout: Pipeline execution took too long.")
96 | return None
97 |
98 | if __name__ == "__main__":
99 | process_pdf(PDF_FILE_PATH)
100 |
--------------------------------------------------------------------------------
/examples/scripts/mistral_example.py:
--------------------------------------------------------------------------------
1 | import json
2 | import logging
3 | import time
4 | from dotenv import load_dotenv
5 | import os
6 | from marly import Marly
7 |
8 | load_dotenv()
9 |
10 | logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
11 |
12 | BASE_URL = "http://localhost:8100"
13 |
14 | def process_pdf():
15 | schema_1 = {
16 | "Firm": "The name of the firm",
17 | "Number of Funds": "The number of funds managed by the firm",
18 | "Commitment": "The commitment amount in millions of dollars",
19 | "Percent of Total Comm": "The percentage of total commitment",
20 | "Exposure (FMV + Unfunded)": "The exposure including fair market value and unfunded commitments in millions of dollars",
21 | "Percent of Total Exposure": "The percentage of total exposure",
22 | "TVPI": "Total Value to Paid-In multiple",
23 | "Net IRR": "Net Internal Rate of Return as a percentage"
24 | }
25 |
26 | client = Marly(base_url=BASE_URL)
27 |
28 | try:
29 | pipeline_response_model = client.pipelines.create(
30 | api_key=os.getenv("MISTRAL_API_KEY"),
31 | provider_model_name="mistral-large-latest",
32 | provider_type="mistral",
33 | workloads=[
34 | {
35 | "file_name": "lacers reduced",
36 | "data_source": "local_fs",
37 | "documents_location": "/app/example_files",
38 | "schemas": [json.dumps(schema_1)],
39 | }
40 | ]
41 | )
42 |
43 | logging.debug(f"Task ID: {pipeline_response_model.task_id}")
44 |
45 | max_attempts = 5
46 | attempt = 0
47 | while attempt < max_attempts:
48 | time.sleep(30)
49 |
50 | results = client.pipelines.retrieve(pipeline_response_model.task_id)
51 | logging.debug(f"Poll attempt {attempt + 1}: Status - {results.status}")
52 |
53 | if results.status == 'COMPLETED':
54 | logging.debug(f"Pipeline completed with results: {results.results}")
55 | return results.results
56 | elif results.status == 'FAILED':
57 | logging.error(f"Error: {results.error_message}")
58 | return None
59 |
60 | attempt += 1
61 |
62 | logging.warning("Timeout: Pipeline execution took too long.")
63 | return None
64 |
65 | except Exception as e:
66 | logging.error(f"Error in pipeline process: {e}")
67 | return None
68 |
69 | if __name__ == "__main__":
70 | results = process_pdf()
71 | print(results)
--------------------------------------------------------------------------------
/examples/scripts/non_marly_examples/lacers_reduced.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/velocitybolt/open-extract/64d06b8f79e18d299767ee5fcf9746e9c9425bb4/examples/scripts/non_marly_examples/lacers_reduced.pdf
--------------------------------------------------------------------------------
/examples/scripts/non_marly_examples/llamaindex_pinecone.py:
--------------------------------------------------------------------------------
1 | from llama_index.readers.web import SimpleWebPageReader
2 | from llama_index.core import VectorStoreIndex
3 | from llama_index.vector_stores.pinecone import PineconeVectorStore
4 | from llama_index.core.storage.storage_context import StorageContext
5 | from llama_index.core.indices.postprocessor import MetadataReplacementPostProcessor
6 | from llama_index.core.settings import Settings
7 | from llama_index.core import SimpleDirectoryReader
8 | from dotenv import load_dotenv
9 | from pinecone import Pinecone, ServerlessSpec
10 | from llama_index.llms.azure_openai import AzureOpenAI
11 | from llama_index.embeddings.azure_openai import AzureOpenAIEmbedding
12 | from llama_index.core import Settings
13 | import os
14 |
15 | load_dotenv()
16 |
17 | SCHEMA = {
18 | "Names": "The first name and last name of the company founders",
19 | "Company Name": "Name of the Company",
20 | "Round": "The round of funding",
21 | "Round Size": "How much money has the company raised",
22 | "Investors": "The names of the investors in the companies (names of investors and firms)",
23 | "Company Valuation": "The current valuation of the company",
24 | "Summary": "Three sentence summary of the company"
25 | }
26 |
27 | SCHEMA_2 = {
28 | "Firm": "The name of the firm",
29 | "Number of Funds": "The number of funds managed by the firm",
30 | "Commitment": "The commitment amount in millions of dollars",
31 | "Percent of Total Comm": "The percentage of total commitment",
32 | "Exposure (FMV + Unfunded)": "The exposure including fair market value and unfunded commitments in millions of dollars",
33 | "Percent of Total Exposure": "The percentage of total exposure",
34 | "TVPI": "Total Value to Paid-In multiple",
35 | "Net IRR": "Net Internal Rate of Return as a percentage"
36 | }
37 |
38 | class LlamaIndexAzure:
39 | def __init__(self):
40 | self.llm = AzureOpenAI (
41 | model=os.environ["AZURE_MODEL_NAME"],
42 | deployment_name=os.environ["AZURE_DEPLOYMENT_ID"],
43 | api_key=os.environ["AZURE_OPENAI_API_KEY"],
44 | azure_endpoint=os.environ["AZURE_ENDPOINT"],
45 | api_version=os.environ["AZURE_API_VERSION"],
46 | )
47 |
48 | self.embed_model = AzureOpenAIEmbedding (
49 | model=os.environ["EMBED_MODEL_NAME"],
50 | deployment_name=os.environ["EMBED_DEPLOYMENT_NAME"],
51 | api_key=os.environ["AZURE_OPENAI_API_KEY"],
52 | azure_endpoint=os.environ["EMBED_AZURE_ENDPOINT"],
53 | api_version=os.environ["EMBED_API_VERSION"],
54 | )
55 | Settings.llm = self.llm
56 | Settings.embed_model = self.embed_model
57 |
58 | llamaindex_azure_llm = LlamaIndexAzure().llm
59 | llamaindex_azure_embed = LlamaIndexAzure().embed_model
60 | Settings.llm = llamaindex_azure_llm
61 | Settings.embed_model = llamaindex_azure_embed
62 |
63 | def ask_question(question):
64 | index = pinecone_setup()
65 | query_engine = index.as_query_engine(
66 | similarity_top_k=9,
67 | node_postprocessors=[
68 | MetadataReplacementPostProcessor(target_metadata_key="window")
69 | ],
70 | )
71 |
72 | return query_engine.query(question)
73 |
74 | def pinecone_setup():
75 | print("Setting up Pinecone...")
76 | pc = Pinecone(
77 | api_key=os.environ.get("PINECONE_API_KEY"),
78 | environment="gcp-starter"
79 | )
80 | index_name = "example"
81 | if index_name not in pc.list_indexes().names():
82 | print("Creating new Pinecone index")
83 | load_docs_into_pinecone(pc,index_name=index_name)
84 | print("Pinecone setup completed and index created")
85 |
86 | print("Index already in the environment")
87 | pinecone_index = pc.Index(index_name)
88 | vector_store = PineconeVectorStore(pinecone_index=pinecone_index)
89 | index = VectorStoreIndex.from_vector_store(vector_store=vector_store)
90 |
91 | return index
92 |
93 | def load_docs_into_pinecone(pc, index_name):
94 | WEBSITE_URL = "https://techcrunch.com/2013/02/08/snapchat-raises-13-5m-series-a-led-by-benchmark-now-sees-60m-snaps-sent-per-day/"
95 | WEBSITE_URL2 = "https://techcrunch.com/2024/08/09/anysphere-a-github-copilot-rival-has-raised-60m-series-a-at-400m-valuation-from-a16z-thrive-sources-say/"
96 |
97 |
98 | documents = SimpleWebPageReader().load_data(
99 | [WEBSITE_URL, WEBSITE_URL2]
100 | )
101 |
102 | print(f"Loaded {len(documents)} document(s)")
103 |
104 | reader = SimpleDirectoryReader(input_files=["./lacers_reduced.pdf"])
105 | documents2 = reader.load_data()
106 |
107 | print(f"Loaded {len(documents2)} document(s)")
108 |
109 | pc.create_index(
110 | name=index_name,
111 | dimension=1536,
112 | metric='euclidean',
113 | spec=ServerlessSpec(
114 | cloud="aws",
115 | region="us-east-1",
116 | )
117 | )
118 |
119 |
120 | pc = Pinecone(
121 | api_key=os.environ.get("PINECONE_API_KEY"),
122 | environment="gcp-starter"
123 | )
124 |
125 | pinecone_index = pc.Index(index_name)
126 |
127 | vector_store = PineconeVectorStore(pinecone_index=pinecone_index)
128 | storage_context = StorageContext.from_defaults(vector_store=vector_store)
129 | VectorStoreIndex.from_documents(documents, storage_context=storage_context)
130 | VectorStoreIndex.from_documents(documents2, storage_context=storage_context)
131 |
132 | if __name__ == "__main__":
133 | question1 = f"Extract values for {SCHEMA} for every company and extract {SCHEMA_2} for the 10 Largest Sponsor Relationships then return the result as JSON. Please include all the data!"
134 | answer = ask_question(question1)
135 | print("Answer:", answer)
--------------------------------------------------------------------------------
/examples/scripts/requirements.txt:
--------------------------------------------------------------------------------
1 | certifi==2024.8.30
2 | charset-normalizer==3.3.2
3 | idna==3.10
4 | python-dotenv==1.0.1
5 | requests==2.32.3
6 | urllib3==2.2.3
7 |
--------------------------------------------------------------------------------
/examples/scripts/s3_example.py:
--------------------------------------------------------------------------------
1 | import json
2 | import logging
3 | import os
4 | from dotenv import load_dotenv
5 | from marly import Marly
6 | import time
7 | import os
8 |
9 | load_dotenv()
10 |
11 | logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
12 |
13 | S3_BUCKET_NAME = os.getenv("S3_BUCKET_NAME")
14 | S3_FILE_KEY = os.getenv("S3_FILE_KEY")
15 | BASE_URL = "http://localhost:8100"
16 |
17 | SCHEMA_1 = {
18 | "Firm": "The name of the firm",
19 | "Number of Funds": "The number of funds managed by the firm",
20 | "Commitment": "The commitment amount in millions of dollars",
21 | "Percent of Total Comm": "The percentage of total commitment",
22 | "Exposure (FMV + Unfunded)": "The exposure including fair market value and unfunded commitments in millions of dollars",
23 | "Percent of Total Exposure": "The percentage of total exposure",
24 | "TVPI": "Total Value to Paid-In multiple",
25 | "Net IRR": "Net Internal Rate of Return as a percentage"
26 | }
27 |
28 | def process_pdf():
29 | client = Marly(base_url=BASE_URL)
30 |
31 | try:
32 | pipeline_response = client.pipelines.create(
33 | api_key=os.getenv("GROQ_API_KEY"),
34 | provider_model_name="llama-3.1-70b-versatile",
35 | provider_type="groq",
36 | workloads=[
37 | {
38 | "file_name": S3_FILE_KEY,
39 | "data_source": "s3",
40 | "documents_location": S3_BUCKET_NAME,
41 | "schemas": [json.dumps(SCHEMA_1)]
42 | }
43 | ],
44 | additional_params={
45 | "bucket_name": S3_BUCKET_NAME,
46 | "region_name": "us-east-1",
47 | "aws_access_key_id": os.getenv("AWS_ACCESS_KEY_ID"),
48 | "aws_secret_access_key": os.getenv("AWS_SECRET_ACCESS_KEY")
49 | }
50 | )
51 | logging.debug(f"Task ID: {pipeline_response.task_id}")
52 |
53 | max_attempts = 5
54 | attempt = 0
55 | while attempt < max_attempts:
56 | time.sleep(10)
57 |
58 | results = client.pipelines.retrieve(pipeline_response.task_id)
59 | logging.debug(f"Poll attempt {attempt + 1}: Status - {results.status}")
60 |
61 | if results.status == 'COMPLETED':
62 | logging.debug(f"Pipeline completed with results: {results.results}")
63 | return results.results
64 | elif results.status == 'FAILED':
65 | logging.error(f"Error: {results.error_message}")
66 | return None
67 |
68 | attempt += 1
69 |
70 | logging.warning("Timeout: Pipeline execution took too long.")
71 | return None
72 |
73 | except Exception as e:
74 | logging.error(f"Error in pipeline process: {e}")
75 | return None
76 |
77 | if __name__ == "__main__":
78 | results = process_pdf()
79 | print(results)
80 |
--------------------------------------------------------------------------------
/examples/scripts/web_and_document_example.py:
--------------------------------------------------------------------------------
1 | import json
2 | import logging
3 | import time
4 | from dotenv import load_dotenv
5 | import os
6 | from marly import Marly
7 |
8 | load_dotenv()
9 |
10 | logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
11 |
12 | BASE_URL = "http://localhost:8100"
13 | WEBSITE_URL = "https://techcrunch.com/2013/02/08/snapchat-raises-13-5m-series-a-led-by-benchmark-now-sees-60m-snaps-sent-per-day/"
14 | WEBSITE_URL2 = "https://techcrunch.com/2024/08/09/anysphere-a-github-copilot-rival-has-raised-60m-series-a-at-400m-valuation-from-a16z-thrive-sources-say/"
15 |
16 | SCHEMA = {
17 | "Names": "The first name and last name of the company founders",
18 | "Company Name": "Name of the Company",
19 | "Round": "The round of funding",
20 | "Round Size": "How much money has the company raised",
21 | "Investors": "The names of the investors in the companies (names of investors and firms)",
22 | "Company Valuation": "The current valuation of the company",
23 | "Summary": "Three sentence summary of the company"
24 | }
25 |
26 | SCHEMA_2 = {
27 | "Firm": "The name of the firm",
28 | "Number of Funds": "The number of funds managed by the firm",
29 | "Commitment": "The commitment amount in millions of dollars",
30 | "Percent of Total Comm": "The percentage of total commitment",
31 | "Exposure (FMV + Unfunded)": "The exposure including fair market value and unfunded commitments in millions of dollars",
32 | "Percent of Total Exposure": "The percentage of total exposure",
33 | "TVPI": "Total Value to Paid-In multiple",
34 | "Net IRR": "Net Internal Rate of Return as a percentage"
35 | }
36 |
37 | def process_website():
38 | client = Marly(base_url=BASE_URL)
39 |
40 | try:
41 | pipeline_response = client.pipelines.create(
42 | api_key=os.getenv("AZURE_OPENAI_API_KEY"),
43 | provider_model_name=os.getenv("AZURE_MODEL_NAME"),
44 | provider_type="azure",
45 | workloads=[{
46 | "data_source": "web",
47 | "documents_location": WEBSITE_URL,
48 | "schemas": [json.dumps(SCHEMA)],
49 | },
50 | {
51 | "data_source": "web",
52 | "documents_location": WEBSITE_URL2,
53 | "schemas": [json.dumps(SCHEMA)],
54 | },
55 | {
56 | "file_name": "lacers reduced",
57 | "data_source": "local_fs",
58 | "documents_location": "/app/example_files",
59 | "schemas": [json.dumps(SCHEMA_2)],
60 | }
61 | ],
62 | additional_params={
63 | "azure_endpoint": os.getenv("AZURE_ENDPOINT"),
64 | "azure_deployment": os.getenv("AZURE_DEPLOYMENT_ID"),
65 | "api_version": os.getenv("AZURE_API_VERSION")
66 | }
67 | )
68 |
69 | while True:
70 | results = client.pipelines.retrieve(pipeline_response.task_id)
71 | if results.status == 'COMPLETED':
72 | processed_results = []
73 | for result in results.results:
74 | for schema_result in result.results:
75 | processed_result = json.loads(schema_result['metrics']['schema_0'])
76 | processed_results.append(processed_result)
77 | return json.dumps(processed_results, indent=2)
78 | elif results.status == 'FAILED':
79 | return None
80 | time.sleep(15)
81 |
82 | except Exception as e:
83 | logging.error(f"Error in pipeline process: {e}")
84 | return None
85 |
86 | if __name__ == "__main__":
87 | print(process_website())
88 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | redis==5.0.1
2 | beautifulsoup4==4.12.2
3 | python-dotenv==1.0.0
4 | langsmith==0.0.83
5 | PyPDF2==3.0.1
6 | markitdown==0.1.0
7 | langchain-core==0.1.27
8 | langgraph==0.0.27
--------------------------------------------------------------------------------
/start-oe.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | if [ ! -x "$0" ]; then
4 | echo -e "${RED}Error: Script doesn't have execute permissions${NC}"
5 | echo "Please run: chmod +x $0"
6 | exit 1
7 | fi
8 |
9 | RED='\033[0;31m'
10 | GREEN='\033[0;32m'
11 | BLUE='\033[0;34m'
12 | YELLOW='\033[1;33m'
13 | NC='\033[0m'
14 |
15 | echo -e "${BLUE}=== Checking System Requirements ===${NC}"
16 |
17 | if [ -f "examples/scripts/requirements.txt" ]; then
18 |
19 | if command -v pip3 &> /dev/null; then
20 | PIP_CMD="pip3"
21 | elif command -v pip &> /dev/null; then
22 | PIP_CMD="pip"
23 | else
24 | echo -e "${RED}Error: Neither pip nor pip3 is installed${NC}"
25 | exit 1
26 | fi
27 |
28 | missing_packages=0
29 | while IFS= read -r package || [ -n "$package" ]; do
30 | [[ $package =~ ^[[:space:]]*$ || $package =~ ^#.*$ ]] && continue
31 |
32 | package_name=$(echo "$package" | cut -d'=' -f1 | cut -d'>' -f1 | cut -d'<' -f1 | tr -d ' ')
33 |
34 | if ! $PIP_CMD show "$package_name" >/dev/null 2>&1; then
35 | echo -e "${RED}❌ Missing package: ${package_name}${NC}"
36 | missing_packages=1
37 | else
38 | echo -e "${GREEN}✓ ${package_name} is installed${NC}"
39 | fi
40 | done < "examples/scripts/requirements.txt"
41 |
42 | if [ $missing_packages -eq 1 ]; then
43 | echo -e "\n${RED}Error: Missing Python requirements${NC}"
44 | echo -e "Please install required packages with:"
45 | echo -e "$PIP_CMD install -r examples/scripts/requirements.txt"
46 | exit 1
47 | fi
48 | else
49 | echo -e "${RED}Error: examples/scripts/requirements.txt not found${NC}"
50 | exit 1
51 | fi
52 |
53 | if ! command -v docker &> /dev/null; then
54 | echo -e "${RED}Error: Docker is not installed${NC}"
55 | echo "Please install Docker first: https://docs.docker.com/get-docker/"
56 | exit 1
57 | else
58 | echo -e "${GREEN}✓ Docker is installed${NC}"
59 | fi
60 |
61 | if ! command -v docker-compose &> /dev/null; then
62 | echo -e "${RED}Error: Docker Compose is not installed${NC}"
63 | echo "Please install Docker Compose: https://docs.docker.com/compose/install/"
64 | exit 1
65 | else
66 | echo -e "${GREEN}✓ Docker Compose is installed${NC}"
67 | fi
68 |
69 | if ! docker info &> /dev/null; then
70 | echo -e "\n${RED}Error: Docker daemon is not running${NC}"
71 | echo "Please start Docker daemon first"
72 | exit 1
73 | else
74 | echo -e "${GREEN}✓ Docker daemon is running${NC}"
75 | fi
76 |
77 | echo -e "\n${BLUE}=== Attempting to load .env file... ===${NC}"
78 |
79 | if [ ! -f .env ]; then
80 | echo -e "${RED}Error: .env file not found${NC}"
81 | echo -e "Please create a .env file with your provider credentials"
82 | exit 1
83 | fi
84 |
85 | set -a
86 | source .env >/dev/null 2>&1
87 | set +a
88 |
89 | provider_configured=false
90 |
91 | if [ ! -z "${AZURE_OPENAI_API_KEY}" ]; then
92 | provider_configured=true
93 | missing_vars=0
94 |
95 | azure_vars=(
96 | "AZURE_RESOURCE_NAME"
97 | "AZURE_DEPLOYMENT_ID"
98 | "AZURE_MODEL_NAME"
99 | "AZURE_API_VERSION"
100 | "AZURE_OPENAI_API_KEY"
101 | "AZURE_ENDPOINT"
102 | )
103 |
104 | for var in "${azure_vars[@]}"; do
105 | if [ -z "${!var}" ]; then
106 | echo -e "${RED}❌ Missing ${var}${NC}"
107 | missing_vars=1
108 | else
109 | echo -e "${GREEN}✓ ${var} is set${NC}"
110 | fi
111 | done
112 |
113 | if [ $missing_vars -eq 1 ]; then
114 | echo -e "${RED}Error: Azure OpenAI requires all related environment variables${NC}"
115 | exit 1
116 | fi
117 | fi
118 |
119 | if [ ! -z "${OPENAI_API_KEY}" ]; then
120 | provider_configured=true
121 | echo -e "${GREEN}✓ OpenAI configured${NC}"
122 | fi
123 |
124 | if [ ! -z "${CEREBRAS_API_KEY}" ]; then
125 | provider_configured=true
126 | echo -e "${GREEN}✓ Cerebras configured${NC}"
127 | fi
128 |
129 | if [ ! -z "${GROQ_API_KEY}" ]; then
130 | provider_configured=true
131 | echo -e "${GREEN}✓ Groq configured${NC}"
132 | fi
133 |
134 | if [ ! -z "${MISTRAL_API_KEY}" ]; then
135 | provider_configured=true
136 | echo -e "${GREEN}✓ Mistral configured${NC}"
137 | fi
138 |
139 | if [ "$provider_configured" = false ]; then
140 | echo -e "${YELLOW}Warning: No AI Model provider credentials found in .env file${NC}"
141 | echo -e "These are some of the model providers that are supported:"
142 | echo -e "- Azure OpenAI (AZURE_OPENAI_API_KEY + related vars)"
143 | echo -e "- OpenAI (OPENAI_API_KEY)"
144 | echo -e "- Cerebras (CEREBRAS_API_KEY)"
145 | echo -e "- Groq (GROQ_API_KEY)"
146 | echo -e "- Mistral (MISTRAL_API_KEY)"
147 | echo -e "\n${YELLOW}Are you sure you want to continue without any of these providers configured? (yes/no)${NC}"
148 | read -r answer
149 |
150 | if [ "$(echo $answer | tr '[:upper:]' '[:lower:]')" = "yes" ]; then
151 | echo -e "${YELLOW}Continuing without provider configuration. Please ensure your desired AI model provider is supported by Marly.${NC}"
152 | else
153 | echo -e "${RED}Please configure at least one provider in your .env file${NC}"
154 | exit 1
155 | fi
156 | fi
157 |
158 |
159 | if ! grep -q "^LANGCHAIN_TRACING_V2=true" .env; then
160 | echo -e "${YELLOW}Warning: LANGCHAIN_TRACING_V2 is not enabled${NC}"
161 | echo -e "\n${YELLOW}Are you sure you want to continue without LangSmith tracing? (yes/no)${NC}"
162 | read -r answer
163 |
164 | if [ "$(echo $answer | tr '[:upper:]' '[:lower:]')" = "yes" ]; then
165 | echo -e "${YELLOW}Continuing without LangSmith tracing...${NC}"
166 | else
167 | echo -e "${RED}Please configure LangSmith tracing in your .env file${NC}"
168 | echo -e "\nRequired variables:"
169 | echo -e "LANGCHAIN_TRACING_V2=true"
170 | echo -e "LANGCHAIN_ENDPOINT=https://api.smith.langchain.com"
171 | echo -e "LANGCHAIN_API_KEY="
172 | echo -e "LANGCHAIN_PROJECT="
173 | exit 1
174 | fi
175 | else
176 | echo -e "\n${BLUE}Checking LangSmith configuration:${NC}"
177 | missing_vars=0
178 |
179 | langsmith_vars=(
180 | "LANGCHAIN_ENDPOINT"
181 | "LANGCHAIN_API_KEY"
182 | "LANGCHAIN_PROJECT"
183 | )
184 |
185 | for var in "${langsmith_vars[@]}"; do
186 | if ! grep -q "^${var}=" .env; then
187 | echo -e "${RED}❌ ${var} must be explicitly set in .env file${NC}"
188 | missing_vars=1
189 | else
190 | echo -e "${GREEN}✓ ${var} is set${NC}"
191 | fi
192 | done
193 |
194 | if [ $missing_vars -eq 1 ]; then
195 | echo -e "${RED}Error: LangSmith tracing requires all related environment variables${NC}"
196 | exit 1
197 | else
198 | echo -e "${GREEN}✓ LangSmith tracing has been correctly configured${NC}"
199 | fi
200 | fi
201 |
202 | echo -e "\n${BLUE}=== Checking for Running Containers ===${NC}"
203 |
204 | if docker ps --format '{{.Names}}' | grep -qE '(pipeline-1|extraction-1|transformation-1)'; then
205 | echo -e "${RED}Error: Marly containers are already running${NC}"
206 | echo -e "\nRunning containers:"
207 | docker ps --format 'table {{.Names}}\t{{.Status}}' | grep -E '(pipeline-1|extraction-1|transformation-1)'
208 | exit 1
209 | else
210 | echo -e "\n${GREEN}✓ All checks passed! Starting services...${NC}"
211 | docker-compose up --build
212 | fi
213 |
--------------------------------------------------------------------------------