├── .dockerignore
├── .gitignore
├── Dockerfile
├── README.md
├── docker-compose.yaml
├── dynamicbatch_ragpipeline
├── __init__.py
├── cache.py
├── doc_layout.py
├── env.py
├── function.py
├── main.py
├── ocr.py
├── ocr_utils.py
└── playwright_utils.py
├── html-to-pdf-only.yaml
├── notebook
├── doc-layout.png
├── document-layout.ipynb
├── huggingface.pdf
├── ocr.ipynb
├── page1.png
├── page2.png
├── page3.png
├── page4.png
├── url-to-pdf.ipynb
└── url-to-pdf.png
├── push-dockerhub.sh
├── requirements.txt
├── runs
└── detect
│ └── train
│ └── args.yaml
├── setup.py
└── stress-test
├── 2310.01889v4.pdf
├── README.md
├── doc_layout.png
├── doc_layout.py
├── ocr.png
├── ocr.py
├── table1.png
├── table2.png
└── title.png
/.dockerignore:
--------------------------------------------------------------------------------
1 | huggingface
2 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | *__pycache__
2 | huggingface
3 | *ipynb_checkpoints
4 | dist
5 | build
6 | *.egg-info
7 | *.whl
8 | *Untitled*.ipynb
--------------------------------------------------------------------------------
/Dockerfile:
--------------------------------------------------------------------------------
1 | FROM nvidia/cuda:12.4.0-runtime-ubuntu22.04 as base
2 |
3 | RUN apt update
4 | RUN apt install python3 python3-dev python3-pip -y
5 | RUN apt install sudo -y
6 | RUN adduser --quiet --disabled-password --shell /bin/bash --home /home/ubuntu --gecos "User" ubuntu
7 | RUN usermod -aG sudo ubuntu
8 | RUN echo "ubuntu ALL=(ALL) NOPASSWD:ALL" >> /etc/sudoers
9 | USER ubuntu
10 | WORKDIR /home/ubuntu
11 |
12 | RUN pip3 install pip -U
13 | RUN sudo apt update
14 |
15 | ADD requirements.txt .
16 | RUN pip3 install -r requirements.txt
17 |
18 | RUN ~/.local/bin/playwright install-deps
19 | RUN ~/.local/bin/playwright install
20 |
21 | COPY ./dynamicbatch_ragpipeline/ /home/ubuntu/dynamicbatch_ragpipeline
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # dynamic-batch-RAG-pipeline
2 |
3 | Dynamic batching for Document Layout and OCR, suitable for RAG.
4 |
5 | 1. Dynamic batching for SOTA Document Layout and OCR, suitable to serve better concurrency.
6 | 2. Continuous batching for Causal based OCR models.
7 | 3. Can serve user defined max concurrency.
8 | 4. Disconnected signal, so this is to ensure early stop for continuous batching.
9 | 5. Extra tool, convert any URL to PDF file.
10 |
11 | ## Available models
12 |
13 | ### Document Layout
14 |
15 | 1. https://github.com/opendatalab/DocLayout-YOLO
16 |
17 | ### OCR
18 |
19 | 1. https://huggingface.co/stepfun-ai/GOT-OCR2_0
20 |
21 | ## how to install
22 |
23 | Using PIP with git,
24 |
25 | ```bash
26 | pip3 install git+https://github.com/mesolitica/dynamic-batch-RAG-pipeline
27 | ```
28 |
29 | Or you can git clone,
30 |
31 | ```bash
32 | git clone https://github.com/mesolitica/dynamic-batch-RAG-pipeline && cd dynamic-batch-RAG-pipeline
33 | ```
34 |
35 | ## how to
36 |
37 | ### Supported parameters
38 |
39 | ```bash
40 | python3 -m dynamicbatch_ragpipeline.main --help
41 | ```
42 |
43 | ```text
44 | usage: main.py [-h] [--host HOST] [--port PORT] [--loglevel LOGLEVEL] [--reload RELOAD]
45 | [--enable-doc-layout ENABLE_DOC_LAYOUT] [--model-doc-layout MODEL_DOC_LAYOUT] [--enable-ocr ENABLE_OCR]
46 | [--model-ocr MODEL_OCR] [--dynamic-batching-microsleep DYNAMIC_BATCHING_MICROSLEEP]
47 | [--dynamic-batching-doc-layout-batch-size DYNAMIC_BATCHING_DOC_LAYOUT_BATCH_SIZE]
48 | [--dynamic-batching-ocr-batch-size DYNAMIC_BATCHING_OCR_BATCH_SIZE] [--accelerator-type ACCELERATOR_TYPE]
49 | [--max-concurrent MAX_CONCURRENT] [--static-cache STATIC_CACHE]
50 | [--static-cache-max-length STATIC_CACHE_MAX_LENGTH] [--enable-url-to-pdf ENABLE_URL_TO_PDF]
51 | [--playwright-max-concurrency PLAYWRIGHT_MAX_CONCURRENCY]
52 |
53 | Configuration parser
54 |
55 | options:
56 | -h, --help show this help message and exit
57 | --host HOST host name to host the app (default: 0.0.0.0, env: HOSTNAME)
58 | --port PORT port to host the app (default: 7088, env: PORT)
59 | --loglevel LOGLEVEL Logging level (default: INFO, env: LOGLEVEL)
60 | --reload RELOAD Enable hot loading (default: False, env: RELOAD)
61 | --enable-doc-layout ENABLE_DOC_LAYOUT
62 | Enable document layout detection (default: True, env: ENABLE_DOC_LAYOUT)
63 | --model-doc-layout MODEL_DOC_LAYOUT
64 | Model type (default: yolo10, env: MODEL_DOC_LAYOUT)
65 | --enable-ocr ENABLE_OCR
66 | Enable OCR (default: True, env: ENABLE_OCR)
67 | --model-ocr MODEL_OCR
68 | Model type (default: got_ocr2_0, env: MODEL_OCR)
69 | --dynamic-batching-microsleep DYNAMIC_BATCHING_MICROSLEEP
70 | microsleep to group dynamic batching, 1 / 1e-4 = 10k steps for second (default: 0.0001, env:
71 | DYNAMIC_BATCHING_MICROSLEEP)
72 | --dynamic-batching-doc-layout-batch-size DYNAMIC_BATCHING_DOC_LAYOUT_BATCH_SIZE
73 | maximum of batch size for document layout during dynamic batching (default: 16, env:
74 | DYNAMIC_BATCHING_DOC_LAYOUT_BATCH_SIZE)
75 | --dynamic-batching-ocr-batch-size DYNAMIC_BATCHING_OCR_BATCH_SIZE
76 | maximum of batch size for OCR during dynamic batching (default: 16, env:
77 | DYNAMIC_BATCHING_OCR_BATCH_SIZE)
78 | --accelerator-type ACCELERATOR_TYPE
79 | Accelerator type (default: cuda, env: ACCELERATOR_TYPE)
80 | --max-concurrent MAX_CONCURRENT
81 | Maximum concurrent requests (default: 100, env: MAX_CONCURRENT)
82 | --static-cache STATIC_CACHE
83 | Preallocate KV Cache for faster inference (default: False, env: STATIC_CACHE)
84 | --static-cache-max-length STATIC_CACHE_MAX_LENGTH
85 | Maximum concurrent requests (default: 8192, env: STATIC_CACHE_MAX_LENGTH)
86 | --enable-url-to-pdf ENABLE_URL_TO_PDF
87 | Enable URL to PDF using Playwright (default: True, env: ENABLE_URL_TO_PDF)
88 | --playwright-max-concurrency PLAYWRIGHT_MAX_CONCURRENCY
89 | Enable URL to PDF using Playwright (default: 1, env: PLAYWRIGHT_MAX_CONCURRENCY)
90 | ```
91 |
92 | **We support both args and OS environment**.
93 |
94 | ### Run
95 |
96 | ```
97 | python3 -m dynamicbatch_ragpipeline.main \
98 | --host 0.0.0.0 --port 7088
99 | ```
100 |
101 | #### Example document layout
102 |
103 | ```bash
104 | curl -X 'POST' \
105 | 'http://localhost:7088/doc_layout' \
106 | -H 'accept: application/json' \
107 | -H 'Content-Type: multipart/form-data' \
108 | -F 'file=@stress-test/2310.01889v4.pdf;type=application/pdf' \
109 | -F 'iou_threshold=0.45'
110 | ```
111 |
112 | Checkout [notebook/document-layout.ipynb](notebook/document-layout.ipynb).
113 |
114 |
115 |
116 | #### Example OCR
117 |
118 | ```bash
119 | curl -X 'POST' \
120 | 'http://localhost:7088/ocr' \
121 | -H 'accept: application/json' \
122 | -H 'Content-Type: multipart/form-data' \
123 | -F 'image=@stress-test/table2.png;type=image/png' \
124 | -F 'max_tokens=4096' \
125 | -F 'stream=false'
126 | ```
127 |
128 | **Because the backend is a continuous batching, so we support streaming**.
129 |
130 | Checkout [notebook/ocr.ipynb](notebook/ocr.ipynb).
131 |
132 |
133 |
134 |
135 |
136 |
137 |
138 |
139 |
140 | #### Example URL to PDF
141 |
142 | ```bash
143 | curl -X 'POST' \
144 | 'http://localhost:7088/url_to_pdf' \
145 | -H 'accept: application/json' \
146 | -H 'Content-Type: application/json' \
147 | -d '{
148 | "url": "https://huggingface.co/",
149 | "viewport_weight": 1470,
150 | "viewport_height": 956
151 | }'
152 | ```
153 |
154 | Checkout [notebook/url-to-pdf.ipynb](notebook/url-to-pdf.ipynb).
155 |
156 |
157 |
158 | **To support more concurrency for URL to PDF, make sure set `--playwright-max-concurrency` more than 1**.
159 |
160 | ## [Stress test](stress-test)
161 |
162 | ### Document layout
163 |
164 | Rate of 10 users per second, total requests up to 100 users for 60 seconds on a RTX 3090 Ti,
165 |
166 | 
167 |
168 | ### OCR
169 |
170 | Rate of 5 users per second, total requests up to 50 users for 60 seconds,
171 |
172 | 
173 |
--------------------------------------------------------------------------------
/docker-compose.yaml:
--------------------------------------------------------------------------------
1 | version: "3.3"
2 |
3 | services:
4 | dynamicbatch_ragpipeline:
5 | build:
6 | context: .
7 | deploy:
8 | resources:
9 | reservations:
10 | devices:
11 | - driver: nvidia
12 | count: 1
13 | capabilities: [gpu]
14 | container_name: dynamicbatch_ragpipeline
15 | environment:
16 | - PYTHONUNBUFFERED=1
17 | - HF_HUB_ENABLE_HF_TRANSFER=1
18 |
19 | volumes:
20 | - "./dynamicbatch_ragpipeline:/home/ubuntu/dynamicbatch_ragpipeline"
21 | - "~/.cache/huggingface:/home/ubuntu/.cache/huggingface"
22 | ports:
23 | - "7088:7088"
24 | command: python3 -m dynamicbatch_ragpipeline.main --host 0.0.0.0 --port 7088 --reload true
--------------------------------------------------------------------------------
/dynamicbatch_ragpipeline/__init__.py:
--------------------------------------------------------------------------------
1 | #
--------------------------------------------------------------------------------
/dynamicbatch_ragpipeline/cache.py:
--------------------------------------------------------------------------------
1 | from typing import List, Tuple, Optional, Dict, Any
2 | from transformers.cache_utils import Cache
3 | import torch
4 | import torch.nn.functional as F
5 | import time
6 |
7 | def pad_kv(caches):
8 | """
9 | List[head, seq, dims]
10 | """
11 |
12 | shapes = [caches[i].shape[2] for i in range(len(caches))]
13 | maxlen = max(shapes)
14 | if all(s == maxlen for s in shapes):
15 | return torch.concat(caches)
16 |
17 | new_caches = []
18 | for i in range(len(caches)):
19 | pad_val = (0, 0, 0, maxlen - caches[i].shape[2], 0, 0, 0, 0)
20 | pad = F.pad(caches[i], pad_val, value=0.0)
21 | new_caches.append(pad)
22 | return torch.concat(new_caches)
23 |
24 |
25 | class DynamicLengthDecoderCache(Cache):
26 |
27 | def __init__(self) -> None:
28 | self.key_cache: List[torch.Tensor] = []
29 | self.value_cache: List[torch.Tensor] = []
30 | self.current_uuid = []
31 |
32 | def batch_size(self):
33 | if len(self.key_cache) > 0:
34 | return len(self.key_cache[0])
35 | return 0
36 |
37 | def __len__(self):
38 | """
39 | Support for backwards-compatible `past_key_value` length, e.g. `len(past_key_value)`. This value corresponds
40 | to the number of layers in the model.
41 | """
42 | return len(self.key_cache)
43 |
44 | def update(
45 | self,
46 | key_states: torch.Tensor,
47 | value_states: torch.Tensor,
48 | layer_idx: int,
49 | cache_kwargs: Optional[Dict[str, Any]] = None,
50 | ) -> Tuple[torch.Tensor, torch.Tensor]:
51 |
52 | keys, values = [], []
53 | for i, k in enumerate(self.current_uuid):
54 | self.key_cache[layer_idx][k] = torch.cat(
55 | [self.key_cache[layer_idx][k], key_states[i: i + 1]], dim=-2)
56 | self.value_cache[layer_idx][k] = torch.cat(
57 | [self.value_cache[layer_idx][k], value_states[i: i + 1]], dim=-2)
58 | keys.append(self.key_cache[layer_idx][k])
59 | values.append(self.value_cache[layer_idx][k])
60 |
61 | k = pad_kv(keys)
62 | v = pad_kv(values)
63 |
64 | return k, v
65 |
66 | def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
67 | """Returns the sequence length of the cached states. A layer index can be optionally passed."""
68 | # TODO: deprecate this function in favor of `cache_position`
69 | if len(self.key_cache) <= layer_idx:
70 | return 0
71 |
72 | lengths = [self.key_cache[0][k].shape[2] for k in self.current_uuid]
73 | return max(lengths)
74 |
75 | def get_max_length(self) -> Optional[int]:
76 | """Returns the maximum sequence length of the cached states. DynamicCache does not have a maximum length."""
77 | return None
78 |
79 | class StaticLengthDecoderCache(Cache):
80 |
81 | def __init__(
82 | self,
83 | batch_size = 20,
84 | max_length = 8192,
85 | device = 'cuda',
86 | head_size = 16,
87 | dim_size = 64,
88 | num_hidden_layers = 24,
89 | dtype = torch.bfloat16,
90 | ) -> None:
91 |
92 | self.key_cache, self.value_cache = [], []
93 | for _ in range(num_hidden_layers):
94 | self.key_cache.append(
95 | torch.zeros(
96 | batch_size,
97 | head_size,
98 | max_length,
99 | dim_size,
100 | dtype = torch.bfloat16).to(device)
101 | )
102 | self.value_cache.append(
103 | torch.zeros(
104 | batch_size,
105 | head_size,
106 | max_length,
107 | dim_size,
108 | dtype = torch.bfloat16).to(device)
109 | )
110 | self.lengths = []
111 |
112 | def __getitem__(self, layer_idx: int) -> List[Tuple[torch.Tensor]]:
113 | """
114 | Support for backwards-compatible `past_key_value` indexing, e.g. `past_key_value[0][0].shape[2]` to get the
115 | sequence length.
116 | """
117 | if layer_idx < len(self):
118 | return self.key_cache[layer_idx], self.value_cache[layer_idx]
119 | else:
120 | raise KeyError(
121 | f"Cache only has {len(self)} layers, attempted to access layer with index {layer_idx}")
122 |
123 | def __len__(self):
124 | """
125 | Support for backwards-compatible `past_key_value` length, e.g. `len(past_key_value)`. This value corresponds
126 | to the number of layers in the model.
127 | """
128 | return len(self.key_cache)
129 |
130 | def update(
131 | self,
132 | key_states: torch.Tensor,
133 | value_states: torch.Tensor,
134 | layer_idx: int,
135 | cache_kwargs: Optional[Dict[str, Any]] = None,
136 | ) -> Tuple[torch.Tensor, torch.Tensor]:
137 |
138 | cache_position = cache_kwargs['cache_position']
139 | maxlen = max(self.lengths)
140 |
141 | for i in range(len(key_states)):
142 | self.key_cache[layer_idx][i, :, self.lengths[i]] = key_states[i,:,0]
143 | self.value_cache[layer_idx][i, :, self.lengths[i]] = value_states[i,:,0]
144 |
145 | k = self.key_cache[layer_idx][:len(key_states), :, :maxlen]
146 | v = self.value_cache[layer_idx][:len(key_states), :, :maxlen]
147 |
148 | return k, v
149 |
150 | def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
151 | """Returns the sequence length of the cached states. A layer index can be optionally passed."""
152 | # TODO: deprecate this function in favor of `cache_position`
153 | if len(self.key_cache) <= layer_idx:
154 | return 0
155 | return max(self.lengths)
156 |
157 | def get_max_length(self) -> Optional[int]:
158 | """Returns the maximum sequence length of the cached states. DynamicCache does not have a maximum length."""
159 | return None
--------------------------------------------------------------------------------
/dynamicbatch_ragpipeline/doc_layout.py:
--------------------------------------------------------------------------------
1 | from doclayout_yolo import YOLOv10
2 | from huggingface_hub import snapshot_download
3 | from dynamicbatch_ragpipeline.env import args
4 | from datetime import datetime
5 | from io import BytesIO
6 | from PIL import Image
7 | import base64
8 | import pymupdf
9 | import numpy as np
10 | import asyncio
11 | import torch
12 | import torchvision
13 | import os
14 | import logging
15 | import time
16 | import sys
17 |
18 | id_to_names = {
19 | 0: 'title',
20 | 1: 'plain text',
21 | 2: 'abandon',
22 | 3: 'figure',
23 | 4: 'figure_caption',
24 | 5: 'table',
25 | 6: 'table_caption',
26 | 7: 'table_footnote',
27 | 8: 'isolate_formula',
28 | 9: 'formula_caption'
29 | }
30 |
31 | model = None
32 | device = 'cpu'
33 | if args.accelerator_type == 'cuda':
34 | if not torch.cuda.is_available():
35 | logging.warning('CUDA is not available, fallback to CPU.')
36 | else:
37 | device = 'cuda'
38 |
39 | step_queue = asyncio.Queue()
40 |
41 | def load_model():
42 | global model
43 |
44 | model_dir = snapshot_download('juliozhao/DocLayout-YOLO-DocStructBench')
45 | model = YOLOv10(os.path.join(model_dir, 'doclayout_yolo_docstructbench_imgsz1024.pt'))
46 |
47 | if args.torch_compile:
48 | logging.info('enabling torch compile for doc layout')
49 | model.compile()
50 |
51 | async def step():
52 | need_sleep = True
53 | while True:
54 | if need_sleep:
55 | await asyncio.sleep(args.dynamic_batching_microsleep)
56 |
57 | try:
58 | need_sleep = True
59 | batch = []
60 | while not step_queue.empty():
61 | try:
62 | request = await asyncio.wait_for(step_queue.get(), timeout=1e-9)
63 | batch.append(request)
64 | if len(batch) >= args.dynamic_batching_doc_layout_batch_size:
65 | need_sleep = False
66 | break
67 |
68 | except asyncio.TimeoutError:
69 | break
70 |
71 | if not len(batch):
72 | continue
73 |
74 | futures = [batch[i][0] for i in range(len(batch))]
75 | input_img = [batch[i][1] for i in range(len(batch))]
76 |
77 | logging.debug(f'{str(datetime.now())} document layout step batch size of {len(input_img)}')
78 |
79 | with torch.no_grad():
80 | det_res = model.predict(
81 | input_img,
82 | imgsz=1024,
83 | conf=0.25,
84 | device=device,
85 | batch=len(input_img)
86 | )
87 |
88 | for i in range(len(futures)):
89 | boxes = det_res[i].__dict__['boxes'].xyxy
90 | classes = det_res[i].__dict__['boxes'].cls
91 | scores = det_res[i].__dict__['boxes'].conf
92 | futures[i].set_result((boxes, classes, scores))
93 |
94 | except Exception as e:
95 | logging.error(e)
96 | try:
97 | futures = [batch[i][0] for i in range(len(batch))]
98 | for i in range(len(futures)):
99 | if not futures[i].done():
100 | futures[i].set_exception(e)
101 | except:
102 | pass
103 |
104 | async def predict(
105 | file,
106 | iou_threshold = 0.45,
107 | ratio_x = 2.0,
108 | ratio_y = 2.0,
109 | request = None,
110 | ):
111 | doc = pymupdf.open(file)
112 | mat = pymupdf.Matrix(ratio_x, ratio_y)
113 | futures, images = [], []
114 |
115 | for page in doc:
116 | pix = page.get_pixmap(matrix=mat)
117 | image = np.frombuffer(pix.samples_mv, dtype=np.uint8).reshape((pix.height, pix.width, 3)).copy()
118 | images.append(image)
119 |
120 | future = asyncio.Future()
121 | await step_queue.put((future, image))
122 | futures.append(future)
123 |
124 | before = time.time()
125 | results = await asyncio.gather(*futures)
126 |
127 | actual_results = []
128 |
129 | for i in range(len(results)):
130 | boxes, classes, scores = results[i]
131 | indices = torchvision.ops.nms(
132 | boxes=torch.Tensor(boxes),
133 | scores=torch.Tensor(scores),
134 | iou_threshold=iou_threshold,
135 | )
136 | boxes, scores, classes = boxes[indices], scores[indices], classes[indices]
137 | if len(boxes.shape) == 1:
138 | boxes = np.expand_dims(boxes, 0)
139 | scores = np.expand_dims(scores, 0)
140 | classes = np.expand_dims(classes, 0)
141 |
142 | image = Image.fromarray(images[i])
143 | buffered = BytesIO()
144 | image.save(buffered, format="JPEG")
145 | img = base64.b64encode(buffered.getvalue()).decode("utf-8")
146 |
147 | coordinates = boxes.int().cpu().numpy().tolist()
148 | classes = [id_to_names[int(c)] for c in classes]
149 | boxes = []
150 | for c in coordinates:
151 | x_min, y_min, x_max, y_max = c
152 | boxes.append({
153 | 'x_min': x_min,
154 | 'y_min': y_min,
155 | 'x_max': x_max,
156 | 'y_max': y_max,
157 | })
158 | sorted_indices = sorted(range(len(boxes)), key=lambda i: (boxes[i]['y_min'], boxes[i]['x_min']))
159 | sorted_boxes = [boxes[i] for i in sorted_indices]
160 | sorted_classes = [classes[i] for i in sorted_indices]
161 | d = {
162 | 'classes': sorted_classes,
163 | 'coordinates': sorted_boxes,
164 | 'img': img,
165 | }
166 | actual_results.append(d)
167 |
168 | after = time.time()
169 |
170 | stats = {
171 | 'total_page': len(images),
172 | 'page_per_second': len(images) / (after - before),
173 | }
174 | return {
175 | 'result': actual_results,
176 | 'stats': stats,
177 | }
178 |
179 |
--------------------------------------------------------------------------------
/dynamicbatch_ragpipeline/env.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import logging
3 | import os
4 | import torch
5 |
6 |
7 | def parse_arguments():
8 | parser = argparse.ArgumentParser(description='Configuration parser')
9 |
10 | parser.add_argument(
11 | '--host', type=str, default=os.environ.get('HOSTNAME', '0.0.0.0'),
12 | help='host name to host the app (default: %(default)s, env: HOSTNAME)'
13 | )
14 | parser.add_argument(
15 | '--port', type=int, default=int(os.environ.get('PORT', '7088')),
16 | help='port to host the app (default: %(default)s, env: PORT)'
17 | )
18 | parser.add_argument(
19 | '--loglevel', default=os.environ.get('LOGLEVEL', 'INFO').upper(),
20 | help='Logging level (default: %(default)s, env: LOGLEVEL)'
21 | )
22 | parser.add_argument(
23 | '--reload', type=lambda x: x.lower() == 'true',
24 | default=os.environ.get('reload', 'false').lower() == 'true',
25 | help='Enable hot loading (default: %(default)s, env: RELOAD)'
26 | )
27 | parser.add_argument(
28 | '--enable-doc-layout', type=lambda x: x.lower() == 'true',
29 | default=os.environ.get('ENABLE_DOC_LAYOUT', 'true').lower() == 'true',
30 | help='Enable document layout detection (default: %(default)s, env: ENABLE_DOC_LAYOUT)'
31 | )
32 | parser.add_argument(
33 | '--model-doc-layout',
34 | default=os.environ.get('MODEL_DOC_LAYOUT', 'yolo10'),
35 | help='Model type (default: %(default)s, env: MODEL_DOC_LAYOUT)'
36 | )
37 | parser.add_argument(
38 | '--enable-ocr', type=lambda x: x.lower() == 'true',
39 | default=os.environ.get('ENABLE_OCR', 'true').lower() == 'true',
40 | help='Enable OCR (default: %(default)s, env: ENABLE_OCR)'
41 | )
42 | parser.add_argument(
43 | '--model-ocr',
44 | default=os.environ.get('MODEL_OCR', 'got_ocr2_0'),
45 | help='Model type (default: %(default)s, env: MODEL_OCR)'
46 | )
47 | parser.add_argument(
48 | '--dynamic-batching-microsleep', type=float,
49 | default=float(os.environ.get('DYNAMIC_BATCHING_MICROSLEEP', '1e-4')),
50 | help='microsleep to group dynamic batching, 1 / 1e-4 = 10k steps for second (default: %(default)s, env: DYNAMIC_BATCHING_MICROSLEEP)'
51 | )
52 | parser.add_argument(
53 | '--dynamic-batching-doc-layout-batch-size', type=int,
54 | default=int(os.environ.get('DYNAMIC_BATCHING_DOC_LAYOUT_BATCH_SIZE', '16')),
55 | help='maximum of batch size for document layout during dynamic batching (default: %(default)s, env: DYNAMIC_BATCHING_DOC_LAYOUT_BATCH_SIZE)'
56 | )
57 | parser.add_argument(
58 | '--dynamic-batching-ocr-batch-size', type=int,
59 | default=int(os.environ.get('DYNAMIC_BATCHING_OCR_BATCH_SIZE', '16')),
60 | help='maximum of batch size for OCR during dynamic batching (default: %(default)s, env: DYNAMIC_BATCHING_OCR_BATCH_SIZE)'
61 | )
62 | parser.add_argument(
63 | '--accelerator-type', default=os.environ.get('ACCELERATOR_TYPE', 'cuda'),
64 | help='Accelerator type (default: %(default)s, env: ACCELERATOR_TYPE)'
65 | )
66 | parser.add_argument(
67 | '--max-concurrent',
68 | type=int,
69 | default=int(os.environ.get('MAX_CONCURRENT', '100')),
70 | help='Maximum concurrent requests (default: %(default)s, env: MAX_CONCURRENT)'
71 | )
72 | parser.add_argument(
73 | '--static-cache', type=lambda x: x.lower() == 'true',
74 | default=os.environ.get('STATIC_CACHE', 'false').lower() == 'true',
75 | help='Preallocate KV Cache for faster inference (default: %(default)s, env: STATIC_CACHE)'
76 | )
77 | parser.add_argument(
78 | '--static-cache-max-length',
79 | type=int,
80 | default=int(os.environ.get('STATIC_CACHE_MAX_LENGTH', '8192')),
81 | help='Maximum concurrent requests (default: %(default)s, env: STATIC_CACHE_MAX_LENGTH)'
82 | )
83 | parser.add_argument(
84 | '--enable-url-to-pdf', type=lambda x: x.lower() == 'true',
85 | default=os.environ.get('ENABLE_URL_TO_PDF', 'true').lower() == 'true',
86 | help='Enable URL to PDF using Playwright (default: %(default)s, env: ENABLE_URL_TO_PDF)'
87 | )
88 | parser.add_argument(
89 | '--playwright-max-concurrency', type=int,
90 | default=int(os.environ.get('PLAYWRIGHT_MAX_CONCURRENCY', '1')),
91 | help='Enable URL to PDF using Playwright (default: %(default)s, env: PLAYWRIGHT_MAX_CONCURRENCY)'
92 | )
93 | parser.add_argument(
94 | '--torch-compile', type=lambda x: x.lower() == 'true',
95 | default=os.environ.get('TORCH_COMPILE', 'true').lower() == 'false',
96 | help='Torch compile necessary forwards, can speed up at least 1.5X (default: %(default)s, env: TORCH_COMPILE)'
97 | )
98 |
99 | args = parser.parse_args()
100 |
101 | if args.model_doc_layout not in {'yolo10'}:
102 | raise ValueError('Currently document layout, `--model-doc-layout` or `MODEL_DOC_LAYOUT` environment variable, only support https://github.com/opendatalab/DocLayout-YOLO')
103 |
104 | if args.model_ocr not in {'got_ocr2_0'}:
105 | raise ValueError('Currently OCR, `--model-ocr` or `MODEL_OCR` environment variable, only support https://huggingface.co/stepfun-ai/GOT-OCR2_0')
106 |
107 | device = 'cpu'
108 | if args.accelerator_type == 'cuda':
109 | if not torch.cuda.is_available():
110 | logging.warning('CUDA is not available, fallback to CPU.')
111 | else:
112 | device = 'cuda'
113 |
114 | args.device = device
115 | return args
116 |
117 |
118 | args = parse_arguments()
119 |
120 | logging.basicConfig(level=args.loglevel)
121 |
122 | logging.info(f'Serving app using {args}')
--------------------------------------------------------------------------------
/dynamicbatch_ragpipeline/function.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | import logging
4 |
5 | def efficient_attention_mask(batch_size, max_len, lengths, device, dtype, ones=True):
6 | lengths = torch.tensor(lengths)
7 | left = torch.arange(max_len).expand(
8 | batch_size, 1, 1, max_len)
9 | right = lengths.view(
10 | batch_size, 1, 1, 1)
11 | if ones:
12 | mask = left < right
13 | mask = mask.float()
14 | else:
15 | mask = left > right
16 | mask = mask.float().masked_fill_(mask, torch.finfo(dtype).min)
17 | return mask.to(device).type(dtype)
18 |
19 | def cleanup_cache(cache):
20 | try:
21 | if isinstance(cache, tuple) or isinstance(cache, list):
22 | cache = list(cache)
23 | for i in range(len(cache)):
24 | cache[i] = list(cache[i])
25 | for _ in range(len(cache[i])):
26 | del cache[i][0]
27 |
28 | else:
29 | for _ in range(len(cache.key_cache)):
30 | del cache.key_cache[0]
31 | for _ in range(len(cache.value_cache)):
32 | del cache.value_cache[0]
33 | except Exception as e:
34 | logging.warning('failed to clear cache')
35 |
--------------------------------------------------------------------------------
/dynamicbatch_ragpipeline/main.py:
--------------------------------------------------------------------------------
1 | from fastapi import FastAPI, Request
2 | from fastapi import File, Form, UploadFile
3 | from fastapi import HTTPException
4 | from fastapi.responses import StreamingResponse
5 | from sse_starlette import EventSourceResponse
6 | from dynamicbatch_ragpipeline.env import args
7 | from dynamicbatch_ragpipeline.doc_layout import (
8 | load_model as doc_layout_load_model,
9 | predict as doc_layout_predict,
10 | step as doc_layout_step,
11 | )
12 | from dynamicbatch_ragpipeline.ocr import (
13 | load_model as ocr_load_model,
14 | predict as ocr_predict,
15 | prefill as ocr_prefill,
16 | step as ocr_step,
17 | )
18 | from dynamicbatch_ragpipeline.function import cleanup_cache
19 | from dynamicbatch_ragpipeline.playwright_utils import (
20 | to_pdf,
21 | initialize_browser,
22 | )
23 | from transformers_openai.middleware import InsertMiddleware
24 | from pydantic import BaseModel
25 | import asyncio
26 | import logging
27 | import uvicorn
28 | import tempfile
29 |
30 |
31 | app = FastAPI()
32 |
33 | app.add_middleware(InsertMiddleware, max_concurrent=args.max_concurrent)
34 |
35 | @app.get('/')
36 | async def hello_world():
37 | return {'hello': 'world'}
38 |
39 | if args.enable_doc_layout:
40 | logging.info('enabling document layout')
41 |
42 | @app.post('/doc_layout')
43 | async def doc_layout(
44 | file: bytes = File(),
45 | iou_threshold: float = Form(0.45),
46 | ratio_x: float = Form(2.0),
47 | ratio_y: float = Form(2.0),
48 | request: Request = None,
49 | ):
50 | """
51 | Support pdf file, one file multiple pages. Will return list of images in base64 with list of layouts.
52 | """
53 |
54 | with tempfile.NamedTemporaryFile(suffix='.pdf') as temp_file:
55 | temp_file.write(file)
56 |
57 | r = await doc_layout_predict(
58 | temp_file,
59 | iou_threshold = iou_threshold,
60 | ratio_x = ratio_x,
61 | ratio_y = ratio_y,
62 | request = request
63 | )
64 | return r
65 |
66 | doc_layout_load_model()
67 |
68 | @app.on_event("startup")
69 | async def startup_event():
70 | app.state.background_doc_layout_step = asyncio.create_task(doc_layout_step())
71 |
72 | @app.on_event("shutdown")
73 | async def shutdown_event():
74 | app.state.background_doc_layout_step.cancel()
75 | try:
76 | await app.state.background_doc_layout_step
77 | except asyncio.CancelledError:
78 | pass
79 |
80 | if args.enable_ocr:
81 | logging.info('enabling OCR')
82 |
83 | @app.post('/ocr')
84 | async def ocr(
85 | image: bytes = File(),
86 | mode: str = Form('format'),
87 | max_tokens: int = Form(4096),
88 | stream: bool = Form(False),
89 | request: Request = None,
90 | ):
91 | """
92 | Convert image to text using OCR.
93 |
94 | Support 2 modes,
95 |
96 | 1. `plain`, plain text.
97 |
98 | 2. `format`, will format into latex.
99 |
100 | """
101 | mode = mode.lower()
102 | if mode not in {'plain', 'format'}:
103 | raise HTTPException(status_code=400, detail='mode only support `plain` or `format`.')
104 |
105 | generator = ocr_predict(image, mode = mode, max_tokens = max_tokens, stream = stream, request = request)
106 | r = await generator
107 | if stream:
108 | return EventSourceResponse(r, headers=HEADERS)
109 | else:
110 | return r
111 |
112 | ocr_load_model()
113 |
114 | @app.on_event("startup")
115 | async def startup_event():
116 | app.state.background_ocr_prefill = asyncio.create_task(ocr_prefill())
117 | app.state.background_ocr_step = asyncio.create_task(ocr_step())
118 |
119 | @app.on_event("shutdown")
120 | async def shutdown_event():
121 | app.state.background_ocr_prefill.cancel()
122 | app.state.background_ocr_step.cancel()
123 | try:
124 | await app.state.background_ocr_prefill
125 | await app.state.background_ocr_step
126 | except asyncio.CancelledError:
127 | pass
128 |
129 | if args.enable_url_to_pdf:
130 | logging.info('enabling URL to PDF')
131 |
132 | class URL(BaseModel):
133 | url: str = 'https://screenresolutiontest.com/screenresolution/'
134 | viewport_weight: int = 1470
135 | viewport_height: int = 956
136 |
137 | @app.post('/url_to_pdf')
138 | async def url_to_pdf(url: URL):
139 | pdf_file = await to_pdf(**url.dict())
140 | return StreamingResponse(
141 | pdf_file,
142 | media_type="application/pdf",
143 | headers={"Content-Disposition": "attachment; filename=mydocument.pdf"}
144 | )
145 |
146 | @app.on_event('startup')
147 | async def warmup():
148 | tasks = []
149 | for index in range(args.playwright_max_concurrency):
150 | task = asyncio.create_task(initialize_browser(index=index))
151 | tasks.append(task)
152 |
153 | await asyncio.gather(*tasks)
154 |
155 | if __name__ == "__main__":
156 | uvicorn.run(
157 | 'dynamicbatch_ragpipeline.main:app',
158 | host=args.host,
159 | port=args.port,
160 | log_level=args.loglevel.lower(),
161 | access_log=True,
162 | reload=args.reload,
163 | )
--------------------------------------------------------------------------------
/dynamicbatch_ragpipeline/ocr.py:
--------------------------------------------------------------------------------
1 | from transformers import AutoModel, AutoTokenizer
2 | from sse_starlette import ServerSentEvent
3 | from dynamicbatch_ragpipeline import ocr_utils
4 | from dynamicbatch_ragpipeline.env import args
5 | from transformers_openai.function import efficient_attention_mask
6 | from transformers_openai.cache import (
7 | DynamicLengthDecoderCache,
8 | )
9 | from datetime import datetime
10 | from PIL import Image
11 | from io import BytesIO
12 | import numpy as np
13 | import torch
14 | import asyncio
15 | import time
16 | import logging
17 | import json
18 |
19 | model = None
20 | tokenizer = None
21 | global_cache = None
22 |
23 | torch_dtype = torch.bfloat16
24 |
25 | device = 'cpu'
26 | if args.accelerator_type == 'cuda':
27 | if not torch.cuda.is_available():
28 | logging.warning('CUDA is not available, fallback to CPU.')
29 | else:
30 | device = 'cuda'
31 |
32 | prefill_queue = asyncio.Queue()
33 | step_queue = asyncio.Queue()
34 |
35 | def load_model():
36 | global model, tokenizer, global_cache
37 |
38 | tokenizer = AutoTokenizer.from_pretrained('ucaslcl/GOT-OCR2_0', trust_remote_code=True)
39 | model = AutoModel.from_pretrained(
40 | 'ucaslcl/GOT-OCR2_0',
41 | trust_remote_code=True,
42 | torch_dtype = torch_dtype,
43 | attn_implementation = 'sdpa'
44 | )
45 | model = model.eval().to(device)
46 | global_cache = DynamicLengthDecoderCache()
47 |
48 | if args.torch_compile:
49 | logging.info('enabling torch compile for OCR')
50 |
51 | model.model.vision_tower_high.forward = torch.compile(
52 | model.model.vision_tower_high.forward,
53 | )
54 | model.model.mm_projector_vary.forward = torch.compile(
55 | model.model.mm_projector_vary.forward,
56 | )
57 | ocr_utils.image_processor_high.transform = torch.compile(
58 | ocr_utils.image_processor_high.transform,
59 | )
60 |
61 | with torch.no_grad():
62 | for i in range(3):
63 | logging.info(f'{i}, warming up vision tower')
64 | image = torch.zeros(3, 1000, 1000).to(device)
65 | image = ocr_utils.image_processor_high(image).unsqueeze(0).type(model.dtype)
66 | cnn_feature = model.model.vision_tower_high(image)
67 | cnn_feature = cnn_feature.flatten(2).permute(0, 2, 1)
68 | model.model.mm_projector_vary(cnn_feature)
69 | del cnn_feature
70 |
71 |
72 | async def prefill():
73 | need_sleep = True
74 | while True:
75 | if need_sleep:
76 | await asyncio.sleep(args.dynamic_batching_microsleep)
77 |
78 | try:
79 | need_sleep = True
80 | batch = []
81 | while not prefill_queue.empty():
82 | try:
83 | request = await asyncio.wait_for(prefill_queue.get(), timeout=1e-6)
84 | batch.append(request)
85 | if len(batch) >= args.dynamic_batching_ocr_batch_size:
86 | need_sleep = False
87 | break
88 |
89 | except asyncio.TimeoutError:
90 | break
91 |
92 | if not len(batch):
93 | continue
94 |
95 | futures = [batch[i][0] for i in range(len(batch))]
96 | input_img = [batch[i][1] for i in range(len(batch))]
97 | modes = [batch[i][3] for i in range(len(batch))]
98 | uuids = [batch[i][4] for i in range(len(batch))]
99 |
100 | logging.debug(f'{str(datetime.now())} OCR prefill batch size of {len(uuids)}')
101 |
102 | prompts = []
103 | for f in modes:
104 | if f == 'format':
105 | qs = 'OCR with format: '
106 | else:
107 | qs = 'OCR: '
108 |
109 | qs = ocr_utils.qs + qs
110 | conv = ocr_utils.conv_mpt.copy()
111 | conv.append_message(conv.roles[0], qs)
112 | conv.append_message(conv.roles[1], None)
113 | prompt = conv.get_prompt()
114 | prompts.append(prompt)
115 |
116 | with torch.no_grad():
117 | images = []
118 | for i in range(len(input_img)):
119 | image = Image.open(BytesIO(input_img[i])).convert('RGB')
120 | image = ocr_utils.image_processor_high.to_tensor(image).to(device)
121 | image_tensor = ocr_utils.image_processor_high(image).unsqueeze(0).type(model.dtype)
122 | images.append(image_tensor)
123 |
124 | input_ids = tokenizer(prompts, return_tensors = 'pt', padding = 'longest')
125 | input_ids.pop('token_type_ids', None)
126 | lengths = input_ids['attention_mask'].sum(axis = 1)
127 | for k in input_ids.keys():
128 | input_ids[k] = input_ids[k].to(device)
129 |
130 | out = model(
131 | **input_ids,
132 | images = images,
133 | past_key_values = None,
134 | use_cache = True,
135 | return_dict = False,
136 | )
137 | out_logits = out[0]
138 | out_caches = out[1]
139 |
140 | cache_exists = len(global_cache.key_cache)
141 |
142 | for k in range(len(out_caches)):
143 |
144 | key_cache = {}
145 | value_cache = {}
146 | for i in range(len(batch)):
147 | key_cache[uuids[i]] = out_caches[k][0][i: i + 1, :, :lengths[i]]
148 | value_cache[uuids[i]] = out_caches[k][1][i: i + 1, :, :lengths[i]]
149 |
150 | if cache_exists:
151 | global_cache.key_cache[k].update(key_cache)
152 | global_cache.value_cache[k].update(value_cache)
153 | else:
154 | global_cache.key_cache.append(key_cache)
155 | global_cache.value_cache.append(value_cache)
156 |
157 | for i in range(len(futures)):
158 | futures[i].set_result((out_logits[i, -1:],))
159 |
160 | for k in range(len(out_caches)):
161 | temp = list(out_caches[k])
162 | for j in range(len(out_caches[k])):
163 | del temp[0]
164 |
165 | except Exception as e:
166 | logging.error(f'error in prefill {e}')
167 | try:
168 | futures = [batch[i][0] for i in range(len(batch))]
169 | for i in range(len(futures)):
170 | if not futures[i].done():
171 | futures[i].set_exception(e)
172 | except:
173 | pass
174 |
175 | async def step():
176 | need_sleep = True
177 | while True:
178 | if need_sleep:
179 | await asyncio.sleep(args.dynamic_batching_microsleep)
180 |
181 | try:
182 | need_sleep = True
183 | batch = []
184 | while not step_queue.empty():
185 | try:
186 | request = await asyncio.wait_for(step_queue.get(), timeout=1e-6)
187 | batch.append(request)
188 |
189 | if len(batch) >= args.dynamic_batching_ocr_batch_size:
190 | need_sleep = False
191 | break
192 |
193 | except asyncio.TimeoutError:
194 | break
195 |
196 | if not len(batch):
197 | continue
198 |
199 | futures = [batch[i][0] for i in range(len(batch))]
200 | inputs = [batch[i][1] for i in range(len(batch))]
201 | lengths = [batch[i][2] for i in range(len(batch))]
202 | uuids = [batch[i][4] for i in range(len(batch))]
203 |
204 | logging.debug(f'{str(datetime.now())} OCR step batch size of {len(uuids)}')
205 |
206 | global_cache.current_uuid = uuids
207 |
208 | max_len_lengths = max(lengths)
209 | with torch.no_grad():
210 | inputs = torch.concat(inputs, dim=0)
211 | attention_mask = efficient_attention_mask(
212 | batch_size=len(lengths),
213 | max_len=max_len_lengths,
214 | lengths=lengths,
215 | device=device,
216 | dtype=torch_dtype,
217 | )
218 | position_ids = torch.tensor([[l - 1 for l in lengths]]).T.to(device)
219 | out = model(
220 | inputs,
221 | images = None,
222 | attention_mask=attention_mask,
223 | position_ids=position_ids,
224 | past_key_values=global_cache,
225 | use_cache=True,
226 | return_dict=False
227 | )
228 |
229 | out_logits = out[0]
230 |
231 | for i in range(len(futures)):
232 | futures[i].set_result((out_logits[i, -1:],))
233 |
234 | except Exception as e:
235 | logging.error(f'error in step {e}')
236 | try:
237 | futures = [batch[i][0] for i in range(len(batch))]
238 | for i in range(len(futures)):
239 | if not futures[i].done():
240 | futures[i].set_exception(e)
241 | except:
242 | pass
243 |
244 | async def streaming(image, mode, max_tokens, request):
245 |
246 | cache = None
247 | length = None
248 | inputs = image
249 | uuid = request.scope['request']['uuid']
250 |
251 | try:
252 | for k in range(max_tokens):
253 |
254 | if k == 0:
255 | q = prefill_queue
256 | l = length
257 | else:
258 | q = step_queue
259 | l = length + k
260 |
261 | future = asyncio.Future()
262 | await q.put((future, inputs, l, mode, uuid))
263 | out = await future
264 |
265 | logits = out[0]
266 |
267 | if length is None:
268 | length = global_cache.key_cache[0][uuid].shape[2]
269 |
270 | idx_next = logits.argmax(-1)
271 | token = tokenizer.decode(idx_next)
272 |
273 | if k == 0:
274 | request.scope['request']['time_first_token'] = time.time()
275 |
276 | if token == ocr_utils.stop_str:
277 | break
278 |
279 | del logits
280 | inputs = idx_next.unsqueeze(0)
281 |
282 | data = {
283 | 'token': token
284 | }
285 | yield json.dumps(data)
286 | await asyncio.sleep(0)
287 |
288 | request.scope['request']['time_max_tokens'] = time.time()
289 | request.scope['request']['total_tokens'] = k
290 |
291 | except asyncio.CancelledError as e:
292 | logging.warning(f"model step cancelled {uuid}")
293 | yield ServerSentEvent(**{"data": str(e)})
294 |
295 | except Exception as e:
296 | logging.error(f"model step exception {e} {uuid}")
297 | yield ServerSentEvent(**{"data": str(e)})
298 |
299 | finally:
300 | logging.debug(f'purging {uuid} KV cache')
301 | for i in range(len(global_cache.key_cache)):
302 | global_cache.key_cache[i].pop(uuid, None)
303 | global_cache.value_cache[i].pop(uuid, None)
304 |
305 | async def predict(image, mode = 'format', max_tokens = 4096, stream = False, request = None):
306 | if model is None:
307 | load_model()
308 |
309 | func = streaming(image=image, mode=mode, max_tokens=max_tokens, request=request)
310 | if stream:
311 | return func
312 | else:
313 | tokens = []
314 | async for data in func:
315 | if isinstance(data, ServerSentEvent):
316 | continue
317 | data = json.loads(data)
318 | tokens.append(data['token'])
319 |
320 | return {
321 | 'result': ''.join(tokens)
322 | }
323 |
324 |
--------------------------------------------------------------------------------
/dynamicbatch_ragpipeline/ocr_utils.py:
--------------------------------------------------------------------------------
1 | from torchvision import transforms
2 | from torchvision.transforms.functional import InterpolationMode
3 | from enum import auto, Enum
4 | from typing import List, Optional, Tuple, Union
5 | import dataclasses
6 |
7 | class SeparatorStyle(Enum):
8 | """Different separator style."""
9 | SINGLE = auto()
10 | TWO = auto()
11 | MPT = auto()
12 |
13 |
14 | @dataclasses.dataclass
15 | class Conversation:
16 | """A class that keeps all conversation history."""
17 | system: str
18 | roles: List[str]
19 | messages: List[List[str]]
20 | offset: int
21 | sep_style: SeparatorStyle = SeparatorStyle.SINGLE
22 | sep: str = "<|im_end|>"
23 | sep2: str = None
24 | version: str = "Unknown"
25 |
26 | skip_next: bool = False
27 |
28 | def get_prompt(self):
29 | if self.sep_style == SeparatorStyle.SINGLE:
30 | ret = self.system + self.sep + '\n'
31 | for role, message in self.messages:
32 | if message:
33 | if type(message) is tuple:
34 | message, _, _ = message
35 | ret += role + ": " + message + self.sep
36 | else:
37 | ret += role + ":"
38 | return ret
39 | elif self.sep_style == SeparatorStyle.TWO:
40 | seps = [self.sep, self.sep2]
41 | ret = self.system + seps[0]
42 | for i, (role, message) in enumerate(self.messages):
43 | if message:
44 | if type(message) is tuple:
45 | message, _, _ = message
46 | ret += role + ": " + message + seps[i % 2]
47 | else:
48 | ret += role + ":"
49 | return ret
50 | if self.sep_style == SeparatorStyle.MPT:
51 | if self.system:
52 | ret = self.system + self.sep
53 | else:
54 | ret = ''
55 | for role, message in self.messages:
56 | if message:
57 | if type(message) is tuple:
58 | message, _, _ = message
59 | ret += role + message + self.sep
60 | else:
61 | ret += role
62 | return ret
63 | else:
64 | raise ValueError(f"Invalid style: {self.sep_style}")
65 |
66 |
67 | def append_message(self, role, message):
68 | self.messages.append([role, message])
69 |
70 | def copy(self):
71 | return Conversation(
72 | system=self.system,
73 | roles=self.roles,
74 | messages=[[x, y] for x, y in self.messages],
75 | offset=self.offset,
76 | sep_style=self.sep_style,
77 | sep=self.sep,
78 | sep2=self.sep2)
79 |
80 | class GOTImageEvalProcessor:
81 | def __init__(self, image_size=384, mean=None, std=None):
82 | if mean is None:
83 | mean = (0.48145466, 0.4578275, 0.40821073)
84 | if std is None:
85 | std = (0.26862954, 0.26130258, 0.27577711)
86 |
87 | self.normalize = transforms.Normalize(mean, std)
88 | self.to_tensor = transforms.ToTensor()
89 |
90 | self.transform = transforms.Compose(
91 | [
92 | transforms.Resize(
93 | (image_size, image_size), interpolation=InterpolationMode.BICUBIC
94 | ),
95 | self.normalize,
96 | ]
97 | )
98 | def __call__(self, item):
99 | return self.transform(item)
100 |
101 | conv_mpt = Conversation(
102 | system="""<|im_start|>system
103 | You should follow the instructions carefully and explain your answers in detail.""",
104 | roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
105 | version="mpt",
106 | messages=(),
107 | offset=0,
108 | sep_style=SeparatorStyle.MPT,
109 | sep="<|im_end|>",
110 | )
111 |
112 | DEFAULT_IMAGE_TOKEN = ''
113 | DEFAULT_IMAGE_PATCH_TOKEN = ''
114 | DEFAULT_IM_START_TOKEN = '
'
115 | DEFAULT_IM_END_TOKEN = ''
116 | image_processor_high = GOTImageEvalProcessor(image_size=1024)
117 |
118 | image_token_len = 256
119 | qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_PATCH_TOKEN*image_token_len + DEFAULT_IM_END_TOKEN + '\n'
120 | stop_str = '<|im_end|>'
--------------------------------------------------------------------------------
/dynamicbatch_ragpipeline/playwright_utils.py:
--------------------------------------------------------------------------------
1 | from dynamicbatch_ragpipeline.env import args
2 | from playwright.async_api import async_playwright
3 | from datetime import datetime
4 | from fastapi import HTTPException
5 | from io import BytesIO
6 | import time
7 | import logging
8 | import asyncio
9 |
10 | playwrights = {}
11 |
12 | async def initialize_browser(index, clear_first = False):
13 | global playwrights
14 |
15 | if clear_first:
16 | logging.info(f'clearing playwright {index}')
17 | for k in list(playwrights[index]):
18 | try:
19 | await playwrights[index][k].close()
20 | except:
21 | pass
22 | try:
23 | del playwrights[index][k]
24 | except:
25 | pass
26 |
27 | logging.info(f'initializing playwright {index}')
28 |
29 | playwright = await async_playwright().start()
30 | browser = await playwright.chromium.launch(headless = True)
31 | page = await browser.new_page()
32 | playwrights[index] = {
33 | 'playwright': playwright,
34 | 'browser': browser,
35 | 'page': page,
36 | 'available': True,
37 | 'last_emit': datetime.now()
38 | }
39 |
40 | async def dead(index):
41 |
42 | died = False
43 | try:
44 | if playwrights[index]['page'].is_closed():
45 | died = True
46 | except:
47 | pass
48 |
49 | try:
50 | if not playwrights[index]['browser'].is_connected():
51 | died = True
52 | except:
53 | pass
54 |
55 | if died:
56 | await initialize_browser(index=index, clear_first=True)
57 |
58 | return died
59 |
60 | async def to_pdf(url, viewport_weight, viewport_height):
61 | index = 0
62 | found = False
63 | try:
64 | while True:
65 | for index in range(args.playwright_max_concurrency):
66 | if playwrights[index]['available']:
67 | playwrights[index]['available'] = False
68 | found = True
69 | break
70 |
71 | await asyncio.sleep(1e-9)
72 |
73 | if found:
74 | break
75 |
76 | await playwrights[index]['page'].set_viewport_size({"width": viewport_weight, "height": viewport_height})
77 | await playwrights[index]['page'].goto(url)
78 | pdf = await playwrights[index]['page'].pdf()
79 | playwrights[index]['available'] = True
80 | return BytesIO(pdf)
81 |
82 | except asyncio.CancelledError as e:
83 | await dead(index)
84 | playwrights[index]['available'] = True
85 |
86 | except Exception as e:
87 | await dead(index)
88 | playwrights[index]['available'] = True
89 | raise HTTPException(status_code=500, detail=f'failed, {e}, please retry.')
90 |
91 |
92 |
93 |
94 |
--------------------------------------------------------------------------------
/html-to-pdf-only.yaml:
--------------------------------------------------------------------------------
1 | version: "3.3"
2 |
3 | services:
4 | dynamicbatch_ragpipeline:
5 | build:
6 | context: .
7 | deploy:
8 | resources:
9 | reservations:
10 | devices:
11 | - driver: nvidia
12 | count: 1
13 | capabilities: [gpu]
14 | container_name: dynamicbatch_ragpipeline
15 | environment:
16 | - PYTHONUNBUFFERED=1
17 | - HF_HUB_ENABLE_HF_TRANSFER=1
18 | - ENABLE_DOC_LAYOUT=false
19 | - ENABLE_OCR=false
20 |
21 | volumes:
22 | - "./dynamicbatch_ragpipeline:/home/ubuntu/dynamicbatch_ragpipeline"
23 | - "~/.cache/huggingface:/home/ubuntu/.cache/huggingface"
24 | ports:
25 | - "7088:7088"
26 | command: python3 -m dynamicbatch_ragpipeline.main --host 0.0.0.0 --port 7088 --reload true
--------------------------------------------------------------------------------
/notebook/doc-layout.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mesolitica/dynamic-batch-RAG-pipeline/a25f1af3c686799c4171d3dd94e49961840acc61/notebook/doc-layout.png
--------------------------------------------------------------------------------
/notebook/huggingface.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mesolitica/dynamic-batch-RAG-pipeline/a25f1af3c686799c4171d3dd94e49961840acc61/notebook/huggingface.pdf
--------------------------------------------------------------------------------
/notebook/page1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mesolitica/dynamic-batch-RAG-pipeline/a25f1af3c686799c4171d3dd94e49961840acc61/notebook/page1.png
--------------------------------------------------------------------------------
/notebook/page2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mesolitica/dynamic-batch-RAG-pipeline/a25f1af3c686799c4171d3dd94e49961840acc61/notebook/page2.png
--------------------------------------------------------------------------------
/notebook/page3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mesolitica/dynamic-batch-RAG-pipeline/a25f1af3c686799c4171d3dd94e49961840acc61/notebook/page3.png
--------------------------------------------------------------------------------
/notebook/page4.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mesolitica/dynamic-batch-RAG-pipeline/a25f1af3c686799c4171d3dd94e49961840acc61/notebook/page4.png
--------------------------------------------------------------------------------
/notebook/url-to-pdf.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "id": "3808afa9",
7 | "metadata": {},
8 | "outputs": [],
9 | "source": [
10 | "import requests\n",
11 | "data = {\n",
12 | " 'url': 'https://huggingface.co/'\n",
13 | "}\n",
14 | "\n",
15 | "r = requests.post('http://localhost:7088/url_to_pdf', json = data)\n",
16 | "with open('huggingface.pdf', 'wb') as fopen:\n",
17 | " fopen.write(r._content)"
18 | ]
19 | },
20 | {
21 | "cell_type": "code",
22 | "execution_count": 2,
23 | "id": "6700b25a",
24 | "metadata": {},
25 | "outputs": [
26 | {
27 | "data": {
28 | "text/html": [
29 | "\n",
30 | " \n",
38 | " "
39 | ],
40 | "text/plain": [
41 | ""
42 | ]
43 | },
44 | "execution_count": 2,
45 | "metadata": {},
46 | "output_type": "execute_result"
47 | }
48 | ],
49 | "source": [
50 | "from IPython.display import IFrame\n",
51 | "\n",
52 | "IFrame('huggingface.pdf', width=600, height=600)"
53 | ]
54 | },
55 | {
56 | "cell_type": "code",
57 | "execution_count": null,
58 | "id": "a9bfd597",
59 | "metadata": {},
60 | "outputs": [],
61 | "source": []
62 | }
63 | ],
64 | "metadata": {
65 | "kernelspec": {
66 | "display_name": "python3.10",
67 | "language": "python",
68 | "name": "python3.10"
69 | },
70 | "language_info": {
71 | "codemirror_mode": {
72 | "name": "ipython",
73 | "version": 3
74 | },
75 | "file_extension": ".py",
76 | "mimetype": "text/x-python",
77 | "name": "python",
78 | "nbconvert_exporter": "python",
79 | "pygments_lexer": "ipython3",
80 | "version": "3.10.15"
81 | }
82 | },
83 | "nbformat": 4,
84 | "nbformat_minor": 5
85 | }
86 |
--------------------------------------------------------------------------------
/notebook/url-to-pdf.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mesolitica/dynamic-batch-RAG-pipeline/a25f1af3c686799c4171d3dd94e49961840acc61/notebook/url-to-pdf.png
--------------------------------------------------------------------------------
/push-dockerhub.sh:
--------------------------------------------------------------------------------
1 | docker build -t mesoliticadev/dynamic-batch-rag-pipeline .
2 | docker push mesoliticadev/dynamic-batch-rag-pipeline
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | fastapi
2 | uvicorn
3 | doclayout-yolo
4 | huggingface-hub
5 | torch
6 | torchvision
7 | python-multipart
8 | pymupdf
9 | verovio
10 | sse_starlette
11 | playwright
12 | transformers
13 | tiktoken
14 | git+https://github.com/mesolitica/transformers-openai-api.git@0990077672793866128d7ec7d6d5f24778bf999f
--------------------------------------------------------------------------------
/runs/detect/train/args.yaml:
--------------------------------------------------------------------------------
1 | task: detect
2 | mode: train
3 | model: /home/husein/.cache/huggingface/hub/models--juliozhao--DocLayout-YOLO-DocStructBench/snapshots/8c3299a30b8ff29a1503c4431b035b93220f7b11/doclayout_yolo_docstructbench_imgsz1024.pt
4 | data: coco8.yaml
5 | epochs: 100
6 | time: null
7 | patience: 100
8 | batch: 16
9 | imgsz: 1024
10 | save: true
11 | save_period: 10
12 | val_period: 1
13 | cache: false
14 | device: null
15 | workers: 8
16 | project: null
17 | name: train
18 | exist_ok: true
19 | pretrained: true
20 | optimizer: auto
21 | verbose: true
22 | seed: 0
23 | deterministic: true
24 | single_cls: false
25 | rect: false
26 | cos_lr: false
27 | close_mosaic: 10
28 | resume: false
29 | amp: true
30 | fraction: 1.0
31 | profile: false
32 | freeze: null
33 | multi_scale: false
34 | overlap_mask: true
35 | mask_ratio: 4
36 | dropout: 0.0
37 | val: true
38 | split: val
39 | save_json: false
40 | save_hybrid: false
41 | conf: null
42 | iou: 0.7
43 | max_det: 300
44 | half: false
45 | dnn: false
46 | plots: true
47 | source: null
48 | vid_stride: 1
49 | stream_buffer: false
50 | visualize: false
51 | augment: false
52 | agnostic_nms: false
53 | classes: null
54 | retina_masks: false
55 | embed: null
56 | show: false
57 | save_frames: false
58 | save_txt: false
59 | save_conf: false
60 | save_crop: false
61 | show_labels: true
62 | show_conf: true
63 | show_boxes: true
64 | line_width: null
65 | format: torchscript
66 | keras: false
67 | optimize: false
68 | int8: false
69 | dynamic: false
70 | simplify: false
71 | opset: null
72 | workspace: 4
73 | nms: false
74 | lr0: 0.01
75 | lrf: 0.01
76 | momentum: 0.937
77 | weight_decay: 0.0005
78 | warmup_epochs: 3.0
79 | warmup_momentum: 0.8
80 | warmup_bias_lr: 0.1
81 | box: 7.5
82 | cls: 0.5
83 | dfl: 1.5
84 | pose: 12.0
85 | kobj: 1.0
86 | label_smoothing: 0.0
87 | nbs: 64
88 | hsv_h: 0.015
89 | hsv_s: 0.7
90 | hsv_v: 0.4
91 | degrees: 0.0
92 | translate: 0.1
93 | scale: 0.5
94 | shear: 0.0
95 | perspective: 0.0
96 | flipud: 0.0
97 | fliplr: 0.5
98 | bgr: 0.0
99 | mosaic: 1.0
100 | mixup: 0.0
101 | copy_paste: 0.0
102 | auto_augment: randaugment
103 | erasing: 0.4
104 | crop_fraction: 1.0
105 | cfg: null
106 | tracker: botsort.yaml
107 | save_dir: runs/detect/train
108 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | import setuptools
2 |
3 |
4 | __packagename__ = 'dynamic-batch-rag-pipeline'
5 |
6 | setuptools.setup(
7 | name=__packagename__,
8 | packages=setuptools.find_packages(),
9 | version='0.1',
10 | python_requires='>=3.8',
11 | description='Dynamic batching for Document Layout and OCR, suitable for RAG',
12 | author='huseinzol05',
13 | url='https://github.com/mesolitica/dynamic-batch-rag-pipeline',
14 | )
--------------------------------------------------------------------------------
/stress-test/2310.01889v4.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mesolitica/dynamic-batch-RAG-pipeline/a25f1af3c686799c4171d3dd94e49961840acc61/stress-test/2310.01889v4.pdf
--------------------------------------------------------------------------------
/stress-test/README.md:
--------------------------------------------------------------------------------
1 | # Stress test
2 |
3 | ## Document layout on RTX 3090 Ti
4 |
5 | Rate of 10 users per second, total requests up to 100 users for 60 seconds,
6 |
7 | ```bash
8 | locust -f doc_layout.py -P 7001 -H http://localhost:7088 -r 10 -u 100 -t 60
9 | ```
10 |
11 | 
12 |
13 | ## OCR on RTX 3090 Ti
14 |
15 | Rate of 5 users per second, total requests up to 50 users for 60 seconds,
16 |
17 | ```bash
18 | locust -f ocr.py -P 7001 -H http://localhost:7088 -r 5 -u 50 -t 60
19 | ```
20 |
21 | ### Continuous batching
22 |
23 | 
--------------------------------------------------------------------------------
/stress-test/doc_layout.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mesolitica/dynamic-batch-RAG-pipeline/a25f1af3c686799c4171d3dd94e49961840acc61/stress-test/doc_layout.png
--------------------------------------------------------------------------------
/stress-test/doc_layout.py:
--------------------------------------------------------------------------------
1 | from locust import HttpUser, task
2 | from locust import events
3 | import itertools
4 | import time
5 |
6 | """
7 | Make sure already running this,
8 |
9 | CUDA_VISIBLE_DEVICES=0 \
10 | python3.10 -m dynamicbatch_ragpipeline.main \
11 | --host 0.0.0.0 --port 7088 \
12 | --dynamic-batching true \
13 | --dynamic-batching-batch-size 64
14 | """
15 |
16 | class HelloWorldUser(HttpUser):
17 |
18 | host = "http://127.0.0.1:7088"
19 |
20 | @task
21 | def hello_world(self):
22 |
23 | files = {
24 | 'file': ('2310.01889v4.pdf', open('2310.01889v4.pdf', 'rb'), 'application/pdf'),
25 | }
26 | r = self.client.post('/doc_layout', files=files)
--------------------------------------------------------------------------------
/stress-test/ocr.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mesolitica/dynamic-batch-RAG-pipeline/a25f1af3c686799c4171d3dd94e49961840acc61/stress-test/ocr.png
--------------------------------------------------------------------------------
/stress-test/ocr.py:
--------------------------------------------------------------------------------
1 | from locust import HttpUser, task
2 | from locust import events
3 | import itertools
4 | import time
5 |
6 | """
7 | Make sure already running this,
8 |
9 | CUDA_VISIBLE_DEVICES=2 \
10 | python3.10 -m dynamicbatch_ragpipeline.main \
11 | --host 0.0.0.0 --port 7088 \
12 | --dynamic-batching true \
13 | --dynamic-batching-ocr-batch-size 32
14 | """
15 |
16 | class HelloWorldUser(HttpUser):
17 |
18 | host = "http://127.0.0.1:7088"
19 |
20 | @task
21 | def hello_world(self):
22 |
23 | files = {
24 | 'image': ('table1.png', open('table1.png', 'rb'), 'image/png'),
25 | }
26 | r = self.client.post('/ocr', files=files)
--------------------------------------------------------------------------------
/stress-test/table1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mesolitica/dynamic-batch-RAG-pipeline/a25f1af3c686799c4171d3dd94e49961840acc61/stress-test/table1.png
--------------------------------------------------------------------------------
/stress-test/table2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mesolitica/dynamic-batch-RAG-pipeline/a25f1af3c686799c4171d3dd94e49961840acc61/stress-test/table2.png
--------------------------------------------------------------------------------
/stress-test/title.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mesolitica/dynamic-batch-RAG-pipeline/a25f1af3c686799c4171d3dd94e49961840acc61/stress-test/title.png
--------------------------------------------------------------------------------