├── .dockerignore ├── .gitignore ├── Dockerfile ├── README.md ├── docker-compose.yaml ├── dynamicbatch_ragpipeline ├── __init__.py ├── cache.py ├── doc_layout.py ├── env.py ├── function.py ├── main.py ├── ocr.py ├── ocr_utils.py └── playwright_utils.py ├── html-to-pdf-only.yaml ├── notebook ├── doc-layout.png ├── document-layout.ipynb ├── huggingface.pdf ├── ocr.ipynb ├── page1.png ├── page2.png ├── page3.png ├── page4.png ├── url-to-pdf.ipynb └── url-to-pdf.png ├── push-dockerhub.sh ├── requirements.txt ├── runs └── detect │ └── train │ └── args.yaml ├── setup.py └── stress-test ├── 2310.01889v4.pdf ├── README.md ├── doc_layout.png ├── doc_layout.py ├── ocr.png ├── ocr.py ├── table1.png ├── table2.png └── title.png /.dockerignore: -------------------------------------------------------------------------------- 1 | huggingface 2 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *__pycache__ 2 | huggingface 3 | *ipynb_checkpoints 4 | dist 5 | build 6 | *.egg-info 7 | *.whl 8 | *Untitled*.ipynb -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM nvidia/cuda:12.4.0-runtime-ubuntu22.04 as base 2 | 3 | RUN apt update 4 | RUN apt install python3 python3-dev python3-pip -y 5 | RUN apt install sudo -y 6 | RUN adduser --quiet --disabled-password --shell /bin/bash --home /home/ubuntu --gecos "User" ubuntu 7 | RUN usermod -aG sudo ubuntu 8 | RUN echo "ubuntu ALL=(ALL) NOPASSWD:ALL" >> /etc/sudoers 9 | USER ubuntu 10 | WORKDIR /home/ubuntu 11 | 12 | RUN pip3 install pip -U 13 | RUN sudo apt update 14 | 15 | ADD requirements.txt . 16 | RUN pip3 install -r requirements.txt 17 | 18 | RUN ~/.local/bin/playwright install-deps 19 | RUN ~/.local/bin/playwright install 20 | 21 | COPY ./dynamicbatch_ragpipeline/ /home/ubuntu/dynamicbatch_ragpipeline -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # dynamic-batch-RAG-pipeline 2 | 3 | Dynamic batching for Document Layout and OCR, suitable for RAG. 4 | 5 | 1. Dynamic batching for SOTA Document Layout and OCR, suitable to serve better concurrency. 6 | 2. Continuous batching for Causal based OCR models. 7 | 3. Can serve user defined max concurrency. 8 | 4. Disconnected signal, so this is to ensure early stop for continuous batching. 9 | 5. Extra tool, convert any URL to PDF file. 10 | 11 | ## Available models 12 | 13 | ### Document Layout 14 | 15 | 1. https://github.com/opendatalab/DocLayout-YOLO 16 | 17 | ### OCR 18 | 19 | 1. https://huggingface.co/stepfun-ai/GOT-OCR2_0 20 | 21 | ## how to install 22 | 23 | Using PIP with git, 24 | 25 | ```bash 26 | pip3 install git+https://github.com/mesolitica/dynamic-batch-RAG-pipeline 27 | ``` 28 | 29 | Or you can git clone, 30 | 31 | ```bash 32 | git clone https://github.com/mesolitica/dynamic-batch-RAG-pipeline && cd dynamic-batch-RAG-pipeline 33 | ``` 34 | 35 | ## how to 36 | 37 | ### Supported parameters 38 | 39 | ```bash 40 | python3 -m dynamicbatch_ragpipeline.main --help 41 | ``` 42 | 43 | ```text 44 | usage: main.py [-h] [--host HOST] [--port PORT] [--loglevel LOGLEVEL] [--reload RELOAD] 45 | [--enable-doc-layout ENABLE_DOC_LAYOUT] [--model-doc-layout MODEL_DOC_LAYOUT] [--enable-ocr ENABLE_OCR] 46 | [--model-ocr MODEL_OCR] [--dynamic-batching-microsleep DYNAMIC_BATCHING_MICROSLEEP] 47 | [--dynamic-batching-doc-layout-batch-size DYNAMIC_BATCHING_DOC_LAYOUT_BATCH_SIZE] 48 | [--dynamic-batching-ocr-batch-size DYNAMIC_BATCHING_OCR_BATCH_SIZE] [--accelerator-type ACCELERATOR_TYPE] 49 | [--max-concurrent MAX_CONCURRENT] [--static-cache STATIC_CACHE] 50 | [--static-cache-max-length STATIC_CACHE_MAX_LENGTH] [--enable-url-to-pdf ENABLE_URL_TO_PDF] 51 | [--playwright-max-concurrency PLAYWRIGHT_MAX_CONCURRENCY] 52 | 53 | Configuration parser 54 | 55 | options: 56 | -h, --help show this help message and exit 57 | --host HOST host name to host the app (default: 0.0.0.0, env: HOSTNAME) 58 | --port PORT port to host the app (default: 7088, env: PORT) 59 | --loglevel LOGLEVEL Logging level (default: INFO, env: LOGLEVEL) 60 | --reload RELOAD Enable hot loading (default: False, env: RELOAD) 61 | --enable-doc-layout ENABLE_DOC_LAYOUT 62 | Enable document layout detection (default: True, env: ENABLE_DOC_LAYOUT) 63 | --model-doc-layout MODEL_DOC_LAYOUT 64 | Model type (default: yolo10, env: MODEL_DOC_LAYOUT) 65 | --enable-ocr ENABLE_OCR 66 | Enable OCR (default: True, env: ENABLE_OCR) 67 | --model-ocr MODEL_OCR 68 | Model type (default: got_ocr2_0, env: MODEL_OCR) 69 | --dynamic-batching-microsleep DYNAMIC_BATCHING_MICROSLEEP 70 | microsleep to group dynamic batching, 1 / 1e-4 = 10k steps for second (default: 0.0001, env: 71 | DYNAMIC_BATCHING_MICROSLEEP) 72 | --dynamic-batching-doc-layout-batch-size DYNAMIC_BATCHING_DOC_LAYOUT_BATCH_SIZE 73 | maximum of batch size for document layout during dynamic batching (default: 16, env: 74 | DYNAMIC_BATCHING_DOC_LAYOUT_BATCH_SIZE) 75 | --dynamic-batching-ocr-batch-size DYNAMIC_BATCHING_OCR_BATCH_SIZE 76 | maximum of batch size for OCR during dynamic batching (default: 16, env: 77 | DYNAMIC_BATCHING_OCR_BATCH_SIZE) 78 | --accelerator-type ACCELERATOR_TYPE 79 | Accelerator type (default: cuda, env: ACCELERATOR_TYPE) 80 | --max-concurrent MAX_CONCURRENT 81 | Maximum concurrent requests (default: 100, env: MAX_CONCURRENT) 82 | --static-cache STATIC_CACHE 83 | Preallocate KV Cache for faster inference (default: False, env: STATIC_CACHE) 84 | --static-cache-max-length STATIC_CACHE_MAX_LENGTH 85 | Maximum concurrent requests (default: 8192, env: STATIC_CACHE_MAX_LENGTH) 86 | --enable-url-to-pdf ENABLE_URL_TO_PDF 87 | Enable URL to PDF using Playwright (default: True, env: ENABLE_URL_TO_PDF) 88 | --playwright-max-concurrency PLAYWRIGHT_MAX_CONCURRENCY 89 | Enable URL to PDF using Playwright (default: 1, env: PLAYWRIGHT_MAX_CONCURRENCY) 90 | ``` 91 | 92 | **We support both args and OS environment**. 93 | 94 | ### Run 95 | 96 | ``` 97 | python3 -m dynamicbatch_ragpipeline.main \ 98 | --host 0.0.0.0 --port 7088 99 | ``` 100 | 101 | #### Example document layout 102 | 103 | ```bash 104 | curl -X 'POST' \ 105 | 'http://localhost:7088/doc_layout' \ 106 | -H 'accept: application/json' \ 107 | -H 'Content-Type: multipart/form-data' \ 108 | -F 'file=@stress-test/2310.01889v4.pdf;type=application/pdf' \ 109 | -F 'iou_threshold=0.45' 110 | ``` 111 | 112 | Checkout [notebook/document-layout.ipynb](notebook/document-layout.ipynb). 113 | 114 | 115 | 116 | #### Example OCR 117 | 118 | ```bash 119 | curl -X 'POST' \ 120 | 'http://localhost:7088/ocr' \ 121 | -H 'accept: application/json' \ 122 | -H 'Content-Type: multipart/form-data' \ 123 | -F 'image=@stress-test/table2.png;type=image/png' \ 124 | -F 'max_tokens=4096' \ 125 | -F 'stream=false' 126 | ``` 127 | 128 | **Because the backend is a continuous batching, so we support streaming**. 129 | 130 | Checkout [notebook/ocr.ipynb](notebook/ocr.ipynb). 131 | 132 | 133 | 134 | 135 | 136 | 137 | 138 | 139 | 140 | #### Example URL to PDF 141 | 142 | ```bash 143 | curl -X 'POST' \ 144 | 'http://localhost:7088/url_to_pdf' \ 145 | -H 'accept: application/json' \ 146 | -H 'Content-Type: application/json' \ 147 | -d '{ 148 | "url": "https://huggingface.co/", 149 | "viewport_weight": 1470, 150 | "viewport_height": 956 151 | }' 152 | ``` 153 | 154 | Checkout [notebook/url-to-pdf.ipynb](notebook/url-to-pdf.ipynb). 155 | 156 | 157 | 158 | **To support more concurrency for URL to PDF, make sure set `--playwright-max-concurrency` more than 1**. 159 | 160 | ## [Stress test](stress-test) 161 | 162 | ### Document layout 163 | 164 | Rate of 10 users per second, total requests up to 100 users for 60 seconds on a RTX 3090 Ti, 165 | 166 | ![alt text](stress-test/doc_layout.png) 167 | 168 | ### OCR 169 | 170 | Rate of 5 users per second, total requests up to 50 users for 60 seconds, 171 | 172 | ![alt text](stress-test/ocr.png) 173 | -------------------------------------------------------------------------------- /docker-compose.yaml: -------------------------------------------------------------------------------- 1 | version: "3.3" 2 | 3 | services: 4 | dynamicbatch_ragpipeline: 5 | build: 6 | context: . 7 | deploy: 8 | resources: 9 | reservations: 10 | devices: 11 | - driver: nvidia 12 | count: 1 13 | capabilities: [gpu] 14 | container_name: dynamicbatch_ragpipeline 15 | environment: 16 | - PYTHONUNBUFFERED=1 17 | - HF_HUB_ENABLE_HF_TRANSFER=1 18 | 19 | volumes: 20 | - "./dynamicbatch_ragpipeline:/home/ubuntu/dynamicbatch_ragpipeline" 21 | - "~/.cache/huggingface:/home/ubuntu/.cache/huggingface" 22 | ports: 23 | - "7088:7088" 24 | command: python3 -m dynamicbatch_ragpipeline.main --host 0.0.0.0 --port 7088 --reload true -------------------------------------------------------------------------------- /dynamicbatch_ragpipeline/__init__.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------------------------------- /dynamicbatch_ragpipeline/cache.py: -------------------------------------------------------------------------------- 1 | from typing import List, Tuple, Optional, Dict, Any 2 | from transformers.cache_utils import Cache 3 | import torch 4 | import torch.nn.functional as F 5 | import time 6 | 7 | def pad_kv(caches): 8 | """ 9 | List[head, seq, dims] 10 | """ 11 | 12 | shapes = [caches[i].shape[2] for i in range(len(caches))] 13 | maxlen = max(shapes) 14 | if all(s == maxlen for s in shapes): 15 | return torch.concat(caches) 16 | 17 | new_caches = [] 18 | for i in range(len(caches)): 19 | pad_val = (0, 0, 0, maxlen - caches[i].shape[2], 0, 0, 0, 0) 20 | pad = F.pad(caches[i], pad_val, value=0.0) 21 | new_caches.append(pad) 22 | return torch.concat(new_caches) 23 | 24 | 25 | class DynamicLengthDecoderCache(Cache): 26 | 27 | def __init__(self) -> None: 28 | self.key_cache: List[torch.Tensor] = [] 29 | self.value_cache: List[torch.Tensor] = [] 30 | self.current_uuid = [] 31 | 32 | def batch_size(self): 33 | if len(self.key_cache) > 0: 34 | return len(self.key_cache[0]) 35 | return 0 36 | 37 | def __len__(self): 38 | """ 39 | Support for backwards-compatible `past_key_value` length, e.g. `len(past_key_value)`. This value corresponds 40 | to the number of layers in the model. 41 | """ 42 | return len(self.key_cache) 43 | 44 | def update( 45 | self, 46 | key_states: torch.Tensor, 47 | value_states: torch.Tensor, 48 | layer_idx: int, 49 | cache_kwargs: Optional[Dict[str, Any]] = None, 50 | ) -> Tuple[torch.Tensor, torch.Tensor]: 51 | 52 | keys, values = [], [] 53 | for i, k in enumerate(self.current_uuid): 54 | self.key_cache[layer_idx][k] = torch.cat( 55 | [self.key_cache[layer_idx][k], key_states[i: i + 1]], dim=-2) 56 | self.value_cache[layer_idx][k] = torch.cat( 57 | [self.value_cache[layer_idx][k], value_states[i: i + 1]], dim=-2) 58 | keys.append(self.key_cache[layer_idx][k]) 59 | values.append(self.value_cache[layer_idx][k]) 60 | 61 | k = pad_kv(keys) 62 | v = pad_kv(values) 63 | 64 | return k, v 65 | 66 | def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: 67 | """Returns the sequence length of the cached states. A layer index can be optionally passed.""" 68 | # TODO: deprecate this function in favor of `cache_position` 69 | if len(self.key_cache) <= layer_idx: 70 | return 0 71 | 72 | lengths = [self.key_cache[0][k].shape[2] for k in self.current_uuid] 73 | return max(lengths) 74 | 75 | def get_max_length(self) -> Optional[int]: 76 | """Returns the maximum sequence length of the cached states. DynamicCache does not have a maximum length.""" 77 | return None 78 | 79 | class StaticLengthDecoderCache(Cache): 80 | 81 | def __init__( 82 | self, 83 | batch_size = 20, 84 | max_length = 8192, 85 | device = 'cuda', 86 | head_size = 16, 87 | dim_size = 64, 88 | num_hidden_layers = 24, 89 | dtype = torch.bfloat16, 90 | ) -> None: 91 | 92 | self.key_cache, self.value_cache = [], [] 93 | for _ in range(num_hidden_layers): 94 | self.key_cache.append( 95 | torch.zeros( 96 | batch_size, 97 | head_size, 98 | max_length, 99 | dim_size, 100 | dtype = torch.bfloat16).to(device) 101 | ) 102 | self.value_cache.append( 103 | torch.zeros( 104 | batch_size, 105 | head_size, 106 | max_length, 107 | dim_size, 108 | dtype = torch.bfloat16).to(device) 109 | ) 110 | self.lengths = [] 111 | 112 | def __getitem__(self, layer_idx: int) -> List[Tuple[torch.Tensor]]: 113 | """ 114 | Support for backwards-compatible `past_key_value` indexing, e.g. `past_key_value[0][0].shape[2]` to get the 115 | sequence length. 116 | """ 117 | if layer_idx < len(self): 118 | return self.key_cache[layer_idx], self.value_cache[layer_idx] 119 | else: 120 | raise KeyError( 121 | f"Cache only has {len(self)} layers, attempted to access layer with index {layer_idx}") 122 | 123 | def __len__(self): 124 | """ 125 | Support for backwards-compatible `past_key_value` length, e.g. `len(past_key_value)`. This value corresponds 126 | to the number of layers in the model. 127 | """ 128 | return len(self.key_cache) 129 | 130 | def update( 131 | self, 132 | key_states: torch.Tensor, 133 | value_states: torch.Tensor, 134 | layer_idx: int, 135 | cache_kwargs: Optional[Dict[str, Any]] = None, 136 | ) -> Tuple[torch.Tensor, torch.Tensor]: 137 | 138 | cache_position = cache_kwargs['cache_position'] 139 | maxlen = max(self.lengths) 140 | 141 | for i in range(len(key_states)): 142 | self.key_cache[layer_idx][i, :, self.lengths[i]] = key_states[i,:,0] 143 | self.value_cache[layer_idx][i, :, self.lengths[i]] = value_states[i,:,0] 144 | 145 | k = self.key_cache[layer_idx][:len(key_states), :, :maxlen] 146 | v = self.value_cache[layer_idx][:len(key_states), :, :maxlen] 147 | 148 | return k, v 149 | 150 | def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: 151 | """Returns the sequence length of the cached states. A layer index can be optionally passed.""" 152 | # TODO: deprecate this function in favor of `cache_position` 153 | if len(self.key_cache) <= layer_idx: 154 | return 0 155 | return max(self.lengths) 156 | 157 | def get_max_length(self) -> Optional[int]: 158 | """Returns the maximum sequence length of the cached states. DynamicCache does not have a maximum length.""" 159 | return None -------------------------------------------------------------------------------- /dynamicbatch_ragpipeline/doc_layout.py: -------------------------------------------------------------------------------- 1 | from doclayout_yolo import YOLOv10 2 | from huggingface_hub import snapshot_download 3 | from dynamicbatch_ragpipeline.env import args 4 | from datetime import datetime 5 | from io import BytesIO 6 | from PIL import Image 7 | import base64 8 | import pymupdf 9 | import numpy as np 10 | import asyncio 11 | import torch 12 | import torchvision 13 | import os 14 | import logging 15 | import time 16 | import sys 17 | 18 | id_to_names = { 19 | 0: 'title', 20 | 1: 'plain text', 21 | 2: 'abandon', 22 | 3: 'figure', 23 | 4: 'figure_caption', 24 | 5: 'table', 25 | 6: 'table_caption', 26 | 7: 'table_footnote', 27 | 8: 'isolate_formula', 28 | 9: 'formula_caption' 29 | } 30 | 31 | model = None 32 | device = 'cpu' 33 | if args.accelerator_type == 'cuda': 34 | if not torch.cuda.is_available(): 35 | logging.warning('CUDA is not available, fallback to CPU.') 36 | else: 37 | device = 'cuda' 38 | 39 | step_queue = asyncio.Queue() 40 | 41 | def load_model(): 42 | global model 43 | 44 | model_dir = snapshot_download('juliozhao/DocLayout-YOLO-DocStructBench') 45 | model = YOLOv10(os.path.join(model_dir, 'doclayout_yolo_docstructbench_imgsz1024.pt')) 46 | 47 | if args.torch_compile: 48 | logging.info('enabling torch compile for doc layout') 49 | model.compile() 50 | 51 | async def step(): 52 | need_sleep = True 53 | while True: 54 | if need_sleep: 55 | await asyncio.sleep(args.dynamic_batching_microsleep) 56 | 57 | try: 58 | need_sleep = True 59 | batch = [] 60 | while not step_queue.empty(): 61 | try: 62 | request = await asyncio.wait_for(step_queue.get(), timeout=1e-9) 63 | batch.append(request) 64 | if len(batch) >= args.dynamic_batching_doc_layout_batch_size: 65 | need_sleep = False 66 | break 67 | 68 | except asyncio.TimeoutError: 69 | break 70 | 71 | if not len(batch): 72 | continue 73 | 74 | futures = [batch[i][0] for i in range(len(batch))] 75 | input_img = [batch[i][1] for i in range(len(batch))] 76 | 77 | logging.debug(f'{str(datetime.now())} document layout step batch size of {len(input_img)}') 78 | 79 | with torch.no_grad(): 80 | det_res = model.predict( 81 | input_img, 82 | imgsz=1024, 83 | conf=0.25, 84 | device=device, 85 | batch=len(input_img) 86 | ) 87 | 88 | for i in range(len(futures)): 89 | boxes = det_res[i].__dict__['boxes'].xyxy 90 | classes = det_res[i].__dict__['boxes'].cls 91 | scores = det_res[i].__dict__['boxes'].conf 92 | futures[i].set_result((boxes, classes, scores)) 93 | 94 | except Exception as e: 95 | logging.error(e) 96 | try: 97 | futures = [batch[i][0] for i in range(len(batch))] 98 | for i in range(len(futures)): 99 | if not futures[i].done(): 100 | futures[i].set_exception(e) 101 | except: 102 | pass 103 | 104 | async def predict( 105 | file, 106 | iou_threshold = 0.45, 107 | ratio_x = 2.0, 108 | ratio_y = 2.0, 109 | request = None, 110 | ): 111 | doc = pymupdf.open(file) 112 | mat = pymupdf.Matrix(ratio_x, ratio_y) 113 | futures, images = [], [] 114 | 115 | for page in doc: 116 | pix = page.get_pixmap(matrix=mat) 117 | image = np.frombuffer(pix.samples_mv, dtype=np.uint8).reshape((pix.height, pix.width, 3)).copy() 118 | images.append(image) 119 | 120 | future = asyncio.Future() 121 | await step_queue.put((future, image)) 122 | futures.append(future) 123 | 124 | before = time.time() 125 | results = await asyncio.gather(*futures) 126 | 127 | actual_results = [] 128 | 129 | for i in range(len(results)): 130 | boxes, classes, scores = results[i] 131 | indices = torchvision.ops.nms( 132 | boxes=torch.Tensor(boxes), 133 | scores=torch.Tensor(scores), 134 | iou_threshold=iou_threshold, 135 | ) 136 | boxes, scores, classes = boxes[indices], scores[indices], classes[indices] 137 | if len(boxes.shape) == 1: 138 | boxes = np.expand_dims(boxes, 0) 139 | scores = np.expand_dims(scores, 0) 140 | classes = np.expand_dims(classes, 0) 141 | 142 | image = Image.fromarray(images[i]) 143 | buffered = BytesIO() 144 | image.save(buffered, format="JPEG") 145 | img = base64.b64encode(buffered.getvalue()).decode("utf-8") 146 | 147 | coordinates = boxes.int().cpu().numpy().tolist() 148 | classes = [id_to_names[int(c)] for c in classes] 149 | boxes = [] 150 | for c in coordinates: 151 | x_min, y_min, x_max, y_max = c 152 | boxes.append({ 153 | 'x_min': x_min, 154 | 'y_min': y_min, 155 | 'x_max': x_max, 156 | 'y_max': y_max, 157 | }) 158 | sorted_indices = sorted(range(len(boxes)), key=lambda i: (boxes[i]['y_min'], boxes[i]['x_min'])) 159 | sorted_boxes = [boxes[i] for i in sorted_indices] 160 | sorted_classes = [classes[i] for i in sorted_indices] 161 | d = { 162 | 'classes': sorted_classes, 163 | 'coordinates': sorted_boxes, 164 | 'img': img, 165 | } 166 | actual_results.append(d) 167 | 168 | after = time.time() 169 | 170 | stats = { 171 | 'total_page': len(images), 172 | 'page_per_second': len(images) / (after - before), 173 | } 174 | return { 175 | 'result': actual_results, 176 | 'stats': stats, 177 | } 178 | 179 | -------------------------------------------------------------------------------- /dynamicbatch_ragpipeline/env.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | import os 4 | import torch 5 | 6 | 7 | def parse_arguments(): 8 | parser = argparse.ArgumentParser(description='Configuration parser') 9 | 10 | parser.add_argument( 11 | '--host', type=str, default=os.environ.get('HOSTNAME', '0.0.0.0'), 12 | help='host name to host the app (default: %(default)s, env: HOSTNAME)' 13 | ) 14 | parser.add_argument( 15 | '--port', type=int, default=int(os.environ.get('PORT', '7088')), 16 | help='port to host the app (default: %(default)s, env: PORT)' 17 | ) 18 | parser.add_argument( 19 | '--loglevel', default=os.environ.get('LOGLEVEL', 'INFO').upper(), 20 | help='Logging level (default: %(default)s, env: LOGLEVEL)' 21 | ) 22 | parser.add_argument( 23 | '--reload', type=lambda x: x.lower() == 'true', 24 | default=os.environ.get('reload', 'false').lower() == 'true', 25 | help='Enable hot loading (default: %(default)s, env: RELOAD)' 26 | ) 27 | parser.add_argument( 28 | '--enable-doc-layout', type=lambda x: x.lower() == 'true', 29 | default=os.environ.get('ENABLE_DOC_LAYOUT', 'true').lower() == 'true', 30 | help='Enable document layout detection (default: %(default)s, env: ENABLE_DOC_LAYOUT)' 31 | ) 32 | parser.add_argument( 33 | '--model-doc-layout', 34 | default=os.environ.get('MODEL_DOC_LAYOUT', 'yolo10'), 35 | help='Model type (default: %(default)s, env: MODEL_DOC_LAYOUT)' 36 | ) 37 | parser.add_argument( 38 | '--enable-ocr', type=lambda x: x.lower() == 'true', 39 | default=os.environ.get('ENABLE_OCR', 'true').lower() == 'true', 40 | help='Enable OCR (default: %(default)s, env: ENABLE_OCR)' 41 | ) 42 | parser.add_argument( 43 | '--model-ocr', 44 | default=os.environ.get('MODEL_OCR', 'got_ocr2_0'), 45 | help='Model type (default: %(default)s, env: MODEL_OCR)' 46 | ) 47 | parser.add_argument( 48 | '--dynamic-batching-microsleep', type=float, 49 | default=float(os.environ.get('DYNAMIC_BATCHING_MICROSLEEP', '1e-4')), 50 | help='microsleep to group dynamic batching, 1 / 1e-4 = 10k steps for second (default: %(default)s, env: DYNAMIC_BATCHING_MICROSLEEP)' 51 | ) 52 | parser.add_argument( 53 | '--dynamic-batching-doc-layout-batch-size', type=int, 54 | default=int(os.environ.get('DYNAMIC_BATCHING_DOC_LAYOUT_BATCH_SIZE', '16')), 55 | help='maximum of batch size for document layout during dynamic batching (default: %(default)s, env: DYNAMIC_BATCHING_DOC_LAYOUT_BATCH_SIZE)' 56 | ) 57 | parser.add_argument( 58 | '--dynamic-batching-ocr-batch-size', type=int, 59 | default=int(os.environ.get('DYNAMIC_BATCHING_OCR_BATCH_SIZE', '16')), 60 | help='maximum of batch size for OCR during dynamic batching (default: %(default)s, env: DYNAMIC_BATCHING_OCR_BATCH_SIZE)' 61 | ) 62 | parser.add_argument( 63 | '--accelerator-type', default=os.environ.get('ACCELERATOR_TYPE', 'cuda'), 64 | help='Accelerator type (default: %(default)s, env: ACCELERATOR_TYPE)' 65 | ) 66 | parser.add_argument( 67 | '--max-concurrent', 68 | type=int, 69 | default=int(os.environ.get('MAX_CONCURRENT', '100')), 70 | help='Maximum concurrent requests (default: %(default)s, env: MAX_CONCURRENT)' 71 | ) 72 | parser.add_argument( 73 | '--static-cache', type=lambda x: x.lower() == 'true', 74 | default=os.environ.get('STATIC_CACHE', 'false').lower() == 'true', 75 | help='Preallocate KV Cache for faster inference (default: %(default)s, env: STATIC_CACHE)' 76 | ) 77 | parser.add_argument( 78 | '--static-cache-max-length', 79 | type=int, 80 | default=int(os.environ.get('STATIC_CACHE_MAX_LENGTH', '8192')), 81 | help='Maximum concurrent requests (default: %(default)s, env: STATIC_CACHE_MAX_LENGTH)' 82 | ) 83 | parser.add_argument( 84 | '--enable-url-to-pdf', type=lambda x: x.lower() == 'true', 85 | default=os.environ.get('ENABLE_URL_TO_PDF', 'true').lower() == 'true', 86 | help='Enable URL to PDF using Playwright (default: %(default)s, env: ENABLE_URL_TO_PDF)' 87 | ) 88 | parser.add_argument( 89 | '--playwright-max-concurrency', type=int, 90 | default=int(os.environ.get('PLAYWRIGHT_MAX_CONCURRENCY', '1')), 91 | help='Enable URL to PDF using Playwright (default: %(default)s, env: PLAYWRIGHT_MAX_CONCURRENCY)' 92 | ) 93 | parser.add_argument( 94 | '--torch-compile', type=lambda x: x.lower() == 'true', 95 | default=os.environ.get('TORCH_COMPILE', 'true').lower() == 'false', 96 | help='Torch compile necessary forwards, can speed up at least 1.5X (default: %(default)s, env: TORCH_COMPILE)' 97 | ) 98 | 99 | args = parser.parse_args() 100 | 101 | if args.model_doc_layout not in {'yolo10'}: 102 | raise ValueError('Currently document layout, `--model-doc-layout` or `MODEL_DOC_LAYOUT` environment variable, only support https://github.com/opendatalab/DocLayout-YOLO') 103 | 104 | if args.model_ocr not in {'got_ocr2_0'}: 105 | raise ValueError('Currently OCR, `--model-ocr` or `MODEL_OCR` environment variable, only support https://huggingface.co/stepfun-ai/GOT-OCR2_0') 106 | 107 | device = 'cpu' 108 | if args.accelerator_type == 'cuda': 109 | if not torch.cuda.is_available(): 110 | logging.warning('CUDA is not available, fallback to CPU.') 111 | else: 112 | device = 'cuda' 113 | 114 | args.device = device 115 | return args 116 | 117 | 118 | args = parse_arguments() 119 | 120 | logging.basicConfig(level=args.loglevel) 121 | 122 | logging.info(f'Serving app using {args}') -------------------------------------------------------------------------------- /dynamicbatch_ragpipeline/function.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import logging 4 | 5 | def efficient_attention_mask(batch_size, max_len, lengths, device, dtype, ones=True): 6 | lengths = torch.tensor(lengths) 7 | left = torch.arange(max_len).expand( 8 | batch_size, 1, 1, max_len) 9 | right = lengths.view( 10 | batch_size, 1, 1, 1) 11 | if ones: 12 | mask = left < right 13 | mask = mask.float() 14 | else: 15 | mask = left > right 16 | mask = mask.float().masked_fill_(mask, torch.finfo(dtype).min) 17 | return mask.to(device).type(dtype) 18 | 19 | def cleanup_cache(cache): 20 | try: 21 | if isinstance(cache, tuple) or isinstance(cache, list): 22 | cache = list(cache) 23 | for i in range(len(cache)): 24 | cache[i] = list(cache[i]) 25 | for _ in range(len(cache[i])): 26 | del cache[i][0] 27 | 28 | else: 29 | for _ in range(len(cache.key_cache)): 30 | del cache.key_cache[0] 31 | for _ in range(len(cache.value_cache)): 32 | del cache.value_cache[0] 33 | except Exception as e: 34 | logging.warning('failed to clear cache') 35 | -------------------------------------------------------------------------------- /dynamicbatch_ragpipeline/main.py: -------------------------------------------------------------------------------- 1 | from fastapi import FastAPI, Request 2 | from fastapi import File, Form, UploadFile 3 | from fastapi import HTTPException 4 | from fastapi.responses import StreamingResponse 5 | from sse_starlette import EventSourceResponse 6 | from dynamicbatch_ragpipeline.env import args 7 | from dynamicbatch_ragpipeline.doc_layout import ( 8 | load_model as doc_layout_load_model, 9 | predict as doc_layout_predict, 10 | step as doc_layout_step, 11 | ) 12 | from dynamicbatch_ragpipeline.ocr import ( 13 | load_model as ocr_load_model, 14 | predict as ocr_predict, 15 | prefill as ocr_prefill, 16 | step as ocr_step, 17 | ) 18 | from dynamicbatch_ragpipeline.function import cleanup_cache 19 | from dynamicbatch_ragpipeline.playwright_utils import ( 20 | to_pdf, 21 | initialize_browser, 22 | ) 23 | from transformers_openai.middleware import InsertMiddleware 24 | from pydantic import BaseModel 25 | import asyncio 26 | import logging 27 | import uvicorn 28 | import tempfile 29 | 30 | 31 | app = FastAPI() 32 | 33 | app.add_middleware(InsertMiddleware, max_concurrent=args.max_concurrent) 34 | 35 | @app.get('/') 36 | async def hello_world(): 37 | return {'hello': 'world'} 38 | 39 | if args.enable_doc_layout: 40 | logging.info('enabling document layout') 41 | 42 | @app.post('/doc_layout') 43 | async def doc_layout( 44 | file: bytes = File(), 45 | iou_threshold: float = Form(0.45), 46 | ratio_x: float = Form(2.0), 47 | ratio_y: float = Form(2.0), 48 | request: Request = None, 49 | ): 50 | """ 51 | Support pdf file, one file multiple pages. Will return list of images in base64 with list of layouts. 52 | """ 53 | 54 | with tempfile.NamedTemporaryFile(suffix='.pdf') as temp_file: 55 | temp_file.write(file) 56 | 57 | r = await doc_layout_predict( 58 | temp_file, 59 | iou_threshold = iou_threshold, 60 | ratio_x = ratio_x, 61 | ratio_y = ratio_y, 62 | request = request 63 | ) 64 | return r 65 | 66 | doc_layout_load_model() 67 | 68 | @app.on_event("startup") 69 | async def startup_event(): 70 | app.state.background_doc_layout_step = asyncio.create_task(doc_layout_step()) 71 | 72 | @app.on_event("shutdown") 73 | async def shutdown_event(): 74 | app.state.background_doc_layout_step.cancel() 75 | try: 76 | await app.state.background_doc_layout_step 77 | except asyncio.CancelledError: 78 | pass 79 | 80 | if args.enable_ocr: 81 | logging.info('enabling OCR') 82 | 83 | @app.post('/ocr') 84 | async def ocr( 85 | image: bytes = File(), 86 | mode: str = Form('format'), 87 | max_tokens: int = Form(4096), 88 | stream: bool = Form(False), 89 | request: Request = None, 90 | ): 91 | """ 92 | Convert image to text using OCR. 93 | 94 | Support 2 modes, 95 | 96 | 1. `plain`, plain text. 97 | 98 | 2. `format`, will format into latex. 99 | 100 | """ 101 | mode = mode.lower() 102 | if mode not in {'plain', 'format'}: 103 | raise HTTPException(status_code=400, detail='mode only support `plain` or `format`.') 104 | 105 | generator = ocr_predict(image, mode = mode, max_tokens = max_tokens, stream = stream, request = request) 106 | r = await generator 107 | if stream: 108 | return EventSourceResponse(r, headers=HEADERS) 109 | else: 110 | return r 111 | 112 | ocr_load_model() 113 | 114 | @app.on_event("startup") 115 | async def startup_event(): 116 | app.state.background_ocr_prefill = asyncio.create_task(ocr_prefill()) 117 | app.state.background_ocr_step = asyncio.create_task(ocr_step()) 118 | 119 | @app.on_event("shutdown") 120 | async def shutdown_event(): 121 | app.state.background_ocr_prefill.cancel() 122 | app.state.background_ocr_step.cancel() 123 | try: 124 | await app.state.background_ocr_prefill 125 | await app.state.background_ocr_step 126 | except asyncio.CancelledError: 127 | pass 128 | 129 | if args.enable_url_to_pdf: 130 | logging.info('enabling URL to PDF') 131 | 132 | class URL(BaseModel): 133 | url: str = 'https://screenresolutiontest.com/screenresolution/' 134 | viewport_weight: int = 1470 135 | viewport_height: int = 956 136 | 137 | @app.post('/url_to_pdf') 138 | async def url_to_pdf(url: URL): 139 | pdf_file = await to_pdf(**url.dict()) 140 | return StreamingResponse( 141 | pdf_file, 142 | media_type="application/pdf", 143 | headers={"Content-Disposition": "attachment; filename=mydocument.pdf"} 144 | ) 145 | 146 | @app.on_event('startup') 147 | async def warmup(): 148 | tasks = [] 149 | for index in range(args.playwright_max_concurrency): 150 | task = asyncio.create_task(initialize_browser(index=index)) 151 | tasks.append(task) 152 | 153 | await asyncio.gather(*tasks) 154 | 155 | if __name__ == "__main__": 156 | uvicorn.run( 157 | 'dynamicbatch_ragpipeline.main:app', 158 | host=args.host, 159 | port=args.port, 160 | log_level=args.loglevel.lower(), 161 | access_log=True, 162 | reload=args.reload, 163 | ) -------------------------------------------------------------------------------- /dynamicbatch_ragpipeline/ocr.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoModel, AutoTokenizer 2 | from sse_starlette import ServerSentEvent 3 | from dynamicbatch_ragpipeline import ocr_utils 4 | from dynamicbatch_ragpipeline.env import args 5 | from transformers_openai.function import efficient_attention_mask 6 | from transformers_openai.cache import ( 7 | DynamicLengthDecoderCache, 8 | ) 9 | from datetime import datetime 10 | from PIL import Image 11 | from io import BytesIO 12 | import numpy as np 13 | import torch 14 | import asyncio 15 | import time 16 | import logging 17 | import json 18 | 19 | model = None 20 | tokenizer = None 21 | global_cache = None 22 | 23 | torch_dtype = torch.bfloat16 24 | 25 | device = 'cpu' 26 | if args.accelerator_type == 'cuda': 27 | if not torch.cuda.is_available(): 28 | logging.warning('CUDA is not available, fallback to CPU.') 29 | else: 30 | device = 'cuda' 31 | 32 | prefill_queue = asyncio.Queue() 33 | step_queue = asyncio.Queue() 34 | 35 | def load_model(): 36 | global model, tokenizer, global_cache 37 | 38 | tokenizer = AutoTokenizer.from_pretrained('ucaslcl/GOT-OCR2_0', trust_remote_code=True) 39 | model = AutoModel.from_pretrained( 40 | 'ucaslcl/GOT-OCR2_0', 41 | trust_remote_code=True, 42 | torch_dtype = torch_dtype, 43 | attn_implementation = 'sdpa' 44 | ) 45 | model = model.eval().to(device) 46 | global_cache = DynamicLengthDecoderCache() 47 | 48 | if args.torch_compile: 49 | logging.info('enabling torch compile for OCR') 50 | 51 | model.model.vision_tower_high.forward = torch.compile( 52 | model.model.vision_tower_high.forward, 53 | ) 54 | model.model.mm_projector_vary.forward = torch.compile( 55 | model.model.mm_projector_vary.forward, 56 | ) 57 | ocr_utils.image_processor_high.transform = torch.compile( 58 | ocr_utils.image_processor_high.transform, 59 | ) 60 | 61 | with torch.no_grad(): 62 | for i in range(3): 63 | logging.info(f'{i}, warming up vision tower') 64 | image = torch.zeros(3, 1000, 1000).to(device) 65 | image = ocr_utils.image_processor_high(image).unsqueeze(0).type(model.dtype) 66 | cnn_feature = model.model.vision_tower_high(image) 67 | cnn_feature = cnn_feature.flatten(2).permute(0, 2, 1) 68 | model.model.mm_projector_vary(cnn_feature) 69 | del cnn_feature 70 | 71 | 72 | async def prefill(): 73 | need_sleep = True 74 | while True: 75 | if need_sleep: 76 | await asyncio.sleep(args.dynamic_batching_microsleep) 77 | 78 | try: 79 | need_sleep = True 80 | batch = [] 81 | while not prefill_queue.empty(): 82 | try: 83 | request = await asyncio.wait_for(prefill_queue.get(), timeout=1e-6) 84 | batch.append(request) 85 | if len(batch) >= args.dynamic_batching_ocr_batch_size: 86 | need_sleep = False 87 | break 88 | 89 | except asyncio.TimeoutError: 90 | break 91 | 92 | if not len(batch): 93 | continue 94 | 95 | futures = [batch[i][0] for i in range(len(batch))] 96 | input_img = [batch[i][1] for i in range(len(batch))] 97 | modes = [batch[i][3] for i in range(len(batch))] 98 | uuids = [batch[i][4] for i in range(len(batch))] 99 | 100 | logging.debug(f'{str(datetime.now())} OCR prefill batch size of {len(uuids)}') 101 | 102 | prompts = [] 103 | for f in modes: 104 | if f == 'format': 105 | qs = 'OCR with format: ' 106 | else: 107 | qs = 'OCR: ' 108 | 109 | qs = ocr_utils.qs + qs 110 | conv = ocr_utils.conv_mpt.copy() 111 | conv.append_message(conv.roles[0], qs) 112 | conv.append_message(conv.roles[1], None) 113 | prompt = conv.get_prompt() 114 | prompts.append(prompt) 115 | 116 | with torch.no_grad(): 117 | images = [] 118 | for i in range(len(input_img)): 119 | image = Image.open(BytesIO(input_img[i])).convert('RGB') 120 | image = ocr_utils.image_processor_high.to_tensor(image).to(device) 121 | image_tensor = ocr_utils.image_processor_high(image).unsqueeze(0).type(model.dtype) 122 | images.append(image_tensor) 123 | 124 | input_ids = tokenizer(prompts, return_tensors = 'pt', padding = 'longest') 125 | input_ids.pop('token_type_ids', None) 126 | lengths = input_ids['attention_mask'].sum(axis = 1) 127 | for k in input_ids.keys(): 128 | input_ids[k] = input_ids[k].to(device) 129 | 130 | out = model( 131 | **input_ids, 132 | images = images, 133 | past_key_values = None, 134 | use_cache = True, 135 | return_dict = False, 136 | ) 137 | out_logits = out[0] 138 | out_caches = out[1] 139 | 140 | cache_exists = len(global_cache.key_cache) 141 | 142 | for k in range(len(out_caches)): 143 | 144 | key_cache = {} 145 | value_cache = {} 146 | for i in range(len(batch)): 147 | key_cache[uuids[i]] = out_caches[k][0][i: i + 1, :, :lengths[i]] 148 | value_cache[uuids[i]] = out_caches[k][1][i: i + 1, :, :lengths[i]] 149 | 150 | if cache_exists: 151 | global_cache.key_cache[k].update(key_cache) 152 | global_cache.value_cache[k].update(value_cache) 153 | else: 154 | global_cache.key_cache.append(key_cache) 155 | global_cache.value_cache.append(value_cache) 156 | 157 | for i in range(len(futures)): 158 | futures[i].set_result((out_logits[i, -1:],)) 159 | 160 | for k in range(len(out_caches)): 161 | temp = list(out_caches[k]) 162 | for j in range(len(out_caches[k])): 163 | del temp[0] 164 | 165 | except Exception as e: 166 | logging.error(f'error in prefill {e}') 167 | try: 168 | futures = [batch[i][0] for i in range(len(batch))] 169 | for i in range(len(futures)): 170 | if not futures[i].done(): 171 | futures[i].set_exception(e) 172 | except: 173 | pass 174 | 175 | async def step(): 176 | need_sleep = True 177 | while True: 178 | if need_sleep: 179 | await asyncio.sleep(args.dynamic_batching_microsleep) 180 | 181 | try: 182 | need_sleep = True 183 | batch = [] 184 | while not step_queue.empty(): 185 | try: 186 | request = await asyncio.wait_for(step_queue.get(), timeout=1e-6) 187 | batch.append(request) 188 | 189 | if len(batch) >= args.dynamic_batching_ocr_batch_size: 190 | need_sleep = False 191 | break 192 | 193 | except asyncio.TimeoutError: 194 | break 195 | 196 | if not len(batch): 197 | continue 198 | 199 | futures = [batch[i][0] for i in range(len(batch))] 200 | inputs = [batch[i][1] for i in range(len(batch))] 201 | lengths = [batch[i][2] for i in range(len(batch))] 202 | uuids = [batch[i][4] for i in range(len(batch))] 203 | 204 | logging.debug(f'{str(datetime.now())} OCR step batch size of {len(uuids)}') 205 | 206 | global_cache.current_uuid = uuids 207 | 208 | max_len_lengths = max(lengths) 209 | with torch.no_grad(): 210 | inputs = torch.concat(inputs, dim=0) 211 | attention_mask = efficient_attention_mask( 212 | batch_size=len(lengths), 213 | max_len=max_len_lengths, 214 | lengths=lengths, 215 | device=device, 216 | dtype=torch_dtype, 217 | ) 218 | position_ids = torch.tensor([[l - 1 for l in lengths]]).T.to(device) 219 | out = model( 220 | inputs, 221 | images = None, 222 | attention_mask=attention_mask, 223 | position_ids=position_ids, 224 | past_key_values=global_cache, 225 | use_cache=True, 226 | return_dict=False 227 | ) 228 | 229 | out_logits = out[0] 230 | 231 | for i in range(len(futures)): 232 | futures[i].set_result((out_logits[i, -1:],)) 233 | 234 | except Exception as e: 235 | logging.error(f'error in step {e}') 236 | try: 237 | futures = [batch[i][0] for i in range(len(batch))] 238 | for i in range(len(futures)): 239 | if not futures[i].done(): 240 | futures[i].set_exception(e) 241 | except: 242 | pass 243 | 244 | async def streaming(image, mode, max_tokens, request): 245 | 246 | cache = None 247 | length = None 248 | inputs = image 249 | uuid = request.scope['request']['uuid'] 250 | 251 | try: 252 | for k in range(max_tokens): 253 | 254 | if k == 0: 255 | q = prefill_queue 256 | l = length 257 | else: 258 | q = step_queue 259 | l = length + k 260 | 261 | future = asyncio.Future() 262 | await q.put((future, inputs, l, mode, uuid)) 263 | out = await future 264 | 265 | logits = out[0] 266 | 267 | if length is None: 268 | length = global_cache.key_cache[0][uuid].shape[2] 269 | 270 | idx_next = logits.argmax(-1) 271 | token = tokenizer.decode(idx_next) 272 | 273 | if k == 0: 274 | request.scope['request']['time_first_token'] = time.time() 275 | 276 | if token == ocr_utils.stop_str: 277 | break 278 | 279 | del logits 280 | inputs = idx_next.unsqueeze(0) 281 | 282 | data = { 283 | 'token': token 284 | } 285 | yield json.dumps(data) 286 | await asyncio.sleep(0) 287 | 288 | request.scope['request']['time_max_tokens'] = time.time() 289 | request.scope['request']['total_tokens'] = k 290 | 291 | except asyncio.CancelledError as e: 292 | logging.warning(f"model step cancelled {uuid}") 293 | yield ServerSentEvent(**{"data": str(e)}) 294 | 295 | except Exception as e: 296 | logging.error(f"model step exception {e} {uuid}") 297 | yield ServerSentEvent(**{"data": str(e)}) 298 | 299 | finally: 300 | logging.debug(f'purging {uuid} KV cache') 301 | for i in range(len(global_cache.key_cache)): 302 | global_cache.key_cache[i].pop(uuid, None) 303 | global_cache.value_cache[i].pop(uuid, None) 304 | 305 | async def predict(image, mode = 'format', max_tokens = 4096, stream = False, request = None): 306 | if model is None: 307 | load_model() 308 | 309 | func = streaming(image=image, mode=mode, max_tokens=max_tokens, request=request) 310 | if stream: 311 | return func 312 | else: 313 | tokens = [] 314 | async for data in func: 315 | if isinstance(data, ServerSentEvent): 316 | continue 317 | data = json.loads(data) 318 | tokens.append(data['token']) 319 | 320 | return { 321 | 'result': ''.join(tokens) 322 | } 323 | 324 | -------------------------------------------------------------------------------- /dynamicbatch_ragpipeline/ocr_utils.py: -------------------------------------------------------------------------------- 1 | from torchvision import transforms 2 | from torchvision.transforms.functional import InterpolationMode 3 | from enum import auto, Enum 4 | from typing import List, Optional, Tuple, Union 5 | import dataclasses 6 | 7 | class SeparatorStyle(Enum): 8 | """Different separator style.""" 9 | SINGLE = auto() 10 | TWO = auto() 11 | MPT = auto() 12 | 13 | 14 | @dataclasses.dataclass 15 | class Conversation: 16 | """A class that keeps all conversation history.""" 17 | system: str 18 | roles: List[str] 19 | messages: List[List[str]] 20 | offset: int 21 | sep_style: SeparatorStyle = SeparatorStyle.SINGLE 22 | sep: str = "<|im_end|>" 23 | sep2: str = None 24 | version: str = "Unknown" 25 | 26 | skip_next: bool = False 27 | 28 | def get_prompt(self): 29 | if self.sep_style == SeparatorStyle.SINGLE: 30 | ret = self.system + self.sep + '\n' 31 | for role, message in self.messages: 32 | if message: 33 | if type(message) is tuple: 34 | message, _, _ = message 35 | ret += role + ": " + message + self.sep 36 | else: 37 | ret += role + ":" 38 | return ret 39 | elif self.sep_style == SeparatorStyle.TWO: 40 | seps = [self.sep, self.sep2] 41 | ret = self.system + seps[0] 42 | for i, (role, message) in enumerate(self.messages): 43 | if message: 44 | if type(message) is tuple: 45 | message, _, _ = message 46 | ret += role + ": " + message + seps[i % 2] 47 | else: 48 | ret += role + ":" 49 | return ret 50 | if self.sep_style == SeparatorStyle.MPT: 51 | if self.system: 52 | ret = self.system + self.sep 53 | else: 54 | ret = '' 55 | for role, message in self.messages: 56 | if message: 57 | if type(message) is tuple: 58 | message, _, _ = message 59 | ret += role + message + self.sep 60 | else: 61 | ret += role 62 | return ret 63 | else: 64 | raise ValueError(f"Invalid style: {self.sep_style}") 65 | 66 | 67 | def append_message(self, role, message): 68 | self.messages.append([role, message]) 69 | 70 | def copy(self): 71 | return Conversation( 72 | system=self.system, 73 | roles=self.roles, 74 | messages=[[x, y] for x, y in self.messages], 75 | offset=self.offset, 76 | sep_style=self.sep_style, 77 | sep=self.sep, 78 | sep2=self.sep2) 79 | 80 | class GOTImageEvalProcessor: 81 | def __init__(self, image_size=384, mean=None, std=None): 82 | if mean is None: 83 | mean = (0.48145466, 0.4578275, 0.40821073) 84 | if std is None: 85 | std = (0.26862954, 0.26130258, 0.27577711) 86 | 87 | self.normalize = transforms.Normalize(mean, std) 88 | self.to_tensor = transforms.ToTensor() 89 | 90 | self.transform = transforms.Compose( 91 | [ 92 | transforms.Resize( 93 | (image_size, image_size), interpolation=InterpolationMode.BICUBIC 94 | ), 95 | self.normalize, 96 | ] 97 | ) 98 | def __call__(self, item): 99 | return self.transform(item) 100 | 101 | conv_mpt = Conversation( 102 | system="""<|im_start|>system 103 | You should follow the instructions carefully and explain your answers in detail.""", 104 | roles=("<|im_start|>user\n", "<|im_start|>assistant\n"), 105 | version="mpt", 106 | messages=(), 107 | offset=0, 108 | sep_style=SeparatorStyle.MPT, 109 | sep="<|im_end|>", 110 | ) 111 | 112 | DEFAULT_IMAGE_TOKEN = '' 113 | DEFAULT_IMAGE_PATCH_TOKEN = '' 114 | DEFAULT_IM_START_TOKEN = '' 115 | DEFAULT_IM_END_TOKEN = '' 116 | image_processor_high = GOTImageEvalProcessor(image_size=1024) 117 | 118 | image_token_len = 256 119 | qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_PATCH_TOKEN*image_token_len + DEFAULT_IM_END_TOKEN + '\n' 120 | stop_str = '<|im_end|>' -------------------------------------------------------------------------------- /dynamicbatch_ragpipeline/playwright_utils.py: -------------------------------------------------------------------------------- 1 | from dynamicbatch_ragpipeline.env import args 2 | from playwright.async_api import async_playwright 3 | from datetime import datetime 4 | from fastapi import HTTPException 5 | from io import BytesIO 6 | import time 7 | import logging 8 | import asyncio 9 | 10 | playwrights = {} 11 | 12 | async def initialize_browser(index, clear_first = False): 13 | global playwrights 14 | 15 | if clear_first: 16 | logging.info(f'clearing playwright {index}') 17 | for k in list(playwrights[index]): 18 | try: 19 | await playwrights[index][k].close() 20 | except: 21 | pass 22 | try: 23 | del playwrights[index][k] 24 | except: 25 | pass 26 | 27 | logging.info(f'initializing playwright {index}') 28 | 29 | playwright = await async_playwright().start() 30 | browser = await playwright.chromium.launch(headless = True) 31 | page = await browser.new_page() 32 | playwrights[index] = { 33 | 'playwright': playwright, 34 | 'browser': browser, 35 | 'page': page, 36 | 'available': True, 37 | 'last_emit': datetime.now() 38 | } 39 | 40 | async def dead(index): 41 | 42 | died = False 43 | try: 44 | if playwrights[index]['page'].is_closed(): 45 | died = True 46 | except: 47 | pass 48 | 49 | try: 50 | if not playwrights[index]['browser'].is_connected(): 51 | died = True 52 | except: 53 | pass 54 | 55 | if died: 56 | await initialize_browser(index=index, clear_first=True) 57 | 58 | return died 59 | 60 | async def to_pdf(url, viewport_weight, viewport_height): 61 | index = 0 62 | found = False 63 | try: 64 | while True: 65 | for index in range(args.playwright_max_concurrency): 66 | if playwrights[index]['available']: 67 | playwrights[index]['available'] = False 68 | found = True 69 | break 70 | 71 | await asyncio.sleep(1e-9) 72 | 73 | if found: 74 | break 75 | 76 | await playwrights[index]['page'].set_viewport_size({"width": viewport_weight, "height": viewport_height}) 77 | await playwrights[index]['page'].goto(url) 78 | pdf = await playwrights[index]['page'].pdf() 79 | playwrights[index]['available'] = True 80 | return BytesIO(pdf) 81 | 82 | except asyncio.CancelledError as e: 83 | await dead(index) 84 | playwrights[index]['available'] = True 85 | 86 | except Exception as e: 87 | await dead(index) 88 | playwrights[index]['available'] = True 89 | raise HTTPException(status_code=500, detail=f'failed, {e}, please retry.') 90 | 91 | 92 | 93 | 94 | -------------------------------------------------------------------------------- /html-to-pdf-only.yaml: -------------------------------------------------------------------------------- 1 | version: "3.3" 2 | 3 | services: 4 | dynamicbatch_ragpipeline: 5 | build: 6 | context: . 7 | deploy: 8 | resources: 9 | reservations: 10 | devices: 11 | - driver: nvidia 12 | count: 1 13 | capabilities: [gpu] 14 | container_name: dynamicbatch_ragpipeline 15 | environment: 16 | - PYTHONUNBUFFERED=1 17 | - HF_HUB_ENABLE_HF_TRANSFER=1 18 | - ENABLE_DOC_LAYOUT=false 19 | - ENABLE_OCR=false 20 | 21 | volumes: 22 | - "./dynamicbatch_ragpipeline:/home/ubuntu/dynamicbatch_ragpipeline" 23 | - "~/.cache/huggingface:/home/ubuntu/.cache/huggingface" 24 | ports: 25 | - "7088:7088" 26 | command: python3 -m dynamicbatch_ragpipeline.main --host 0.0.0.0 --port 7088 --reload true -------------------------------------------------------------------------------- /notebook/doc-layout.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mesolitica/dynamic-batch-RAG-pipeline/a25f1af3c686799c4171d3dd94e49961840acc61/notebook/doc-layout.png -------------------------------------------------------------------------------- /notebook/huggingface.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mesolitica/dynamic-batch-RAG-pipeline/a25f1af3c686799c4171d3dd94e49961840acc61/notebook/huggingface.pdf -------------------------------------------------------------------------------- /notebook/page1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mesolitica/dynamic-batch-RAG-pipeline/a25f1af3c686799c4171d3dd94e49961840acc61/notebook/page1.png -------------------------------------------------------------------------------- /notebook/page2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mesolitica/dynamic-batch-RAG-pipeline/a25f1af3c686799c4171d3dd94e49961840acc61/notebook/page2.png -------------------------------------------------------------------------------- /notebook/page3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mesolitica/dynamic-batch-RAG-pipeline/a25f1af3c686799c4171d3dd94e49961840acc61/notebook/page3.png -------------------------------------------------------------------------------- /notebook/page4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mesolitica/dynamic-batch-RAG-pipeline/a25f1af3c686799c4171d3dd94e49961840acc61/notebook/page4.png -------------------------------------------------------------------------------- /notebook/url-to-pdf.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "id": "3808afa9", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "import requests\n", 11 | "data = {\n", 12 | " 'url': 'https://huggingface.co/'\n", 13 | "}\n", 14 | "\n", 15 | "r = requests.post('http://localhost:7088/url_to_pdf', json = data)\n", 16 | "with open('huggingface.pdf', 'wb') as fopen:\n", 17 | " fopen.write(r._content)" 18 | ] 19 | }, 20 | { 21 | "cell_type": "code", 22 | "execution_count": 2, 23 | "id": "6700b25a", 24 | "metadata": {}, 25 | "outputs": [ 26 | { 27 | "data": { 28 | "text/html": [ 29 | "\n", 30 | " \n", 38 | " " 39 | ], 40 | "text/plain": [ 41 | "" 42 | ] 43 | }, 44 | "execution_count": 2, 45 | "metadata": {}, 46 | "output_type": "execute_result" 47 | } 48 | ], 49 | "source": [ 50 | "from IPython.display import IFrame\n", 51 | "\n", 52 | "IFrame('huggingface.pdf', width=600, height=600)" 53 | ] 54 | }, 55 | { 56 | "cell_type": "code", 57 | "execution_count": null, 58 | "id": "a9bfd597", 59 | "metadata": {}, 60 | "outputs": [], 61 | "source": [] 62 | } 63 | ], 64 | "metadata": { 65 | "kernelspec": { 66 | "display_name": "python3.10", 67 | "language": "python", 68 | "name": "python3.10" 69 | }, 70 | "language_info": { 71 | "codemirror_mode": { 72 | "name": "ipython", 73 | "version": 3 74 | }, 75 | "file_extension": ".py", 76 | "mimetype": "text/x-python", 77 | "name": "python", 78 | "nbconvert_exporter": "python", 79 | "pygments_lexer": "ipython3", 80 | "version": "3.10.15" 81 | } 82 | }, 83 | "nbformat": 4, 84 | "nbformat_minor": 5 85 | } 86 | -------------------------------------------------------------------------------- /notebook/url-to-pdf.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mesolitica/dynamic-batch-RAG-pipeline/a25f1af3c686799c4171d3dd94e49961840acc61/notebook/url-to-pdf.png -------------------------------------------------------------------------------- /push-dockerhub.sh: -------------------------------------------------------------------------------- 1 | docker build -t mesoliticadev/dynamic-batch-rag-pipeline . 2 | docker push mesoliticadev/dynamic-batch-rag-pipeline -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | fastapi 2 | uvicorn 3 | doclayout-yolo 4 | huggingface-hub 5 | torch 6 | torchvision 7 | python-multipart 8 | pymupdf 9 | verovio 10 | sse_starlette 11 | playwright 12 | transformers 13 | tiktoken 14 | git+https://github.com/mesolitica/transformers-openai-api.git@0990077672793866128d7ec7d6d5f24778bf999f -------------------------------------------------------------------------------- /runs/detect/train/args.yaml: -------------------------------------------------------------------------------- 1 | task: detect 2 | mode: train 3 | model: /home/husein/.cache/huggingface/hub/models--juliozhao--DocLayout-YOLO-DocStructBench/snapshots/8c3299a30b8ff29a1503c4431b035b93220f7b11/doclayout_yolo_docstructbench_imgsz1024.pt 4 | data: coco8.yaml 5 | epochs: 100 6 | time: null 7 | patience: 100 8 | batch: 16 9 | imgsz: 1024 10 | save: true 11 | save_period: 10 12 | val_period: 1 13 | cache: false 14 | device: null 15 | workers: 8 16 | project: null 17 | name: train 18 | exist_ok: true 19 | pretrained: true 20 | optimizer: auto 21 | verbose: true 22 | seed: 0 23 | deterministic: true 24 | single_cls: false 25 | rect: false 26 | cos_lr: false 27 | close_mosaic: 10 28 | resume: false 29 | amp: true 30 | fraction: 1.0 31 | profile: false 32 | freeze: null 33 | multi_scale: false 34 | overlap_mask: true 35 | mask_ratio: 4 36 | dropout: 0.0 37 | val: true 38 | split: val 39 | save_json: false 40 | save_hybrid: false 41 | conf: null 42 | iou: 0.7 43 | max_det: 300 44 | half: false 45 | dnn: false 46 | plots: true 47 | source: null 48 | vid_stride: 1 49 | stream_buffer: false 50 | visualize: false 51 | augment: false 52 | agnostic_nms: false 53 | classes: null 54 | retina_masks: false 55 | embed: null 56 | show: false 57 | save_frames: false 58 | save_txt: false 59 | save_conf: false 60 | save_crop: false 61 | show_labels: true 62 | show_conf: true 63 | show_boxes: true 64 | line_width: null 65 | format: torchscript 66 | keras: false 67 | optimize: false 68 | int8: false 69 | dynamic: false 70 | simplify: false 71 | opset: null 72 | workspace: 4 73 | nms: false 74 | lr0: 0.01 75 | lrf: 0.01 76 | momentum: 0.937 77 | weight_decay: 0.0005 78 | warmup_epochs: 3.0 79 | warmup_momentum: 0.8 80 | warmup_bias_lr: 0.1 81 | box: 7.5 82 | cls: 0.5 83 | dfl: 1.5 84 | pose: 12.0 85 | kobj: 1.0 86 | label_smoothing: 0.0 87 | nbs: 64 88 | hsv_h: 0.015 89 | hsv_s: 0.7 90 | hsv_v: 0.4 91 | degrees: 0.0 92 | translate: 0.1 93 | scale: 0.5 94 | shear: 0.0 95 | perspective: 0.0 96 | flipud: 0.0 97 | fliplr: 0.5 98 | bgr: 0.0 99 | mosaic: 1.0 100 | mixup: 0.0 101 | copy_paste: 0.0 102 | auto_augment: randaugment 103 | erasing: 0.4 104 | crop_fraction: 1.0 105 | cfg: null 106 | tracker: botsort.yaml 107 | save_dir: runs/detect/train 108 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import setuptools 2 | 3 | 4 | __packagename__ = 'dynamic-batch-rag-pipeline' 5 | 6 | setuptools.setup( 7 | name=__packagename__, 8 | packages=setuptools.find_packages(), 9 | version='0.1', 10 | python_requires='>=3.8', 11 | description='Dynamic batching for Document Layout and OCR, suitable for RAG', 12 | author='huseinzol05', 13 | url='https://github.com/mesolitica/dynamic-batch-rag-pipeline', 14 | ) -------------------------------------------------------------------------------- /stress-test/2310.01889v4.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mesolitica/dynamic-batch-RAG-pipeline/a25f1af3c686799c4171d3dd94e49961840acc61/stress-test/2310.01889v4.pdf -------------------------------------------------------------------------------- /stress-test/README.md: -------------------------------------------------------------------------------- 1 | # Stress test 2 | 3 | ## Document layout on RTX 3090 Ti 4 | 5 | Rate of 10 users per second, total requests up to 100 users for 60 seconds, 6 | 7 | ```bash 8 | locust -f doc_layout.py -P 7001 -H http://localhost:7088 -r 10 -u 100 -t 60 9 | ``` 10 | 11 | ![alt text](doc_layout.png) 12 | 13 | ## OCR on RTX 3090 Ti 14 | 15 | Rate of 5 users per second, total requests up to 50 users for 60 seconds, 16 | 17 | ```bash 18 | locust -f ocr.py -P 7001 -H http://localhost:7088 -r 5 -u 50 -t 60 19 | ``` 20 | 21 | ### Continuous batching 22 | 23 | ![alt text](ocr.png) -------------------------------------------------------------------------------- /stress-test/doc_layout.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mesolitica/dynamic-batch-RAG-pipeline/a25f1af3c686799c4171d3dd94e49961840acc61/stress-test/doc_layout.png -------------------------------------------------------------------------------- /stress-test/doc_layout.py: -------------------------------------------------------------------------------- 1 | from locust import HttpUser, task 2 | from locust import events 3 | import itertools 4 | import time 5 | 6 | """ 7 | Make sure already running this, 8 | 9 | CUDA_VISIBLE_DEVICES=0 \ 10 | python3.10 -m dynamicbatch_ragpipeline.main \ 11 | --host 0.0.0.0 --port 7088 \ 12 | --dynamic-batching true \ 13 | --dynamic-batching-batch-size 64 14 | """ 15 | 16 | class HelloWorldUser(HttpUser): 17 | 18 | host = "http://127.0.0.1:7088" 19 | 20 | @task 21 | def hello_world(self): 22 | 23 | files = { 24 | 'file': ('2310.01889v4.pdf', open('2310.01889v4.pdf', 'rb'), 'application/pdf'), 25 | } 26 | r = self.client.post('/doc_layout', files=files) -------------------------------------------------------------------------------- /stress-test/ocr.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mesolitica/dynamic-batch-RAG-pipeline/a25f1af3c686799c4171d3dd94e49961840acc61/stress-test/ocr.png -------------------------------------------------------------------------------- /stress-test/ocr.py: -------------------------------------------------------------------------------- 1 | from locust import HttpUser, task 2 | from locust import events 3 | import itertools 4 | import time 5 | 6 | """ 7 | Make sure already running this, 8 | 9 | CUDA_VISIBLE_DEVICES=2 \ 10 | python3.10 -m dynamicbatch_ragpipeline.main \ 11 | --host 0.0.0.0 --port 7088 \ 12 | --dynamic-batching true \ 13 | --dynamic-batching-ocr-batch-size 32 14 | """ 15 | 16 | class HelloWorldUser(HttpUser): 17 | 18 | host = "http://127.0.0.1:7088" 19 | 20 | @task 21 | def hello_world(self): 22 | 23 | files = { 24 | 'image': ('table1.png', open('table1.png', 'rb'), 'image/png'), 25 | } 26 | r = self.client.post('/ocr', files=files) -------------------------------------------------------------------------------- /stress-test/table1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mesolitica/dynamic-batch-RAG-pipeline/a25f1af3c686799c4171d3dd94e49961840acc61/stress-test/table1.png -------------------------------------------------------------------------------- /stress-test/table2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mesolitica/dynamic-batch-RAG-pipeline/a25f1af3c686799c4171d3dd94e49961840acc61/stress-test/table2.png -------------------------------------------------------------------------------- /stress-test/title.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mesolitica/dynamic-batch-RAG-pipeline/a25f1af3c686799c4171d3dd94e49961840acc61/stress-test/title.png --------------------------------------------------------------------------------