├── .env.example ├── .github └── ISSUE_TEMPLATE │ ├── bug_report.yml │ ├── config.yml │ ├── feature_request.yml │ └── question.yml ├── .gitignore ├── Dockerfile ├── LICENSE ├── README.md ├── app ├── api │ ├── md.py │ └── uvdoc.py ├── banner.txt ├── config.py └── utils │ ├── __init__.py │ ├── file_utils.py │ ├── image_processing.py │ ├── model.py │ └── pdf_utils.py ├── main.py ├── model └── best_model.pkl └── requirements.txt /.env.example: -------------------------------------------------------------------------------- 1 | # API配置 2 | API_KEY=your_api_key_here 3 | BASE_URL=https://open.bigmodel.cn/api/paas/v4 4 | MODEL=glm-4v-flash 5 | PROMPT=提取图片中全部的文本,不需要任何推理和总结,只需要原文 6 | 7 | # 文件处理配置 8 | FILE_DELETE_DELAY=300 # 5分钟后删除临时文件 9 | 10 | # PDF处理配置 11 | PDF_CONCURRENT_LIMIT=5 12 | PDF_BATCH_SIZE=10 -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/bug_report.yml: -------------------------------------------------------------------------------- 1 | name: Bug 报告 2 | description: 创建 bug 报告以帮助我们改进 3 | labels: ["bug"] 4 | body: 5 | - type: markdown 6 | attributes: 7 | value: | 8 | 感谢您花时间填写这份 bug 报告! 9 | 10 | - type: textarea 11 | id: bug-description 12 | attributes: 13 | label: 描述这个 bug 14 | description: 请清晰简洁地描述这个 bug 是什么 15 | placeholder: 当我...的时候,出现了...的情况 16 | validations: 17 | required: true 18 | 19 | - type: textarea 20 | id: reproduction 21 | attributes: 22 | label: 重现步骤 23 | description: 请描述重现这个 bug 的步骤 24 | placeholder: | 25 | 1. 进入 '...' 26 | 2. 点击 '....' 27 | 3. 滚动到 '....' 28 | 4. 看到错误 29 | validations: 30 | required: true 31 | 32 | - type: textarea 33 | id: expected 34 | attributes: 35 | label: 期望行为 36 | description: 请清晰简洁地描述您期望发生的事情 37 | validations: 38 | required: true 39 | 40 | - type: textarea 41 | id: environment 42 | attributes: 43 | label: 环境信息 44 | description: | 45 | 请提供您的环境信息: 46 | - 操作系统 47 | - Python 版本 48 | - 相关依赖版本 49 | placeholder: | 50 | - 操作系统: Ubuntu 20.04 51 | - Python: 3.8.10 52 | - 依赖版本: requirements.txt 中的版本 53 | validations: 54 | required: true 55 | 56 | - type: textarea 57 | id: additional 58 | attributes: 59 | label: 补充信息 60 | description: 添加任何其他有关该问题的上下文信息 -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/config.yml: -------------------------------------------------------------------------------- 1 | blank_issues_enabled: false 2 | contact_links: 3 | - name: GitHub Discussions 4 | url: https://github.com/OWNER/REPO/discussions 5 | about: 请在这里提出一般性问题和讨论 -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature_request.yml: -------------------------------------------------------------------------------- 1 | name: 功能请求 2 | description: 为这个项目提出一个想法 3 | labels: ["enhancement"] 4 | body: 5 | - type: markdown 6 | attributes: 7 | value: | 8 | 感谢您为项目提出新的想法! 9 | 10 | - type: textarea 11 | id: problem 12 | attributes: 13 | label: 这个功能请求是否与某个问题相关? 14 | description: 请清晰简洁地描述问题所在 15 | placeholder: 当我想要...的时候,总是感觉很困难 16 | validations: 17 | required: true 18 | 19 | - type: textarea 20 | id: solution 21 | attributes: 22 | label: 描述您想要的解决方案 23 | description: 请清晰简洁地描述您希望发生的事情 24 | validations: 25 | required: true 26 | 27 | - type: textarea 28 | id: alternatives 29 | attributes: 30 | label: 您考虑过的替代方案 31 | description: 请描述您考虑过的任何替代解决方案或功能 32 | validations: 33 | required: false 34 | 35 | - type: textarea 36 | id: additional 37 | attributes: 38 | label: 补充信息 39 | description: 添加任何其他有关功能请求的上下文或截图 -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/question.yml: -------------------------------------------------------------------------------- 1 | name: 问题咨询 2 | description: 询问使用相关的问题 3 | labels: ["question"] 4 | body: 5 | - type: markdown 6 | attributes: 7 | value: | 8 | 在提问之前,请确保您已经: 9 | 1. 搜索过现有的 issues 和 discussions 10 | 2. 阅读过项目文档 11 | 3. 查看过相关示例 12 | 13 | - type: textarea 14 | id: question 15 | attributes: 16 | label: 您的问题是什么? 17 | description: 请清晰简洁地描述您的问题 18 | validations: 19 | required: true 20 | 21 | - type: textarea 22 | id: context 23 | attributes: 24 | label: 补充信息 25 | description: 提供任何可能有助于我们理解您问题的其他信息 26 | validations: 27 | required: false -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | *.py[cod] 3 | *$py.class 4 | 5 | # 环境文件 6 | .env 7 | .idea 8 | 9 | # 上传文件目录 10 | files/ 11 | 12 | # 虚拟环境 13 | venv/ 14 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | # 使用更轻量级的 Python 基础镜像 2 | FROM python:3.10-slim AS builder 3 | 4 | # 设置工作目录 5 | WORKDIR /app 6 | 7 | # 只复制依赖文件 8 | COPY requirements.txt . 9 | 10 | # 安装构建依赖和 Python 包 11 | RUN apt-get update && \ 12 | apt-get install -y --no-install-recommends \ 13 | ffmpeg \ 14 | build-essential \ 15 | && \ 16 | pip install --no-cache-dir --user -r requirements.txt && \ 17 | rm -rf /var/lib/apt/lists/* 18 | 19 | # 第二阶段:准备应用代码 20 | FROM python:3.10-slim AS app-prep 21 | 22 | WORKDIR /app 23 | 24 | # 只复制必要的应用代码 25 | COPY main.py . 26 | COPY app/ ./app/ 27 | COPY model/ ./model/ 28 | 29 | # 第三阶段:最终运行环境 30 | FROM python:3.10-slim 31 | 32 | # 定义构建参数 33 | ARG API_KEY 34 | ARG BASE_URL 35 | ARG MODEL 36 | ARG PROMPT 37 | ARG FILE_DELETE_DELAY 38 | 39 | # 设置环境变量 40 | ENV API_KEY=${API_KEY} \ 41 | BASE_URL=${BASE_URL} \ 42 | MODEL=${MODEL} \ 43 | PROMPT=${PROMPT} \ 44 | FILE_DELETE_DELAY=${FILE_DELETE_DELAY} \ 45 | PYTHONUNBUFFERED=1 \ 46 | PYTHONDONTWRITEBYTECODE=1 \ 47 | PIP_NO_CACHE_DIR=1 48 | 49 | WORKDIR /app 50 | 51 | # 只安装运行时必需的系统依赖 52 | RUN apt-get update && \ 53 | apt-get install -y --no-install-recommends \ 54 | ffmpeg \ 55 | && \ 56 | apt-get clean && \ 57 | rm -rf /var/lib/apt/lists/* 58 | 59 | # 从 builder 阶段复制安装好的 Python 包 60 | COPY --from=builder /root/.local /root/.local 61 | 62 | # 从 app-prep 阶段复制应用代码 63 | COPY --from=app-prep /app /app 64 | 65 | # 确保 Python 包在 PATH 中 66 | ENV PATH=/root/.local/bin:$PATH 67 | 68 | # 暴露端口 69 | EXPOSE 8000 70 | 71 | # 使用 uvicorn 运行 FastAPI 应用 72 | CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8000"] -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # office2md 2 | 3 | 这是一项基于 Markdown 格式的多功能转换服务,支持将 PowerPoint、Word、Excel、图像、音频和 HTML 等文件转化为 Markdown 格式。同时,服务整合了 Gitee AI 和智谱 AI 提供的 GLM-4V 模型,以及阿里云百炼平台的 Qwen-VL-Max 模型,用于图片和 PDF 文件的高效文本识别。 4 | 5 | ## Docker 使用说明 6 | 7 | ### 1. 快速使用 8 | 9 | ```bash 10 | # 内置了GLM-4V-FLASH视觉模型,仅供测试使用 11 | docker run -p 8000:8000 registry.cn-hangzhou.aliyuncs.com/dockerhub_mirror/markitdown 12 | ``` 13 | 14 | ### 2. 使用 Gitee AI 15 | 16 | ```bash 17 | docker run -d \ 18 | -p 8000:8000 \ 19 | -e API_KEY=gitee_ai_key \ 20 | -e MODEL=InternVL2_5-26B \ 21 | -e BASE_URL=https://ai.gitee.com/v1 \ 22 | registry.cn-hangzhou.aliyuncs.com/dockerhub_mirror/markitdown 23 | ``` 24 | 25 | ### 3. 使用阿里云百炼平台 26 | 27 | ```bash 28 | docker run -d \ 29 | -p 8000:8000 \ 30 | -e API_KEY=your_aliyun_api_key \ 31 | -e MODEL=qwen-vl-max \ 32 | -e BASE_URL=https://dashscope.aliyuncs.com/api/v1 \ 33 | registry.cn-hangzhou.aliyuncs.com/dockerhub_mirror/markitdown 34 | ``` 35 | 36 | ## 环境变量说明 37 | 38 | 服务支持以下环境变量配置: 39 | 40 | | 环境变量 | 说明 | 默认值 | 41 | | ------------ | ---------------------- | ------------------------------------------------------ | 42 | | API_KEY | AI 平台的 API 密钥 | XXXX | 43 | | BASE_URL | AI 平台的 API 基础 URL | https://open.bigmodel.cn/api/paas/v4 | 44 | | MODEL | 使用的模型名称 | glm-4v-flash | 45 | | DELETE_DELAY | 临时文件删除延迟(秒) | 300 | 46 | | PROMPT | 文本提取提示词 | 提取图片中全部的文本,不需要任何推理和总结,只需要原文 | 47 | 48 | ### 支持的模型配置 49 | 50 | #### 智谱 AI 51 | 52 | - MODEL=glm-4v-flash 53 | - BASE_URL=https://open.bigmodel.cn/api/paas/v4 54 | 55 | #### Gitee AI 56 | 57 | - MODEL=InternVL2_5-26B 58 | - BASE_URL=https://ai.gitee.com/v1 59 | 60 | #### 阿里云百炼 61 | 62 | - MODEL=qwen-vl-max 63 | - BASE_URL=https://dashscope.aliyuncs.com/compatible-mode/v1 64 | 65 | ## API 接口 66 | 67 | ### 1. 上传图片并提取文本 68 | 69 | **Endpoint:** POST /upload/ 70 | 71 | **请求格式:** multipart/form-data 72 | 73 | **参数:** 74 | 75 | - file: 图片文件 76 | 77 | **响应示例:** 78 | 79 | ```json 80 | { 81 | "text": "提取的文本内容" 82 | } 83 | ``` 84 | 85 | ### 2. 文档图像矫正 86 | 87 | **Endpoint:** POST /uvdoc/unwarp 88 | 89 | **请求格式:** multipart/form-data 90 | 91 | **参数:** 92 | 93 | - file: 需要进行展平处理的文档图片文件 94 | 95 | **响应格式:** image/png 96 | 97 | **说明:** 98 | 99 | - 该接口用于处理弯曲变形的文档图片,返回展平后的图片 100 | - 支持常见图片格式(PNG、JPEG等) 101 | - 返回的是展平后的PNG格式图片数据 102 | 103 | **错误响应:** 104 | 105 | ```json 106 | { 107 | "detail": "Error message" 108 | } 109 | ``` 110 | 111 | ## 源码运行 112 | 113 | ``` 114 | git clone https://gitee.com/log4j/office2md.git 115 | 116 | cd office2md 117 | 118 | python3 -m venv venvdev 119 | 120 | source venvdev/bin/activate 121 | 122 | pip install -r requirements.txt 123 | 124 | # 启动服务 125 | uvicorn main:app --reload 126 | ``` 127 | 128 | ## 注意事项 129 | 130 | 1. 使用前请确保已获取相应平台的 API 密钥 131 | 2. 智谱 AI 和阿里云百炼平台的接口略有不同,请确保使用正确的配置 132 | 3. 上传的图片文件会在处理后自动删除(默认 5 分钟) 133 | 4. 服务默认监听 8000 端口 134 | -------------------------------------------------------------------------------- /app/api/md.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import asyncio 4 | import logging 5 | import json 6 | from fastapi import APIRouter, File, UploadFile, status, HTTPException, Form, Body, Depends 7 | from typing import Dict, Optional 8 | from openai import OpenAI 9 | from markitdown import MarkItDown 10 | from app.utils.pdf_utils import PDFProcessor 11 | from pydantic import BaseModel 12 | 13 | # 配置日志 14 | logger = logging.getLogger(__name__) 15 | 16 | 17 | from app.config import ( 18 | API_KEY, 19 | BASE_URL, 20 | MODEL, 21 | FILE_DELETE_DELAY, 22 | MLM_PROMPT, 23 | PDF_CONCURRENT_LIMIT, 24 | PDF_BATCH_SIZE 25 | ) 26 | from app.utils.file_utils import save_upload_file, delete_files 27 | 28 | router = APIRouter() 29 | 30 | client = OpenAI( 31 | base_url=BASE_URL, 32 | api_key=API_KEY 33 | ) 34 | 35 | class MarkdownResponse(BaseModel): 36 | success: bool 37 | message: str 38 | text: Optional[str] = None 39 | 40 | class AiMarkitdownDTO(BaseModel): 41 | """AI Markitdown服务的配置参数""" 42 | base_url: Optional[str] = "" 43 | """自定义OpenAI API基础URL,为空时使用系统配置""" 44 | 45 | api_key: Optional[str] = "" 46 | """自定义OpenAI API密钥,为空时使用系统配置""" 47 | 48 | model: Optional[str] = "" 49 | """自定义OpenAI模型名称,为空时使用系统配置""" 50 | 51 | prompt: Optional[str] = "" 52 | """自定义提示词,为空时使用系统配置""" 53 | 54 | concurrent_limit: Optional[int] = 5 55 | """自定义PDF处理并发限制,控制PDF处理时的并发数量""" 56 | 57 | batch_size: Optional[int] = 10 58 | """自定义PDF批处理大小,控制PDF处理时的批量大小""" 59 | 60 | delete_delay: Optional[int] = 300 61 | """自定义文件删除延迟时间(秒),控制临时文件的保留时间""" 62 | 63 | async def parse_request_json(request: Optional[str] = Form("", description="JSON格式的配置参数,包含API密钥、模型等设置")) -> Optional[AiMarkitdownDTO]: 64 | """从表单字段解析JSON对象""" 65 | if not request: 66 | return None 67 | try: 68 | data = json.loads(request) 69 | return AiMarkitdownDTO(**data) 70 | except Exception as e: 71 | logger.error(f"Error parsing request JSON: {e}") 72 | return None 73 | 74 | @router.post("/upload", 75 | response_model=MarkdownResponse, 76 | status_code=status.HTTP_200_OK, 77 | summary="上传文件", 78 | description="上传图片或PDF文件并提取其中的文本内容", 79 | responses={ 80 | 200: { 81 | "description": "成功提取文本", 82 | "content": { 83 | "application/json": { 84 | "example": { 85 | "success": True, 86 | "message": "Text extracted successfully", 87 | "text": "提取的文本内容" 88 | } 89 | } 90 | } 91 | } 92 | } 93 | ) 94 | async def upload_file( 95 | file: UploadFile = File(..., description="要上传的文件,支持常见图片格式和PDF文件"), 96 | request: Optional[AiMarkitdownDTO] = Depends(parse_request_json) 97 | ): 98 | # 创建一个默认的DTO对象,如果request为None 99 | if request is None: 100 | request = AiMarkitdownDTO() 101 | 102 | # 记录请求参数(安全处理API密钥) 103 | masked_api_key = "None" if not request.api_key else f"****{request.api_key[-4:]}" if len(request.api_key) > 4 else "****" 104 | logger.debug(f"Received request with parameters: base_url={request.base_url}, api_key={masked_api_key}, model={request.model}, prompt_provided={request.prompt != ''}, concurrent_limit={request.concurrent_limit}, batch_size={request.batch_size}, delete_delay={request.delete_delay}") 105 | 106 | timestamp = int(time.time()) 107 | file_extension = os.path.splitext(file.filename)[1].lower() 108 | new_filename = f"{timestamp}{file_extension}" 109 | 110 | content = await file.read() 111 | file_path = await save_upload_file(content, new_filename) 112 | 113 | # 使用用户提供的参数或默认值 114 | used_base_url = request.base_url or BASE_URL 115 | used_api_key = request.api_key or API_KEY 116 | used_model = request.model or MODEL 117 | used_prompt = request.prompt or MLM_PROMPT 118 | used_concurrent_limit = request.concurrent_limit or PDF_CONCURRENT_LIMIT 119 | used_batch_size = request.batch_size or PDF_BATCH_SIZE 120 | used_delete_delay = request.delete_delay or FILE_DELETE_DELAY 121 | 122 | # 处理API密钥显示,只显示最后4位 123 | masked_api_key = "None" if not used_api_key else f"****{used_api_key[-4:]}" if len(used_api_key) > 4 else "****" 124 | 125 | logger.info(f"Processing with parameters: base_url={used_base_url}, api_key={masked_api_key}, model={used_model}, prompt_provided={used_prompt is not None}, concurrent_limit={used_concurrent_limit}, batch_size={used_batch_size}, delete_delay={used_delete_delay}") 126 | 127 | # 如果用户提供了自定义参数,创建新的OpenAI客户端 128 | current_client = client 129 | if request.base_url or request.api_key: 130 | current_client = OpenAI( 131 | base_url=used_base_url, 132 | api_key=used_api_key 133 | ) 134 | 135 | # 使用当前客户端创建MarkItDown实例 136 | markitdown = MarkItDown(llm_client=current_client, llm_model=used_model) 137 | 138 | result = markitdown.convert(file_path, llm_prompt=used_prompt) 139 | 140 | # 如果是PDF文件且未提取到文本,则尝试其他方法 141 | if file_extension == '.pdf' and not result.text_content: 142 | async with PDFProcessor(concurrent_limit=used_concurrent_limit) as processor: 143 | success, text = await processor.extract_text( 144 | file_path, 145 | used_base_url, 146 | used_api_key, 147 | used_model, 148 | used_prompt, 149 | batch_size=used_batch_size, 150 | ) 151 | if success: 152 | result.text_content = text 153 | 154 | # 创建异步任务删除临时文件 155 | asyncio.create_task(delete_files(file_path, "", used_delete_delay)) 156 | 157 | return MarkdownResponse( 158 | success=True, 159 | message="Text extracted successfully", 160 | text=result.text_content or "" 161 | ) -------------------------------------------------------------------------------- /app/api/uvdoc.py: -------------------------------------------------------------------------------- 1 | import os 2 | from io import BytesIO 3 | from typing import Optional 4 | 5 | import cv2 6 | import numpy as np 7 | import torch 8 | from fastapi import APIRouter, File, UploadFile, HTTPException, Response 9 | from pydantic import BaseModel 10 | from PIL import Image 11 | 12 | from app.utils import IMG_SIZE, bilinear_unwarping, load_model 13 | 14 | # 路由器添加描述 15 | router = APIRouter( 16 | prefix="/uvdoc" 17 | ) 18 | 19 | # 全局变量存储加载的模型 20 | model = None 21 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 22 | 23 | 24 | class UnwarpResponse(BaseModel): 25 | success: bool 26 | message: str 27 | text: Optional[str] = None 28 | 29 | 30 | # 创建一个普通函数来加载模型 31 | def load_model_fn(): 32 | """加载模型的函数""" 33 | global model 34 | try: 35 | ckpt_path = os.getenv("MODEL_CHECKPOINT_PATH", "./model/best_model.pkl") 36 | model = load_model(ckpt_path) 37 | model.to(device) 38 | model.eval() 39 | except Exception as e: 40 | print(f"Error loading model: {str(e)}") 41 | model = None 42 | 43 | # 在main.py中会调用这个函数 44 | load_model_fn() 45 | 46 | 47 | @router.post( 48 | "/unwarp", 49 | response_class=Response, 50 | summary="文档图像展平", 51 | description="接收一张弯曲变形的文档图片,返回展平后的图片", 52 | responses={ 53 | 200: { 54 | "description": "图片展平处理成功", 55 | "content": {"image/png": {}} 56 | }, 57 | 500: { 58 | "description": "模型加载错误或处理错误", 59 | "content": { 60 | "application/json": { 61 | "example": { 62 | "detail": "Model not loaded" 63 | } 64 | } 65 | } 66 | } 67 | } 68 | ) 69 | async def unwarp_image( 70 | file: UploadFile = File(..., description="需要进行展平处理的文档图片文件") 71 | ): 72 | """ 73 | 使用深度学习模型对文档图片进行展平处理。 74 | 75 | 参数: 76 | file (UploadFile): 输入的图片文件,支持常见图片格式(PNG、JPEG等) 77 | 78 | 返回: 79 | Response: 包含展平后图片数据的响应对象 80 | 81 | 异常: 82 | HTTPException: 当模型未加载或发生其他处理错误时抛出 83 | """ 84 | if model is None: 85 | raise HTTPException(status_code=500, detail="Model not loaded") 86 | 87 | try: 88 | # 读取上传的图片 89 | contents = await file.read() 90 | nparr = np.frombuffer(contents, np.uint8) 91 | img = cv2.imdecode(nparr, cv2.IMREAD_COLOR) 92 | 93 | if img is None: 94 | return UnwarpResponse(success=False, message="Invalid image file") 95 | 96 | # 转换图片格式 97 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB).astype(np.float32) / 255 98 | inp = torch.from_numpy(cv2.resize(img, IMG_SIZE).transpose(2, 0, 1)).unsqueeze(0) 99 | 100 | # 进行预测 101 | inp = inp.to(device) 102 | with torch.no_grad(): 103 | point_positions2D, _ = model(inp) 104 | 105 | # 展平处理 106 | size = img.shape[:2][::-1] 107 | unwarped = bilinear_unwarping( 108 | warped_img=torch.from_numpy(img.transpose(2, 0, 1)).unsqueeze(0).to(device), 109 | point_positions=torch.unsqueeze(point_positions2D[0], dim=0), 110 | img_size=tuple(size), 111 | ) 112 | unwarped = (unwarped[0].detach().cpu().numpy().transpose(1, 2, 0) * 255).astype(np.uint8) 113 | 114 | # 转换为PIL Image并直接输出为字节流 115 | pil_img = Image.fromarray(cv2.cvtColor(unwarped, cv2.COLOR_RGB2BGR)) 116 | img_byte_arr = BytesIO() 117 | pil_img.save(img_byte_arr, format='PNG') 118 | img_byte_arr.seek(0) 119 | 120 | return Response( 121 | content=img_byte_arr.getvalue(), 122 | media_type="image/png" 123 | ) 124 | 125 | except Exception as e: 126 | raise HTTPException( 127 | status_code=500, 128 | detail=f"Error processing image: {str(e)}" 129 | ) 130 | -------------------------------------------------------------------------------- /app/banner.txt: -------------------------------------------------------------------------------- 1 | +------------------------------------------------+ 2 | | | 3 | | 🚀 MarkItDown API Server | 4 | | | 5 | | ✨ Server is running... | 6 | | | 7 | | 🌐 http://localhost:8000/docs | 8 | | | 9 | | 🤖 https://ai.pig4cloud.com | 10 | | | 11 | +------------------------------------------------+ -------------------------------------------------------------------------------- /app/config.py: -------------------------------------------------------------------------------- 1 | import os 2 | from dotenv import load_dotenv 3 | 4 | # 加载.env文件 5 | load_dotenv() 6 | 7 | # API配置 8 | API_KEY = os.getenv('API_KEY') 9 | BASE_URL = os.getenv('BASE_URL') 10 | MODEL = os.getenv('MODEL') 11 | 12 | # 文件处理配置 13 | FILE_DELETE_DELAY = int(os.getenv('FILE_DELETE_DELAY', 300)) # 默认5分钟 14 | MLM_PROMPT = os.getenv('PROMPT') 15 | 16 | # PDF处理相关配置 17 | PDF_CONCURRENT_LIMIT = int(os.getenv('PDF_CONCURRENT_LIMIT', '5')) 18 | PDF_BATCH_SIZE = int(os.getenv('PDF_BATCH_SIZE', '10')) -------------------------------------------------------------------------------- /app/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .image_processing import ( 2 | IMG_SIZE, 3 | GRID_SIZE, 4 | load_model, 5 | bilinear_unwarping 6 | ) 7 | 8 | __all__ = [ 9 | 'IMG_SIZE', 10 | 'GRID_SIZE', 11 | 'load_model', 12 | 'bilinear_unwarping' 13 | ] -------------------------------------------------------------------------------- /app/utils/file_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import asyncio 3 | import aiofiles 4 | from typing import Optional 5 | 6 | async def save_upload_file(file_content: bytes, filename: str, directory: str = "files") -> str: 7 | """保存上传的文件""" 8 | if not os.path.exists(directory): 9 | os.makedirs(directory) 10 | 11 | file_path = os.path.join(directory, filename) 12 | async with aiofiles.open(file_path, 'wb') as out_file: 13 | await out_file.write(file_content) 14 | return file_path 15 | 16 | async def delete_files(file_path: str, output_path: Optional[str], delay: int): 17 | """延迟删除文件""" 18 | await asyncio.sleep(delay) 19 | if os.path.exists(file_path): 20 | os.remove(file_path) 21 | if output_path and os.path.exists(output_path): 22 | os.remove(output_path) -------------------------------------------------------------------------------- /app/utils/image_processing.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.nn.functional as F 4 | from app.utils.model import UVDocnet 5 | 6 | IMG_SIZE = [488, 712] 7 | GRID_SIZE = [45, 31] 8 | 9 | 10 | def load_model(ckpt_path): 11 | """ 12 | Load UVDocnet model. 13 | """ 14 | model = UVDocnet(num_filter=32, kernel_size=5) 15 | ckpt = torch.load(ckpt_path, map_location=torch.device('cpu')) 16 | model.load_state_dict(ckpt["model_state"]) 17 | return model 18 | 19 | 20 | def bilinear_unwarping(warped_img, point_positions, img_size): 21 | """ 22 | Utility function that unwarps an image. 23 | Unwarp warped_img based on the 2D grid point_positions with a size img_size. 24 | """ 25 | upsampled_grid = F.interpolate( 26 | point_positions, size=(img_size[1], img_size[0]), mode="bilinear", align_corners=True 27 | ) 28 | unwarped_img = F.grid_sample(warped_img, upsampled_grid.transpose(1, 2).transpose(2, 3), align_corners=True) 29 | return unwarped_img -------------------------------------------------------------------------------- /app/utils/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | def conv3x3(in_channels, out_channels, kernel_size, stride=1): 6 | return nn.Conv2d( 7 | in_channels, 8 | out_channels, 9 | kernel_size=kernel_size, 10 | stride=stride, 11 | padding=kernel_size // 2, 12 | ) 13 | 14 | 15 | def dilated_conv_bn_act(in_channels, out_channels, act_fn, BatchNorm, dilation): 16 | model = nn.Sequential( 17 | nn.Conv2d( 18 | in_channels, 19 | out_channels, 20 | bias=False, 21 | kernel_size=3, 22 | stride=1, 23 | padding=dilation, 24 | dilation=dilation, 25 | ), 26 | BatchNorm(out_channels), 27 | act_fn, 28 | ) 29 | return model 30 | 31 | 32 | def dilated_conv(in_channels, out_channels, kernel_size, dilation, stride=1): 33 | model = nn.Sequential( 34 | nn.Conv2d( 35 | in_channels, 36 | out_channels, 37 | kernel_size=kernel_size, 38 | stride=stride, 39 | padding=dilation * (kernel_size // 2), 40 | dilation=dilation, 41 | ) 42 | ) 43 | return model 44 | 45 | 46 | class ResidualBlockWithDilation(nn.Module): 47 | def __init__( 48 | self, 49 | in_channels, 50 | out_channels, 51 | BatchNorm, 52 | kernel_size, 53 | stride=1, 54 | downsample=None, 55 | is_activation=True, 56 | is_top=False, 57 | ): 58 | super(ResidualBlockWithDilation, self).__init__() 59 | self.stride = stride 60 | self.downsample = downsample 61 | self.is_activation = is_activation 62 | self.is_top = is_top 63 | if self.stride != 1 or self.is_top: 64 | self.conv1 = conv3x3(in_channels, out_channels, kernel_size, self.stride) 65 | self.conv2 = conv3x3(out_channels, out_channels, kernel_size) 66 | else: 67 | self.conv1 = dilated_conv(in_channels, out_channels, kernel_size, dilation=3) 68 | self.conv2 = dilated_conv(out_channels, out_channels, kernel_size, dilation=3) 69 | 70 | self.bn1 = BatchNorm(out_channels) 71 | self.relu = nn.ReLU(inplace=True) 72 | self.bn2 = BatchNorm(out_channels) 73 | 74 | def forward(self, x): 75 | residual = x 76 | if self.downsample is not None: 77 | residual = self.downsample(x) 78 | 79 | out1 = self.relu(self.bn1(self.conv1(x))) 80 | out2 = self.bn2(self.conv2(out1)) 81 | 82 | out2 += residual 83 | out = self.relu(out2) 84 | return out 85 | 86 | 87 | class ResnetStraight(nn.Module): 88 | def __init__( 89 | self, 90 | num_filter, 91 | map_num, 92 | BatchNorm, 93 | block_nums=[3, 4, 6, 3], 94 | block=ResidualBlockWithDilation, 95 | kernel_size=5, 96 | stride=[1, 1, 2, 2], 97 | ): 98 | super(ResnetStraight, self).__init__() 99 | self.in_channels = num_filter * map_num[0] 100 | self.stride = stride 101 | self.relu = nn.ReLU(inplace=True) 102 | self.block_nums = block_nums 103 | self.kernel_size = kernel_size 104 | 105 | self.layer1 = self.blocklayer( 106 | block, 107 | num_filter * map_num[0], 108 | self.block_nums[0], 109 | BatchNorm, 110 | kernel_size=self.kernel_size, 111 | stride=self.stride[0], 112 | ) 113 | self.layer2 = self.blocklayer( 114 | block, 115 | num_filter * map_num[1], 116 | self.block_nums[1], 117 | BatchNorm, 118 | kernel_size=self.kernel_size, 119 | stride=self.stride[1], 120 | ) 121 | self.layer3 = self.blocklayer( 122 | block, 123 | num_filter * map_num[2], 124 | self.block_nums[2], 125 | BatchNorm, 126 | kernel_size=self.kernel_size, 127 | stride=self.stride[2], 128 | ) 129 | 130 | def blocklayer(self, block, out_channels, block_nums, BatchNorm, kernel_size, stride=1): 131 | downsample = None 132 | if (stride != 1) or (self.in_channels != out_channels): 133 | downsample = nn.Sequential( 134 | conv3x3( 135 | self.in_channels, 136 | out_channels, 137 | kernel_size=kernel_size, 138 | stride=stride, 139 | ), 140 | BatchNorm(out_channels), 141 | ) 142 | 143 | layers = [] 144 | layers.append( 145 | block( 146 | self.in_channels, 147 | out_channels, 148 | BatchNorm, 149 | kernel_size, 150 | stride, 151 | downsample, 152 | is_top=True, 153 | ) 154 | ) 155 | self.in_channels = out_channels 156 | for i in range(1, block_nums): 157 | layers.append( 158 | block( 159 | out_channels, 160 | out_channels, 161 | BatchNorm, 162 | kernel_size, 163 | is_activation=True, 164 | is_top=False, 165 | ) 166 | ) 167 | 168 | return nn.Sequential(*layers) 169 | 170 | def forward(self, x): 171 | out1 = self.layer1(x) 172 | out2 = self.layer2(out1) 173 | out3 = self.layer3(out2) 174 | return out3 175 | 176 | 177 | class UVDocnet(nn.Module): 178 | def __init__(self, num_filter, kernel_size=5): 179 | super(UVDocnet, self).__init__() 180 | self.num_filter = num_filter 181 | self.in_channels = 3 182 | self.kernel_size = kernel_size 183 | self.stride = [1, 2, 2, 2] 184 | 185 | BatchNorm = nn.BatchNorm2d 186 | act_fn = nn.ReLU(inplace=True) 187 | map_num = [1, 2, 4, 8, 16] 188 | 189 | self.resnet_head = nn.Sequential( 190 | nn.Conv2d( 191 | self.in_channels, 192 | self.num_filter * map_num[0], 193 | bias=False, 194 | kernel_size=self.kernel_size, 195 | stride=2, 196 | padding=self.kernel_size // 2, 197 | ), 198 | BatchNorm(self.num_filter * map_num[0]), 199 | act_fn, 200 | nn.Conv2d( 201 | self.num_filter * map_num[0], 202 | self.num_filter * map_num[0], 203 | bias=False, 204 | kernel_size=self.kernel_size, 205 | stride=2, 206 | padding=self.kernel_size // 2, 207 | ), 208 | BatchNorm(self.num_filter * map_num[0]), 209 | act_fn, 210 | ) 211 | 212 | self.resnet_down = ResnetStraight( 213 | self.num_filter, 214 | map_num, 215 | BatchNorm, 216 | block_nums=[3, 4, 6, 3], 217 | block=ResidualBlockWithDilation, 218 | kernel_size=self.kernel_size, 219 | stride=self.stride, 220 | ) 221 | 222 | map_num_i = 2 223 | self.bridge_1 = nn.Sequential( 224 | dilated_conv_bn_act( 225 | self.num_filter * map_num[map_num_i], 226 | self.num_filter * map_num[map_num_i], 227 | act_fn, 228 | BatchNorm, 229 | dilation=1, 230 | ) 231 | ) 232 | 233 | self.bridge_2 = nn.Sequential( 234 | dilated_conv_bn_act( 235 | self.num_filter * map_num[map_num_i], 236 | self.num_filter * map_num[map_num_i], 237 | act_fn, 238 | BatchNorm, 239 | dilation=2, 240 | ) 241 | ) 242 | 243 | self.bridge_3 = nn.Sequential( 244 | dilated_conv_bn_act( 245 | self.num_filter * map_num[map_num_i], 246 | self.num_filter * map_num[map_num_i], 247 | act_fn, 248 | BatchNorm, 249 | dilation=5, 250 | ) 251 | ) 252 | 253 | self.bridge_4 = nn.Sequential( 254 | *[ 255 | dilated_conv_bn_act( 256 | self.num_filter * map_num[map_num_i], 257 | self.num_filter * map_num[map_num_i], 258 | act_fn, 259 | BatchNorm, 260 | dilation=d, 261 | ) 262 | for d in [8, 3, 2] 263 | ] 264 | ) 265 | 266 | self.bridge_5 = nn.Sequential( 267 | *[ 268 | dilated_conv_bn_act( 269 | self.num_filter * map_num[map_num_i], 270 | self.num_filter * map_num[map_num_i], 271 | act_fn, 272 | BatchNorm, 273 | dilation=d, 274 | ) 275 | for d in [12, 7, 4] 276 | ] 277 | ) 278 | 279 | self.bridge_6 = nn.Sequential( 280 | *[ 281 | dilated_conv_bn_act( 282 | self.num_filter * map_num[map_num_i], 283 | self.num_filter * map_num[map_num_i], 284 | act_fn, 285 | BatchNorm, 286 | dilation=d, 287 | ) 288 | for d in [18, 12, 6] 289 | ] 290 | ) 291 | 292 | self.bridge_concat = nn.Sequential( 293 | nn.Conv2d( 294 | self.num_filter * map_num[map_num_i] * 6, 295 | self.num_filter * map_num[2], 296 | bias=False, 297 | kernel_size=1, 298 | stride=1, 299 | padding=0, 300 | ), 301 | BatchNorm(self.num_filter * map_num[2]), 302 | act_fn, 303 | ) 304 | 305 | self.out_point_positions2D = nn.Sequential( 306 | nn.Conv2d( 307 | self.num_filter * map_num[2], 308 | self.num_filter * map_num[0], 309 | bias=False, 310 | kernel_size=self.kernel_size, 311 | stride=1, 312 | padding=self.kernel_size // 2, 313 | padding_mode="reflect", 314 | ), 315 | BatchNorm(self.num_filter * map_num[0]), 316 | nn.PReLU(), 317 | nn.Conv2d( 318 | self.num_filter * map_num[0], 319 | 2, 320 | kernel_size=self.kernel_size, 321 | stride=1, 322 | padding=self.kernel_size // 2, 323 | padding_mode="reflect", 324 | ), 325 | ) 326 | 327 | self.out_point_positions3D = nn.Sequential( 328 | nn.Conv2d( 329 | self.num_filter * map_num[2], 330 | self.num_filter * map_num[0], 331 | bias=False, 332 | kernel_size=self.kernel_size, 333 | stride=1, 334 | padding=self.kernel_size // 2, 335 | padding_mode="reflect", 336 | ), 337 | BatchNorm(self.num_filter * map_num[0]), 338 | nn.PReLU(), 339 | nn.Conv2d( 340 | self.num_filter * map_num[0], 341 | 3, 342 | kernel_size=self.kernel_size, 343 | stride=1, 344 | padding=self.kernel_size // 2, 345 | padding_mode="reflect", 346 | ), 347 | ) 348 | 349 | self._initialize_weights() 350 | 351 | def _initialize_weights(self): 352 | for m in self.modules(): 353 | if isinstance(m, nn.Conv2d): 354 | nn.init.xavier_normal_(m.weight, gain=0.2) 355 | if isinstance(m, nn.ConvTranspose2d): 356 | assert m.kernel_size[0] == m.kernel_size[1] 357 | nn.init.xavier_normal_(m.weight, gain=0.2) 358 | 359 | def forward(self, x): 360 | resnet_head = self.resnet_head(x) 361 | resnet_down = self.resnet_down(resnet_head) 362 | bridge_1 = self.bridge_1(resnet_down) 363 | bridge_2 = self.bridge_2(resnet_down) 364 | bridge_3 = self.bridge_3(resnet_down) 365 | bridge_4 = self.bridge_4(resnet_down) 366 | bridge_5 = self.bridge_5(resnet_down) 367 | bridge_6 = self.bridge_6(resnet_down) 368 | bridge_concat = torch.cat([bridge_1, bridge_2, bridge_3, bridge_4, bridge_5, bridge_6], dim=1) 369 | bridge = self.bridge_concat(bridge_concat) 370 | 371 | out_point_positions2D = self.out_point_positions2D(bridge) 372 | out_point_positions3D = self.out_point_positions3D(bridge) 373 | 374 | return out_point_positions2D, out_point_positions3D 375 | -------------------------------------------------------------------------------- /app/utils/pdf_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import tempfile 4 | import asyncio 5 | import logging 6 | from concurrent.futures import ProcessPoolExecutor 7 | import fitz # PyMuPDF 8 | from typing import Optional, Tuple, List, NamedTuple 9 | from markitdown import MarkItDown 10 | from openai import OpenAI 11 | from app.config import BASE_URL, MODEL 12 | 13 | # 配置日志 14 | logger = logging.getLogger(__name__) 15 | 16 | class OCRArgs(NamedTuple): 17 | """OCR参数集合""" 18 | image_path: str 19 | api_key: str 20 | mlm_prompt: str 21 | base_url: str 22 | model: str 23 | 24 | def ocr_worker(args: OCRArgs) -> Optional[str]: 25 | """OCR工作进程""" 26 | client = None 27 | try: 28 | client = OpenAI(base_url=args.base_url, api_key=args.api_key) 29 | markitdown = MarkItDown( 30 | llm_client=client, 31 | llm_model=args.model 32 | ) 33 | result = markitdown.convert(args.image_path, llm_prompt=args.mlm_prompt) 34 | return result.text_content 35 | except Exception as e: 36 | logger.error(f"OCR处理错误: {str(e)}") 37 | return None 38 | finally: 39 | if client: 40 | client.close() 41 | 42 | class PDFProcessor: 43 | def __init__(self, concurrent_limit: int = 5): 44 | """初始化PDF处理器""" 45 | self.concurrent_limit = concurrent_limit 46 | self.executor = None 47 | 48 | async def __aenter__(self): 49 | self.executor = ProcessPoolExecutor(max_workers=self.concurrent_limit) 50 | return self 51 | 52 | async def __aexit__(self, exc_type, exc_val, exc_tb): 53 | if self.executor: 54 | self.executor.shutdown(wait=True) 55 | self.executor = None 56 | 57 | async def process_page( 58 | self, 59 | page: fitz.Page, 60 | base_url: str, 61 | api_key: str, 62 | model: str, 63 | mlm_prompt: str, 64 | page_num: int 65 | ) -> Tuple[int, Optional[str]]: 66 | """处理单个PDF页面""" 67 | logger.info(f"开始处理第 {page_num + 1} 页...") 68 | 69 | # 尝试直接提取文本 70 | text = page.get_text().strip() 71 | if text: 72 | logger.info(f"第 {page_num + 1} 页: 成功直接提取文本,长度 {len(text)} 字符") 73 | return page_num, text 74 | 75 | # OCR处理 76 | logger.info(f"第 {page_num + 1} 页: 无法直接提取文本,开始OCR处理...") 77 | try: 78 | pix = page.get_pixmap() 79 | with tempfile.NamedTemporaryFile(suffix='.png', delete=False) as temp_img: 80 | pix.save(temp_img.name) 81 | logger.debug(f"第 {page_num + 1} 页: 临时图片已保存到 {temp_img.name}") 82 | 83 | try: 84 | args = OCRArgs( 85 | image_path=temp_img.name, 86 | api_key=api_key, 87 | mlm_prompt=mlm_prompt, 88 | base_url=base_url, 89 | model=model 90 | ) 91 | 92 | logger.info(f"第 {page_num + 1} 页: 开始OCR识别...") 93 | start_time = time.perf_counter() 94 | 95 | loop = asyncio.get_running_loop() 96 | text_content = await loop.run_in_executor( 97 | self.executor, 98 | ocr_worker, 99 | args 100 | ) 101 | 102 | process_time = time.perf_counter() - start_time 103 | if text_content: 104 | logger.info( 105 | f"第 {page_num + 1} 页: OCR识别成功," 106 | f"耗时 {process_time:.1f}秒," 107 | f"提取文本长度 {len(text_content)} 字符" 108 | ) 109 | else: 110 | logger.warning(f"第 {page_num + 1} 页: OCR识别失败,耗时 {process_time:.1f}秒") 111 | 112 | return page_num, text_content 113 | 114 | finally: 115 | try: 116 | os.unlink(temp_img.name) 117 | logger.debug(f"第 {page_num + 1} 页: 临时文件已清理") 118 | except Exception as e: 119 | logger.warning(f"第 {page_num + 1} 页: 清理临时文件失败: {str(e)}") 120 | 121 | except Exception as e: 122 | logger.error(f"第 {page_num + 1} 页处理错误: {str(e)}") 123 | return page_num, None 124 | 125 | async def process_batch( 126 | self, 127 | pdf_document: fitz.Document, 128 | base_url: str, 129 | api_key: str, 130 | model: str, 131 | mlm_prompt: str, 132 | start_page: int, 133 | end_page: int 134 | ) -> List[Tuple[int, Optional[str]]]: 135 | """处理一批PDF页面""" 136 | tasks = [] 137 | for page_num in range(start_page, min(end_page, len(pdf_document))): 138 | task = self.process_page( 139 | pdf_document[page_num], 140 | base_url, 141 | api_key, 142 | model, 143 | mlm_prompt, 144 | page_num 145 | ) 146 | tasks.append(task) 147 | 148 | results = await asyncio.gather(*tasks) 149 | return results 150 | 151 | async def extract_text( 152 | self, 153 | pdf_path: str, 154 | base_url: str, 155 | api_key: str, 156 | model: str, 157 | mlm_prompt: str, 158 | batch_size: int = 10 159 | ) -> Tuple[bool, str]: 160 | """提取PDF文本""" 161 | pdf_document = None 162 | start_time = time.perf_counter() 163 | try: 164 | pdf_document = fitz.open(pdf_path) 165 | total_pages = len(pdf_document) 166 | total_batches = (total_pages + batch_size - 1) // batch_size 167 | logger.info(f"开始处理PDF文件,总页数: {total_pages},批次数: {total_batches}") 168 | 169 | # 创建所有批次的任务 170 | tasks = [] 171 | for start_page in range(0, total_pages, batch_size): 172 | end_page = min(start_page + batch_size, total_pages) 173 | batch_num = start_page // batch_size + 1 174 | logger.info(f"创建批次 {batch_num}/{total_batches} (页码 {start_page+1}-{end_page})") 175 | task = self.process_batch( 176 | pdf_document, 177 | base_url, 178 | api_key, 179 | model, 180 | mlm_prompt, 181 | start_page, 182 | end_page 183 | ) 184 | tasks.append((batch_num, task)) 185 | 186 | # 并发执行所有批次 187 | logger.info(f"开始并发处理 {len(tasks)} 个批次...") 188 | batch_results = await asyncio.gather(*(task for _, task in tasks)) 189 | 190 | # 合并所有批次的结果并记录每个批次的完成情况 191 | all_results = [] 192 | for (batch_num, _), results in zip(tasks, batch_results): 193 | batch_success = sum(1 for _, text in results if text is not None) 194 | batch_total = len(results) 195 | logger.info( 196 | f"批次 {batch_num}/{total_batches} 完成: " 197 | f"成功 {batch_success}/{batch_total} 页 " 198 | f"({batch_success/batch_total*100:.1f}%)" 199 | ) 200 | all_results.extend(results) 201 | 202 | # 最终处理结果 203 | all_results.sort(key=lambda x: x[0]) 204 | valid_texts = [text for _, text in all_results if text] 205 | 206 | total_time = time.perf_counter() - start_time 207 | logger.info( 208 | f"PDF处理完成: 成功率 {(len(valid_texts)/total_pages*100):.1f}%," 209 | f" 总耗时 {total_time:.1f}秒," 210 | f" 平均每页 {(total_time/total_pages):.1f}秒" 211 | ) 212 | 213 | if valid_texts: 214 | return True, '\n'.join(valid_texts) 215 | return False, "" 216 | 217 | except Exception as e: 218 | elapsed_time = time.perf_counter() - start_time 219 | logger.error(f"PDF处理错误: {str(e)}, 耗时 {elapsed_time:.1f}秒") 220 | return False, "" 221 | finally: 222 | if pdf_document: 223 | pdf_document.close() 224 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from contextlib import asynccontextmanager 3 | from fastapi import FastAPI 4 | from fastapi.middleware.cors import CORSMiddleware 5 | from fastapi.responses import JSONResponse 6 | from fastapi import HTTPException 7 | import logging 8 | import warnings 9 | 10 | from app.api.md import router as md_router 11 | from app.api.uvdoc import router as uvdoc_router 12 | 13 | # 配置警告过滤 14 | warnings.filterwarnings("ignore", category=DeprecationWarning) 15 | warnings.filterwarnings("ignore", category=ResourceWarning) 16 | 17 | # 更多具体的警告过滤 18 | warnings.filterwarnings("ignore", message=".*swigvarlink.*") 19 | warnings.filterwarnings("ignore", message=".*unclosed.*SSLSocket.*") 20 | 21 | # 配置日志 22 | logging.basicConfig( 23 | level=logging.INFO, 24 | format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' 25 | ) 26 | 27 | @asynccontextmanager 28 | async def lifespan(app: FastAPI): 29 | # 启动时执行 30 | from app.api.uvdoc import load_model_fn 31 | load_model_fn() # 加载模型 32 | 33 | current_dir = Path(__file__).parent 34 | banner_path = current_dir / 'app' / 'banner.txt' 35 | print(f"Looking for banner at: {banner_path.absolute()}") 36 | try: 37 | with open(banner_path, 'r', encoding='utf-8') as f: 38 | banner = f.read() 39 | print(banner) 40 | except FileNotFoundError: 41 | print(f"Banner file not found at {banner_path}, starting server without banner...") 42 | yield 43 | 44 | def create_app() -> FastAPI: 45 | app = FastAPI( 46 | lifespan=lifespan 47 | ) 48 | 49 | # Add CORS middleware 50 | app.add_middleware( 51 | CORSMiddleware, 52 | allow_origins=["*"], 53 | allow_credentials=True, 54 | allow_methods=["*"], 55 | allow_headers=["*"], 56 | ) 57 | 58 | # 注册路由 59 | app.include_router(md_router) 60 | app.include_router(uvdoc_router) 61 | 62 | @app.exception_handler(HTTPException) 63 | async def http_exception_handler(request, exc): 64 | return JSONResponse( 65 | status_code=exc.status_code, 66 | content={"message": exc.detail}, 67 | ) 68 | 69 | return app 70 | 71 | app = create_app() -------------------------------------------------------------------------------- /model/best_model.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pig-mesh/office2md/d60522a86b66ccf37fe5af43ab7ed15bb9a2347c/model/best_model.pkl -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | fastapi>=0.68.0 2 | aiofiles 3 | uvicorn>=0.15.0 4 | openai 5 | socksio 6 | PyMuPDF 7 | markitdown[all]==0.1.1 8 | python-multipart>=0.0.5 9 | python-dotenv==1.0.1 10 | aiomultiprocess==0.9.0 11 | numpy==1.24.0 12 | opencv_python_headless==4.7.0.68 13 | torch==1.13.0 14 | pillow>=8.0.0 --------------------------------------------------------------------------------