├── README.md ├── answer ├── t1.txt ├── t3.txt └── t6.txt ├── fonts └── simfang.ttf ├── images ├── t1.jpg ├── t2.jpg ├── t3.jpg ├── t4.jpg ├── t5.jpg ├── t6.jpg ├── t7.jpg └── trans.jpg ├── main.py ├── settings.yaml ├── test.py └── utils.py /README.md: -------------------------------------------------------------------------------- 1 | # DetectAnswer 2 | Python+OCR+OpenCV实现答题卡选项识别 3 | -------------------------------------------------------------------------------- /answer/t1.txt: -------------------------------------------------------------------------------- 1 | 1 B 2 | 2 C 3 | 3 A 4 | 4 D 5 | 5 B 6 | -------------------------------------------------------------------------------- /answer/t3.txt: -------------------------------------------------------------------------------- 1 | 1 B 2 | 2 C 3 | 3 B 4 | 4 D 5 | 5 B 6 | 6 A 7 | 7 C 8 | 8 D 9 | 9 B 10 | 10 D 11 | 11 B 12 | 12 D 13 | 13 C 14 | 14 A 15 | 15 C 16 | 16 B 17 | 17 B 18 | 18 C 19 | 19 B 20 | 20 A 21 | 21 B 22 | 22 D 23 | 23 C 24 | 24 B 25 | 25 D 26 | 26 B 27 | 27 C 28 | 28 B 29 | 29 B 30 | 30 C 31 | -------------------------------------------------------------------------------- /answer/t6.txt: -------------------------------------------------------------------------------- 1 | 1 B 2 | 2 C 3 | 3 A 4 | 4 C 5 | 5 D 6 | 6 A 7 | 7 A 8 | 8 A 9 | 9 C 10 | 10 C 11 | 11 B 12 | 12 D 13 | 13 A 14 | 14 C 15 | 15 B 16 | 16 D 17 | 17 C 18 | 18 B 19 | 19 B 20 | 20 A 21 | 21 B 22 | 22 C 23 | 23 D 24 | 24 A 25 | 25 C 26 | 26 B 27 | 27 D 28 | 28 C 29 | 29 D 30 | 30 A 31 | 31 C 32 | 32 D 33 | 33 B 34 | 34 C 35 | 35 A 36 | 36 A 37 | 37 B 38 | 38 D 39 | 39 C 40 | 40 C 41 | 41 B 42 | 42 D 43 | 43 B 44 | 44 D 45 | 45 A 46 | 46 C 47 | 47 D 48 | 48 B 49 | 49 D 50 | 50 A 51 | -------------------------------------------------------------------------------- /fonts/simfang.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StuOfBupt/DetectAnswer/c7b0ad2b89dc46c4b76458b57bcd43048845ccd3/fonts/simfang.ttf -------------------------------------------------------------------------------- /images/t1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StuOfBupt/DetectAnswer/c7b0ad2b89dc46c4b76458b57bcd43048845ccd3/images/t1.jpg -------------------------------------------------------------------------------- /images/t2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StuOfBupt/DetectAnswer/c7b0ad2b89dc46c4b76458b57bcd43048845ccd3/images/t2.jpg -------------------------------------------------------------------------------- /images/t3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StuOfBupt/DetectAnswer/c7b0ad2b89dc46c4b76458b57bcd43048845ccd3/images/t3.jpg -------------------------------------------------------------------------------- /images/t4.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StuOfBupt/DetectAnswer/c7b0ad2b89dc46c4b76458b57bcd43048845ccd3/images/t4.jpg -------------------------------------------------------------------------------- /images/t5.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StuOfBupt/DetectAnswer/c7b0ad2b89dc46c4b76458b57bcd43048845ccd3/images/t5.jpg -------------------------------------------------------------------------------- /images/t6.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StuOfBupt/DetectAnswer/c7b0ad2b89dc46c4b76458b57bcd43048845ccd3/images/t6.jpg -------------------------------------------------------------------------------- /images/t7.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StuOfBupt/DetectAnswer/c7b0ad2b89dc46c4b76458b57bcd43048845ccd3/images/t7.jpg -------------------------------------------------------------------------------- /images/trans.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/StuOfBupt/DetectAnswer/c7b0ad2b89dc46c4b76458b57bcd43048845ccd3/images/trans.jpg -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import imutils 3 | from PIL import Image 4 | from imutils.perspective import four_point_transform 5 | 6 | import test 7 | import utils 8 | from paddleocr import PaddleOCR, draw_ocr 9 | 10 | ocr = PaddleOCR(det_model_dir='./model/ch_ppocr_server_v2.0_det_infer', 11 | cls_model_dir='./model/ch_ppocr_mobile_v2.0_cls_infer', 12 | rec_model_dir='./model/ch_ppocr_server_v2.0_rec_infer', 13 | use_angle_cls=True, lang='ch', use_gpu=False) 14 | font_path = './fonts/simfang.ttf' 15 | 16 | DEBUG_MODE = False 17 | distance_estimation = set() 18 | 19 | 20 | def cv_show(img, msg=""): 21 | if DEBUG_MODE: 22 | cv2.imshow('temp', img) 23 | if msg: 24 | print(msg) 25 | cv2.waitKey(0) 26 | 27 | 28 | def get_answer(base_image, img_path=None): 29 | gray = cv2.cvtColor(base_image, cv2.COLOR_BGR2GRAY) 30 | blurred = cv2.GaussianBlur(gray, (5, 5), 0) 31 | cv_show(blurred, msg="blurred img") 32 | edged = imutils.auto_canny(blurred) 33 | cv_show(edged, msg="edged") 34 | 35 | thresh = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY_INV | cv2.THRESH_OTSU)[1] 36 | cv_show(thresh, "thresh") 37 | 38 | # 二值图像中查找轮廓 39 | cnts = cv2.findContours(thresh.copy(), cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE) 40 | cnts = imutils.grab_contours(cnts) 41 | boxes = [cv2.boundingRect(cnt) for cnt in cnts] 42 | 43 | cp_img = base_image.copy() 44 | 45 | for bbox in boxes: 46 | [x, y, w, h] = bbox 47 | cv2.rectangle(cp_img, (x, y), (x + w, y + h), (255, 0, 0), 2) 48 | 49 | cv_show(cp_img, "cp") 50 | lines = ocr.ocr(img_path, cls=True) 51 | w_threshod, h_threshod = utils.get_choice_wh(lines) 52 | 53 | def between(var, threshod): 54 | return threshod[0] <= var <= threshod[1] 55 | 56 | choice_boxes = [] 57 | for box in boxes: 58 | [x, y, w, h] = box 59 | if between(w, w_threshod) and between(h, h_threshod): 60 | choice_boxes.append(box) 61 | cv2.rectangle(base_image, (x, y), (x + w, y + h), (255, 0, 0), 2) 62 | cv_show(base_image, 'finial') 63 | utils.compute_choice_interval(choice_boxes) 64 | utils.compute_choice_width(choice_boxes) 65 | ans = dict() 66 | for num, (res, box) in utils.analysis_orc_lines(lines).items(): 67 | if res is None: 68 | print('题号 %d ocr 未识别出选项,使用距离估算' % num) 69 | distance_estimation.add(num) 70 | res = utils.find_choice_by_num_box(box, choice_boxes) 71 | ans[num] = res 72 | # for num, xywh in utils.extract_num(lines).items(): 73 | # if num in ans: 74 | # print('题号 %d 重复判断' % num) 75 | # continue 76 | # print('题号 %d 使用距离估算' % num) 77 | # ans[num] = utils.find_choice_by_num_box(utils.xywh2box(*xywh), choice_boxes) 78 | 79 | test.format_print_dict(ans) 80 | utils.write_lines('lines.txt', lines) 81 | ocr_boxes = [line[0] for line in lines] 82 | im_show = draw_ocr(base_image, ocr_boxes, None, None, font_path=font_path) 83 | im_show = Image.fromarray(im_show) 84 | im_show.save('result.jpg') 85 | return ans 86 | 87 | 88 | def filter_boxes(boxes, image): 89 | area = image.shape[0] * image.shape[1] 90 | valid_boxes = [] 91 | max_area = 0 92 | for box in boxes: 93 | [_, _, w, h] = box 94 | if w > h and w / h < 3.0 and w * h < area / 8: 95 | valid_boxes.append(box) 96 | max_area = max(max_area, w * h) 97 | threshod = max_area * 0.6 98 | ret = [] 99 | for box in valid_boxes: 100 | [_, _, w, h] = box 101 | if w * h > threshod: 102 | ret.append(box) 103 | return ret 104 | 105 | 106 | def judge(ret, answer): 107 | count = 0 108 | for num, choice in answer.items(): 109 | if num not in ret: 110 | print('题号 %d 没有识别出来' % num) 111 | continue 112 | if ret[num] != choice: 113 | msg = '距离估算错误' if num in distance_estimation else 'ocr 识别错误' 114 | print('题号 %d 识别错误 %s (%s), 原因:%s' % (num, ret[num], choice, msg)) 115 | 116 | continue 117 | count += 1 118 | print('正确率:%.2f %d/%d' % (count / len(answer.keys()), count, len(answer.keys()))) 119 | 120 | 121 | if __name__ == '__main__': 122 | img_path = './images/t6.jpg' 123 | image = cv2.imread(img_path) 124 | ret = get_answer(image, img_path) 125 | answer = utils.read_answer('./answer/t6.txt') 126 | judge(ret, answer) 127 | -------------------------------------------------------------------------------- /settings.yaml: -------------------------------------------------------------------------------- 1 | # 选项间间隔 2 | choice_interval: 0 3 | # 选项宽度 4 | choice_width: 0 5 | # 选项 6 | choices: [ 'A', 'B', 'C', 'D' ] -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | import yaml 4 | 5 | import numpy as np 6 | from imutils.perspective import four_point_transform 7 | import cv2 8 | import imutils 9 | from PIL import Image 10 | 11 | 12 | def format_print_dict(d): 13 | print(json.dumps(d, indent=4, ensure_ascii=False)) # 缩进4空格,中文字符不转义成Unicode) 14 | 15 | 16 | def load_config(): 17 | f = open('settings.yaml') 18 | cf = yaml.load(f, Loader=yaml.FullLoader) 19 | format_print_dict(cf) 20 | cf['ANS_IMG_KERNEL'] = np.ones((2, 2), np.uint8) 21 | cf['CHOICE_IMG_KERNEL'] = np.ones((2, 2), np.uint8) 22 | return cf 23 | 24 | 25 | def trans(base_image): 26 | gray = cv2.cvtColor(base_image, cv2.COLOR_BGR2GRAY) 27 | blurred = cv2.GaussianBlur(gray, (5, 5), 0) 28 | edged = imutils.auto_canny(blurred) 29 | cnts = cv2.findContours(edged.copy(), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) 30 | cnts = imutils.grab_contours(cnts) 31 | docCnt = None 32 | # 寻找答题卡区域 33 | if len(cnts) > 0: 34 | # 轮廓按照面积降序排列 35 | cnts = sorted(cnts, key=cv2.contourArea, reverse=True) 36 | for c in cnts: 37 | peri = cv2.arcLength(c, True) 38 | approx = cv2.approxPolyDP(c, 0.02 * peri, True) 39 | if len(approx) == 4: 40 | docCnt = approx 41 | break 42 | else: 43 | return 44 | paper = four_point_transform(base_image, docCnt.reshape(4, 2)) 45 | im = Image.fromarray(paper) 46 | im.save('trans.jpg') 47 | 48 | 49 | 50 | 51 | 52 | if __name__ == '__main__': 53 | img_path = './images/t5.jpg' 54 | image = cv2.imread(img_path) 55 | trans(image) -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | from functools import cmp_to_key 2 | from test import load_config 3 | 4 | config = load_config() 5 | 6 | 7 | # 根据 orc 切分结果获取一个选项框的宽和高区域范围 8 | def get_choice_wh(lines): 9 | boxes = [line[0] for line in lines] 10 | heights = [] 11 | for box in boxes: 12 | heights.append(box[-1][1] - box[0][1]) 13 | # 高的计算直接拿平均宽度 * 1.5 14 | avg_h = sum(heights) / len(heights) 15 | min_h = avg_h * 0.7 16 | max_h = avg_h * 1.5 17 | min_w = min_h * 2 18 | max_w = max_h * 2 19 | return (min_w, max_w), (min_h, max_h) 20 | 21 | 22 | # 分析 ocr 结果 23 | def analysis_orc_lines(lines: list): 24 | # 只包含题号的 25 | # 只包含选项的 26 | # 包含题号和选项 27 | choice = config['choices'] 28 | 29 | # 返回 text 中有多少个数字以及每个数字的起始位置 30 | # eg. '12ABD 24 BCD' 31 | def split_line_by_num(text: str): 32 | result = [] 33 | length = len(text) 34 | i = 0 35 | while i < length: 36 | if not text[i].isdigit(): 37 | i += 1 38 | continue 39 | j = i + 1 40 | while j < length: 41 | if text[j].isdigit(): 42 | j += 1 43 | else: 44 | break 45 | result.append(i) 46 | i = j 47 | return result 48 | 49 | # 过滤文本 留下数字和大写的选项 50 | def filter(text): 51 | s = "" 52 | for c in text: 53 | if c.isalpha() and c.upper() in choice: 54 | s += c.upper() 55 | continue 56 | if c.isdigit(): 57 | s += c 58 | return s 59 | 60 | def get_choice(text): 61 | res = [] 62 | for c in choice: 63 | res.append(c in text) 64 | if sum(res) != len(choice) - 1: 65 | return None 66 | for i in range(len(res)): 67 | if not res[i]: 68 | return choice[i] 69 | return None 70 | 71 | # 分析题目和其选项 72 | # 返回题号和它对应的选项,若没有则置为 None 73 | def analysis(text: str): 74 | if text.isdigit(): 75 | return int(text), None 76 | elif text.isalnum(): 77 | num = "" 78 | for c in text: 79 | if not c.isdigit(): 80 | break 81 | num += c 82 | if not num.isdigit(): 83 | return None, None 84 | return int(num), get_choice(text) 85 | else: 86 | return None, None 87 | 88 | # 返回 {题目--->(选项, 位置)} 89 | ret = dict() 90 | for line in lines: 91 | txt, confidence = line[1] 92 | if txt.isdigit() and confidence <= 0.85: 93 | # 纯数字的置信度应大于 0.85 94 | continue 95 | box = line[0] 96 | X, Y, W, H = box2xywh(box) 97 | texts = split_line_by_num(txt) 98 | if not texts: 99 | # 不包含数字的行 直接过滤 100 | continue 101 | length = len(txt) 102 | for i in range(len(texts)): 103 | start = texts[i] 104 | end = length 105 | if i + 1 < len(texts): 106 | end = texts[i + 1] 107 | text = txt[start: end] 108 | s = filter(text) 109 | num, res = analysis(s) 110 | if num: 111 | # 若 num 不在 或者 单单分析出题号,则使用距离估算 112 | x = X + (start / length) * W 113 | w = W * (end - start) / length 114 | if num not in ret or res is None: 115 | ret[num] = (res, xywh2box(x, Y, w, H)) 116 | return ret 117 | 118 | 119 | def find_choice_by_num_box(num_box, choice_boxes): 120 | h0, h1 = num_box[0][1], num_box[-1][1] 121 | x0 = num_box[0][0] + (h1 - h0) * 1.5 # 题目所在 x 位置 122 | candidate = [] 123 | for box in choice_boxes: 124 | x, y0, w, h = box 125 | y1 = y0 + h 126 | min_y = min(y0, h0) 127 | max_y = max(y1, h1) 128 | # 选项在题目右侧并且横向有交集,则加入备选序列 129 | if (x + w / 2) > x0 and (max_y - min_y) < (h + h1 - h0): 130 | candidate.append(box) 131 | # 没有候选项,返回 None 132 | if len(candidate) == 0: 133 | print('没有候选项') 134 | return None 135 | 136 | # 按照 x 坐标排序 137 | def compare(b1, b2): 138 | if b1[0] == b2[0]: 139 | return 0 140 | elif b1[0] < b2[0]: 141 | return -1 142 | else: 143 | return 1 144 | 145 | candidate.sort(key=cmp_to_key(compare)) 146 | choice = candidate[0] 147 | x, y, w, h = choice 148 | choice = ['A', 'B', 'C', 'D'] 149 | if config['choice_width'] > 0: 150 | w = config['choice_width'] 151 | interval = config['choice_interval'] 152 | idx = int((x - x0) / (w + interval)) 153 | print('选项宽度: %.2f\t估算偏移:%d' % (w, idx)) 154 | if idx >= len(choice): 155 | return choice[-1] 156 | else: 157 | return choice[idx] 158 | 159 | 160 | def read_answer(file): 161 | ret = dict() 162 | with open(file) as fr: 163 | for line in fr.readlines(): 164 | line = line.strip().split(' ') 165 | ret[int(line[0])] = line[1] 166 | return ret 167 | 168 | 169 | def write_lines(file, lines): 170 | with open(file, 'w') as fr: 171 | for line in lines: 172 | fr.write(line[1][0]) 173 | fr.write('\t置信度:%.2f' % line[1][1]) 174 | fr.write('\n') 175 | 176 | 177 | # 计算每个选项框之间的间隙 178 | def compute_choice_interval(choice_boxes: list): 179 | if len(choice_boxes) < 2: 180 | return 181 | 182 | # 先按照 x 坐标排序 183 | def cmp_x(b1, b2): 184 | if b1[0] == b2[0]: 185 | return 0 186 | elif b1[0] < b2[0]: 187 | return -1 188 | else: 189 | return 1 190 | 191 | choice_boxes.sort(key=cmp_to_key(cmp_x)) 192 | intervals = [] 193 | for box in choice_boxes: 194 | x, _, w, _ = box 195 | if len(intervals) == 0: 196 | intervals.append([x, x + w]) 197 | continue 198 | last = intervals[-1] 199 | if x > last[1]: 200 | intervals.append([x, x + w]) 201 | else: 202 | last[1] = x + w 203 | interval = 9999 204 | for i in range(len(intervals) - 1): 205 | # 计算最小间隔作为每个选项的间隔 206 | interval = min(intervals[i + 1][0] - intervals[i][1], interval) 207 | config['choice_interval'] = interval 208 | print('选项间间隔:%.2f' % interval) 209 | 210 | 211 | # 计算选项框的平均宽度 212 | def compute_choice_width(choice_boxes: list): 213 | widths = [] 214 | for box in choice_boxes: 215 | _, _, w, _ = box 216 | widths.append(w) 217 | 218 | # 选项小于 10 直接计算平均 219 | if len(widths) >= 10: 220 | widths.sort() 221 | length = len(widths) 222 | # 长度大于 10 取其中 60% 作为依据 223 | length = length // 5 224 | widths = widths[length: -length] 225 | config['choice_width'] = sum(widths) / len(widths) 226 | 227 | 228 | # 从 ocr 文本行中提取选项和位置 229 | # 一个文本行中可能有多个,返回列表结果 [[num, x, y, w, h], ...] 230 | def extract_num_pos(line: list): 231 | result = [] 232 | txt, _ = line[1] 233 | # 不包含数字,过滤 234 | if not any([c.isdigit() for c in txt]): 235 | return result 236 | box = line[0] 237 | # 文本框的坐标,宽高 238 | X, Y = box[0] 239 | W = box[1][0] - X 240 | H = box[-1][1] - Y 241 | i = 0 242 | length = len(txt) 243 | while i < length: 244 | if not txt[i].isdigit(): 245 | i += 1 246 | continue 247 | j = i + 1 248 | while j < length: 249 | if txt[j].isdigit(): 250 | j += 1 251 | else: 252 | break 253 | num = int(txt[i: j]) 254 | x = X + W * (i / length) 255 | # 认为数字的宽度等于高度 256 | w = H 257 | result.append([num, x, Y, w, H]) 258 | i = j 259 | return result 260 | 261 | 262 | def extract_num(lines: list): 263 | ret = dict() 264 | for line in lines: 265 | nums = extract_num_pos(line) 266 | if not nums: 267 | continue 268 | for item in nums: 269 | ret[item[0]] = item[1:] 270 | return ret 271 | 272 | 273 | def xywh2box(x, y, w, h): 274 | return [[x, y], [x + w, y], [x + w, y + h], [x, y + h]] 275 | 276 | 277 | def box2xywh(box): 278 | x, y = box[0] 279 | w = box[1][0] - x 280 | h = box[-1][1] - y 281 | return x, y, w, h 282 | --------------------------------------------------------------------------------