├── DocImgTranslation ├── FastSAM │ ├── utils │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── tools.cpython-38.pyc │ │ │ └── __init__.cpython-38.pyc │ │ ├── tools_gradio.py │ │ └── tools.py │ ├── fastsam │ │ ├── .DS_Store │ │ ├── __pycache__ │ │ │ ├── model.cpython-38.pyc │ │ │ ├── prompt.cpython-38.pyc │ │ │ ├── utils.cpython-38.pyc │ │ │ ├── __init__.cpython-38.pyc │ │ │ ├── decoder.cpython-38.pyc │ │ │ └── predict.cpython-38.pyc │ │ ├── __init__.py │ │ ├── utils.py │ │ ├── predict.py │ │ ├── model.py │ │ ├── decoder.py │ │ └── prompt.py │ ├── __pycache__ │ │ └── Inference.cpython-38.pyc │ ├── setup.py │ ├── cog.yaml │ ├── segpredict.py │ ├── Inference.py │ ├── predict.py │ └── app_gradio.py ├── imageEnhancement.jpg ├── __pycache__ │ ├── structure.cpython-38.pyc │ ├── SegmentImage.cpython-38.pyc │ └── imageEnhancement.cpython-38.pyc ├── imageEnhancement.py ├── imageSegmentation1.py ├── ocr_image_ch2en.py ├── ocr_image_en2ch.py ├── SegmentImage.py └── structure.py ├── image ├── 12.jpg ├── translate.png ├── Xnip2023-09-10_23-21-12.jpg ├── Xnip2023-09-10_23-22-08.jpg ├── Xnip2023-09-10_23-22-59.jpg ├── Xnip2023-09-10_23-25-36.jpg ├── Xnip2023-09-10_23-26-39.jpg ├── Xnip2023-09-10_23-30-04.jpg ├── Xnip2023-09-10_23-30-57.jpg ├── Xnip2023-09-10_23-41-39.jpg ├── Xnip2023-09-10_23-50-55.jpg ├── Xnip2023-09-10_23-53-20.jpg ├── Xnip2023-09-10_23-54-02.jpg ├── Xnip2023-09-11_00-00-42.jpg ├── Xnip2023-09-11_00-03-44.jpg ├── Xnip2023-09-11_00-07-27.jpg ├── Xnip2023-09-11_00-09-16.jpg ├── Xnip2023-09-11_11-55-56.jpg └── Xnip2023-09-11_11-57-18.jpg ├── static ├── simfang.ttf └── translate.png ├── config.conf ├── Translation ├── __pycache__ │ ├── translation_ch2en.cpython-38.pyc │ └── translation_en2ch.cpython-38.pyc ├── translation_ch2en.py └── translation_en2ch.py ├── SpeechTranslation ├── __pycache__ │ ├── speech_translation_ch2en.cpython-38.pyc │ └── speech_translation_en2ch.cpython-38.pyc ├── speech_translation_en2ch.py └── speech_translation_ch2en.py ├── SubtitleTranslation ├── tmp.srt ├── __pycache__ │ ├── subtitle_translation_ch2en.cpython-38.pyc │ └── subtitle_translation_en2ch.cpython-38.pyc ├── subtitle_translation_ch2en.py └── subtitle_translation_en2ch.py ├── FileTranslation ├── Excel │ ├── tmp.html │ ├── excel_translation_ch2en.py │ └── excel_translation_en2ch.py ├── TXT │ ├── txt_translation_ch2en.py │ └── txt_translation_en2ch.py ├── PPT │ ├── ppt_translation_ch2en.py │ └── ppt_translation_en2ch.py ├── Word │ ├── word_translation_ch2en.py │ └── word_translation_en2ch.py ├── Image │ ├── image_translation_ch2en.py │ └── image_translation_en2ch.py └── PDF │ ├── pdf_translation_ch2en.py │ └── pdf_translation_en2ch.py ├── requirements.txt ├── HyperTranslation ├── hyper_translation_ch2en.py └── hyper_translation_en2ch.py ├── my_utils.py ├── ScreenshotTranslation ├── screenshot_translation_ch2en.py └── screenshot_translation_en2ch.py ├── README.md └── LICENSE /DocImgTranslation/FastSAM/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /image/12.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tianclll/Ace-Translate/HEAD/image/12.jpg -------------------------------------------------------------------------------- /image/translate.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tianclll/Ace-Translate/HEAD/image/translate.png -------------------------------------------------------------------------------- /static/simfang.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tianclll/Ace-Translate/HEAD/static/simfang.ttf -------------------------------------------------------------------------------- /static/translate.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tianclll/Ace-Translate/HEAD/static/translate.png -------------------------------------------------------------------------------- /config.conf: -------------------------------------------------------------------------------- 1 | [shortcuts] 2 | translation = ++f 3 | screenshot = ++s 4 | 5 | [devices] 6 | device = cpu -------------------------------------------------------------------------------- /image/Xnip2023-09-10_23-21-12.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tianclll/Ace-Translate/HEAD/image/Xnip2023-09-10_23-21-12.jpg -------------------------------------------------------------------------------- /image/Xnip2023-09-10_23-22-08.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tianclll/Ace-Translate/HEAD/image/Xnip2023-09-10_23-22-08.jpg -------------------------------------------------------------------------------- /image/Xnip2023-09-10_23-22-59.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tianclll/Ace-Translate/HEAD/image/Xnip2023-09-10_23-22-59.jpg -------------------------------------------------------------------------------- /image/Xnip2023-09-10_23-25-36.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tianclll/Ace-Translate/HEAD/image/Xnip2023-09-10_23-25-36.jpg -------------------------------------------------------------------------------- /image/Xnip2023-09-10_23-26-39.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tianclll/Ace-Translate/HEAD/image/Xnip2023-09-10_23-26-39.jpg -------------------------------------------------------------------------------- /image/Xnip2023-09-10_23-30-04.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tianclll/Ace-Translate/HEAD/image/Xnip2023-09-10_23-30-04.jpg -------------------------------------------------------------------------------- /image/Xnip2023-09-10_23-30-57.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tianclll/Ace-Translate/HEAD/image/Xnip2023-09-10_23-30-57.jpg -------------------------------------------------------------------------------- /image/Xnip2023-09-10_23-41-39.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tianclll/Ace-Translate/HEAD/image/Xnip2023-09-10_23-41-39.jpg -------------------------------------------------------------------------------- /image/Xnip2023-09-10_23-50-55.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tianclll/Ace-Translate/HEAD/image/Xnip2023-09-10_23-50-55.jpg -------------------------------------------------------------------------------- /image/Xnip2023-09-10_23-53-20.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tianclll/Ace-Translate/HEAD/image/Xnip2023-09-10_23-53-20.jpg -------------------------------------------------------------------------------- /image/Xnip2023-09-10_23-54-02.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tianclll/Ace-Translate/HEAD/image/Xnip2023-09-10_23-54-02.jpg -------------------------------------------------------------------------------- /image/Xnip2023-09-11_00-00-42.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tianclll/Ace-Translate/HEAD/image/Xnip2023-09-11_00-00-42.jpg -------------------------------------------------------------------------------- /image/Xnip2023-09-11_00-03-44.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tianclll/Ace-Translate/HEAD/image/Xnip2023-09-11_00-03-44.jpg -------------------------------------------------------------------------------- /image/Xnip2023-09-11_00-07-27.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tianclll/Ace-Translate/HEAD/image/Xnip2023-09-11_00-07-27.jpg -------------------------------------------------------------------------------- /image/Xnip2023-09-11_00-09-16.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tianclll/Ace-Translate/HEAD/image/Xnip2023-09-11_00-09-16.jpg -------------------------------------------------------------------------------- /image/Xnip2023-09-11_11-55-56.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tianclll/Ace-Translate/HEAD/image/Xnip2023-09-11_11-55-56.jpg -------------------------------------------------------------------------------- /image/Xnip2023-09-11_11-57-18.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tianclll/Ace-Translate/HEAD/image/Xnip2023-09-11_11-57-18.jpg -------------------------------------------------------------------------------- /DocImgTranslation/imageEnhancement.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tianclll/Ace-Translate/HEAD/DocImgTranslation/imageEnhancement.jpg -------------------------------------------------------------------------------- /DocImgTranslation/FastSAM/fastsam/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tianclll/Ace-Translate/HEAD/DocImgTranslation/FastSAM/fastsam/.DS_Store -------------------------------------------------------------------------------- /DocImgTranslation/__pycache__/structure.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tianclll/Ace-Translate/HEAD/DocImgTranslation/__pycache__/structure.cpython-38.pyc -------------------------------------------------------------------------------- /DocImgTranslation/__pycache__/SegmentImage.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tianclll/Ace-Translate/HEAD/DocImgTranslation/__pycache__/SegmentImage.cpython-38.pyc -------------------------------------------------------------------------------- /Translation/__pycache__/translation_ch2en.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tianclll/Ace-Translate/HEAD/Translation/__pycache__/translation_ch2en.cpython-38.pyc -------------------------------------------------------------------------------- /Translation/__pycache__/translation_en2ch.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tianclll/Ace-Translate/HEAD/Translation/__pycache__/translation_en2ch.cpython-38.pyc -------------------------------------------------------------------------------- /DocImgTranslation/FastSAM/__pycache__/Inference.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tianclll/Ace-Translate/HEAD/DocImgTranslation/FastSAM/__pycache__/Inference.cpython-38.pyc -------------------------------------------------------------------------------- /DocImgTranslation/__pycache__/imageEnhancement.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tianclll/Ace-Translate/HEAD/DocImgTranslation/__pycache__/imageEnhancement.cpython-38.pyc -------------------------------------------------------------------------------- /DocImgTranslation/FastSAM/utils/__pycache__/tools.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tianclll/Ace-Translate/HEAD/DocImgTranslation/FastSAM/utils/__pycache__/tools.cpython-38.pyc -------------------------------------------------------------------------------- /DocImgTranslation/FastSAM/fastsam/__pycache__/model.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tianclll/Ace-Translate/HEAD/DocImgTranslation/FastSAM/fastsam/__pycache__/model.cpython-38.pyc -------------------------------------------------------------------------------- /DocImgTranslation/FastSAM/fastsam/__pycache__/prompt.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tianclll/Ace-Translate/HEAD/DocImgTranslation/FastSAM/fastsam/__pycache__/prompt.cpython-38.pyc -------------------------------------------------------------------------------- /DocImgTranslation/FastSAM/fastsam/__pycache__/utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tianclll/Ace-Translate/HEAD/DocImgTranslation/FastSAM/fastsam/__pycache__/utils.cpython-38.pyc -------------------------------------------------------------------------------- /DocImgTranslation/FastSAM/utils/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tianclll/Ace-Translate/HEAD/DocImgTranslation/FastSAM/utils/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /DocImgTranslation/FastSAM/fastsam/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tianclll/Ace-Translate/HEAD/DocImgTranslation/FastSAM/fastsam/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /DocImgTranslation/FastSAM/fastsam/__pycache__/decoder.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tianclll/Ace-Translate/HEAD/DocImgTranslation/FastSAM/fastsam/__pycache__/decoder.cpython-38.pyc -------------------------------------------------------------------------------- /DocImgTranslation/FastSAM/fastsam/__pycache__/predict.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tianclll/Ace-Translate/HEAD/DocImgTranslation/FastSAM/fastsam/__pycache__/predict.cpython-38.pyc -------------------------------------------------------------------------------- /SpeechTranslation/__pycache__/speech_translation_ch2en.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tianclll/Ace-Translate/HEAD/SpeechTranslation/__pycache__/speech_translation_ch2en.cpython-38.pyc -------------------------------------------------------------------------------- /SpeechTranslation/__pycache__/speech_translation_en2ch.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tianclll/Ace-Translate/HEAD/SpeechTranslation/__pycache__/speech_translation_en2ch.cpython-38.pyc -------------------------------------------------------------------------------- /SubtitleTranslation/tmp.srt: -------------------------------------------------------------------------------- 1 | 1 2 | 0:00:01,000 --> 0:00:01,1500 3 | Let's travel from time to time. 4 | 5 | 2 6 | 0:00:03,000 --> 0:00:03,1500 7 | The story of opening a plant paradise. 8 | 9 | -------------------------------------------------------------------------------- /SubtitleTranslation/__pycache__/subtitle_translation_ch2en.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tianclll/Ace-Translate/HEAD/SubtitleTranslation/__pycache__/subtitle_translation_ch2en.cpython-38.pyc -------------------------------------------------------------------------------- /SubtitleTranslation/__pycache__/subtitle_translation_en2ch.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tianclll/Ace-Translate/HEAD/SubtitleTranslation/__pycache__/subtitle_translation_en2ch.cpython-38.pyc -------------------------------------------------------------------------------- /FileTranslation/Excel/tmp.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | Table HTML Page 5 | 6 | 7 |
CategoryPrices
Apple.6
Bananas.2
8 | 9 | 10 | -------------------------------------------------------------------------------- /DocImgTranslation/FastSAM/fastsam/__init__.py: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, AGPL-3.0 license 2 | 3 | from .model import FastSAM 4 | from .predict import FastSAMPredictor 5 | from .prompt import FastSAMPrompt 6 | # from .val import FastSAMValidator 7 | from .decoder import FastSAMDecoder 8 | 9 | __all__ = 'FastSAMPredictor', 'FastSAM', 'FastSAMPrompt', 'FastSAMDecoder' 10 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | paddleocr>=2.6.0.3 2 | paddleclas>=2.4.3 3 | pytest-runner 4 | paddlespeech==1.2.0 5 | paddlenlp<=2.0.5 6 | torch>1.7.0 7 | Pillow==9.5.0 8 | nltk 9 | pystray 10 | wave 11 | pydub 12 | pynput 13 | pyperclip 14 | ultralytics == 8.0.120 15 | pandas>=1.1.4 16 | tqdm>=4.64.0 17 | scipy>=1.4.1 18 | requests>=2.23.0 19 | PyYAML>=5.3.1 20 | opencv-python>=4.6.0 21 | ttkthemes 22 | python-pptx 23 | transformers 24 | sacremoses 25 | moviepy 26 | pysrt 27 | -------------------------------------------------------------------------------- /DocImgTranslation/FastSAM/setup.py: -------------------------------------------------------------------------------- 1 | # This source code is licensed under the license found in the 2 | # LICENSE file in the root directory of this source tree. 3 | 4 | from setuptools import find_packages, setup 5 | 6 | setup( 7 | name="fastsam", 8 | version="0.1.1", 9 | install_requires=[], 10 | package_dir= { 11 | "fastsam": "fastsam", 12 | "fastsam_tools": "utils", 13 | }, 14 | url="https://github.com/CASIA-IVA-Lab/FastSAM", 15 | ) -------------------------------------------------------------------------------- /Translation/translation_ch2en.py: -------------------------------------------------------------------------------- 1 | import transformers 2 | from transformers import AutoTokenizer, AutoModelForSeq2SeqLM 3 | def translate(text): 4 | model_path = 'models/translate/zh-en' 5 | tokenizer = AutoTokenizer.from_pretrained(model_path) 6 | translate_model = AutoModelForSeq2SeqLM.from_pretrained(model_path) 7 | pipeline = transformers.pipeline("translation", model=translate_model, tokenizer=tokenizer) 8 | translate_text = pipeline(text)[0]['translation_text'] 9 | return translate_text -------------------------------------------------------------------------------- /Translation/translation_en2ch.py: -------------------------------------------------------------------------------- 1 | 2 | import transformers 3 | from transformers import AutoTokenizer, AutoModelForSeq2SeqLM 4 | import sys 5 | import os 6 | sys.path.append(os.getcwd()) 7 | import my_utils 8 | def translate(text): 9 | model_path = 'models/translate/en-zh' 10 | tokenizer = AutoTokenizer.from_pretrained(model_path) 11 | translate_model = AutoModelForSeq2SeqLM.from_pretrained(model_path) 12 | pipeline = transformers.pipeline("translation", model=translate_model, tokenizer=tokenizer) 13 | translate_text = pipeline(text)[0]['translation_text'] 14 | translate_text = my_utils.do_sentence(translate_text) 15 | return translate_text -------------------------------------------------------------------------------- /DocImgTranslation/FastSAM/cog.yaml: -------------------------------------------------------------------------------- 1 | # Configuration for Cog ⚙️ 2 | # Reference: https://github.com/replicate/cog/blob/main/docs/yaml.md 3 | # Thanks for chenxwh. 4 | 5 | build: 6 | # set to true if your model requires a GPU 7 | gpu: true 8 | cuda: "11.7" 9 | system_packages: 10 | - "libgl1-mesa-glx" 11 | - "libglib2.0-0" 12 | python_version: "3.8" 13 | python_packages: 14 | - "matplotlib==3.7.1" 15 | - "opencv-python==4.7.0.72" 16 | - "Pillow==9.5.0" 17 | - "PyYAML==6.0" 18 | - "requests==2.31.0" 19 | - "scipy==1.10.1" 20 | - "torch==2.0.1" 21 | - "torchvision==0.15.2" 22 | - "tqdm==4.65.0" 23 | - "pandas==2.0.2" 24 | - "seaborn==0.12.0" 25 | - "ultralytics==8.0.121" 26 | - git+https://github.com/openai/CLIP.git 27 | predict: "predict.py:Predictor" 28 | -------------------------------------------------------------------------------- /FileTranslation/TXT/txt_translation_ch2en.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | import transformers 4 | from transformers import AutoTokenizer, AutoModelForSeq2SeqLM 5 | if __name__ == '__main__': 6 | savePath = sys.argv[2] 7 | # 1、excel地址 8 | txtPath = sys.argv[1] 9 | model_path = 'models/translate/zh-en' 10 | tokenizer = AutoTokenizer.from_pretrained(model_path) 11 | model = AutoModelForSeq2SeqLM.from_pretrained(model_path) 12 | translate_model = transformers.pipeline("translation", model=model, tokenizer=tokenizer) 13 | with open(txtPath,'r') as f: 14 | text = f.readlines() 15 | f.close() 16 | 17 | lines = [] 18 | for line in text: 19 | line_clean = line.replace('\n','') 20 | translate_line = translate_model(line_clean)[0]['translation_text'] 21 | translate_line = translate_line + '\n' 22 | lines.append(translate_line) 23 | 24 | txt_name = os.path.basename(txtPath) 25 | with open(os.path.join(savePath,txt_name),'w') as f: 26 | f.writelines(lines) 27 | f.close() -------------------------------------------------------------------------------- /DocImgTranslation/imageEnhancement.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | 4 | 5 | # Gamma校正 fGamaa=0.45是常用值 6 | def gamma_correction(src, fGamma): 7 | # 构建查找表 8 | lut = np.array([np.uint8(pow(i / 255.0, fGamma) * 255) for i in range(256)]) 9 | dst = cv2.LUT(src, lut) 10 | return dst 11 | 12 | 13 | def image_enhancement(image_data): 14 | # image_data = cv2.cvtColor(image_data, cv2.COLOR_BGR2GRAY) 15 | # 划分算法 16 | # 如果混合色与基色相同则结果色为白色 17 | # 如混合色为白色则结果色为基色不变 18 | # 如混合色为黑色则结果色为白色 19 | src = image_data.astype(np.float32) / 255.0 20 | gauss = cv2.GaussianBlur(src, (101, 101), 0) 21 | dst = src / gauss 22 | dst = np.clip(dst * 255, 0, 255).astype(np.uint8) 23 | 24 | # gamma变换 25 | matGamma = gamma_correction(dst.copy(), 1.5) 26 | 27 | # 显示最终结果 28 | cv2.imwrite("DocImgTranslation/imageEnhancement.jpg", matGamma) 29 | return matGamma 30 | 31 | if __name__ == "__main__": 32 | P = cv2.imread('12.jpg', cv2.IMREAD_COLOR) 33 | gray_P = cv2.cvtColor(P, cv2.COLOR_BGR2GRAY) 34 | P_shape = gray_P.shape 35 | 36 | image_enhancement(gray_P) 37 | 38 | -------------------------------------------------------------------------------- /FileTranslation/TXT/txt_translation_en2ch.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import sys 4 | sys.path.append(os.getcwd()) 5 | import my_utils 6 | import transformers 7 | from transformers import AutoTokenizer, AutoModelForSeq2SeqLM 8 | if __name__ == '__main__': 9 | savePath = sys.argv[2] 10 | # 1、excel地址 11 | txtPath = sys.argv[1] 12 | model_path = 'models/translate/en-zh' 13 | tokenizer = AutoTokenizer.from_pretrained(model_path) 14 | model = AutoModelForSeq2SeqLM.from_pretrained(model_path) 15 | translate_model = transformers.pipeline("translation", model=model, tokenizer=tokenizer) 16 | with open(txtPath,'r') as f: 17 | text = f.readlines() 18 | f.close() 19 | 20 | lines = [] 21 | for line in text: 22 | line_clean = line.replace('\n','') 23 | translate_line = translate_model(line_clean)[0]['translation_text'] 24 | translate_line = my_utils.do_sentence(translate_line) 25 | translate_line = translate_line + '\n' 26 | lines.append(translate_line) 27 | 28 | txt_name = os.path.basename(txtPath) 29 | with open(os.path.join(savePath,txt_name),'w') as f: 30 | f.writelines(lines) 31 | f.close() -------------------------------------------------------------------------------- /FileTranslation/PPT/ppt_translation_ch2en.py: -------------------------------------------------------------------------------- 1 | from pptx import Presentation 2 | from time import sleep 3 | import transformers 4 | from transformers import AutoTokenizer, AutoModelForSeq2SeqLM 5 | import os 6 | import sys 7 | if __name__ == '__main__': 8 | savePath = sys.argv[2] 9 | # 1、ppt地址 10 | pptPath = sys.argv[1] 11 | model_path = 'models/translate/zh-en' 12 | tokenizer = AutoTokenizer.from_pretrained(model_path) 13 | model = AutoModelForSeq2SeqLM.from_pretrained(model_path) 14 | translate_model = transformers.pipeline("translation", model=model, tokenizer=tokenizer) 15 | prs = Presentation(pptPath) 16 | 17 | print("start") 18 | for ns, slide in enumerate(prs.slides): 19 | for nsh, shape in enumerate(slide.shapes): 20 | if not shape.has_text_frame: 21 | continue 22 | for np, paragraph in enumerate(shape.text_frame.paragraphs): 23 | for rs, run in enumerate(paragraph.runs): 24 | str_in = run.text 25 | str_out = translate_model(str_in)[0]['translation_text'] 26 | prs.slides[ns].shapes[nsh].text_frame.paragraphs[np].runs[rs].text = str_out 27 | sleep(1.5) 28 | print(np) 29 | ppt_name = os.path.basename(pptPath) 30 | print(ppt_name) 31 | prs.save(os.path.join(savePath,ppt_name)) 32 | print(os.path.join(savePath,ppt_name)) -------------------------------------------------------------------------------- /FileTranslation/PPT/ppt_translation_en2ch.py: -------------------------------------------------------------------------------- 1 | from pptx import Presentation 2 | from time import sleep 3 | import transformers 4 | from transformers import AutoTokenizer, AutoModelForSeq2SeqLM 5 | import os 6 | import sys 7 | sys.path.append(os.getcwd()) 8 | import my_utils 9 | 10 | 11 | if __name__ == '__main__': 12 | savePath = sys.argv[2] 13 | # 1、ppt地址 14 | pptPath = sys.argv[1] 15 | model_path = 'models/translate/en-zh' 16 | tokenizer = AutoTokenizer.from_pretrained(model_path) 17 | model = AutoModelForSeq2SeqLM.from_pretrained(model_path) 18 | translate_model = transformers.pipeline("translation", model=model, tokenizer=tokenizer) 19 | prs = Presentation(pptPath) 20 | 21 | print("start") 22 | for ns, slide in enumerate(prs.slides): 23 | for nsh, shape in enumerate(slide.shapes): 24 | if not shape.has_text_frame: 25 | continue 26 | for np, paragraph in enumerate(shape.text_frame.paragraphs): 27 | for rs, run in enumerate(paragraph.runs): 28 | str_in = run.text 29 | str_out = translate_model(str_in)[0]['translation_text'] 30 | str_out = my_utils.do_sentence(str_out) 31 | prs.slides[ns].shapes[nsh].text_frame.paragraphs[np].runs[rs].text = str_out 32 | sleep(1.5) 33 | ppt_name = os.path.basename(pptPath) 34 | prs.save(os.path.join(savePath,ppt_name)) -------------------------------------------------------------------------------- /DocImgTranslation/FastSAM/segpredict.py: -------------------------------------------------------------------------------- 1 | from fastsam import FastSAM, FastSAMPrompt 2 | import torch 3 | 4 | model = FastSAM('FastSAM.pt') 5 | IMAGE_PATH = './images/dogs.jpg' 6 | DEVICE = torch.device( 7 | "cuda" 8 | if torch.cuda.is_available() 9 | else "mps" 10 | if torch.backends.mps.is_available() 11 | else "cpu" 12 | ) 13 | everything_results = model( 14 | IMAGE_PATH, 15 | device=DEVICE, 16 | retina_masks=True, 17 | imgsz=1024, 18 | conf=0.4, 19 | iou=0.9, 20 | ) 21 | prompt_process = FastSAMPrompt(IMAGE_PATH, everything_results, device=DEVICE) 22 | 23 | # # everything prompt 24 | ann = prompt_process.everything_prompt() 25 | 26 | # # bbox prompt 27 | # # bbox default shape [0,0,0,0] -> [x1,y1,x2,y2] 28 | # bboxes default shape [[0,0,0,0]] -> [[x1,y1,x2,y2]] 29 | # ann = prompt_process.box_prompt(bbox=[200, 200, 300, 300]) 30 | # ann = prompt_process.box_prompt(bboxes=[[200, 200, 300, 300], [500, 500, 600, 600]]) 31 | 32 | # # text prompt 33 | # ann = prompt_process.text_prompt(text='a photo of a dog') 34 | 35 | # # point prompt 36 | # # points default [[0,0]] [[x1,y1],[x2,y2]] 37 | # # point_label default [0] [1,0] 0:background, 1:foreground 38 | # ann = prompt_process.point_prompt(points=[[620, 360]], pointlabel=[1]) 39 | 40 | # point prompt 41 | # points default [[0,0]] [[x1,y1],[x2,y2]] 42 | # point_label default [0] [1,0] 0:background, 1:foreground 43 | ann = prompt_process.point_prompt(points=[[620, 360]], pointlabel=[1]) 44 | 45 | prompt_process.plot( 46 | annotations=ann, 47 | output='./output/', 48 | mask_random_color=True, 49 | better_quality=True, 50 | retina=False, 51 | withContours=True, 52 | ) 53 | -------------------------------------------------------------------------------- /DocImgTranslation/FastSAM/Inference.py: -------------------------------------------------------------------------------- 1 | 2 | from DocImgTranslation.FastSAM.fastsam import FastSAM, FastSAMPrompt 3 | import ast 4 | import torch 5 | from PIL import Image 6 | import configparser 7 | config = configparser.ConfigParser() 8 | config.read("config.conf") 9 | device = config.get("devices", "device") 10 | def main(img_path,device=device,text_prompt="A sheet of white A4 paper",model_path='models/FastSAM.pt'): 11 | # load model 12 | model = FastSAM(model_path) 13 | input = Image.open(img_path) 14 | 15 | # 检查图像的旋转方向标志 16 | if hasattr(input, '_getexif'): # 检查是否有EXIF数据 17 | exif = input._getexif() 18 | if exif is not None: 19 | orientation = exif.get(0x0112) # 获取方向标志(0x0112是EXIF中的方向标志) 20 | if orientation is not None: 21 | if orientation == 1: # 没有旋转 22 | pass 23 | elif orientation == 3: # 旋转180度 24 | input = input.rotate(180, expand=True) 25 | elif orientation == 6: # 顺时针旋转90度 26 | input = input.rotate(270, expand=True) 27 | elif orientation == 8: # 逆时针旋转90度 28 | input = input.rotate(90, expand=True) 29 | input = input.convert("RGB") 30 | everything_results = model( 31 | input, 32 | device=device, 33 | retina_masks=True, 34 | imgsz=512, 35 | conf=0.7, 36 | iou=0.6 37 | ) 38 | prompt_process = FastSAMPrompt(input, everything_results, device=device) 39 | # ann = prompt_process.text_prompt(text=text_prompt) 40 | # everything prompt 41 | ann = prompt_process.everything_prompt() 42 | return ann 43 | 44 | if __name__ == "__main__": 45 | img_path="/Users/liuhongdi/计算机设计大赛/OCR图片文字识别/数据/jpg_pdf/1/16.jpg" 46 | main(img_path,device='cpu') 47 | -------------------------------------------------------------------------------- /FileTranslation/Word/word_translation_ch2en.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from docx import Document 3 | import os 4 | import transformers 5 | from transformers import AutoTokenizer, AutoModelForSeq2SeqLM 6 | def get_paragraphs_text(path): 7 | """ 8 | 获取所有段落的文本 9 | :param path: word路径 10 | :return: list类型,如: 11 | ['Test', 'hello world', ...] 12 | """ 13 | document = Document(path) 14 | all_paragraphs = document.paragraphs 15 | paragraphs_text = [] 16 | for paragraph in all_paragraphs: 17 | # 拼接一个list,包括段落的结构和内容 18 | paragraphs_text.append([paragraph.style.name,paragraph.text]) 19 | return paragraphs_text 20 | if __name__ == '__main__': 21 | savePath = sys.argv[2] 22 | # 1、excel地址 23 | wordPath = sys.argv[1] 24 | # 提取文档信息 25 | text = get_paragraphs_text(wordPath) 26 | model_path = 'models/translate/zh-en' 27 | tokenizer = AutoTokenizer.from_pretrained(model_path) 28 | model = AutoModelForSeq2SeqLM.from_pretrained(model_path) 29 | translate_model = transformers.pipeline("translation", model=model, tokenizer=tokenizer) 30 | new_document = Document() 31 | for item in text: 32 | src_style = item[0] 33 | src_texts = [item[1]] 34 | print('标题层级:', src_style) 35 | print(src_texts) 36 | n_best = 1 # 每个输入样本的输出候选句子数量 37 | trg_texts = [translate_model(src_texts[0])[0]['translation_text']] 38 | print(trg_texts) 39 | if trg_texts == ['N']: 40 | trg_texts = [''] 41 | if src_style == 'Title': 42 | new_document.add_heading(trg_texts, 0) 43 | elif src_style[:7:] == 'Heading': 44 | new_document.add_heading(trg_texts, level=int(src_style[-1])) 45 | else: 46 | new_document.add_paragraph(trg_texts) 47 | word_name = os.path.basename(wordPath) 48 | new_document.save(os.path.join(savePath,word_name)) -------------------------------------------------------------------------------- /FileTranslation/Word/word_translation_en2ch.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from docx import Document 3 | import os 4 | sys.path.append(os.getcwd()) 5 | import my_utils 6 | import transformers 7 | from transformers import AutoTokenizer, AutoModelForSeq2SeqLM 8 | def get_paragraphs_text(path): 9 | """ 10 | 获取所有段落的文本 11 | :param path: word路径 12 | :return: list类型,如: 13 | ['Test', 'hello world', ...] 14 | """ 15 | document = Document(path) 16 | all_paragraphs = document.paragraphs 17 | paragraphs_text = [] 18 | for paragraph in all_paragraphs: 19 | # 拼接一个list,包括段落的结构和内容 20 | paragraphs_text.append([paragraph.style.name,paragraph.text]) 21 | return paragraphs_text 22 | if __name__ == '__main__': 23 | savePath = sys.argv[2] 24 | # 1、excel地址 25 | wordPath = sys.argv[1] 26 | # 提取文档信息 27 | text = get_paragraphs_text(wordPath) 28 | model_path = 'models/translate/en-zh' 29 | tokenizer = AutoTokenizer.from_pretrained(model_path) 30 | model = AutoModelForSeq2SeqLM.from_pretrained(model_path) 31 | translate_model = transformers.pipeline("translation", model=model, tokenizer=tokenizer) 32 | new_document = Document() 33 | for item in text: 34 | src_style = item[0] 35 | src_texts = [item[1]] 36 | print('标题层级:', src_style) 37 | print(src_texts) 38 | n_best = 1 # 每个输入样本的输出候选句子数量 39 | trg_texts = [my_utils.do_sentence(translate_model(src_texts[0])[0]['translation_text'])] 40 | print(trg_texts) 41 | if trg_texts == ['N']: 42 | trg_texts = [''] 43 | if src_style == 'Title': 44 | new_document.add_heading(trg_texts, 0) 45 | elif src_style[:7:] == 'Heading': 46 | new_document.add_heading(trg_texts, level=int(src_style[-1])) 47 | else: 48 | new_document.add_paragraph(trg_texts) 49 | word_name = os.path.basename(wordPath) 50 | new_document.save(os.path.join(savePath,word_name)) -------------------------------------------------------------------------------- /DocImgTranslation/imageSegmentation1.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import requests 3 | import json 4 | import base64 5 | import numpy as np 6 | import cv2 7 | def get_file_content(filePath): 8 | with open(filePath, 'rb') as fp: 9 | return fp.read() 10 | 11 | class CommonOcr(object): 12 | def __init__(self, img_path): 13 | # 请登录后前往 “工作台-账号设置-开发者信息” 查看 x-ti-app-id 14 | # 示例代码中 x-ti-app-id 非真实数据 15 | self._app_id = 'a1ad6dc8a0a9f5da5487cb5a1a6e5bcf' 16 | # 请登录后前往 “工作台-账号设置-开发者信息” 查看 x-ti-secret-code 17 | # 示例代码中 x-ti-secret-code 非真实数据 18 | self._secret_code = '9ec1a530b4c985ed431891ca83e19acd' 19 | self._img_path = img_path 20 | 21 | def recognize(self): 22 | # 文档图像切边矫正 23 | url = 'https://api.textin.com/ai/service/v1/dewarp' 24 | head = {} 25 | try: 26 | image = get_file_content(self._img_path) 27 | head['x-ti-app-id'] = self._app_id 28 | head['x-ti-secret-code'] = self._secret_code 29 | result = requests.post(url, data=image, headers=head) 30 | return result.text 31 | except Exception as e: 32 | return e 33 | def image_segmentation(image_path): 34 | response = CommonOcr(image_path) 35 | result = response.recognize() 36 | result = json.loads(result)["result"]["image"] 37 | img_binary = base64.b64decode(result) 38 | # Convert binary to numpy array 39 | img_np = np.fromstring(img_binary, np.uint8) 40 | # Convert numpy array to image 41 | img = cv2.imdecode(img_np, cv2.IMREAD_COLOR) 42 | return img 43 | if __name__ == "__main__": 44 | response = CommonOcr(r'/Users/liuhongdi/计算机设计大赛/OCR图片文字识别/数据/JPG/000022.JPG') 45 | result = response.recognize() 46 | result = json.loads(result)["result"]["image"] 47 | 48 | img_binary = base64.b64decode(result) 49 | 50 | # Convert binary to numpy array 51 | img_np = np.fromstring(img_binary, np.uint8) 52 | 53 | # Convert numpy array to image 54 | img = cv2.imdecode(img_np, cv2.IMREAD_COLOR) -------------------------------------------------------------------------------- /DocImgTranslation/ocr_image_ch2en.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | from structure import Structure 4 | from imageEnhancement import image_enhancement 5 | from SegmentImage import image_segmentation 6 | from paddleocr.ppocr.utils.logging import get_logger 7 | 8 | logger = get_logger() 9 | ap = argparse.ArgumentParser() 10 | ap.add_argument("-i", "--image", help="path to the input image") 11 | ap.add_argument("-d","--image_dir") 12 | ap.add_argument("-o","--output") 13 | args = vars(ap.parse_args()) 14 | 15 | if args['image'] != None and args['image_dir'] == None: 16 | logger.info('开始分割图像...') 17 | img = image_segmentation(args['image']) 18 | #img = cv2.imread(args['image'],0) 19 | logger.info('分割图像已完成!') 20 | logger.info('开始图像增强...') 21 | img = image_enhancement(img) 22 | logger.info('图像增强已完成!') 23 | logger.info('开始版面分析并转为Word...') 24 | Structure(img,args['image'],args['output'],"ch2en") 25 | logger.info('已处理完成!') 26 | elif args['image'] == None and args['image_dir'] != None: 27 | image_dir = args['image_dir'] 28 | if os.path.exists(image_dir): 29 | contents = os.listdir(image_dir) 30 | for i, image in enumerate(contents): 31 | if image[-3:] not in ['jpg', 'JPG', 'png', 'PNG', 'peg']: 32 | contents.pop(i) 33 | logger.info('检测到{}张图片!'.format(len(contents))) 34 | for i,image in enumerate(contents): 35 | logger.info('正在处理第{}张图片...'.format(i+1)) 36 | # image = cv2.imread(os.path.join(image_dir,i)) 37 | logger.info('开始分割图像...') 38 | img = image_segmentation(os.path.join(image_dir,image)) 39 | logger.info('分割图像已完成!') 40 | logger.info('开始图像增强...') 41 | img = image_enhancement(img) 42 | logger.info('图像增强已完成!') 43 | logger.info('开始版面分析并转为Word...') 44 | Structure(img, os.path.join(image_dir,image),args['output']) 45 | logger.info('第{}张图片已完成!'.format(i+1)) 46 | print() 47 | if i == len(contents)-1: 48 | logger.info('已全部处理完成!') 49 | break 50 | else: 51 | print('请输入正确的路径!!!') 52 | 53 | -------------------------------------------------------------------------------- /DocImgTranslation/ocr_image_en2ch.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | from structure import Structure 4 | from imageEnhancement import image_enhancement 5 | from SegmentImage import image_segmentation 6 | from paddleocr.ppocr.utils.logging import get_logger 7 | 8 | logger = get_logger() 9 | ap = argparse.ArgumentParser() 10 | ap.add_argument("-i", "--image", help="path to the input image") 11 | ap.add_argument("-d","--image_dir") 12 | ap.add_argument("-o","--output") 13 | args = vars(ap.parse_args()) 14 | 15 | if args['image'] != None and args['image_dir'] == None: 16 | logger.info('开始分割图像...') 17 | img = image_segmentation(args['image']) 18 | #img = cv2.imread(args['image'],0) 19 | logger.info('分割图像已完成!') 20 | logger.info('开始图像增强...') 21 | img = image_enhancement(img) 22 | logger.info('图像增强已完成!') 23 | logger.info('开始版面分析并转为Word...') 24 | Structure(img,args['image'],args['output'],"en2ch") 25 | logger.info('已处理完成!') 26 | elif args['image'] == None and args['image_dir'] != None: 27 | image_dir = args['image_dir'] 28 | if os.path.exists(image_dir): 29 | contents = os.listdir(image_dir) 30 | for i, image in enumerate(contents): 31 | if image[-3:] not in ['jpg', 'JPG', 'png', 'PNG', 'peg']: 32 | contents.pop(i) 33 | logger.info('检测到{}张图片!'.format(len(contents))) 34 | for i,image in enumerate(contents): 35 | logger.info('正在处理第{}张图片...'.format(i+1)) 36 | # image = cv2.imread(os.path.join(image_dir,i)) 37 | logger.info('开始分割图像...') 38 | img = image_segmentation(os.path.join(image_dir,image)) 39 | logger.info('分割图像已完成!') 40 | logger.info('开始图像增强...') 41 | img = image_enhancement(img) 42 | logger.info('图像增强已完成!') 43 | logger.info('开始版面分析并转为Word...') 44 | Structure(img, os.path.join(image_dir,image),args['output']) 45 | logger.info('第{}张图片已完成!'.format(i+1)) 46 | print() 47 | if i == len(contents)-1: 48 | logger.info('已全部处理完成!') 49 | break 50 | else: 51 | print('请输入正确的路径!!!') 52 | 53 | -------------------------------------------------------------------------------- /SpeechTranslation/speech_translation_en2ch.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python 2 | import os 3 | import shutil 4 | from paddlespeech.cli.st.infer import STExecutor 5 | from pydub import AudioSegment 6 | from pydub.utils import make_chunks 7 | import configparser 8 | config = configparser.ConfigParser() 9 | config.read("config.conf") 10 | device = config.get("devices", "device") 11 | # 输入音频文件的路径(例如 MP3、FLAC、OGG 等) 12 | 13 | def translate(input_audio_file): 14 | # 指定输出的 WAV 文件路径 15 | output_wav_file = "SpeechTranslation/output_audio.wav" 16 | 17 | # 使用 AudioSegment 加载音频文件 18 | audio = AudioSegment.from_file(input_audio_file) 19 | audio_length_seconds = len(audio) / 1000.0 20 | # 将音频文件保存为 WAV 格式 21 | audio.export(output_wav_file, format="wav") 22 | st = STExecutor() 23 | text='' 24 | if audio_length_seconds > 50: 25 | # 切分保存的 WAV 文件为多个50秒的片段 26 | # 定义切分的时长(单位:毫秒) 27 | chunk_length_ms = 50000 # 50秒 28 | 29 | # 使用 AudioSegment 再次加载 WAV 文件 30 | wav_audio = AudioSegment.from_file(output_wav_file) 31 | 32 | # 切分音频文件成多个片段 33 | chunks = make_chunks(wav_audio, chunk_length_ms) 34 | 35 | # 确保输出文件夹存在 36 | 37 | audio_folder = "SpeechTranslation/output_chunks/" 38 | os.makedirs(audio_folder, exist_ok=True) 39 | 40 | # 保存切分后的音频片段 41 | for i, chunk in enumerate(chunks): 42 | chunk = chunk.set_frame_rate(16000) 43 | output_file = f"{audio_folder}{i + 1}.wav" 44 | chunk.export(output_file, format="wav") 45 | 46 | print("音频文件已成功转换为 WAV 格式,并切分为多个50秒的片段,保存在:", audio_folder) 47 | 48 | audio_files = [os.path.join(audio_folder, file) for file in os.listdir(audio_folder) if file.endswith(".wav")] 49 | # 对音频文件列表按文件名排序 50 | audio_files.sort() 51 | for audio in audio_files: 52 | result = st(audio_file=audio,device=device) 53 | text = result+text 54 | else: 55 | text = st(audio_file=output_wav_file, device=device) 56 | shutil.rmtree("exp") 57 | os.remove("SpeechTranslation/audio_recording.wav") 58 | os.remove("SpeechTranslation/output_audio.wav") 59 | shutil.rmtree("SpeechTranslation/output_chunks") 60 | return text 61 | -------------------------------------------------------------------------------- /SpeechTranslation/speech_translation_ch2en.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python 2 | import os 3 | import transformers 4 | from transformers import AutoTokenizer, AutoModelForSeq2SeqLM 5 | from paddlespeech.cli.asr.infer import ASRExecutor 6 | from paddlespeech.cli.text.infer import TextExecutor 7 | from pydub import AudioSegment 8 | from pydub.utils import make_chunks 9 | import configparser 10 | config = configparser.ConfigParser() 11 | config.read("config.conf") 12 | device = config.get("devices", "device") 13 | # 输入音频文件的路径(例如 MP3、FLAC、OGG 等) 14 | 15 | def translate(input_audio_file): 16 | # 指定输出的 WAV 文件路径 17 | output_wav_file = "SpeechTranslation/output_audio.wav" 18 | # 确保输出文件夹存在 19 | 20 | audio_folder = "SpeechTranslation/output_chunks/" 21 | os.makedirs(audio_folder, exist_ok=True) 22 | # 使用 AudioSegment 加载音频文件 23 | audio = AudioSegment.from_file(input_audio_file) 24 | audio_length_seconds = len(audio) / 1000.0 25 | # 将音频文件保存为 WAV 格式 26 | audio.export(output_wav_file, format="wav") 27 | asr = ASRExecutor() 28 | text='' 29 | if audio_length_seconds > 50: 30 | # 切分保存的 WAV 文件为多个50秒的片段 31 | # 定义切分的时长(单位:毫秒) 32 | chunk_length_ms = 50000 # 50秒 33 | 34 | # 使用 AudioSegment 再次加载 WAV 文件 35 | wav_audio = AudioSegment.from_file(output_wav_file) 36 | 37 | # 切分音频文件成多个片段 38 | chunks = make_chunks(wav_audio, chunk_length_ms) 39 | 40 | 41 | # 保存切分后的音频片段 42 | for i, chunk in enumerate(chunks): 43 | chunk = chunk.set_frame_rate(16000) 44 | output_file = f"{audio_folder}{i + 1}.wav" 45 | chunk.export(output_file, format="wav") 46 | 47 | print("音频文件已成功转换为 WAV 格式,并切分为多个50秒的片段,保存在:", audio_folder) 48 | 49 | audio_files = [os.path.join(audio_folder, file) for file in os.listdir(audio_folder) if file.endswith(".wav")] 50 | # 对音频文件列表按文件名排序 51 | audio_files.sort() 52 | for audio in audio_files: 53 | result = asr(audio_file=audio,device=device) 54 | text = result+text 55 | else: 56 | text = asr(audio_file=output_wav_file, device=device) 57 | # text_punc = TextExecutor() 58 | # translate_text = text_punc(text=text) 59 | model_path = 'models/translate/zh-en' 60 | tokenizer = AutoTokenizer.from_pretrained(model_path) 61 | model = AutoModelForSeq2SeqLM.from_pretrained(model_path) 62 | pipeline = transformers.pipeline("translation", model=model, tokenizer=tokenizer) 63 | translate_text = pipeline(text)[0]['translation_text'] 64 | # shutil.rmtree("exp") 65 | os.remove("SpeechTranslation/audio_recording.wav") 66 | os.remove("SpeechTranslation/output_audio.wav") 67 | return translate_text 68 | 69 | -------------------------------------------------------------------------------- /FileTranslation/Excel/excel_translation_ch2en.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import pandas as pd 4 | from bs4 import BeautifulSoup 5 | import transformers 6 | from transformers import AutoTokenizer, AutoModelForSeq2SeqLM 7 | def table_translate(html_table,translate_model): 8 | # 解析 HTML 9 | soup = BeautifulSoup(html_table, 'html.parser') 10 | # 提取并翻译表格的第一行(表头) 11 | original_header_row = soup.find('tr') 12 | translated_header_row = '' 13 | for cell in original_header_row.find_all('td'): 14 | original_text = cell.text.strip() 15 | translated_text = translate_model(original_text)[0]['translation_text'] 16 | translated_cell = f'{translated_text}' 17 | translated_header_row += translated_cell 18 | translated_header_row += '' 19 | 20 | # 提取并翻译表格的内容行 21 | translated_table = soup.new_tag('table') # 创建一个新的表格用于存储翻译后的文本 22 | translated_table.append(BeautifulSoup(translated_header_row, 'html.parser')) # 将翻译后的表头行插入到表格中 23 | for row in soup.find_all('tr')[1:]: 24 | translated_row = '' # 创建新的行 25 | for cell in row.find_all('td'): 26 | original_text = cell.text.strip() # 提取原始文本 27 | translated_text = translate_model(original_text)[0]['translation_text'] 28 | translated_cell = f'{translated_text}' # 创建新的单元格 29 | translated_row += translated_cell # 将单元格添加到行中 30 | translated_row += '' 31 | translated_table.append(BeautifulSoup(translated_row, 'html.parser')) # 将行添加到表格中 32 | 33 | return translated_table 34 | if __name__ == '__main__': 35 | savePath = sys.argv[2] 36 | # 1、excel地址 37 | excelPath = sys.argv[1] 38 | model_path = 'models/translate/zh-en' 39 | tokenizer = AutoTokenizer.from_pretrained(model_path) 40 | model = AutoModelForSeq2SeqLM.from_pretrained(model_path) 41 | translate_model = transformers.pipeline("translation", model=model, tokenizer=tokenizer) 42 | # 读取 Excel 文件 43 | df = pd.read_excel(excelPath, sheet_name=0, header=None) 44 | html = df.to_html(index=False) 45 | html_string = table_translate(html,translate_model) 46 | # 创建 HTML 模板 47 | html_template = f""" 48 | 49 | 50 | Table HTML Page 51 | 52 | 53 | {html_string} 54 | 55 | 56 | """ 57 | 58 | with open("FileTranslation/Excel/tmp.html", "w") as file: 59 | file.write(html_template) 60 | # 使用 pandas 读取表格 61 | df = pd.read_html("FileTranslation/Excel/tmp.html", encoding='utf-8')[0] 62 | print(df) 63 | # 将数据框保存为 Excel 文件 64 | excel_name = os.path.basename(excelPath) 65 | df.to_excel(os.path.join(savePath, excel_name), index=False, header=None) -------------------------------------------------------------------------------- /DocImgTranslation/SegmentImage.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | from FastSAM import Inference 4 | 5 | 6 | def order_points(pts): 7 | # 一共4个坐标点 8 | rect = np.zeros((4, 2), dtype="float32") 9 | # 按顺序找到对应坐标0123分别是 左上,右上,右下,左下 10 | # 计算左上,右下 11 | s = pts.sum(axis = 1) 12 | rect[0] = pts[np.argmin(s)] 13 | rect[2] = pts[np.argmax(s)] 14 | # 计算右上和左下 15 | diff = np.diff(pts, axis = 1) 16 | rect[1] = pts[np.argmin(diff)] 17 | rect[3] = pts[np.argmax(diff)] 18 | return rect 19 | 20 | 21 | def four_point_transform(image, pts): 22 | # 获取输入坐标点 23 | rect = order_points(pts) 24 | (tl, tr, br, bl) = rect 25 | 26 | # 计算输入的w和h值 27 | widthA = np.sqrt(((br[0] - bl[0]) ** 2) + ((br[1] - bl[1]) ** 2)) 28 | widthB = np.sqrt(((tr[0] - tl[0]) ** 2) + ((tr[1] - tl[1]) ** 2)) 29 | maxWidth = max(int(widthA), int(widthB)) 30 | 31 | heightA = np.sqrt(((tr[0] - br[0]) ** 2) + ((tr[1] - br[1]) ** 2)) 32 | heightB = np.sqrt(((tl[0] - bl[0]) ** 2) + ((tl[1] - bl[1]) ** 2)) 33 | maxHeight = max(int(heightA), int(heightB)) 34 | 35 | # 变换后对应坐标位置 36 | dst = np.array([ 37 | [0, 0], 38 | [maxWidth - 1, 0], 39 | [maxWidth - 1, maxHeight - 1], 40 | [0, maxHeight - 1]], dtype="float32") 41 | 42 | # 计算变换矩阵 43 | M = cv2.getPerspectiveTransform(rect, dst) 44 | warped = cv2.warpPerspective(image, M, (maxWidth, maxHeight)) 45 | # 返回变换后结果 46 | return warped 47 | 48 | 49 | def image_segmentation(img_path): 50 | img = cv2.imread(img_path) 51 | ann = Inference.main(img_path) 52 | ann = ann.numpy() 53 | ann = ann.astype('uint8') 54 | # ann = np.reshape(ann, (ann.shape[1],ann.shape[2])) 55 | ann[ann == 1] = 255 56 | contours = [] 57 | # print(ann.shape) 58 | for i in ann: 59 | # print(i.shape) 60 | kernel = np.ones((20, 20), np.uint8) 61 | i = cv2.erode(i, kernel, iterations=1) 62 | contour, hierarchy = cv2.findContours(i, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE) 63 | contours.append(contour[0]) 64 | docCnt = None 65 | if len(contours) > 0: 66 | # 根据轮廓大小进行排序 67 | cnts = sorted(contours, key=cv2.contourArea, reverse=True) 68 | # 遍历每一个轮廓 69 | for c in cnts: 70 | # 近似 71 | peri = cv2.arcLength(c, True) 72 | approx = cv2.approxPolyDP(c, 0.02 * peri, True) 73 | # 准备做透视变换 74 | if len(approx) == 4: 75 | if approx[0][0][0] == 0 and approx[0][0][1] == 0: 76 | continue 77 | docCnt = approx 78 | break 79 | 80 | warped = four_point_transform(img, docCnt.reshape(4, 2)) 81 | return warped 82 | if __name__ == '__main__': 83 | image_segmentation("./16.jpg") 84 | -------------------------------------------------------------------------------- /HyperTranslation/hyper_translation_ch2en.py: -------------------------------------------------------------------------------- 1 | import configparser 2 | import tkinter as tk 3 | 4 | import pyperclip 5 | import transformers 6 | from pynput import keyboard 7 | from transformers import AutoTokenizer, AutoModelForSeq2SeqLM 8 | 9 | config = configparser.ConfigParser() 10 | config.read("config.conf") 11 | class TranslatorApp(tk.Tk): 12 | def __init__(self): 13 | super().__init__() 14 | self.title("全能翻译") 15 | self.geometry("300x100") 16 | 17 | self.init_ui() 18 | self.withdraw() # 初始时隐藏窗口 19 | 20 | def init_ui(self): 21 | self.protocol("WM_DELETE_WINDOW", self.on_closing) # 拦截窗口关闭事件 22 | 23 | # self.label = tk.Label(self, text="使用快捷键 Ctrl+Shift+F 进行截图并识别文字:") 24 | # self.label.pack(pady=10) 25 | 26 | self.frame = tk.Frame(self) 27 | self.frame.pack() 28 | 29 | self.text_frame = tk.Frame(self.frame) 30 | self.text_frame.pack(padx=10, pady=10) 31 | 32 | self.text_label = tk.Label(self.text_frame, text="翻译结果:") 33 | self.text_label.pack() 34 | 35 | self.text_box = tk.Text(self.text_frame, height=5, width=40) 36 | self.text_box.pack(fill=tk.BOTH, expand=True) 37 | translation_shortcut = config.get("shortcuts", "translation") 38 | self.listener = keyboard.GlobalHotKeys({ 39 | translation_shortcut: self.text_ocr 40 | }) 41 | self.listener.start() 42 | # 43 | # def get_selected_text(self): 44 | # try: 45 | # selected_text = subprocess.check_output(['pbpaste'], universal_newlines=True) 46 | # return selected_text 47 | # except subprocess.CalledProcessError: 48 | # return "" 49 | def text_ocr(self): 50 | # keyboard1 = Controller() 51 | # 模拟按下 Ctrl + C 组合键 52 | # with keyboard1.pressed(Key.cmd): 53 | # keyboard1.press('c') 54 | # keyboard1.release('c') 55 | # time.sleep(0.5) 56 | # text = self.get_selected_text() 57 | text = pyperclip.paste() 58 | model_path = 'models/translate/zh-en' 59 | tokenizer = AutoTokenizer.from_pretrained(model_path) 60 | translate_model = AutoModelForSeq2SeqLM.from_pretrained(model_path) 61 | pipeline = transformers.pipeline("translation", model=translate_model, tokenizer=tokenizer) 62 | translate_text = pipeline(text) 63 | self.display_translation(translate_text[0]['translation_text']) 64 | print(translate_text[0]['translation_text']) 65 | self.deiconify() # 在截图完成后显示窗口 66 | def display_translation(self, text): 67 | self.text_box.delete("1.0", tk.END) # 清空文本框 68 | self.text_box.insert(tk.END, text) 69 | def on_closing(self): 70 | self.withdraw() # 仅隐藏窗口,不退出程序 71 | if __name__ == "__main__": 72 | app = TranslatorApp() 73 | app.mainloop() -------------------------------------------------------------------------------- /FileTranslation/Excel/excel_translation_en2ch.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pandas as pd 3 | import sys 4 | sys.path.append(os.getcwd()) 5 | import my_utils 6 | from bs4 import BeautifulSoup 7 | import transformers 8 | from transformers import AutoTokenizer, AutoModelForSeq2SeqLM 9 | def table_translate(html_table,translate_model): 10 | # 解析 HTML 11 | soup = BeautifulSoup(html_table, 'html.parser') 12 | # 提取并翻译表格的第一行(表头) 13 | original_header_row = soup.find('tr') 14 | translated_header_row = '' 15 | for cell in original_header_row.find_all('td'): 16 | original_text = cell.text.strip() 17 | translated_text = translate_model(original_text)[0]['translation_text'] 18 | translated_text = my_utils.do_sentence(translated_text) 19 | translated_cell = f'{translated_text}' 20 | translated_header_row += translated_cell 21 | translated_header_row += '' 22 | 23 | # 提取并翻译表格的内容行 24 | translated_table = soup.new_tag('table') # 创建一个新的表格用于存储翻译后的文本 25 | translated_table.append(BeautifulSoup(translated_header_row, 'html.parser')) # 将翻译后的表头行插入到表格中 26 | for row in soup.find_all('tr')[1:]: 27 | translated_row = '' # 创建新的行 28 | for cell in row.find_all('td'): 29 | original_text = cell.text.strip() # 提取原始文本 30 | translated_text = translate_model(original_text)[0]['translation_text'] 31 | translated_text = my_utils.do_sentence(translated_text) 32 | translated_cell = f'{translated_text}' # 创建新的单元格 33 | translated_row += translated_cell # 将单元格添加到行中 34 | translated_row += '' 35 | translated_table.append(BeautifulSoup(translated_row, 'html.parser')) # 将行添加到表格中 36 | 37 | return translated_table 38 | if __name__ == '__main__': 39 | savePath = sys.argv[2] 40 | # 1、excel地址 41 | excelPath = sys.argv[1] 42 | model_path = 'models/translate/en-zh' 43 | tokenizer = AutoTokenizer.from_pretrained(model_path) 44 | model = AutoModelForSeq2SeqLM.from_pretrained(model_path) 45 | translate_model = transformers.pipeline("translation", model=model, tokenizer=tokenizer) 46 | # 读取 Excel 文件 47 | df = pd.read_excel(excelPath, sheet_name=0, header=None) 48 | html = df.to_html(index=False) 49 | html_string = table_translate(html,translate_model) 50 | # 创建 HTML 模板 51 | html_template = f""" 52 | 53 | 54 | Table HTML Page 55 | 56 | 57 | {html_string} 58 | 59 | 60 | """ 61 | 62 | with open("FileTranslation/Excel/tmp.html", "w") as file: 63 | file.write(html_template) 64 | # 使用 pandas 读取表格 65 | df = pd.read_html("FileTranslation/Excel/tmp.html", encoding='utf-8')[0] 66 | print(df) 67 | # 将数据框保存为 Excel 文件 68 | excel_name = os.path.basename(excelPath) 69 | df.to_excel(os.path.join(savePath, excel_name), index=False, header=None) -------------------------------------------------------------------------------- /HyperTranslation/hyper_translation_en2ch.py: -------------------------------------------------------------------------------- 1 | import configparser 2 | import tkinter as tk 3 | import sys 4 | import os 5 | sys.path.append(os.getcwd()) 6 | import my_utils 7 | import pyperclip 8 | import transformers 9 | from pynput import keyboard 10 | from transformers import AutoTokenizer, AutoModelForSeq2SeqLM 11 | config = configparser.ConfigParser() 12 | config.read("config.conf") 13 | class TranslatorApp(tk.Tk): 14 | def __init__(self): 15 | super().__init__() 16 | self.title("全能翻译") 17 | self.geometry("300x100") 18 | 19 | self.init_ui() 20 | self.withdraw() # 初始时隐藏窗口 21 | 22 | def init_ui(self): 23 | self.protocol("WM_DELETE_WINDOW", self.on_closing) # 拦截窗口关闭事件 24 | 25 | # self.label = tk.Label(self, text="使用快捷键 Ctrl+Shift+F 进行截图并识别文字:") 26 | # self.label.pack(pady=10) 27 | 28 | self.frame = tk.Frame(self) 29 | self.frame.pack() 30 | 31 | self.text_frame = tk.Frame(self.frame) 32 | self.text_frame.pack(padx=10, pady=10) 33 | 34 | self.text_label = tk.Label(self.text_frame, text="翻译结果:") 35 | self.text_label.pack() 36 | 37 | self.text_box = tk.Text(self.text_frame, height=5, width=40) 38 | self.text_box.pack(fill=tk.BOTH, expand=True) 39 | translation_shortcut = config.get("shortcuts", "translation") 40 | self.listener = keyboard.GlobalHotKeys({ 41 | translation_shortcut: self.text_ocr 42 | }) 43 | self.listener.start() 44 | # 45 | # def get_selected_text(self): 46 | # try: 47 | # selected_text = subprocess.check_output(['pbpaste'], universal_newlines=True) 48 | # return selected_text 49 | # except subprocess.CalledProcessError: 50 | # return "" 51 | def text_ocr(self): 52 | # keyboard1 = Controller() 53 | # 模拟按下 Ctrl + C 组合键 54 | # with keyboard1.pressed(Key.cmd): 55 | # keyboard1.press('c') 56 | # keyboard1.release('c') 57 | # time.sleep(0.5) 58 | # text = self.get_selected_text() 59 | text = pyperclip.paste() 60 | model_path = 'models/translate/en-zh' 61 | tokenizer = AutoTokenizer.from_pretrained(model_path) 62 | translate_model = AutoModelForSeq2SeqLM.from_pretrained(model_path) 63 | pipeline = transformers.pipeline("translation", model=translate_model, tokenizer=tokenizer) 64 | translate_text = pipeline(text) 65 | translate_text = my_utils.do_sentence(translate_text[0]['translation_text']) 66 | self.display_translation(translate_text) 67 | print(translate_text) 68 | self.deiconify() # 在截图完成后显示窗口 69 | def display_translation(self, text): 70 | self.text_box.delete("1.0", tk.END) # 清空文本框 71 | self.text_box.insert(tk.END, text) 72 | def on_closing(self): 73 | self.withdraw() # 仅隐藏窗口,不退出程序 74 | if __name__ == "__main__": 75 | app = TranslatorApp() 76 | app.mainloop() 77 | -------------------------------------------------------------------------------- /DocImgTranslation/FastSAM/fastsam/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from PIL import Image 4 | 5 | 6 | def adjust_bboxes_to_image_border(boxes, image_shape, threshold=20): 7 | '''Adjust bounding boxes to stick to image border if they are within a certain threshold. 8 | Args: 9 | boxes: (n, 4) 10 | image_shape: (height, width) 11 | threshold: pixel threshold 12 | Returns: 13 | adjusted_boxes: adjusted bounding boxes 14 | ''' 15 | 16 | # Image dimensions 17 | h, w = image_shape 18 | 19 | # Adjust boxes 20 | boxes[:, 0] = torch.where(boxes[:, 0] < threshold, torch.tensor( 21 | 0, dtype=torch.float, device=boxes.device), boxes[:, 0]) # x1 22 | boxes[:, 1] = torch.where(boxes[:, 1] < threshold, torch.tensor( 23 | 0, dtype=torch.float, device=boxes.device), boxes[:, 1]) # y1 24 | boxes[:, 2] = torch.where(boxes[:, 2] > w - threshold, torch.tensor( 25 | w, dtype=torch.float, device=boxes.device), boxes[:, 2]) # x2 26 | boxes[:, 3] = torch.where(boxes[:, 3] > h - threshold, torch.tensor( 27 | h, dtype=torch.float, device=boxes.device), boxes[:, 3]) # y2 28 | 29 | return boxes 30 | 31 | 32 | 33 | def convert_box_xywh_to_xyxy(box): 34 | x1 = box[0] 35 | y1 = box[1] 36 | x2 = box[0] + box[2] 37 | y2 = box[1] + box[3] 38 | return [x1, y1, x2, y2] 39 | 40 | 41 | def bbox_iou(box1, boxes, iou_thres=0.9, image_shape=(640, 640), raw_output=False): 42 | '''Compute the Intersection-Over-Union of a bounding box with respect to an array of other bounding boxes. 43 | Args: 44 | box1: (4, ) 45 | boxes: (n, 4) 46 | Returns: 47 | high_iou_indices: Indices of boxes with IoU > thres 48 | ''' 49 | boxes = adjust_bboxes_to_image_border(boxes, image_shape) 50 | # obtain coordinates for intersections 51 | x1 = torch.max(box1[0], boxes[:, 0]) 52 | y1 = torch.max(box1[1], boxes[:, 1]) 53 | x2 = torch.min(box1[2], boxes[:, 2]) 54 | y2 = torch.min(box1[3], boxes[:, 3]) 55 | 56 | # compute the area of intersection 57 | intersection = (x2 - x1).clamp(0) * (y2 - y1).clamp(0) 58 | 59 | # compute the area of both individual boxes 60 | box1_area = (box1[2] - box1[0]) * (box1[3] - box1[1]) 61 | box2_area = (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1]) 62 | 63 | # compute the area of union 64 | union = box1_area + box2_area - intersection 65 | 66 | # compute the IoU 67 | iou = intersection / union # Should be shape (n, ) 68 | if raw_output: 69 | if iou.numel() == 0: 70 | return 0 71 | return iou 72 | 73 | # get indices of boxes with IoU > thres 74 | high_iou_indices = torch.nonzero(iou > iou_thres).flatten() 75 | 76 | return high_iou_indices 77 | 78 | 79 | def image_to_np_ndarray(image): 80 | if type(image) is str: 81 | return np.array(Image.open(image)) 82 | elif issubclass(type(image), Image.Image): 83 | return np.array(image) 84 | elif type(image) is np.ndarray: 85 | return image 86 | return None 87 | -------------------------------------------------------------------------------- /DocImgTranslation/FastSAM/fastsam/predict.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from ultralytics.yolo.engine.results import Results 4 | from ultralytics.yolo.utils import DEFAULT_CFG, ops 5 | from ultralytics.yolo.v8.detect.predict import DetectionPredictor 6 | from .utils import bbox_iou 7 | 8 | class FastSAMPredictor(DetectionPredictor): 9 | 10 | def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None): 11 | super().__init__(cfg, overrides, _callbacks) 12 | self.args.task = 'segment' 13 | 14 | def postprocess(self, preds, img, orig_imgs): 15 | """TODO: filter by classes.""" 16 | p = ops.non_max_suppression(preds[0], 17 | self.args.conf, 18 | self.args.iou, 19 | agnostic=self.args.agnostic_nms, 20 | max_det=self.args.max_det, 21 | nc=len(self.model.names), 22 | classes=self.args.classes) 23 | 24 | results = [] 25 | if len(p) == 0 or len(p[0]) == 0: 26 | print("No object detected.") 27 | return results 28 | 29 | full_box = torch.zeros_like(p[0][0]) 30 | full_box[2], full_box[3], full_box[4], full_box[6:] = img.shape[3], img.shape[2], 1.0, 1.0 31 | full_box = full_box.view(1, -1) 32 | critical_iou_index = bbox_iou(full_box[0][:4], p[0][:, :4], iou_thres=0.9, image_shape=img.shape[2:]) 33 | if critical_iou_index.numel() != 0: 34 | full_box[0][4] = p[0][critical_iou_index][:,4] 35 | full_box[0][6:] = p[0][critical_iou_index][:,6:] 36 | p[0][critical_iou_index] = full_box 37 | 38 | proto = preds[1][-1] if len(preds[1]) == 3 else preds[1] # second output is len 3 if pt, but only 1 if exported 39 | for i, pred in enumerate(p): 40 | orig_img = orig_imgs[i] if isinstance(orig_imgs, list) else orig_imgs 41 | path = self.batch[0] 42 | img_path = path[i] if isinstance(path, list) else path 43 | if not len(pred): # save empty boxes 44 | results.append(Results(orig_img=orig_img, path=img_path, names=self.model.names, boxes=pred[:, :6])) 45 | continue 46 | if self.args.retina_masks: 47 | if not isinstance(orig_imgs, torch.Tensor): 48 | pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape) 49 | masks = ops.process_mask_native(proto[i], pred[:, 6:], pred[:, :4], orig_img.shape[:2]) # HWC 50 | else: 51 | masks = ops.process_mask(proto[i], pred[:, 6:], pred[:, :4], img.shape[2:], upsample=True) # HWC 52 | if not isinstance(orig_imgs, torch.Tensor): 53 | pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape) 54 | results.append( 55 | Results(orig_img=orig_img, path=img_path, names=self.model.names, boxes=pred[:, :6], masks=masks)) 56 | return results 57 | -------------------------------------------------------------------------------- /my_utils.py: -------------------------------------------------------------------------------- 1 | import re 2 | from nltk import word_tokenize 3 | import numpy as np 4 | import jieba 5 | 6 | def process_word(text): 7 | 8 | if re.search(r'[\u4e00-\u9fff]', text): # 判断是否包含中文字符 9 | text = text.replace(' ', '') 10 | if len(text) > 1: 11 | result = re.sub(r'([\u4e00-\u9fff])\1+', r'\1', text) # 替换连续的相同中文字为第一个中文字 12 | return result 13 | else: 14 | return text 15 | else: 16 | result = re.sub(r'([a-zA-Z])\1{2,}', r'\1\1', text) # 替换连续的相同字母为第一个字母 17 | return result 18 | def remove_last_punctuation(text): 19 | # 匹配最后一个标点符号的正则表达式 20 | pattern = r'[,.?;]$' 21 | punctuation_list = [] 22 | cleaned_list = [] 23 | for sentence in text: 24 | match = re.search(pattern, sentence) 25 | if match: 26 | # 如果匹配到了标点符号,则去除标点符号并记录到结果列表中 27 | cleaned_sentence = re.sub(pattern, '', sentence) 28 | punctuation = match.group() 29 | else: 30 | # 如果没有匹配到标点符号,则保持原始文本并将结果列表中添加空字符串 31 | cleaned_sentence = sentence 32 | punctuation = '' 33 | punctuation_list.append(punctuation) 34 | cleaned_list.append(cleaned_sentence) 35 | return punctuation_list, cleaned_list 36 | 37 | def cosine_similarity(u, v): 38 | dot_product = np.dot(u, v) 39 | norm_u = np.linalg.norm(u) 40 | norm_v = np.linalg.norm(v) 41 | similarity = dot_product / (norm_u * norm_v) 42 | return similarity 43 | def do_sentence(translated_text): 44 | if translated_text == '': 45 | return '' 46 | translated_text = process_word(translated_text) 47 | print(translated_text) 48 | tokenized_sentence = list(jieba.cut(translated_text)) 49 | # 保留第一个词语 50 | words = [tokenized_sentence[0]] # 保留第一个词语 51 | for i in range(1, len(tokenized_sentence)): 52 | if tokenized_sentence[i] != tokenized_sentence[i - 1]: 53 | words.append(tokenized_sentence[i]) 54 | words = [process_word(word) for word in words] 55 | print(words) 56 | filtered_sentence = "".join(words) 57 | sentences = [] 58 | j = '' 59 | for i in words: 60 | j = j + i 61 | if i[-1] in [',', '?', '!', ')', ';', ' ', '...']: 62 | sentences.append(j) 63 | j = '' 64 | sentences.append(j) 65 | if len(sentences) <= 1: 66 | return filtered_sentence 67 | # 分词和构建词袋表示 68 | punctuation_list, sentences = remove_last_punctuation(sentences) 69 | print(punctuation_list) 70 | print(sentences) 71 | tokenized_sentences = [list(jieba.cut(sentence)) for sentence in sentences] 72 | vocabulary = set() 73 | for sentence in tokenized_sentences: 74 | vocabulary.update(sentence) 75 | vocabulary = list(vocabulary) 76 | word_to_index = {word: i for i, word in enumerate(vocabulary)} 77 | bag_of_words = np.zeros((len(sentences), len(vocabulary))) 78 | for i, sentence in enumerate(tokenized_sentences): 79 | for word in sentence: 80 | word_index = word_to_index[word] 81 | bag_of_words[i, word_index] += 1 82 | 83 | # 计算词频分布 84 | word_frequency = bag_of_words / np.linalg.norm(bag_of_words, axis=1, keepdims=True) 85 | 86 | # 保留第一个句子 87 | filtered_sentences = [sentences[0]+punctuation_list[0]] # 保留第一个句子 88 | for i in range(1, len(sentences)): 89 | similarity = cosine_similarity(word_frequency[i-1], word_frequency[i]) 90 | print(similarity) 91 | if similarity < 0.75: 92 | filtered_sentences.append(sentences[i]+punctuation_list[i]) 93 | print(punctuation_list[i]) 94 | 95 | # 输出结果 96 | 97 | sentence = ''.join(filtered_sentences) 98 | return sentence 99 | if __name__ == '__main__': 100 | 101 | print(do_sentence("")) -------------------------------------------------------------------------------- /ScreenshotTranslation/screenshot_translation_ch2en.py: -------------------------------------------------------------------------------- 1 | import os 2 | import subprocess 3 | import tkinter as tk 4 | import configparser 5 | from PIL import Image, ImageTk 6 | from pynput import keyboard 7 | from paddleocr import PaddleOCR 8 | from transformers import AutoTokenizer, AutoModelForSeq2SeqLM 9 | import transformers 10 | import platform 11 | 12 | config = configparser.ConfigParser() 13 | config.read("config.conf") 14 | 15 | 16 | class ScreenShotTool(tk.Tk): 17 | def __init__(self): 18 | super().__init__() 19 | 20 | self.title("全能翻译") 21 | self.geometry("700x200") 22 | 23 | self.init_ui() 24 | self.withdraw() # 初始时隐藏窗口 25 | 26 | def init_ui(self): 27 | self.protocol("WM_DELETE_WINDOW", self.on_closing) # 拦截窗口关闭事件 28 | 29 | self.frame = tk.Frame(self) 30 | self.frame.pack() 31 | 32 | self.image_frame = tk.Frame(self.frame) 33 | self.image_frame.pack(side=tk.LEFT, padx=10, pady=10) 34 | 35 | self.text_frame = tk.Frame(self.frame) 36 | self.text_frame.pack(side=tk.RIGHT, padx=10, pady=10) 37 | self.image_label1 = tk.Label(self.image_frame, text="截图:") 38 | self.image_label1.pack() 39 | self.image_label = tk.Label(self.image_frame) 40 | self.image_label.pack() 41 | 42 | self.text_label = tk.Label(self.text_frame, text="翻译结果:") 43 | self.text_label.pack() 44 | 45 | self.text_box = tk.Text(self.text_frame, height=10, width=40) 46 | self.text_box.pack(fill=tk.BOTH, expand=True) 47 | screenshot_shortcut = config.get("shortcuts", "screenshot") 48 | self.listener = keyboard.GlobalHotKeys({ 49 | screenshot_shortcut: self.capture_and_ocr 50 | }) 51 | self.listener.start() 52 | 53 | def capture_and_ocr(self): 54 | ocr = PaddleOCR(use_angle_cls=True, lang='ch',cls_model_dir='models/ocr/cls/ch_ppocr_mobile_v2.0_cls_infer',det_model_dir='models/ocr/det/ch/ch_PP-OCRv4_det_infer',rec_model_dir='models/ocr/rec/ch/ch_PP-OCRv4_rec_infer') 55 | model_path = 'models/translate/zh-en' 56 | tokenizer = AutoTokenizer.from_pretrained(model_path) 57 | model = AutoModelForSeq2SeqLM.from_pretrained(model_path) 58 | pipeline = transformers.pipeline("translation", model=model, tokenizer=tokenizer) 59 | screenshot_path = "temp_screenshot.png" 60 | 61 | system = platform.system() 62 | if system == "Windows": 63 | from PIL import ImageGrab 64 | subprocess.run(["snippingtool", "/clip"]) 65 | clipboard_image = ImageGrab.grabclipboard() 66 | 67 | if clipboard_image is not None: 68 | # 保存图像到指定路径 69 | clipboard_image.save(screenshot_path) 70 | elif system == "Linux": 71 | subprocess.run(["gnome-screenshot", "-a", "-f", screenshot_path]) 72 | elif system == "Darwin": 73 | subprocess.run(["screencapture", "-i", screenshot_path]) 74 | else: 75 | print("Unsupported operating system.") 76 | return 77 | 78 | image = Image.open(screenshot_path) 79 | image.thumbnail((300, 300)) 80 | photo = ImageTk.PhotoImage(image) 81 | self.image_label.config(image=photo) 82 | self.image_label.image = photo 83 | 84 | recognized_text = [] 85 | results = ocr.ocr(screenshot_path, cls=True) 86 | for idx in range(len(results)): 87 | res = results[idx] 88 | for line in res: 89 | translate_text = pipeline(line[1][0])[0]['translation_text'] 90 | recognized_text.append(translate_text) 91 | translate_text = "\n".join(recognized_text) 92 | print(translate_text) 93 | self.text_box.delete("1.0", tk.END) # 清空文本框 94 | self.text_box.insert(tk.END, translate_text) 95 | 96 | os.remove(screenshot_path) # 删除临时截图文件 97 | 98 | self.deiconify() # 在截图完成后显示窗口 99 | 100 | def on_closing(self): 101 | self.withdraw() # 仅隐藏窗口,不退出程序 102 | 103 | 104 | if __name__ == "__main__": 105 | app = ScreenShotTool() 106 | app.mainloop() 107 | -------------------------------------------------------------------------------- /ScreenshotTranslation/screenshot_translation_en2ch.py: -------------------------------------------------------------------------------- 1 | import os 2 | import subprocess 3 | import tkinter as tk 4 | import configparser 5 | from PIL import Image, ImageTk 6 | from pynput import keyboard 7 | from paddleocr import PaddleOCR 8 | from transformers import AutoTokenizer, AutoModelForSeq2SeqLM 9 | import transformers 10 | import platform 11 | import sys 12 | sys.path.append(os.getcwd()) 13 | import my_utils 14 | 15 | config = configparser.ConfigParser() 16 | config.read("config.conf") 17 | 18 | 19 | class ScreenShotTool(tk.Tk): 20 | def __init__(self): 21 | super().__init__() 22 | 23 | self.title("全能翻译") 24 | self.geometry("700x200") 25 | 26 | self.init_ui() 27 | self.withdraw() # 初始时隐藏窗口 28 | 29 | def init_ui(self): 30 | self.protocol("WM_DELETE_WINDOW", self.on_closing) # 拦截窗口关闭事件 31 | 32 | self.frame = tk.Frame(self) 33 | self.frame.pack() 34 | 35 | self.image_frame = tk.Frame(self.frame) 36 | self.image_frame.pack(side=tk.LEFT, padx=10, pady=10) 37 | 38 | self.text_frame = tk.Frame(self.frame) 39 | self.text_frame.pack(side=tk.RIGHT, padx=10, pady=10) 40 | self.image_label1 = tk.Label(self.image_frame, text="截图:") 41 | self.image_label1.pack() 42 | self.image_label = tk.Label(self.image_frame) 43 | self.image_label.pack() 44 | 45 | self.text_label = tk.Label(self.text_frame, text="翻译结果:") 46 | self.text_label.pack() 47 | 48 | self.text_box = tk.Text(self.text_frame, height=10, width=40) 49 | self.text_box.pack(fill=tk.BOTH, expand=True) 50 | screenshot_shortcut = config.get("shortcuts", "screenshot") 51 | self.listener = keyboard.GlobalHotKeys({ 52 | screenshot_shortcut: self.capture_and_ocr 53 | }) 54 | self.listener.start() 55 | 56 | def capture_and_ocr(self): 57 | ocr = PaddleOCR(use_angle_cls=True, lang='en',cls_model_dir='models/ocr/cls/ch_ppocr_mobile_v2.0_cls_infer',det_model_dir='models/ocr/det/en/en_PP-OCRv4_det_infer',rec_model_dir='models/ocr/rec/en/en_PP-OCRv4_rec_infer') 58 | model_path = 'models/translate/en-zh' 59 | tokenizer = AutoTokenizer.from_pretrained(model_path) 60 | model = AutoModelForSeq2SeqLM.from_pretrained(model_path) 61 | pipeline = transformers.pipeline("translation", model=model, tokenizer=tokenizer) 62 | screenshot_path = "temp_screenshot.png" 63 | 64 | system = platform.system() 65 | 66 | if system == "Windows": 67 | from PIL import ImageGrab 68 | subprocess.run(["snippingtool", "/clip"]) 69 | clipboard_image = ImageGrab.grabclipboard() 70 | 71 | if clipboard_image is not None: 72 | # 保存图像到指定路径 73 | clipboard_image.save(screenshot_path) 74 | elif system == "Linux": 75 | subprocess.run(["gnome-screenshot", "-a", "-f", screenshot_path]) 76 | elif system == "Darwin": 77 | subprocess.run(["screencapture", "-i", screenshot_path]) 78 | else: 79 | print("Unsupported operating system.") 80 | return 81 | 82 | image = Image.open(screenshot_path) 83 | image.thumbnail((300, 300)) 84 | photo = ImageTk.PhotoImage(image) 85 | self.image_label.config(image=photo) 86 | self.image_label.image = photo 87 | 88 | recognized_text = [] 89 | results = ocr.ocr(screenshot_path, cls=True) 90 | for idx in range(len(results)): 91 | res = results[idx] 92 | for line in res: 93 | translate_text = pipeline(line[1][0])[0]['translation_text'] 94 | translate_text = my_utils.do_sentence(translate_text) 95 | recognized_text.append(translate_text) 96 | translate_text = "\n".join(recognized_text) 97 | print(translate_text) 98 | self.text_box.delete("1.0", tk.END) # 清空文本框 99 | self.text_box.insert(tk.END, translate_text) 100 | 101 | os.remove(screenshot_path) # 删除临时截图文件 102 | 103 | self.deiconify() # 在截图完成后显示窗口 104 | 105 | def on_closing(self): 106 | self.withdraw() # 仅隐藏窗口,不退出程序 107 | 108 | 109 | if __name__ == "__main__": 110 | app = ScreenShotTool() 111 | app.mainloop() 112 | -------------------------------------------------------------------------------- /DocImgTranslation/FastSAM/fastsam/model.py: -------------------------------------------------------------------------------- 1 | # Ultralytics YOLO 🚀, AGPL-3.0 license 2 | """ 3 | FastSAM model interface. 4 | 5 | Usage - Predict: 6 | from ultralytics import FastSAM 7 | 8 | model = FastSAM('last.pt') 9 | results = model.predict('ultralytics/assets/bus.jpg') 10 | """ 11 | 12 | from ultralytics.yolo.cfg import get_cfg 13 | from ultralytics.yolo.engine.exporter import Exporter 14 | from ultralytics.yolo.engine.model import YOLO 15 | from ultralytics.yolo.utils import DEFAULT_CFG, LOGGER, ROOT, is_git_dir 16 | from ultralytics.yolo.utils.checks import check_imgsz 17 | 18 | from ultralytics.yolo.utils.torch_utils import model_info, smart_inference_mode 19 | from .predict import FastSAMPredictor 20 | 21 | 22 | class FastSAM(YOLO): 23 | 24 | @smart_inference_mode() 25 | def predict(self, source=None, stream=False, **kwargs): 26 | """ 27 | Perform prediction using the YOLO model. 28 | 29 | Args: 30 | source (str | int | PIL | np.ndarray): The source of the image to make predictions on. 31 | Accepts all source types accepted by the YOLO model. 32 | stream (bool): Whether to stream the predictions or not. Defaults to False. 33 | **kwargs : Additional keyword arguments passed to the predictor. 34 | Check the 'configuration' section in the documentation for all available options. 35 | 36 | Returns: 37 | (List[ultralytics.yolo.engine.results.Results]): The prediction results. 38 | """ 39 | if source is None: 40 | source = ROOT / 'assets' if is_git_dir() else 'https://ultralytics.com/images/bus.jpg' 41 | LOGGER.warning(f"WARNING ⚠️ 'source' is missing. Using 'source={source}'.") 42 | overrides = self.overrides.copy() 43 | overrides['conf'] = 0.25 44 | overrides.update(kwargs) # prefer kwargs 45 | overrides['mode'] = kwargs.get('mode', 'predict') 46 | assert overrides['mode'] in ['track', 'predict'] 47 | overrides['save'] = kwargs.get('save', False) # do not save by default if called in Python 48 | self.predictor = FastSAMPredictor(overrides=overrides) 49 | self.predictor.setup_model(model=self.model, verbose=False) 50 | try: 51 | return self.predictor(source, stream=stream) 52 | except Exception as e: 53 | return None 54 | 55 | def train(self, **kwargs): 56 | """Function trains models but raises an error as FastSAM models do not support training.""" 57 | raise NotImplementedError("Currently, the training codes are on the way.") 58 | 59 | def val(self, **kwargs): 60 | """Run validation given dataset.""" 61 | overrides = dict(task='segment', mode='val') 62 | overrides.update(kwargs) # prefer kwargs 63 | args = get_cfg(cfg=DEFAULT_CFG, overrides=overrides) 64 | args.imgsz = check_imgsz(args.imgsz, max_dim=1) 65 | validator = FastSAM(args=args) 66 | validator(model=self.model) 67 | self.metrics = validator.metrics 68 | return validator.metrics 69 | 70 | @smart_inference_mode() 71 | def export(self, **kwargs): 72 | """ 73 | Export model. 74 | 75 | Args: 76 | **kwargs : Any other args accepted by the predictors. To see all args check 'configuration' section in docs 77 | """ 78 | overrides = dict(task='detect') 79 | overrides.update(kwargs) 80 | overrides['mode'] = 'export' 81 | args = get_cfg(cfg=DEFAULT_CFG, overrides=overrides) 82 | args.task = self.task 83 | if args.imgsz == DEFAULT_CFG.imgsz: 84 | args.imgsz = self.model.args['imgsz'] # use trained imgsz unless custom value is passed 85 | if args.batch == DEFAULT_CFG.batch: 86 | args.batch = 1 # default to 1 if not modified 87 | return Exporter(overrides=args)(model=self.model) 88 | 89 | def info(self, detailed=False, verbose=True): 90 | """ 91 | Logs model info. 92 | 93 | Args: 94 | detailed (bool): Show detailed information about model. 95 | verbose (bool): Controls verbosity. 96 | """ 97 | return model_info(self.model, detailed=detailed, verbose=verbose, imgsz=640) 98 | 99 | def __call__(self, source=None, stream=False, **kwargs): 100 | """Calls the 'predict' function with given arguments to perform object detection.""" 101 | return self.predict(source, stream, **kwargs) 102 | 103 | def __getattr__(self, attr): 104 | """Raises error if object has no requested attribute.""" 105 | name = self.__class__.__name__ 106 | raise AttributeError(f"'{name}' object has no attribute '{attr}'. See valid attributes below.\n{self.__doc__}") 107 | -------------------------------------------------------------------------------- /DocImgTranslation/FastSAM/fastsam/decoder.py: -------------------------------------------------------------------------------- 1 | from .model import FastSAM 2 | import numpy as np 3 | from PIL import Image 4 | import clip 5 | from typing import Optional, List, Tuple, Union 6 | 7 | 8 | class FastSAMDecoder: 9 | def __init__( 10 | self, 11 | model: FastSAM, 12 | device: str='cpu', 13 | conf: float=0.4, 14 | iou: float=0.9, 15 | imgsz: int=1024, 16 | retina_masks: bool=True, 17 | ): 18 | self.model = model 19 | self.device = device 20 | self.retina_masks = retina_masks 21 | self.imgsz = imgsz 22 | self.conf = conf 23 | self.iou = iou 24 | self.image = None 25 | self.image_embedding = None 26 | 27 | def run_encoder(self, image): 28 | if isinstance(image,str): 29 | image = np.array(Image.open(image)) 30 | self.image = image 31 | image_embedding = self.model( 32 | self.image, 33 | device=self.device, 34 | retina_masks=self.retina_masks, 35 | imgsz=self.imgsz, 36 | conf=self.conf, 37 | iou=self.iou 38 | ) 39 | return image_embedding[0].numpy() 40 | 41 | def run_decoder( 42 | self, 43 | image_embedding, 44 | point_prompt: Optional[np.ndarray]=None, 45 | point_label: Optional[np.ndarray]=None, 46 | box_prompt: Optional[np.ndarray]=None, 47 | text_prompt: Optional[str]=None, 48 | )->np.ndarray: 49 | self.image_embedding = image_embedding 50 | if point_prompt is not None: 51 | ann = self.point_prompt(points=point_prompt, pointlabel=point_label) 52 | return ann 53 | elif box_prompt is not None: 54 | ann = self.box_prompt(bbox=box_prompt) 55 | return ann 56 | elif text_prompt is not None: 57 | ann = self.text_prompt(text=text_prompt) 58 | return ann 59 | else: 60 | return None 61 | 62 | def box_prompt(self, bbox): 63 | assert (bbox[2] != 0 and bbox[3] != 0) 64 | masks = self.image_embedding.masks.data 65 | target_height = self.image.shape[0] 66 | target_width = self.image.shape[1] 67 | h = masks.shape[1] 68 | w = masks.shape[2] 69 | if h != target_height or w != target_width: 70 | bbox = [ 71 | int(bbox[0] * w / target_width), 72 | int(bbox[1] * h / target_height), 73 | int(bbox[2] * w / target_width), 74 | int(bbox[3] * h / target_height), ] 75 | bbox[0] = round(bbox[0]) if round(bbox[0]) > 0 else 0 76 | bbox[1] = round(bbox[1]) if round(bbox[1]) > 0 else 0 77 | bbox[2] = round(bbox[2]) if round(bbox[2]) < w else w 78 | bbox[3] = round(bbox[3]) if round(bbox[3]) < h else h 79 | 80 | # IoUs = torch.zeros(len(masks), dtype=torch.float32) 81 | bbox_area = (bbox[3] - bbox[1]) * (bbox[2] - bbox[0]) 82 | 83 | masks_area = np.sum(masks[:, bbox[1]:bbox[3], bbox[0]:bbox[2]], axis=(1, 2)) 84 | orig_masks_area = np.sum(masks, axis=(1, 2)) 85 | 86 | union = bbox_area + orig_masks_area - masks_area 87 | IoUs = masks_area / union 88 | max_iou_index = np.argmax(IoUs) 89 | 90 | return np.array([masks[max_iou_index].cpu().numpy()]) 91 | 92 | def point_prompt(self, points, pointlabel): # numpy 93 | 94 | masks = self._format_results(self.image_embedding[0], 0) 95 | target_height = self.image.shape[0] 96 | target_width = self.image.shape[1] 97 | h = masks[0]['segmentation'].shape[0] 98 | w = masks[0]['segmentation'].shape[1] 99 | if h != target_height or w != target_width: 100 | points = [[int(point[0] * w / target_width), int(point[1] * h / target_height)] for point in points] 101 | onemask = np.zeros((h, w)) 102 | masks = sorted(masks, key=lambda x: x['area'], reverse=True) 103 | for i, annotation in enumerate(masks): 104 | if type(annotation) == dict: 105 | mask = annotation['segmentation'] 106 | else: 107 | mask = annotation 108 | for i, point in enumerate(points): 109 | if mask[point[1], point[0]] == 1 and pointlabel[i] == 1: 110 | onemask[mask] = 1 111 | if mask[point[1], point[0]] == 1 and pointlabel[i] == 0: 112 | onemask[mask] = 0 113 | onemask = onemask >= 1 114 | return np.array([onemask]) 115 | 116 | def _format_results(self, result, filter=0): 117 | annotations = [] 118 | n = len(result.masks.data) 119 | for i in range(n): 120 | annotation = {} 121 | mask = result.masks.data[i] == 1.0 122 | 123 | if np.sum(mask) < filter: 124 | continue 125 | annotation['id'] = i 126 | annotation['segmentation'] = mask 127 | annotation['bbox'] = result.boxes.data[i] 128 | annotation['score'] = result.boxes.conf[i] 129 | annotation['area'] = annotation['segmentation'].sum() 130 | annotations.append(annotation) 131 | return annotations 132 | -------------------------------------------------------------------------------- /SubtitleTranslation/subtitle_translation_ch2en.py: -------------------------------------------------------------------------------- 1 | from datetime import timedelta 2 | import math 3 | import cv2 4 | import os 5 | from paddleocr import PaddleOCR 6 | import transformers 7 | from transformers import AutoTokenizer, AutoModelForSeq2SeqLM 8 | from moviepy.editor import VideoFileClip 9 | from moviepy.editor import TextClip 10 | from moviepy.editor import CompositeVideoClip 11 | running = True 12 | def cvsecs(time): 13 | print(time) 14 | time_parts = time.split(':') 15 | time_parts[-1] = time_parts[-1].replace(',','.') 16 | seconds = time_parts[-1].replace(',','.') 17 | seconds = seconds.split('.')[0] 18 | milliseconds = "0" 19 | if len(time_parts[-1].split('.')) > 1: 20 | milliseconds = time_parts[-1].split('.')[1] 21 | milliseconds = milliseconds.replace(',', '.') 22 | hours = time_parts[0] 23 | minutes = time_parts[1] 24 | 25 | return float(hours) * 3600 + float(minutes) * 60 + float(seconds) + float(milliseconds) / 1000 26 | 27 | def gen_video(video_path, output_file,fps): 28 | # 读取原始视频 29 | video = VideoFileClip(video_path) 30 | 31 | # 读取字幕文件 32 | subtitles = [] 33 | with open("SubtitleTranslation/tmp.srt", "r") as srt_file: 34 | for line in srt_file: 35 | if not running: 36 | break 37 | line = line.strip() 38 | if line.isdigit(): 39 | continue 40 | elif "-->" in line: 41 | start, end = line.split(" --> ") 42 | elif line: 43 | subtitles.append((start, end, line)) 44 | print(subtitles) 45 | 46 | # 创建带有字幕的视频剪辑 47 | subtitled_video = video 48 | for subtitle in subtitles: 49 | if not running: 50 | break 51 | start_time = subtitle[0] 52 | end_time = subtitle[1] 53 | text = subtitle[2] 54 | 55 | # 将时间字符串转换为浮点数 56 | start_time = cvsecs(start_time) 57 | end_time = cvsecs(end_time) 58 | 59 | subtitle_clip = TextClip(text, fontsize=30, color='white', font='Arial', bg_color='black').set_position( 60 | ('center', 'bottom')).set_start(start_time).set_end(end_time) 61 | 62 | # 叠加字幕剪辑到视频上 63 | subtitled_video = CompositeVideoClip([subtitled_video.set_audio(None), subtitle_clip]) 64 | 65 | 66 | # 设置视频的持续时间 67 | subtitled_video = subtitled_video.set_duration(video.duration) 68 | print(subtitled_video) 69 | # 添加原始视频的音频 70 | subtitled_video = subtitled_video.set_audio(video.audio) 71 | # 生成输出视频文件 72 | subtitled_video.write_videofile(output_file, codec='libx264',audio_codec="aac", fps=fps) 73 | 74 | def translate(video_path, output_path, type): 75 | src_video = cv2.VideoCapture(video_path) 76 | ocr = PaddleOCR(use_angle_cls=False, lang='ch',cls_model_dir='models/ocr/cls/ch_ppocr_mobile_v2.0_cls_infer',det_model_dir='models/ocr/det/ch/ch_PP-OCRv4_det_infer',rec_model_dir='models/ocr/rec/ch/ch_PP-OCRv4_rec_infer') 77 | model_path = 'models/translate/zh-en' 78 | tokenizer = AutoTokenizer.from_pretrained(model_path) 79 | model = AutoModelForSeq2SeqLM.from_pretrained(model_path) 80 | pipeline = transformers.pipeline("translation", model=model, tokenizer=tokenizer) 81 | 82 | # 获取视频帧率 83 | fps = round(src_video.get(cv2.CAP_PROP_FPS)) 84 | # 计算视频总帧数 85 | total_frame = int(src_video.get(cv2.CAP_PROP_FRAME_COUNT)) 86 | 87 | # 设置字幕序号 88 | num = 0 89 | save_text = [] 90 | save_time = [] 91 | last_res = 'None' 92 | print(total_frame) 93 | 94 | if type == "输出字幕文件": 95 | video_name = os.path.basename(video_path) 96 | video_name = video_name.replace('mp4','srt') 97 | srt_path = os.path.join(output_path, video_name) 98 | else: 99 | srt_path = 'SubtitleTranslation/tmp.srt' 100 | 101 | with open(srt_path, 'w+') as f: 102 | for i in range(total_frame): 103 | if not running: 104 | break 105 | success, frame = src_video.read() 106 | # 在视频字幕提取任务中,为了提速改为了每秒抽一帧,因此所有时间都是整数 107 | if i % math.floor(fps) == 0: 108 | if success: 109 | result = ocr.ocr(frame[-120:-30, :], cls=True) 110 | if len(result[0]) > 0: 111 | res = result[0][0][1][0] 112 | res = pipeline(res)[0]['translation_text'] 113 | if ((res[:-1] not in last_res) and (last_res[:-1] not in res)): 114 | # 检测到新字幕,录入并开始计时 115 | last_res = res 116 | # 更新序号 117 | num += 1 118 | # 更新开始时间 119 | start_time = "%s,%03d" % (timedelta(seconds=i // fps), 0 * 1000) 120 | # 一段语句的结束时间,和语速也有一定关系,读者可以自行调整 121 | end_time = "%s,%03d" % (timedelta(seconds=i // fps), 1.5 * 1000) 122 | f.write(str(num) + '\n') 123 | f.write(start_time + ' --> ' + end_time + '\n') 124 | f.write(res + '\n') 125 | f.write('\n') 126 | # 保存文稿信息,用于后续标点恢复和问答任务 127 | save_text.append(res) 128 | # 保存字幕 + 时间信息,用于快速截取视频片段 129 | save_time.append(res + ' ' + str(i // fps)) 130 | 131 | if type == "输出视频": 132 | output_path = os.path.join(output_path, os.path.basename(video_path)) 133 | gen_video(video_path, output_path,fps) 134 | 135 | text = "已处理完成!" 136 | return text 137 | def stop_translate(): 138 | global running 139 | running = False 140 | if __name__ == '__main__': 141 | translate('/Users/liuhongdi/Downloads/Trim.mp4', "/Users/liuhongdi/Downloads/", "输出视频") 142 | -------------------------------------------------------------------------------- /SubtitleTranslation/subtitle_translation_en2ch.py: -------------------------------------------------------------------------------- 1 | from datetime import timedelta 2 | import math 3 | import cv2 4 | import os 5 | from paddleocr import PaddleOCR 6 | import transformers 7 | from transformers import AutoTokenizer, AutoModelForSeq2SeqLM 8 | from moviepy.editor import VideoFileClip 9 | from moviepy.editor import TextClip 10 | from moviepy.editor import CompositeVideoClip 11 | import sys 12 | sys.path.append(os.getcwd()) 13 | import my_utils 14 | 15 | running = True 16 | def cvsecs(time): 17 | print(time) 18 | time_parts = time.split(':') 19 | time_parts[-1] = time_parts[-1].replace(',','.') 20 | seconds = time_parts[-1].replace(',','.') 21 | seconds = seconds.split('.')[0] 22 | milliseconds = "0" 23 | if len(time_parts[-1].split('.')) > 1: 24 | milliseconds = time_parts[-1].split('.')[1] 25 | milliseconds = milliseconds.replace(',', '.') 26 | hours = time_parts[0] 27 | minutes = time_parts[1] 28 | 29 | return float(hours) * 3600 + float(minutes) * 60 + float(seconds) + float(milliseconds) / 1000 30 | 31 | def gen_video(video_path, output_file,fps): 32 | # 读取原始视频 33 | video = VideoFileClip(video_path) 34 | 35 | # 读取字幕文件 36 | subtitles = [] 37 | with open("SubtitleTranslation/tmp.srt", "r") as srt_file: 38 | for line in srt_file: 39 | if not running: 40 | break 41 | line = line.strip() 42 | if line.isdigit(): 43 | continue 44 | elif "-->" in line: 45 | start, end = line.split(" --> ") 46 | elif line: 47 | subtitles.append((start, end, line)) 48 | print(subtitles) 49 | 50 | # 创建带有字幕的视频剪辑 51 | subtitled_video = video 52 | for subtitle in subtitles: 53 | if not running: 54 | break 55 | start_time = subtitle[0] 56 | end_time = subtitle[1] 57 | text = subtitle[2] 58 | 59 | # 将时间字符串转换为浮点数 60 | start_time = cvsecs(start_time) 61 | end_time = cvsecs(end_time) 62 | 63 | subtitle_clip = TextClip(text, fontsize=30, color='white', font='Arial', bg_color='black').set_position( 64 | ('center', 'bottom')).set_start(start_time).set_end(end_time) 65 | 66 | # 叠加字幕剪辑到视频上 67 | subtitled_video = CompositeVideoClip([subtitled_video.set_audio(None), subtitle_clip]) 68 | 69 | # 设置视频的持续时间 70 | subtitled_video = subtitled_video.set_duration(video.duration) 71 | 72 | # 添加原始视频的音频 73 | subtitled_video = subtitled_video.set_audio(video.audio) 74 | 75 | # 生成输出视频文件 76 | subtitled_video.write_videofile(output_file, fps=fps, codec='libx264',audio_codec="aac") 77 | 78 | def translate(video_path, output_path, type): 79 | src_video = cv2.VideoCapture(video_path) 80 | ocr = PaddleOCR(use_angle_cls=False, lang='en',cls_model_dir='models/ocr/cls/ch_ppocr_mobile_v2.0_cls_infer',det_model_dir='models/ocr/det/en/en_PP-OCRv4_det_infer',rec_model_dir='models/ocr/rec/en/en_PP-OCRv4_rec_infer') 81 | model_path = 'models/translate/en-zh' 82 | tokenizer = AutoTokenizer.from_pretrained(model_path) 83 | model = AutoModelForSeq2SeqLM.from_pretrained(model_path) 84 | pipeline = transformers.pipeline("translation", model=model, tokenizer=tokenizer) 85 | 86 | # 获取视频帧率 87 | fps = round(src_video.get(cv2.CAP_PROP_FPS)) 88 | # 计算视频总帧数 89 | total_frame = int(src_video.get(cv2.CAP_PROP_FRAME_COUNT)) 90 | 91 | # 设置字幕序号 92 | num = 0 93 | save_text = [] 94 | save_time = [] 95 | last_res = 'None' 96 | print(total_frame) 97 | 98 | if type == "输出字幕文件": 99 | video_name = os.path.basename(video_path) 100 | video_name = video_name.replace('mp4','srt') 101 | srt_path = os.path.join(output_path, video_name) 102 | else: 103 | srt_path = 'SubtitleTranslation/tmp.srt' 104 | 105 | with open(srt_path, 'w+') as f: 106 | for i in range(total_frame): 107 | if not running: 108 | break 109 | success, frame = src_video.read() 110 | 111 | if i % math.floor(fps) == 0: 112 | if success: 113 | result = ocr.ocr(frame[-120:-30, :], cls=True) 114 | if len(result[0]) > 0: 115 | res = result[0][0][1][0] 116 | res = pipeline(res)[0]['translation_text'] 117 | res = my_utils.do_sentence(res) 118 | if ((res[:-1] not in last_res) and (last_res[:-1] not in res)): 119 | # 检测到新字幕,录入并开始计时 120 | last_res = res 121 | # 更新序号 122 | num += 1 123 | # 更新开始时间 124 | start_time = "%s,%03d" % (timedelta(seconds=i // fps), 0 * 1000) 125 | # 一段语句的结束时间,和语速也有一定关系,读者可以自行调整 126 | end_time = "%s,%03d" % (timedelta(seconds=i // fps), 1.5 * 1000) 127 | f.write(str(num) + '\n') 128 | f.write(start_time + ' --> ' + end_time + '\n') 129 | f.write(res + '\n') 130 | f.write('\n') 131 | # 保存文稿信息,用于后续标点恢复和问答任务 132 | save_text.append(res) 133 | # 保存字幕 + 时间信息,用于快速截取视频片段 134 | save_time.append(res + ' ' + str(i // fps)) 135 | 136 | if type == "输出视频": 137 | output_path = os.path.join(output_path, os.path.basename(video_path)) 138 | gen_video(video_path, output_path,fps) 139 | 140 | text = "已处理完成!" 141 | return text 142 | def stop_translate(): 143 | global running 144 | running = False 145 | if __name__ == '__main__': 146 | translate('/Users/liuhongdi/Downloads/Trim.mp4', "/Users/liuhongdi/Downloads/", "输出视频") 147 | -------------------------------------------------------------------------------- /DocImgTranslation/FastSAM/predict.py: -------------------------------------------------------------------------------- 1 | # Prediction interface for Cog ⚙️ 2 | # https://github.com/replicate/cog/blob/main/docs/python.md 3 | # Thanks for chenxwh. 4 | 5 | import argparse 6 | import cv2 7 | import shutil 8 | import ast 9 | from cog import BasePredictor, Input, Path 10 | from ultralytics import YOLO 11 | from utils.tools import * 12 | 13 | 14 | class Predictor(BasePredictor): 15 | def setup(self): 16 | """Load the model into memory to make running multiple predictions efficient""" 17 | self.models = {k: YOLO(f"{k}.pt") for k in ["FastSAM-s", "FastSAM-x"]} 18 | 19 | def predict( 20 | self, 21 | input_image: Path = Input(description="Input image"), 22 | model_name: str = Input( 23 | description="choose a model", 24 | choices=["FastSAM-x", "FastSAM-s"], 25 | default="FastSAM-x", 26 | ), 27 | iou: float = Input( 28 | description="iou threshold for filtering the annotations", default=0.7 29 | ), 30 | text_prompt: str = Input( 31 | description='use text prompt eg: "a black dog"', default=None 32 | ), 33 | conf: float = Input(description="object confidence threshold", default=0.25), 34 | retina: bool = Input( 35 | description="draw high-resolution segmentation masks", default=True 36 | ), 37 | box_prompt: str = Input(default="[0,0,0,0]", description="[x,y,w,h]"), 38 | point_prompt: str = Input(default="[[0,0]]", description="[[x1,y1],[x2,y2]]"), 39 | point_label: str = Input(default="[0]", description="[1,0] 0:background, 1:foreground"), 40 | withContours: bool = Input( 41 | description="draw the edges of the masks", default=False 42 | ), 43 | better_quality: bool = Input( 44 | description="better quality using morphologyEx", default=False 45 | ), 46 | ) -> Path: 47 | """Run a single prediction on the model""" 48 | 49 | # default params 50 | 51 | out_path = "output" 52 | if os.path.exists(out_path): 53 | shutil.rmtree(out_path) 54 | os.makedirs(out_path, exist_ok=True) 55 | 56 | device = torch.device( 57 | "cuda" 58 | if torch.cuda.is_available() 59 | else "mps" 60 | if torch.backends.mps.is_available() 61 | else "cpu" 62 | ) 63 | 64 | args = argparse.Namespace( 65 | better_quality=better_quality, 66 | box_prompt=box_prompt, 67 | conf=conf, 68 | device=device, 69 | img_path=str(input_image), 70 | imgsz=1024, 71 | iou=iou, 72 | model_path="FastSAM-x.pt", 73 | output=out_path, 74 | point_label=point_label, 75 | point_prompt=point_prompt, 76 | randomcolor=True, 77 | retina=retina, 78 | text_prompt=text_prompt, 79 | withContours=withContours, 80 | ) 81 | args.point_prompt = ast.literal_eval(args.point_prompt) 82 | args.box_prompt = ast.literal_eval(args.box_prompt) 83 | args.point_label = ast.literal_eval(args.point_label) 84 | 85 | model = self.models[model_name] 86 | 87 | results = model( 88 | str(input_image), 89 | imgsz=args.imgsz, 90 | device=args.device, 91 | retina_masks=args.retina, 92 | iou=args.iou, 93 | conf=args.conf, 94 | max_det=100, 95 | ) 96 | 97 | if args.box_prompt[2] != 0 and args.box_prompt[3] != 0: 98 | annotations = prompt(results, args, box=True) 99 | annotations = np.array([annotations]) 100 | fast_process( 101 | annotations=annotations, 102 | args=args, 103 | mask_random_color=args.randomcolor, 104 | bbox=convert_box_xywh_to_xyxy(args.box_prompt), 105 | ) 106 | 107 | elif args.text_prompt != None: 108 | results = format_results(results[0], 0) 109 | annotations = prompt(results, args, text=True) 110 | annotations = np.array([annotations]) 111 | fast_process( 112 | annotations=annotations, args=args, mask_random_color=args.randomcolor 113 | ) 114 | 115 | elif args.point_prompt[0] != [0, 0]: 116 | results = format_results(results[0], 0) 117 | annotations = prompt(results, args, point=True) 118 | # list to numpy 119 | annotations = np.array([annotations]) 120 | fast_process( 121 | annotations=annotations, 122 | args=args, 123 | mask_random_color=args.randomcolor, 124 | points=args.point_prompt, 125 | ) 126 | 127 | else: 128 | fast_process( 129 | annotations=results[0].masks.data, 130 | args=args, 131 | mask_random_color=args.randomcolor, 132 | ) 133 | 134 | out = "/tmp.out.png" 135 | shutil.copy(os.path.join(out_path, os.listdir(out_path)[0]), out) 136 | 137 | return Path(out) 138 | 139 | 140 | def prompt(results, args, box=None, point=None, text=None): 141 | ori_img = cv2.imread(args.img_path) 142 | ori_h = ori_img.shape[0] 143 | ori_w = ori_img.shape[1] 144 | if box: 145 | mask, idx = box_prompt( 146 | results[0].masks.data, 147 | convert_box_xywh_to_xyxy(args.box_prompt), 148 | ori_h, 149 | ori_w, 150 | ) 151 | elif point: 152 | mask, idx = point_prompt( 153 | results, args.point_prompt, args.point_label, ori_h, ori_w 154 | ) 155 | elif text: 156 | mask, idx = text_prompt(results, args.text_prompt, args.img_path, args.device) 157 | else: 158 | return None 159 | return mask 160 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

2 | 3 |

4 |

Ace-Translate

5 |

6 | 一款本地离线的翻译程序 7 |

8 | 9 | 10 | 11 |

12 | 13 | 14 | 15 |

16 | 17 | 18 | 19 | 20 | ## Features 21 | 22 | 支持多种翻译场景。 23 | 24 | - 汉译英和英译汉 25 | - 文本翻译 26 | - 划词翻译 27 | - 截图翻译 28 | - 视频翻译 29 | - 文件翻译,包括TXT文件、Excel、PPT、PDF、图片和Word 30 | - 文档图片翻译 31 | 32 | 33 | 34 | ## INSTALL 35 | 36 | 推荐使用`python3.8`+`paddlepaddle2.4.0`+`torch2.0.1` 37 | 38 | ### 1.拉代码 39 | 40 | ``` 41 | git clone https://github.com/tianclll/Ace-Translate.git 42 | ``` 43 | 44 | ``` 45 | cd Ace-Translate 46 | ``` 47 | 48 | 49 | 50 | ### 2.安装 51 | 52 | #### 2.1安装PaddlePaddle 53 | 54 | - GPU 55 | 56 | ``` 57 | python3 -m pip install paddlepaddle-gpu==2.4.0 -i https://mirror.baidu.com/pypi/simple 58 | ``` 59 | 60 | - CPU 61 | 62 | ``` 63 | python3 -m pip install paddlepaddle==2.4.0 -i https://mirror.baidu.com/pypi/simple 64 | ``` 65 | 66 | #### 2.2安装依赖 67 | 68 | ``` 69 | pip install -r requirements.txt 70 | ``` 71 | - Windows需要额外下载 72 | ``` 73 | pip install transformers[sentencepiece] 74 | pip install clip 75 | ``` 76 | 77 | #### 2.3下载模型文件 78 | 79 | 点击[此处](https://www.123pan.com/s/knrdjv-JN5N3.html)下载 80 | 81 | 解压后,放入项目文件夹(Ace-Translate)中。 82 | 83 | #### 2.4安装Pyaudio 84 | 需要运行语音翻译才安装 85 | 86 | - Linux 87 | 88 | ``` 89 | sudo apt-get install libasound2-dev 90 | wget https://files.portaudio.com/archives/pa_stable_v190700_20210406.tgz 91 | tar -xvf pa_stable_v190700_20210406.tgz 92 | cd portaudio 93 | ./configure 94 | make 95 | sudo make install 96 | make clean 97 | sudo apt-get install python3-pyaudio 98 | pip install pyaudio 99 | ``` 100 | - Mac 101 | 102 | ``` 103 | sudo brew install libasound2-dev 104 | wget https://files.portaudio.com/archives/pa_stable_v190700_20210406.tgz 105 | tar -xvf pa_stable_v190700_20210406.tgz 106 | cd portaudio 107 | ./configure 108 | make 109 | sudo make install 110 | make clean 111 | pip install pyaudio 112 | ``` 113 | - Windows 114 | ``` 115 | pip install pipwin 116 | pipwin install pyaudio 117 | ``` 118 | #### 2.5安装ImageMagick 119 | 120 | 需要运行视频翻译才安装 121 | [官网下载地址](https://www.imagemagick.org/script/download.php) 122 | ### 3.设置 123 | 124 | 修改`config.conf`文件: 125 | 126 | - 设置快捷键 127 | - 设置运行设备 `gpu` or `cpu` 128 | 129 | ### 4.运行 130 | 131 | 注意:第一次语音翻译模块都需要连网 132 | 133 | ``` 134 | python main.py 135 | ``` 136 | 137 | 138 | 139 | ## 效果展示 140 | 141 | 有"划词翻译","截图翻译","PDF翻译","文档图片翻译"四个功能,项目运行后会挂载到状态栏上,点击"x"时不会退出只是隐藏,点击状态栏上的"打开",就会弹出,点击状态栏上的"退出",才是真正的退出程序。(Ubuntu18.04及以后默认无状态栏) 142 | 143 | ### 文本翻译 144 | 145 |

146 | 文本翻译 147 |

148 | 149 | ### 划词翻译 150 | 151 | - 选择"汉译英"或者"英译汉",然后点击开始 152 | - 然后鼠标选中想要翻译的内容,点击复制 153 | - 按下设置的快捷键,就能翻译了 154 | 155 |

156 | 划词翻译 157 |

158 | 159 | ### 截图翻译 160 | 161 |

162 | 截图翻译 163 |

164 | 165 | ### 语音翻译 166 | 167 | 支持音频文件和语音录入 168 | 169 |

170 | 语音翻译 171 |

172 | 173 | ### 视频翻译 174 | 175 | 支持输出srt字幕文件和视频 176 | 177 |

178 | 视频翻译 179 | 视频翻译 180 |

181 | 182 | 183 | 184 | ### 文件翻译 185 | 186 | - TXT 187 | 188 |

189 | TXT文件翻译 190 | TXT文件翻译 191 |

192 | 193 | 194 | - PDF 195 | 196 |

197 | PDF文件翻译 198 | PDF文件翻译 199 |

200 | 201 | 202 | - Excel 203 | 204 |

205 | Excel文件翻译 206 | Excel文件翻译 207 |

208 | 209 | 210 | - Word 211 | 212 |

213 | Word文件翻译 214 | Word文件翻译 215 |

216 | 217 | 218 | ### 文档图片翻译 219 | 220 |

221 | 文档图片翻译 222 | 文档图片翻译 223 |

224 | 225 | 226 | ## 第三方开源软件 227 | 228 | 本项目使用了以下基于 Apache 许可证,版本 2.0 的第三方开源软件: 229 | 230 | - 项目 PaddleOCR 由以下贡献者开发: 231 | - [WenmuZhou](https://github.com/WenmuZhou) 232 | - [LDOUBLEV](https://github.com/LDOUBLEV) 233 | - [MissPenguin](https://github.com/MissPenguin) 234 | ... 235 | 236 | [项目 PaddleOCR 链接](https://github.com/PaddlePaddle/PaddleOCR) 237 | 238 | - 项目 PaddleSpeech 由以下贡献者开发: 239 | - [zh794390558](https://github.com/zh794390558) 240 | - [yt605155624](https://github.com/yt605155624) 241 | - [Jackwaterveg](https://github.com/Jackwaterveg) 242 | ... 243 | 244 | [项目 PaddleSpeech 链接](https://github.com/PaddlePaddle/PaddleSpeech) 245 | 246 | - 项目 FastSAM 由以下贡献者开发: 247 | - [zxDeepDiver](https://github.com/zxDeepDiver) 248 | - [YinglongDu](https://github.com/YinglongDu) 249 | - [berry-ding](https://github.com/berry-ding) 250 | ... 251 | 252 | [项目 FastSAM 链接](https://github.com/CASIA-IVA-Lab/FastSAM) 253 | ## LICENSE 254 | 255 | 本项目的发布受[Apache 2.0 license](https://github.com/tianclll/Ace-Translate/blob/main/LICENSE)许可认证。 256 | 257 | Copyright (c) 2023 tianclll 258 | -------------------------------------------------------------------------------- /FileTranslation/Image/image_translation_ch2en.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import numpy as np 4 | from bs4 import BeautifulSoup 5 | from PIL import Image,ImageFont,ImageDraw 6 | from paddleocr import PPStructure,save_structure_res,PaddleOCR 7 | from paddleocr.ppstructure.recovery.recovery_to_doc import sorted_layout_boxes, convert_info_docx 8 | import transformers 9 | from transformers import AutoTokenizer, AutoModelForSeq2SeqLM 10 | from paddleocr.ppocr.utils.logging import get_logger 11 | import sys 12 | logger = get_logger() 13 | def image_translate(img,text_region,text): 14 | x_coords = [text_region[0][0], text_region[1][0], text_region[2][0], text_region[3][0]] 15 | y_coords = [text_region[0][1],text_region[1][1],text_region[2][1],text_region[3][1]] 16 | x1, x2 = int(min(x_coords)), int(max(x_coords)) 17 | y1, y2 = int(min(y_coords)), int(max(y_coords)) 18 | img[y1:y2, x1:x2] = (255, 255, 255) # 将区域填充为背景色 19 | bbox_width = max(x_coords) - min(x_coords) 20 | bbox_height = max(y_coords) - min(y_coords) 21 | font_size = 12 22 | # 将NumPy数组转换为PIL Image对象 23 | image_pil = Image.fromarray(img) 24 | draw = ImageDraw.Draw(image_pil) 25 | font = ImageFont.truetype("static/simfang.ttf", font_size) 26 | while draw.textsize(text, font=font)[0] < bbox_width and draw.textsize(text, font=font)[ 27 | 1] < bbox_height: 28 | font_size += 1 29 | font = ImageFont.truetype("static/simfang.ttf", font_size) 30 | text_x = text_region[0][0] 31 | text_y = text_region[0][1] 32 | font_color = (0, 0, 0) 33 | draw.text((text_x, text_y), text, font=font, fill=font_color) 34 | # image_pil.save("output_image.jpg") 35 | # 将修改后的PIL Image对象转回OpenCV格式 36 | img = cv2.cvtColor(np.array(image_pil), cv2.COLOR_RGB2BGR) 37 | return img 38 | def table_translate(html_table,translate_model): 39 | # 解析 HTML 40 | soup = BeautifulSoup(html_table, 'html.parser') 41 | # 提取并翻译表格的第一行(表头) 42 | original_header_row = soup.find('tr') 43 | translated_header_row = '' 44 | for cell in original_header_row.find_all('td'): 45 | original_text = cell.text.strip() 46 | translated_text = translate_model(original_text)[0]['translation_text'] 47 | translated_cell = f'{translated_text}' 48 | translated_header_row += translated_cell 49 | translated_header_row += '' 50 | 51 | # 提取并翻译表格的内容行 52 | translated_table = soup.new_tag('table') # 创建一个新的表格用于存储翻译后的文本 53 | translated_table.append(BeautifulSoup(translated_header_row, 'html.parser')) # 将翻译后的表头行插入到表格中 54 | for row in soup.find_all('tr')[1:]: 55 | translated_row = '' # 创建新的行 56 | for cell in row.find_all('td'): 57 | original_text = cell.text.strip() # 提取原始文本 58 | translated_text = translate_model(original_text)[0]['translation_text'] 59 | translated_cell = f'{translated_text}' # 创建新的单元格 60 | translated_row += translated_cell # 将单元格添加到行中 61 | translated_row += '' 62 | translated_table.append(BeautifulSoup(translated_row, 'html.parser')) # 将行添加到表格中 63 | 64 | return translated_table 65 | def img2doc(img,image_name,savePath): 66 | # 加载翻译模型 67 | #translate_model = hub.Module(name='transformer_zh-en', beam_size=5) 68 | # translate_model = hub.Module(name='baidu_translate') 69 | model_path = 'models/translate/zh-en' 70 | tokenizer = AutoTokenizer.from_pretrained(model_path) 71 | model = AutoModelForSeq2SeqLM.from_pretrained(model_path) 72 | translate_model = transformers.pipeline("translation", model=model, tokenizer=tokenizer) 73 | #table_engine = PPStructure(show_log=False, recovery=True, image_orientation=True, layout_model_dir="CDLA dict",structure_version="PP-StructureV2") 74 | table_engine = PPStructure(show_log=False, recovery=True, image_orientation=True, layout_model_dir="models/ocr/layout/CDLA dict", 75 | structure_version="PP-StructureV2",cls_model_dir='models/ocr/cls/ch_ppocr_mobile_v2.0_cls_infer',det_model_dir='models/ocr/det/ch/ch_PP-OCRv4_det_infer',rec_model_dir='models/ocr/rec/ch/ch_PP-OCRv4_rec_infer', 76 | table_model_dir='models/ocr/table/ch_ppstructure_mobile_v2.0_SLANet_infer') 77 | ocr = PaddleOCR(use_angle_cls=True, lang='ch',cls_model_dir='models/ocr/cls/ch_ppocr_mobile_v2.0_cls_infer',det_model_dir='models/ocr/det/ch/ch_PP-OCRv4_det_infer',rec_model_dir='models/ocr/rec/ch/ch_PP-OCRv4_rec_infer') 78 | save_folder = os.path.join(savePath,'output/') 79 | result = table_engine(img) 80 | for i in range(len(result)): 81 | if result[i]["type"] != "figure" and result[i]["type"] != "table" and result[i]["type"] != "equation": 82 | # 本地中译英模型 83 | for j in result[i]["res"]: 84 | content = j["text"] 85 | translate_text = translate_model(content)[0]['translation_text'] 86 | j["text"] = translate_text 87 | # 百度模型英译中 88 | # content = j["text"] 89 | # translate_text = translate_model.translate(content) 90 | # j["text"] = translate_text 91 | elif result[i]["type"] == "figure": 92 | roi_img = result[i]["img"] 93 | res = ocr.ocr(roi_img, cls=True) 94 | for j in res[0]: 95 | content = j[1][0] 96 | translate_text = translate_model(content)[0]['translation_text'] 97 | text_region = j[0] 98 | roi_img = image_translate(roi_img,text_region, translate_text) 99 | result[i]['img'] = roi_img 100 | elif result[i]["type"] == "table": 101 | result[i]["res"]["html"] = "{}".format(table_translate(result[i]["res"]["html"],translate_model)) 102 | save_structure_res(result, save_folder, image_name) 103 | h, w, _ = img.shape 104 | res = sorted_layout_boxes(result, w) 105 | convert_info_docx(img, res, save_folder, image_name) 106 | logger.info('已全部处理完成!') 107 | if __name__ == '__main__': 108 | savePath = sys.argv[2] 109 | imgPath = sys.argv[1] 110 | img_name = os.path.basename(imgPath).split('.')[0] 111 | img = cv2.imread(imgPath) 112 | img2doc(img, img_name, savePath) -------------------------------------------------------------------------------- /FileTranslation/Image/image_translation_en2ch.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import numpy as np 4 | import sys 5 | sys.path.append(os.getcwd()) 6 | import my_utils 7 | from bs4 import BeautifulSoup 8 | from PIL import Image,ImageFont,ImageDraw 9 | from paddleocr import PPStructure,save_structure_res,PaddleOCR 10 | from paddleocr.ppstructure.recovery.recovery_to_doc import sorted_layout_boxes, convert_info_docx 11 | import transformers 12 | from transformers import AutoTokenizer, AutoModelForSeq2SeqLM 13 | from paddleocr.ppocr.utils.logging import get_logger 14 | import sys 15 | logger = get_logger() 16 | def image_translate(img,text_region,text): 17 | x_coords = [text_region[0][0], text_region[1][0], text_region[2][0], text_region[3][0]] 18 | y_coords = [text_region[0][1],text_region[1][1],text_region[2][1],text_region[3][1]] 19 | x1, x2 = int(min(x_coords)), int(max(x_coords)) 20 | y1, y2 = int(min(y_coords)), int(max(y_coords)) 21 | img[y1:y2, x1:x2] = (255, 255, 255) # 将区域填充为背景色 22 | bbox_width = max(x_coords) - min(x_coords) 23 | bbox_height = max(y_coords) - min(y_coords) 24 | font_size = 12 25 | # 将NumPy数组转换为PIL Image对象 26 | image_pil = Image.fromarray(img) 27 | draw = ImageDraw.Draw(image_pil) 28 | font = ImageFont.truetype("static/simfang.ttf", font_size) 29 | while draw.textsize(text, font=font)[0] < bbox_width and draw.textsize(text, font=font)[ 30 | 1] < bbox_height: 31 | font_size += 1 32 | font = ImageFont.truetype("static/simfang.ttf", font_size) 33 | text_x = text_region[0][0] 34 | text_y = text_region[0][1] 35 | font_color = (0, 0, 0) 36 | draw.text((text_x, text_y), text, font=font, fill=font_color) 37 | # image_pil.save("output_image.jpg") 38 | # 将修改后的PIL Image对象转回OpenCV格式 39 | img = cv2.cvtColor(np.array(image_pil), cv2.COLOR_RGB2BGR) 40 | return img 41 | def table_translate(html_table,translate_model): 42 | # 解析 HTML 43 | soup = BeautifulSoup(html_table, 'html.parser') 44 | # 提取并翻译表格的第一行(表头) 45 | original_header_row = soup.find('tr') 46 | translated_header_row = '' 47 | for cell in original_header_row.find_all('td'): 48 | original_text = cell.text.strip() 49 | translated_text = translate_model(original_text)[0]['translation_text'] 50 | translated_text = my_utils.do_sentence(translated_text) 51 | translated_cell = f'{translated_text}' 52 | translated_header_row += translated_cell 53 | translated_header_row += '' 54 | 55 | # 提取并翻译表格的内容行 56 | translated_table = soup.new_tag('table') # 创建一个新的表格用于存储翻译后的文本 57 | translated_table.append(BeautifulSoup(translated_header_row, 'html.parser')) # 将翻译后的表头行插入到表格中 58 | for row in soup.find_all('tr')[1:]: 59 | translated_row = '' # 创建新的行 60 | for cell in row.find_all('td'): 61 | original_text = cell.text.strip() # 提取原始文本 62 | translated_text = translate_model(original_text)[0]['translation_text'] 63 | translated_text = my_utils.do_sentence(translated_text) 64 | translated_cell = f'{translated_text}' # 创建新的单元格 65 | translated_row += translated_cell # 将单元格添加到行中 66 | translated_row += '' 67 | translated_table.append(BeautifulSoup(translated_row, 'html.parser')) # 将行添加到表格中 68 | 69 | return translated_table 70 | def img2doc(img,image_name,savePath): 71 | # 加载翻译模型 72 | #translate_model = hub.Module(name='transformer_zh-en', beam_size=5) 73 | # translate_model = hub.Module(name='baidu_translate') 74 | model_path = 'models/translate/en-zh' 75 | tokenizer = AutoTokenizer.from_pretrained(model_path) 76 | model = AutoModelForSeq2SeqLM.from_pretrained(model_path) 77 | translate_model = transformers.pipeline("translation", model=model, tokenizer=tokenizer) 78 | #table_engine = PPStructure(show_log=False, recovery=True, image_orientation=True, layout_model_dir="CDLA dict",structure_version="PP-StructureV2") 79 | table_engine = PPStructure(show_log=False,recovery=True, lang='en',cls_model_dir='models/ocr/cls/ch_ppocr_mobile_v2.0_cls_infer',det_model_dir='models/ocr/det/en/en_PP-OCRv4_det_infer',rec_model_dir='models/ocr/rec/en/en_PP-OCRv4_rec_infer', 80 | layout_model_dir='models/ocr/layout/picodet_lcnet_x1_0_fgd_layout_infer',table_model_dir='models/ocr/table/en_ppstructure_mobile_v2.0_SLANet_infer') 81 | ocr = PaddleOCR(use_angle_cls=True, lang='en',cls_model_dir='models/ocr/cls/ch_ppocr_mobile_v2.0_cls_infer',det_model_dir='models/ocr/det/en/en_PP-OCRv4_det_infer',rec_model_dir='models/ocr/rec/en/en_PP-OCRv4_rec_infer') 82 | save_folder = os.path.join(savePath,'output/') 83 | result = table_engine(img) 84 | for i in range(len(result)): 85 | if result[i]["type"] != "figure" and result[i]["type"] != "table" and result[i]["type"] != "equation": 86 | # 本地中译英模型 87 | for j in result[i]["res"]: 88 | content = j["text"] 89 | translate_text = translate_model(content)[0]['translation_text'] 90 | translate_text = my_utils.do_sentence(translate_text) 91 | j["text"] = translate_text 92 | # 百度模型英译中 93 | # content = j["text"] 94 | # translate_text = translate_model.translate(content) 95 | # j["text"] = translate_text 96 | elif result[i]["type"] == "figure": 97 | roi_img = result[i]["img"] 98 | res = ocr.ocr(roi_img, cls=True) 99 | for j in res[0]: 100 | content = j[1][0] 101 | translate_text = translate_model(content)[0]['translation_text'] 102 | translate_text = my_utils.do_sentence(translate_text) 103 | text_region = j[0] 104 | roi_img = image_translate(roi_img,text_region, translate_text) 105 | result[i]['img'] = roi_img 106 | elif result[i]["type"] == "table": 107 | result[i]["res"]["html"] = "{}".format(table_translate(result[i]["res"]["html"],translate_model)) 108 | save_structure_res(result, save_folder, image_name) 109 | h, w, _ = img.shape 110 | res = sorted_layout_boxes(result, w) 111 | convert_info_docx(img, res, save_folder, image_name) 112 | logger.info('已全部处理完成!') 113 | if __name__ == '__main__': 114 | savePath = sys.argv[2] 115 | imgPath = sys.argv[1] 116 | img_name = os.path.basename(imgPath).split('.')[0] 117 | img = cv2.imread(imgPath) 118 | img2doc(img, img_name, savePath) -------------------------------------------------------------------------------- /DocImgTranslation/FastSAM/utils/tools_gradio.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from PIL import Image 3 | import matplotlib.pyplot as plt 4 | import cv2 5 | import torch 6 | 7 | 8 | def fast_process( 9 | annotations, 10 | image, 11 | device, 12 | scale, 13 | better_quality=False, 14 | mask_random_color=True, 15 | bbox=None, 16 | use_retina=True, 17 | withContours=True, 18 | ): 19 | if isinstance(annotations[0], dict): 20 | annotations = [annotation['segmentation'] for annotation in annotations] 21 | 22 | original_h = image.height 23 | original_w = image.width 24 | if better_quality: 25 | if isinstance(annotations[0], torch.Tensor): 26 | annotations = np.array(annotations.cpu()) 27 | for i, mask in enumerate(annotations): 28 | mask = cv2.morphologyEx(mask.astype(np.uint8), cv2.MORPH_CLOSE, np.ones((3, 3), np.uint8)) 29 | annotations[i] = cv2.morphologyEx(mask.astype(np.uint8), cv2.MORPH_OPEN, np.ones((8, 8), np.uint8)) 30 | if device == 'cpu': 31 | annotations = np.array(annotations) 32 | inner_mask = fast_show_mask( 33 | annotations, 34 | plt.gca(), 35 | random_color=mask_random_color, 36 | bbox=bbox, 37 | retinamask=use_retina, 38 | target_height=original_h, 39 | target_width=original_w, 40 | ) 41 | else: 42 | if isinstance(annotations[0], np.ndarray): 43 | annotations = torch.from_numpy(annotations) 44 | inner_mask = fast_show_mask_gpu( 45 | annotations, 46 | plt.gca(), 47 | random_color=mask_random_color, 48 | bbox=bbox, 49 | retinamask=use_retina, 50 | target_height=original_h, 51 | target_width=original_w, 52 | ) 53 | if isinstance(annotations, torch.Tensor): 54 | annotations = annotations.cpu().numpy() 55 | 56 | if withContours: 57 | contour_all = [] 58 | temp = np.zeros((original_h, original_w, 1)) 59 | for i, mask in enumerate(annotations): 60 | if type(mask) == dict: 61 | mask = mask['segmentation'] 62 | annotation = mask.astype(np.uint8) 63 | if use_retina == False: 64 | annotation = cv2.resize( 65 | annotation, 66 | (original_w, original_h), 67 | interpolation=cv2.INTER_NEAREST, 68 | ) 69 | contours, _ = cv2.findContours(annotation, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE) 70 | for contour in contours: 71 | contour_all.append(contour) 72 | cv2.drawContours(temp, contour_all, -1, (255, 255, 255), 2 // scale) 73 | color = np.array([0 / 255, 0 / 255, 255 / 255, 0.9]) 74 | contour_mask = temp / 255 * color.reshape(1, 1, -1) 75 | 76 | image = image.convert('RGBA') 77 | overlay_inner = Image.fromarray((inner_mask * 255).astype(np.uint8), 'RGBA') 78 | image.paste(overlay_inner, (0, 0), overlay_inner) 79 | 80 | if withContours: 81 | overlay_contour = Image.fromarray((contour_mask * 255).astype(np.uint8), 'RGBA') 82 | image.paste(overlay_contour, (0, 0), overlay_contour) 83 | 84 | return image 85 | 86 | 87 | # CPU post process 88 | def fast_show_mask( 89 | annotation, 90 | ax, 91 | random_color=False, 92 | bbox=None, 93 | retinamask=True, 94 | target_height=960, 95 | target_width=960, 96 | ): 97 | mask_sum = annotation.shape[0] 98 | height = annotation.shape[1] 99 | weight = annotation.shape[2] 100 | # 将annotation 按照面积 排序 101 | areas = np.sum(annotation, axis=(1, 2)) 102 | sorted_indices = np.argsort(areas)[::1] 103 | annotation = annotation[sorted_indices] 104 | 105 | index = (annotation != 0).argmax(axis=0) 106 | if random_color: 107 | color = np.random.random((mask_sum, 1, 1, 3)) 108 | else: 109 | color = np.ones((mask_sum, 1, 1, 3)) * np.array([30 / 255, 144 / 255, 255 / 255]) 110 | transparency = np.ones((mask_sum, 1, 1, 1)) * 0.6 111 | visual = np.concatenate([color, transparency], axis=-1) 112 | mask_image = np.expand_dims(annotation, -1) * visual 113 | 114 | mask = np.zeros((height, weight, 4)) 115 | 116 | h_indices, w_indices = np.meshgrid(np.arange(height), np.arange(weight), indexing='ij') 117 | indices = (index[h_indices, w_indices], h_indices, w_indices, slice(None)) 118 | 119 | mask[h_indices, w_indices, :] = mask_image[indices] 120 | if bbox is not None: 121 | x1, y1, x2, y2 = bbox 122 | ax.add_patch(plt.Rectangle((x1, y1), x2 - x1, y2 - y1, fill=False, edgecolor='b', linewidth=1)) 123 | 124 | if not retinamask: 125 | mask = cv2.resize(mask, (target_width, target_height), interpolation=cv2.INTER_NEAREST) 126 | 127 | return mask 128 | 129 | 130 | def fast_show_mask_gpu( 131 | annotation, 132 | ax, 133 | random_color=False, 134 | bbox=None, 135 | retinamask=True, 136 | target_height=960, 137 | target_width=960, 138 | ): 139 | device = annotation.device 140 | mask_sum = annotation.shape[0] 141 | height = annotation.shape[1] 142 | weight = annotation.shape[2] 143 | areas = torch.sum(annotation, dim=(1, 2)) 144 | sorted_indices = torch.argsort(areas, descending=False) 145 | annotation = annotation[sorted_indices] 146 | # 找每个位置第一个非零值下标 147 | index = (annotation != 0).to(torch.long).argmax(dim=0) 148 | if random_color: 149 | color = torch.rand((mask_sum, 1, 1, 3)).to(device) 150 | else: 151 | color = torch.ones((mask_sum, 1, 1, 3)).to(device) * torch.tensor( 152 | [30 / 255, 144 / 255, 255 / 255] 153 | ).to(device) 154 | transparency = torch.ones((mask_sum, 1, 1, 1)).to(device) * 0.6 155 | visual = torch.cat([color, transparency], dim=-1) 156 | mask_image = torch.unsqueeze(annotation, -1) * visual 157 | # 按index取数,index指每个位置选哪个batch的数,把mask_image转成一个batch的形式 158 | mask = torch.zeros((height, weight, 4)).to(device) 159 | h_indices, w_indices = torch.meshgrid(torch.arange(height), torch.arange(weight)) 160 | indices = (index[h_indices, w_indices], h_indices, w_indices, slice(None)) 161 | # 使用向量化索引更新show的值 162 | mask[h_indices, w_indices, :] = mask_image[indices] 163 | mask_cpu = mask.cpu().numpy() 164 | if bbox is not None: 165 | x1, y1, x2, y2 = bbox 166 | ax.add_patch( 167 | plt.Rectangle( 168 | (x1, y1), x2 - x1, y2 - y1, fill=False, edgecolor="b", linewidth=1 169 | ) 170 | ) 171 | if not retinamask: 172 | mask_cpu = cv2.resize( 173 | mask_cpu, (target_width, target_height), interpolation=cv2.INTER_NEAREST 174 | ) 175 | return mask_cpu 176 | -------------------------------------------------------------------------------- /DocImgTranslation/structure.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import sys 4 | sys.path.append(os.getcwd()) 5 | import my_utils 6 | import numpy as np 7 | from bs4 import BeautifulSoup 8 | import transformers 9 | from transformers import AutoTokenizer, AutoModelForSeq2SeqLM 10 | from PIL import Image,ImageFont,ImageDraw 11 | from paddleocr import PPStructure,save_structure_res,PaddleOCR 12 | from paddleocr.ppstructure.recovery.recovery_to_doc import sorted_layout_boxes, convert_info_docx 13 | 14 | def image_translate(img,text_region,text): 15 | x_coords = [text_region[0][0], text_region[1][0], text_region[2][0], text_region[3][0]] 16 | y_coords = [text_region[0][1],text_region[1][1],text_region[2][1],text_region[3][1]] 17 | x1, x2 = int(min(x_coords)), int(max(x_coords)) 18 | y1, y2 = int(min(y_coords)), int(max(y_coords)) 19 | img[y1:y2, x1:x2] = (255, 255, 255) # 将区域填充为背景色 20 | bbox_width = max(x_coords) - min(x_coords) 21 | bbox_height = max(y_coords) - min(y_coords) 22 | font_size = 12 23 | # 将NumPy数组转换为PIL Image对象 24 | image_pil = Image.fromarray(img) 25 | draw = ImageDraw.Draw(image_pil) 26 | font = ImageFont.truetype("static/simfang.ttf", font_size) 27 | while draw.textsize(text, font=font)[0] < bbox_width and draw.textsize(text, font=font)[ 28 | 1] < bbox_height: 29 | font_size += 1 30 | font = ImageFont.truetype("static/simfang.ttf", font_size) 31 | text_x = text_region[0][0] 32 | text_y = text_region[0][1] 33 | font_color = (0, 0, 0) 34 | draw.text((text_x, text_y), text, font=font, fill=font_color) 35 | # image_pil.save("output_image.jpg") 36 | # 将修改后的PIL Image对象转回OpenCV格式 37 | img = cv2.cvtColor(np.array(image_pil), cv2.COLOR_RGB2BGR) 38 | return img 39 | def table_translate(html_table,translate_model): 40 | # 解析 HTML 41 | soup = BeautifulSoup(html_table, 'html.parser') 42 | # 提取并翻译表格的第一行(表头) 43 | original_header_row = soup.find('tr') 44 | translated_header_row = '' 45 | for cell in original_header_row.find_all('td'): 46 | original_text = cell.text.strip() 47 | translated_text = translate_model(original_text)[0]['translation_text'] 48 | translated_text = my_utils.do_sentence(translated_text) 49 | translated_cell = f'{translated_text}' 50 | translated_header_row += translated_cell 51 | translated_header_row += '' 52 | 53 | # 提取并翻译表格的内容行 54 | translated_table = soup.new_tag('table') # 创建一个新的表格用于存储翻译后的文本 55 | translated_table.append(BeautifulSoup(translated_header_row, 'html.parser')) # 将翻译后的表头行插入到表格中 56 | for row in soup.find_all('tr')[1:]: 57 | translated_row = '' # 创建新的行 58 | for cell in row.find_all('td'): 59 | original_text = cell.text.strip() # 提取原始文本 60 | translated_text = translate_model(original_text)[0]['translation_text'] 61 | translated_text = my_utils.do_sentence(translated_text) 62 | translated_cell = f'{translated_text}' # 创建新的单元格 63 | translated_row += translated_cell # 将单元格添加到行中 64 | translated_row += '' 65 | translated_table.append(BeautifulSoup(translated_row, 'html.parser')) # 将行添加到表格中 66 | 67 | return translated_table 68 | def Structure(img,img_path,save_path,mode): 69 | # 加载翻译模型 70 | if mode == "en2ch": 71 | model_path = 'models/translate/en-zh' 72 | # translate_model = hub.Module(name='transformer_zh-en', beam_size=5) 73 | # translate_model = hub.Module(name='baidu_translate') 74 | # translate_model = hub.Module(name='tfbasemt_enzh', beam_size=5) 75 | # English image 76 | tokenizer = AutoTokenizer.from_pretrained(model_path) 77 | model = AutoModelForSeq2SeqLM.from_pretrained(model_path) 78 | translate_model = transformers.pipeline("translation", model=model, tokenizer=tokenizer) 79 | table_engine = PPStructure(show_log=False, recovery=True, image_orientation=True, lang='en',cls_model_dir='models/ocr/cls/ch_ppocr_mobile_v2.0_cls_infer',det_model_dir='models/ocr/det/en/en_PP-OCRv4_det_infer',rec_model_dir='models/ocr/rec/en/en_PP-OCRv4_rec_infer', 80 | layout_model_dir='models/ocr/layout/picodet_lcnet_x1_0_fgd_layout_infer',table_model_dir='models/ocr/table/en_ppstructure_mobile_v2.0_SLANet_infer') 81 | ocr = PaddleOCR(use_angle_cls=True, lang='ch',cls_model_dir='models/ocr/cls/ch_ppocr_mobile_v2.0_cls_infer',det_model_dir='models/ocr/det/en/en_PP-OCRv4_det_infer',rec_model_dir='models/ocr/rec/en/en_PP-OCRv4_rec_infer') 82 | else: 83 | model_path = 'models/translate/zh-en' 84 | tokenizer = AutoTokenizer.from_pretrained(model_path) 85 | model = AutoModelForSeq2SeqLM.from_pretrained(model_path) 86 | translate_model = transformers.pipeline("translation", model=model, tokenizer=tokenizer) 87 | # Chinese image 88 | table_engine = PPStructure(show_log=False, recovery=True, image_orientation=True, layout_model_dir="models/ocr/layout/CDLA dict", 89 | structure_version="PP-StructureV2",cls_model_dir='models/ocr/cls/ch_ppocr_mobile_v2.0_cls_infer',det_model_dir='models/ocr/det/ch/ch_PP-OCRv4_det_infer',rec_model_dir='models/ocr/rec/ch/ch_PP-OCRv4_rec_infer', 90 | table_model_dir='models/ocr/table/ch_ppstructure_mobile_v2.0_SLANet_infer') 91 | ocr = PaddleOCR(use_angle_cls=True, lang='ch',cls_model_dir='models/ocr/cls/ch_ppocr_mobile_v2.0_cls_infer',det_model_dir='models/ocr/det/ch/ch_PP-OCRv4_det_infer',rec_model_dir='models/ocr/rec/ch/ch_PP-OCRv4_rec_infer') 92 | 93 | save_folder = os.path.join(save_path,'./output') 94 | result = table_engine(img) 95 | for i in range(len(result)): 96 | if result[i]["type"] != "figure" and result[i]["type"] != "table" and result[i]["type"] != "equation": 97 | # 本地中译英模型 98 | for j in result[i]["res"]: 99 | # print(j) 100 | content = j["text"] 101 | translate_text = translate_model(content)[0]['translation_text'] 102 | translate_text = my_utils.do_sentence(translate_text) 103 | j["text"] = translate_text 104 | elif result[i]["type"] == "figure": 105 | roi_img = result[i]["img"] 106 | res = ocr.ocr(roi_img, cls=True) 107 | for j in res[0]: 108 | content = j[1][0] 109 | translate_text = translate_model(content)[0]['translation_text'] 110 | translate_text = my_utils.do_sentence(translate_text) 111 | text_region = j[0] 112 | roi_img = image_translate(roi_img, text_region, translate_text) 113 | result[i]['img'] = roi_img 114 | elif result[i]["type"] == "table": 115 | result[i]["res"]["html"] = "{}".format(table_translate(result[i]["res"]["html"], translate_model)) 116 | save_structure_res(result, save_folder, os.path.basename(img_path).split('.')[0]) 117 | 118 | # for line in result: 119 | # line.pop('img') 120 | # # print(line) 121 | h, w, _ = img.shape 122 | res = sorted_layout_boxes(result, w) 123 | convert_info_docx(img, res, save_folder, os.path.basename(img_path).split('.')[0]) -------------------------------------------------------------------------------- /FileTranslation/PDF/pdf_translation_ch2en.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import cv2 4 | import numpy as np 5 | import fitz # fitz就是pip install PyMuPDF 6 | import transformers 7 | from transformers import AutoTokenizer, AutoModelForSeq2SeqLM 8 | from bs4 import BeautifulSoup 9 | from PIL import Image,ImageDraw,ImageFont 10 | from paddleocr import PPStructure,save_structure_res,PaddleOCR 11 | from paddleocr.ppstructure.recovery.recovery_to_doc import sorted_layout_boxes, convert_info_docx 12 | from paddleocr.ppocr.utils.logging import get_logger 13 | 14 | logger = get_logger() 15 | def pyMuPDF_fitz(pdfPath): 16 | pdf_name = os.path.basename(pdfPath).split('.')[0] 17 | pdfDoc = fitz.open(pdfPath) 18 | imgs = [] 19 | logger.info('检测PDF文档有共{}页'.format(pdfDoc.page_count)) 20 | logger.info('正在处理...') 21 | for pg in range(0,pdfDoc.page_count): 22 | page = pdfDoc[pg] 23 | rotate = int(0) 24 | # 每个尺寸的缩放系数为1.3,这将为我们生成分辨率提高2.6的图像。 25 | # 此处若是不做设置,默认图片大小为:792X612, dpi=96 26 | zoom_x = 2 # (1.33333333-->1056x816) (2-->1584x1224) 27 | zoom_y = 2 28 | mat = fitz.Matrix(zoom_x, zoom_y).prerotate(rotate) 29 | pm = page.get_pixmap(matrix=mat, alpha=False) 30 | img = Image.frombytes("RGB", [pm.width, pm.height], pm.samples) 31 | img = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR) 32 | imgs.append(img) 33 | # if not os.path.exists(imagePath): # 判断存放图片的文件夹是否存在 34 | # os.makedirs(imagePath) # 若图片文件夹不存在就创建 35 | # img.imwrite('./1.pjg',0) # 将图片写入指定的文件夹内 36 | 37 | return imgs,pdf_name 38 | def image_translate(img,text_region,text): 39 | x_coords = [text_region[0][0], text_region[1][0], text_region[2][0], text_region[3][0]] 40 | y_coords = [text_region[0][1],text_region[1][1],text_region[2][1],text_region[3][1]] 41 | x1, x2 = int(min(x_coords)), int(max(x_coords)) 42 | y1, y2 = int(min(y_coords)), int(max(y_coords)) 43 | img[y1:y2, x1:x2] = (255, 255, 255) # 将区域填充为背景色 44 | bbox_width = max(x_coords) - min(x_coords) 45 | bbox_height = max(y_coords) - min(y_coords) 46 | font_size = 1 47 | # 将NumPy数组转换为PIL Image对象 48 | image_pil = Image.fromarray(img) 49 | draw = ImageDraw.Draw(image_pil) 50 | font = ImageFont.truetype("static/simfang.ttf", font_size) 51 | while draw.textsize(text, font=font)[0] < bbox_width and draw.textsize(text, font=font)[ 52 | 1] < bbox_height: 53 | font_size += 1 54 | font = ImageFont.truetype("static/simfang.ttf", font_size) 55 | text_x = text_region[0][0] 56 | text_y = text_region[0][1] 57 | font_color = (0, 0, 0) 58 | draw.text((text_x, text_y), text, font=font, fill=font_color) 59 | # image_pil.save("output_image.jpg") 60 | # 将修改后的PIL Image对象转回OpenCV格式 61 | img = cv2.cvtColor(np.array(image_pil), cv2.COLOR_RGB2BGR) 62 | return img 63 | def table_translate(html_table,translate_model): 64 | # 解析 HTML 65 | soup = BeautifulSoup(html_table, 'html.parser') 66 | # 提取并翻译表格的第一行(表头) 67 | original_header_row = soup.find('tr') 68 | translated_header_row = '' 69 | for cell in original_header_row.find_all('td'): 70 | original_text = cell.text.strip() 71 | translated_text = translate_model(original_text)[0]['translation_text'] 72 | translated_cell = f'{translated_text}' 73 | translated_header_row += translated_cell 74 | translated_header_row += '' 75 | 76 | # 提取并翻译表格的内容行 77 | translated_table = soup.new_tag('table') # 创建一个新的表格用于存储翻译后的文本 78 | translated_table.append(BeautifulSoup(translated_header_row, 'html.parser')) # 将翻译后的表头行插入到表格中 79 | for row in soup.find_all('tr')[1:]: 80 | translated_row = '' # 创建新的行 81 | for cell in row.find_all('td'): 82 | original_text = cell.text.strip() # 提取原始文本 83 | translated_text = translate_model(original_text)[0]['translation_text'] 84 | translated_cell = f'{translated_text}' # 创建新的单元格 85 | translated_row += translated_cell # 将单元格添加到行中 86 | translated_row += '' 87 | translated_table.append(BeautifulSoup(translated_row, 'html.parser')) # 将行添加到表格中 88 | 89 | return translated_table 90 | def img2doc(imgs,pdf_name,imagePath): 91 | # 加载翻译模型 92 | #translate_model = hub.Module(name='transformer_zh-en', beam_size=5) 93 | # translate_model = hub.Module(name='baidu_translate') 94 | model_path = 'models/translate/zh-en' 95 | tokenizer = AutoTokenizer.from_pretrained(model_path) 96 | model = AutoModelForSeq2SeqLM.from_pretrained(model_path) 97 | translate_model = transformers.pipeline("translation", model=model, tokenizer=tokenizer) 98 | all_res = [] 99 | table_engine = PPStructure(show_log=False, recovery=True, image_orientation=True, layout_model_dir="models/ocr/layout/CDLA dict", 100 | structure_version="PP-StructureV2",cls_model_dir='models/ocr/cls/ch_ppocr_mobile_v2.0_cls_infer',det_model_dir='models/ocr/det/ch/ch_PP-OCRv4_det_infer',rec_model_dir='models/ocr/rec/ch/ch_PP-OCRv4_rec_infer', 101 | table_model_dir='models/ocr/table/ch_ppstructure_mobile_v2.0_SLANet_infer') 102 | # table_engine = PPStructure(show_log=False,recovery=True, lang='en') 103 | ocr = PaddleOCR(use_angle_cls=True, lang='ch',cls_model_dir='models/ocr/cls/ch_ppocr_mobile_v2.0_cls_infer',det_model_dir='models/ocr/det/ch/ch_PP-OCRv4_det_infer',rec_model_dir='models/ocr/rec/ch/ch_PP-OCRv4_rec_infer') 104 | save_folder = os.path.join(imagePath,'output/') 105 | for index, img in enumerate(imgs): 106 | result = table_engine(img) 107 | for i in range(len(result)): 108 | if result[i]["type"] != "figure" and result[i]["type"] != "table" and result[i]["type"] != "equation": 109 | #本地中译英模型 110 | for j in result[i]["res"]: 111 | content = j["text"] 112 | translate_text = translate_model(content)[0]['translation_text'] 113 | j["text"] = translate_text 114 | #百度模型英译中 115 | # content = j["text"] 116 | # translate_text = translate_model.translate(content) 117 | # j["text"] = translate_text 118 | elif result[i]["type"] == "figure": 119 | roi_img = result[i]["img"] 120 | res = ocr.ocr(roi_img, cls=True) 121 | for j in res[0]: 122 | content = j[1][0] 123 | translate_text = translate_model(content)[0]['translation_text'] 124 | text_region = j[0] 125 | roi_img = image_translate(roi_img,text_region, translate_text) 126 | result[i]['img'] = roi_img 127 | elif result[i]["type"] == "table": 128 | result[i]["res"]["html"] = "{}".format(table_translate(result[i]["res"]["html"],translate_model)) 129 | 130 | save_structure_res(result, save_folder, pdf_name) 131 | h, w, _ = img.shape 132 | res = sorted_layout_boxes(result, w) 133 | all_res += res 134 | logger.info('第{}页已处理完成...'.format(index+1)) 135 | convert_info_docx(imgs, all_res, save_folder, pdf_name) 136 | logger.info('已全部处理完成!') 137 | if __name__ == "__main__": 138 | imagePath = sys.argv[2] 139 | # 1、PDF地址 140 | pdfPath = sys.argv[1] 141 | # 2、需要储存图片的目录 142 | imgs,pdf_name= pyMuPDF_fitz(pdfPath) 143 | img2doc(imgs,pdf_name,imagePath) -------------------------------------------------------------------------------- /FileTranslation/PDF/pdf_translation_en2ch.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import sys 4 | sys.path.append(os.getcwd()) 5 | import my_utils 6 | import numpy as np 7 | import fitz # fitz就是pip install PyMuPDF 8 | import transformers 9 | from transformers import AutoTokenizer, AutoModelForSeq2SeqLM 10 | from bs4 import BeautifulSoup 11 | from PIL import Image,ImageDraw,ImageFont 12 | from paddleocr import PPStructure,save_structure_res,PaddleOCR 13 | from paddleocr.ppstructure.recovery.recovery_to_doc import sorted_layout_boxes, convert_info_docx 14 | from paddleocr.ppocr.utils.logging import get_logger 15 | 16 | logger = get_logger() 17 | def pyMuPDF_fitz(pdfPath): 18 | pdf_name = os.path.basename(pdfPath).split('.')[0] 19 | pdfDoc = fitz.open(pdfPath) 20 | imgs = [] 21 | logger.info('检测PDF文档有共{}页'.format(pdfDoc.page_count)) 22 | logger.info('正在处理...') 23 | for pg in range(0,pdfDoc.page_count): 24 | page = pdfDoc[pg] 25 | rotate = int(0) 26 | # 每个尺寸的缩放系数为1.3,这将为我们生成分辨率提高2.6的图像。 27 | # 此处若是不做设置,默认图片大小为:792X612, dpi=96 28 | zoom_x = 2 # (1.33333333-->1056x816) (2-->1584x1224) 29 | zoom_y = 2 30 | mat = fitz.Matrix(zoom_x, zoom_y).prerotate(rotate) 31 | pm = page.get_pixmap(matrix=mat, alpha=False) 32 | img = Image.frombytes("RGB", [pm.width, pm.height], pm.samples) 33 | img = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR) 34 | imgs.append(img) 35 | # if not os.path.exists(imagePath): # 判断存放图片的文件夹是否存在 36 | # os.makedirs(imagePath) # 若图片文件夹不存在就创建 37 | # img.imwrite('./1.pjg',0) # 将图片写入指定的文件夹内 38 | 39 | return imgs,pdf_name 40 | def image_translate(img,text_region,text): 41 | x_coords = [text_region[0][0], text_region[1][0], text_region[2][0], text_region[3][0]] 42 | y_coords = [text_region[0][1],text_region[1][1],text_region[2][1],text_region[3][1]] 43 | x1, x2 = int(min(x_coords)), int(max(x_coords)) 44 | y1, y2 = int(min(y_coords)), int(max(y_coords)) 45 | img[y1:y2, x1:x2] = (255, 255, 255) # 将区域填充为背景色 46 | bbox_width = max(x_coords) - min(x_coords) 47 | bbox_height = max(y_coords) - min(y_coords) 48 | font_size = 1 49 | # 将NumPy数组转换为PIL Image对象 50 | image_pil = Image.fromarray(img) 51 | draw = ImageDraw.Draw(image_pil) 52 | font = ImageFont.truetype("static/simfang.ttf", font_size) 53 | while draw.textsize(text, font=font)[0] < bbox_width and draw.textsize(text, font=font)[ 54 | 1] < bbox_height: 55 | font_size += 1 56 | font = ImageFont.truetype("static/simfang.ttf", font_size) 57 | text_x = text_region[0][0] 58 | text_y = text_region[0][1] 59 | font_color = (0, 0, 0) 60 | draw.text((text_x, text_y), text, font=font, fill=font_color) 61 | # image_pil.save("output_image.jpg") 62 | # 将修改后的PIL Image对象转回OpenCV格式 63 | img = cv2.cvtColor(np.array(image_pil), cv2.COLOR_RGB2BGR) 64 | return img 65 | def table_translate(html_table,translate_model): 66 | # 解析 HTML 67 | soup = BeautifulSoup(html_table, 'html.parser') 68 | # 提取并翻译表格的第一行(表头) 69 | original_header_row = soup.find('tr') 70 | translated_header_row = '' 71 | for cell in original_header_row.find_all('td'): 72 | original_text = cell.text.strip() 73 | translated_text = translate_model(original_text)[0]['translation_text'] 74 | translated_text = my_utils.do_sentence(translated_text) 75 | translated_cell = f'{translated_text}' 76 | translated_header_row += translated_cell 77 | translated_header_row += '' 78 | 79 | # 提取并翻译表格的内容行 80 | translated_table = soup.new_tag('table') # 创建一个新的表格用于存储翻译后的文本 81 | translated_table.append(BeautifulSoup(translated_header_row, 'html.parser')) # 将翻译后的表头行插入到表格中 82 | for row in soup.find_all('tr')[1:]: 83 | translated_row = '' # 创建新的行 84 | for cell in row.find_all('td'): 85 | original_text = cell.text.strip() # 提取原始文本 86 | translated_text = translate_model(original_text)[0]['translation_text'] 87 | translated_text = my_utils.do_sentence(translated_text) 88 | translated_cell = f'{translated_text}' # 创建新的单元格 89 | translated_row += translated_cell # 将单元格添加到行中 90 | translated_row += '' 91 | translated_table.append(BeautifulSoup(translated_row, 'html.parser')) # 将行添加到表格中 92 | 93 | return translated_table 94 | def img2doc(imgs,pdf_name,imagePath): 95 | # 加载翻译模型 96 | #translate_model = hub.Module(name='transformer_zh-en', beam_size=5) 97 | # translate_model = hub.Module(name='baidu_translate') 98 | model_path = 'models/translate/en-zh' 99 | tokenizer = AutoTokenizer.from_pretrained(model_path) 100 | model = AutoModelForSeq2SeqLM.from_pretrained(model_path) 101 | translate_model = transformers.pipeline("translation", model=model, tokenizer=tokenizer) 102 | all_res = [] 103 | # table_engine = PPStructure(show_log=False, recovery=True, layout_model_dir="models/CDLA dict", save_pdf=True, structure_version="PP-StructureV2") 104 | table_engine = PPStructure(show_log=False,recovery=True, lang='en',cls_model_dir='models/ocr/cls/ch_ppocr_mobile_v2.0_cls_infer',det_model_dir='models/ocr/det/en/en_PP-OCRv4_det_infer',rec_model_dir='models/ocr/rec/en/en_PP-OCRv4_rec_infer', 105 | layout_model_dir='models/ocr/layout/picodet_lcnet_x1_0_fgd_layout_infer',table_model_dir='models/ocr/table/en_ppstructure_mobile_v2.0_SLANet_infer') 106 | ocr = PaddleOCR(use_angle_cls=True, lang='en',cls_model_dir='models/ocr/cls/ch_ppocr_mobile_v2.0_cls_infer',det_model_dir='models/ocr/det/en/en_PP-OCRv4_det_infer',rec_model_dir='models/ocr/rec/en/en_PP-OCRv4_rec_infer') 107 | save_folder = os.path.join(imagePath,'output/') 108 | for index, img in enumerate(imgs): 109 | result = table_engine(img) 110 | for i in range(len(result)): 111 | if result[i]["type"] != "figure" and result[i]["type"] != "table" and result[i]["type"] != "equation": 112 | #本地中译英模型 113 | for j in result[i]["res"]: 114 | content = j["text"] 115 | translate_text = translate_model(content)[0]['translation_text'] 116 | translate_text = my_utils.do_sentence(translate_text) 117 | j["text"] = translate_text 118 | #百度模型英译中 119 | # content = j["text"] 120 | # translate_text = translate_model.translate(content) 121 | # j["text"] = translate_text 122 | elif result[i]["type"] == "figure": 123 | roi_img = result[i]["img"] 124 | res = ocr.ocr(roi_img, cls=True) 125 | for j in res[0]: 126 | content = j[1][0] 127 | translate_text = translate_model(content)[0]['translation_text'] 128 | translate_text = my_utils.do_sentence(translate_text) 129 | text_region = j[0] 130 | roi_img = image_translate(roi_img,text_region, translate_text) 131 | result[i]['img'] = roi_img 132 | elif result[i]["type"] == "table": 133 | result[i]["res"]["html"] = "{}".format(table_translate(result[i]["res"]["html"],translate_model)) 134 | 135 | save_structure_res(result, save_folder, pdf_name) 136 | h, w, _ = img.shape 137 | res = sorted_layout_boxes(result, w) 138 | all_res += res 139 | logger.info('第{}页已处理完成...'.format(index+1)) 140 | convert_info_docx(imgs, all_res, save_folder, pdf_name) 141 | logger.info('已全部处理完成!') 142 | if __name__ == "__main__": 143 | imagePath = sys.argv[2] 144 | # 1、PDF地址 145 | pdfPath = sys.argv[1] 146 | # 2、需要储存图片的目录 147 | imgs,pdf_name= pyMuPDF_fitz(pdfPath) 148 | img2doc(imgs,pdf_name,imagePath) -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /DocImgTranslation/FastSAM/app_gradio.py: -------------------------------------------------------------------------------- 1 | from ultralytics import YOLO 2 | import gradio as gr 3 | import torch 4 | from utils.tools_gradio import fast_process 5 | from utils.tools import format_results, box_prompt, point_prompt, text_prompt 6 | from PIL import ImageDraw 7 | import numpy as np 8 | 9 | # Load the pre-trained model 10 | model = YOLO('./weights/FastSAM.pt') 11 | 12 | device = torch.device( 13 | "cuda" 14 | if torch.cuda.is_available() 15 | else "mps" 16 | if torch.backends.mps.is_available() 17 | else "cpu" 18 | ) 19 | 20 | # Description 21 | title = "
🏃 Fast Segment Anything 🤗
" 22 | 23 | news = """ # 📖 News 24 | 🔥 2023/07/14: Add a "wider result" button in text mode (Thanks for [gaoxinge](https://github.com/CASIA-IVA-Lab/FastSAM/pull/95)). 25 | 26 | 🔥 2023/06/29: Support the text mode (Thanks for [gaoxinge](https://github.com/CASIA-IVA-Lab/FastSAM/pull/47)). 27 | 28 | 🔥 2023/06/26: Support the points mode. (Better and faster interaction will come soon!) 29 | 30 | 🔥 2023/06/24: Add the 'Advanced options" in Everything mode to get a more detailed adjustment. 31 | """ 32 | 33 | description_e = """This is a demo on Github project 🏃 [Fast Segment Anything Model](https://github.com/CASIA-IVA-Lab/FastSAM). Welcome to give a star ⭐️ to it. 34 | 35 | 🎯 Upload an Image, segment it with Fast Segment Anything (Everything mode). The other modes will come soon. 36 | 37 | ⌛️ It takes about 6~ seconds to generate segment results. The concurrency_count of queue is 1, please wait for a moment when it is crowded. 38 | 39 | 🚀 To get faster results, you can use a smaller input size and leave high_visual_quality unchecked. 40 | 41 | 📣 You can also obtain the segmentation results of any Image through this Colab: [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1oX14f6IneGGw612WgVlAiy91UHwFAvr9?usp=sharing) 42 | 43 | 😚 A huge thanks goes out to the @HuggingFace Team for supporting us with GPU grant. 44 | 45 | 🏠 Check out our [Model Card 🏃](https://huggingface.co/An-619/FastSAM) 46 | 47 | """ 48 | 49 | description_p = """ # 🎯 Instructions for points mode 50 | This is a demo on Github project 🏃 [Fast Segment Anything Model](https://github.com/CASIA-IVA-Lab/FastSAM). Welcome to give a star ⭐️ to it. 51 | 52 | 1. Upload an image or choose an example. 53 | 54 | 2. Choose the point label ('Add mask' means a positive point. 'Remove' Area means a negative point that is not segmented). 55 | 56 | 3. Add points one by one on the image. 57 | 58 | 4. Click the 'Segment with points prompt' button to get the segmentation results. 59 | 60 | **5. If you get Error, click the 'Clear points' button and try again may help.** 61 | 62 | """ 63 | 64 | examples = [["examples/sa_8776.jpg"], ["examples/sa_414.jpg"], ["examples/sa_1309.jpg"], ["examples/sa_11025.jpg"], 65 | ["examples/sa_561.jpg"], ["examples/sa_192.jpg"], ["examples/sa_10039.jpg"], ["examples/sa_862.jpg"]] 66 | 67 | default_example = examples[0] 68 | 69 | css = "h1 { text-align: center } .about { text-align: justify; padding-left: 10%; padding-right: 10%; }" 70 | 71 | 72 | def segment_everything( 73 | input, 74 | input_size=1024, 75 | iou_threshold=0.7, 76 | conf_threshold=0.25, 77 | better_quality=False, 78 | withContours=True, 79 | use_retina=True, 80 | text="", 81 | wider=False, 82 | mask_random_color=True, 83 | ): 84 | input_size = int(input_size) # 确保 imgsz 是整数 85 | # Thanks for the suggestion by hysts in HuggingFace. 86 | w, h = input.size 87 | scale = input_size / max(w, h) 88 | new_w = int(w * scale) 89 | new_h = int(h * scale) 90 | input = input.resize((new_w, new_h)) 91 | 92 | results = model(input, 93 | device=device, 94 | retina_masks=True, 95 | iou=iou_threshold, 96 | conf=conf_threshold, 97 | imgsz=input_size,) 98 | 99 | if len(text) > 0: 100 | results = format_results(results[0], 0) 101 | annotations, _ = text_prompt(results, text, input, device=device, wider=wider) 102 | annotations = np.array([annotations]) 103 | else: 104 | annotations = results[0].masks.data 105 | 106 | fig = fast_process(annotations=annotations, 107 | image=input, 108 | device=device, 109 | scale=(1024 // input_size), 110 | better_quality=better_quality, 111 | mask_random_color=mask_random_color, 112 | bbox=None, 113 | use_retina=use_retina, 114 | withContours=withContours,) 115 | return fig 116 | 117 | 118 | def segment_with_points( 119 | input, 120 | input_size=1024, 121 | iou_threshold=0.7, 122 | conf_threshold=0.25, 123 | better_quality=False, 124 | withContours=True, 125 | use_retina=True, 126 | mask_random_color=True, 127 | ): 128 | global global_points 129 | global global_point_label 130 | 131 | input_size = int(input_size) # 确保 imgsz 是整数 132 | # Thanks for the suggestion by hysts in HuggingFace. 133 | w, h = input.size 134 | scale = input_size / max(w, h) 135 | new_w = int(w * scale) 136 | new_h = int(h * scale) 137 | input = input.resize((new_w, new_h)) 138 | 139 | scaled_points = [[int(x * scale) for x in point] for point in global_points] 140 | 141 | results = model(input, 142 | device=device, 143 | retina_masks=True, 144 | iou=iou_threshold, 145 | conf=conf_threshold, 146 | imgsz=input_size,) 147 | 148 | results = format_results(results[0], 0) 149 | annotations, _ = point_prompt(results, scaled_points, global_point_label, new_h, new_w) 150 | annotations = np.array([annotations]) 151 | 152 | fig = fast_process(annotations=annotations, 153 | image=input, 154 | device=device, 155 | scale=(1024 // input_size), 156 | better_quality=better_quality, 157 | mask_random_color=mask_random_color, 158 | bbox=None, 159 | use_retina=use_retina, 160 | withContours=withContours,) 161 | 162 | global_points = [] 163 | global_point_label = [] 164 | return fig, None 165 | 166 | 167 | def get_points_with_draw(image, label, evt: gr.SelectData): 168 | global global_points 169 | global global_point_label 170 | 171 | x, y = evt.index[0], evt.index[1] 172 | point_radius, point_color = 15, (255, 255, 0) if label == 'Add Mask' else (255, 0, 255) 173 | global_points.append([x, y]) 174 | global_point_label.append(1 if label == 'Add Mask' else 0) 175 | 176 | print(x, y, label == 'Add Mask') 177 | 178 | # 创建一个可以在图像上绘图的对象 179 | draw = ImageDraw.Draw(image) 180 | draw.ellipse([(x - point_radius, y - point_radius), (x + point_radius, y + point_radius)], fill=point_color) 181 | return image 182 | 183 | 184 | cond_img_e = gr.Image(label="Input", value=default_example[0], type='pil') 185 | cond_img_p = gr.Image(label="Input with points", value=default_example[0], type='pil') 186 | cond_img_t = gr.Image(label="Input with text", value="examples/dogs.jpg", type='pil') 187 | 188 | segm_img_e = gr.Image(label="Segmented Image", interactive=False, type='pil') 189 | segm_img_p = gr.Image(label="Segmented Image with points", interactive=False, type='pil') 190 | segm_img_t = gr.Image(label="Segmented Image with text", interactive=False, type='pil') 191 | 192 | global_points = [] 193 | global_point_label = [] 194 | 195 | input_size_slider = gr.components.Slider(minimum=512, 196 | maximum=1024, 197 | value=1024, 198 | step=64, 199 | label='Input_size', 200 | info='Our model was trained on a size of 1024') 201 | 202 | with gr.Blocks(css=css, title='Fast Segment Anything') as demo: 203 | with gr.Row(): 204 | with gr.Column(scale=1): 205 | # Title 206 | gr.Markdown(title) 207 | 208 | with gr.Column(scale=1): 209 | # News 210 | gr.Markdown(news) 211 | 212 | with gr.Tab("Everything mode"): 213 | # Images 214 | with gr.Row(variant="panel"): 215 | with gr.Column(scale=1): 216 | cond_img_e.render() 217 | 218 | with gr.Column(scale=1): 219 | segm_img_e.render() 220 | 221 | # Submit & Clear 222 | with gr.Row(): 223 | with gr.Column(): 224 | input_size_slider.render() 225 | 226 | with gr.Row(): 227 | contour_check = gr.Checkbox(value=True, label='withContours', info='draw the edges of the masks') 228 | 229 | with gr.Column(): 230 | segment_btn_e = gr.Button("Segment Everything", variant='primary') 231 | clear_btn_e = gr.Button("Clear", variant="secondary") 232 | 233 | gr.Markdown("Try some of the examples below ⬇️") 234 | gr.Examples(examples=examples, 235 | inputs=[cond_img_e], 236 | outputs=segm_img_e, 237 | fn=segment_everything, 238 | cache_examples=True, 239 | examples_per_page=4) 240 | 241 | with gr.Column(): 242 | with gr.Accordion("Advanced options", open=False): 243 | iou_threshold = gr.Slider(0.1, 0.9, 0.7, step=0.1, label='iou', info='iou threshold for filtering the annotations') 244 | conf_threshold = gr.Slider(0.1, 0.9, 0.25, step=0.05, label='conf', info='object confidence threshold') 245 | with gr.Row(): 246 | mor_check = gr.Checkbox(value=False, label='better_visual_quality', info='better quality using morphologyEx') 247 | with gr.Column(): 248 | retina_check = gr.Checkbox(value=True, label='use_retina', info='draw high-resolution segmentation masks') 249 | 250 | # Description 251 | gr.Markdown(description_e) 252 | 253 | segment_btn_e.click(segment_everything, 254 | inputs=[ 255 | cond_img_e, 256 | input_size_slider, 257 | iou_threshold, 258 | conf_threshold, 259 | mor_check, 260 | contour_check, 261 | retina_check, 262 | ], 263 | outputs=segm_img_e) 264 | 265 | with gr.Tab("Points mode"): 266 | # Images 267 | with gr.Row(variant="panel"): 268 | with gr.Column(scale=1): 269 | cond_img_p.render() 270 | 271 | with gr.Column(scale=1): 272 | segm_img_p.render() 273 | 274 | # Submit & Clear 275 | with gr.Row(): 276 | with gr.Column(): 277 | with gr.Row(): 278 | add_or_remove = gr.Radio(["Add Mask", "Remove Area"], value="Add Mask", label="Point_label (foreground/background)") 279 | 280 | with gr.Column(): 281 | segment_btn_p = gr.Button("Segment with points prompt", variant='primary') 282 | clear_btn_p = gr.Button("Clear points", variant='secondary') 283 | 284 | gr.Markdown("Try some of the examples below ⬇️") 285 | gr.Examples(examples=examples, 286 | inputs=[cond_img_p], 287 | # outputs=segm_img_p, 288 | # fn=segment_with_points, 289 | # cache_examples=True, 290 | examples_per_page=4) 291 | 292 | with gr.Column(): 293 | # Description 294 | gr.Markdown(description_p) 295 | 296 | cond_img_p.select(get_points_with_draw, [cond_img_p, add_or_remove], cond_img_p) 297 | 298 | segment_btn_p.click(segment_with_points, 299 | inputs=[cond_img_p], 300 | outputs=[segm_img_p, cond_img_p]) 301 | 302 | with gr.Tab("Text mode"): 303 | # Images 304 | with gr.Row(variant="panel"): 305 | with gr.Column(scale=1): 306 | cond_img_t.render() 307 | 308 | with gr.Column(scale=1): 309 | segm_img_t.render() 310 | 311 | # Submit & Clear 312 | with gr.Row(): 313 | with gr.Column(): 314 | input_size_slider_t = gr.components.Slider(minimum=512, 315 | maximum=1024, 316 | value=1024, 317 | step=64, 318 | label='Input_size', 319 | info='Our model was trained on a size of 1024') 320 | with gr.Row(): 321 | with gr.Column(): 322 | contour_check = gr.Checkbox(value=True, label='withContours', info='draw the edges of the masks') 323 | text_box = gr.Textbox(label="text prompt", value="a black dog") 324 | 325 | with gr.Column(): 326 | segment_btn_t = gr.Button("Segment with text", variant='primary') 327 | clear_btn_t = gr.Button("Clear", variant="secondary") 328 | 329 | gr.Markdown("Try some of the examples below ⬇️") 330 | gr.Examples(examples=[["examples/dogs.jpg"]] + examples, 331 | inputs=[cond_img_e], 332 | # outputs=segm_img_e, 333 | # fn=segment_everything, 334 | # cache_examples=True, 335 | examples_per_page=4) 336 | 337 | with gr.Column(): 338 | with gr.Accordion("Advanced options", open=False): 339 | iou_threshold = gr.Slider(0.1, 0.9, 0.7, step=0.1, label='iou', info='iou threshold for filtering the annotations') 340 | conf_threshold = gr.Slider(0.1, 0.9, 0.25, step=0.05, label='conf', info='object confidence threshold') 341 | with gr.Row(): 342 | mor_check = gr.Checkbox(value=False, label='better_visual_quality', info='better quality using morphologyEx') 343 | retina_check = gr.Checkbox(value=True, label='use_retina', info='draw high-resolution segmentation masks') 344 | wider_check = gr.Checkbox(value=False, label='wider', info='wider result') 345 | 346 | # Description 347 | gr.Markdown(description_e) 348 | 349 | segment_btn_t.click(segment_everything, 350 | inputs=[ 351 | cond_img_t, 352 | input_size_slider_t, 353 | iou_threshold, 354 | conf_threshold, 355 | mor_check, 356 | contour_check, 357 | retina_check, 358 | text_box, 359 | wider_check, 360 | ], 361 | outputs=segm_img_t) 362 | 363 | def clear(): 364 | return None, None 365 | 366 | def clear_text(): 367 | return None, None, None 368 | 369 | clear_btn_e.click(clear, outputs=[cond_img_e, segm_img_e]) 370 | clear_btn_p.click(clear, outputs=[cond_img_p, segm_img_p]) 371 | clear_btn_t.click(clear_text, outputs=[cond_img_p, segm_img_p, text_box]) 372 | 373 | demo.queue() 374 | demo.launch() 375 | -------------------------------------------------------------------------------- /DocImgTranslation/FastSAM/utils/tools.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from PIL import Image 3 | import matplotlib.pyplot as plt 4 | import cv2 5 | import torch 6 | import os 7 | import sys 8 | import clip 9 | 10 | 11 | def convert_box_xywh_to_xyxy(box): 12 | if len(box) == 4: 13 | return [box[0], box[1], box[0] + box[2], box[1] + box[3]] 14 | else: 15 | result = [] 16 | for b in box: 17 | b = convert_box_xywh_to_xyxy(b) 18 | result.append(b) 19 | return result 20 | 21 | 22 | def segment_image(image, bbox): 23 | image_array = np.array(image) 24 | segmented_image_array = np.zeros_like(image_array) 25 | x1, y1, x2, y2 = bbox 26 | segmented_image_array[y1:y2, x1:x2] = image_array[y1:y2, x1:x2] 27 | segmented_image = Image.fromarray(segmented_image_array) 28 | black_image = Image.new("RGB", image.size, (255, 255, 255)) 29 | # transparency_mask = np.zeros_like((), dtype=np.uint8) 30 | transparency_mask = np.zeros( 31 | (image_array.shape[0], image_array.shape[1]), dtype=np.uint8 32 | ) 33 | transparency_mask[y1:y2, x1:x2] = 255 34 | transparency_mask_image = Image.fromarray(transparency_mask, mode="L") 35 | black_image.paste(segmented_image, mask=transparency_mask_image) 36 | return black_image 37 | 38 | 39 | def format_results(result, filter=0): 40 | annotations = [] 41 | n = len(result.masks.data) 42 | for i in range(n): 43 | annotation = {} 44 | mask = result.masks.data[i] == 1.0 45 | 46 | if torch.sum(mask) < filter: 47 | continue 48 | annotation["id"] = i 49 | annotation["segmentation"] = mask.cpu().numpy() 50 | annotation["bbox"] = result.boxes.data[i] 51 | annotation["score"] = result.boxes.conf[i] 52 | annotation["area"] = annotation["segmentation"].sum() 53 | annotations.append(annotation) 54 | return annotations 55 | 56 | 57 | def filter_masks(annotations): # filter the overlap mask 58 | annotations.sort(key=lambda x: x["area"], reverse=True) 59 | to_remove = set() 60 | for i in range(0, len(annotations)): 61 | a = annotations[i] 62 | for j in range(i + 1, len(annotations)): 63 | b = annotations[j] 64 | if i != j and j not in to_remove: 65 | # check if 66 | if b["area"] < a["area"]: 67 | if (a["segmentation"] & b["segmentation"]).sum() / b[ 68 | "segmentation" 69 | ].sum() > 0.8: 70 | to_remove.add(j) 71 | 72 | return [a for i, a in enumerate(annotations) if i not in to_remove], to_remove 73 | 74 | 75 | def get_bbox_from_mask(mask): 76 | mask = mask.astype(np.uint8) 77 | contours, hierarchy = cv2.findContours( 78 | mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE 79 | ) 80 | x1, y1, w, h = cv2.boundingRect(contours[0]) 81 | x2, y2 = x1 + w, y1 + h 82 | if len(contours) > 1: 83 | for b in contours: 84 | x_t, y_t, w_t, h_t = cv2.boundingRect(b) 85 | # 将多个bbox合并成一个 86 | x1 = min(x1, x_t) 87 | y1 = min(y1, y_t) 88 | x2 = max(x2, x_t + w_t) 89 | y2 = max(y2, y_t + h_t) 90 | h = y2 - y1 91 | w = x2 - x1 92 | return [x1, y1, x2, y2] 93 | 94 | 95 | def fast_process( 96 | annotations, args, mask_random_color, bbox=None, points=None, edges=False 97 | ): 98 | if isinstance(annotations[0], dict): 99 | annotations = [annotation["segmentation"] for annotation in annotations] 100 | result_name = os.path.basename(args.img_path) 101 | image = cv2.imread(args.img_path) 102 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) 103 | original_h = image.shape[0] 104 | original_w = image.shape[1] 105 | if sys.platform == "darwin": 106 | plt.switch_backend("TkAgg") 107 | plt.figure(figsize=(original_w/100, original_h/100)) 108 | # Add subplot with no margin. 109 | plt.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=0, wspace=0) 110 | plt.margins(0, 0) 111 | plt.gca().xaxis.set_major_locator(plt.NullLocator()) 112 | plt.gca().yaxis.set_major_locator(plt.NullLocator()) 113 | plt.imshow(image) 114 | if args.better_quality == True: 115 | if isinstance(annotations[0], torch.Tensor): 116 | annotations = np.array(annotations.cpu()) 117 | for i, mask in enumerate(annotations): 118 | mask = cv2.morphologyEx( 119 | mask.astype(np.uint8), cv2.MORPH_CLOSE, np.ones((3, 3), np.uint8) 120 | ) 121 | annotations[i] = cv2.morphologyEx( 122 | mask.astype(np.uint8), cv2.MORPH_OPEN, np.ones((8, 8), np.uint8) 123 | ) 124 | if args.device == "cpu": 125 | annotations = np.array(annotations) 126 | fast_show_mask( 127 | annotations, 128 | plt.gca(), 129 | random_color=mask_random_color, 130 | bbox=bbox, 131 | points=points, 132 | point_label=args.point_label, 133 | retinamask=args.retina, 134 | target_height=original_h, 135 | target_width=original_w, 136 | ) 137 | else: 138 | if isinstance(annotations[0], np.ndarray): 139 | annotations = torch.from_numpy(annotations) 140 | fast_show_mask_gpu( 141 | annotations, 142 | plt.gca(), 143 | random_color=args.randomcolor, 144 | bbox=bbox, 145 | points=points, 146 | point_label=args.point_label, 147 | retinamask=args.retina, 148 | target_height=original_h, 149 | target_width=original_w, 150 | ) 151 | if isinstance(annotations, torch.Tensor): 152 | annotations = annotations.cpu().numpy() 153 | if args.withContours == True: 154 | contour_all = [] 155 | temp = np.zeros((original_h, original_w, 1)) 156 | for i, mask in enumerate(annotations): 157 | if type(mask) == dict: 158 | mask = mask["segmentation"] 159 | annotation = mask.astype(np.uint8) 160 | if args.retina == False: 161 | annotation = cv2.resize( 162 | annotation, 163 | (original_w, original_h), 164 | interpolation=cv2.INTER_NEAREST, 165 | ) 166 | contours, hierarchy = cv2.findContours( 167 | annotation, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE 168 | ) 169 | for contour in contours: 170 | contour_all.append(contour) 171 | cv2.drawContours(temp, contour_all, -1, (255, 255, 255), 2) 172 | color = np.array([0 / 255, 0 / 255, 255 / 255, 0.8]) 173 | contour_mask = temp / 255 * color.reshape(1, 1, -1) 174 | plt.imshow(contour_mask) 175 | 176 | save_path = args.output 177 | if not os.path.exists(save_path): 178 | os.makedirs(save_path) 179 | plt.axis("off") 180 | fig = plt.gcf() 181 | plt.draw() 182 | 183 | try: 184 | buf = fig.canvas.tostring_rgb() 185 | except AttributeError: 186 | fig.canvas.draw() 187 | buf = fig.canvas.tostring_rgb() 188 | 189 | cols, rows = fig.canvas.get_width_height() 190 | img_array = np.fromstring(buf, dtype=np.uint8).reshape(rows, cols, 3) 191 | cv2.imwrite(os.path.join(save_path, result_name), cv2.cvtColor(img_array, cv2.COLOR_RGB2BGR)) 192 | 193 | 194 | # CPU post process 195 | def fast_show_mask( 196 | annotation, 197 | ax, 198 | random_color=False, 199 | bbox=None, 200 | points=None, 201 | point_label=None, 202 | retinamask=True, 203 | target_height=960, 204 | target_width=960, 205 | ): 206 | msak_sum = annotation.shape[0] 207 | height = annotation.shape[1] 208 | weight = annotation.shape[2] 209 | # 将annotation 按照面积 排序 210 | areas = np.sum(annotation, axis=(1, 2)) 211 | sorted_indices = np.argsort(areas) 212 | annotation = annotation[sorted_indices] 213 | 214 | index = (annotation != 0).argmax(axis=0) 215 | if random_color == True: 216 | color = np.random.random((msak_sum, 1, 1, 3)) 217 | else: 218 | color = np.ones((msak_sum, 1, 1, 3)) * np.array( 219 | [30 / 255, 144 / 255, 255 / 255] 220 | ) 221 | transparency = np.ones((msak_sum, 1, 1, 1)) * 0.6 222 | visual = np.concatenate([color, transparency], axis=-1) 223 | mask_image = np.expand_dims(annotation, -1) * visual 224 | 225 | show = np.zeros((height, weight, 4)) 226 | h_indices, w_indices = np.meshgrid( 227 | np.arange(height), np.arange(weight), indexing="ij" 228 | ) 229 | indices = (index[h_indices, w_indices], h_indices, w_indices, slice(None)) 230 | # 使用向量化索引更新show的值 231 | show[h_indices, w_indices, :] = mask_image[indices] 232 | if bbox is not None: 233 | x1, y1, x2, y2 = bbox 234 | ax.add_patch( 235 | plt.Rectangle( 236 | (x1, y1), x2 - x1, y2 - y1, fill=False, edgecolor="b", linewidth=1 237 | ) 238 | ) 239 | # draw point 240 | if points is not None: 241 | plt.scatter( 242 | [point[0] for i, point in enumerate(points) if point_label[i] == 1], 243 | [point[1] for i, point in enumerate(points) if point_label[i] == 1], 244 | s=20, 245 | c="y", 246 | ) 247 | plt.scatter( 248 | [point[0] for i, point in enumerate(points) if point_label[i] == 0], 249 | [point[1] for i, point in enumerate(points) if point_label[i] == 0], 250 | s=20, 251 | c="m", 252 | ) 253 | 254 | if retinamask == False: 255 | show = cv2.resize( 256 | show, (target_width, target_height), interpolation=cv2.INTER_NEAREST 257 | ) 258 | ax.imshow(show) 259 | 260 | 261 | def fast_show_mask_gpu( 262 | annotation, 263 | ax, 264 | random_color=False, 265 | bbox=None, 266 | points=None, 267 | point_label=None, 268 | retinamask=True, 269 | target_height=960, 270 | target_width=960, 271 | ): 272 | msak_sum = annotation.shape[0] 273 | height = annotation.shape[1] 274 | weight = annotation.shape[2] 275 | areas = torch.sum(annotation, dim=(1, 2)) 276 | sorted_indices = torch.argsort(areas, descending=False) 277 | annotation = annotation[sorted_indices] 278 | # 找每个位置第一个非零值下标 279 | index = (annotation != 0).to(torch.long).argmax(dim=0) 280 | if random_color == True: 281 | color = torch.rand((msak_sum, 1, 1, 3)).to(annotation.device) 282 | else: 283 | color = torch.ones((msak_sum, 1, 1, 3)).to(annotation.device) * torch.tensor( 284 | [30 / 255, 144 / 255, 255 / 255] 285 | ).to(annotation.device) 286 | transparency = torch.ones((msak_sum, 1, 1, 1)).to(annotation.device) * 0.6 287 | visual = torch.cat([color, transparency], dim=-1) 288 | mask_image = torch.unsqueeze(annotation, -1) * visual 289 | # 按index取数,index指每个位置选哪个batch的数,把mask_image转成一个batch的形式 290 | show = torch.zeros((height, weight, 4)).to(annotation.device) 291 | h_indices, w_indices = torch.meshgrid( 292 | torch.arange(height), torch.arange(weight), indexing="ij" 293 | ) 294 | indices = (index[h_indices, w_indices], h_indices, w_indices, slice(None)) 295 | # 使用向量化索引更新show的值 296 | show[h_indices, w_indices, :] = mask_image[indices] 297 | show_cpu = show.cpu().numpy() 298 | if bbox is not None: 299 | x1, y1, x2, y2 = bbox 300 | ax.add_patch( 301 | plt.Rectangle( 302 | (x1, y1), x2 - x1, y2 - y1, fill=False, edgecolor="b", linewidth=1 303 | ) 304 | ) 305 | # draw point 306 | if points is not None: 307 | plt.scatter( 308 | [point[0] for i, point in enumerate(points) if point_label[i] == 1], 309 | [point[1] for i, point in enumerate(points) if point_label[i] == 1], 310 | s=20, 311 | c="y", 312 | ) 313 | plt.scatter( 314 | [point[0] for i, point in enumerate(points) if point_label[i] == 0], 315 | [point[1] for i, point in enumerate(points) if point_label[i] == 0], 316 | s=20, 317 | c="m", 318 | ) 319 | if retinamask == False: 320 | show_cpu = cv2.resize( 321 | show_cpu, (target_width, target_height), interpolation=cv2.INTER_NEAREST 322 | ) 323 | ax.imshow(show_cpu) 324 | 325 | 326 | # clip 327 | @torch.no_grad() 328 | def retriev( 329 | model, preprocess, elements: [Image.Image], search_text: str, device 330 | ): 331 | preprocessed_images = [preprocess(image).to(device) for image in elements] 332 | tokenized_text = clip.tokenize([search_text]).to(device) 333 | stacked_images = torch.stack(preprocessed_images) 334 | image_features = model.encode_image(stacked_images) 335 | text_features = model.encode_text(tokenized_text) 336 | image_features /= image_features.norm(dim=-1, keepdim=True) 337 | text_features /= text_features.norm(dim=-1, keepdim=True) 338 | probs = 100.0 * image_features @ text_features.T 339 | return probs[:, 0].softmax(dim=0) 340 | 341 | 342 | def crop_image(annotations, image_like): 343 | if isinstance(image_like, str): 344 | image = Image.open(image_like) 345 | else: 346 | image = image_like 347 | ori_w, ori_h = image.size 348 | mask_h, mask_w = annotations[0]["segmentation"].shape 349 | if ori_w != mask_w or ori_h != mask_h: 350 | image = image.resize((mask_w, mask_h)) 351 | cropped_boxes = [] 352 | cropped_images = [] 353 | not_crop = [] 354 | origin_id = [] 355 | for _, mask in enumerate(annotations): 356 | if np.sum(mask["segmentation"]) <= 100: 357 | continue 358 | origin_id.append(_) 359 | bbox = get_bbox_from_mask(mask["segmentation"]) # mask 的 bbox 360 | cropped_boxes.append(segment_image(image, bbox)) # 保存裁剪的图片 361 | # cropped_boxes.append(segment_image(image,mask["segmentation"])) 362 | cropped_images.append(bbox) # 保存裁剪的图片的bbox 363 | return cropped_boxes, cropped_images, not_crop, origin_id, annotations 364 | 365 | 366 | def box_prompt(masks, bbox, target_height, target_width): 367 | h = masks.shape[1] 368 | w = masks.shape[2] 369 | if h != target_height or w != target_width: 370 | bbox = [ 371 | int(bbox[0] * w / target_width), 372 | int(bbox[1] * h / target_height), 373 | int(bbox[2] * w / target_width), 374 | int(bbox[3] * h / target_height), 375 | ] 376 | bbox[0] = round(bbox[0]) if round(bbox[0]) > 0 else 0 377 | bbox[1] = round(bbox[1]) if round(bbox[1]) > 0 else 0 378 | bbox[2] = round(bbox[2]) if round(bbox[2]) < w else w 379 | bbox[3] = round(bbox[3]) if round(bbox[3]) < h else h 380 | 381 | # IoUs = torch.zeros(len(masks), dtype=torch.float32) 382 | bbox_area = (bbox[3] - bbox[1]) * (bbox[2] - bbox[0]) 383 | 384 | masks_area = torch.sum(masks[:, bbox[1] : bbox[3], bbox[0] : bbox[2]], dim=(1, 2)) 385 | orig_masks_area = torch.sum(masks, dim=(1, 2)) 386 | 387 | union = bbox_area + orig_masks_area - masks_area 388 | IoUs = masks_area / union 389 | max_iou_index = torch.argmax(IoUs) 390 | 391 | return masks[max_iou_index].cpu().numpy(), max_iou_index 392 | 393 | 394 | def point_prompt(masks, points, point_label, target_height, target_width): # numpy 处理 395 | h = masks[0]["segmentation"].shape[0] 396 | w = masks[0]["segmentation"].shape[1] 397 | if h != target_height or w != target_width: 398 | points = [ 399 | [int(point[0] * w / target_width), int(point[1] * h / target_height)] 400 | for point in points 401 | ] 402 | onemask = np.zeros((h, w)) 403 | masks = sorted(masks, key=lambda x: x['area'], reverse=True) 404 | for i, annotation in enumerate(masks): 405 | if type(annotation) == dict: 406 | mask = annotation['segmentation'] 407 | else: 408 | mask = annotation 409 | for i, point in enumerate(points): 410 | if mask[point[1], point[0]] == 1 and point_label[i] == 1: 411 | onemask[mask] = 1 412 | if mask[point[1], point[0]] == 1 and point_label[i] == 0: 413 | onemask[mask] = 0 414 | onemask = onemask >= 1 415 | return onemask, 0 416 | 417 | 418 | def text_prompt(annotations, text, img_path, device, wider=False, threshold=0.9): 419 | cropped_boxes, cropped_images, not_crop, origin_id, annotations_ = crop_image( 420 | annotations, img_path 421 | ) 422 | clip_model, preprocess = clip.load("ViT-B/32", device=device) 423 | scores = retriev( 424 | clip_model, preprocess, cropped_boxes, text, device=device 425 | ) 426 | max_idx = scores.argsort() 427 | max_idx = max_idx[-1] 428 | max_idx = origin_id[int(max_idx)] 429 | 430 | # find the biggest mask which contains the mask with max score 431 | if wider: 432 | mask0 = annotations_[max_idx]["segmentation"] 433 | area0 = np.sum(mask0) 434 | areas = [(i, np.sum(mask["segmentation"])) for i, mask in enumerate(annotations_) if i in origin_id] 435 | areas = sorted(areas, key=lambda area: area[1], reverse=True) 436 | indices = [area[0] for area in areas] 437 | for index in indices: 438 | if index == max_idx or np.sum(annotations_[index]["segmentation"] & mask0) / area0 > threshold: 439 | max_idx = index 440 | break 441 | 442 | return annotations_[max_idx]["segmentation"], max_idx 443 | -------------------------------------------------------------------------------- /DocImgTranslation/FastSAM/fastsam/prompt.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import cv2 4 | import matplotlib.pyplot as plt 5 | import numpy as np 6 | import torch 7 | from .utils import image_to_np_ndarray 8 | from PIL import Image 9 | 10 | try: 11 | import clip # for linear_assignment 12 | 13 | except (ImportError, AssertionError, AttributeError): 14 | from ultralytics.yolo.utils.checks import check_requirements 15 | 16 | check_requirements('git+https://github.com/openai/CLIP.git') # required before installing lap from source 17 | import clip 18 | 19 | 20 | class FastSAMPrompt: 21 | 22 | def __init__(self, image, results, device='cuda'): 23 | if isinstance(image, str) or isinstance(image, Image.Image): 24 | image = image_to_np_ndarray(image) 25 | self.device = device 26 | self.results = results 27 | self.img = image 28 | 29 | def _segment_image(self, image, bbox): 30 | if isinstance(image, Image.Image): 31 | image_array = np.array(image) 32 | else: 33 | image_array = image 34 | segmented_image_array = np.zeros_like(image_array) 35 | x1, y1, x2, y2 = bbox 36 | segmented_image_array[y1:y2, x1:x2] = image_array[y1:y2, x1:x2] 37 | segmented_image = Image.fromarray(segmented_image_array) 38 | black_image = Image.new('RGB', image.size, (255, 255, 255)) 39 | # transparency_mask = np.zeros_like((), dtype=np.uint8) 40 | transparency_mask = np.zeros((image_array.shape[0], image_array.shape[1]), dtype=np.uint8) 41 | transparency_mask[y1:y2, x1:x2] = 255 42 | transparency_mask_image = Image.fromarray(transparency_mask, mode='L') 43 | black_image.paste(segmented_image, mask=transparency_mask_image) 44 | return black_image 45 | 46 | def _format_results(self, result, filter=0): 47 | annotations = [] 48 | n = len(result.masks.data) 49 | for i in range(n): 50 | annotation = {} 51 | mask = result.masks.data[i] == 1.0 52 | 53 | if torch.sum(mask) < filter: 54 | continue 55 | annotation['id'] = i 56 | annotation['segmentation'] = mask.cpu().numpy() 57 | annotation['bbox'] = result.boxes.data[i] 58 | annotation['score'] = result.boxes.conf[i] 59 | annotation['area'] = annotation['segmentation'].sum() 60 | annotations.append(annotation) 61 | return annotations 62 | 63 | def filter_masks(annotations): # filte the overlap mask 64 | annotations.sort(key=lambda x: x['area'], reverse=True) 65 | to_remove = set() 66 | for i in range(0, len(annotations)): 67 | a = annotations[i] 68 | for j in range(i + 1, len(annotations)): 69 | b = annotations[j] 70 | if i != j and j not in to_remove: 71 | # check if 72 | if b['area'] < a['area']: 73 | if (a['segmentation'] & b['segmentation']).sum() / b['segmentation'].sum() > 0.8: 74 | to_remove.add(j) 75 | 76 | return [a for i, a in enumerate(annotations) if i not in to_remove], to_remove 77 | 78 | def _get_bbox_from_mask(self, mask): 79 | mask = mask.astype(np.uint8) 80 | contours, hierarchy = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) 81 | x1, y1, w, h = cv2.boundingRect(contours[0]) 82 | x2, y2 = x1 + w, y1 + h 83 | if len(contours) > 1: 84 | for b in contours: 85 | x_t, y_t, w_t, h_t = cv2.boundingRect(b) 86 | # Merge multiple bounding boxes into one. 87 | x1 = min(x1, x_t) 88 | y1 = min(y1, y_t) 89 | x2 = max(x2, x_t + w_t) 90 | y2 = max(y2, y_t + h_t) 91 | h = y2 - y1 92 | w = x2 - x1 93 | return [x1, y1, x2, y2] 94 | 95 | def plot_to_result(self, 96 | annotations, 97 | bboxes=None, 98 | points=None, 99 | point_label=None, 100 | mask_random_color=True, 101 | better_quality=True, 102 | retina=False, 103 | withContours=True) -> np.ndarray: 104 | if isinstance(annotations[0], dict): 105 | annotations = [annotation['segmentation'] for annotation in annotations] 106 | image = self.img 107 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) 108 | original_h = image.shape[0] 109 | original_w = image.shape[1] 110 | if sys.platform == "darwin": 111 | plt.switch_backend("TkAgg") 112 | plt.figure(figsize=(original_w / 100, original_h / 100)) 113 | # Add subplot with no margin. 114 | plt.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=0, wspace=0) 115 | plt.margins(0, 0) 116 | plt.gca().xaxis.set_major_locator(plt.NullLocator()) 117 | plt.gca().yaxis.set_major_locator(plt.NullLocator()) 118 | 119 | plt.imshow(image) 120 | if better_quality: 121 | if isinstance(annotations[0], torch.Tensor): 122 | annotations = np.array(annotations.cpu()) 123 | for i, mask in enumerate(annotations): 124 | mask = cv2.morphologyEx(mask.astype(np.uint8), cv2.MORPH_CLOSE, np.ones((3, 3), np.uint8)) 125 | annotations[i] = cv2.morphologyEx(mask.astype(np.uint8), cv2.MORPH_OPEN, np.ones((8, 8), np.uint8)) 126 | if self.device == 'cpu': 127 | annotations = np.array(annotations) 128 | self.fast_show_mask( 129 | annotations, 130 | plt.gca(), 131 | random_color=mask_random_color, 132 | bboxes=bboxes, 133 | points=points, 134 | pointlabel=point_label, 135 | retinamask=retina, 136 | target_height=original_h, 137 | target_width=original_w, 138 | ) 139 | else: 140 | if isinstance(annotations[0], np.ndarray): 141 | annotations = torch.from_numpy(annotations) 142 | self.fast_show_mask_gpu( 143 | annotations, 144 | plt.gca(), 145 | random_color=mask_random_color, 146 | bboxes=bboxes, 147 | points=points, 148 | pointlabel=point_label, 149 | retinamask=retina, 150 | target_height=original_h, 151 | target_width=original_w, 152 | ) 153 | if isinstance(annotations, torch.Tensor): 154 | annotations = annotations.cpu().numpy() 155 | if withContours: 156 | contour_all = [] 157 | temp = np.zeros((original_h, original_w, 1)) 158 | for i, mask in enumerate(annotations): 159 | if type(mask) == dict: 160 | mask = mask['segmentation'] 161 | annotation = mask.astype(np.uint8) 162 | if not retina: 163 | annotation = cv2.resize( 164 | annotation, 165 | (original_w, original_h), 166 | interpolation=cv2.INTER_NEAREST, 167 | ) 168 | contours, hierarchy = cv2.findContours(annotation, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE) 169 | for contour in contours: 170 | contour_all.append(contour) 171 | cv2.drawContours(temp, contour_all, -1, (255, 255, 255), 2) 172 | color = np.array([0 / 255, 0 / 255, 255 / 255, 0.8]) 173 | contour_mask = temp / 255 * color.reshape(1, 1, -1) 174 | plt.imshow(contour_mask) 175 | 176 | plt.axis('off') 177 | fig = plt.gcf() 178 | plt.draw() 179 | 180 | try: 181 | buf = fig.canvas.tostring_rgb() 182 | except AttributeError: 183 | fig.canvas.draw() 184 | buf = fig.canvas.tostring_rgb() 185 | cols, rows = fig.canvas.get_width_height() 186 | img_array = np.frombuffer(buf, dtype=np.uint8).reshape(rows, cols, 3) 187 | result = cv2.cvtColor(img_array, cv2.COLOR_RGB2BGR) 188 | plt.close() 189 | return result 190 | 191 | # Remark for refactoring: IMO a function should do one thing only, storing the image and plotting should be seperated and do not necessarily need to be class functions but standalone utility functions that the user can chain in his scripts to have more fine-grained control. 192 | def plot(self, 193 | annotations, 194 | output_path, 195 | bboxes=None, 196 | points=None, 197 | point_label=None, 198 | mask_random_color=True, 199 | better_quality=True, 200 | retina=False, 201 | withContours=True): 202 | if len(annotations) == 0: 203 | return None 204 | result = self.plot_to_result( 205 | annotations, 206 | bboxes, 207 | points, 208 | point_label, 209 | mask_random_color, 210 | better_quality, 211 | retina, 212 | withContours, 213 | ) 214 | 215 | path = os.path.dirname(os.path.abspath(output_path)) 216 | if not os.path.exists(path): 217 | os.makedirs(path) 218 | result = result[:, :, ::-1] 219 | cv2.imwrite(output_path, result) 220 | 221 | # CPU post process 222 | def fast_show_mask( 223 | self, 224 | annotation, 225 | ax, 226 | random_color=False, 227 | bboxes=None, 228 | points=None, 229 | pointlabel=None, 230 | retinamask=True, 231 | target_height=960, 232 | target_width=960, 233 | ): 234 | msak_sum = annotation.shape[0] 235 | height = annotation.shape[1] 236 | weight = annotation.shape[2] 237 | #Sort annotations based on area. 238 | areas = np.sum(annotation, axis=(1, 2)) 239 | sorted_indices = np.argsort(areas) 240 | annotation = annotation[sorted_indices] 241 | 242 | index = (annotation != 0).argmax(axis=0) 243 | if random_color: 244 | color = np.random.random((msak_sum, 1, 1, 3)) 245 | else: 246 | color = np.ones((msak_sum, 1, 1, 3)) * np.array([30 / 255, 144 / 255, 255 / 255]) 247 | transparency = np.ones((msak_sum, 1, 1, 1)) * 0.6 248 | visual = np.concatenate([color, transparency], axis=-1) 249 | mask_image = np.expand_dims(annotation, -1) * visual 250 | 251 | show = np.zeros((height, weight, 4)) 252 | h_indices, w_indices = np.meshgrid(np.arange(height), np.arange(weight), indexing='ij') 253 | indices = (index[h_indices, w_indices], h_indices, w_indices, slice(None)) 254 | # Use vectorized indexing to update the values of 'show'. 255 | show[h_indices, w_indices, :] = mask_image[indices] 256 | if bboxes is not None: 257 | for bbox in bboxes: 258 | x1, y1, x2, y2 = bbox 259 | ax.add_patch(plt.Rectangle((x1, y1), x2 - x1, y2 - y1, fill=False, edgecolor='b', linewidth=1)) 260 | # draw point 261 | if points is not None: 262 | plt.scatter( 263 | [point[0] for i, point in enumerate(points) if pointlabel[i] == 1], 264 | [point[1] for i, point in enumerate(points) if pointlabel[i] == 1], 265 | s=20, 266 | c='y', 267 | ) 268 | plt.scatter( 269 | [point[0] for i, point in enumerate(points) if pointlabel[i] == 0], 270 | [point[1] for i, point in enumerate(points) if pointlabel[i] == 0], 271 | s=20, 272 | c='m', 273 | ) 274 | 275 | if not retinamask: 276 | show = cv2.resize(show, (target_width, target_height), interpolation=cv2.INTER_NEAREST) 277 | ax.imshow(show) 278 | 279 | def fast_show_mask_gpu( 280 | self, 281 | annotation, 282 | ax, 283 | random_color=False, 284 | bboxes=None, 285 | points=None, 286 | pointlabel=None, 287 | retinamask=True, 288 | target_height=960, 289 | target_width=960, 290 | ): 291 | msak_sum = annotation.shape[0] 292 | height = annotation.shape[1] 293 | weight = annotation.shape[2] 294 | areas = torch.sum(annotation, dim=(1, 2)) 295 | sorted_indices = torch.argsort(areas, descending=False) 296 | annotation = annotation[sorted_indices] 297 | # Find the index of the first non-zero value at each position. 298 | index = (annotation != 0).to(torch.long).argmax(dim=0) 299 | if random_color: 300 | color = torch.rand((msak_sum, 1, 1, 3)).to(annotation.device) 301 | else: 302 | color = torch.ones((msak_sum, 1, 1, 3)).to(annotation.device) * torch.tensor([ 303 | 30 / 255, 144 / 255, 255 / 255]).to(annotation.device) 304 | transparency = torch.ones((msak_sum, 1, 1, 1)).to(annotation.device) * 0.6 305 | visual = torch.cat([color, transparency], dim=-1) 306 | mask_image = torch.unsqueeze(annotation, -1) * visual 307 | # Select data according to the index. The index indicates which batch's data to choose at each position, converting the mask_image into a single batch form. 308 | show = torch.zeros((height, weight, 4)).to(annotation.device) 309 | try: 310 | h_indices, w_indices = torch.meshgrid(torch.arange(height), torch.arange(weight), indexing='ij') 311 | except: 312 | h_indices, w_indices = torch.meshgrid(torch.arange(height), torch.arange(weight)) 313 | indices = (index[h_indices, w_indices], h_indices, w_indices, slice(None)) 314 | # Use vectorized indexing to update the values of 'show'. 315 | show[h_indices, w_indices, :] = mask_image[indices] 316 | show_cpu = show.cpu().numpy() 317 | if bboxes is not None: 318 | for bbox in bboxes: 319 | x1, y1, x2, y2 = bbox 320 | ax.add_patch(plt.Rectangle((x1, y1), x2 - x1, y2 - y1, fill=False, edgecolor='b', linewidth=1)) 321 | # draw point 322 | if points is not None: 323 | plt.scatter( 324 | [point[0] for i, point in enumerate(points) if pointlabel[i] == 1], 325 | [point[1] for i, point in enumerate(points) if pointlabel[i] == 1], 326 | s=20, 327 | c='y', 328 | ) 329 | plt.scatter( 330 | [point[0] for i, point in enumerate(points) if pointlabel[i] == 0], 331 | [point[1] for i, point in enumerate(points) if pointlabel[i] == 0], 332 | s=20, 333 | c='m', 334 | ) 335 | if not retinamask: 336 | show_cpu = cv2.resize(show_cpu, (target_width, target_height), interpolation=cv2.INTER_NEAREST) 337 | ax.imshow(show_cpu) 338 | 339 | # clip 340 | @torch.no_grad() 341 | def retrieve(self, model, preprocess, elements, search_text: str, device) -> int: 342 | preprocessed_images = [preprocess(image).to(device) for image in elements] 343 | tokenized_text = clip.tokenize([search_text]).to(device) 344 | stacked_images = torch.stack(preprocessed_images) 345 | image_features = model.encode_image(stacked_images) 346 | text_features = model.encode_text(tokenized_text) 347 | image_features /= image_features.norm(dim=-1, keepdim=True) 348 | text_features /= text_features.norm(dim=-1, keepdim=True) 349 | probs = 100.0 * image_features @ text_features.T 350 | return probs[:, 0].softmax(dim=0) 351 | 352 | def _crop_image(self, format_results): 353 | 354 | image = Image.fromarray(cv2.cvtColor(self.img, cv2.COLOR_BGR2RGB)) 355 | ori_w, ori_h = image.size 356 | annotations = format_results 357 | mask_h, mask_w = annotations[0]['segmentation'].shape 358 | if ori_w != mask_w or ori_h != mask_h: 359 | image = image.resize((mask_w, mask_h)) 360 | cropped_boxes = [] 361 | cropped_images = [] 362 | not_crop = [] 363 | filter_id = [] 364 | # annotations, _ = filter_masks(annotations) 365 | # filter_id = list(_) 366 | for _, mask in enumerate(annotations): 367 | if np.sum(mask['segmentation']) <= 100: 368 | filter_id.append(_) 369 | continue 370 | bbox = self._get_bbox_from_mask(mask['segmentation']) # mask 的 bbox 371 | cropped_boxes.append(self._segment_image(image, bbox)) 372 | # cropped_boxes.append(segment_image(image,mask["segmentation"])) 373 | cropped_images.append(bbox) # Save the bounding box of the cropped image. 374 | 375 | return cropped_boxes, cropped_images, not_crop, filter_id, annotations 376 | 377 | def box_prompt(self, bbox=None, bboxes=None): 378 | if self.results == None: 379 | return [] 380 | assert bbox or bboxes 381 | if bboxes is None: 382 | bboxes = [bbox] 383 | max_iou_index = [] 384 | for bbox in bboxes: 385 | assert (bbox[2] != 0 and bbox[3] != 0) 386 | masks = self.results[0].masks.data 387 | target_height = self.img.shape[0] 388 | target_width = self.img.shape[1] 389 | h = masks.shape[1] 390 | w = masks.shape[2] 391 | if h != target_height or w != target_width: 392 | bbox = [ 393 | int(bbox[0] * w / target_width), 394 | int(bbox[1] * h / target_height), 395 | int(bbox[2] * w / target_width), 396 | int(bbox[3] * h / target_height), ] 397 | bbox[0] = round(bbox[0]) if round(bbox[0]) > 0 else 0 398 | bbox[1] = round(bbox[1]) if round(bbox[1]) > 0 else 0 399 | bbox[2] = round(bbox[2]) if round(bbox[2]) < w else w 400 | bbox[3] = round(bbox[3]) if round(bbox[3]) < h else h 401 | 402 | # IoUs = torch.zeros(len(masks), dtype=torch.float32) 403 | bbox_area = (bbox[3] - bbox[1]) * (bbox[2] - bbox[0]) 404 | 405 | masks_area = torch.sum(masks[:, bbox[1]:bbox[3], bbox[0]:bbox[2]], dim=(1, 2)) 406 | orig_masks_area = torch.sum(masks, dim=(1, 2)) 407 | 408 | union = bbox_area + orig_masks_area - masks_area 409 | IoUs = masks_area / union 410 | max_iou_index.append(int(torch.argmax(IoUs))) 411 | max_iou_index = list(set(max_iou_index)) 412 | return np.array(masks[max_iou_index].cpu().numpy()) 413 | 414 | def point_prompt(self, points, pointlabel): # numpy 415 | if self.results == None: 416 | return [] 417 | masks = self._format_results(self.results[0], 0) 418 | target_height = self.img.shape[0] 419 | target_width = self.img.shape[1] 420 | h = masks[0]['segmentation'].shape[0] 421 | w = masks[0]['segmentation'].shape[1] 422 | if h != target_height or w != target_width: 423 | points = [[int(point[0] * w / target_width), int(point[1] * h / target_height)] for point in points] 424 | onemask = np.zeros((h, w)) 425 | masks = sorted(masks, key=lambda x: x['area'], reverse=True) 426 | for i, annotation in enumerate(masks): 427 | if type(annotation) == dict: 428 | mask = annotation['segmentation'] 429 | else: 430 | mask = annotation 431 | for i, point in enumerate(points): 432 | if mask[point[1], point[0]] == 1 and pointlabel[i] == 1: 433 | onemask[mask] = 1 434 | if mask[point[1], point[0]] == 1 and pointlabel[i] == 0: 435 | onemask[mask] = 0 436 | onemask = onemask >= 1 437 | return np.array([onemask]) 438 | 439 | def text_prompt(self, text): 440 | if self.results == None: 441 | return [] 442 | format_results = self._format_results(self.results[0], 0) 443 | cropped_boxes, cropped_images, not_crop, filter_id, annotations = self._crop_image(format_results) 444 | clip_model, preprocess = clip.load('ViT-B/32', device=self.device) 445 | scores = self.retrieve(clip_model, preprocess, cropped_boxes, text, device=self.device) 446 | max_idx = scores.argsort() 447 | max_idx = max_idx[-1] 448 | max_idx += sum(np.array(filter_id) <= int(max_idx)) 449 | return np.array([annotations[max_idx]['segmentation']]) 450 | 451 | def everything_prompt(self): 452 | if self.results == None: 453 | return [] 454 | return self.results[0].masks.data 455 | 456 | --------------------------------------------------------------------------------