├── .gitattributes
├── LICENSE
├── README.md
├── README_en.md
├── logo
└── logo.png
├── models
├── kimi_ocr.py
├── llava_ocr.py
├── ocr.py
├── openai_ocr.py
├── paddle_ocr.py
└── qwen_ocr.py
├── start_ocr.py
└── utils
├── mm_utils.py
└── read_pdf_to_text.py
/.gitattributes:
--------------------------------------------------------------------------------
1 | # Auto detect text files and perform LF normalization
2 | * text=auto
3 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2024 jackfsuia
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 | # LLM-Data-Cleaner
6 | 简体中文 | [English](README_en.md)
7 |
8 | ## 更新
9 | - 看来有人专门训练LLM来干这些活了。见[refuel-llm-2](https://www.refuel.ai/blog-posts/announcing-refuel-llm-2)。
10 | - 关于OCR,可以结合使用[Nougat](https://github.com/facebookresearch/nougat)、[Marker](https://github.com/VikParuchuri/marker)、或多模态大模型[MiniCPM-V](https://github.com/OpenBMB/MiniCPM-V)。
11 | ## 背景
12 | 未来人类会用大模型预加工所有数据。本项目旨在结合大模型来批量预处理数据,以支持科研目的。 现阶段支持OCR功能, 支持使用的大模型有 qwen(通义千问), moonshot(月之暗面), PaddleOCR(百度飞桨OCR), openai, Llava。
13 | ## 启动
14 | 克隆并且进入仓库
15 | ```bash
16 | git clone https://github.com/jackfsuia/LLM-Data-Cleaner.git && cd LLM-Data-Cleaner
17 | ```
18 | 进入仓库然后跑下面命令启动OCR
19 | ```bash
20 | python start_ocr.py --model MODEL --key YOUR_API_KEY --img_path /path/to/images/ --outdir /path/to/output/ --lang language --batchsize batchsize
21 | ```
22 | **MODEL** 的值可以是 ["qwen"(通义千问)](https://help.aliyun.com/zh/dashscope/developer-reference/activate-dashscope-and-create-an-api-key), ["moonshot"(月之暗面)](https://platform.moonshot.cn/console/api-keys), ["paddle"(百度飞桨OCR)](https://github.com/PaddlePaddle/PaddleOCR), ["openai"](https://platform.openai.com/docs/models/overview)和[llava](https://github.com/haotian-liu/LLaVA). **YOUR_API_KEY** 是你申请的API KEY,没有的话点上面相应的模型字体链接申请, paddle, llava不需要。 **/path/to/images/** 是图片目录, 里面所有图片都会被OCR, 结果保存在 **/path/to/output/** data.jsonl。 **language** 是识别的语言,值可以是 ch (中文), en (英文), fr (法语), german (德语), korean (韩语), japan (日语), 只有百度飞桨OCR可能会用到。**batchsize** 是每批量的大小,也是线程数,计算资源运行情况下,越大越好,默认是数据集大小。
23 | ## 示例
24 | 假如你要用通义千问的qwen-vl-plus模型做OCR,API密钥是sbadgassjda,图片数据所在目录是/images/,结果输出data.jsonl文件目录是/images/,无论是识别什么语言,你都应该跑下面的代码
25 | ```bash
26 | python start_ocr.py --model qwen-vl-plus --key sbadgassjda --img_path /images/ --outdir /images/
27 | ```
28 | 假如你要用百度飞桨OCR做OCR,图片数据所在目录是/images/,希望结果输出data.jsonl文件目录是/images/,语言是中文,那你应该跑下面的代码
29 | ```bash
30 | python start_ocr.py --model paddle --img_path /images/ --outdir /images/ --lang ch
31 | ```
32 | 假如你要用 `llava`, 跑下面的代码
33 | ```bash
34 | python start_ocr --model LLAVA_PATH --img_path /images/ --outdir /images/
35 | ```
36 | `LLAVA_PATH`` is 你的llava模型路径(HuggingFace类的模型路径).
37 | ## 附录
38 | OCR的提示词存在文件[ocr.py](models/ocr.py)里。
39 | ## 许可
40 |
41 | 项目许可证是[LICENSE](LICENSE)。
42 |
--------------------------------------------------------------------------------
/README_en.md:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 | # LLM-Data-Cleaner
6 | English | [简体中文](README.md)
7 | ## Updates
8 | - Looks like people have been training LLMs to do those clean jobs, see [refuel-llm-2](https://www.refuel.ai/blog-posts/announcing-refuel-llm-2).
9 | - As for OCR,you can use [Nougat](https://github.com/facebookresearch/nougat), [Marker](https://github.com/VikParuchuri/marker) or multimodal models like [MiniCPM-V](https://github.com/OpenBMB/MiniCPM-V), they work pretty well.
10 | ## Background
11 | In future human will use LLM to preprocess all data. This project assembles LLMs and old tools to generate or clean data for academic use. For now it supports OCR, using various models like PaddleOCR, OpenAI, Llava, qwen, moonshot.
12 | ## Start
13 | Clone and enter the repo
14 | ```bash
15 | git clone https://github.com/jackfsuia/LLM-Data-Cleaner.git && cd LLM-Data-Cleaner
16 | ```
17 | then to start OCR, run
18 | ```bash
19 | python start_ocr.py --model MODEL --key YOUR_API_KEY --img_path /path/to/images/ --outdir /path/to/output/ --lang language --batchsize batchsize
20 | ```
21 | **MODEL** can be [qwen](https://help.aliyun.com/zh/dashscope/developer-reference/activate-dashscope-and-create-an-api-key), [moonshot](https://platform.moonshot.cn/console/api-keys), [paddle](https://github.com/PaddlePaddle/PaddleOCR), [openai](https://platform.openai.com/docs/models/overview) and [llava](https://github.com/haotian-liu/LLaVA). **YOUR_API_KEY** is the API KEY, not needed for paddle and llava. **/path/to/images/** is the images folder, it will ocr all the images under that path, and save the result to the file **/path/to/output/** data.jsonl. **language** can be ch (Chinese), en (English), fr (French), german (German), korean (Korean), japan (Japanese), it is only used by paddle. **batchsize** is batch size, also the number of threads to process the images, default to be the size of the target dataset.
22 | ## Examples
23 | If you use `gpt-4-turbo` for OCR and your API key is `sbadgassjda`,the images data are in `/images/`,and the output `data.jsonl` is wished to be in `/images/` too, whatever language is, run
24 | ```bash
25 | python start_ocr.py --model gpt-4-turbo --key sbadgassjda --img_path /images/ --outdir /images/
26 | ```
27 | If you use `PaddleOCR`, the images data are in `/images/`,and the output `data.json`l is wished to be in `/images/` too, the OCR target language is `English`,run
28 | ```bash
29 | python start_ocr --model paddle --img_path /images/ --outdir /images/ --lang en
30 | ```
31 | If you use `llava`, run
32 | ```bash
33 | python start_ocr --model LLAVA_PATH --img_path /images/ --outdir /images/
34 | ```
35 | `LLAVA_PATH`` is your huggingface-like llava model path.
36 | ## Appendix
37 | The mission prompt for OCR is in [ocr.py](models/ocr.py) in case you want to change it.
38 | ## License
39 |
40 | LLM-Data-Cleaner is licensed under the MIT License found in the [LICENSE](LICENSE) file in the root directory of this repository.
41 |
--------------------------------------------------------------------------------
/logo/logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jackfsuia/LLM-Data-Cleaner/19b8d261f9ed69782cc694aca419a58c96ff207e/logo/logo.png
--------------------------------------------------------------------------------
/models/kimi_ocr.py:
--------------------------------------------------------------------------------
1 | from models.ocr import base_ocr
2 | from pathlib import Path
3 | from openai import OpenAI
4 |
5 |
6 | class kimi_ocr(base_ocr):
7 | def __init__(self, MODEL, KEY):
8 | super().__init__()
9 | if MODEL == "kimi":
10 | MODEL = "moonshot-v1-32k"
11 | self.client = OpenAI(
12 | api_key=KEY,
13 | base_url="https://api.moonshot.cn/v1",
14 | )
15 | self.MODEL=MODEL
16 |
17 | def ocr_image(self, img):
18 | # xlnet.pdf 是一个示例文件, 我们支持 pdf, doc 以及图片等格式, 对于图片和 pdf 文件,提供 ocr 相关能力
19 | file_object = self.client.files.create(file=Path(img), purpose="file-extract")
20 |
21 | file_content = self.client.files.content(file_id=file_object.id).text
22 |
23 | # 把它放进请求中
24 | messages = [
25 | {
26 | "role": "system",
27 | "content": "你是 Kimi,由 Moonshot AI 提供的人工智能助手,你更擅长中文和英文的对话。你会为用户提供安全,有帮助,准确的回答。同时,你会拒绝一切涉及恐怖主义,种族歧视,黄色暴力等问题的回答。Moonshot AI 为专有名词,不可翻译成其他语言。",
28 | },
29 | {
30 | "role": "system",
31 | "content": file_content,
32 | },
33 | {"role": "user", "content": self.prompt},
34 | ]
35 |
36 | # 然后调用 chat-completion, 获取 Kimi 的回答
37 | completion = self.client.chat.completions.create(
38 | model= self.MODEL,
39 | messages=messages,
40 | temperature=0.3,
41 | )
42 |
43 | return completion.choices[0].message.content
--------------------------------------------------------------------------------
/models/llava_ocr.py:
--------------------------------------------------------------------------------
1 |
2 |
3 | from ocr import base_ocr
4 | from PIL import Image
5 |
6 | from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig, BitsAndBytesConfig
7 | from utils.mm_utils import tokenizer_image_token, process_images, get_model_name_from_path
8 |
9 | # many of those code are copied from https://github.com/haotian-liu/LLaVA/blob/3e337ad269da3245643a2724a1d694b5839c37f9/llava/conversation.py
10 | import dataclasses
11 | from enum import auto, Enum
12 | from typing import List
13 | import base64
14 | from io import BytesIO
15 | from PIL import Image
16 |
17 |
18 | class SeparatorStyle(Enum):
19 | """Different separator style."""
20 | SINGLE = auto()
21 | TWO = auto()
22 | MPT = auto()
23 | PLAIN = auto()
24 | LLAMA_2 = auto()
25 |
26 |
27 | @dataclasses.dataclass
28 | class Conversation:
29 | """A class that keeps all conversation history."""
30 | system: str
31 | roles: List[str]
32 | messages: List[List[str]]
33 | offset: int
34 | sep_style: SeparatorStyle = SeparatorStyle.SINGLE
35 | sep: str = "###"
36 | sep2: str = None
37 | version: str = "Unknown"
38 |
39 | skip_next: bool = False
40 |
41 | def get_prompt(self):
42 | messages = self.messages
43 | if len(messages) > 0 and type(messages[0][1]) is tuple:
44 | messages = self.messages.copy()
45 | init_role, init_msg = messages[0].copy()
46 | init_msg = init_msg[0].replace("", "").strip()
47 | if 'mmtag' in self.version:
48 | messages[0] = (init_role, init_msg)
49 | messages.insert(0, (self.roles[0], ""))
50 | messages.insert(1, (self.roles[1], "Received."))
51 | else:
52 | messages[0] = (init_role, "\n" + init_msg)
53 |
54 | if self.sep_style == SeparatorStyle.SINGLE:
55 | ret = self.system + self.sep
56 | for role, message in messages:
57 | if message:
58 | if type(message) is tuple:
59 | message, _, _ = message
60 | ret += role + ": " + message + self.sep
61 | else:
62 | ret += role + ":"
63 | elif self.sep_style == SeparatorStyle.TWO:
64 | seps = [self.sep, self.sep2]
65 | ret = self.system + seps[0]
66 | for i, (role, message) in enumerate(messages):
67 | if message:
68 | if type(message) is tuple:
69 | message, _, _ = message
70 | ret += role + ": " + message + seps[i % 2]
71 | else:
72 | ret += role + ":"
73 | elif self.sep_style == SeparatorStyle.MPT:
74 | ret = self.system + self.sep
75 | for role, message in messages:
76 | if message:
77 | if type(message) is tuple:
78 | message, _, _ = message
79 | ret += role + message + self.sep
80 | else:
81 | ret += role
82 | elif self.sep_style == SeparatorStyle.LLAMA_2:
83 | wrap_sys = lambda msg: f"<>\n{msg}\n<>\n\n" if len(msg) > 0 else msg
84 | wrap_inst = lambda msg: f"[INST] {msg} [/INST]"
85 | ret = ""
86 |
87 | for i, (role, message) in enumerate(messages):
88 | if i == 0:
89 | assert message, "first message should not be none"
90 | assert role == self.roles[0], "first message should come from user"
91 | if message:
92 | if type(message) is tuple:
93 | message, _, _ = message
94 | if i == 0: message = wrap_sys(self.system) + message
95 | if i % 2 == 0:
96 | message = wrap_inst(message)
97 | ret += self.sep + message
98 | else:
99 | ret += " " + message + " " + self.sep2
100 | else:
101 | ret += ""
102 | ret = ret.lstrip(self.sep)
103 | elif self.sep_style == SeparatorStyle.PLAIN:
104 | seps = [self.sep, self.sep2]
105 | ret = self.system
106 | for i, (role, message) in enumerate(messages):
107 | if message:
108 | if type(message) is tuple:
109 | message, _, _ = message
110 | ret += message + seps[i % 2]
111 | else:
112 | ret += ""
113 | else:
114 | raise ValueError(f"Invalid style: {self.sep_style}")
115 |
116 | return ret
117 |
118 | def append_message(self, role, message):
119 | self.messages.append([role, message])
120 |
121 | def process_image(self, image, image_process_mode, return_pil=False, image_format='PNG', max_len=1344, min_len=672):
122 | if image_process_mode == "Pad":
123 | def expand2square(pil_img, background_color=(122, 116, 104)):
124 | width, height = pil_img.size
125 | if width == height:
126 | return pil_img
127 | elif width > height:
128 | result = Image.new(pil_img.mode, (width, width), background_color)
129 | result.paste(pil_img, (0, (width - height) // 2))
130 | return result
131 | else:
132 | result = Image.new(pil_img.mode, (height, height), background_color)
133 | result.paste(pil_img, ((height - width) // 2, 0))
134 | return result
135 | image = expand2square(image)
136 | elif image_process_mode in ["Default", "Crop"]:
137 | pass
138 | elif image_process_mode == "Resize":
139 | image = image.resize((336, 336))
140 | else:
141 | raise ValueError(f"Invalid image_process_mode: {image_process_mode}")
142 | if max(image.size) > max_len:
143 | max_hw, min_hw = max(image.size), min(image.size)
144 | aspect_ratio = max_hw / min_hw
145 | shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
146 | longest_edge = int(shortest_edge * aspect_ratio)
147 | W, H = image.size
148 | if H > W:
149 | H, W = longest_edge, shortest_edge
150 | else:
151 | H, W = shortest_edge, longest_edge
152 | image = image.resize((W, H))
153 | if return_pil:
154 | return image
155 | else:
156 | buffered = BytesIO()
157 | image.save(buffered, format=image_format)
158 | img_b64_str = base64.b64encode(buffered.getvalue()).decode()
159 | return img_b64_str
160 |
161 | def get_images(self, return_pil=False):
162 | images = []
163 | for i, (role, msg) in enumerate(self.messages[self.offset:]):
164 | if i % 2 == 0:
165 | if type(msg) is tuple:
166 | msg, image, image_process_mode = msg
167 | image = self.process_image(image, image_process_mode, return_pil=return_pil)
168 | images.append(image)
169 | return images
170 |
171 |
172 | def copy(self):
173 | return Conversation(
174 | system=self.system,
175 | roles=self.roles,
176 | messages=[[x, y] for x, y in self.messages],
177 | offset=self.offset,
178 | sep_style=self.sep_style,
179 | sep=self.sep,
180 | sep2=self.sep2,
181 | version=self.version)
182 |
183 | def dict(self):
184 | if len(self.get_images()) > 0:
185 | return {
186 | "system": self.system,
187 | "roles": self.roles,
188 | "messages": [[x, y[0] if type(y) is tuple else y] for x, y in self.messages],
189 | "offset": self.offset,
190 | "sep": self.sep,
191 | "sep2": self.sep2,
192 | }
193 | return {
194 | "system": self.system,
195 | "roles": self.roles,
196 | "messages": self.messages,
197 | "offset": self.offset,
198 | "sep": self.sep,
199 | "sep2": self.sep2,
200 | }
201 |
202 |
203 | conv_vicuna_v0 = Conversation(
204 | system="A chat between a curious human and an artificial intelligence assistant. "
205 | "The assistant gives helpful, detailed, and polite answers to the human's questions.",
206 | roles=("Human", "Assistant"),
207 | messages=(
208 | ("Human", "What are the key differences between renewable and non-renewable energy sources?"),
209 | ("Assistant",
210 | "Renewable energy sources are those that can be replenished naturally in a relatively "
211 | "short amount of time, such as solar, wind, hydro, geothermal, and biomass. "
212 | "Non-renewable energy sources, on the other hand, are finite and will eventually be "
213 | "depleted, such as coal, oil, and natural gas. Here are some key differences between "
214 | "renewable and non-renewable energy sources:\n"
215 | "1. Availability: Renewable energy sources are virtually inexhaustible, while non-renewable "
216 | "energy sources are finite and will eventually run out.\n"
217 | "2. Environmental impact: Renewable energy sources have a much lower environmental impact "
218 | "than non-renewable sources, which can lead to air and water pollution, greenhouse gas emissions, "
219 | "and other negative effects.\n"
220 | "3. Cost: Renewable energy sources can be more expensive to initially set up, but they typically "
221 | "have lower operational costs than non-renewable sources.\n"
222 | "4. Reliability: Renewable energy sources are often more reliable and can be used in more remote "
223 | "locations than non-renewable sources.\n"
224 | "5. Flexibility: Renewable energy sources are often more flexible and can be adapted to different "
225 | "situations and needs, while non-renewable sources are more rigid and inflexible.\n"
226 | "6. Sustainability: Renewable energy sources are more sustainable over the long term, while "
227 | "non-renewable sources are not, and their depletion can lead to economic and social instability.\n")
228 | ),
229 | offset=2,
230 | sep_style=SeparatorStyle.SINGLE,
231 | sep="###",
232 | )
233 |
234 | conv_vicuna_v1 = Conversation(
235 | system="A chat between a curious user and an artificial intelligence assistant. "
236 | "The assistant gives helpful, detailed, and polite answers to the user's questions.",
237 | roles=("USER", "ASSISTANT"),
238 | version="v1",
239 | messages=(),
240 | offset=0,
241 | sep_style=SeparatorStyle.TWO,
242 | sep=" ",
243 | sep2="",
244 | )
245 |
246 | conv_llama_2 = Conversation(
247 | system="""You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.
248 |
249 | If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.""",
250 | roles=("USER", "ASSISTANT"),
251 | version="llama_v2",
252 | messages=(),
253 | offset=0,
254 | sep_style=SeparatorStyle.LLAMA_2,
255 | sep="",
256 | sep2="",
257 | )
258 |
259 | conv_llava_llama_2 = Conversation(
260 | system="You are a helpful language and vision assistant. "
261 | "You are able to understand the visual content that the user provides, "
262 | "and assist the user with a variety of tasks using natural language.",
263 | roles=("USER", "ASSISTANT"),
264 | version="llama_v2",
265 | messages=(),
266 | offset=0,
267 | sep_style=SeparatorStyle.LLAMA_2,
268 | sep="",
269 | sep2="",
270 | )
271 |
272 | conv_mpt = Conversation(
273 | system="""<|im_start|>system
274 | A conversation between a user and an LLM-based AI assistant. The assistant gives helpful and honest answers.""",
275 | roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
276 | version="mpt",
277 | messages=(),
278 | offset=0,
279 | sep_style=SeparatorStyle.MPT,
280 | sep="<|im_end|>",
281 | )
282 |
283 | conv_llava_plain = Conversation(
284 | system="",
285 | roles=("", ""),
286 | messages=(
287 | ),
288 | offset=0,
289 | sep_style=SeparatorStyle.PLAIN,
290 | sep="\n",
291 | )
292 |
293 | conv_llava_v0 = Conversation(
294 | system="A chat between a curious human and an artificial intelligence assistant. "
295 | "The assistant gives helpful, detailed, and polite answers to the human's questions.",
296 | roles=("Human", "Assistant"),
297 | messages=(
298 | ),
299 | offset=0,
300 | sep_style=SeparatorStyle.SINGLE,
301 | sep="###",
302 | )
303 |
304 | conv_llava_v0_mmtag = Conversation(
305 | system="A chat between a curious user and an artificial intelligence assistant. "
306 | "The assistant is able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language."
307 | "The visual content will be provided with the following format: visual content.",
308 | roles=("Human", "Assistant"),
309 | messages=(
310 | ),
311 | offset=0,
312 | sep_style=SeparatorStyle.SINGLE,
313 | sep="###",
314 | version="v0_mmtag",
315 | )
316 |
317 | conv_llava_v1 = Conversation(
318 | system="A chat between a curious human and an artificial intelligence assistant. "
319 | "The assistant gives helpful, detailed, and polite answers to the human's questions.",
320 | roles=("USER", "ASSISTANT"),
321 | version="v1",
322 | messages=(),
323 | offset=0,
324 | sep_style=SeparatorStyle.TWO,
325 | sep=" ",
326 | sep2="",
327 | )
328 |
329 | conv_llava_v1_mmtag = Conversation(
330 | system="A chat between a curious user and an artificial intelligence assistant. "
331 | "The assistant is able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language."
332 | "The visual content will be provided with the following format: visual content.",
333 | roles=("USER", "ASSISTANT"),
334 | messages=(),
335 | offset=0,
336 | sep_style=SeparatorStyle.TWO,
337 | sep=" ",
338 | sep2="",
339 | version="v1_mmtag",
340 | )
341 |
342 | conv_mistral_instruct = Conversation(
343 | system="",
344 | roles=("USER", "ASSISTANT"),
345 | version="llama_v2",
346 | messages=(),
347 | offset=0,
348 | sep_style=SeparatorStyle.LLAMA_2,
349 | sep="",
350 | sep2="",
351 | )
352 |
353 | conv_chatml_direct = Conversation(
354 | system="""<|im_start|>system
355 | Answer the questions.""",
356 | roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
357 | version="mpt",
358 | messages=(),
359 | offset=0,
360 | sep_style=SeparatorStyle.MPT,
361 | sep="<|im_end|>",
362 | )
363 |
364 | default_conversation = conv_vicuna_v1
365 | conv_templates = {
366 | "default": conv_vicuna_v0,
367 | "v0": conv_vicuna_v0,
368 | "v1": conv_vicuna_v1,
369 | "vicuna_v1": conv_vicuna_v1,
370 | "llama_2": conv_llama_2,
371 | "mistral_instruct": conv_mistral_instruct,
372 | "chatml_direct": conv_chatml_direct,
373 | "mistral_direct": conv_chatml_direct,
374 |
375 | "plain": conv_llava_plain,
376 | "v0_plain": conv_llava_plain,
377 | "llava_v0": conv_llava_v0,
378 | "v0_mmtag": conv_llava_v0_mmtag,
379 | "llava_v1": conv_llava_v1,
380 | "v1_mmtag": conv_llava_v1_mmtag,
381 | "llava_llama_2": conv_llava_llama_2,
382 |
383 | "mpt": conv_mpt,
384 | }
385 |
386 | CONV_MODE = "llava_v1"
387 |
388 | CONTROLLER_HEART_BEAT_EXPIRATION = 30
389 | WORKER_HEART_BEAT_INTERVAL = 15
390 |
391 | LOGDIR = "."
392 |
393 | # Model Constants
394 | IGNORE_INDEX = -100
395 | IMAGE_TOKEN_INDEX = -200
396 | DEFAULT_IMAGE_TOKEN = ""
397 | DEFAULT_IMAGE_PATCH_TOKEN = ""
398 | DEFAULT_IM_START_TOKEN = ""
399 | DEFAULT_IM_END_TOKEN = ""
400 | IMAGE_PLACEHOLDER = ""
401 |
402 | class llava_ocr(base_ocr):
403 |
404 | def __init__(self, MODEL):
405 | super().__init__()
406 |
407 | self.model=AutoModelForCausalLM.from_pretrained(MODEL, torch_dtype="auto", device_map="auto", trust_remote_code=True)
408 | self.tokenizer = AutoTokenizer.from_pretrained(MODEL, trust_remote_code=True)
409 |
410 | def ocr_image(self, image_path)->str:
411 |
412 | image = Image.open(image_path)
413 | vision_tower = self.model.get_vision_tower()
414 | if not vision_tower.is_loaded:
415 | vision_tower.load_model(device_map="auto")
416 | # if device_map != 'auto':
417 | # vision_tower.to(device=device_map, dtype=torch.float16)
418 | image_processor = vision_tower.image_processor
419 | image_tensor = process_images([image], image_processor, self.model.config)[0]
420 | images = image_tensor.unsqueeze(0).half().cuda()
421 | image_sizes = [image.size]
422 | qs=self.prompt
423 | if getattr(self.model.config, 'mm_use_im_start_end', False):
424 | qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + qs
425 | else:
426 | qs = DEFAULT_IMAGE_TOKEN + '\n' + qs
427 |
428 | conv = conv_templates[CONV_MODE].copy()
429 | conv.append_message(conv.roles[0], qs)
430 | conv.append_message(conv.roles[1], None)
431 | prompt = conv.get_prompt()
432 |
433 | input_ids = tokenizer_image_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda()
434 |
435 | output_ids = self.model.generate(
436 | input_ids,
437 | images=images,
438 | image_sizes=image_sizes,
439 | max_new_tokens=1024,
440 | use_cache=True,
441 | )
442 |
443 | outputs = self.tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0].strip()
444 |
445 | return outputs
446 |
447 |
448 |
449 |
--------------------------------------------------------------------------------
/models/ocr.py:
--------------------------------------------------------------------------------
1 |
2 | class base_ocr:
3 |
4 | def __init__(self) -> None:
5 | self.prompt = "把里面的文字、每一条公式都原封不动、不带翻译地提取出来。你的回答应该直接是结果,不用你其他多余的说明。"
6 | def ocr_image(self,img):
7 | pass
8 |
9 | def closuer_ocr_image(self):
10 |
11 | def ocr_image(img):
12 | return self.ocr_image(img)
13 |
14 | return ocr_image
15 |
16 |
17 |
--------------------------------------------------------------------------------
/models/openai_ocr.py:
--------------------------------------------------------------------------------
1 | import base64
2 | import requests
3 | from ocr import base_ocr
4 |
5 |
6 |
7 | class openai_ocr(base_ocr):
8 | def __init__(self, MODEL, KEY):
9 | super().__init__()
10 | if MODEL == 'openai':
11 | MODEL='gpt-4-turbo'
12 | self.MODEL=MODEL
13 | self.api_key = KEY
14 |
15 | def ocr_image(self, image_path):
16 |
17 | # Function to encode the image
18 | def encode_image(image_path):
19 | with open(image_path, "rb") as image_file:
20 | return base64.b64encode(image_file.read()).decode('utf-8')
21 |
22 | # Path to your image
23 | # Getting the base64 string
24 | base64_image = encode_image(image_path)
25 |
26 | headers = {
27 | "Content-Type": "application/json",
28 | "Authorization": f"Bearer {self.api_key}"
29 | }
30 |
31 | payload = {
32 | "model": self.MODEL,
33 | "messages": [
34 | {
35 | "role": "user",
36 | "content": [
37 | {
38 | "type": "text",
39 | "text": self.prompt
40 | },
41 | {
42 | "type": "image_url",
43 | "image_url": {
44 | "url": f"data:image/jpeg;base64,{base64_image}"
45 | }
46 | }
47 | ]
48 | }
49 | ],
50 | "max_tokens": 300
51 | }
52 |
53 | response = requests.post("https://api.openai.com/v1/chat/completions", headers=headers, json=payload)
54 | # to be done...
55 | return str(response.json())
56 |
57 |
58 |
59 |
--------------------------------------------------------------------------------
/models/paddle_ocr.py:
--------------------------------------------------------------------------------
1 | from paddleocr import PaddleOCR, draw_ocr
2 | from models.ocr import base_ocr
3 |
4 | class paddle_ocr(base_ocr):
5 | def __init__(self, lang):
6 | super().__init__()
7 | self.lang=lang
8 | def ocr_image(self, img_path):
9 | # Paddleocr目前支持的多语言语种可以通过修改lang参数进行切换
10 | # 例如`ch`, `en`, `fr`, `german`, `korean`, `japan`
11 | text=""
12 | pocr = PaddleOCR(use_angle_cls=True, lang=self.lang) # need to run only once to download and load model into memory
13 | result = pocr.ocr(img_path, cls=True)
14 | for idx in range(len(result)):
15 | res = result[idx]
16 | for line in res:
17 | text=text+line[1][0]+'\n'
18 | return text
19 |
--------------------------------------------------------------------------------
/models/qwen_ocr.py:
--------------------------------------------------------------------------------
1 | from http import HTTPStatus
2 | import dashscope
3 | from ocr import base_ocr
4 |
5 | class qwen_ocr(base_ocr):
6 | def __init__(self, MODEL, KEY):
7 | super().__init__()
8 | if MODEL == 'qwen':
9 | MODEL='qwen-vl-plus'
10 | self.MODEL=MODEL
11 | dashscope.api_key = KEY
12 |
13 | def ocr_image(self, img):
14 |
15 | messages = [
16 | {
17 | "role": "user",
18 | "content": [
19 | {"image": f"file://{img}"},
20 | {"text": self.prompt}
21 | ]
22 | }
23 | ]
24 | response = dashscope.MultiModalConversation.call(model=self.MODEL,messages=messages)
25 |
26 | if response.status_code == HTTPStatus.OK:
27 | return response.output.choices[0].message.content
28 | else:
29 | print(response.code) # The error code.
30 | print(response.message) # The error message.
31 | return None
32 |
33 |
34 |
35 |
--------------------------------------------------------------------------------
/start_ocr.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import json
3 | import os
4 | from concurrent.futures import ThreadPoolExecutor
5 |
6 | def OCR_the_dataset(ocr_img, folder_path, outdir, batchsize = None):
7 | current_dir_path = os.path.dirname(os.path.abspath(__file__))
8 | if outdir is None:
9 | outdir=current_dir_path
10 | if folder_path is None:
11 | folder_path=current_dir_path
12 |
13 | def ocrf(filename):
14 | if filename.endswith(".jpg") or filename.endswith(".png") or filename.endswith(".tif") or filename.endswith(".jpeg"):
15 | img_path = os.path.join(folder_path, filename)
16 | item={}
17 | item["name"]=filename
18 | ocr_result=ocr_img(img_path)
19 | if ocr_result:
20 | print('-->one ocr success')
21 | return ocr_result
22 | print('-->not a image')
23 | return ""
24 |
25 | files=os.listdir(folder_path)
26 |
27 | if batchsize is None:
28 | batchsize = len(files)
29 | with ThreadPoolExecutor(max_workers=batchsize) as executor:
30 |
31 | ocr_results = list(executor.map(ocrf, files))
32 |
33 | save_file_name = os.path.join(outdir, "data.jsonl")
34 | with open(save_file_name, 'w', encoding='utf-8') as jsonl_file:##
35 | for i, ocr_result in enumerate(ocr_results):
36 | item={}
37 | item["name"]=files[i]
38 | item["ocr_result"]= ocr_result
39 | jsonl_file.write(json.dumps(item) + '\n')
40 |
41 | print(f"data has been written to {save_file_name}")
42 |
43 |
44 | if __name__ == '__main__':
45 |
46 | parser = argparse.ArgumentParser(description="ocr")
47 |
48 | parser.add_argument('--model', type=str, help="model name")
49 | parser.add_argument('--key', type=str, help="api key")
50 | parser.add_argument('--img_path', type=str, help="images folder path")
51 | parser.add_argument('--outdir', type=str, help="output dir")
52 | parser.add_argument('--lang', default='en',type=str, help="language:`ch`, `en`, `fr`, `german`, `korean`, `japan`")
53 | parser.add_argument('--batchsize', default=None,type=int, help="batchsize")
54 | args = parser.parse_args()
55 |
56 |
57 | model = args.model
58 | key= args.key
59 | img_path=args.img_path
60 | outdir=args.outdir
61 | lang=args.lang
62 | if model.startswith("qwen"):
63 | from models import qwen_ocr
64 | ocr_img = qwen_ocr.qwen_ocr(model, key).closuer_ocr_image()
65 | elif model.startswith("kimi") or model.startswith("moonshot"):
66 | from models import kimi_ocr
67 | ocr_img = kimi_ocr.kimi_ocr(model, key).closuer_ocr_image()
68 | elif model=='paddle':
69 | from models import paddle_ocr
70 | ocr_img = paddle_ocr.paddle_ocr(lang).closuer_ocr_image()
71 | elif model == 'openai' or model.startswith("gpt"):
72 | from models import openai_ocr
73 | ocr_img = openai_ocr.openai_ocr(model, key).closuer_ocr_image()
74 | elif "llava" in model.lower():
75 | from models import llava_ocr
76 | ocr_img = llava_ocr.llava_ocr(model).closuer_ocr_image()
77 | else:
78 | raise Exception("This model has not been supported.")
79 |
80 | OCR_the_dataset(ocr_img, img_path, outdir, args.batchsize)
81 |
82 |
83 |
84 |
85 |
86 |
--------------------------------------------------------------------------------
/utils/mm_utils.py:
--------------------------------------------------------------------------------
1 | # This file is from https://github.com/haotian-liu/LLaVA/blob/main/llava/mm_utils.py
2 |
3 | from PIL import Image
4 | from io import BytesIO
5 | import base64
6 | import torch
7 | import math
8 | import ast
9 |
10 | from transformers import StoppingCriteria
11 |
12 | IMAGE_TOKEN_INDEX=-200
13 | def select_best_resolution(original_size, possible_resolutions):
14 | """
15 | Selects the best resolution from a list of possible resolutions based on the original size.
16 |
17 | Args:
18 | original_size (tuple): The original size of the image in the format (width, height).
19 | possible_resolutions (list): A list of possible resolutions in the format [(width1, height1), (width2, height2), ...].
20 |
21 | Returns:
22 | tuple: The best fit resolution in the format (width, height).
23 | """
24 | original_width, original_height = original_size
25 | best_fit = None
26 | max_effective_resolution = 0
27 | min_wasted_resolution = float('inf')
28 |
29 | for width, height in possible_resolutions:
30 | scale = min(width / original_width, height / original_height)
31 | downscaled_width, downscaled_height = int(original_width * scale), int(original_height * scale)
32 | effective_resolution = min(downscaled_width * downscaled_height, original_width * original_height)
33 | wasted_resolution = (width * height) - effective_resolution
34 |
35 | if effective_resolution > max_effective_resolution or (effective_resolution == max_effective_resolution and wasted_resolution < min_wasted_resolution):
36 | max_effective_resolution = effective_resolution
37 | min_wasted_resolution = wasted_resolution
38 | best_fit = (width, height)
39 |
40 | return best_fit
41 |
42 |
43 | def resize_and_pad_image(image, target_resolution):
44 | """
45 | Resize and pad an image to a target resolution while maintaining aspect ratio.
46 |
47 | Args:
48 | image (PIL.Image.Image): The input image.
49 | target_resolution (tuple): The target resolution (width, height) of the image.
50 |
51 | Returns:
52 | PIL.Image.Image: The resized and padded image.
53 | """
54 | original_width, original_height = image.size
55 | target_width, target_height = target_resolution
56 |
57 | scale_w = target_width / original_width
58 | scale_h = target_height / original_height
59 |
60 | if scale_w < scale_h:
61 | new_width = target_width
62 | new_height = min(math.ceil(original_height * scale_w), target_height)
63 | else:
64 | new_height = target_height
65 | new_width = min(math.ceil(original_width * scale_h), target_width)
66 |
67 | # Resize the image
68 | resized_image = image.resize((new_width, new_height))
69 |
70 | new_image = Image.new('RGB', (target_width, target_height), (0, 0, 0))
71 | paste_x = (target_width - new_width) // 2
72 | paste_y = (target_height - new_height) // 2
73 | new_image.paste(resized_image, (paste_x, paste_y))
74 |
75 | return new_image
76 |
77 |
78 | def divide_to_patches(image, patch_size):
79 | """
80 | Divides an image into patches of a specified size.
81 |
82 | Args:
83 | image (PIL.Image.Image): The input image.
84 | patch_size (int): The size of each patch.
85 |
86 | Returns:
87 | list: A list of PIL.Image.Image objects representing the patches.
88 | """
89 | patches = []
90 | width, height = image.size
91 | for i in range(0, height, patch_size):
92 | for j in range(0, width, patch_size):
93 | box = (j, i, j + patch_size, i + patch_size)
94 | patch = image.crop(box)
95 | patches.append(patch)
96 |
97 | return patches
98 |
99 |
100 | def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
101 | """
102 | Calculate the shape of the image patch grid after the preprocessing for images of any resolution.
103 |
104 | Args:
105 | image_size (tuple): The size of the input image in the format (width, height).
106 | grid_pinpoints (str): A string representation of a list of possible resolutions.
107 | patch_size (int): The size of each image patch.
108 |
109 | Returns:
110 | tuple: The shape of the image patch grid in the format (width, height).
111 | """
112 | if type(grid_pinpoints) is list:
113 | possible_resolutions = grid_pinpoints
114 | else:
115 | possible_resolutions = ast.literal_eval(grid_pinpoints)
116 | width, height = select_best_resolution(image_size, possible_resolutions)
117 | return width // patch_size, height // patch_size
118 |
119 |
120 | def process_anyres_image(image, processor, grid_pinpoints):
121 | """
122 | Process an image with variable resolutions.
123 |
124 | Args:
125 | image (PIL.Image.Image): The input image to be processed.
126 | processor: The image processor object.
127 | grid_pinpoints (str): A string representation of a list of possible resolutions.
128 |
129 | Returns:
130 | torch.Tensor: A tensor containing the processed image patches.
131 | """
132 | if type(grid_pinpoints) is list:
133 | possible_resolutions = grid_pinpoints
134 | else:
135 | possible_resolutions = ast.literal_eval(grid_pinpoints)
136 | best_resolution = select_best_resolution(image.size, possible_resolutions)
137 | image_padded = resize_and_pad_image(image, best_resolution)
138 |
139 | patches = divide_to_patches(image_padded, processor.crop_size['height'])
140 |
141 | image_original_resize = image.resize((processor.size['shortest_edge'], processor.size['shortest_edge']))
142 |
143 | image_patches = [image_original_resize] + patches
144 | image_patches = [processor.preprocess(image_patch, return_tensors='pt')['pixel_values'][0]
145 | for image_patch in image_patches]
146 | return torch.stack(image_patches, dim=0)
147 |
148 |
149 | def load_image_from_base64(image):
150 | return Image.open(BytesIO(base64.b64decode(image)))
151 |
152 |
153 | def expand2square(pil_img, background_color):
154 | width, height = pil_img.size
155 | if width == height:
156 | return pil_img
157 | elif width > height:
158 | result = Image.new(pil_img.mode, (width, width), background_color)
159 | result.paste(pil_img, (0, (width - height) // 2))
160 | return result
161 | else:
162 | result = Image.new(pil_img.mode, (height, height), background_color)
163 | result.paste(pil_img, ((height - width) // 2, 0))
164 | return result
165 |
166 |
167 | def process_images(images, image_processor, model_cfg):
168 | image_aspect_ratio = getattr(model_cfg, "image_aspect_ratio", None)
169 | new_images = []
170 | if image_aspect_ratio == 'pad':
171 | for image in images:
172 | image = expand2square(image, tuple(int(x*255) for x in image_processor.image_mean))
173 | image = image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
174 | new_images.append(image)
175 | elif image_aspect_ratio == "anyres":
176 | for image in images:
177 | image = process_anyres_image(image, image_processor, model_cfg.image_grid_pinpoints)
178 | new_images.append(image)
179 | else:
180 | return image_processor(images, return_tensors='pt')['pixel_values']
181 | if all(x.shape == new_images[0].shape for x in new_images):
182 | new_images = torch.stack(new_images, dim=0)
183 | return new_images
184 |
185 |
186 | def tokenizer_image_token(prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, return_tensors=None):
187 | prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split('')]
188 |
189 | def insert_separator(X, sep):
190 | return [ele for sublist in zip(X, [sep]*len(X)) for ele in sublist][:-1]
191 |
192 | input_ids = []
193 | offset = 0
194 | if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id:
195 | offset = 1
196 | input_ids.append(prompt_chunks[0][0])
197 |
198 | for x in insert_separator(prompt_chunks, [image_token_index] * (offset + 1)):
199 | input_ids.extend(x[offset:])
200 |
201 | if return_tensors is not None:
202 | if return_tensors == 'pt':
203 | return torch.tensor(input_ids, dtype=torch.long)
204 | raise ValueError(f'Unsupported tensor type: {return_tensors}')
205 | return input_ids
206 |
207 |
208 | def get_model_name_from_path(model_path):
209 | model_path = model_path.strip("/")
210 | model_paths = model_path.split("/")
211 | if model_paths[-1].startswith('checkpoint-'):
212 | return model_paths[-2] + "_" + model_paths[-1]
213 | else:
214 | return model_paths[-1]
215 |
216 | class KeywordsStoppingCriteria(StoppingCriteria):
217 | def __init__(self, keywords, tokenizer, input_ids):
218 | self.keywords = keywords
219 | self.keyword_ids = []
220 | self.max_keyword_len = 0
221 | for keyword in keywords:
222 | cur_keyword_ids = tokenizer(keyword).input_ids
223 | if len(cur_keyword_ids) > 1 and cur_keyword_ids[0] == tokenizer.bos_token_id:
224 | cur_keyword_ids = cur_keyword_ids[1:]
225 | if len(cur_keyword_ids) > self.max_keyword_len:
226 | self.max_keyword_len = len(cur_keyword_ids)
227 | self.keyword_ids.append(torch.tensor(cur_keyword_ids))
228 | self.tokenizer = tokenizer
229 | self.start_len = input_ids.shape[1]
230 |
231 | def call_for_batch(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
232 | offset = min(output_ids.shape[1] - self.start_len, self.max_keyword_len)
233 | self.keyword_ids = [keyword_id.to(output_ids.device) for keyword_id in self.keyword_ids]
234 | for keyword_id in self.keyword_ids:
235 | truncated_output_ids = output_ids[0, -keyword_id.shape[0]:]
236 | if torch.equal(truncated_output_ids, keyword_id):
237 | return True
238 | outputs = self.tokenizer.batch_decode(output_ids[:, -offset:], skip_special_tokens=True)[0]
239 | for keyword in self.keywords:
240 | if keyword in outputs:
241 | return True
242 | return False
243 |
244 | def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
245 | outputs = []
246 | for i in range(output_ids.shape[0]):
247 | outputs.append(self.call_for_batch(output_ids[i].unsqueeze(0), scores))
248 | return all(outputs)
--------------------------------------------------------------------------------
/utils/read_pdf_to_text.py:
--------------------------------------------------------------------------------
1 | import pdfplumber
2 |
3 | import os
4 | def extract_txt_from_pdf(fn, tgt_path):
5 | """
6 | Extract text from a pdf file and save to target path.
7 |
8 | :param fn: path to input pdf file
9 | :param tgt_path: path to save text file.
10 | """
11 | with pdfplumber.open(fn) as pdf:
12 | text = []
13 | for page in pdf.pages:
14 | # remove tables from each page extracted by pdfplumber
15 | tables = page.find_tables()
16 | for table in tables:
17 | page = page.outside_bbox(table.bbox)
18 | # remove page number from the end of each page
19 | page_text = page.extract_text()
20 | page_num = str(page.page_number)
21 | if page_text.rstrip().endswith(page_num):
22 | page_text = page_text.rstrip()[:-len(page_num)]
23 | if page_text.strip():
24 | text.append(page_text)
25 | base_fn = os.path.basename(fn).lower().replace('.pdf', '.txt')
26 | with open(os.path.join(tgt_path, base_fn), 'w', encoding='utf-8') as f:
27 | f.write('\n'.join(text))
28 |
29 | extract_txt_from_pdf("C:\\Users\\Administrator\\Desktop\\CVX.pdf", "C:\\Users\\Administrator\\Desktop")
--------------------------------------------------------------------------------