├── img ├── img.png ├── img_1.png ├── img_2.png ├── img_3.png └── img_4.png ├── requirements.txt ├── ocr ├── __init__.py ├── utils.py └── ocr.py ├── README.md ├── AI-医学图片OCR.py └── ocr_utils.py /img/img.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tianchiguaixia/medical_ocr_streamlit/HEAD/img/img.png -------------------------------------------------------------------------------- /img/img_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tianchiguaixia/medical_ocr_streamlit/HEAD/img/img_1.png -------------------------------------------------------------------------------- /img/img_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tianchiguaixia/medical_ocr_streamlit/HEAD/img/img_2.png -------------------------------------------------------------------------------- /img/img_3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tianchiguaixia/medical_ocr_streamlit/HEAD/img/img_3.png -------------------------------------------------------------------------------- /img/img_4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tianchiguaixia/medical_ocr_streamlit/HEAD/img/img_4.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | paddleocr~=2.6.0.2 2 | opencv-python~=4.6.0.66 3 | numpy~=1.23.4 4 | pillow~=8.0.1 5 | streamlit~=1.12.2 6 | pandas~=1.5.0 -------------------------------------------------------------------------------- /ocr/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # time: 2022/10/17 13:03 3 | # file: __init__.py.py 4 | # author: fengchenggang 5 | # email: chenggang.feng@ashermed.com 6 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## 医学图片的ocr识别 2 | ### 背景 3 | 该项目主要使用百度的paddleocr对医学图片进行识别。利用PPStructure对识别的内容进行结构化,最终将结构化的内容保存成csv文件。 4 | 整个项目通过streamlit进行前端的展示。 5 | 6 | 7 | ### 数据形式 8 | ![](img/img_1.png) 9 | 10 | ### 代码结构 11 | 12 | ``` 13 | ├── AI-医学图片OCR.py # 前端展示 14 | ├── ocr 15 | │ ├── __init__.py # 初始化 16 | │ ├── ocr.py # ocr识别 17 | │ └── utils.py # 一些工具函数 18 | ├── ocr_utils.py # 一些工具函数 19 | ``` 20 | 21 | 22 | ### 项目启动 23 | 24 | ``` 25 | streamlit run AI-医学图片OCR.py 26 | ``` 27 | 28 | 29 | 30 | ### 前端展示 31 | * 前端展示地址:http://ip:8501 32 | ![](img/img_4.png) 33 | 34 | -------------------------------------------------------------------------------- /ocr/utils.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # time: 2022/10/17 13:04 3 | # file: utils.py 4 | 5 | 6 | import cv2 7 | import numpy as np 8 | 9 | 10 | def bytes_to_numpy(image_bytes, channels='BGR'): 11 | """ 12 | 图片格式转换 bytes -> numpy 13 | args: 14 | image_bytes(str): 图片的字节流 15 | channels(str): 图片的格式 ['BGR'|'RGB'] 16 | return(array): 17 | 转换后的图片 18 | """ 19 | _image_np = np.frombuffer(image_bytes, dtype=np.uint8) 20 | image_np = cv2.imdecode(_image_np, cv2.IMREAD_COLOR) 21 | if channels == 'BGR': 22 | return image_np 23 | elif channels == 'RGB': 24 | image_np = cv2.cvtColor(image_np, cv2.COLOR_BGR2RGB) 25 | return image_np 26 | 27 | 28 | def numpy_to_pic(image_path, image_np, channels='BGR'): 29 | """ 30 | 保存图片 31 | args: 32 | image_path(str): 图片路径 33 | image_np(array): numpy格式的图片数据 34 | channels(str): 图片的格式 ['BGR'|'RGB'] 35 | """ 36 | if channels == 'BGR': 37 | cv2.imwrite(image_path, image_np) 38 | elif channels == 'RGB': 39 | image_np = cv2.cvtColor(image_np, cv2.COLOR_RGB2BGR) 40 | cv2.imwrite(image_path, image_np) -------------------------------------------------------------------------------- /ocr/ocr.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # time: 2022/10/17 13:04 3 | # file: ocr.py 4 | 5 | 6 | from paddleocr import PaddleOCR 7 | from ocr_utils import draw_ocr, draw_ocr_box_txt 8 | 9 | ocr = PaddleOCR(lang='ch',use_angle_cls=True) 10 | 11 | 12 | def detect(image): 13 | """ 14 | 文本检测 15 | args: 16 | image(array): numpy格式的图片 'RGB' 17 | return(array): 18 | 检测后的图片 numpy格式 'RGB' 19 | """ 20 | result = ocr.ocr(image, rec=False) 21 | im_show = draw_ocr(image, result) 22 | 23 | return im_show 24 | 25 | 26 | def recognize(image, output_mode=0): 27 | """ 28 | 文本识别 29 | args: 30 | image(array): numpy格式的图片 'RGB' 31 | output_mode(int): 图片输出模式 [0|1] 32 | return(array): 33 | 识别后的图片 numpy格式 'RGB' 34 | """ 35 | result = ocr.ocr(image) 36 | boxes = [line[0] for line in result] 37 | txts = [line[1][0] for line in result] 38 | scores = [line[1][1] for line in result] 39 | if output_mode == 0: 40 | im_show = draw_ocr_box_txt(image, boxes, txts, scores) 41 | elif output_mode == 1: 42 | im_show = draw_ocr(image, boxes, txts, scores) 43 | 44 | return im_show 45 | 46 | -------------------------------------------------------------------------------- /AI-医学图片OCR.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # time: 2022/10/17 11:22 3 | # file: AI-医学图片OCR.py 4 | 5 | 6 | 7 | import streamlit as st 8 | 9 | from ocr.ocr import detect, recognize 10 | from ocr.utils import bytes_to_numpy 11 | import pandas as pd 12 | 13 | import os 14 | import cv2 15 | from paddleocr import PPStructure,draw_structure_result,save_structure_res 16 | 17 | st.title("AI-医学图片OCR") 18 | def convert_df(df): 19 | # IMPORTANT: Cache the conversion to prevent computation on every rerun 20 | return df.to_csv().encode("gbk") 21 | 22 | 23 | # 上传图片 24 | uploaded_file = st.sidebar.file_uploader('请选择一张图片', type=['png', 'jpg', 'jpeg']) 25 | print('uploaded_file:', uploaded_file) 26 | table_engine = PPStructure(show_log=True) 27 | if uploaded_file is not None: 28 | # To read file as bytes: 29 | # content = cv2.imread(uploaded_file) 30 | # st.write(content) 31 | bytes_data = uploaded_file.getvalue() 32 | # 转换格式 33 | img = bytes_to_numpy(bytes_data, channels='RGB') 34 | option_task = st.sidebar.radio('请选择要执行的任务', ('查看原图', '文本检测')) 35 | if option_task == '查看原图': 36 | st.image(img, caption='原图') 37 | elif option_task == '文本检测': 38 | im_show = detect(img) 39 | st.image(im_show, caption='文本检测后的图片') 40 | 41 | base_path="streamlit_data" 42 | 43 | path=os.path.exists(base_path+"/"+uploaded_file.name.split('.')[0]) 44 | 45 | if st.button('✨ 启动!'): 46 | local_path=base_path +"/"+uploaded_file.name.split('.')[0] 47 | result = table_engine(img) 48 | save_structure_res(result, base_path,uploaded_file.name.split('.')[0]) 49 | with st.container(): 50 | with st.expander(label="json结果展示", expanded=False): 51 | st.write(result) 52 | for i in os.listdir(local_path): 53 | if ".xlsx" in i: 54 | df = pd.read_excel(os.path.join(local_path, i)) 55 | df=df.fillna("") 56 | st.write(df) 57 | csv = convert_df(df) 58 | st.download_button( 59 | label="Download data as csv", 60 | data=csv, 61 | file_name='large_df.csv', 62 | mime='text/csv', 63 | ) 64 | 65 | 66 | 67 | 68 | 69 | -------------------------------------------------------------------------------- /ocr_utils.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # time: 2022/10/17 13:25 3 | # file: ocr_utils.py 4 | 5 | import cv2 6 | import math 7 | import numpy as np 8 | 9 | from PIL import Image, ImageDraw, ImageFont 10 | 11 | 12 | def resize_img(img, input_size=600): 13 | """ 14 | resize img and limit the longest side of the image to input_size 15 | """ 16 | img = np.array(img) 17 | im_shape = img.shape 18 | im_size_max = np.max(im_shape[0:2]) 19 | im_scale = float(input_size) / float(im_size_max) 20 | img = cv2.resize(img, None, None, fx=im_scale, fy=im_scale) 21 | return img 22 | 23 | 24 | def draw_ocr( 25 | image, 26 | boxes, 27 | txts=None, 28 | scores=None, 29 | drop_score=0.5, 30 | font_path="./fonts/font.ttf" 31 | ): 32 | """ 33 | Visualize the results of OCR detection and recognition 34 | args: 35 | image(Image|array): RGB image 36 | boxes(list): boxes with shape(N, 4, 2) 37 | txts(list): the texts 38 | scores(list): txxs corresponding scores 39 | drop_score(float): only scores greater than drop_threshold will be visualized 40 | font_path: the path of font which is used to draw text 41 | return(array): 42 | the visualized img 43 | """ 44 | if scores is None: 45 | scores = [1] * len(boxes) 46 | box_num = len(boxes) 47 | for i in range(box_num): 48 | if scores is not None and (scores[i] < drop_score or math.isnan(scores[i])): 49 | continue 50 | box = np.reshape(np.array(boxes[i]), [-1, 1, 2]).astype(np.int64) 51 | image = cv2.polylines(np.array(image), [box], True, (255, 0, 0), 2) 52 | if txts is not None: 53 | img = np.array(resize_img(image, input_size=600)) 54 | txt_img = text_visual( 55 | txts, 56 | scores, 57 | img_h=img.shape[0], 58 | img_w=600, 59 | threshold=drop_score, 60 | font_path=font_path 61 | ) 62 | img = np.concatenate([np.array(img), np.array(txt_img)], axis=1) 63 | return img 64 | return image 65 | 66 | 67 | def draw_ocr_box_txt( 68 | image, 69 | boxes, 70 | txts, 71 | scores=None, 72 | drop_score=0.5, 73 | font_path="./fonts/font.ttf" 74 | ): 75 | image = Image.fromarray(image) 76 | h, w = image.height, image.width 77 | img_left = image.copy() 78 | img_right = Image.new('RGB', (w, h), (255, 255, 255)) 79 | 80 | import random 81 | 82 | random.seed(0) 83 | draw_left = ImageDraw.Draw(img_left) 84 | draw_right = ImageDraw.Draw(img_right) 85 | for idx, (box, txt) in enumerate(zip(boxes, txts)): 86 | if scores is not None and scores[idx] < drop_score: 87 | continue 88 | color = (random.randint(0, 255), random.randint(0, 255), random.randint(0, 255)) 89 | draw_left.polygon( 90 | [ 91 | box[0][0], box[0][1], box[1][0], box[1][1], box[2][0], 92 | box[2][1], box[3][0], box[3][1] 93 | ], 94 | fill=color) 95 | draw_right.polygon( 96 | [ 97 | box[0][0], box[0][1], box[1][0], box[1][1], box[2][0], 98 | box[2][1], box[3][0], box[3][1] 99 | ], 100 | outline=color) 101 | box_height = math.sqrt((box[0][0] - box[3][0])**2 + (box[0][1] - box[3][ 102 | 1])**2) 103 | box_width = math.sqrt((box[0][0] - box[1][0])**2 + (box[0][1] - box[1][ 104 | 1])**2) 105 | if box_height > 2 * box_width: 106 | font_size = max(int(box_width * 0.9), 10) 107 | font = ImageFont.truetype(font_path, font_size, encoding="utf-8") 108 | cur_y = box[0][1] 109 | for c in txt: 110 | char_size = font.getsize(c) 111 | draw_right.text((box[0][0] + 3, cur_y), c, fill=(0, 0, 0), font=font) 112 | cur_y += char_size[1] 113 | else: 114 | font_size = max(int(box_height * 0.8), 10) 115 | font = ImageFont.truetype(font_path, font_size, encoding="utf-8") 116 | draw_right.text([box[0][0], box[0][1]], txt, fill=(0, 0, 0), font=font) 117 | img_left = Image.blend(image, img_left, 0.5) 118 | img_show = Image.new('RGB', (w * 2, h), (255, 255, 255)) 119 | img_show.paste(img_left, (0, 0, w, h)) 120 | img_show.paste(img_right, (w, 0, w * 2, h)) 121 | return np.array(img_show) 122 | 123 | 124 | def str_count(s): 125 | """ 126 | Count the number of Chinese characters, 127 | a single English character and a single number 128 | equal to half the length of Chinese characters. 129 | args: 130 | s(string): the input of string 131 | return(int): 132 | the number of Chinese characters 133 | """ 134 | import string 135 | count_zh = count_pu = 0 136 | s_len = len(s) 137 | en_dg_count = 0 138 | for c in s: 139 | if c in string.ascii_letters or c.isdigit() or c.isspace(): 140 | en_dg_count += 1 141 | elif c.isalpha(): 142 | count_zh += 1 143 | else: 144 | count_pu += 1 145 | return s_len - math.ceil(en_dg_count / 2) 146 | 147 | 148 | def text_visual( 149 | texts, 150 | scores, 151 | img_h=400, 152 | img_w=600, 153 | threshold=0., 154 | font_path="./fonts/font.ttf" 155 | ): 156 | """ 157 | create new blank img and draw txt on it 158 | args: 159 | texts(list): the text will be draw 160 | scores(list|None): corresponding score of each txt 161 | img_h(int): the height of blank img 162 | img_w(int): the width of blank img 163 | font_path: the path of font which is used to draw text 164 | return(array): 165 | """ 166 | if scores is not None: 167 | assert len(texts) == len(scores), "The number of txts and corresponding scores must match" 168 | 169 | def create_blank_img(): 170 | blank_img = np.ones(shape=[img_h, img_w], dtype=np.int8) * 255 171 | blank_img[:, img_w - 1:] = 0 172 | blank_img = Image.fromarray(blank_img).convert("RGB") 173 | draw_txt = ImageDraw.Draw(blank_img) 174 | return blank_img, draw_txt 175 | 176 | blank_img, draw_txt = create_blank_img() 177 | 178 | font_size = 20 179 | txt_color = (0, 0, 0) 180 | font = ImageFont.truetype(font_path, font_size, encoding="utf-8") 181 | 182 | gap = font_size + 5 183 | txt_img_list = [] 184 | count, index = 1, 0 185 | for idx, txt in enumerate(texts): 186 | index += 1 187 | if scores[idx] < threshold or math.isnan(scores[idx]): 188 | index -= 1 189 | continue 190 | first_line = True 191 | while str_count(txt) >= img_w // font_size - 4: 192 | tmp = txt 193 | txt = tmp[:img_w // font_size - 4] 194 | if first_line: 195 | new_txt = str(index) + ': ' + txt 196 | first_line = False 197 | else: 198 | new_txt = ' ' + txt 199 | draw_txt.text((0, gap * count), new_txt, txt_color, font=font) 200 | txt = tmp[img_w // font_size - 4:] 201 | if count >= img_h // gap - 1: 202 | txt_img_list.append(np.array(blank_img)) 203 | blank_img, draw_txt = create_blank_img() 204 | count = 0 205 | count += 1 206 | if first_line: 207 | new_txt = str(index) + ': ' + txt + ' ' + '%.3f' % (scores[idx]) 208 | else: 209 | new_txt = " " + txt + " " + '%.3f' % (scores[idx]) 210 | draw_txt.text((0, gap * count), new_txt, txt_color, font=font) 211 | # whether add new blank img or not 212 | if count >= img_h // gap - 1 and idx + 1 < len(texts): 213 | txt_img_list.append(np.array(blank_img)) 214 | blank_img, draw_txt = create_blank_img() 215 | count = 0 216 | count += 1 217 | txt_img_list.append(np.array(blank_img)) 218 | if len(txt_img_list) == 1: 219 | blank_img = np.array(txt_img_list[0]) 220 | else: 221 | blank_img = np.concatenate(txt_img_list, axis=1) 222 | return np.array(blank_img) 223 | 224 | 225 | def base64_to_cv2(b64str): 226 | import base64 227 | data = base64.b64decode(b64str.encode('utf8')) 228 | data = np.fromstring(data, np.uint8) 229 | data = cv2.imdecode(data, cv2.IMREAD_COLOR) 230 | return data 231 | 232 | 233 | def draw_boxes(image, boxes, scores=None, drop_score=0.5): 234 | if scores is None: 235 | scores = [1] * len(boxes) 236 | for (box, score) in zip(boxes, scores): 237 | if score < drop_score: 238 | continue 239 | box = np.reshape(np.array(box), [-1, 1, 2]).astype(np.int64) 240 | image = cv2.polylines(np.array(image), [box], True, (255, 0, 0), 2) 241 | return image 242 | 243 | 244 | def get_rotate_crop_image(img, points): 245 | ''' 246 | img_height, img_width = img.shape[0:2] 247 | left = int(np.min(points[:, 0])) 248 | right = int(np.max(points[:, 0])) 249 | top = int(np.min(points[:, 1])) 250 | bottom = int(np.max(points[:, 1])) 251 | img_crop = img[top:bottom, left:right, :].copy() 252 | points[:, 0] = points[:, 0] - left 253 | points[:, 1] = points[:, 1] - top 254 | ''' 255 | assert len(points) == 4, "shape of points must be 4*2" 256 | img_crop_width = int( 257 | max( 258 | np.linalg.norm(points[0] - points[1]), 259 | np.linalg.norm(points[2] - points[3]))) 260 | img_crop_height = int( 261 | max( 262 | np.linalg.norm(points[0] - points[3]), 263 | np.linalg.norm(points[1] - points[2]))) 264 | pts_std = np.float32([[0, 0], [img_crop_width, 0], 265 | [img_crop_width, img_crop_height], 266 | [0, img_crop_height]]) 267 | M = cv2.getPerspectiveTransform(points, pts_std) 268 | dst_img = cv2.warpPerspective( 269 | img, 270 | M, (img_crop_width, img_crop_height), 271 | borderMode=cv2.BORDER_REPLICATE, 272 | flags=cv2.INTER_CUBIC 273 | ) 274 | dst_img_height, dst_img_width = dst_img.shape[0:2] 275 | if dst_img_height * 1.0 / dst_img_width >= 1.5: 276 | dst_img = np.rot90(dst_img) 277 | return dst_img 278 | 279 | 280 | if __name__ == '__main__': 281 | pass 282 | --------------------------------------------------------------------------------