├── 001.jpg └── img_transform.py /001.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/springhser/img_trans/a11e96dd9cd4f438c3a19acda87f62a87fd10396/001.jpg -------------------------------------------------------------------------------- /img_transform.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # encoding: utf-8 3 | """ 4 | @version: 0.1 5 | @author: springhser 6 | @license: Apache Licence 7 | @contact: endoffight@gmail.com 8 | @site: http://www.springhser.com 9 | @software: PyCharm Community Edition 10 | @file: img_transform.py 11 | @time: 2017/12/30 17:58 12 | """ 13 | import numpy as np 14 | import cv2 15 | from matplotlib import pyplot as plt 16 | 17 | # 线段类 18 | class Line: 19 | def __init__(self, line): 20 | # x1, y1, x2, y2 = l # 前两个数为起点,后两个数为终点 21 | self.x1 = line[0] 22 | self.y1 = line[1] 23 | self.x2 = line[2] 24 | self.y2 = line[3] 25 | # 线段中点的坐标, 为排序使用 26 | self.half_x = (self.x1 + self.x2) / 2 27 | self.half_y = (self.y1 + self.y2) / 2 28 | # 求出与另外一条不平行线段延长线的交点 29 | def get_cross_point(self, l_a): 30 | a1 = self.y2 - self.y1 31 | b1 = self.x1 - self.x2 32 | c1 = a1 * self.x1 + b1 * self.y1 33 | a2 = l_a.y2 - l_a.y1 34 | b2 = l_a.x1 - l_a.x2 35 | c2 = a2 * l_a.x1 + b2 * l_a.x2 36 | d = a1 * b2 - a2 * b1 37 | if d is 0: # 平行或共线的情况 38 | raise ValueError 39 | return (1. * (b2 * c1 - b1 * c2) / d, 1. * (a1 * c2 - a2 * c1) / d) 40 | 41 | # 对图像进行预处理求出图像中的直线 42 | def img_process(old_img): 43 | # 重新设定图像大小,方便计算 44 | height = old_img.shape[0] 45 | weight = old_img.shape[1] 46 | min_weight = 200 47 | scale = min(10., weight * 1. / min_weight) 48 | new_h = int(height * 1. / scale) 49 | new_w = int(weight * 1. / scale) 50 | new_img = cv2.resize(old_img, (new_w, new_h)) 51 | gray_img = cv2.cvtColor(new_img, cv2.COLOR_BGR2GRAY) # 转为灰度图像 52 | # 利用Canny 边缘检测和霍夫变换提取直线 53 | highthreshold = cv2.threshold(gray_img, 0, 255, cv2.THRESH_BINARY | cv2.THRESH_OTSU)[0] # http://blog.csdn.net/on2way/article/details/46812121 54 | lowthreshold = highthreshold * 0.2 55 | canny_img = cv2.Canny(gray_img, lowthreshold, highthreshold) 56 | return cv2.HoughLinesP(canny_img, 1, np.pi / 180, new_w // 3, 50, new_w // 3, 20), new_img 57 | 58 | # 获得四个目标点 59 | def get_target_points(lines, img): 60 | # 分别提取出水平和垂直的线段 61 | # print(lines) 62 | lines_h = [] # 存放接近水平的线段 63 | lines_v = [] # 存放接近垂直的线段 64 | lines1 = lines[:, 0, :] # 提取为二维 65 | for l in lines1: 66 | # print(l) 67 | line = Line(l) 68 | if abs(line.x1 - line.x2) > abs(line.y1 - line.y2): 69 | lines_h.append(line) 70 | else: 71 | lines_v.append(line) 72 | 73 | # 如果线段数不够两条, 直接用原图像的边缘替代 74 | if len(lines_h) <= 1: 75 | if not lines_h or lines_h[0].half_y > img.shape[0]/2: 76 | lines_h.append(Line((0, 0, img.shape[1] - 1, 0))) 77 | if not lines_h or lines_h[0].half.y <= img.shape[0] / 2: 78 | lines_h.append(Line((0, img.shape[0] - 1, img.shape[1] - 1, img.shape[0] - 1))) 79 | if len(lines_v) <= 1: 80 | if not lines_v or lines_v[0].half_x > img.shape[1] / 2: 81 | lines_v.append(Line((0, 0, 0, img.shape[0] - 1))) 82 | if not lines_v or lines_v[0].c_x <= img.shape[1] / 2: 83 | lines_v.append(Line((img.shape[1] - 1, 0, img.shape[1] - 1, img.shape[0] - 1))) 84 | 85 | # 获取最靠近边缘的四条线段求出他们的交点 86 | lines_h.sort(key=lambda line: line.half_y) 87 | lines_v.sort(key=lambda line: line.half_x) 88 | return [lines_h[0].get_cross_point(lines_v[0]), 89 | lines_h[0].get_cross_point(lines_v[-1]), 90 | lines_h[-1].get_cross_point(lines_v[0]), 91 | lines_h[-1].get_cross_point(lines_v[-1])] 92 | 93 | # 做透视变换 94 | def per_transform(target_points, old_img): 95 | height = old_img.shape[0] 96 | weight = old_img.shape[1] 97 | min_weight = 200 98 | scale = min(10., weight * 1. / min_weight) 99 | # 恢复为原图像大小 100 | for i, p in enumerate(target_points): 101 | x, y = p 102 | target_points[i] = (x * scale, y * scale) 103 | # 原图像的四个点 104 | four_points= np.array(((0, 0), 105 | (weight - 1, 0), 106 | (0, height - 1), 107 | (weight - 1, height - 1)), 108 | np.float32) 109 | target_points = np.array(target_points, np.float32) 110 | M = cv2.getPerspectiveTransform(target_points, four_points) 111 | return cv2.warpPerspective(old_img, M, (weight, height)) 112 | 113 | 114 | if __name__ == '__main__': 115 | old_img = cv2.imread("001.jpg") 116 | lines, new_img = img_process(old_img) 117 | t_points = get_target_points(lines, new_img) 118 | revert_img = per_transform(t_points, old_img) 119 | plt.imshow(revert_img,) 120 | plt.xticks([]), plt.yticks([]) # to hide tick values on X and Y axis 121 | plt.show() 122 | 123 | 124 | --------------------------------------------------------------------------------