├── .gitignore ├── README.md ├── craft_utils.py ├── divide_text_region_from_gt.py ├── file_utils.py ├── generate_score_map.py ├── imgproc.py ├── my_dataset.py ├── my_model.py ├── readme_imgs ├── res_blw_1.jpg └── 标注.png ├── rename_filename.py ├── test.py ├── train.py └── vgg16_bn.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | 106 | # mine 107 | pretrained/ 108 | data/ 109 | models/ 110 | result/ 111 | imgs/ 112 | divides/ -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # 如何利用CRAFT训练属于自己文本检测数据集的模型 2 | 本项目旨在以CRAFT提供的预训练模型为基础,进行迁移学习以用于检测自己数据集中的文本。 3 | [[CRAFT论文]](https://arxiv.org/abs/1904.01941) 4 | [[代码]](https://github.com/CommissarMa/pytorch-CRAFT) 5 | [[论文中文解读]](https://github.com/CommissarMa/Awesome_CV_papers/blob/master/Text_Related/cvpr2019_CRAFT/cvpr2019_CRAFT.md) 6 | 7 | ## 1. 直接使用CRAFT的预训练模型测试自己的文本图像 8 | 1. 下载CRAFT预训练权重文件[[craft_mlt_25k.pth]](https://pan.baidu.com/s/1oinKoVnIMP017hc-1yX_CQ)(提取码:3bgk),并将该权重文件放入pretrained目录下。 9 | 2. 将需要检测的图像全部放入imgs目录下。 10 | 3. 运行代码: 11 | ``` 12 | python test.py --trained_model ./pretrained/craft_mlt_25k.pth 13 | ``` 14 | 4. 检测的结果将保存在result文件夹中供查看。 15 | 16 | ## 2. 在自己的数据集上训练CRAFT,迁移学习 17 | 1. 标注自己的数据集,使用标注工具[[labelme]](https://github.com/wkentaro/labelme),我们进行字符级别的标注,即对每个字符顺时针标注4个点构成一个多边形框,如下图所示: 18 | ![标注](./readme_imgs/标注.png) 19 | 然后我们给这个多边形框标注对应的字符,方便之后如果要做文本识别时使用。 20 | 2. 假设我们数据集的根目录是blw,目录中有图片blw_1.jpg和标注blw_1.json两种文件[[参考数据集:提取码6q33]](https://pan.baidu.com/s/10FO2Y9tMPcrjmBoTbPJlXw),此时我们运行generate_score_map.py(注意修改main函数中的name = 你的根目录名称),运行完之后,你的目录中除了上面两种.jpg和.json外,会多了blw_region_1.npy和blw_affinity_1.npy两种,分别对应了CRAFT中的region_map和affinity_map。然后我们建立一个根目录data,并将四种文件分别放入对应的子目录,如下所示: 21 | + data: 22 | + affinity:blw_affinity_1.npy 23 | + anno:blw_1.json 24 | + img:blw_1.jpg 25 | + region:blw_region_1.npy 26 | 27 | 此时我们自己的数据集就准备好了。 28 | 3. 运行train.py(注意修改main函数中的参数设置),训练好的模型默认存放在./models中。 29 | 30 | ## 3. 使用得到的新模型来测试自己的文本图像 31 | 这里我们会遇到一个问题,在训练时,每个epoch之后我们都会保持一个模型,那究竟应该使用哪个呢? 32 | 其实这也没有一个标准答案,选不同的都试试看测试效果,然后挑一个比较好的即可。 33 | 具体的步骤与1差不多:将models文件夹中的权重,比如:[[100.pth]](https://pan.baidu.com/s/1Na5hA2-RXMovIa6J7aJzhw )(提取码:1tmc)放入./pretrained中,然后运行代码: 34 | ``` 35 | python test.py --trained_model ./pretrained/100.pth 36 | ``` 37 | 检测的结果将保存在result文件夹中供查看,可以发现检测的效果相比直接使用预训练模型有了一定的提升! 38 | ![检测结果](./readme_imgs/res_blw_1.jpg) 39 | 40 | ## 4. 将标注的点透视变换成水平文本框(供文本识别使用) 41 | ``` 42 | run divide_text_region_from_gt.py 43 | ``` 44 | 注意修改main函数中的参数配置。 45 | 46 | -------------------------------------------------------------------------------- /craft_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2019-present NAVER Corp. 3 | MIT License 4 | """ 5 | 6 | # -*- coding: utf-8 -*- 7 | import numpy as np 8 | import cv2 9 | import math 10 | 11 | """ auxilary functions """ 12 | # unwarp corodinates 13 | def warpCoord(Minv, pt): 14 | out = np.matmul(Minv, (pt[0], pt[1], 1)) 15 | return np.array([out[0]/out[2], out[1]/out[2]]) 16 | """ end of auxilary functions """ 17 | 18 | 19 | def getDetBoxes_core(textmap, linkmap, text_threshold, link_threshold, low_text): 20 | # prepare data 21 | linkmap = linkmap.copy() 22 | textmap = textmap.copy() 23 | img_h, img_w = textmap.shape 24 | 25 | """ labeling method """ 26 | ret, text_score = cv2.threshold(textmap, low_text, 1, 0) 27 | ret, link_score = cv2.threshold(linkmap, link_threshold, 1, 0) 28 | 29 | text_score_comb = np.clip(text_score + link_score, 0, 1) 30 | nLabels, labels, stats, centroids = cv2.connectedComponentsWithStats(text_score_comb.astype(np.uint8), connectivity=4) 31 | 32 | det = [] 33 | mapper = [] 34 | for k in range(1,nLabels): 35 | # size filtering 36 | size = stats[k, cv2.CC_STAT_AREA] 37 | if size < 10: continue 38 | 39 | # thresholding 40 | if np.max(textmap[labels==k]) < text_threshold: continue 41 | 42 | # make segmentation map 43 | segmap = np.zeros(textmap.shape, dtype=np.uint8) 44 | segmap[labels==k] = 255 45 | segmap[np.logical_and(link_score==1, text_score==0)] = 0 # remove link area 46 | x, y = stats[k, cv2.CC_STAT_LEFT], stats[k, cv2.CC_STAT_TOP] 47 | w, h = stats[k, cv2.CC_STAT_WIDTH], stats[k, cv2.CC_STAT_HEIGHT] 48 | niter = int(math.sqrt(size * min(w, h) / (w * h)) * 2) 49 | sx, ex, sy, ey = x - niter, x + w + niter + 1, y - niter, y + h + niter + 1 50 | # boundary check 51 | if sx < 0 : sx = 0 52 | if sy < 0 : sy = 0 53 | if ex >= img_w: ex = img_w 54 | if ey >= img_h: ey = img_h 55 | kernel = cv2.getStructuringElement(cv2.MORPH_RECT,(1 + niter, 1 + niter)) 56 | segmap[sy:ey, sx:ex] = cv2.dilate(segmap[sy:ey, sx:ex], kernel) 57 | 58 | # make box 59 | np_contours = np.roll(np.array(np.where(segmap!=0)),1,axis=0).transpose().reshape(-1,2) 60 | rectangle = cv2.minAreaRect(np_contours) 61 | box = cv2.boxPoints(rectangle) 62 | 63 | # align diamond-shape 64 | w, h = np.linalg.norm(box[0] - box[1]), np.linalg.norm(box[1] - box[2]) 65 | box_ratio = max(w, h) / (min(w, h) + 1e-5) 66 | if abs(1 - box_ratio) <= 0.1: 67 | l, r = min(np_contours[:,0]), max(np_contours[:,0]) 68 | t, b = min(np_contours[:,1]), max(np_contours[:,1]) 69 | box = np.array([[l, t], [r, t], [r, b], [l, b]], dtype=np.float32) 70 | 71 | # make clock-wise order 72 | startidx = box.sum(axis=1).argmin() 73 | box = np.roll(box, 4-startidx, 0) 74 | box = np.array(box) 75 | 76 | det.append(box) 77 | mapper.append(k) 78 | 79 | return det, labels, mapper 80 | 81 | def getPoly_core(boxes, labels, mapper, linkmap): 82 | # configs 83 | num_cp = 5 84 | max_len_ratio = 0.7 85 | expand_ratio = 1.45 86 | max_r = 2.0 87 | step_r = 0.2 88 | 89 | polys = [] 90 | for k, box in enumerate(boxes): 91 | # size filter for small instance 92 | w, h = int(np.linalg.norm(box[0] - box[1]) + 1), int(np.linalg.norm(box[1] - box[2]) + 1) 93 | if w < 30 or h < 30: 94 | polys.append(None); continue 95 | 96 | # warp image 97 | tar = np.float32([[0,0],[w,0],[w,h],[0,h]]) 98 | M = cv2.getPerspectiveTransform(box, tar) 99 | word_label = cv2.warpPerspective(labels, M, (w, h), flags=cv2.INTER_NEAREST) 100 | try: 101 | Minv = np.linalg.inv(M) 102 | except: 103 | polys.append(None); continue 104 | 105 | # binarization for selected label 106 | cur_label = mapper[k] 107 | word_label[word_label != cur_label] = 0 108 | word_label[word_label > 0] = 1 109 | 110 | """ Polygon generation """ 111 | # find top/bottom contours 112 | cp = [] 113 | max_len = -1 114 | for i in range(w): 115 | region = np.where(word_label[:,i] != 0)[0] 116 | if len(region) < 2 : continue 117 | cp.append((i, region[0], region[-1])) 118 | length = region[-1] - region[0] + 1 119 | if length > max_len: max_len = length 120 | 121 | # pass if max_len is similar to h 122 | if h * max_len_ratio < max_len: 123 | polys.append(None); continue 124 | 125 | # get pivot points with fixed length 126 | tot_seg = num_cp * 2 + 1 127 | seg_w = w / tot_seg # segment width 128 | pp = [None] * num_cp # init pivot points 129 | cp_section = [[0, 0]] * tot_seg 130 | seg_height = [0] * num_cp 131 | seg_num = 0 132 | num_sec = 0 133 | prev_h = -1 134 | for i in range(0,len(cp)): 135 | (x, sy, ey) = cp[i] 136 | if (seg_num + 1) * seg_w <= x and seg_num <= tot_seg: 137 | # average previous segment 138 | if num_sec == 0: break 139 | cp_section[seg_num] = [cp_section[seg_num][0] / num_sec, cp_section[seg_num][1] / num_sec] 140 | num_sec = 0 141 | 142 | # reset variables 143 | seg_num += 1 144 | prev_h = -1 145 | 146 | # accumulate center points 147 | cy = (sy + ey) * 0.5 148 | cur_h = ey - sy + 1 149 | cp_section[seg_num] = [cp_section[seg_num][0] + x, cp_section[seg_num][1] + cy] 150 | num_sec += 1 151 | 152 | if seg_num % 2 == 0: continue # No polygon area 153 | 154 | if prev_h < cur_h: 155 | pp[int((seg_num - 1)/2)] = (x, cy) 156 | seg_height[int((seg_num - 1)/2)] = cur_h 157 | prev_h = cur_h 158 | 159 | # processing last segment 160 | if num_sec != 0: 161 | cp_section[-1] = [cp_section[-1][0] / num_sec, cp_section[-1][1] / num_sec] 162 | 163 | # pass if num of pivots is not sufficient or segment widh is smaller than character height 164 | if None in pp or seg_w < np.max(seg_height) * 0.25: 165 | polys.append(None); continue 166 | 167 | # calc median maximum of pivot points 168 | half_char_h = np.median(seg_height) * expand_ratio / 2 169 | 170 | # calc gradiant and apply to make horizontal pivots 171 | new_pp = [] 172 | for i, (x, cy) in enumerate(pp): 173 | dx = cp_section[i * 2 + 2][0] - cp_section[i * 2][0] 174 | dy = cp_section[i * 2 + 2][1] - cp_section[i * 2][1] 175 | if dx == 0: # gradient if zero 176 | new_pp.append([x, cy - half_char_h, x, cy + half_char_h]) 177 | continue 178 | rad = - math.atan2(dy, dx) 179 | c, s = half_char_h * math.cos(rad), half_char_h * math.sin(rad) 180 | new_pp.append([x - s, cy - c, x + s, cy + c]) 181 | 182 | # get edge points to cover character heatmaps 183 | isSppFound, isEppFound = False, False 184 | grad_s = (pp[1][1] - pp[0][1]) / (pp[1][0] - pp[0][0]) + (pp[2][1] - pp[1][1]) / (pp[2][0] - pp[1][0]) 185 | grad_e = (pp[-2][1] - pp[-1][1]) / (pp[-2][0] - pp[-1][0]) + (pp[-3][1] - pp[-2][1]) / (pp[-3][0] - pp[-2][0]) 186 | for r in np.arange(0.5, max_r, step_r): 187 | dx = 2 * half_char_h * r 188 | if not isSppFound: 189 | line_img = np.zeros(word_label.shape, dtype=np.uint8) 190 | dy = grad_s * dx 191 | p = np.array(new_pp[0]) - np.array([dx, dy, dx, dy]) 192 | cv2.line(line_img, (int(p[0]), int(p[1])), (int(p[2]), int(p[3])), 1, thickness=1) 193 | if np.sum(np.logical_and(word_label, line_img)) == 0 or r + 2 * step_r >= max_r: 194 | spp = p 195 | isSppFound = True 196 | if not isEppFound: 197 | line_img = np.zeros(word_label.shape, dtype=np.uint8) 198 | dy = grad_e * dx 199 | p = np.array(new_pp[-1]) + np.array([dx, dy, dx, dy]) 200 | cv2.line(line_img, (int(p[0]), int(p[1])), (int(p[2]), int(p[3])), 1, thickness=1) 201 | if np.sum(np.logical_and(word_label, line_img)) == 0 or r + 2 * step_r >= max_r: 202 | epp = p 203 | isEppFound = True 204 | if isSppFound and isEppFound: 205 | break 206 | 207 | # pass if boundary of polygon is not found 208 | if not (isSppFound and isEppFound): 209 | polys.append(None); continue 210 | 211 | # make final polygon 212 | poly = [] 213 | poly.append(warpCoord(Minv, (spp[0], spp[1]))) 214 | for p in new_pp: 215 | poly.append(warpCoord(Minv, (p[0], p[1]))) 216 | poly.append(warpCoord(Minv, (epp[0], epp[1]))) 217 | poly.append(warpCoord(Minv, (epp[2], epp[3]))) 218 | for p in reversed(new_pp): 219 | poly.append(warpCoord(Minv, (p[2], p[3]))) 220 | poly.append(warpCoord(Minv, (spp[2], spp[3]))) 221 | 222 | # add to final result 223 | polys.append(np.array(poly)) 224 | 225 | return polys 226 | 227 | def getDetBoxes(textmap, linkmap, text_threshold, link_threshold, low_text, poly=False): 228 | boxes, labels, mapper = getDetBoxes_core(textmap, linkmap, text_threshold, link_threshold, low_text) 229 | 230 | if poly: 231 | polys = getPoly_core(boxes, labels, mapper, linkmap) 232 | else: 233 | polys = [None] * len(boxes) 234 | 235 | return boxes, polys 236 | 237 | def adjustResultCoordinates(polys, ratio_w, ratio_h, ratio_net = 2): 238 | if len(polys) > 0: 239 | polys = np.array(polys) 240 | for k in range(len(polys)): 241 | if polys[k] is not None: 242 | polys[k] *= (ratio_w * ratio_net, ratio_h * ratio_net) 243 | return polys 244 | -------------------------------------------------------------------------------- /divide_text_region_from_gt.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import matplotlib.pyplot as plt 4 | import cv2 5 | import json 6 | 7 | import craft_utils 8 | import imgproc 9 | 10 | 11 | # bboxes, polys, score_text = test_net(net, image, args.text_threshold, 12 | # args.link_threshold, args.low_text, args.cuda, args.poly) 13 | if __name__ == '__main__': 14 | # 参数设置 15 | canvas_size = 1280 16 | mag_ratio = 1.0 # 图像放大倍数 17 | text_threshold = 0.7 # region_map阈值 18 | low_text = 0.4 # text low-bound score 19 | link_threshold = 0.001 # affinity_map阈值 20 | poly = False # 是否输出多边形框,默认输出四个点的框 21 | 22 | 23 | root = './data' 24 | img_list = os.listdir(os.path.join(root,'img')) 25 | 26 | for img_path in img_list: 27 | image_path = os.path.join(root, 'img', img_path) 28 | affinity_path = os.path.join(root, 'affinity', img_path.split('_')[0] + '_affinity_' \ 29 | + img_path.split('_')[1].split('.')[0] + '.npy') 30 | region_path = os.path.join(root, 'region', img_path.split('_')[0] + '_region_' \ 31 | + img_path.split('_')[1].split('.')[0] + '.npy') 32 | anno_path = os.path.join(root, 'anno', img_path.split('.')[0] + '.json') 33 | 34 | image = np.array(plt.imread(image_path)) # 225*517*3 35 | region = np.load(region_path) 36 | affinity = np.load(affinity_path) 37 | 38 | # resize 39 | # img_resized=352*800*3, target_ratio=1.5 40 | # size_heatmap=400*176, ratio_h=w=0.66666667 41 | # img_resized, target_ratio, size_heatmap = imgproc.resize_aspect_ratio(image, canvas_size, interpolation=cv2.INTER_LINEAR, mag_ratio=mag_ratio) 42 | # ratio_h = ratio_w = 1 / target_ratio 43 | # plt.imshow(img_resized.astype(np.int)) 44 | # region = cv2.resize(region,(img_resized.shape[1]//2,img_resized.shape[0]//2)) 45 | # affinity = cv2.resize(affinity,(img_resized.shape[1]//2,img_resized.shape[0]//2)) 46 | 47 | # Post-processing 48 | boxes, polys = craft_utils.getDetBoxes(region, affinity, text_threshold, link_threshold, low_text, poly) 49 | 50 | # coordinate adjustment 51 | # boxes = craft_utils.adjustResultCoordinates(boxes, ratio_w, ratio_h) 52 | # polys = craft_utils.adjustResultCoordinates(polys, ratio_w, ratio_h) 53 | # for k in range(len(polys)): 54 | # if polys[k] is None: 55 | # polys[k] = boxes[k] 56 | 57 | # render results (optional) 58 | render_img = region.copy() 59 | render_img = np.hstack((render_img, affinity)) 60 | ret_score_text = imgproc.cvt2HeatmapImg(render_img) 61 | for i, box in enumerate(boxes): 62 | _,(kernel_w,kernel_h),_ = cv2.minAreaRect(box) # 得到最小外接矩形的(中心(x,y), (宽,高), 旋转角度) 63 | kernel_w,kernel_h = int(kernel_w),int(kernel_h) 64 | if kernel_w < kernel_h: 65 | kernel_w,kernel_h = kernel_h,kernel_w 66 | 67 | box = np.array(box).astype(np.int32).reshape((-1)) 68 | box = box.reshape(-1, 2) 69 | # cv2.polylines(image, [box.reshape((-1, 1, 2))], True, color=(0, 0, 255), thickness=2) 70 | # 将高斯核透视变换,坐标(列,行)[box.reshape((-1, 1, 2))] 71 | src = np.float32(box) # 左上,左下,右下,右上 72 | tgt = np.float32([(0,0),(kernel_w,0),(kernel_w,kernel_h),(0,kernel_h)]) 73 | M = cv2.getPerspectiveTransform(src, tgt) 74 | dst = cv2.warpPerspective(image, M, (kernel_w,kernel_h)) # dst就是所要的文本图像 75 | 76 | '''读取标注文件中的字符''' 77 | f=open(anno_path,encoding='utf-8') 78 | anno = json.load(f) 79 | shapes = anno['shapes'] 80 | 81 | text = [] 82 | for s in shapes: 83 | text.append(s['label']) 84 | text = ''.join(text) 85 | 86 | save_path = './divides' 87 | if not os.path.exists(save_path): 88 | os.mkdir(save_path) 89 | print(text) 90 | cv2.imwrite(os.path.join(save_path,text+'.jpg'),dst) 91 | 92 | # cv2.imshow('win',dst) 93 | # if cv2.waitKey() == 0xFF & ord('q'): 94 | # cv2.destroyAllWindows() 95 | # import sys 96 | # sys.exit() 97 | # break 98 | 99 | # break -------------------------------------------------------------------------------- /file_utils.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import os 3 | import numpy as np 4 | import cv2 5 | import imgproc 6 | 7 | # borrowed from https://github.com/lengstrom/fast-style-transfer/blob/master/src/utils.py 8 | def get_files(img_dir): 9 | imgs, masks, xmls = list_files(img_dir) 10 | return imgs, masks, xmls 11 | 12 | def list_files(in_path): 13 | img_files = [] 14 | mask_files = [] 15 | gt_files = [] 16 | for (dirpath, dirnames, filenames) in os.walk(in_path): 17 | for file in filenames: 18 | filename, ext = os.path.splitext(file) 19 | ext = str.lower(ext) 20 | if ext == '.jpg' or ext == '.jpeg' or ext == '.gif' or ext == '.png' or ext == '.pgm': 21 | img_files.append(os.path.join(dirpath, file)) 22 | elif ext == '.bmp': 23 | mask_files.append(os.path.join(dirpath, file)) 24 | elif ext == '.xml' or ext == '.gt' or ext == '.txt': 25 | gt_files.append(os.path.join(dirpath, file)) 26 | elif ext == '.zip': 27 | continue 28 | # img_files.sort() 29 | # mask_files.sort() 30 | # gt_files.sort() 31 | return img_files, mask_files, gt_files 32 | 33 | def saveResult(img_file, img, boxes, dirname='./result/', verticals=None, texts=None): 34 | """ save text detection result one by one 35 | Args: 36 | img_file (str): image file name 37 | img (array): raw image context 38 | boxes (array): array of result file 39 | Shape: [num_detections, 4] for BB output / [num_detections, 4] for QUAD output 40 | Return: 41 | None 42 | """ 43 | img = np.array(img) 44 | 45 | # make result file list 46 | filename, file_ext = os.path.splitext(os.path.basename(img_file)) 47 | 48 | # result directory 49 | res_file = dirname + "res_" + filename + '.txt' 50 | res_img_file = dirname + "res_" + filename + '.jpg' 51 | 52 | if not os.path.isdir(dirname): 53 | os.mkdir(dirname) 54 | 55 | with open(res_file, 'w') as f: 56 | for i, box in enumerate(boxes): 57 | poly = np.array(box).astype(np.int32).reshape((-1)) 58 | strResult = ','.join([str(p) for p in poly]) + '\r\n' 59 | f.write(strResult) 60 | 61 | poly = poly.reshape(-1, 2) 62 | cv2.polylines(img, [poly.reshape((-1, 1, 2))], True, color=(0, 0, 255), thickness=2) 63 | ptColor = (0, 255, 255) 64 | if verticals is not None: 65 | if verticals[i]: 66 | ptColor = (255, 0, 0) 67 | 68 | if texts is not None: 69 | font = cv2.FONT_HERSHEY_SIMPLEX 70 | font_scale = 0.5 71 | cv2.putText(img, "{}".format(texts[i]), (poly[0][0]+1, poly[0][1]+1), font, font_scale, (0, 0, 0), thickness=1) 72 | cv2.putText(img, "{}".format(texts[i]), tuple(poly[0]), font, font_scale, (0, 255, 255), thickness=1) 73 | 74 | # Save result image 75 | cv2.imwrite(res_img_file, img) 76 | 77 | -------------------------------------------------------------------------------- /generate_score_map.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Tue Aug 6 10:13:25 2019 4 | 5 | @author: Ma Zhenwei 6 | """ 7 | import os 8 | import json 9 | import matplotlib.pyplot as plt 10 | import matplotlib.cm as CM 11 | import numpy as np 12 | from scipy.ndimage.filters import gaussian_filter 13 | import cv2 14 | 15 | 16 | def generate_region_score(img, shapes): 17 | ''' 18 | img:维度(h, w, c) 19 | shapes:标注文件中的shapes 20 | ''' 21 | h, w, c = img.shape 22 | region_map = np.zeros((h, w), dtype=float) 23 | 24 | for s in shapes: 25 | points = s['points'] # 4个点 26 | # 将顺时针的标注 变成 逆时针 以符合cv2函数的要求 27 | points = [(int(points[0][0]),int(points[0][1])),(int(points[3][0]),int(points[3][1])), 28 | (int(points[2][0]),int(points[2][1])),(int(points[1][0]),int(points[1][1]))] 29 | 30 | dst = generate_transformed_gaussian_kernel(h, w, points) 31 | 32 | # 叠加到 region_map 33 | region_map += dst 34 | return region_map 35 | 36 | 37 | def generate_affinity_score(img, shapes): 38 | ''' 39 | img:维度(h, w, c) 40 | shapes:标注文件中的shapes 41 | ''' 42 | h, w, c = img.shape 43 | affinity_map = np.zeros((h, w), dtype=float) 44 | 45 | for i in range(len(shapes)-1): 46 | # 第一个字符位置 & 第二个字符位置 47 | points1 = np.float32(shapes[i]['points']) 48 | points2 = np.float32(shapes[i+1]['points']) 49 | # 第一个字符中心 & 第二个字符中心 50 | center1 = np.sum(np.array(points1),axis=0) / 4 51 | center2 = np.sum(np.array(points2),axis=0) / 4 52 | # 生成affinity box的4个顶点 53 | top_left = (points1[0] + points1[1] + center1) /3 54 | top_right = (points2[0] + points2[1] + center2) /3 55 | down_left = (points1[2] + points1[3] + center1) /3 56 | down_right = (points2[2] + points2[3] + center2) /3 57 | points = np.float32([top_left, down_left, down_right, top_right]) 58 | dst = generate_transformed_gaussian_kernel(h, w, points) 59 | 60 | affinity_map += dst 61 | return affinity_map 62 | 63 | 64 | def generate_transformed_gaussian_kernel(h, w, points): 65 | ''' 66 | 使用透视变换的高斯核建模region或affinity 67 | h:图像的高 68 | w:图像的宽 69 | points:维度(4,2) 70 | ''' 71 | # 生成高斯核 72 | minX, minY = points[0] 73 | maxX, maxY = points[0] 74 | for i in range(1,4): 75 | minX = min(points[i][0],minX) 76 | minY = min(points[i][1],minY) 77 | maxX = max(points[i][0],maxX) 78 | maxY = max(points[i][1],maxY) 79 | kernel_w = int((maxX - minX + 1) // 2 * 2) 80 | kernel_h = int((maxY - minY + 1) // 2 * 2) 81 | 82 | kernel_size = 31 83 | kernel = np.zeros((kernel_size, kernel_size)) 84 | kernel[kernel_size//2, kernel_size//2] = 1 85 | kernel = gaussian_filter(kernel, 10, mode='constant') 86 | 87 | kernel_size = max(kernel_h, kernel_w) 88 | kernel = cv2.resize(kernel,(kernel_size,kernel_size)) 89 | 90 | # 将高斯核透视变换,坐标(列,行) 91 | src = np.float32([(0,0),(0,kernel_size),(kernel_size,kernel_size),(kernel_size,0)]) # 左上,左下,右下,右上 92 | tgt = np.float32(points) 93 | M = cv2.getPerspectiveTransform(src, tgt) 94 | dst = cv2.warpPerspective(kernel, M, (w,h)) 95 | 96 | # 转换到[0.001,1]之间 97 | mini = dst[np.where(dst>0)].min() 98 | maxi = dst[np.where(dst>0)].max() 99 | h = 1 100 | l = 0.001 # 与预训练模型的分布保持一致 101 | dst[np.where(dst>0)] = ((h-l)*dst[np.where(dst>0)]-h*mini+l*maxi) / (maxi-mini) 102 | 103 | return dst 104 | 105 | 106 | if __name__ == '__main__': 107 | # 注意:标注是顺时针方向,4个顶点 108 | 109 | name = 'ydc' 110 | root = './data/'+name 111 | for c in os.listdir(root): 112 | if '.json' in c: 113 | continue 114 | if '.npy' in c: 115 | continue 116 | 117 | img_path = os.path.join(root, c) 118 | anno_path = img_path.replace('.jpg','.json') 119 | 120 | img = plt.imread(img_path) 121 | 122 | f=open(anno_path,encoding='utf-8') 123 | anno = json.load(f) 124 | shapes = anno['shapes'] 125 | 126 | region_map = generate_region_score(img,shapes) 127 | affinity_map = generate_affinity_score(img,shapes) 128 | np.save(os.path.join(root, name+'_region_'+(c.split('.')[0]).split('_')[1]+'.npy'), region_map) 129 | np.save(os.path.join(root, name+'_affinity_'+(c.split('.')[0]).split('_')[1]+'.npy'), affinity_map) 130 | print(c) -------------------------------------------------------------------------------- /imgproc.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2019-present NAVER Corp. 3 | MIT License 4 | """ 5 | 6 | # -*- coding: utf-8 -*- 7 | import numpy as np 8 | from skimage import io 9 | import cv2 10 | 11 | def loadImage(img_file): 12 | img = io.imread(img_file) # RGB order 13 | if img.shape[0] == 2: img = img[0] 14 | if len(img.shape) == 2 : img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB) 15 | if img.shape[2] == 4: img = img[:,:,:3] 16 | img = np.array(img) 17 | 18 | return img 19 | 20 | def normalizeMeanVariance(in_img, mean=(0.485, 0.456, 0.406), variance=(0.229, 0.224, 0.225)): 21 | # should be RGB order 22 | img = in_img.copy().astype(np.float32) 23 | 24 | img -= np.array([mean[0] * 255.0, mean[1] * 255.0, mean[2] * 255.0], dtype=np.float32) 25 | img /= np.array([variance[0] * 255.0, variance[1] * 255.0, variance[2] * 255.0], dtype=np.float32) 26 | return img 27 | 28 | def denormalizeMeanVariance(in_img, mean=(0.485, 0.456, 0.406), variance=(0.229, 0.224, 0.225)): 29 | # should be RGB order 30 | img = in_img.copy() 31 | img *= variance 32 | img += mean 33 | img *= 255.0 34 | img = np.clip(img, 0, 255).astype(np.uint8) 35 | return img 36 | 37 | def resize_aspect_ratio(img, square_size, interpolation, mag_ratio=1): 38 | height, width, channel = img.shape 39 | 40 | # magnify image size 41 | target_size = mag_ratio * max(height, width) 42 | 43 | # set original image size 44 | if target_size > square_size: 45 | target_size = square_size 46 | 47 | ratio = target_size / max(height, width) 48 | 49 | target_h, target_w = int(height * ratio), int(width * ratio) 50 | proc = cv2.resize(img, (target_w, target_h), interpolation = interpolation) 51 | 52 | 53 | # make canvas and paste image 54 | target_h32, target_w32 = target_h, target_w 55 | if target_h % 32 != 0: 56 | target_h32 = target_h + (32 - target_h % 32) 57 | if target_w % 32 != 0: 58 | target_w32 = target_w + (32 - target_w % 32) 59 | resized = np.zeros((target_h32, target_w32, channel), dtype=np.float32) 60 | resized[0:target_h, 0:target_w, :] = proc 61 | target_h, target_w = target_h32, target_w32 62 | 63 | size_heatmap = (int(target_w/2), int(target_h/2)) 64 | 65 | return resized, ratio, size_heatmap 66 | 67 | def cvt2HeatmapImg(img): 68 | img = (np.clip(img, 0, 1) * 255).astype(np.uint8) 69 | img = cv2.applyColorMap(img, cv2.COLORMAP_JET) 70 | return img 71 | -------------------------------------------------------------------------------- /my_dataset.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset 2 | import os 3 | import matplotlib.pyplot as plt 4 | import matplotlib.cm as CM 5 | import numpy as np 6 | import torch 7 | import cv2 8 | 9 | 10 | class MyDataset(Dataset): 11 | def __init__(self, root): 12 | self.root = root 13 | self.imglist = [f.split('.')[0] for f in os.listdir(os.path.join(root, 'img'))] 14 | 15 | def __getitem__(self, index): 16 | # read img, region_map, affinity_map 17 | img_path = os.path.join(self.root, 'img', self.imglist[index]+'.jpg') 18 | # img = plt.imread(img_path) 19 | img = np.array(plt.imread(img_path)) 20 | 21 | region_path = os.path.join(self.root, 'region', 22 | self.imglist[index].split('_')[0]+'_region_' 23 | +self.imglist[index].split('_')[1]+'.npy') 24 | region_map = np.load(region_path).astype(np.float32) 25 | 26 | affinity_path = os.path.join(self.root, 'affinity', 27 | self.imglist[index].split('_')[0]+'_affinity_' 28 | +self.imglist[index].split('_')[1]+'.npy') 29 | affinity_map = np.load(affinity_path).astype(np.float32) 30 | 31 | # 保证图像长和宽是2的倍数 32 | h, w, c = img.shape 33 | if h % 2 != 0 or w % 2 != 0: 34 | h = int(h // 2 * 2) 35 | w = int(w // 2 * 2) 36 | img = cv2.resize(img, (w, h)) 37 | region_map = cv2.resize(region_map, (w, h)) 38 | affinity_map = cv2.resize(affinity_map, (w, h)) 39 | 40 | # preprocess 41 | img = normalizeMeanVariance(img) 42 | img = torch.from_numpy(img).permute(2, 0, 1) # [h, w, c] to [c, h, w] 43 | 44 | region_map = cv2.resize(region_map, (w//2, h//2)) 45 | region_map = torch.tensor(region_map).unsqueeze(2) 46 | affinity_map = cv2.resize(affinity_map, (w//2, h//2)) 47 | affinity_map = torch.tensor(affinity_map).unsqueeze(2) 48 | gt_map = torch.cat((region_map,affinity_map), dim=2) 49 | 50 | return {'img':img, 'gt':gt_map} 51 | 52 | 53 | def __len__(self): 54 | return len(self.imglist) 55 | 56 | def normalizeMeanVariance(in_img, mean=(0.485, 0.456, 0.406), variance=(0.229, 0.224, 0.225)): 57 | # should be RGB order 58 | img = in_img.copy().astype(np.float32) 59 | img -= np.array([mean[0] * 255.0, mean[1] * 255.0, mean[2] * 255.0], dtype=np.float32) 60 | img /= np.array([variance[0] * 255.0, variance[1] * 255.0, variance[2] * 255.0], dtype=np.float32) 61 | return img 62 | 63 | 64 | if __name__ == '__main__': 65 | d = MyDataset('./blw') 66 | for i, data in enumerate(d): 67 | img = data['img'] 68 | # plt.imshow(img) 69 | # plt.figure() 70 | gt = data['gt'] 71 | # plt.imshow(region,cmap=CM.jet) 72 | print(gt.max()) 73 | print(gt.shape) 74 | break 75 | -------------------------------------------------------------------------------- /my_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from vgg16_bn import vgg16_bn, init_weights 6 | 7 | class double_conv(nn.Module): 8 | def __init__(self, in_ch, mid_ch, out_ch): 9 | super(double_conv, self).__init__() 10 | self.conv = nn.Sequential( 11 | nn.Conv2d(in_ch + mid_ch, mid_ch, kernel_size=1), 12 | nn.BatchNorm2d(mid_ch), 13 | nn.ReLU(inplace=True), 14 | nn.Conv2d(mid_ch, out_ch, kernel_size=3, padding=1), 15 | nn.BatchNorm2d(out_ch), 16 | nn.ReLU(inplace=True) 17 | ) 18 | 19 | def forward(self, x): 20 | x = self.conv(x) 21 | return x 22 | 23 | 24 | class CRAFT(nn.Module): 25 | def __init__(self, pretrained=False, freeze=False, phase='test'): 26 | super(CRAFT, self).__init__() 27 | 28 | """ Base network """ 29 | self.basenet = vgg16_bn(pretrained, freeze) 30 | 31 | """ 固定部分参数,用于迁移学习""" 32 | if phase == 'train': 33 | for p in self.parameters(): 34 | p.requires_grad=False 35 | 36 | """ U network """ 37 | self.upconv1 = double_conv(1024, 512, 256) 38 | self.upconv2 = double_conv(512, 256, 128) 39 | self.upconv3 = double_conv(256, 128, 64) 40 | self.upconv4 = double_conv(128, 64, 32) 41 | 42 | num_class = 2 43 | self.conv_cls = nn.Sequential( 44 | nn.Conv2d(32, 32, kernel_size=3, padding=1), nn.ReLU(inplace=True), 45 | nn.Conv2d(32, 32, kernel_size=3, padding=1), nn.ReLU(inplace=True), 46 | nn.Conv2d(32, 16, kernel_size=3, padding=1), nn.ReLU(inplace=True), 47 | nn.Conv2d(16, 16, kernel_size=1), nn.ReLU(inplace=True), 48 | nn.Conv2d(16, num_class, kernel_size=1), 49 | ) 50 | 51 | init_weights(self.upconv1.modules()) 52 | init_weights(self.upconv2.modules()) 53 | init_weights(self.upconv3.modules()) 54 | init_weights(self.upconv4.modules()) 55 | init_weights(self.conv_cls.modules()) 56 | 57 | def forward(self, x): 58 | """ Base network """ 59 | sources = self.basenet(x) 60 | 61 | """ U network """ 62 | y = torch.cat([sources[0], sources[1]], dim=1) 63 | y = self.upconv1(y) 64 | 65 | y = F.interpolate(y, size=sources[2].size()[2:], mode='bilinear', align_corners=False) 66 | y = torch.cat([y, sources[2]], dim=1) 67 | y = self.upconv2(y) 68 | 69 | y = F.interpolate(y, size=sources[3].size()[2:], mode='bilinear', align_corners=False) 70 | y = torch.cat([y, sources[3]], dim=1) 71 | y = self.upconv3(y) 72 | 73 | y = F.interpolate(y, size=sources[4].size()[2:], mode='bilinear', align_corners=False) 74 | y = torch.cat([y, sources[4]], dim=1) 75 | feature = self.upconv4(y) 76 | 77 | y = self.conv_cls(feature) 78 | 79 | return y.permute(0,2,3,1), feature 80 | 81 | if __name__ == '__main__': 82 | model = CRAFT(pretrained=False) 83 | output, _ = model(torch.randn(1, 3, 768, 768)) 84 | print(output.shape) -------------------------------------------------------------------------------- /readme_imgs/res_blw_1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CommissarMa/pytorch-CRAFT/a6dabd98cec7c52ebe3ff806a62aca9bc137250e/readme_imgs/res_blw_1.jpg -------------------------------------------------------------------------------- /readme_imgs/标注.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CommissarMa/pytorch-CRAFT/a6dabd98cec7c52ebe3ff806a62aca9bc137250e/readme_imgs/标注.png -------------------------------------------------------------------------------- /rename_filename.py: -------------------------------------------------------------------------------- 1 | import os 2 | from tqdm import tqdm 3 | 4 | 5 | for d in os.listdir('dataset'): 6 | print('d') 7 | i = 1 8 | for f in tqdm(os.listdir(os.path.join('dataset',d))): 9 | os.rename(os.path.join('dataset',d,f),os.path.join('dataset',d,d+'_'+str(i)+'.jpg')) 10 | i += 1 11 | 12 | 13 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | import time 4 | import argparse 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.backends.cudnn as cudnn 9 | from torch.autograd import Variable 10 | 11 | from PIL import Image 12 | 13 | import cv2 14 | import matplotlib.pyplot as plt 15 | #from skimage import io 16 | import numpy as np 17 | import craft_utils 18 | import imgproc 19 | import file_utils 20 | import json 21 | import zipfile 22 | 23 | from my_model import CRAFT 24 | 25 | from collections import OrderedDict 26 | def copyStateDict(state_dict): 27 | if list(state_dict.keys())[0].startswith("module"): 28 | start_idx = 1 29 | else: 30 | start_idx = 0 31 | new_state_dict = OrderedDict() 32 | for k, v in state_dict.items(): 33 | name = ".".join(k.split(".")[start_idx:]) 34 | new_state_dict[name] = v 35 | return new_state_dict 36 | 37 | def str2bool(v): 38 | return v.lower() in ("yes", "y", "true", "t", "1") 39 | 40 | parser = argparse.ArgumentParser(description='CRAFT Text Detection') 41 | #parser.add_argument('--trained_model', default='./pretrained/craft_mlt_25k.pth', type=str, help='pretrained model') 42 | parser.add_argument('--trained_model', default='./pretrained/100.pth', type=str, help='pretrained model') 43 | parser.add_argument('--text_threshold', default=0.7, type=float, help='text confidence threshold') 44 | parser.add_argument('--low_text', default=0.4, type=float, help='text low-bound score') 45 | parser.add_argument('--link_threshold', default=0.001, type=float, help='link confidence threshold') 46 | parser.add_argument('--cuda', default=True, type=str2bool, help='Use cuda to train model') 47 | parser.add_argument('--canvas_size', default=1280, type=int, help='image size for inference') 48 | parser.add_argument('--mag_ratio', default=1.5, type=float, help='image magnification ratio') 49 | parser.add_argument('--poly', default=False, action='store_true', help='enable polygon type') 50 | parser.add_argument('--show_time', default=False, action='store_true', help='show processing time') 51 | parser.add_argument('--test_folder', default='./imgs', type=str, help='folder path to input images') 52 | 53 | args = parser.parse_args() 54 | 55 | 56 | """ For test images in a folder """ 57 | image_list, _, _ = file_utils.get_files(args.test_folder) 58 | 59 | result_folder = './result/' 60 | if not os.path.isdir(result_folder): 61 | os.mkdir(result_folder) 62 | 63 | def test_net(net, image, text_threshold, link_threshold, low_text, cuda, poly): 64 | t0 = time.time() 65 | 66 | # resize 67 | img_resized, target_ratio, size_heatmap = imgproc.resize_aspect_ratio(image, args.canvas_size, interpolation=cv2.INTER_LINEAR, mag_ratio=args.mag_ratio) 68 | ratio_h = ratio_w = 1 / target_ratio 69 | 70 | # preprocessing 71 | x = imgproc.normalizeMeanVariance(img_resized) 72 | x = torch.from_numpy(x).permute(2, 0, 1) # [h, w, c] to [c, h, w] 73 | x = Variable(x.unsqueeze(0)) # [c, h, w] to [b, c, h, w] 74 | if cuda: 75 | x = x.cuda() 76 | 77 | # forward pass 78 | y, _ = net(x) 79 | 80 | # make score and link map 81 | score_text = y[0,:,:,0].cpu().data.numpy() 82 | score_link = y[0,:,:,1].cpu().data.numpy() 83 | 84 | t0 = time.time() - t0 85 | t1 = time.time() 86 | 87 | # Post-processing 88 | boxes, polys = craft_utils.getDetBoxes(score_text, score_link, text_threshold, link_threshold, low_text, poly) 89 | 90 | # coordinate adjustment 91 | boxes = craft_utils.adjustResultCoordinates(boxes, ratio_w, ratio_h) 92 | polys = craft_utils.adjustResultCoordinates(polys, ratio_w, ratio_h) 93 | for k in range(len(polys)): 94 | if polys[k] is None: polys[k] = boxes[k] 95 | 96 | t1 = time.time() - t1 97 | 98 | # render results (optional) 99 | render_img = score_text.copy() 100 | render_img = np.hstack((render_img, score_link)) 101 | ret_score_text = imgproc.cvt2HeatmapImg(render_img) 102 | 103 | if args.show_time : print("\ninfer/postproc time : {:.3f}/{:.3f}".format(t0, t1)) 104 | 105 | return boxes, polys, ret_score_text 106 | 107 | 108 | 109 | if __name__ == '__main__': 110 | args.cuda = torch.cuda.is_available() # 自动判断使用CPU还是CUDA 111 | 112 | # load net 113 | net = CRAFT() # initialize 114 | 115 | print('Loading weights from checkpoint (' + args.trained_model + ')') 116 | if args.cuda: 117 | net.load_state_dict(copyStateDict(torch.load(args.trained_model))) 118 | else: 119 | net.load_state_dict(copyStateDict(torch.load(args.trained_model, map_location='cpu'))) 120 | 121 | if args.cuda: 122 | net = net.cuda() 123 | net = torch.nn.DataParallel(net) 124 | cudnn.benchmark = False 125 | 126 | net.eval() 127 | 128 | t = time.time() 129 | 130 | # load data 131 | for k, image_path in enumerate(image_list): 132 | print("Test image {:d}/{:d}: {:s}".format(k+1, len(image_list), image_path), end='\r') 133 | image = imgproc.loadImage(image_path) 134 | 135 | bboxes, polys, score_text = test_net(net, image, args.text_threshold, args.link_threshold, args.low_text, args.cuda, args.poly) 136 | 137 | # save score text 138 | filename, file_ext = os.path.splitext(os.path.basename(image_path)) 139 | mask_file = result_folder + "/res_" + filename + '_mask.jpg' 140 | cv2.imwrite(mask_file, score_text) 141 | 142 | file_utils.saveResult(image_path, image[:,:,::-1], polys, dirname=result_folder) 143 | 144 | print("elapsed time : {}s".format(time.time() - t)) 145 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import DataLoader 2 | from my_dataset import MyDataset 3 | from my_model import CRAFT 4 | import torch.nn as nn 5 | import torch 6 | import os 7 | from collections import OrderedDict 8 | import numpy as np 9 | 10 | 11 | def copyStateDict(state_dict): 12 | if list(state_dict.keys())[0].startswith("module"): 13 | start_idx = 1 14 | else: 15 | start_idx = 0 16 | new_state_dict = OrderedDict() 17 | for k, v in state_dict.items(): 18 | name = ".".join(k.split(".")[start_idx:]) 19 | new_state_dict[name] = v 20 | return new_state_dict 21 | 22 | if __name__ == '__main__': 23 | """参数设置""" 24 | device = 'cuda' # cpu 或 cuda 25 | dataset_path = './data' # 自己数据集的路径 26 | pretrained_path = './pretrained/craft_mlt_25k.pth' # 预训练模型的存放路径 27 | model_path = './models' # 现在训练的模型要存储的路径 28 | 29 | 30 | dataset = MyDataset(dataset_path) 31 | loader = DataLoader(dataset, batch_size=1, shuffle=True) 32 | net = CRAFT(phase='train').to(device) 33 | net.load_state_dict(copyStateDict(torch.load(pretrained_path, map_location=device))) 34 | criterion=nn.MSELoss(size_average=False).to(device) 35 | optimizer=torch.optim.SGD(filter(lambda p: p.requires_grad, net.parameters()),1e-7, 36 | momentum=0.95, 37 | weight_decay=0) 38 | if not os.path.exists(model_path): 39 | os.mkdir(model_path) 40 | 41 | for epoch in range(500): 42 | epoch_loss = 0 43 | for i, data in enumerate(loader): 44 | img = data['img'].to(device) 45 | gt = data['gt'].to(device) 46 | 47 | # forward 48 | y, _ = net(img) 49 | loss = criterion(y, gt) 50 | optimizer.zero_grad() 51 | loss.backward() 52 | optimizer.step() 53 | epoch_loss += loss.detach() 54 | print('epoch loss_'+str(epoch),':',epoch_loss/len(loader)) 55 | torch.save(net.state_dict(), os.path.join(model_path,str(epoch)+'.pth')) 56 | -------------------------------------------------------------------------------- /vgg16_bn.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.init as init 6 | from torchvision import models 7 | from torchvision.models.vgg import model_urls 8 | 9 | def init_weights(modules): 10 | for m in modules: 11 | if isinstance(m, nn.Conv2d): 12 | init.xavier_uniform_(m.weight.data) 13 | if m.bias is not None: 14 | m.bias.data.zero_() 15 | elif isinstance(m, nn.BatchNorm2d): 16 | m.weight.data.fill_(1) 17 | m.bias.data.zero_() 18 | elif isinstance(m, nn.Linear): 19 | m.weight.data.normal_(0, 0.01) 20 | m.bias.data.zero_() 21 | 22 | class vgg16_bn(torch.nn.Module): 23 | def __init__(self, pretrained=True, freeze=True): 24 | super(vgg16_bn, self).__init__() 25 | model_urls['vgg16_bn'] = model_urls['vgg16_bn'].replace('https://', 'http://') 26 | vgg_pretrained_features = models.vgg16_bn(pretrained=pretrained).features 27 | self.slice1 = torch.nn.Sequential() 28 | self.slice2 = torch.nn.Sequential() 29 | self.slice3 = torch.nn.Sequential() 30 | self.slice4 = torch.nn.Sequential() 31 | self.slice5 = torch.nn.Sequential() 32 | for x in range(12): # conv2_2 33 | self.slice1.add_module(str(x), vgg_pretrained_features[x]) 34 | for x in range(12, 19): # conv3_3 35 | self.slice2.add_module(str(x), vgg_pretrained_features[x]) 36 | for x in range(19, 29): # conv4_3 37 | self.slice3.add_module(str(x), vgg_pretrained_features[x]) 38 | for x in range(29, 39): # conv5_3 39 | self.slice4.add_module(str(x), vgg_pretrained_features[x]) 40 | 41 | # fc6, fc7 without atrous conv 42 | self.slice5 = torch.nn.Sequential( 43 | nn.MaxPool2d(kernel_size=3, stride=1, padding=1), 44 | nn.Conv2d(512, 1024, kernel_size=3, padding=6, dilation=6), 45 | nn.Conv2d(1024, 1024, kernel_size=1) 46 | ) 47 | 48 | if not pretrained: 49 | init_weights(self.slice1.modules()) 50 | init_weights(self.slice2.modules()) 51 | init_weights(self.slice3.modules()) 52 | init_weights(self.slice4.modules()) 53 | 54 | init_weights(self.slice5.modules()) # no pretrained model for fc6 and fc7 55 | 56 | if freeze: 57 | for param in self.slice1.parameters(): # only first conv 58 | param.requires_grad= False 59 | 60 | def forward(self, X): 61 | h = self.slice1(X) 62 | h_relu2_2 = h 63 | h = self.slice2(h) 64 | h_relu3_2 = h 65 | h = self.slice3(h) 66 | h_relu4_3 = h 67 | h = self.slice4(h) 68 | h_relu5_3 = h 69 | h = self.slice5(h) 70 | h_fc7 = h 71 | vgg_outputs = namedtuple("VggOutputs", ['fc7', 'relu5_3', 'relu4_3', 'relu3_2', 'relu2_2']) 72 | out = vgg_outputs(h_fc7, h_relu5_3, h_relu4_3, h_relu3_2, h_relu2_2) 73 | return out 74 | --------------------------------------------------------------------------------