├── recognize ├── .DS_Store ├── alphabet.pkl ├── config.py ├── keys.py ├── crnn_recognizer.py └── crnn.py ├── test_images ├── t1.png ├── t2.png ├── t3.png ├── t4.png ├── t5.png └── .DS_Store ├── test_result └── t1.jpg ├── checkpoints └── .DS_Store ├── train_code ├── .DS_Store ├── train_crnn │ ├── t1.jpg │ ├── .DS_Store │ ├── alphabet.pkl │ ├── crnn_models │ │ └── .DS_Store │ ├── split_train_test.py │ ├── config.py │ ├── readme.md │ ├── keys.py │ ├── recognizer.py │ ├── online_test.py │ ├── crnn_recognizer.py │ ├── trans_utils.py │ ├── train_warp_ctc.py │ ├── train_warp_ctc_v2.py │ ├── train_pytorch_ctc.py │ ├── crnn.py │ ├── utils.py │ ├── trans.py │ └── mydataset.py └── train_ctpn │ ├── .DS_Store │ ├── __pycache__ │ ├── config.cpython-37.pyc │ ├── ctpn_model.cpython-37.pyc │ └── ctpn_utils.cpython-37.pyc │ ├── data │ ├── __pycache__ │ │ ├── __init__.cpython-36.pyc │ │ └── dataset.cpython-36.pyc │ ├── __init__.py │ └── dataset.py │ ├── readme.md │ ├── config.py │ ├── ctpn_predict.py │ ├── ctpn_train.py │ ├── ctpn_model_v2.py │ ├── ctpn_model.py │ └── ctpn_utils.py ├── detect ├── __pycache__ │ ├── config.cpython-37.pyc │ ├── ctpn_model.cpython-37.pyc │ ├── ctpn_utils.cpython-37.pyc │ └── ctpn_predict.cpython-37.pyc ├── config.py ├── ctpn_predict.py ├── ctpn_model.py └── ctpn_utils.py ├── .gitignore ├── test_one.py ├── LICENSE ├── demo.py ├── README.md └── ocr.py /recognize/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/courao/ocr.pytorch/HEAD/recognize/.DS_Store -------------------------------------------------------------------------------- /test_images/t1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/courao/ocr.pytorch/HEAD/test_images/t1.png -------------------------------------------------------------------------------- /test_images/t2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/courao/ocr.pytorch/HEAD/test_images/t2.png -------------------------------------------------------------------------------- /test_images/t3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/courao/ocr.pytorch/HEAD/test_images/t3.png -------------------------------------------------------------------------------- /test_images/t4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/courao/ocr.pytorch/HEAD/test_images/t4.png -------------------------------------------------------------------------------- /test_images/t5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/courao/ocr.pytorch/HEAD/test_images/t5.png -------------------------------------------------------------------------------- /test_result/t1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/courao/ocr.pytorch/HEAD/test_result/t1.jpg -------------------------------------------------------------------------------- /checkpoints/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/courao/ocr.pytorch/HEAD/checkpoints/.DS_Store -------------------------------------------------------------------------------- /recognize/alphabet.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/courao/ocr.pytorch/HEAD/recognize/alphabet.pkl -------------------------------------------------------------------------------- /test_images/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/courao/ocr.pytorch/HEAD/test_images/.DS_Store -------------------------------------------------------------------------------- /train_code/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/courao/ocr.pytorch/HEAD/train_code/.DS_Store -------------------------------------------------------------------------------- /train_code/train_crnn/t1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/courao/ocr.pytorch/HEAD/train_code/train_crnn/t1.jpg -------------------------------------------------------------------------------- /train_code/train_crnn/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/courao/ocr.pytorch/HEAD/train_code/train_crnn/.DS_Store -------------------------------------------------------------------------------- /train_code/train_ctpn/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/courao/ocr.pytorch/HEAD/train_code/train_ctpn/.DS_Store -------------------------------------------------------------------------------- /train_code/train_crnn/alphabet.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/courao/ocr.pytorch/HEAD/train_code/train_crnn/alphabet.pkl -------------------------------------------------------------------------------- /detect/__pycache__/config.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/courao/ocr.pytorch/HEAD/detect/__pycache__/config.cpython-37.pyc -------------------------------------------------------------------------------- /detect/__pycache__/ctpn_model.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/courao/ocr.pytorch/HEAD/detect/__pycache__/ctpn_model.cpython-37.pyc -------------------------------------------------------------------------------- /detect/__pycache__/ctpn_utils.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/courao/ocr.pytorch/HEAD/detect/__pycache__/ctpn_utils.cpython-37.pyc -------------------------------------------------------------------------------- /train_code/train_crnn/crnn_models/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/courao/ocr.pytorch/HEAD/train_code/train_crnn/crnn_models/.DS_Store -------------------------------------------------------------------------------- /detect/__pycache__/ctpn_predict.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/courao/ocr.pytorch/HEAD/detect/__pycache__/ctpn_predict.cpython-37.pyc -------------------------------------------------------------------------------- /train_code/train_ctpn/__pycache__/config.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/courao/ocr.pytorch/HEAD/train_code/train_ctpn/__pycache__/config.cpython-37.pyc -------------------------------------------------------------------------------- /train_code/train_ctpn/__pycache__/ctpn_model.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/courao/ocr.pytorch/HEAD/train_code/train_ctpn/__pycache__/ctpn_model.cpython-37.pyc -------------------------------------------------------------------------------- /train_code/train_ctpn/__pycache__/ctpn_utils.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/courao/ocr.pytorch/HEAD/train_code/train_ctpn/__pycache__/ctpn_utils.cpython-37.pyc -------------------------------------------------------------------------------- /train_code/train_ctpn/data/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/courao/ocr.pytorch/HEAD/train_code/train_ctpn/data/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /train_code/train_ctpn/data/__pycache__/dataset.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/courao/ocr.pytorch/HEAD/train_code/train_ctpn/data/__pycache__/dataset.cpython-36.pyc -------------------------------------------------------------------------------- /train_code/train_ctpn/data/__init__.py: -------------------------------------------------------------------------------- 1 | #-*- coding:utf-8 -*- 2 | #''' 3 | # Created on 18-12-27 上午10:33 4 | # 5 | # @Author: Greg Gao(laygin) 6 | #''' 7 | from .dataset import VOCDataset,ICDARDataset -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.hd 2 | *.cmds 3 | *.acn 4 | *.acr 5 | *.alg 6 | *.aux 7 | *.bbl 8 | *.blg 9 | *.dvi 10 | *.fdb_latexmk 11 | *.glg 12 | *.glo 13 | *.gls 14 | *.idx 15 | *.ilg 16 | *.ind 17 | *.ist 18 | *.lof 19 | *.log 20 | *.lot 21 | *.maf 22 | *.mtc 23 | *.mtc0 24 | *.nav 25 | *.nlo 26 | *.out 27 | *.pdfsync 28 | *.ps 29 | *.snm 30 | *.synctex.gz 31 | *.toc 32 | *.vrb 33 | *.xdy 34 | *.tdo 35 | *.thm 36 | .DS_Store 37 | *.bak -------------------------------------------------------------------------------- /train_code/train_crnn/split_train_test.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | h1 = open('infofiles/infofile_selfcollect.txt',encoding='utf-8') 3 | h2 = open('infofiles/infofile_selfcollect_train.txt','w',encoding='utf-8') 4 | h3 = open('infofiles/infofile_selfcollect_test.txt','w',encoding='utf-8') 5 | content = h1.readlines() 6 | for line in content: 7 | if np.random.random()<0.05: 8 | h3.write(line) 9 | else: 10 | h2.write(line) 11 | -------------------------------------------------------------------------------- /train_code/train_ctpn/readme.md: -------------------------------------------------------------------------------- 1 | ## Train CTPN 2 | > Modified codes from [pytorch_ctpn](https://github.com/opconty/pytorch_ctpn) 3 | Add OHEM 4 | Support ICDAR17MLT dataset 5 | 6 | To train your own model, put your images into one directory [images], 7 | and labels into another directory [labels]. 8 | Replace the value of icdar17_mlt_img_dir and icdar17_mlt_gt_dir in config.py by your own path. 9 | Then run 10 | >python3 ctpn_train.py 11 | 12 | If you want to train on ICDAR datasets, please visit [here](https://rrc.cvc.uab.es) and download datasets you like. 13 | 14 | 15 | -------------------------------------------------------------------------------- /test_one.py: -------------------------------------------------------------------------------- 1 | import os 2 | from ocr import ocr 3 | import time 4 | import shutil 5 | import numpy as np 6 | from PIL import Image 7 | from glob import glob 8 | import cv2 9 | 10 | def single_pic_proc(image_file): 11 | image = np.array(Image.open(image_file).convert('RGB')) 12 | result, image_framed = ocr(image) 13 | return result,image_framed 14 | 15 | def dis(image): 16 | cv2.imshow('image', image) 17 | cv2.waitKey(0) 18 | cv2.destroyAllWindows() 19 | 20 | if __name__ == '__main__': 21 | import sys 22 | if len(sys.argv)>=2: 23 | filename = sys.argv[1] 24 | if filename.endswith('jpg') or filename.endswith('png'): 25 | result, image_framed = single_pic_proc(filename) 26 | print(result) 27 | dis(image_framed) 28 | -------------------------------------------------------------------------------- /train_code/train_crnn/config.py: -------------------------------------------------------------------------------- 1 | import keys 2 | 3 | train_infofile = 'data_set/infofile_train_10w.txt' 4 | train_infofile_fullimg = '' 5 | val_infofile = 'data_set/infofile_test.txt' 6 | alphabet = keys.alphabet 7 | alphabet_v2 = keys.alphabet_v2 8 | workers = 4 9 | batchSize = 50 10 | imgH = 32 11 | imgW = 280 12 | nc = 1 13 | nclass = len(alphabet)+1 14 | nh = 256 15 | niter = 100 16 | lr = 0.0003 17 | beta1 = 0.5 18 | cuda = True 19 | ngpu = 1 20 | pretrained_model = '' 21 | saved_model_dir = 'crnn_models' 22 | saved_model_prefix = 'CRNN-' 23 | use_log = False 24 | remove_blank = False 25 | 26 | experiment = None 27 | displayInterval = 500 28 | n_test_disp = 10 29 | valInterval = 500 30 | saveInterval = 500 31 | adam = False 32 | adadelta = False 33 | keep_ratio = False 34 | random_sample = True 35 | 36 | -------------------------------------------------------------------------------- /recognize/config.py: -------------------------------------------------------------------------------- 1 | from recognize import keys 2 | 3 | train_infofile = 'data_set/infofile_train_10w.txt' 4 | train_infofile_fullimg = '' 5 | val_infofile = 'data_set/infofile_test.txt' 6 | alphabet = keys.alphabet 7 | alphabet_v2 = keys.alphabet_v2 8 | workers = 4 9 | batchSize = 50 10 | imgH = 32 11 | imgW = 280 12 | nc = 1 13 | nclass = len(alphabet)+1 14 | nh = 256 15 | niter = 100 16 | lr = 0.0003 17 | beta1 = 0.5 18 | cuda = True 19 | ngpu = 1 20 | pretrained_model = '' 21 | saved_model_dir = 'crnn_models' 22 | saved_model_prefix = 'CRNN-' 23 | use_log = False 24 | remove_blank = False 25 | 26 | experiment = None 27 | displayInterval = 500 28 | n_test_disp = 10 29 | valInterval = 500 30 | saveInterval = 500 31 | adam = False 32 | adadelta = False 33 | keep_ratio = False 34 | random_sample = True 35 | 36 | -------------------------------------------------------------------------------- /train_code/train_crnn/readme.md: -------------------------------------------------------------------------------- 1 | ## Train CRNN 2 | 3 | To train your own model, firstly you need to prepare your own text-line dataset, and organize your dataset with an infoile. 4 | In the infofile, each line represents an text-line image, path to image and image label are splited by a special character '\t'. 5 | Such as follows: 6 | >data_set/my_data1/0001.jpg\t37918 7 | data_set/my_data1/0002.jpg\tHello World! 8 | data_set/my_data1/0003.jpg\t你好 9 | ... 10 | 11 | Then replace your infofile by your own in the training file train_warp_ctc.py 12 | >config.train_infofile = ['path_to_train_infofile1.txt','path_to_train_infofile2.txt'] 13 | config.val_infofile = 'path_to_test_infofile.txt' 14 | 15 | Then if you want to use warp-ctc as loss function, run 16 | >python3 train_warp_ctc.py 17 | 18 | or if you want to use pytorch-ctc as loss function, run 19 | >python3 train_pytorch_ctc.py 20 | 21 | 22 | -------------------------------------------------------------------------------- /recognize/keys.py: -------------------------------------------------------------------------------- 1 | import pickle as pkl 2 | # gen alphabet via label 3 | # alphabet_set = set() 4 | # infofiles = ['infofiles/infofile_selfcollect.txt','infofiles/infofile_train_public.txt'] 5 | # for infofile in infofiles: 6 | # f = open(infofile) 7 | # content = f.readlines() 8 | # f.close() 9 | # for line in content: 10 | # if len(line.strip())>0: 11 | # if len(line.strip().split('\t'))!=2: 12 | # print(line) 13 | # else: 14 | # fname,label = line.strip().split('\t') 15 | # for ch in label: 16 | # alphabet_set.add(ch) 17 | # 18 | # alphabet_list = sorted(list(alphabet_set)) 19 | # pkl.dump(alphabet_list,open('alphabet.pkl','wb')) 20 | 21 | alphabet_list = pkl.load(open('recognize/alphabet.pkl','rb')) 22 | alphabet = [ord(ch) for ch in alphabet_list] 23 | alphabet_v2 = alphabet 24 | # print(alphabet_v2) -------------------------------------------------------------------------------- /train_code/train_crnn/keys.py: -------------------------------------------------------------------------------- 1 | import pickle as pkl 2 | # gen alphabet via label 3 | # alphabet_set = set() 4 | # infofiles = ['infofiles/infofile_selfcollect.txt','infofiles/infofile_train_public.txt'] 5 | # for infofile in infofiles: 6 | # f = open(infofile) 7 | # content = f.readlines() 8 | # f.close() 9 | # for line in content: 10 | # if len(line.strip())>0: 11 | # if len(line.strip().split('\t'))!=2: 12 | # print(line) 13 | # else: 14 | # fname,label = line.strip().split('\t') 15 | # for ch in label: 16 | # alphabet_set.add(ch) 17 | # 18 | # alphabet_list = sorted(list(alphabet_set)) 19 | # pkl.dump(alphabet_list,open('alphabet.pkl','wb')) 20 | 21 | alphabet_list = pkl.load(open('alphabet.pkl','rb')) 22 | alphabet = [ord(ch) for ch in alphabet_list] 23 | alphabet_v2 = alphabet 24 | # print(alphabet_v2) -------------------------------------------------------------------------------- /train_code/train_ctpn/config.py: -------------------------------------------------------------------------------- 1 | #-*- coding:utf-8 -*- 2 | #''' 3 | # Created on 18-12-11 上午10:09 4 | # 5 | # @Author: Greg Gao(laygin) 6 | #''' 7 | import os 8 | 9 | # base_dir = 'path to dataset base dir' 10 | base_dir = './images' 11 | img_dir = os.path.join(base_dir, 'VOC2007_text_detection/JPEGImages') 12 | xml_dir = os.path.join(base_dir, 'VOC2007_text_detection/Annotations') 13 | 14 | icdar17_mlt_img_dir = '/home/data2/egz/ICDAR2017_MLT/train/' 15 | icdar17_mlt_gt_dir = '/home/data2/egz/ICDAR2017_MLT/train_gt/' 16 | num_workers = 2 17 | pretrained_weights = 'checkpoints/v3_ctpn_ep22_0.3801_0.0971_0.4773.pth' 18 | 19 | 20 | 21 | anchor_scale = 16 22 | IOU_NEGATIVE = 0.3 23 | IOU_POSITIVE = 0.7 24 | IOU_SELECT = 0.7 25 | 26 | RPN_POSITIVE_NUM = 150 27 | RPN_TOTAL_NUM = 300 28 | 29 | # bgr can find from here: https://github.com/fchollet/deep-learning-models/blob/master/imagenet_utils.py 30 | IMAGE_MEAN = [123.68, 116.779, 103.939] 31 | OHEM = True 32 | 33 | checkpoints_dir = './checkpoints' 34 | outputs = r'./logs' 35 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 唐董琦 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /detect/config.py: -------------------------------------------------------------------------------- 1 | #-*- coding:utf-8 -*- 2 | #''' 3 | # Created on 18-12-11 上午10:09 4 | # 5 | # @Author: Greg Gao(laygin) 6 | #''' 7 | import os 8 | 9 | # base_dir = 'path to dataset base dir' 10 | base_dir = './images' 11 | img_dir = os.path.join(base_dir, 'VOC2007_text_detection/JPEGImages') 12 | xml_dir = os.path.join(base_dir, 'VOC2007_text_detection/Annotations') 13 | 14 | icdar17_mlt_img_dir = '/home/data2/egz/ICDAR2017_MLT/train/' 15 | icdar17_mlt_gt_dir = '/home/data2/egz/ICDAR2017_MLT/train_gt/' 16 | num_workers = 2 17 | pretrained_weights = 'checkpoints/base.pth.tar' 18 | 19 | train_txt_file = os.path.join(base_dir, r'VOC2007_text_detection/ImageSets/Main/train.txt') 20 | val_txt_file = os.path.join(base_dir, r'VOC2007_text_detection/ImageSets/Main/val.txt') 21 | 22 | 23 | anchor_scale = 16 24 | IOU_NEGATIVE = 0.3 25 | IOU_POSITIVE = 0.7 26 | IOU_SELECT = 0.7 27 | 28 | RPN_POSITIVE_NUM = 150 29 | RPN_TOTAL_NUM = 300 30 | 31 | # bgr can find from here: https://github.com/fchollet/deep-learning-models/blob/master/imagenet_utils.py 32 | IMAGE_MEAN = [123.68, 116.779, 103.939] 33 | 34 | 35 | checkpoints_dir = './checkpoints' 36 | outputs = r'./logs' 37 | -------------------------------------------------------------------------------- /demo.py: -------------------------------------------------------------------------------- 1 | import os 2 | from ocr import ocr 3 | import time 4 | import shutil 5 | import numpy as np 6 | from PIL import Image 7 | from glob import glob 8 | 9 | 10 | def single_pic_proc(image_file): 11 | image = np.array(Image.open(image_file).convert('RGB')) 12 | result, image_framed = ocr(image) 13 | return result,image_framed 14 | 15 | 16 | if __name__ == '__main__': 17 | image_files = glob('./test_images/*.*') 18 | result_dir = './test_result' 19 | if os.path.exists(result_dir): 20 | shutil.rmtree(result_dir) 21 | os.mkdir(result_dir) 22 | 23 | for image_file in sorted(image_files): 24 | t = time.time() 25 | result, image_framed = single_pic_proc(image_file) 26 | output_file = os.path.join(result_dir, image_file.split('/')[-1]) 27 | txt_file = os.path.join(result_dir, image_file.split('/')[-1].split('.')[0]+'.txt') 28 | print(txt_file) 29 | txt_f = open(txt_file, 'w') 30 | Image.fromarray(image_framed).save(output_file) 31 | print("Mission complete, it took {:.3f}s".format(time.time() - t)) 32 | print("\nRecognition Result:\n") 33 | for key in result: 34 | print(result[key][1]) 35 | txt_f.write(result[key][1]+'\n') 36 | txt_f.close() -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## ocr.pytorch 2 | > A pure pytorch implemented ocr project. 3 | Text detection is based CTPN and text recognition is based CRNN. 4 | More detection and recognition methods will be supported! 5 | 6 | ## Prerequisite 7 | 8 | - python-3.5+ 9 | - pytorch-0.4.1+ 10 | - torchvision-0.2.1 11 | - opencv-3.4.0.14 12 | - numpy-1.14.3 13 | 14 | 15 | They could all be installed through pip except pytorch and torchvision. As for pytorch and torchvision, 16 | they both depends on your CUDA version, you would prefer to reading [pytorch's official site](https://pytorch.org/) 17 | 18 | 19 | ### Detection 20 | Detection is based on [CTPN](https://arxiv.org/abs/1609.03605), some codes are borrowed from 21 | [pytorch_ctpn](https://github.com/opconty/pytorch_ctpn), several detection results: 22 | ![detect1](test_result/t1.jpg) 23 | ![detect2](test_result/t2.jpg) 24 | ### Recognition 25 | Recognition is based on [CRNN](http://arxiv.org/abs/1507.05717), some codes are borrowed from 26 | [crnn.pytorch](https://github.com/meijieru/crnn.pytorch) 27 | 28 | ### Test 29 | Download pretrained models from [Baidu Netdisk](https://pan.baidu.com/s/1yllO9hBF8TgChHJ7i3WobA) (extract code: u2ff) or [Google Driver](https://drive.google.com/open?id=1hRr9v9ky4VGygToFjLD9Cd-9xan43qID) 30 | and put these files into checkpoints. 31 | Then run 32 | >python3 demo.py 33 | 34 | The image files in ./test_images will be tested for text detection and recognition, the results will be stored in ./test_result. 35 | 36 | If you want to test a single image, run 37 | >python3 test_one.py [filename] 38 | 39 | ### Train 40 | Training codes are placed into train_code directory. 41 | Train [CTPN](./train_code/train_ctpn/readme.md) 42 | Train [CRNN](./train_code/train_crnn/readme.md) 43 | 44 | ### Licence 45 | [MIT License](https://opensource.org/licenses/MIT) -------------------------------------------------------------------------------- /train_code/train_crnn/recognizer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd import Variable 3 | import utils 4 | import mydataset 5 | from PIL import Image 6 | import numpy as np 7 | import crnn as crnn 8 | import cv2 9 | import torch.nn.functional as F 10 | import keys 11 | import config 12 | gpu = True 13 | if not torch.cuda.is_available(): 14 | gpu = False 15 | 16 | model_path = './crnn_models/CRNN-0618-10w_21_990.pth' 17 | alphabet = keys.alphabet 18 | print(len(alphabet)) 19 | imgH = config.imgH 20 | imgW = config.imgW 21 | model = crnn.CRNN(imgH, 1, len(alphabet) + 1, 256) 22 | if gpu: 23 | model = model.cuda() 24 | print('loading pretrained model from %s' % model_path) 25 | if gpu: 26 | model.load_state_dict( torch.load( model_path ) ) 27 | else: 28 | model.load_state_dict(torch.load(model_path,map_location=lambda storage,loc:storage)) 29 | 30 | converter = utils.strLabelConverter(alphabet) 31 | transformer = mydataset.resizeNormalize((imgW, imgH),is_test=True) 32 | 33 | def recognize_downline(img,crnn_model=model): 34 | img = cv2.cvtColor( img, cv2.COLOR_BGR2RGB ) 35 | image = Image.fromarray(np.uint8(img)).convert('L') 36 | image = transformer( image ) 37 | if gpu: 38 | image = image.cuda() 39 | image = image.view( 1, *image.size() ) 40 | image = Variable( image ) 41 | 42 | model.eval() 43 | preds = model( image ) 44 | 45 | preds = F.log_softmax(preds,2) 46 | conf, preds = preds.max( 2 ) 47 | preds = preds.transpose( 1, 0 ).contiguous().view( -1 ) 48 | 49 | preds_size = Variable( torch.IntTensor( [preds.size( 0 )] ) ) 50 | raw_pred = converter.decode( preds.data, preds_size.data, raw=True ) 51 | sim_pred = converter.decode( preds.data, preds_size.data, raw=False ) 52 | return sim_pred.upper() 53 | 54 | 55 | if __name__ == '__main__': 56 | import shutil 57 | saved_path = 'test_imgs/' 58 | wrong_results = list() 59 | with open('data_set/infofile_test.txt') as f: 60 | content = f.readlines() 61 | num_all = 0 62 | num_correct = 0 63 | for line in content: 64 | fname, label = line.split('g:') 65 | fname += 'g' 66 | label = label.replace('\r', '').replace('\n', '') 67 | img = cv2.imread(fname) 68 | res = recognize_downline(img) 69 | if res==label: 70 | num_correct+=1 71 | else: 72 | # new_name = saved_path + fname.split('/')[-1] 73 | # shutil.copyfile(fname, new_name) 74 | wrong_results.append('res:{} / label:{}'.format(res,label)) 75 | num_all+=1 76 | print(fname,res==label,res,label) 77 | 78 | print(num_correct/num_all) 79 | # print(wrong_results) 80 | 81 | 82 | 83 | 84 | 85 | -------------------------------------------------------------------------------- /ocr.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | from math import * 3 | import numpy as np 4 | from detect.ctpn_predict import get_det_boxes 5 | from recognize.crnn_recognizer import PytorchOcr 6 | recognizer = PytorchOcr() 7 | 8 | def dis(image): 9 | cv2.imshow('image', image) 10 | cv2.waitKey(0) 11 | 12 | def sort_box(box): 13 | """ 14 | 对box进行排序 15 | """ 16 | box = sorted(box, key=lambda x: sum([x[1], x[3], x[5], x[7]])) 17 | return box 18 | 19 | def dumpRotateImage(img, degree, pt1, pt2, pt3, pt4): 20 | height, width = img.shape[:2] 21 | heightNew = int(width * fabs(sin(radians(degree))) + height * fabs(cos(radians(degree)))) 22 | widthNew = int(height * fabs(sin(radians(degree))) + width * fabs(cos(radians(degree)))) 23 | matRotation = cv2.getRotationMatrix2D((width // 2, height // 2), degree, 1) 24 | matRotation[0, 2] += (widthNew - width) // 2 25 | matRotation[1, 2] += (heightNew - height) // 2 26 | imgRotation = cv2.warpAffine(img, matRotation, (widthNew, heightNew), borderValue=(255, 255, 255)) 27 | pt1 = list(pt1) 28 | pt3 = list(pt3) 29 | 30 | [[pt1[0]], [pt1[1]]] = np.dot(matRotation, np.array([[pt1[0]], [pt1[1]], [1]])) 31 | [[pt3[0]], [pt3[1]]] = np.dot(matRotation, np.array([[pt3[0]], [pt3[1]], [1]])) 32 | ydim, xdim = imgRotation.shape[:2] 33 | imgOut = imgRotation[max(1, int(pt1[1])): min(ydim - 1, int(pt3[1])), 34 | max(1, int(pt1[0])): min(xdim - 1, int(pt3[0]))] 35 | 36 | return imgOut 37 | 38 | 39 | def charRec(img, text_recs, adjust=False): 40 | """ 41 | 加载OCR模型,进行字符识别 42 | """ 43 | results = {} 44 | xDim, yDim = img.shape[1], img.shape[0] 45 | 46 | for index, rec in enumerate(text_recs): 47 | xlength = int((rec[6] - rec[0]) * 0.1) 48 | ylength = int((rec[7] - rec[1]) * 0.2) 49 | if adjust: 50 | pt1 = (max(1, rec[0] - xlength), max(1, rec[1] - ylength)) 51 | pt2 = (rec[2], rec[3]) 52 | pt3 = (min(rec[6] + xlength, xDim - 2), min(yDim - 2, rec[7] + ylength)) 53 | pt4 = (rec[4], rec[5]) 54 | else: 55 | pt1 = (max(1, rec[0]), max(1, rec[1])) 56 | pt2 = (rec[2], rec[3]) 57 | pt3 = (min(rec[6], xDim - 2), min(yDim - 2, rec[7])) 58 | pt4 = (rec[4], rec[5]) 59 | 60 | degree = degrees(atan2(pt2[1] - pt1[1], pt2[0] - pt1[0])) # 图像倾斜角度 61 | 62 | partImg = dumpRotateImage(img, degree, pt1, pt2, pt3, pt4) 63 | # dis(partImg) 64 | if partImg.shape[0] < 1 or partImg.shape[1] < 1 or partImg.shape[0] > partImg.shape[1]: # 过滤异常图片 65 | continue 66 | text = recognizer.recognize(partImg) 67 | if len(text) > 0: 68 | results[index] = [rec] 69 | results[index].append(text) # 识别文字 70 | 71 | return results 72 | 73 | def ocr(image): 74 | # detect 75 | text_recs, img_framed, image = get_det_boxes(image) 76 | text_recs = sort_box(text_recs) 77 | result = charRec(image, text_recs) 78 | return result, img_framed 79 | -------------------------------------------------------------------------------- /train_code/train_ctpn/ctpn_predict.py: -------------------------------------------------------------------------------- 1 | #-*- coding:utf-8 -*- 2 | #''' 3 | # Created on 18-12-11 上午10:03 4 | # 5 | # @Author: Greg Gao(laygin) 6 | #''' 7 | import os 8 | os.environ['CUDA_VISIBLE_DEVICES'] = '' 9 | import cv2 10 | import numpy as np 11 | 12 | import torch 13 | import torch.nn.functional as F 14 | from ctpn_model import CTPN_Model 15 | from ctpn_utils import gen_anchor, bbox_transfor_inv, clip_box, filter_bbox,nms, TextProposalConnectorOriented 16 | from ctpn_utils import resize 17 | import config 18 | 19 | 20 | prob_thresh = 0.5 21 | width = 960 22 | device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') 23 | weights = os.path.join(config.checkpoints_dir, 'v3_ctpn_ep30_0.3699_0.0929_0.4628.pth')#'ctpn_ep17_0.0544_0.1125_0.1669.pth') 24 | 25 | 26 | model = CTPN_Model() 27 | model.load_state_dict(torch.load(weights, map_location=device)['model_state_dict']) 28 | model.to(device) 29 | model.eval() 30 | 31 | 32 | def dis(image): 33 | cv2.imshow('image', image) 34 | cv2.waitKey(0) 35 | cv2.destroyAllWindows() 36 | 37 | 38 | def get_det_boxes(image,display = True): 39 | image = resize(image, height=720) 40 | image_c = image.copy() 41 | h, w = image.shape[:2] 42 | image = image.astype(np.float32) - config.IMAGE_MEAN 43 | image = torch.from_numpy(image.transpose(2, 0, 1)).unsqueeze(0).float() 44 | 45 | with torch.no_grad(): 46 | image = image.to(device) 47 | cls, regr = model(image) 48 | cls_prob = F.softmax(cls, dim=-1).cpu().numpy() 49 | regr = regr.cpu().numpy() 50 | anchor = gen_anchor((int(h / 16), int(w / 16)), 16) 51 | bbox = bbox_transfor_inv(anchor, regr) 52 | bbox = clip_box(bbox, [h, w]) 53 | # print(bbox.shape) 54 | 55 | fg = np.where(cls_prob[0, :, 1] > prob_thresh)[0] 56 | # print(np.max(cls_prob[0, :, 1])) 57 | select_anchor = bbox[fg, :] 58 | select_score = cls_prob[0, fg, 1] 59 | select_anchor = select_anchor.astype(np.int32) 60 | # print(select_anchor.shape) 61 | keep_index = filter_bbox(select_anchor, 16) 62 | 63 | # nms 64 | select_anchor = select_anchor[keep_index] 65 | select_score = select_score[keep_index] 66 | select_score = np.reshape(select_score, (select_score.shape[0], 1)) 67 | nmsbox = np.hstack((select_anchor, select_score)) 68 | keep = nms(nmsbox, 0.3) 69 | # print(keep) 70 | select_anchor = select_anchor[keep] 71 | select_score = select_score[keep] 72 | 73 | # text line- 74 | textConn = TextProposalConnectorOriented() 75 | text = textConn.get_text_lines(select_anchor, select_score, [h, w]) 76 | print(text) 77 | if display: 78 | for i in text: 79 | s = str(round(i[-1] * 100, 2)) + '%' 80 | i = [int(j) for j in i] 81 | cv2.line(image_c, (i[0], i[1]), (i[2], i[3]), (0, 0, 255), 2) 82 | cv2.line(image_c, (i[0], i[1]), (i[4], i[5]), (0, 0, 255), 2) 83 | cv2.line(image_c, (i[6], i[7]), (i[2], i[3]), (0, 0, 255), 2) 84 | cv2.line(image_c, (i[4], i[5]), (i[6], i[7]), (0, 0, 255), 2) 85 | cv2.putText(image_c, s, (i[0]+13, i[1]+13), 86 | cv2.FONT_HERSHEY_SIMPLEX, 87 | 1, 88 | (255,0,0), 89 | 2, 90 | cv2.LINE_AA) 91 | 92 | return text,image_c 93 | 94 | if __name__ == '__main__': 95 | img_path = 'images/t1.png' 96 | image = cv2.imread(img_path) 97 | text,image = get_det_boxes(image) 98 | cv2.imwrite('results/t.jpg',image) 99 | # dis(image) -------------------------------------------------------------------------------- /train_code/train_crnn/online_test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd import Variable 3 | import utils 4 | import mydataset 5 | from PIL import Image 6 | import numpy as np 7 | import crnn as crnn 8 | import cv2 9 | import torch.nn.functional as F 10 | import keys 11 | import config 12 | 13 | alphabet = keys.alphabet_v2 14 | converter = utils.strLabelConverter(alphabet.copy()) 15 | 16 | 17 | def val_model(infofile,model,gpu,log_file = '0625.log'): 18 | h = open('log/{}'.format(log_file),'w') 19 | with open(infofile) as f: 20 | content = f.readlines() 21 | num_all = 0 22 | num_correct = 0 23 | 24 | for line in content: 25 | if '\t' in line: 26 | fname, label = line.split('\t') 27 | else: 28 | fname, label = line.split('g:') 29 | fname += 'g' 30 | label = label.replace('\r', '').replace('\n', '') 31 | img = cv2.imread(fname) 32 | res = val_on_image(img,model,gpu) 33 | res = res.strip() 34 | label = label.strip() 35 | if res == label: 36 | num_correct+=1 37 | else: 38 | print('filename:{}\npred :{}\ntarget:{}'.format(fname, res, label)) 39 | h.write('filename:{}\npred :{}\ntarget:{}\n'.format(fname,res, label)) 40 | # else: 41 | # # new_name = saved_path + fname.split('/')[-1] 42 | # # shutil.copyfile(fname, new_name) 43 | # wrong_results.append('res:{} / label:{}'.format(res,label)) 44 | num_all+=1 45 | h.write('ocr_correct: {}/{}/{}\n'.format(num_correct,num_all,num_correct/num_all)) 46 | print(num_correct/num_all) 47 | h.close() 48 | return num_correct, num_all 49 | 50 | def val_on_image(img,model,gpu): 51 | imgH = config.imgH 52 | h,w = img.shape[:2] 53 | imgW = imgH*w//h 54 | 55 | transformer = mydataset.resizeNormalize((imgW, imgH), is_test=True) 56 | img = cv2.cvtColor( img, cv2.COLOR_BGR2RGB ) 57 | image = Image.fromarray(np.uint8(img)).convert('L') 58 | image = transformer( image ) 59 | if gpu: 60 | image = image.cuda() 61 | image = image.view( 1, *image.size() ) 62 | image = Variable( image ) 63 | 64 | model.eval() 65 | preds = model( image ) 66 | 67 | preds = F.log_softmax(preds,2) 68 | conf, preds = preds.max( 2 ) 69 | preds = preds.transpose( 1, 0 ).contiguous().view( -1 ) 70 | 71 | preds_size = Variable( torch.IntTensor( [preds.size( 0 )] ) ) 72 | # raw_pred = converter.decode( preds.data, preds_size.data, raw=True ) 73 | sim_pred = converter.decode( preds.data, preds_size.data, raw=False ) 74 | return sim_pred 75 | 76 | 77 | if __name__ == '__main__': 78 | import sys 79 | model_path = './crnn_models/CRNN-0627-crop_48_901.pth' 80 | gpu = True 81 | if not torch.cuda.is_available(): 82 | gpu = False 83 | 84 | model = crnn.CRNN(config.imgH, 1, len(alphabet) + 1, 256) 85 | if gpu: 86 | model = model.cuda() 87 | print('loading pretrained model from %s' % model_path) 88 | if gpu: 89 | model.load_state_dict(torch.load(model_path)) 90 | else: 91 | model.load_state_dict(torch.load(model_path, map_location=lambda storage, loc: storage)) 92 | 93 | if len(sys.argv)>1 and 'train' in sys.argv[1]: 94 | infofile = 'data_set/infofile_updated_0627_train.txt' 95 | print(val_model(infofile, model, gpu, '0627_train.log')) 96 | elif len(sys.argv)>1 and 'gen' in sys.argv[1]: 97 | infofile = 'data_set/infofile_0627_gen_test.txt' 98 | print(val_model(infofile, model, gpu, '0627_gen.log')) 99 | else: 100 | infofile = 'data_set/infofile_updated_0627_test.txt' 101 | print(val_model(infofile, model, gpu, '0627_test.log')) 102 | 103 | 104 | 105 | 106 | -------------------------------------------------------------------------------- /detect/ctpn_predict.py: -------------------------------------------------------------------------------- 1 | #-*- coding:utf-8 -*- 2 | #''' 3 | # Created on 18-12-11 上午10:03 4 | # 5 | # @Author: Greg Gao(laygin) 6 | #''' 7 | import os 8 | os.environ['CUDA_VISIBLE_DEVICES'] = '0' 9 | import cv2 10 | import numpy as np 11 | 12 | import torch 13 | import torch.nn.functional as F 14 | from detect.ctpn_model import CTPN_Model 15 | from detect.ctpn_utils import gen_anchor, bbox_transfor_inv, clip_box, filter_bbox,nms, TextProposalConnectorOriented 16 | from detect.ctpn_utils import resize 17 | from detect import config 18 | 19 | prob_thresh = 0.5 20 | height = 720 21 | gpu = True 22 | if not torch.cuda.is_available(): 23 | gpu = False 24 | device = torch.device('cuda:0' if gpu else 'cpu') 25 | weights = os.path.join(config.checkpoints_dir, 'CTPN.pth') 26 | model = CTPN_Model() 27 | model.load_state_dict(torch.load(weights, map_location=device)['model_state_dict']) 28 | model.to(device) 29 | model.eval() 30 | 31 | 32 | def dis(image): 33 | cv2.imshow('image', image) 34 | cv2.waitKey(0) 35 | cv2.destroyAllWindows() 36 | 37 | 38 | def get_det_boxes(image,display = True, expand = True): 39 | image = resize(image, height=height) 40 | image_r = image.copy() 41 | image_c = image.copy() 42 | h, w = image.shape[:2] 43 | image = image.astype(np.float32) - config.IMAGE_MEAN 44 | image = torch.from_numpy(image.transpose(2, 0, 1)).unsqueeze(0).float() 45 | 46 | with torch.no_grad(): 47 | image = image.to(device) 48 | cls, regr = model(image) 49 | cls_prob = F.softmax(cls, dim=-1).cpu().numpy() 50 | regr = regr.cpu().numpy() 51 | anchor = gen_anchor((int(h / 16), int(w / 16)), 16) 52 | bbox = bbox_transfor_inv(anchor, regr) 53 | bbox = clip_box(bbox, [h, w]) 54 | # print(bbox.shape) 55 | 56 | fg = np.where(cls_prob[0, :, 1] > prob_thresh)[0] 57 | # print(np.max(cls_prob[0, :, 1])) 58 | select_anchor = bbox[fg, :] 59 | select_score = cls_prob[0, fg, 1] 60 | select_anchor = select_anchor.astype(np.int32) 61 | # print(select_anchor.shape) 62 | keep_index = filter_bbox(select_anchor, 16) 63 | 64 | # nms 65 | select_anchor = select_anchor[keep_index] 66 | select_score = select_score[keep_index] 67 | select_score = np.reshape(select_score, (select_score.shape[0], 1)) 68 | nmsbox = np.hstack((select_anchor, select_score)) 69 | keep = nms(nmsbox, 0.3) 70 | # print(keep) 71 | select_anchor = select_anchor[keep] 72 | select_score = select_score[keep] 73 | 74 | # text line- 75 | textConn = TextProposalConnectorOriented() 76 | text = textConn.get_text_lines(select_anchor, select_score, [h, w]) 77 | 78 | # expand text 79 | if expand: 80 | for idx in range(len(text)): 81 | text[idx][0] = max(text[idx][0] - 10, 0) 82 | text[idx][2] = min(text[idx][2] + 10, w - 1) 83 | text[idx][4] = max(text[idx][4] - 10, 0) 84 | text[idx][6] = min(text[idx][6] + 10, w - 1) 85 | 86 | 87 | # print(text) 88 | if display: 89 | blank = np.zeros(image_c.shape,dtype=np.uint8) 90 | for box in select_anchor: 91 | pt1 = (box[0], box[1]) 92 | pt2 = (box[2], box[3]) 93 | blank = cv2.rectangle(blank, pt1, pt2, (50, 0, 0), -1) 94 | image_c = image_c+blank 95 | image_c[image_c>255] = 255 96 | for i in text: 97 | s = str(round(i[-1] * 100, 2)) + '%' 98 | i = [int(j) for j in i] 99 | cv2.line(image_c, (i[0], i[1]), (i[2], i[3]), (0, 0, 255), 2) 100 | cv2.line(image_c, (i[0], i[1]), (i[4], i[5]), (0, 0, 255), 2) 101 | cv2.line(image_c, (i[6], i[7]), (i[2], i[3]), (0, 0, 255), 2) 102 | cv2.line(image_c, (i[4], i[5]), (i[6], i[7]), (0, 0, 255), 2) 103 | cv2.putText(image_c, s, (i[0]+13, i[1]+13), 104 | cv2.FONT_HERSHEY_SIMPLEX, 105 | 1, 106 | (255,0,0), 107 | 2, 108 | cv2.LINE_AA) 109 | # dis(image_c) 110 | # print(text) 111 | return text,image_c,image_r 112 | 113 | if __name__ == '__main__': 114 | img_path = 'images/t1.png' 115 | image = cv2.imread(img_path) 116 | text,image = get_det_boxes(image) 117 | dis(image) -------------------------------------------------------------------------------- /detect/ctpn_model.py: -------------------------------------------------------------------------------- 1 | #-*- coding:utf-8 -*- 2 | #''' 3 | # Created on 18-12-11 上午10:01 4 | # 5 | # @Author: Greg Gao(laygin) 6 | #''' 7 | import os 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | import torchvision.models as models 12 | 13 | 14 | class RPN_REGR_Loss(nn.Module): 15 | def __init__(self, device, sigma=9.0): 16 | super(RPN_REGR_Loss, self).__init__() 17 | self.sigma = sigma 18 | self.device = device 19 | 20 | def forward(self, input, target): 21 | ''' 22 | smooth L1 loss 23 | :param input:y_preds 24 | :param target: y_true 25 | :return: 26 | ''' 27 | try: 28 | cls = target[0, :, 0] 29 | regr = target[0, :, 1:3] 30 | regr_keep = (cls == 1).nonzero()[:, 0] 31 | regr_true = regr[regr_keep] 32 | regr_pred = input[0][regr_keep] 33 | diff = torch.abs(regr_true - regr_pred) 34 | less_one = (diff<1.0/self.sigma).float() 35 | loss = less_one * 0.5 * diff ** 2 * self.sigma + torch.abs(1- less_one) * (diff - 0.5/self.sigma) 36 | loss = torch.sum(loss, 1) 37 | loss = torch.mean(loss) if loss.numel() > 0 else torch.tensor(0.0) 38 | except Exception as e: 39 | print('RPN_REGR_Loss Exception:', e) 40 | # print(input, target) 41 | loss = torch.tensor(0.0) 42 | 43 | return loss.to(self.device) 44 | 45 | 46 | class RPN_CLS_Loss(nn.Module): 47 | def __init__(self,device): 48 | super(RPN_CLS_Loss, self).__init__() 49 | self.device = device 50 | 51 | def forward(self, input, target): 52 | y_true = target[0][0] 53 | cls_keep = (y_true != -1).nonzero()[:, 0] 54 | cls_true = y_true[cls_keep].long() 55 | cls_pred = input[0][cls_keep] 56 | loss = F.nll_loss(F.log_softmax(cls_pred, dim=-1), cls_true) # original is sparse_softmax_cross_entropy_with_logits 57 | # loss = nn.BCEWithLogitsLoss()(cls_pred[:,0], cls_true.float()) # 18-12-8 58 | loss = torch.clamp(torch.mean(loss), 0, 10) if loss.numel() > 0 else torch.tensor(0.0) 59 | return loss.to(self.device) 60 | 61 | 62 | class basic_conv(nn.Module): 63 | def __init__(self, 64 | in_planes, 65 | out_planes, 66 | kernel_size, 67 | stride=1, 68 | padding=0, 69 | dilation=1, 70 | groups=1, 71 | relu=True, 72 | bn=True, 73 | bias=True): 74 | super(basic_conv, self).__init__() 75 | self.out_channels = out_planes 76 | self.conv = nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias) 77 | self.bn = nn.BatchNorm2d(out_planes, eps=1e-5, momentum=0.01, affine=True) if bn else None 78 | self.relu = nn.ReLU(inplace=True) if relu else None 79 | 80 | def forward(self, x): 81 | x = self.conv(x) 82 | if self.bn is not None: 83 | x = self.bn(x) 84 | if self.relu is not None: 85 | x = self.relu(x) 86 | return x 87 | 88 | 89 | class CTPN_Model(nn.Module): 90 | def __init__(self): 91 | super().__init__() 92 | base_model = models.vgg16(pretrained=False) 93 | layers = list(base_model.features)[:-1] 94 | self.base_layers = nn.Sequential(*layers) # block5_conv3 output 95 | self.rpn = basic_conv(512, 512, 3, 1, 1, bn=False) 96 | self.brnn = nn.GRU(512,128, bidirectional=True, batch_first=True) 97 | self.lstm_fc = basic_conv(256, 512, 1, 1, relu=True, bn=False) 98 | self.rpn_class = basic_conv(512, 10 * 2, 1, 1, relu=False, bn=False) 99 | self.rpn_regress = basic_conv(512, 10 * 2, 1, 1, relu=False, bn=False) 100 | 101 | def forward(self, x): 102 | x = self.base_layers(x) 103 | # rpn 104 | x = self.rpn(x) #[b, c, h, w] 105 | 106 | x1 = x.permute(0,2,3,1).contiguous() # channels last [b, h, w, c] 107 | b = x1.size() # b, h, w, c 108 | x1 = x1.view(b[0]*b[1], b[2], b[3]) 109 | 110 | x2, _ = self.brnn(x1) 111 | 112 | xsz = x.size() 113 | x3 = x2.view(xsz[0], xsz[2], xsz[3], 256) # torch.Size([4, 20, 20, 256]) 114 | 115 | x3 = x3.permute(0,3,1,2).contiguous() # channels first [b, c, h, w] 116 | x3 = self.lstm_fc(x3) 117 | x = x3 118 | 119 | cls = self.rpn_class(x) 120 | regr = self.rpn_regress(x) 121 | 122 | cls = cls.permute(0,2,3,1).contiguous() 123 | regr = regr.permute(0,2,3,1).contiguous() 124 | 125 | cls = cls.view(cls.size(0), cls.size(1)*cls.size(2)*10, 2) 126 | regr = regr.view(regr.size(0), regr.size(1)*regr.size(2)*10, 2) 127 | 128 | return cls, regr 129 | -------------------------------------------------------------------------------- /train_code/train_ctpn/ctpn_train.py: -------------------------------------------------------------------------------- 1 | #-*- coding:utf-8 -*- 2 | #''' 3 | # Created on 18-12-27 上午10:31 4 | # 5 | # @Author: Greg Gao(laygin) 6 | #''' 7 | import os 8 | os.environ['CUDA_VISIBLE_DEVICES'] = '3' 9 | import torch 10 | from torch.utils.data import DataLoader 11 | from torch import optim 12 | import numpy as np 13 | import argparse 14 | 15 | import config 16 | from ctpn_model import CTPN_Model, RPN_CLS_Loss, RPN_REGR_Loss 17 | from data.dataset import ICDARDataset 18 | 19 | 20 | random_seed = 2019 21 | torch.random.manual_seed(random_seed) 22 | np.random.seed(random_seed) 23 | 24 | epochs = 30 25 | lr = 1e-3 26 | resume_epoch = 0 27 | 28 | 29 | def save_checkpoint(state, epoch, loss_cls, loss_regr, loss, ext='pth'): 30 | check_path = os.path.join(config.checkpoints_dir, 31 | f'v3_ctpn_ep{epoch:02d}_' 32 | f'{loss_cls:.4f}_{loss_regr:.4f}_{loss:.4f}.{ext}') 33 | 34 | try: 35 | torch.save(state, check_path) 36 | except BaseException as e: 37 | print(e) 38 | print('fail to save to {}'.format(check_path)) 39 | print('saving to {}'.format(check_path)) 40 | 41 | def weights_init(m): 42 | classname = m.__class__.__name__ 43 | if classname.find('Conv') != -1: 44 | m.weight.data.normal_(0.0, 0.02) 45 | elif classname.find('BatchNorm') != -1: 46 | m.weight.data.normal_(1.0, 0.02) 47 | m.bias.data.fill_(0) 48 | 49 | 50 | if __name__ == '__main__': 51 | device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') 52 | checkpoints_weight = config.pretrained_weights 53 | print('exist pretrained ',os.path.exists(checkpoints_weight)) 54 | if os.path.exists(checkpoints_weight): 55 | pretrained = False 56 | 57 | dataset = ICDARDataset(config.icdar17_mlt_img_dir, config.icdar17_mlt_gt_dir) 58 | dataloader = DataLoader(dataset, batch_size=1, shuffle=True, num_workers=config.num_workers) 59 | model = CTPN_Model() 60 | model.to(device) 61 | 62 | if os.path.exists(checkpoints_weight): 63 | print('using pretrained weight: {}'.format(checkpoints_weight)) 64 | cc = torch.load(checkpoints_weight, map_location=device) 65 | model.load_state_dict(cc['model_state_dict']) 66 | resume_epoch = cc['epoch'] 67 | else: 68 | model.apply(weights_init) 69 | 70 | params_to_uodate = model.parameters() 71 | optimizer = optim.SGD(params_to_uodate, lr=lr, momentum=0.9) 72 | 73 | critetion_cls = RPN_CLS_Loss(device) 74 | critetion_regr = RPN_REGR_Loss(device) 75 | 76 | best_loss_cls = 100 77 | best_loss_regr = 100 78 | best_loss = 100 79 | best_model = None 80 | epochs += resume_epoch 81 | scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1) 82 | 83 | for epoch in range(resume_epoch+1, epochs): 84 | print(f'Epoch {epoch}/{epochs}') 85 | print('#'*50) 86 | epoch_size = len(dataset) // 1 87 | model.train() 88 | epoch_loss_cls = 0 89 | epoch_loss_regr = 0 90 | epoch_loss = 0 91 | scheduler.step(epoch) 92 | 93 | for batch_i, (imgs, clss, regrs) in enumerate(dataloader): 94 | # print(imgs.shape) 95 | imgs = imgs.to(device) 96 | clss = clss.to(device) 97 | regrs = regrs.to(device) 98 | 99 | optimizer.zero_grad() 100 | 101 | out_cls, out_regr = model(imgs) 102 | loss_cls = critetion_cls(out_cls, clss) 103 | loss_regr = critetion_regr(out_regr, regrs) 104 | 105 | loss = loss_cls + loss_regr # total loss 106 | loss.backward() 107 | optimizer.step() 108 | 109 | epoch_loss_cls += loss_cls.item() 110 | epoch_loss_regr += loss_regr.item() 111 | epoch_loss += loss.item() 112 | mmp = batch_i+1 113 | 114 | print(f'Ep:{epoch}/{epochs-1}--' 115 | f'Batch:{batch_i}/{epoch_size}\n' 116 | f'batch: loss_cls:{loss_cls.item():.4f}--loss_regr:{loss_regr.item():.4f}--loss:{loss.item():.4f}\n' 117 | f'Epoch: loss_cls:{epoch_loss_cls/mmp:.4f}--loss_regr:{epoch_loss_regr/mmp:.4f}--' 118 | f'loss:{epoch_loss/mmp:.4f}\n') 119 | 120 | epoch_loss_cls /= epoch_size 121 | epoch_loss_regr /= epoch_size 122 | epoch_loss /= epoch_size 123 | print(f'Epoch:{epoch}--{epoch_loss_cls:.4f}--{epoch_loss_regr:.4f}--{epoch_loss:.4f}') 124 | if best_loss_cls > epoch_loss_cls or best_loss_regr > epoch_loss_regr or best_loss > epoch_loss: 125 | best_loss = epoch_loss 126 | best_loss_regr = epoch_loss_regr 127 | best_loss_cls = epoch_loss_cls 128 | best_model = model 129 | save_checkpoint({'model_state_dict': best_model.state_dict(), 130 | 'epoch': epoch}, 131 | epoch, 132 | best_loss_cls, 133 | best_loss_regr, 134 | best_loss) 135 | 136 | if torch.cuda.is_available(): 137 | torch.cuda.empty_cache() 138 | 139 | -------------------------------------------------------------------------------- /train_code/train_crnn/crnn_recognizer.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | # import torchvision.models as models 3 | import torch, os 4 | from PIL import Image 5 | import cv2 6 | import torchvision.transforms as transforms 7 | from torch.autograd import Variable 8 | import numpy as np 9 | import random 10 | from crnn import CRNN 11 | import config 12 | 13 | # copy from mydataset 14 | class resizeNormalize(object): 15 | def __init__(self, size, interpolation=Image.LANCZOS, is_test=True): 16 | self.size = size 17 | self.interpolation = interpolation 18 | self.toTensor = transforms.ToTensor() 19 | self.is_test = is_test 20 | 21 | def __call__(self, img): 22 | w, h = self.size 23 | w0 = img.size[0] 24 | h0 = img.size[1] 25 | if w <= (w0 / h0 * h): 26 | img = img.resize(self.size, self.interpolation) 27 | img = self.toTensor(img) 28 | img.sub_(0.5).div_(0.5) 29 | else: 30 | w_real = int(w0 / h0 * h) 31 | img = img.resize((w_real, h), self.interpolation) 32 | img = self.toTensor(img) 33 | img.sub_(0.5).div_(0.5) 34 | tmp = torch.zeros([img.shape[0], h, w]) 35 | start = random.randint(0, w - w_real - 1) 36 | if self.is_test: 37 | start = 0 38 | tmp[:, :, start:start + w_real] = img 39 | img = tmp 40 | return img 41 | 42 | # copy from utils 43 | class strLabelConverter(object): 44 | def __init__(self, alphabet, ignore_case=False): 45 | self._ignore_case = ignore_case 46 | if self._ignore_case: 47 | alphabet = alphabet.lower() 48 | self.alphabet = alphabet + '_' # for `-1` index 49 | 50 | self.dict = {} 51 | for i, char in enumerate(alphabet): 52 | # NOTE: 0 is reserved for 'blank' required by wrap_ctc 53 | self.dict[char] = i + 1 54 | 55 | # print(self.dict) 56 | def encode(self, text): 57 | length = [] 58 | result = [] 59 | for item in text: 60 | item = item.decode('utf-8', 'strict') 61 | length.append(len(item)) 62 | for char in item: 63 | if char not in self.dict.keys(): 64 | index = 0 65 | else: 66 | index = self.dict[char] 67 | result.append(index) 68 | text = result 69 | return (torch.IntTensor(text), torch.IntTensor(length)) 70 | 71 | def decode(self, t, length, raw=False): 72 | if length.numel() == 1: 73 | length = length[0] 74 | assert t.numel() == length, "text with length: {} does not match declared length: {}".format(t.numel(), 75 | length) 76 | if raw: 77 | return ''.join([self.alphabet[i - 1] for i in t]) 78 | else: 79 | char_list = [] 80 | for i in range(length): 81 | if t[i] != 0 and (not (i > 0 and t[i - 1] == t[i])): 82 | char_list.append(self.alphabet[t[i] - 1]) 83 | return ''.join(char_list) 84 | else: 85 | # batch mode 86 | assert t.numel() == length.sum(), "texts with length: {} does not match declared length: {}".format( 87 | t.numel(), length.sum()) 88 | texts = [] 89 | index = 0 90 | for i in range(length.numel()): 91 | l = length[i] 92 | texts.append( 93 | self.decode( 94 | t[index:index + l], torch.IntTensor([l]), raw=raw)) 95 | index += l 96 | return texts 97 | 98 | # recognize api 99 | class PytorchOcr(): 100 | def __init__(self, model_path): 101 | alphabet_unicode = config.alphabet_v2 102 | self.alphabet = ''.join([chr(uni) for uni in alphabet_unicode]) 103 | # print(len(self.alphabet)) 104 | self.nclass = len(self.alphabet) + 1 105 | self.model = CRNN(config.imgH, 1, self.nclass, 256) 106 | self.cuda = False 107 | if torch.cuda.is_available(): 108 | self.cuda = True 109 | self.model.cuda() 110 | self.model.load_state_dict({k.replace('module.', ''): v for k, v in torch.load(model_path).items()}) 111 | else: 112 | # self.model = nn.DataParallel(self.model) 113 | self.model.load_state_dict(torch.load(model_path, map_location='cpu')) 114 | self.model.eval() 115 | self.converter = strLabelConverter(self.alphabet) 116 | 117 | def recognize(self, img): 118 | h,w = img.shape[:2] 119 | if len(img.shape) == 3: 120 | img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) 121 | image = Image.fromarray(img) 122 | transformer = resizeNormalize((int(w/h*32), 32)) 123 | image = transformer(image) 124 | image = image.view(1, *image.size()) 125 | image = Variable(image) 126 | 127 | if self.cuda: 128 | image = image.cuda() 129 | 130 | preds = self.model(image) 131 | 132 | _, preds = preds.max(2) 133 | preds = preds.transpose(1, 0).contiguous().view(-1) 134 | 135 | preds_size = Variable(torch.IntTensor([preds.size(0)])) 136 | txt = self.converter.decode(preds.data, preds_size.data, raw=False) 137 | 138 | return txt 139 | 140 | 141 | if __name__ == '__main__': 142 | model_path = './crnn_models/CRNN-1008.pth' 143 | recognizer = PytorchOcr(model_path) 144 | img_name = 't1.jpg' 145 | img = cv2.imread(img_name) 146 | h, w = img.shape[:2] 147 | res = recognizer.recognize(img) 148 | print(res) 149 | 150 | 151 | 152 | 153 | -------------------------------------------------------------------------------- /recognize/crnn_recognizer.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | # import torchvision.models as models 3 | import torch, os 4 | from PIL import Image 5 | import cv2 6 | import torchvision.transforms as transforms 7 | from torch.autograd import Variable 8 | import numpy as np 9 | import random 10 | from recognize.crnn import CRNN 11 | from recognize import config 12 | 13 | # copy from mydataset 14 | class resizeNormalize(object): 15 | def __init__(self, size, interpolation=Image.LANCZOS, is_test=True): 16 | self.size = size 17 | self.interpolation = interpolation 18 | self.toTensor = transforms.ToTensor() 19 | self.is_test = is_test 20 | 21 | def __call__(self, img): 22 | w, h = self.size 23 | w0 = img.size[0] 24 | h0 = img.size[1] 25 | if w <= (w0 / h0 * h): 26 | img = img.resize(self.size, self.interpolation) 27 | img = self.toTensor(img) 28 | img.sub_(0.5).div_(0.5) 29 | else: 30 | w_real = int(w0 / h0 * h) 31 | img = img.resize((w_real, h), self.interpolation) 32 | img = self.toTensor(img) 33 | img.sub_(0.5).div_(0.5) 34 | tmp = torch.zeros([img.shape[0], h, w]) 35 | start = random.randint(0, w - w_real - 1) 36 | if self.is_test: 37 | start = 0 38 | tmp[:, :, start:start + w_real] = img 39 | img = tmp 40 | return img 41 | 42 | # copy from utils 43 | class strLabelConverter(object): 44 | def __init__(self, alphabet, ignore_case=False): 45 | self._ignore_case = ignore_case 46 | if self._ignore_case: 47 | alphabet = alphabet.lower() 48 | self.alphabet = alphabet + '_' # for `-1` index 49 | 50 | self.dict = {} 51 | for i, char in enumerate(alphabet): 52 | # NOTE: 0 is reserved for 'blank' required by wrap_ctc 53 | self.dict[char] = i + 1 54 | 55 | # print(self.dict) 56 | def encode(self, text): 57 | length = [] 58 | result = [] 59 | for item in text: 60 | item = item.decode('utf-8', 'strict') 61 | length.append(len(item)) 62 | for char in item: 63 | if char not in self.dict.keys(): 64 | index = 0 65 | else: 66 | index = self.dict[char] 67 | result.append(index) 68 | text = result 69 | return (torch.IntTensor(text), torch.IntTensor(length)) 70 | 71 | def decode(self, t, length, raw=False): 72 | if length.numel() == 1: 73 | length = length[0] 74 | assert t.numel() == length, "text with length: {} does not match declared length: {}".format(t.numel(), 75 | length) 76 | if raw: 77 | return ''.join([self.alphabet[i - 1] for i in t]) 78 | else: 79 | char_list = [] 80 | for i in range(length): 81 | if t[i] != 0 and (not (i > 0 and t[i - 1] == t[i])): 82 | char_list.append(self.alphabet[t[i] - 1]) 83 | return ''.join(char_list) 84 | else: 85 | # batch mode 86 | assert t.numel() == length.sum(), "texts with length: {} does not match declared length: {}".format( 87 | t.numel(), length.sum()) 88 | texts = [] 89 | index = 0 90 | for i in range(length.numel()): 91 | l = length[i] 92 | texts.append( 93 | self.decode( 94 | t[index:index + l], torch.IntTensor([l]), raw=raw)) 95 | index += l 96 | return texts 97 | 98 | # recognize api 99 | class PytorchOcr(): 100 | def __init__(self, model_path='checkpoints/CRNN-1010.pth'): 101 | alphabet_unicode = config.alphabet_v2 102 | self.alphabet = ''.join([chr(uni) for uni in alphabet_unicode]) 103 | # print(len(self.alphabet)) 104 | self.nclass = len(self.alphabet) + 1 105 | self.model = CRNN(config.imgH, 1, self.nclass, 256) 106 | self.cuda = False 107 | if torch.cuda.is_available(): 108 | self.cuda = True 109 | self.model.cuda() 110 | self.model.load_state_dict({k.replace('module.', ''): v for k, v in torch.load(model_path).items()}) 111 | else: 112 | # self.model = nn.DataParallel(self.model) 113 | self.model.load_state_dict(torch.load(model_path, map_location='cpu')) 114 | self.model.eval() 115 | self.converter = strLabelConverter(self.alphabet) 116 | 117 | def recognize(self, img): 118 | h,w = img.shape[:2] 119 | if len(img.shape) == 3: 120 | img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) 121 | image = Image.fromarray(img) 122 | transformer = resizeNormalize((int(w/h*32), 32)) 123 | image = transformer(image) 124 | image = image.view(1, *image.size()) 125 | image = Variable(image) 126 | 127 | if self.cuda: 128 | image = image.cuda() 129 | 130 | preds = self.model(image) 131 | 132 | _, preds = preds.max(2) 133 | preds = preds.transpose(1, 0).contiguous().view(-1) 134 | 135 | preds_size = Variable(torch.IntTensor([preds.size(0)])) 136 | txt = self.converter.decode(preds.data, preds_size.data, raw=False).strip() 137 | 138 | return txt 139 | 140 | 141 | if __name__ == '__main__': 142 | model_path = './recognize/crnn_models/CRNN-1008.pth' 143 | recognizer = PytorchOcr(model_path) 144 | img_name = 't1.jpg' 145 | img = cv2.imread(img_name) 146 | h, w = img.shape[:2] 147 | res = recognizer.recognize(img) 148 | print(res) 149 | 150 | 151 | 152 | 153 | -------------------------------------------------------------------------------- /train_code/train_ctpn/ctpn_model_v2.py: -------------------------------------------------------------------------------- 1 | #-*- coding:utf-8 -*- 2 | #''' 3 | # Created on 18-12-11 上午10:01 4 | # 5 | # @Author: Greg Gao(laygin) 6 | #''' 7 | import os 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | import torchvision.models as models 12 | 13 | 14 | class RPN_REGR_Loss(nn.Module): 15 | def __init__(self, device, sigma=9.0): 16 | super(RPN_REGR_Loss, self).__init__() 17 | self.sigma = sigma 18 | self.device = device 19 | 20 | def forward(self, input, target): 21 | ''' 22 | smooth L1 loss 23 | :param input:y_preds 24 | :param target: y_true 25 | :return: 26 | ''' 27 | try: 28 | cls = target[0, :, 0] 29 | regr = target[0, :, 1:3] 30 | # apply regression to positive sample 31 | regr_keep = (cls == 1).nonzero()[:, 0] 32 | regr_true = regr[regr_keep] 33 | regr_pred = input[0][regr_keep] 34 | diff = torch.abs(regr_true - regr_pred) 35 | less_one = (diff<1.0/self.sigma).float() 36 | loss = less_one * 0.5 * diff ** 2 * self.sigma + torch.abs(1- less_one) * (diff - 0.5/self.sigma) 37 | loss = torch.sum(loss, 1) 38 | loss = torch.mean(loss) if loss.numel() > 0 else torch.tensor(0.0) 39 | except Exception as e: 40 | print('RPN_REGR_Loss Exception:', e) 41 | # print(input, target) 42 | loss = torch.tensor(0.0) 43 | 44 | return loss.to(self.device) 45 | 46 | 47 | class RPN_CLS_Loss(nn.Module): 48 | def __init__(self,device): 49 | super(RPN_CLS_Loss, self).__init__() 50 | self.device = device 51 | 52 | def forward(self, input, target): 53 | y_true = target[0][0] 54 | cls_keep = (y_true != -1).nonzero()[:, 0] 55 | cls_true = y_true[cls_keep].long() 56 | cls_pred = input[0][cls_keep] 57 | loss = F.nll_loss(F.log_softmax(cls_pred, dim=-1), cls_true) # original is sparse_softmax_cross_entropy_with_logits 58 | # loss = nn.BCEWithLogitsLoss()(cls_pred[:,0], cls_true.float()) # 18-12-8 59 | loss = torch.clamp(torch.mean(loss), 0, 10) if loss.numel() > 0 else torch.tensor(0.0) 60 | return loss.to(self.device) 61 | 62 | class RPN_Loss(nn.Module): 63 | def __init__(self,device): 64 | super(RPN_Loss, self).__init__() 65 | self.device = device 66 | self.L_cls = nn.CrossEntropyLoss(reduction='none') 67 | self.L_regr = nn.SmoothL1Loss() 68 | self.L_refi = nn.SmoothL1Loss() 69 | self.pos_neg_ratio = 3 70 | 71 | def forward(self, cls, regr, refi, target_cls, target_regr, target_refi): 72 | 73 | # calculate classification loss 74 | cls_gt = target_cls[0][0] 75 | cls_pos = (cls_gt == 1).nonzero()[:, 0] 76 | gt_pos = cls_gt[cls_pos].long() 77 | cls_pred_pos = input[0][cls_pos] 78 | 79 | cls_neg = (cls_gt == 0).nonzero()[:, 0] 80 | gt_neg = cls_gt[cls_neg].long() 81 | cls_pred_neg = input[0][cls_neg] 82 | 83 | loss_pos = self.L_cls(cls_pred_pos.view(-1,2),gt_pos.view(-1)) 84 | loss_neg = self.L_cls(cls_pred_neg.view(-1,2),gt_neg.view(-1)) 85 | loss_neg_topK, _ = torch.topk(loss_neg, min(len(loss_neg), len(loss_pos) * self.pos_neg_ratio)) 86 | loss_cls = loss_pos.mean() + loss_neg_topK.mean() 87 | return loss_cls.to(self.device) 88 | 89 | 90 | 91 | 92 | class basic_conv(nn.Module): 93 | def __init__(self, 94 | in_planes, 95 | out_planes, 96 | kernel_size, 97 | stride=1, 98 | padding=0, 99 | dilation=1, 100 | groups=1, 101 | relu=True, 102 | bn=True, 103 | bias=True): 104 | super(basic_conv, self).__init__() 105 | self.out_channels = out_planes 106 | self.conv = nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias) 107 | self.bn = nn.BatchNorm2d(out_planes, eps=1e-5, momentum=0.01, affine=True) if bn else None 108 | self.relu = nn.ReLU(inplace=True) if relu else None 109 | 110 | def forward(self, x): 111 | x = self.conv(x) 112 | if self.bn is not None: 113 | x = self.bn(x) 114 | if self.relu is not None: 115 | x = self.relu(x) 116 | return x 117 | 118 | 119 | class CTPN_Model(nn.Module): 120 | def __init__(self): 121 | super().__init__() 122 | base_model = models.vgg16(pretrained=False) 123 | layers = list(base_model.features)[:-1] 124 | self.base_layers = nn.Sequential(*layers) # block5_conv3 output 125 | self.rpn = basic_conv(512, 512, 3, 1, 1, bn=False) 126 | self.brnn = nn.GRU(512,128, bidirectional=True, batch_first=True) 127 | self.lstm_fc = basic_conv(256, 512, 1, 1, relu=True, bn=False) 128 | self.rpn_class = basic_conv(512, 10 * 2, 1, 1, relu=False, bn=False) 129 | self.rpn_regress = basic_conv(512, 10 * 2, 1, 1, relu=False, bn=False) 130 | self.rpn_refiment = basic_conv(512, 10, 1, 1, relu=False, bn=False) 131 | 132 | def forward(self, x): 133 | x = self.base_layers(x) 134 | # rpn 135 | x = self.rpn(x) #[b, c, h, w] 136 | 137 | x1 = x.permute(0,2,3,1).contiguous() # channels last [b, h, w, c] 138 | b = x1.size() # b, h, w, c 139 | x1 = x1.view(b[0]*b[1], b[2], b[3]) 140 | 141 | x2, _ = self.brnn(x1) 142 | 143 | xsz = x.size() 144 | x3 = x2.view(xsz[0], xsz[2], xsz[3], 256) # torch.Size([4, 20, 20, 256]) 145 | 146 | x3 = x3.permute(0,3,1,2).contiguous() # channels first [b, c, h, w] 147 | x3 = self.lstm_fc(x3) 148 | x = x3 149 | 150 | cls = self.rpn_class(x) 151 | regr = self.rpn_regress(x) 152 | refi = self.rpn_refiment(x) 153 | 154 | cls = cls.permute(0,2,3,1).contiguous() # [b,h,w,c] 155 | regr = regr.permute(0,2,3,1).contiguous() 156 | refi = refi.permute(0,2,3,1).contiguous() 157 | 158 | cls = cls.view(cls.size(0), cls.size(1)*cls.size(2)*10, 2) 159 | regr = regr.view(regr.size(0), regr.size(1)*regr.size(2)*10, 2) 160 | refi = refi.view(refi.size(0),refi.size(1)*refi.size(2)*10,1) 161 | 162 | return cls, regr, refi 163 | -------------------------------------------------------------------------------- /train_code/train_crnn/trans_utils.py: -------------------------------------------------------------------------------- 1 | import sys, os 2 | import shutil 3 | from PIL import Image, ImageDraw, ImageFont, ImageChops 4 | import cv2 5 | import numpy as np 6 | # import pyblur 7 | import PIL 8 | from PIL import Image, ImageEnhance 9 | # import 10 | import abc 11 | import time, datetime, inspect 12 | import hashlib 13 | import json 14 | import math 15 | 16 | 17 | def rename(filepath): 18 | # print(f'rename {filepath} to 00X') 19 | filelist = os.listdir(filepath) 20 | filelist.sort() 21 | i = 1 22 | for filename in filelist: 23 | if str(filename) == '.DS_Store': 24 | continue 25 | ext = filename.split('.')[-1] 26 | shutil.move(filepath + '/' + filename, filepath + '/' + str(i).zfill(3) + '.' + ext) 27 | i += 1 28 | 29 | 30 | def zlog(func): 31 | def new_fn(*args): 32 | start = time.time() 33 | result = func(*args) 34 | end = time.time() 35 | duration = end - start 36 | duration = "%.4f" % duration 37 | # fulltime = time.strftime("%Y-%m-%d %H:%M:%S %f", time.localtime()) 38 | fulltime = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S.%f') 39 | # print(f'{fulltime} {__file__} {func.__name__}:{inspect.getsourcelines(func)[-1]} cost: {duration}s', 40 | # sep=' ', end='\n', file=sys.stdout, flush=False) 41 | return result 42 | 43 | return new_fn 44 | 45 | 46 | def getpilimage(image): 47 | if isinstance(image, PIL.Image.Image): # or isinstance(image, PIL.JpegImagePlugin.JpegImageFile): 48 | return image 49 | elif isinstance(image, np.ndarray): 50 | return cv2pil(image) 51 | 52 | 53 | def getcvimage(image): 54 | if isinstance(image, np.ndarray): 55 | return image 56 | elif isinstance(image, PIL.Image.Image): # or isinstance(image, PIL.JpegImagePlugin.JpegImageFile): 57 | return pil2cv(image) 58 | 59 | 60 | def cshowone(image): 61 | image = getcvimage(image) 62 | cv2.imshow('tmp', image) 63 | cv2.waitKey(3000) 64 | return 65 | 66 | 67 | def pshowone(image): 68 | image = getpilimage(image) 69 | image.show() 70 | return 71 | 72 | 73 | def cshowtwo(image1, image2): 74 | width = 800 / 2 75 | height = 500 / 2 76 | image1 = getpilimage(image1) 77 | image2 = getpilimage(image2) 78 | h, w = image1.size 79 | image1 = image1.resize((int(width), int(h * height / w))) 80 | image2 = image2.resize(image1.size) 81 | bigimg = Image.new('RGB', (width * 2, image1.size[1])) 82 | 83 | bigimg.paste(image1, (0, 0, image1.size[0], image1.size[1])) 84 | bigimg.paste(image2, (width, 0, width + image1.size[0], image1.size[1])) 85 | bigimg = getcvimage(bigimg) 86 | cshowone(bigimg) 87 | return 88 | 89 | 90 | def pshowtwo(image1, image2): 91 | width = int(800 / 2) 92 | height = int(500 / 2) 93 | image1 = getpilimage(image1) 94 | image2 = getpilimage(image2) 95 | h, w = image1.size 96 | image1 = image1.resize((int(width), int(h * height / w))) 97 | image2 = image2.resize(image1.size) 98 | bigimg = Image.new('RGB', (width * 2, image1.size[1])) 99 | 100 | bigimg.paste(image1, (0, 0, image1.size[0], image1.size[1])) 101 | bigimg.paste(image2, (width, 0, width + image1.size[0], image1.size[1])) 102 | pshowone(bigimg) 103 | return 104 | 105 | 106 | def pil2cv(image): 107 | # assert isinstance(image, PIL.Image.Image) or isinstance(image, 108 | # PIL.JpegImagePlugin.JpegImageFile), f'input image type is not PIL.image and is {type( 109 | # image)}' 110 | if len(image.split()) == 1: 111 | return np.asarray(image) 112 | elif len(image.split()) == 3: 113 | return cv2.cvtColor(np.asarray(image), cv2.COLOR_RGB2BGR) 114 | elif len(image.split()) == 4: 115 | return cv2.cvtColor(np.asarray(image), cv2.COLOR_RGBA2BGR) 116 | 117 | 118 | def cv2pil(image): 119 | assert isinstance(image, np.ndarray), 'input image type is not cv2' 120 | if len(image.shape) == 2: 121 | return Image.fromarray(image) 122 | elif len(image.shape) == 3: 123 | return Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB)) 124 | 125 | 126 | def rgb2gray(filename): 127 | im = Image.open(filename).convert('L') 128 | im.show() 129 | 130 | new_image = Image.new("L", (im.width + 6, im.height + 6), 0) 131 | out_image = Image.new("L", (im.width + 6, im.height + 6), 0) 132 | 133 | new_image.paste(im, (3, 3, im.width + 3, im.height + 3)) 134 | 135 | im = getcvimage(im) 136 | new_image = getcvimage(new_image) 137 | out_image = getcvimage(out_image) 138 | 139 | _, thresh = cv2.threshold(new_image, 0, 255, cv2.THRESH_OTSU) 140 | pshowone(thresh) 141 | image, contours, hierarchy = cv2.findContours(thresh, 3, 2) 142 | # cnt = contours[0] 143 | # hull = cv2.convexHull(cnt) 144 | # image = cv2.cvtColor(image, cv2.COLOR_GRAY2BGR) 145 | print(len(contours)) 146 | cv2.polylines(out_image, contours, True, 255) 147 | # cv2.fillPoly(image, [cnt], 255) 148 | image = getpilimage(out_image) 149 | im = getpilimage(im) 150 | image = image.crop((3, 3, im.width + 3, im.height + 3)) 151 | # char_color = image.crop((3,3,char_image.width + 3, char_image.height + 3)) 152 | image.show() 153 | return 154 | 155 | 156 | def uniqueimg(filepath): 157 | # print(f'unique {filepath}') 158 | filepath += '/' 159 | filelist = os.listdir(filepath) 160 | filelist.sort() 161 | i = 1 162 | for filename in filelist: 163 | if str(filename) == '.DS_Store': 164 | continue 165 | fd = np.array(Image.open(filepath + filename)) 166 | fmd5 = hashlib.md5(fd) 167 | # print(fmd5.hexdigest()) 168 | # print(filename) 169 | ext = filename.split('.')[-1] 170 | shutil.move(filepath + filename, filepath + fmd5.hexdigest() + '.' + ext) 171 | # i += 1 172 | 173 | 174 | if __name__ == '__main__': 175 | # print(sys.argv) 176 | # rename(sys.argv[1]) 177 | # uniqueimg('/Users/ganyufei/temp/') 178 | allimg = getlabeljson('/Users/ganyufei/Desktop/jiu_zheng/jiu_zheng.json') 179 | print(cal_sim_all(allimg['20190113_092023.jpg'], allimg['NID 7333475056 (1) Front.jpg'])) 180 | # print(cal_sim_all(allimg['20190113_092023.jpg'], allimg['20190113_092023.jpg'])) 181 | 182 | # genpair(12) 183 | 184 | 185 | 186 | 187 | 188 | 189 | 190 | 191 | 192 | -------------------------------------------------------------------------------- /train_code/train_ctpn/ctpn_model.py: -------------------------------------------------------------------------------- 1 | #-*- coding:utf-8 -*- 2 | #''' 3 | # Created on 18-12-11 上午10:01 4 | # 5 | # @Author: Greg Gao(laygin) 6 | #''' 7 | import os 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | import torchvision.models as models 12 | import config 13 | 14 | class RPN_REGR_Loss(nn.Module): 15 | def __init__(self, device, sigma=9.0): 16 | super(RPN_REGR_Loss, self).__init__() 17 | self.sigma = sigma 18 | self.device = device 19 | 20 | def forward(self, input, target): 21 | ''' 22 | smooth L1 loss 23 | :param input:y_preds 24 | :param target: y_true 25 | :return: 26 | ''' 27 | try: 28 | cls = target[0, :, 0] 29 | regr = target[0, :, 1:3] 30 | # apply regression to positive sample 31 | regr_keep = (cls == 1).nonzero()[:, 0] 32 | regr_true = regr[regr_keep] 33 | regr_pred = input[0][regr_keep] 34 | diff = torch.abs(regr_true - regr_pred) 35 | less_one = (diff<1.0/self.sigma).float() 36 | loss = less_one * 0.5 * diff ** 2 * self.sigma + torch.abs(1- less_one) * (diff - 0.5/self.sigma) 37 | loss = torch.sum(loss, 1) 38 | loss = torch.mean(loss) if loss.numel() > 0 else torch.tensor(0.0) 39 | except Exception as e: 40 | print('RPN_REGR_Loss Exception:', e) 41 | # print(input, target) 42 | loss = torch.tensor(0.0) 43 | 44 | return loss.to(self.device) 45 | 46 | 47 | class RPN_CLS_Loss(nn.Module): 48 | def __init__(self,device): 49 | super(RPN_CLS_Loss, self).__init__() 50 | self.device = device 51 | self.L_cls = nn.CrossEntropyLoss(reduction='none') 52 | # self.L_regr = nn.SmoothL1Loss() 53 | # self.L_refi = nn.SmoothL1Loss() 54 | self.pos_neg_ratio = 3 55 | 56 | def forward(self, input, target): 57 | if config.OHEM: 58 | cls_gt = target[0][0] 59 | num_pos = 0 60 | loss_pos_sum = 0 61 | 62 | # print(len((cls_gt == 0).nonzero()),len((cls_gt == 1).nonzero())) 63 | 64 | if len((cls_gt == 1).nonzero())!=0: # avoid num of pos sample is 0 65 | cls_pos = (cls_gt == 1).nonzero()[:, 0] 66 | gt_pos = cls_gt[cls_pos].long() 67 | cls_pred_pos = input[0][cls_pos] 68 | # print(cls_pred_pos.shape) 69 | loss_pos = self.L_cls(cls_pred_pos.view(-1, 2), gt_pos.view(-1)) 70 | loss_pos_sum = loss_pos.sum() 71 | num_pos = len(loss_pos) 72 | 73 | cls_neg = (cls_gt == 0).nonzero()[:, 0] 74 | gt_neg = cls_gt[cls_neg].long() 75 | cls_pred_neg = input[0][cls_neg] 76 | 77 | loss_neg = self.L_cls(cls_pred_neg.view(-1, 2), gt_neg.view(-1)) 78 | loss_neg_topK, _ = torch.topk(loss_neg, min(len(loss_neg), config.RPN_TOTAL_NUM-num_pos)) 79 | loss_cls = loss_pos_sum+loss_neg_topK.sum() 80 | loss_cls = loss_cls/config.RPN_TOTAL_NUM 81 | return loss_cls.to(self.device) 82 | else: 83 | y_true = target[0][0] 84 | cls_keep = (y_true != -1).nonzero()[:, 0] 85 | cls_true = y_true[cls_keep].long() 86 | cls_pred = input[0][cls_keep] 87 | loss = F.nll_loss(F.log_softmax(cls_pred, dim=-1), 88 | cls_true) # original is sparse_softmax_cross_entropy_with_logits 89 | # loss = nn.BCEWithLogitsLoss()(cls_pred[:,0], cls_true.float()) # 18-12-8 90 | loss = torch.clamp(torch.mean(loss), 0, 10) if loss.numel() > 0 else torch.tensor(0.0) 91 | return loss.to(self.device) 92 | 93 | 94 | class basic_conv(nn.Module): 95 | def __init__(self, 96 | in_planes, 97 | out_planes, 98 | kernel_size, 99 | stride=1, 100 | padding=0, 101 | dilation=1, 102 | groups=1, 103 | relu=True, 104 | bn=True, 105 | bias=True): 106 | super(basic_conv, self).__init__() 107 | self.out_channels = out_planes 108 | self.conv = nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias) 109 | self.bn = nn.BatchNorm2d(out_planes, eps=1e-5, momentum=0.01, affine=True) if bn else None 110 | self.relu = nn.ReLU(inplace=True) if relu else None 111 | 112 | def forward(self, x): 113 | x = self.conv(x) 114 | if self.bn is not None: 115 | x = self.bn(x) 116 | if self.relu is not None: 117 | x = self.relu(x) 118 | return x 119 | 120 | 121 | class CTPN_Model(nn.Module): 122 | def __init__(self): 123 | super().__init__() 124 | base_model = models.vgg16(pretrained=False) 125 | layers = list(base_model.features)[:-1] 126 | self.base_layers = nn.Sequential(*layers) # block5_conv3 output 127 | self.rpn = basic_conv(512, 512, 3, 1, 1, bn=False) 128 | self.brnn = nn.GRU(512,128, bidirectional=True, batch_first=True) 129 | self.lstm_fc = basic_conv(256, 512, 1, 1, relu=True, bn=False) 130 | self.rpn_class = basic_conv(512, 10 * 2, 1, 1, relu=False, bn=False) 131 | self.rpn_regress = basic_conv(512, 10 * 2, 1, 1, relu=False, bn=False) 132 | 133 | def forward(self, x): 134 | x = self.base_layers(x) 135 | # rpn 136 | x = self.rpn(x) #[b, c, h, w] 137 | 138 | x1 = x.permute(0,2,3,1).contiguous() # channels last [b, h, w, c] 139 | b = x1.size() # b, h, w, c 140 | x1 = x1.view(b[0]*b[1], b[2], b[3]) 141 | 142 | x2, _ = self.brnn(x1) 143 | 144 | xsz = x.size() 145 | x3 = x2.view(xsz[0], xsz[2], xsz[3], 256) # torch.Size([4, 20, 20, 256]) 146 | 147 | x3 = x3.permute(0,3,1,2).contiguous() # channels first [b, c, h, w] 148 | x3 = self.lstm_fc(x3) 149 | x = x3 150 | 151 | cls = self.rpn_class(x) 152 | regr = self.rpn_regress(x) 153 | 154 | cls = cls.permute(0,2,3,1).contiguous() 155 | regr = regr.permute(0,2,3,1).contiguous() 156 | 157 | cls = cls.view(cls.size(0), cls.size(1)*cls.size(2)*10, 2) 158 | regr = regr.view(regr.size(0), regr.size(1)*regr.size(2)*10, 2) 159 | 160 | return cls, regr 161 | -------------------------------------------------------------------------------- /train_code/train_crnn/train_warp_ctc.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import argparse 3 | import random 4 | import torch 5 | # import torch.backends.cudnn as cudnn 6 | import torch.optim as optim 7 | import torch.utils.data 8 | from torch.autograd import Variable 9 | import numpy as np 10 | from warpctc_pytorch import CTCLoss 11 | # from torch.nn import CTCLoss 12 | import utils 13 | import mydataset 14 | import crnn as crnn 15 | import config 16 | from online_test import val_model 17 | config.imgW = 800 18 | config.alphabet = config.alphabet_v2 19 | config.nclass = len(config.alphabet) + 1 20 | config.saved_model_prefix = 'CRNN-1010' 21 | config.train_infofile = ['path_to_train_infofile1.txt','path_to_train_infofile2.txt'] 22 | config.val_infofile = 'path_to_test_infofile.txt' 23 | config.keep_ratio = True 24 | config.use_log = True 25 | config.pretrained_model = 'path_to_your_pretrained_model.pth' 26 | config.batchSize = 80 27 | config.workers = 10 28 | config.adam = True 29 | # config.lr = 0.00003 30 | import os 31 | import datetime 32 | 33 | os.environ["CUDA_VISIBLE_DEVICES"] = "0" 34 | log_filename = os.path.join('log/','loss_acc-'+config.saved_model_prefix + '.log') 35 | if not os.path.exists('debug_files'): 36 | os.mkdir('debug_files') 37 | if not os.path.exists(config.saved_model_dir): 38 | os.mkdir(config.saved_model_dir) 39 | if config.use_log and not os.path.exists('log'): 40 | os.mkdir('log') 41 | if config.use_log and os.path.exists(log_filename): 42 | os.remove(log_filename) 43 | if config.experiment is None: 44 | config.experiment = 'expr' 45 | if not os.path.exists(config.experiment): 46 | os.mkdir(config.experiment) 47 | 48 | config.manualSeed = random.randint(1, 10000) # fix seed 49 | print("Random Seed: ", config.manualSeed) 50 | random.seed(config.manualSeed) 51 | np.random.seed(config.manualSeed) 52 | torch.manual_seed(config.manualSeed) 53 | 54 | # cudnn.benchmark = True 55 | train_dataset = mydataset.MyDataset(info_filename=config.train_infofile) 56 | assert train_dataset 57 | if not config.random_sample: 58 | sampler = mydataset.randomSequentialSampler(train_dataset, config.batchSize) 59 | else: 60 | sampler = None 61 | train_loader = torch.utils.data.DataLoader( 62 | train_dataset, batch_size=config.batchSize, 63 | shuffle=True, sampler=sampler, 64 | num_workers=int(config.workers), 65 | collate_fn=mydataset.alignCollate(imgH=config.imgH, imgW=config.imgW, keep_ratio=config.keep_ratio)) 66 | 67 | test_dataset = mydataset.MyDataset( 68 | info_filename=config.val_infofile, transform=mydataset.resizeNormalize((config.imgW, config.imgH), is_test=True)) 69 | 70 | converter = utils.strLabelConverter(config.alphabet) 71 | # criterion = CTCLoss(reduction='sum',zero_infinity=True) 72 | criterion = CTCLoss() 73 | best_acc = 0.9 74 | 75 | 76 | # custom weights initialization called on crnn 77 | def weights_init(m): 78 | classname = m.__class__.__name__ 79 | if classname.find('Conv') != -1: 80 | m.weight.data.normal_(0.0, 0.02) 81 | elif classname.find('BatchNorm') != -1: 82 | m.weight.data.normal_(1.0, 0.02) 83 | m.bias.data.fill_(0) 84 | 85 | 86 | crnn = crnn.CRNN(config.imgH, config.nc, config.nclass, config.nh) 87 | if config.pretrained_model!='' and os.path.exists(config.pretrained_model): 88 | print('loading pretrained model from %s' % config.pretrained_model) 89 | crnn.load_state_dict(torch.load(config.pretrained_model)) 90 | else: 91 | crnn.apply(weights_init) 92 | 93 | print(crnn) 94 | 95 | # image = torch.FloatTensor(config.batchSize, 3, config.imgH, config.imgH) 96 | # text = torch.IntTensor(config.batchSize * 5) 97 | # length = torch.IntTensor(config.batchSize) 98 | device = torch.device('cpu') 99 | if config.cuda: 100 | crnn.cuda() 101 | # crnn = torch.nn.DataParallel(crnn, device_ids=range(opt.ngpu)) 102 | # image = image.cuda() 103 | device = torch.device('cuda:0') 104 | criterion = criterion.cuda() 105 | 106 | # image = Variable(image) 107 | # text = Variable(text) 108 | # length = Variable(length) 109 | 110 | # loss averager 111 | loss_avg = utils.averager() 112 | 113 | # setup optimizer 114 | if config.adam: 115 | optimizer = optim.Adam(crnn.parameters(), lr=config.lr, betas=(config.beta1, 0.999)) 116 | elif config.adadelta: 117 | optimizer = optim.Adadelta(crnn.parameters(), lr=config.lr) 118 | else: 119 | optimizer = optim.RMSprop(crnn.parameters(), lr=config.lr) 120 | 121 | 122 | def val(net, dataset, criterion, max_iter=100): 123 | print('Start val') 124 | for p in net.parameters(): 125 | p.requires_grad = False 126 | 127 | num_correct, num_all = val_model(config.val_infofile,net,True,log_file='compare-'+config.saved_model_prefix+'.log') 128 | accuracy = num_correct / num_all 129 | 130 | print('ocr_acc: %f' % (accuracy)) 131 | if config.use_log: 132 | with open(log_filename, 'a') as f: 133 | f.write('ocr_acc:{}\n'.format(accuracy)) 134 | global best_acc 135 | if accuracy > best_acc: 136 | best_acc = accuracy 137 | torch.save(crnn.state_dict(), '{}/{}_{}_{}.pth'.format(config.saved_model_dir, config.saved_model_prefix, epoch, 138 | int(best_acc * 1000))) 139 | torch.save(crnn.state_dict(), '{}/{}.pth'.format(config.saved_model_dir, config.saved_model_prefix)) 140 | 141 | 142 | def trainBatch(net, criterion, optimizer): 143 | data = train_iter.next() 144 | cpu_images, cpu_texts = data 145 | batch_size = cpu_images.size(0) 146 | image = cpu_images.to(device) 147 | 148 | text, length = converter.encode(cpu_texts) 149 | # utils.loadData(text, t) 150 | # utils.loadData(length, l) 151 | 152 | preds = net(image) # seqLength x batchSize x alphabet_size 153 | preds_size = Variable(torch.IntTensor([preds.size(0)] * batch_size)) # seqLength x batchSize 154 | cost = criterion(preds, text, preds_size, length) / batch_size 155 | if torch.isnan(cost): 156 | print(batch_size,cpu_texts) 157 | else: 158 | net.zero_grad() 159 | cost.backward() 160 | optimizer.step() 161 | return cost 162 | 163 | 164 | for epoch in range(config.niter): 165 | loss_avg.reset() 166 | print('epoch {}....'.format(epoch)) 167 | train_iter = iter(train_loader) 168 | i = 0 169 | n_batch = len(train_loader) 170 | while i < len(train_loader): 171 | for p in crnn.parameters(): 172 | p.requires_grad = True 173 | crnn.train() 174 | cost = trainBatch(crnn, criterion, optimizer) 175 | print('epoch: {} iter: {}/{} Train loss: {:.3f}'.format(epoch, i, n_batch, cost.item())) 176 | loss_avg.add(cost) 177 | loss_avg.add(cost) 178 | i += 1 179 | print('Train loss: %f' % (loss_avg.val())) 180 | if config.use_log: 181 | with open(log_filename, 'a') as f: 182 | f.write('{}\n'.format(datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S.%f'))) 183 | f.write('train_loss:{}\n'.format(loss_avg.val())) 184 | 185 | val(crnn, test_dataset, criterion) 186 | 187 | 188 | -------------------------------------------------------------------------------- /train_code/train_crnn/train_warp_ctc_v2.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import argparse 3 | import random 4 | import torch 5 | # import torch.backends.cudnn as cudnn 6 | import torch.optim as optim 7 | import torch.utils.data 8 | from torch.autograd import Variable 9 | import numpy as np 10 | from warpctc_pytorch import CTCLoss 11 | # from torch.nn import CTCLoss 12 | import utils 13 | import mydataset 14 | import crnn as crnn 15 | import config 16 | from online_test import val_model 17 | config.imgW = 800 18 | config.alphabet = config.alphabet_v2 19 | config.nclass = len(config.alphabet) + 1 20 | config.saved_model_prefix = 'CRNN-1010' 21 | config.train_infofile = ['path_to_train_infofile1.txt','path_to_train_infofile2.txt'] 22 | config.val_infofile = 'path_to_test_infofile.txt' 23 | config.keep_ratio = True 24 | config.use_log = True 25 | config.pretrained_model = 'path_to_your_pretrained_model.pth' 26 | config.batchSize = 80 27 | config.workers = 10 28 | config.adam = True 29 | # config.lr = 0.00003 30 | import os 31 | import datetime 32 | 33 | os.environ["CUDA_VISIBLE_DEVICES"] = "0" 34 | log_filename = os.path.join('log/','loss_acc-'+config.saved_model_prefix + '.log') 35 | if not os.path.exists('debug_files'): 36 | os.mkdir('debug_files') 37 | if not os.path.exists(config.saved_model_dir): 38 | os.mkdir(config.saved_model_dir) 39 | if config.use_log and not os.path.exists('log'): 40 | os.mkdir('log') 41 | if config.use_log and os.path.exists(log_filename): 42 | os.remove(log_filename) 43 | if config.experiment is None: 44 | config.experiment = 'expr' 45 | if not os.path.exists(config.experiment): 46 | os.mkdir(config.experiment) 47 | 48 | config.manualSeed = random.randint(1, 10000) # fix seed 49 | print("Random Seed: ", config.manualSeed) 50 | random.seed(config.manualSeed) 51 | np.random.seed(config.manualSeed) 52 | torch.manual_seed(config.manualSeed) 53 | 54 | # cudnn.benchmark = True 55 | train_dataset = mydataset.MyDataset(info_filename=config.train_infofile) 56 | assert train_dataset 57 | if not config.random_sample: 58 | sampler = mydataset.randomSequentialSampler(train_dataset, config.batchSize) 59 | else: 60 | sampler = None 61 | train_loader = torch.utils.data.DataLoader( 62 | train_dataset, batch_size=config.batchSize, 63 | shuffle=True, sampler=sampler, 64 | num_workers=int(config.workers), 65 | collate_fn=mydataset.alignCollate(imgH=config.imgH, imgW=config.imgW, keep_ratio=config.keep_ratio)) 66 | 67 | test_dataset = mydataset.MyDataset( 68 | info_filename=config.val_infofile, transform=mydataset.resizeNormalize((config.imgW, config.imgH), is_test=True)) 69 | 70 | converter = utils.strLabelConverter(config.alphabet) 71 | # criterion = CTCLoss(reduction='sum',zero_infinity=True) 72 | criterion = CTCLoss() 73 | best_acc = 0.9 74 | 75 | 76 | # custom weights initialization called on crnn 77 | def weights_init(m): 78 | classname = m.__class__.__name__ 79 | if classname.find('Conv') != -1: 80 | m.weight.data.normal_(0.0, 0.02) 81 | elif classname.find('BatchNorm') != -1: 82 | m.weight.data.normal_(1.0, 0.02) 83 | m.bias.data.fill_(0) 84 | 85 | 86 | crnn = crnn.CRNN(config.imgH, config.nc, config.nclass, config.nh) 87 | if config.pretrained_model!='' and os.path.exists(config.pretrained_model): 88 | print('loading pretrained model from %s' % config.pretrained_model) 89 | crnn.load_state_dict(torch.load(config.pretrained_model)) 90 | else: 91 | crnn.apply(weights_init) 92 | 93 | print(crnn) 94 | 95 | # image = torch.FloatTensor(config.batchSize, 3, config.imgH, config.imgH) 96 | # text = torch.IntTensor(config.batchSize * 5) 97 | # length = torch.IntTensor(config.batchSize) 98 | device = torch.device('cpu') 99 | if config.cuda: 100 | crnn.cuda() 101 | # crnn = torch.nn.DataParallel(crnn, device_ids=range(opt.ngpu)) 102 | # image = image.cuda() 103 | device = torch.device('cuda:0') 104 | criterion = criterion.cuda() 105 | 106 | # image = Variable(image) 107 | # text = Variable(text) 108 | # length = Variable(length) 109 | 110 | # loss averager 111 | loss_avg = utils.averager() 112 | 113 | # setup optimizer 114 | if config.adam: 115 | optimizer = optim.Adam(crnn.parameters(), lr=config.lr, betas=(config.beta1, 0.999)) 116 | elif config.adadelta: 117 | optimizer = optim.Adadelta(crnn.parameters(), lr=config.lr) 118 | else: 119 | optimizer = optim.RMSprop(crnn.parameters(), lr=config.lr) 120 | 121 | 122 | def val(net, dataset, criterion, max_iter=100): 123 | print('Start val') 124 | for p in net.parameters(): 125 | p.requires_grad = False 126 | 127 | num_correct, num_all = val_model(config.val_infofile,net,True,log_file='compare-'+config.saved_model_prefix+'.log') 128 | accuracy = num_correct / num_all 129 | 130 | print('ocr_acc: %f' % (accuracy)) 131 | if config.use_log: 132 | with open(log_filename, 'a') as f: 133 | f.write('ocr_acc:{}\n'.format(accuracy)) 134 | global best_acc 135 | if accuracy > best_acc: 136 | best_acc = accuracy 137 | torch.save(crnn.state_dict(), '{}/{}_{}_{}.pth'.format(config.saved_model_dir, config.saved_model_prefix, epoch, 138 | int(best_acc * 1000))) 139 | torch.save(crnn.state_dict(), '{}/{}.pth'.format(config.saved_model_dir, config.saved_model_prefix)) 140 | 141 | 142 | def trainBatch(net, criterion, optimizer): 143 | data = train_iter.next() 144 | cpu_images, cpu_texts = data 145 | batch_size = cpu_images.size(0) 146 | image = cpu_images.to(device) 147 | 148 | text, length = converter.encode(cpu_texts) 149 | # utils.loadData(text, t) 150 | # utils.loadData(length, l) 151 | 152 | preds = net(image) # seqLength x batchSize x alphabet_size 153 | preds_size = Variable(torch.IntTensor([preds.size(0)] * batch_size)) # seqLength x batchSize 154 | cost = criterion(preds, text, preds_size, length) / batch_size 155 | if torch.isnan(cost): 156 | print(batch_size,cpu_texts) 157 | else: 158 | net.zero_grad() 159 | cost.backward() 160 | optimizer.step() 161 | return cost 162 | 163 | 164 | for epoch in range(config.niter): 165 | loss_avg.reset() 166 | print('epoch {}....'.format(epoch)) 167 | train_iter = iter(train_loader) 168 | i = 0 169 | n_batch = len(train_loader) 170 | while i < len(train_loader): 171 | for p in crnn.parameters(): 172 | p.requires_grad = True 173 | crnn.train() 174 | cost = trainBatch(crnn, criterion, optimizer) 175 | print('epoch: {} iter: {}/{} Train loss: {:.3f}'.format(epoch, i, n_batch, cost.item())) 176 | loss_avg.add(cost) 177 | loss_avg.add(cost) 178 | i += 1 179 | print('Train loss: %f' % (loss_avg.val())) 180 | if config.use_log: 181 | with open(log_filename, 'a') as f: 182 | f.write('{}\n'.format(datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S.%f'))) 183 | f.write('train_loss:{}\n'.format(loss_avg.val())) 184 | 185 | val(crnn, test_dataset, criterion) 186 | 187 | 188 | -------------------------------------------------------------------------------- /train_code/train_crnn/train_pytorch_ctc.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import argparse 3 | import random 4 | import torch 5 | # import torch.backends.cudnn as cudnn 6 | import torch.optim as optim 7 | import torch.utils.data 8 | from torch.autograd import Variable 9 | import numpy as np 10 | # from warpctc_pytorch import CTCLoss 11 | from torch.nn import CTCLoss 12 | import utils 13 | import mydataset 14 | import crnn as crnn 15 | import config 16 | from online_test import val_model 17 | config.imgW = 800 18 | config.alphabet = config.alphabet_v2 19 | config.nclass = len(config.alphabet) + 1 20 | config.saved_model_prefix = 'CRNN-1010' 21 | config.train_infofile = ['path_to_train_infofile1.txt','path_to_train_infofile2.txt'] 22 | config.val_infofile = 'path_to_test_infofile.txt' 23 | config.keep_ratio = True 24 | config.use_log = True 25 | config.pretrained_model = 'path_to_your_pretrained_model.pth' 26 | config.batchSize = 80 27 | config.workers = 10 28 | config.adam = True 29 | # config.lr = 0.00003 30 | import os 31 | import datetime 32 | 33 | os.environ["CUDA_VISIBLE_DEVICES"] = "0" 34 | log_filename = os.path.join('log/','loss_acc-'+config.saved_model_prefix + '.log') 35 | if not os.path.exists('debug_files'): 36 | os.mkdir('debug_files') 37 | if not os.path.exists(config.saved_model_dir): 38 | os.mkdir(config.saved_model_dir) 39 | if config.use_log and not os.path.exists('log'): 40 | os.mkdir('log') 41 | if config.use_log and os.path.exists(log_filename): 42 | os.remove(log_filename) 43 | if config.experiment is None: 44 | config.experiment = 'expr' 45 | if not os.path.exists(config.experiment): 46 | os.mkdir(config.experiment) 47 | 48 | config.manualSeed = random.randint(1, 10000) # fix seed 49 | print("Random Seed: ", config.manualSeed) 50 | random.seed(config.manualSeed) 51 | np.random.seed(config.manualSeed) 52 | torch.manual_seed(config.manualSeed) 53 | 54 | # cudnn.benchmark = True 55 | train_dataset = mydataset.MyDataset(info_filename=config.train_infofile) 56 | assert train_dataset 57 | if not config.random_sample: 58 | sampler = mydataset.randomSequentialSampler(train_dataset, config.batchSize) 59 | else: 60 | sampler = None 61 | train_loader = torch.utils.data.DataLoader( 62 | train_dataset, batch_size=config.batchSize, 63 | shuffle=True, sampler=sampler, 64 | num_workers=int(config.workers), 65 | collate_fn=mydataset.alignCollate(imgH=config.imgH, imgW=config.imgW, keep_ratio=config.keep_ratio)) 66 | 67 | test_dataset = mydataset.MyDataset( 68 | info_filename=config.val_infofile, transform=mydataset.resizeNormalize((config.imgW, config.imgH), is_test=True)) 69 | 70 | converter = utils.strLabelConverter(config.alphabet) 71 | criterion = CTCLoss(reduction='sum',zero_infinity=True) 72 | # criterion = CTCLoss() 73 | best_acc = 0.9 74 | 75 | 76 | # custom weights initialization called on crnn 77 | def weights_init(m): 78 | classname = m.__class__.__name__ 79 | if classname.find('Conv') != -1: 80 | m.weight.data.normal_(0.0, 0.02) 81 | elif classname.find('BatchNorm') != -1: 82 | m.weight.data.normal_(1.0, 0.02) 83 | m.bias.data.fill_(0) 84 | 85 | 86 | crnn = crnn.CRNN(config.imgH, config.nc, config.nclass, config.nh) 87 | if config.pretrained_model!='' and os.path.exists(config.pretrained_model): 88 | print('loading pretrained model from %s' % config.pretrained_model) 89 | crnn.load_state_dict(torch.load(config.pretrained_model)) 90 | else: 91 | crnn.apply(weights_init) 92 | 93 | print(crnn) 94 | 95 | # image = torch.FloatTensor(config.batchSize, 3, config.imgH, config.imgH) 96 | # text = torch.IntTensor(config.batchSize * 5) 97 | # length = torch.IntTensor(config.batchSize) 98 | device = torch.device('cpu') 99 | if config.cuda: 100 | crnn.cuda() 101 | # crnn = torch.nn.DataParallel(crnn, device_ids=range(opt.ngpu)) 102 | # image = image.cuda() 103 | device = torch.device('cuda:0') 104 | criterion = criterion.cuda() 105 | 106 | # image = Variable(image) 107 | # text = Variable(text) 108 | # length = Variable(length) 109 | 110 | # loss averager 111 | loss_avg = utils.averager() 112 | 113 | # setup optimizer 114 | if config.adam: 115 | optimizer = optim.Adam(crnn.parameters(), lr=config.lr, betas=(config.beta1, 0.999)) 116 | elif config.adadelta: 117 | optimizer = optim.Adadelta(crnn.parameters(), lr=config.lr) 118 | else: 119 | optimizer = optim.RMSprop(crnn.parameters(), lr=config.lr) 120 | 121 | 122 | def val(net, dataset, criterion, max_iter=100): 123 | print('Start val') 124 | for p in net.parameters(): 125 | p.requires_grad = False 126 | 127 | num_correct, num_all = val_model(config.val_infofile,net,True,log_file='compare-'+config.saved_model_prefix+'.log') 128 | accuracy = num_correct / num_all 129 | 130 | print('ocr_acc: %f' % (accuracy)) 131 | if config.use_log: 132 | with open(log_filename, 'a') as f: 133 | f.write('ocr_acc:{}\n'.format(accuracy)) 134 | global best_acc 135 | if accuracy > best_acc: 136 | best_acc = accuracy 137 | torch.save(crnn.state_dict(), '{}/{}_{}_{}.pth'.format(config.saved_model_dir, config.saved_model_prefix, epoch, 138 | int(best_acc * 1000))) 139 | torch.save(crnn.state_dict(), '{}/{}.pth'.format(config.saved_model_dir, config.saved_model_prefix)) 140 | 141 | 142 | def trainBatch(net, criterion, optimizer): 143 | data = train_iter.next() 144 | cpu_images, cpu_texts = data 145 | batch_size = cpu_images.size(0) 146 | image = cpu_images.to(device) 147 | 148 | text, length = converter.encode(cpu_texts) 149 | # utils.loadData(text, t) 150 | # utils.loadData(length, l) 151 | 152 | preds = net(image) # seqLength x batchSize x alphabet_size 153 | preds_size = Variable(torch.IntTensor([preds.size(0)] * batch_size)) # seqLength x batchSize 154 | cost = criterion(preds.log_softmax(2).cpu(), text, preds_size, length) / batch_size 155 | if torch.isnan(cost): 156 | print(batch_size,cpu_texts) 157 | else: 158 | net.zero_grad() 159 | cost.backward() 160 | optimizer.step() 161 | return cost 162 | 163 | 164 | for epoch in range(config.niter): 165 | loss_avg.reset() 166 | print('epoch {}....'.format(epoch)) 167 | train_iter = iter(train_loader) 168 | i = 0 169 | n_batch = len(train_loader) 170 | while i < len(train_loader): 171 | for p in crnn.parameters(): 172 | p.requires_grad = True 173 | crnn.train() 174 | cost = trainBatch(crnn, criterion, optimizer) 175 | print('epoch: {} iter: {}/{} Train loss: {:.3f}'.format(epoch, i, n_batch, cost.item())) 176 | loss_avg.add(cost) 177 | loss_avg.add(cost) 178 | i += 1 179 | print('Train loss: %f' % (loss_avg.val())) 180 | if config.use_log: 181 | with open(log_filename, 'a') as f: 182 | f.write('{}\n'.format(datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S.%f'))) 183 | f.write('train_loss:{}\n'.format(loss_avg.val())) 184 | 185 | val(crnn, test_dataset, criterion) 186 | 187 | 188 | -------------------------------------------------------------------------------- /train_code/train_ctpn/data/dataset.py: -------------------------------------------------------------------------------- 1 | #-*- coding:utf-8 -*- 2 | #''' 3 | # Created on 18-12-27 上午10:34 4 | # 5 | # @Author: Greg Gao(laygin) 6 | #''' 7 | 8 | import os 9 | import xml.etree.ElementTree as ET 10 | import numpy as np 11 | import cv2 12 | from torch.utils.data import Dataset 13 | import torch 14 | from config import IMAGE_MEAN 15 | from ctpn_utils import cal_rpn 16 | 17 | 18 | def readxml(path): 19 | gtboxes = [] 20 | imgfile = '' 21 | xml = ET.parse(path) 22 | for elem in xml.iter(): 23 | if 'filename' in elem.tag: 24 | imgfile = elem.text 25 | if 'object' in elem.tag: 26 | for attr in list(elem): 27 | if 'bndbox' in attr.tag: 28 | xmin = int(round(float(attr.find('xmin').text))) 29 | ymin = int(round(float(attr.find('ymin').text))) 30 | xmax = int(round(float(attr.find('xmax').text))) 31 | ymax = int(round(float(attr.find('ymax').text))) 32 | 33 | gtboxes.append((xmin, ymin, xmax, ymax)) 34 | 35 | return np.array(gtboxes), imgfile 36 | 37 | 38 | # for ctpn text detection 39 | class VOCDataset(Dataset): 40 | def __init__(self, 41 | datadir, 42 | labelsdir): 43 | ''' 44 | 45 | :param txtfile: image name list text file 46 | :param datadir: image's directory 47 | :param labelsdir: annotations' directory 48 | ''' 49 | if not os.path.isdir(datadir): 50 | raise Exception('[ERROR] {} is not a directory'.format(datadir)) 51 | if not os.path.isdir(labelsdir): 52 | raise Exception('[ERROR] {} is not a directory'.format(labelsdir)) 53 | 54 | self.datadir = datadir 55 | self.img_names = os.listdir(self.datadir) 56 | self.labelsdir = labelsdir 57 | 58 | def __len__(self): 59 | return len(self.img_names) 60 | 61 | def __getitem__(self, idx): 62 | img_name = self.img_names[idx] 63 | img_path = os.path.join(self.datadir, img_name) 64 | print(img_path) 65 | xml_path = os.path.join(self.labelsdir, img_name.replace('.jpg', '.xml')) 66 | gtbox, _ = readxml(xml_path) 67 | img = cv2.imread(img_path) 68 | h, w, c = img.shape 69 | 70 | # clip image 71 | if np.random.randint(2) == 1: 72 | img = img[:, ::-1, :] 73 | newx1 = w - gtbox[:, 2] - 1 74 | newx2 = w - gtbox[:, 0] - 1 75 | gtbox[:, 0] = newx1 76 | gtbox[:, 2] = newx2 77 | 78 | [cls, regr], _ = cal_rpn((h, w), (int(h / 16), int(w / 16)), 16, gtbox) 79 | 80 | m_img = img - IMAGE_MEAN 81 | 82 | regr = np.hstack([cls.reshape(cls.shape[0], 1), regr]) 83 | 84 | cls = np.expand_dims(cls, axis=0) 85 | 86 | # transform to torch tensor 87 | m_img = torch.from_numpy(m_img.transpose([2, 0, 1])).float() 88 | cls = torch.from_numpy(cls).float() 89 | regr = torch.from_numpy(regr).float() 90 | 91 | return m_img, cls, regr 92 | 93 | class ICDARDataset(Dataset): 94 | def __init__(self, 95 | datadir, 96 | labelsdir): 97 | ''' 98 | 99 | :param txtfile: image name list text file 100 | :param datadir: image's directory 101 | :param labelsdir: annotations' directory 102 | ''' 103 | if not os.path.isdir(datadir): 104 | raise Exception('[ERROR] {} is not a directory'.format(datadir)) 105 | if not os.path.isdir(labelsdir): 106 | raise Exception('[ERROR] {} is not a directory'.format(labelsdir)) 107 | 108 | self.datadir = datadir 109 | self.img_names = os.listdir(self.datadir) 110 | self.labelsdir = labelsdir 111 | 112 | def __len__(self): 113 | return len(self.img_names) 114 | 115 | def box_transfer(self,coor_lists,rescale_fac = 1.0): 116 | gtboxes = [] 117 | for coor_list in coor_lists: 118 | coors_x = [int(coor_list[2*i]) for i in range(4)] 119 | coors_y = [int(coor_list[2*i+1]) for i in range(4)] 120 | xmin = min(coors_x) 121 | xmax = max(coors_x) 122 | ymin = min(coors_y) 123 | ymax = max(coors_y) 124 | if rescale_fac>1.0: 125 | xmin = int(xmin / rescale_fac) 126 | xmax = int(xmax / rescale_fac) 127 | ymin = int(ymin / rescale_fac) 128 | ymax = int(ymax / rescale_fac) 129 | gtboxes.append((xmin, ymin, xmax, ymax)) 130 | return np.array(gtboxes) 131 | 132 | def box_transfer_v2(self,coor_lists,rescale_fac = 1.0): 133 | gtboxes = [] 134 | for coor_list in coor_lists: 135 | coors_x = [int(coor_list[2 * i]) for i in range(4)] 136 | coors_y = [int(coor_list[2 * i + 1]) for i in range(4)] 137 | xmin = min(coors_x) 138 | xmax = max(coors_x) 139 | ymin = min(coors_y) 140 | ymax = max(coors_y) 141 | if rescale_fac > 1.0: 142 | xmin = int(xmin / rescale_fac) 143 | xmax = int(xmax / rescale_fac) 144 | ymin = int(ymin / rescale_fac) 145 | ymax = int(ymax / rescale_fac) 146 | prev = xmin 147 | for i in range(xmin // 16 + 1, xmax // 16 + 1): 148 | next = 16*i-0.5 149 | gtboxes.append((prev, ymin, next, ymax)) 150 | prev = next 151 | gtboxes.append((prev, ymin, xmax, ymax)) 152 | return np.array(gtboxes) 153 | 154 | def parse_gtfile(self,gt_path,rescale_fac = 1.0): 155 | coor_lists = list() 156 | with open(gt_path) as f: 157 | content = f.readlines() 158 | for line in content: 159 | coor_list = line.split(',')[:8] 160 | if len(coor_list)==8: 161 | coor_lists.append(coor_list) 162 | return self.box_transfer_v2(coor_lists,rescale_fac) 163 | 164 | def draw_boxes(self,img,cls,base_anchors,gt_box): 165 | for i in range(len(cls)): 166 | if cls[i]==1: 167 | pt1 = (int(base_anchors[i][0]),int(base_anchors[i][1])) 168 | pt2 = (int(base_anchors[i][2]),int(base_anchors[i][3])) 169 | img = cv2.rectangle(img,pt1,pt2,(200,100,100)) 170 | for i in range(gt_box.shape[0]): 171 | pt1 = (int(gt_box[i][0]),int(gt_box[i][1])) 172 | pt2 = (int(gt_box[i][2]),int(gt_box[i][3])) 173 | img = cv2.rectangle(img, pt1, pt2, (100, 200, 100)) 174 | return img 175 | 176 | def __getitem__(self, idx): 177 | img_name = self.img_names[idx] 178 | img_path = os.path.join(self.datadir, img_name) 179 | # print(img_path) 180 | img = cv2.imread(img_path) 181 | #####for read error, use default image##### 182 | if img is None: 183 | print(img_path) 184 | with open('error_imgs.txt','a') as f: 185 | f.write('{}\n'.format(img_path)) 186 | img_name = 'img_2647.jpg' 187 | img_path = os.path.join(self.datadir, img_name) 188 | img = cv2.imread(img_path) 189 | 190 | #####for read error, use default image##### 191 | 192 | h, w, c = img.shape 193 | rescale_fac = max(h, w) / 1600 194 | if rescale_fac>1.0: 195 | h = int(h/rescale_fac) 196 | w = int(w/rescale_fac) 197 | img = cv2.resize(img,(w,h)) 198 | 199 | gt_path = os.path.join(self.labelsdir, 'gt_'+img_name.split('.')[0]+'.txt') 200 | gtbox = self.parse_gtfile(gt_path,rescale_fac) 201 | 202 | # clip image 203 | if np.random.randint(2) == 1: 204 | img = img[:, ::-1, :] 205 | newx1 = w - gtbox[:, 2] - 1 206 | newx2 = w - gtbox[:, 0] - 1 207 | gtbox[:, 0] = newx1 208 | gtbox[:, 2] = newx2 209 | 210 | [cls, regr], base_anchors = cal_rpn((h, w), (int(h / 16), int(w / 16)), 16, gtbox) 211 | # debug_img = self.draw_boxes(img.copy(),cls,base_anchors,gtbox) 212 | # cv2.imwrite('debug/{}'.format(img_name),debug_img) 213 | m_img = img - IMAGE_MEAN 214 | 215 | regr = np.hstack([cls.reshape(cls.shape[0], 1), regr]) 216 | 217 | cls = np.expand_dims(cls, axis=0) 218 | 219 | # transform to torch tensor 220 | m_img = torch.from_numpy(m_img.transpose([2, 0, 1])).float() 221 | cls = torch.from_numpy(cls).float() 222 | regr = torch.from_numpy(regr).float() 223 | 224 | return m_img, cls, regr 225 | 226 | if __name__ == '__main__': 227 | xmin = 15 228 | xmax = 95 229 | for i in range(xmin//16+1,xmax//16+1): 230 | print(16*i-0.5) -------------------------------------------------------------------------------- /recognize/crnn.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from collections import OrderedDict 3 | 4 | class BidirectionalLSTM(nn.Module): 5 | 6 | def __init__(self, nIn, nHidden, nOut): 7 | super(BidirectionalLSTM, self).__init__() 8 | 9 | self.rnn = nn.LSTM(nIn, nHidden, bidirectional=True) 10 | self.embedding = nn.Linear(nHidden * 2, nOut) 11 | 12 | def forward(self, input): 13 | recurrent, _ = self.rnn(input) 14 | T, b, h = recurrent.size() 15 | t_rec = recurrent.view(T * b, h) 16 | 17 | output = self.embedding(t_rec) # [T * b, nOut] 18 | output = output.view(T, b, -1) 19 | return output 20 | 21 | 22 | class CRNN(nn.Module): 23 | 24 | def __init__(self, imgH, nc, nclass, nh, leakyRelu=False): 25 | super(CRNN, self).__init__() 26 | assert imgH % 16 == 0, 'imgH has to be a multiple of 16' 27 | 28 | # 1x32x128 29 | self.conv1 = nn.Conv2d(nc, 64, 3, 1, 1) 30 | self.relu1 = nn.ReLU(True) 31 | self.pool1 = nn.MaxPool2d(2, 2) 32 | 33 | # 64x16x64 34 | self.conv2 = nn.Conv2d(64, 128, 3, 1, 1) 35 | self.relu2 = nn.ReLU(True) 36 | self.pool2 = nn.MaxPool2d(2, 2) 37 | 38 | # 128x8x32 39 | self.conv3_1 = nn.Conv2d(128, 256, 3, 1, 1) 40 | self.bn3 = nn.BatchNorm2d(256) 41 | self.relu3_1 = nn.ReLU(True) 42 | self.conv3_2 = nn.Conv2d(256, 256, 3, 1, 1) 43 | self.relu3_2 = nn.ReLU(True) 44 | self.pool3 = nn.MaxPool2d((2, 2), (2, 1), (0, 1)) 45 | 46 | # 256x4x16 47 | self.conv4_1 = nn.Conv2d(256, 512, 3, 1, 1) 48 | self.bn4 = nn.BatchNorm2d(512) 49 | self.relu4_1 = nn.ReLU(True) 50 | self.conv4_2 = nn.Conv2d(512, 512, 3, 1, 1) 51 | self.relu4_2 = nn.ReLU(True) 52 | self.pool4 = nn.MaxPool2d((2, 2), (2, 1), (0, 1)) 53 | 54 | # 512x2x16 55 | self.conv5 = nn.Conv2d(512, 512, 2, 1, 0) 56 | self.bn5 = nn.BatchNorm2d(512) 57 | self.relu5 = nn.ReLU(True) 58 | 59 | # 512x1x16 60 | 61 | self.rnn = nn.Sequential( 62 | BidirectionalLSTM(512, nh, nh), 63 | BidirectionalLSTM(nh, nh, nclass)) 64 | 65 | 66 | def forward(self, input): 67 | # conv features 68 | x = self.pool1(self.relu1(self.conv1(input))) 69 | x = self.pool2(self.relu2(self.conv2(x))) 70 | x = self.pool3(self.relu3_2(self.conv3_2(self.relu3_1(self.bn3(self.conv3_1(x)))))) 71 | x = self.pool4(self.relu4_2(self.conv4_2(self.relu4_1(self.bn4(self.conv4_1(x)))))) 72 | conv = self.relu5(self.bn5(self.conv5(x))) 73 | # print(conv.size()) 74 | 75 | b, c, h, w = conv.size() 76 | assert h == 1, "the height of conv must be 1" 77 | conv = conv.squeeze(2) 78 | conv = conv.permute(2, 0, 1) # [w, b, c] 79 | 80 | # rnn features 81 | output = self.rnn(conv) 82 | 83 | return output 84 | 85 | 86 | class CRNN_v2(nn.Module): 87 | 88 | def __init__(self, imgH, nc, nclass, nh, leakyRelu=False): 89 | super(CRNN_v2, self).__init__() 90 | assert imgH % 16 == 0, 'imgH has to be a multiple of 16' 91 | 92 | # 1x32x128 93 | self.conv1_1 = nn.Conv2d(nc, 32, 3, 1, 1) 94 | self.bn1_1 = nn.BatchNorm2d(32) 95 | self.relu1_1 = nn.ReLU(True) 96 | 97 | self.conv1_2 = nn.Conv2d(32, 64, 3, 1, 1) 98 | self.bn1_2 = nn.BatchNorm2d(64) 99 | self.relu1_2 = nn.ReLU(True) 100 | self.pool1 = nn.MaxPool2d(2, 2) 101 | 102 | # 64x16x64 103 | self.conv2_1 = nn.Conv2d(64, 64, 3, 1, 1) 104 | self.bn2_1 = nn.BatchNorm2d(64) 105 | self.relu2_1 = nn.ReLU(True) 106 | 107 | self.conv2_2 = nn.Conv2d(64, 128, 3, 1, 1) 108 | self.bn2_2 = nn.BatchNorm2d(128) 109 | self.relu2_2 = nn.ReLU(True) 110 | self.pool2 = nn.MaxPool2d(2, 2) 111 | 112 | # 128x8x32 113 | self.conv3_1 = nn.Conv2d(128, 96, 3, 1, 1) 114 | self.bn3_1 = nn.BatchNorm2d(96) 115 | self.relu3_1 = nn.ReLU(True) 116 | 117 | self.conv3_2 = nn.Conv2d(96, 192, 3, 1, 1) 118 | self.bn3_2 = nn.BatchNorm2d(192) 119 | self.relu3_2 = nn.ReLU(True) 120 | self.pool3 = nn.MaxPool2d((2, 2), (2, 1), (0, 1)) 121 | 122 | # 192x4x32 123 | self.conv4_1 = nn.Conv2d(192, 128, 3, 1, 1) 124 | self.bn4_1 = nn.BatchNorm2d(128) 125 | self.relu4_1 = nn.ReLU(True) 126 | self.conv4_2 = nn.Conv2d(128, 256, 3, 1, 1) 127 | self.bn4_2 = nn.BatchNorm2d(256) 128 | self.relu4_2 = nn.ReLU(True) 129 | self.pool4 = nn.MaxPool2d((2, 2), (2, 1), (0, 1)) 130 | 131 | # 256x2x32 132 | self.bn5 = nn.BatchNorm2d(256) 133 | 134 | 135 | # 256x2x32 136 | 137 | self.rnn = nn.Sequential( 138 | BidirectionalLSTM(512, nh, nh), 139 | BidirectionalLSTM(nh, nh, nclass)) 140 | 141 | 142 | def forward(self, input): 143 | # conv features 144 | x = self.pool1(self.relu1_2(self.bn1_2(self.conv1_2(self.relu1_1(self.bn1_1(self.conv1_1(input))))))) 145 | x = self.pool2(self.relu2_2(self.bn2_2(self.conv2_2(self.relu2_1(self.bn2_1(self.conv2_1(x))))))) 146 | x = self.pool3(self.relu3_2(self.bn3_2(self.conv3_2(self.relu3_1(self.bn3_1(self.conv3_1(x))))))) 147 | x = self.pool4(self.relu4_2(self.bn4_2(self.conv4_2(self.relu4_1(self.bn4_1(self.conv4_1(x))))))) 148 | conv = self.bn5(x) 149 | # print(conv.size()) 150 | 151 | b, c, h, w = conv.size() 152 | assert h == 2, "the height of conv must be 2" 153 | conv = conv.reshape([b,c*h,w]) 154 | conv = conv.permute(2, 0, 1) # [w, b, c] 155 | 156 | # rnn features 157 | output = self.rnn(conv) 158 | 159 | return output 160 | 161 | 162 | def conv3x3(nIn, nOut, stride=1): 163 | # "3x3 convolution with padding" 164 | return nn.Conv2d( nIn, nOut, kernel_size=3, stride=stride, padding=1, bias=False ) 165 | 166 | 167 | class basic_res_block(nn.Module): 168 | 169 | def __init__(self, nIn, nOut, stride=1, downsample=None): 170 | super( basic_res_block, self ).__init__() 171 | m = OrderedDict() 172 | m['conv1'] = conv3x3( nIn, nOut, stride ) 173 | m['bn1'] = nn.BatchNorm2d( nOut ) 174 | m['relu1'] = nn.ReLU( inplace=True ) 175 | m['conv2'] = conv3x3( nOut, nOut ) 176 | m['bn2'] = nn.BatchNorm2d( nOut ) 177 | self.group1 = nn.Sequential( m ) 178 | 179 | self.relu = nn.Sequential( nn.ReLU( inplace=True ) ) 180 | self.downsample = downsample 181 | 182 | def forward(self, x): 183 | if self.downsample is not None: 184 | residual = self.downsample( x ) 185 | else: 186 | residual = x 187 | out = self.group1( x ) + residual 188 | out = self.relu( out ) 189 | return out 190 | 191 | 192 | class CRNN_res(nn.Module): 193 | 194 | def __init__(self, imgH, nc, nclass, nh): 195 | super(CRNN_res, self).__init__() 196 | assert imgH % 16 == 0, 'imgH has to be a multiple of 16' 197 | 198 | self.conv1 = nn.Conv2d(nc, 64, 3, 1, 1) 199 | self.relu1 = nn.ReLU(True) 200 | self.res1 = basic_res_block(64, 64) 201 | # 1x32x128 202 | 203 | down1 = nn.Sequential(nn.Conv2d(64, 128, kernel_size=1, stride=2, bias=False),nn.BatchNorm2d(128)) 204 | self.res2_1 = basic_res_block( 64, 128, 2, down1 ) 205 | self.res2_2 = basic_res_block(128,128) 206 | # 64x16x64 207 | 208 | down2 = nn.Sequential(nn.Conv2d(128, 256, kernel_size=1, stride=2, bias=False),nn.BatchNorm2d(256)) 209 | self.res3_1 = basic_res_block(128, 256, 2, down2) 210 | self.res3_2 = basic_res_block(256, 256) 211 | self.res3_3 = basic_res_block(256, 256) 212 | # 128x8x32 213 | 214 | down3 = nn.Sequential(nn.Conv2d(256, 512, kernel_size=1, stride=(2, 1), bias=False),nn.BatchNorm2d(512)) 215 | self.res4_1 = basic_res_block(256, 512, (2, 1), down3) 216 | self.res4_2 = basic_res_block(512, 512) 217 | self.res4_3 = basic_res_block(512, 512) 218 | # 256x4x16 219 | 220 | self.pool = nn.AvgPool2d((2, 2), (2, 1), (0, 1)) 221 | # 512x2x16 222 | 223 | self.conv5 = nn.Conv2d(512, 512, 2, 1, 0) 224 | self.bn5 = nn.BatchNorm2d(512) 225 | self.relu5 = nn.ReLU(True) 226 | # 512x1x16 227 | 228 | self.rnn = nn.Sequential( 229 | BidirectionalLSTM(512, nh, nh), 230 | BidirectionalLSTM(nh, nh, nclass)) 231 | 232 | def forward(self, input): 233 | # conv features 234 | x = self.res1(self.relu1(self.conv1(input))) 235 | x = self.res2_2(self.res2_1(x)) 236 | x = self.res3_3(self.res3_2(self.res3_1(x))) 237 | x = self.res4_3(self.res4_2(self.res4_1(x))) 238 | x = self.pool(x) 239 | conv = self.relu5(self.bn5(self.conv5(x))) 240 | # print(conv.size()) 241 | b, c, h, w = conv.size() 242 | assert h == 1, "the height of conv must be 1" 243 | conv = conv.squeeze(2) 244 | conv = conv.permute(2, 0, 1) # [w, b, c] 245 | 246 | # rnn features 247 | output = self.rnn(conv) 248 | 249 | return output 250 | 251 | if __name__ == '__main__': 252 | pass 253 | -------------------------------------------------------------------------------- /train_code/train_crnn/crnn.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from collections import OrderedDict 3 | 4 | class BidirectionalLSTM(nn.Module): 5 | 6 | def __init__(self, nIn, nHidden, nOut): 7 | super(BidirectionalLSTM, self).__init__() 8 | 9 | self.rnn = nn.LSTM(nIn, nHidden, bidirectional=True) 10 | self.embedding = nn.Linear(nHidden * 2, nOut) 11 | 12 | def forward(self, input): 13 | recurrent, _ = self.rnn(input) 14 | T, b, h = recurrent.size() 15 | t_rec = recurrent.view(T * b, h) 16 | 17 | output = self.embedding(t_rec) # [T * b, nOut] 18 | output = output.view(T, b, -1) 19 | return output 20 | 21 | 22 | class CRNN(nn.Module): 23 | 24 | def __init__(self, imgH, nc, nclass, nh, leakyRelu=False): 25 | super(CRNN, self).__init__() 26 | assert imgH % 16 == 0, 'imgH has to be a multiple of 16' 27 | 28 | # 1x32x128 29 | self.conv1 = nn.Conv2d(nc, 64, 3, 1, 1) 30 | self.relu1 = nn.ReLU(True) 31 | self.pool1 = nn.MaxPool2d(2, 2) 32 | 33 | # 64x16x64 34 | self.conv2 = nn.Conv2d(64, 128, 3, 1, 1) 35 | self.relu2 = nn.ReLU(True) 36 | self.pool2 = nn.MaxPool2d(2, 2) 37 | 38 | # 128x8x32 39 | self.conv3_1 = nn.Conv2d(128, 256, 3, 1, 1) 40 | self.bn3 = nn.BatchNorm2d(256) 41 | self.relu3_1 = nn.ReLU(True) 42 | self.conv3_2 = nn.Conv2d(256, 256, 3, 1, 1) 43 | self.relu3_2 = nn.ReLU(True) 44 | self.pool3 = nn.MaxPool2d((2, 2), (2, 1), (0, 1)) 45 | 46 | # 256x4x16 47 | self.conv4_1 = nn.Conv2d(256, 512, 3, 1, 1) 48 | self.bn4 = nn.BatchNorm2d(512) 49 | self.relu4_1 = nn.ReLU(True) 50 | self.conv4_2 = nn.Conv2d(512, 512, 3, 1, 1) 51 | self.relu4_2 = nn.ReLU(True) 52 | self.pool4 = nn.MaxPool2d((2, 2), (2, 1), (0, 1)) 53 | 54 | # 512x2x16 55 | self.conv5 = nn.Conv2d(512, 512, 2, 1, 0) 56 | self.bn5 = nn.BatchNorm2d(512) 57 | self.relu5 = nn.ReLU(True) 58 | 59 | # 512x1x16 60 | 61 | self.rnn = nn.Sequential( 62 | BidirectionalLSTM(512, nh, nh), 63 | BidirectionalLSTM(nh, nh, nclass)) 64 | 65 | 66 | def forward(self, input): 67 | # conv features 68 | x = self.pool1(self.relu1(self.conv1(input))) 69 | x = self.pool2(self.relu2(self.conv2(x))) 70 | x = self.pool3(self.relu3_2(self.conv3_2(self.relu3_1(self.bn3(self.conv3_1(x)))))) 71 | x = self.pool4(self.relu4_2(self.conv4_2(self.relu4_1(self.bn4(self.conv4_1(x)))))) 72 | conv = self.relu5(self.bn5(self.conv5(x))) 73 | # print(conv.size()) 74 | 75 | b, c, h, w = conv.size() 76 | assert h == 1, "the height of conv must be 1" 77 | conv = conv.squeeze(2) 78 | conv = conv.permute(2, 0, 1) # [w, b, c] 79 | 80 | # rnn features 81 | output = self.rnn(conv) 82 | 83 | return output 84 | 85 | 86 | class CRNN_v2(nn.Module): 87 | 88 | def __init__(self, imgH, nc, nclass, nh, leakyRelu=False): 89 | super(CRNN_v2, self).__init__() 90 | assert imgH % 16 == 0, 'imgH has to be a multiple of 16' 91 | 92 | # 1x32x128 93 | self.conv1_1 = nn.Conv2d(nc, 32, 3, 1, 1) 94 | self.bn1_1 = nn.BatchNorm2d(32) 95 | self.relu1_1 = nn.ReLU(True) 96 | 97 | self.conv1_2 = nn.Conv2d(32, 64, 3, 1, 1) 98 | self.bn1_2 = nn.BatchNorm2d(64) 99 | self.relu1_2 = nn.ReLU(True) 100 | self.pool1 = nn.MaxPool2d(2, 2) 101 | 102 | # 64x16x64 103 | self.conv2_1 = nn.Conv2d(64, 64, 3, 1, 1) 104 | self.bn2_1 = nn.BatchNorm2d(64) 105 | self.relu2_1 = nn.ReLU(True) 106 | 107 | self.conv2_2 = nn.Conv2d(64, 128, 3, 1, 1) 108 | self.bn2_2 = nn.BatchNorm2d(128) 109 | self.relu2_2 = nn.ReLU(True) 110 | self.pool2 = nn.MaxPool2d(2, 2) 111 | 112 | # 128x8x32 113 | self.conv3_1 = nn.Conv2d(128, 96, 3, 1, 1) 114 | self.bn3_1 = nn.BatchNorm2d(96) 115 | self.relu3_1 = nn.ReLU(True) 116 | 117 | self.conv3_2 = nn.Conv2d(96, 192, 3, 1, 1) 118 | self.bn3_2 = nn.BatchNorm2d(192) 119 | self.relu3_2 = nn.ReLU(True) 120 | self.pool3 = nn.MaxPool2d((2, 2), (2, 1), (0, 1)) 121 | 122 | # 192x4x32 123 | self.conv4_1 = nn.Conv2d(192, 128, 3, 1, 1) 124 | self.bn4_1 = nn.BatchNorm2d(128) 125 | self.relu4_1 = nn.ReLU(True) 126 | self.conv4_2 = nn.Conv2d(128, 256, 3, 1, 1) 127 | self.bn4_2 = nn.BatchNorm2d(256) 128 | self.relu4_2 = nn.ReLU(True) 129 | self.pool4 = nn.MaxPool2d((2, 2), (2, 1), (0, 1)) 130 | 131 | # 256x2x32 132 | self.bn5 = nn.BatchNorm2d(256) 133 | 134 | 135 | # 256x2x32 136 | 137 | self.rnn = nn.Sequential( 138 | BidirectionalLSTM(512, nh, nh), 139 | BidirectionalLSTM(nh, nh, nclass)) 140 | 141 | 142 | def forward(self, input): 143 | # conv features 144 | x = self.pool1(self.relu1_2(self.bn1_2(self.conv1_2(self.relu1_1(self.bn1_1(self.conv1_1(input))))))) 145 | x = self.pool2(self.relu2_2(self.bn2_2(self.conv2_2(self.relu2_1(self.bn2_1(self.conv2_1(x))))))) 146 | x = self.pool3(self.relu3_2(self.bn3_2(self.conv3_2(self.relu3_1(self.bn3_1(self.conv3_1(x))))))) 147 | x = self.pool4(self.relu4_2(self.bn4_2(self.conv4_2(self.relu4_1(self.bn4_1(self.conv4_1(x))))))) 148 | conv = self.bn5(x) 149 | # print(conv.size()) 150 | 151 | b, c, h, w = conv.size() 152 | assert h == 2, "the height of conv must be 2" 153 | conv = conv.reshape([b,c*h,w]) 154 | conv = conv.permute(2, 0, 1) # [w, b, c] 155 | 156 | # rnn features 157 | output = self.rnn(conv) 158 | 159 | return output 160 | 161 | 162 | def conv3x3(nIn, nOut, stride=1): 163 | # "3x3 convolution with padding" 164 | return nn.Conv2d( nIn, nOut, kernel_size=3, stride=stride, padding=1, bias=False ) 165 | 166 | 167 | class basic_res_block(nn.Module): 168 | 169 | def __init__(self, nIn, nOut, stride=1, downsample=None): 170 | super( basic_res_block, self ).__init__() 171 | m = OrderedDict() 172 | m['conv1'] = conv3x3( nIn, nOut, stride ) 173 | m['bn1'] = nn.BatchNorm2d( nOut ) 174 | m['relu1'] = nn.ReLU( inplace=True ) 175 | m['conv2'] = conv3x3( nOut, nOut ) 176 | m['bn2'] = nn.BatchNorm2d( nOut ) 177 | self.group1 = nn.Sequential( m ) 178 | 179 | self.relu = nn.Sequential( nn.ReLU( inplace=True ) ) 180 | self.downsample = downsample 181 | 182 | def forward(self, x): 183 | if self.downsample is not None: 184 | residual = self.downsample( x ) 185 | else: 186 | residual = x 187 | out = self.group1( x ) + residual 188 | out = self.relu( out ) 189 | return out 190 | 191 | 192 | class CRNN_res(nn.Module): 193 | 194 | def __init__(self, imgH, nc, nclass, nh): 195 | super(CRNN_res, self).__init__() 196 | assert imgH % 16 == 0, 'imgH has to be a multiple of 16' 197 | 198 | self.conv1 = nn.Conv2d(nc, 64, 3, 1, 1) 199 | self.relu1 = nn.ReLU(True) 200 | self.res1 = basic_res_block(64, 64) 201 | # 1x32x128 202 | 203 | down1 = nn.Sequential(nn.Conv2d(64, 128, kernel_size=1, stride=2, bias=False),nn.BatchNorm2d(128)) 204 | self.res2_1 = basic_res_block( 64, 128, 2, down1 ) 205 | self.res2_2 = basic_res_block(128,128) 206 | # 64x16x64 207 | 208 | down2 = nn.Sequential(nn.Conv2d(128, 256, kernel_size=1, stride=2, bias=False),nn.BatchNorm2d(256)) 209 | self.res3_1 = basic_res_block(128, 256, 2, down2) 210 | self.res3_2 = basic_res_block(256, 256) 211 | self.res3_3 = basic_res_block(256, 256) 212 | # 128x8x32 213 | 214 | down3 = nn.Sequential(nn.Conv2d(256, 512, kernel_size=1, stride=(2, 1), bias=False),nn.BatchNorm2d(512)) 215 | self.res4_1 = basic_res_block(256, 512, (2, 1), down3) 216 | self.res4_2 = basic_res_block(512, 512) 217 | self.res4_3 = basic_res_block(512, 512) 218 | # 256x4x16 219 | 220 | self.pool = nn.AvgPool2d((2, 2), (2, 1), (0, 1)) 221 | # 512x2x16 222 | 223 | self.conv5 = nn.Conv2d(512, 512, 2, 1, 0) 224 | self.bn5 = nn.BatchNorm2d(512) 225 | self.relu5 = nn.ReLU(True) 226 | # 512x1x16 227 | 228 | self.rnn = nn.Sequential( 229 | BidirectionalLSTM(512, nh, nh), 230 | BidirectionalLSTM(nh, nh, nclass)) 231 | 232 | def forward(self, input): 233 | # conv features 234 | x = self.res1(self.relu1(self.conv1(input))) 235 | x = self.res2_2(self.res2_1(x)) 236 | x = self.res3_3(self.res3_2(self.res3_1(x))) 237 | x = self.res4_3(self.res4_2(self.res4_1(x))) 238 | x = self.pool(x) 239 | conv = self.relu5(self.bn5(self.conv5(x))) 240 | # print(conv.size()) 241 | b, c, h, w = conv.size() 242 | assert h == 1, "the height of conv must be 1" 243 | conv = conv.squeeze(2) 244 | conv = conv.permute(2, 0, 1) # [w, b, c] 245 | 246 | # rnn features 247 | output = self.rnn(conv) 248 | 249 | return output 250 | 251 | if __name__ == '__main__': 252 | pass 253 | -------------------------------------------------------------------------------- /train_code/train_crnn/utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # encoding: utf-8 3 | 4 | import torch 5 | import torch.nn as nn 6 | from torch.autograd import Variable 7 | import collections 8 | from datetime import datetime 9 | import torch.nn.functional as F 10 | 11 | 12 | def get_acc(output, label): 13 | total = output.shape[0] 14 | _, pred_label = output.max(1) 15 | num_correct = (pred_label == label).sum().item() 16 | # print( pred_label.data.cpu().numpy() ) 17 | # print( label.data.cpu().numpy() ) 18 | return 1.0*num_correct / total 19 | 20 | def adjust_learning_rate(optimizer,decay_rate = 0.97): 21 | for param_group in optimizer.param_groups: 22 | param_group['lr'] = param_group['lr']*decay_rate 23 | 24 | def train(net, train_data, valid_data, num_epochs, optimizer, criterion,saver_freq = 50,saver_prefix = 'vgg16'): 25 | if torch.cuda.is_available(): 26 | net = net.cuda() 27 | prev_time = datetime.now() 28 | best_acc = 0.98 29 | for epoch in range(num_epochs): 30 | train_loss = 0 31 | train_acc = 0 32 | net = net.train() 33 | for im, label in train_data: 34 | # print(label) 35 | if torch.cuda.is_available(): 36 | im = Variable(im.cuda()) # (bs, 3, h, w) 37 | label = Variable(label.cuda()) # (bs, h, w) 38 | else: 39 | im = Variable(im) 40 | label = Variable(label) 41 | # forward 42 | output = net(im) 43 | loss = criterion(output, label) 44 | # backward 45 | optimizer.zero_grad() 46 | loss.backward() 47 | optimizer.step() 48 | 49 | train_loss += loss.item() 50 | train_acc += get_acc(output, label) 51 | 52 | cur_time = datetime.now() 53 | h, remainder = divmod((cur_time - prev_time).seconds, 3600) 54 | m, s = divmod(remainder, 60) 55 | time_str = "Time %02d:%02d:%02d" % (h, m, s) 56 | if valid_data is not None: 57 | valid_loss = 0 58 | valid_acc = 0 59 | net = net.eval() 60 | for im, label in valid_data: 61 | if torch.cuda.is_available(): 62 | im = Variable(im.cuda(), volatile=True) 63 | label = Variable(label.cuda(), volatile=True) 64 | else: 65 | im = Variable(im, volatile=True) 66 | label = Variable(label, volatile=True) 67 | output = net(im) 68 | loss = criterion(output, label) 69 | valid_loss += loss.item() 70 | valid_acc += get_acc(output, label) 71 | epoch_str = ( 72 | "Epoch %d. Train Loss: %f, Train Acc: %f, Valid Loss: %f, Valid Acc: %f, " 73 | % (epoch, train_loss / len(train_data), 74 | train_acc / len(train_data), valid_loss / len(valid_data), 75 | valid_acc / len(valid_data))) 76 | if valid_acc / len(valid_data)>best_acc: 77 | best_acc = valid_acc / len(valid_data) 78 | torch.save( net.state_dict(), 'models/{}-{}-{}-0819-model-db.pth'.format(saver_prefix ,epoch + 1,int(best_acc*1000) ) ) 79 | else: 80 | epoch_str = ("Epoch %d. Train Loss: %f, Train Acc: %f, " % 81 | (epoch, train_loss / len(train_data), 82 | train_acc / len(train_data))) 83 | prev_time = cur_time 84 | print(epoch_str + time_str) 85 | # if (epoch+1)%saver_freq == 0: 86 | # # torch.save(net,'models/vgg-16-'+str(epoch+1)+'-model.pth') 87 | # # another weight saver method 88 | # torch.save(net.state_dict(),'models/{}-{}-0711-model.pth'.format(saver_prefix, epoch+1)) 89 | adjust_learning_rate(optimizer) 90 | 91 | class strLabelConverter(object): 92 | """Convert between str and label. 93 | 94 | NOTE: 95 | Insert `blank` to the alphabet for CTC. 96 | 97 | Args: 98 | alphabet (str): set of the possible characters. 99 | ignore_case (bool, default=True): whether or not to ignore all of the case. 100 | """ 101 | 102 | def __init__(self, alphabet, ignore_case=False): 103 | self._ignore_case = ignore_case 104 | if self._ignore_case: 105 | alphabet = alphabet.lower() 106 | self.alphabet = alphabet 107 | self.alphabet.append(ord('_')) # for `-1` index 108 | # print(self.alphabet) 109 | 110 | self.dict = {} 111 | for i, char in enumerate(alphabet): 112 | # NOTE: 0 is reserved for 'blank' required by wrap_ctc 113 | self.dict[char] = i + 1 114 | 115 | def encode(self, text): 116 | """Support batch or single str. 117 | 118 | Args: 119 | text (str or list of str): texts to convert. 120 | 121 | Returns: 122 | torch.IntTensor [length_0 + length_1 + ... length_{n - 1}]: encoded texts. 123 | torch.IntTensor [n]: length of each text. 124 | """ 125 | # print(text) 126 | try: 127 | if isinstance(text, str): 128 | # for char in text: 129 | # print(char) 130 | text = [ 131 | self.dict[ord(char.lower() if self._ignore_case else char)] 132 | for char in text# if char in self.dict.keys() 133 | ] 134 | length = [len(text)] 135 | elif isinstance(text, collections.Iterable): 136 | length = [len(s) for s in text] 137 | text = ''.join(text) 138 | text, _ = self.encode(text) 139 | except KeyError as e: 140 | # print(text) 141 | print(e) 142 | for ch in text: 143 | if ord(ch) not in self.dict.keys(): 144 | print('Not Covering Char: {} - {}'.format(ch,ord(ch))) 145 | return (torch.IntTensor(text), torch.IntTensor(length)) 146 | 147 | def decode(self, t, length, raw=False): 148 | """Decode encoded texts back into strs. 149 | 150 | Args: 151 | torch.IntTensor [length_0 + length_1 + ... length_{n - 1}]: encoded texts. 152 | torch.IntTensor [n]: length of each text. 153 | 154 | Raises: 155 | AssertionError: when the texts and its length does not match. 156 | 157 | Returns: 158 | text (str or list of str): texts to convert. 159 | """ 160 | if length.numel() == 1: 161 | length = length[0] 162 | assert t.numel() == length, "text with length: {} does not match declared length: {}".format(t.numel(), length) 163 | if raw: 164 | return ''.join([chr(self.alphabet[i - 1]) for i in t]) 165 | else: 166 | char_list = [] 167 | for i in range(length): 168 | if t[i] != 0 and (not (i > 0 and t[i - 1] == t[i])): 169 | char_list.append(chr(self.alphabet[t[i] - 1])) 170 | 171 | return ''.join(char_list) 172 | else: 173 | # batch mode 174 | assert t.numel() == length.sum(), "texts with length: {} does not match declared length: {}".format(t.numel(), length.sum()) 175 | texts = [] 176 | index = 0 177 | for i in range(length.numel()): 178 | l = length[i] 179 | texts.append( 180 | self.decode( 181 | t[index:index + l], torch.IntTensor([l]), raw=raw)) 182 | index += l 183 | return texts 184 | 185 | 186 | class averager(object): 187 | """Compute average for `torch.Variable` and `torch.Tensor`. """ 188 | 189 | def __init__(self): 190 | self.reset() 191 | 192 | def add(self, v): 193 | if isinstance(v, Variable): 194 | count = v.data.numel() 195 | v = v.data.sum() 196 | elif isinstance(v, torch.Tensor): 197 | count = v.numel() 198 | v = v.sum() 199 | 200 | self.n_count += count 201 | self.sum += v 202 | 203 | def reset(self): 204 | self.n_count = 0 205 | self.sum = 0 206 | 207 | def val(self): 208 | res = 0 209 | if self.n_count != 0: 210 | res = self.sum / float(self.n_count) 211 | return res 212 | 213 | 214 | def oneHot(v, v_length, nc): 215 | batchSize = v_length.size(0) 216 | maxLength = v_length.max() 217 | v_onehot = torch.FloatTensor(batchSize, maxLength, nc).fill_(0) 218 | acc = 0 219 | for i in range(batchSize): 220 | length = v_length[i] 221 | label = v[acc:acc + length].view(-1, 1).long() 222 | v_onehot[i, :length].scatter_(1, label, 1.0) 223 | acc += length 224 | return v_onehot 225 | 226 | 227 | def loadData(v, data): 228 | v.data.resize_(data.size()).copy_(data) 229 | 230 | 231 | def prettyPrint(v): 232 | print('Size {0}, Type: {1}'.format(str(v.size()), v.data.type())) 233 | print('| Max: %f | Min: %f | Mean: %f' % (v.max().item(), v.min().item(), 234 | v.mean().item())) 235 | 236 | 237 | def assureRatio(img): 238 | """Ensure imgH <= imgW.""" 239 | b, c, h, w = img.size() 240 | if h > w: 241 | main = nn.UpsamplingBilinear2d(size=(h, h), scale_factor=None) 242 | img = main(img) 243 | return img 244 | -------------------------------------------------------------------------------- /train_code/train_crnn/trans.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | #coding:utf-8 3 | import sys 4 | # reload(sys) 5 | # sys.setdefaultencoding("utf-8") 6 | import os, sys, shutil, math, random, json, multiprocessing, threading 7 | from PIL import Image, ImageDraw, ImageFont, ImageChops 8 | import cv2 9 | import numpy as np 10 | from PIL import Image, ImageEnhance, ImageFilter, ImageOps 11 | # import 12 | import abc 13 | import trans_utils 14 | from trans_utils import getpilimage 15 | 16 | 17 | 18 | global colormap 19 | index = 0 20 | 21 | class TransBase(object): 22 | def __init__(self, probability = 1.): 23 | super(TransBase, self).__init__() 24 | self.probability = probability 25 | @abc.abstractmethod 26 | def tranfun(self, inputimage): 27 | pass 28 | # @utils.zlog 29 | def process(self,inputimage): 30 | if np.random.random() < self.probability: 31 | return self.tranfun(inputimage) 32 | else: 33 | return inputimage 34 | 35 | class RandomContrast(TransBase): 36 | def setparam(self, lower=0.5, upper=1.5): 37 | self.lower = lower 38 | self.upper = upper 39 | assert self.upper >= self.lower, "upper must be >= lower." 40 | assert self.lower >= 0, "lower must be non-negative." 41 | def tranfun(self, image): 42 | image = getpilimage(image) 43 | enh_con = ImageEnhance.Brightness(image) 44 | return enh_con.enhance(random.uniform(self.lower, self.upper)) 45 | 46 | class RandomBrightness(TransBase): 47 | def setparam(self, lower=0.5, upper=1.5): 48 | self.lower = lower 49 | self.upper = upper 50 | assert self.upper >= self.lower, "upper must be >= lower." 51 | assert self.lower >= 0, "lower must be non-negative." 52 | def tranfun(self, image): 53 | image = getpilimage(image) 54 | bri = ImageEnhance.Brightness(image) 55 | return bri.enhance(random.uniform(self.lower, self.upper)) 56 | 57 | class RandomColor(TransBase): 58 | def setparam(self, lower=0.5, upper=1.5): 59 | self.lower = lower 60 | self.upper = upper 61 | assert self.upper >= self.lower, "upper must be >= lower." 62 | assert self.lower >= 0, "lower must be non-negative." 63 | def tranfun(self, image): 64 | image = getpilimage(image) 65 | col = ImageEnhance.Color(image) 66 | return col.enhance(random.uniform(self.lower, self.upper)) 67 | 68 | class RandomSharpness(TransBase): 69 | def setparam(self, lower=0.5, upper=1.5): 70 | self.lower = lower 71 | self.upper = upper 72 | assert self.upper >= self.lower, "upper must be >= lower." 73 | assert self.lower >= 0, "lower must be non-negative." 74 | def tranfun(self, image): 75 | image = getpilimage(image) 76 | sha = ImageEnhance.Sharpness(image) 77 | return sha.enhance(random.uniform(self.lower, self.upper)) 78 | 79 | class Compress(TransBase): 80 | def setparam(self, lower=5, upper=85): 81 | self.lower = lower 82 | self.upper = upper 83 | assert self.upper >= self.lower, "upper must be >= lower." 84 | assert self.lower >= 0, "lower must be non-negative." 85 | def tranfun(self, image): 86 | img = trans_utils.getcvimage(image) 87 | param = [int(cv2.IMWRITE_JPEG_QUALITY), random.randint(self.lower, self.upper)] 88 | img_encode = cv2.imencode('.jpeg', img, param) 89 | img_decode = cv2.imdecode(img_encode[1], cv2.IMREAD_COLOR) 90 | pil_img = trans_utils.cv2pil(img_decode) 91 | if len(image.split())==1: 92 | pil_img = pil_img.convert('L') 93 | return pil_img 94 | 95 | class Exposure(TransBase): 96 | def setparam(self, lower=5, upper=10): 97 | self.lower = lower 98 | self.upper = upper 99 | assert self.upper >= self.lower, "upper must be >= lower." 100 | assert self.lower >= 0, "lower must be non-negative." 101 | def tranfun(self, image): 102 | image = trans_utils.getcvimage(image) 103 | h,w = image.shape[:2] 104 | x0 = random.randint(0, w) 105 | y0 = random.randint(0, h) 106 | x1 = random.randint(x0, w) 107 | y1 = random.randint(y0, h) 108 | transparent_area = (x0, y0, x1, y1) 109 | mask=Image.new('L', (w, h), color=255) 110 | draw=ImageDraw.Draw(mask) 111 | mask = np.array(mask) 112 | if len(image.shape)==3: 113 | mask = mask[:,:,np.newaxis] 114 | mask = np.concatenate([mask,mask,mask],axis=2) 115 | draw.rectangle(transparent_area, fill=random.randint(150,255)) 116 | reflection_result = image + (255 - mask) 117 | reflection_result = np.clip(reflection_result, 0, 255) 118 | return trans_utils.cv2pil(reflection_result) 119 | 120 | class Rotate(TransBase): 121 | def setparam(self, lower=-5, upper=5): 122 | self.lower = lower 123 | self.upper = upper 124 | assert self.upper >= self.lower, "upper must be >= lower." 125 | # assert self.lower >= 0, "lower must be non-negative." 126 | def tranfun(self, image): 127 | image = getpilimage(image) 128 | rot = random.uniform(self.lower, self.upper) 129 | trans_img = image.rotate(rot, expand=True) 130 | # trans_img.show() 131 | return trans_img 132 | 133 | class Blur(TransBase): 134 | def setparam(self, lower=0, upper=1): 135 | self.lower = lower 136 | self.upper = upper 137 | assert self.upper >= self.lower, "upper must be >= lower." 138 | assert self.lower >= 0, "lower must be non-negative." 139 | def tranfun(self, image): 140 | image = getpilimage(image) 141 | image = image.filter(ImageFilter.GaussianBlur(radius=1)) 142 | # blurred_image = image.filter(ImageFilter.Kernel((3,3), (1,1,1,0,0,0,2,0,2))) 143 | # Kernel 144 | return image 145 | 146 | class Salt(TransBase): 147 | def setparam(self, rate=0.02): 148 | self.rate = rate 149 | def tranfun(self, image): 150 | image = getpilimage(image) 151 | num_noise = int(image.size[1] * image.size[0] * self.rate) 152 | # assert len(image.split()) == 1 153 | for k in range(num_noise): 154 | i = int(np.random.random() * image.size[1]) 155 | j = int(np.random.random() * image.size[0]) 156 | image.putpixel((j, i), int(np.random.random() * 255)) 157 | return image 158 | 159 | 160 | class AdjustResolution(TransBase): 161 | def setparam(self, max_rate=0.95,min_rate = 0.5): 162 | self.max_rate = max_rate 163 | self.min_rate = min_rate 164 | 165 | def tranfun(self, image): 166 | image = getpilimage(image) 167 | w, h = image.size 168 | rate = np.random.random()*(self.max_rate-self.min_rate)+self.min_rate 169 | w2 = int(w*rate) 170 | h2 = int(h*rate) 171 | image = image.resize((w2, h2)) 172 | image = image.resize((w, h)) 173 | return image 174 | 175 | 176 | class Crop(TransBase): 177 | def setparam(self, maxv=2): 178 | self.maxv = maxv 179 | def tranfun(self, image): 180 | img = trans_utils.getcvimage(image) 181 | h,w = img.shape[:2] 182 | org = np.array([[0,np.random.randint(0,self.maxv)], 183 | [w,np.random.randint(0,self.maxv)], 184 | [0,h-np.random.randint(0,self.maxv)], 185 | [w,h-np.random.randint(0,self.maxv)]],np.float32) 186 | dst = np.array([[0, 0], [w, 0], [0, h], [w, h]], np.float32) 187 | M = cv2.getPerspectiveTransform(org,dst) 188 | res = cv2.warpPerspective(img,M,(w,h)) 189 | return getpilimage(res) 190 | 191 | class Crop2(TransBase): 192 | def setparam(self, maxv_h=4, maxv_w=4): 193 | self.maxv_h = maxv_h 194 | self.maxv_w = maxv_w 195 | def tranfun(self, image_and_loc): 196 | image, left, top, right, bottom = image_and_loc 197 | w, h = image.size 198 | left = np.clip(left,0,w-1) 199 | right = np.clip(right,0,w-1) 200 | top = np.clip(top, 0, h-1) 201 | bottom = np.clip(bottom, 0, h-1) 202 | img = trans_utils.getcvimage(image) 203 | try: 204 | # global index 205 | res = getpilimage(img[top:bottom,left:right]) 206 | # res.save('test_imgs/crop-debug-{}.jpg'.format(index)) 207 | # index+=1 208 | return res 209 | except AttributeError as e: 210 | print('error') 211 | image.save('test_imgs/t.png') 212 | print( left, top, right, bottom) 213 | 214 | h = bottom - top 215 | w = right - left 216 | org = np.array([[left - np.random.randint(0, self.maxv_w), top + np.random.randint(-self.maxv_h, self.maxv_h//2)], 217 | [right + np.random.randint(0, self.maxv_w), top + np.random.randint(-self.maxv_h, self.maxv_h//2)], 218 | [left - np.random.randint(0, self.maxv_w), bottom - np.random.randint(-self.maxv_h, self.maxv_h//2)], 219 | [right + np.random.randint(0, self.maxv_w), bottom - np.random.randint(-self.maxv_h, self.maxv_h//2)]], np.float32) 220 | dst = np.array([[0, 0], [w, 0], [0, h], [w, h]], np.float32) 221 | M = cv2.getPerspectiveTransform(org,dst) 222 | res = cv2.warpPerspective(img,M,(w,h)) 223 | return getpilimage(res) 224 | 225 | class Stretch(TransBase): 226 | def setparam(self, max_rate = 1.2,min_rate = 0.8): 227 | self.max_rate = max_rate 228 | self.min_rate = min_rate 229 | 230 | def tranfun(self, image): 231 | image = getpilimage(image) 232 | w, h = image.size 233 | rate = np.random.random()*(self.max_rate-self.min_rate)+self.min_rate 234 | w2 = int(w*rate) 235 | image = image.resize((w2, h)) 236 | return image 237 | 238 | 239 | if __name__ == "__main__": 240 | # img_name = 'test_files/NID 1468666480 (1) Front.jpg' 241 | # img = Image.open(img_name) 242 | # w,h = img.size 243 | # 244 | # img.show() 245 | # rc = Crop2() 246 | # rc.setparam() 247 | # img = rc.process([img,362,418,581,463]) 248 | # # img = ImageOps.invert(img) 249 | # img.show() 250 | 251 | img_name = 'data_set/images_0701_EC_3/0.png' 252 | img = Image.open(img_name) 253 | print(img.size[1]) 254 | w, h = img.size 255 | img_cv = trans_utils.pil2cv(img) 256 | print(img_cv.shape) 257 | # print(len(img.split())) 258 | 259 | img.show() 260 | # img = cv2.imread(img_name) 261 | rc = Compress() 262 | rc.setparam() 263 | img = rc.process(img) 264 | # img = ImageOps.invert(img) 265 | img.show() 266 | 267 | -------------------------------------------------------------------------------- /train_code/train_crnn/mydataset.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # encoding: utf-8 3 | 4 | import random 5 | import torch 6 | from torch.utils.data import Dataset 7 | from torch.utils.data import sampler 8 | import torchvision.transforms as transforms 9 | from PIL import Image,ImageEnhance,ImageOps 10 | import numpy as np 11 | import codecs 12 | import trans 13 | 14 | debug_idx = 0 15 | debug = True 16 | 17 | crop = trans.Crop(probability=0.1) 18 | crop2 = trans.Crop2(probability=1.1) 19 | random_contrast = trans.RandomContrast(probability=0.1) 20 | random_brightness = trans.RandomBrightness(probability=0.1) 21 | random_color = trans.RandomColor(probability=0.1) 22 | random_sharpness = trans.RandomSharpness(probability=0.1) 23 | compress = trans.Compress(probability=0.3) 24 | exposure = trans.Exposure(probability=0.1) 25 | rotate = trans.Rotate(probability=0.1) 26 | blur = trans.Blur(probability=0.1) 27 | salt = trans.Salt(probability=0.1) 28 | adjust_resolution = trans.AdjustResolution(probability=0.1) 29 | stretch = trans.Stretch(probability=0.1) 30 | 31 | crop.setparam() 32 | crop2.setparam() 33 | random_contrast.setparam() 34 | random_brightness.setparam() 35 | random_color.setparam() 36 | random_sharpness.setparam() 37 | compress.setparam() 38 | exposure.setparam() 39 | rotate.setparam() 40 | blur.setparam() 41 | salt.setparam() 42 | adjust_resolution.setparam() 43 | stretch.setparam() 44 | 45 | def randomColor(image): 46 | """ 47 | 对图像进行颜色抖动 48 | :param image: PIL的图像image 49 | :return: 有颜色色差的图像image 50 | """ 51 | random_factor = np.random.randint( 0, 31 ) / 10. # 随机因子 52 | color_image = ImageEnhance.Color( image ).enhance( random_factor ) # 调整图像的饱和度 53 | random_factor = np.random.randint( 10, 21 ) / 10. # 随机因子 54 | brightness_image = ImageEnhance.Brightness( color_image ).enhance( random_factor ) # 调整图像的亮度 55 | random_factor = np.random.randint( 10, 21 ) / 10. # 随机因1子 56 | contrast_image = ImageEnhance.Contrast( brightness_image ).enhance( random_factor ) # 调整图像对比度 57 | random_factor = np.random.randint( 0, 31 ) / 10. # 随机因子 58 | return ImageEnhance.Sharpness( contrast_image ).enhance( random_factor ) # 调整图像锐度 59 | 60 | def randomGaussian(image, mean=0.2, sigma=0.3): 61 | """ 62 | 对图像进行高斯噪声处理 63 | :param image: 64 | :return: 65 | """ 66 | 67 | def gaussianNoisy(im, mean=0.2, sigma=0.3): 68 | """ 69 | 对图像做高斯噪音处理 70 | :param im: 单通道图像 71 | :param mean: 偏移量 72 | :param sigma: 标准差 73 | :return: 74 | """ 75 | for _i in range( len( im ) ): 76 | im[_i] += random.gauss( mean, sigma ) 77 | return im 78 | 79 | # 将图像转化成数组 80 | img = np.asarray( image ) 81 | img.flags.writeable = True # 将数组改为读写模式 82 | width, height = img.shape[:2] 83 | img_r = gaussianNoisy( img[:, :, 0].flatten(), mean, sigma ) 84 | img_g = gaussianNoisy( img[:, :, 1].flatten(), mean, sigma ) 85 | img_b = gaussianNoisy( img[:, :, 2].flatten(), mean, sigma ) 86 | img[:, :, 0] = img_r.reshape( [width, height] ) 87 | img[:, :, 1] = img_g.reshape( [width, height] ) 88 | img[:, :, 2] = img_b.reshape( [width, height] ) 89 | return Image.fromarray( np.uint8( img ) ) 90 | 91 | def inverse_color(image): 92 | if np.random.random()<0.4: 93 | image = ImageOps.invert(image) 94 | return image 95 | 96 | # def data_tf(img): 97 | # img = randomColor(img) 98 | # # img = randomGaussian(img) 99 | # img = inverse_color(img) 100 | # return img 101 | 102 | def data_tf(img): 103 | img = crop.process(img) 104 | img = random_contrast.process(img) 105 | img = random_brightness.process(img) 106 | img = random_color.process(img) 107 | img = random_sharpness.process(img) 108 | if img.size[1]>=32: 109 | img = compress.process(img) 110 | img = adjust_resolution.process(img) 111 | img = blur.process(img) 112 | img = exposure.process(img) 113 | # img = rotate.process(img) 114 | img = salt.process(img) 115 | img = inverse_color(img) 116 | img = stretch.process(img) 117 | if debug and np.random.random() < 0.001: 118 | global debug_idx 119 | img.save('debug_files/{:05}.jpg'.format(debug_idx)) 120 | debug_idx += 1 121 | if debug_idx == 10000: 122 | debug_idx = 0 123 | return img 124 | 125 | def data_tf_fullimg(img,loc): 126 | left, top, right, bottom = loc 127 | img = crop2.process([img, left, top, right, bottom]) 128 | img = random_contrast.process(img) 129 | img = random_brightness.process(img) 130 | img = random_color.process(img) 131 | img = random_sharpness.process(img) 132 | img = compress.process(img) 133 | img = exposure.process(img) 134 | # img = rotate.process(img) 135 | img = blur.process(img) 136 | img = salt.process(img) 137 | # img = inverse_color(img) 138 | img = adjust_resolution.process(img) 139 | img = stretch.process(img) 140 | return img 141 | 142 | 143 | 144 | class MyDataset(Dataset): 145 | def __init__(self,info_filename,train=True, transform=data_tf,target_transform=None,remove_blank = False): 146 | super(Dataset, self).__init__() 147 | self.transform = transform 148 | self.target_transform = target_transform 149 | self.info_filename = info_filename 150 | if isinstance(self.info_filename,str): 151 | self.info_filename = [self.info_filename] 152 | self.train = train 153 | self.files = list() 154 | self.labels = list() 155 | for info_name in self.info_filename: 156 | with open(info_name) as f: 157 | content = f.readlines() 158 | for line in content: 159 | if '\t' in line: 160 | if len(line.split('\t'))!=2: 161 | print(line) 162 | fname, label = line.split('\t') 163 | 164 | else: 165 | fname,label = line.split('g:') 166 | fname += 'g' 167 | if remove_blank: 168 | label = label.strip() 169 | else: 170 | label = ' '+label.strip()+' ' 171 | self.files.append(fname) 172 | self.labels.append(label) 173 | 174 | def name(self): 175 | return 'MyDataset' 176 | 177 | def __getitem__(self, index): 178 | # print(self.files[index]) 179 | # print(self.files[index]) 180 | img = Image.open(self.files[index]) 181 | if self.transform is not None: 182 | img = self.transform( img ) 183 | img = img.convert('L') 184 | # target = torch.zeros(len(self.labels_min)) 185 | # target[self.labels_min.index(self.labels[index])] = 1 186 | label = self.labels[index] 187 | if self.target_transform is not None: 188 | label = self.target_transform( label ) 189 | return (img,label) 190 | 191 | def __len__(self): 192 | return len(self.labels) 193 | 194 | class MyDatasetPro(Dataset): 195 | def __init__(self, info_filename_txtline=list(), info_filename_fullimg=list(), train=True, txtline_transform=data_tf, 196 | fullimg_transform = data_tf_fullimg, target_transform=None): 197 | super(Dataset, self).__init__() 198 | self.txtline_transform = txtline_transform 199 | self.fullimg_transform = fullimg_transform 200 | self.target_transform = target_transform 201 | self.info_filename_txtline = info_filename_txtline 202 | self.info_filename_fullimg = info_filename_fullimg 203 | if isinstance(self.info_filename_txtline,str): 204 | self.info_filename_txtline = [self.info_filename_txtline] 205 | if isinstance(self.info_filename_fullimg,str): 206 | self.info_filename_fullimg = [self.info_filename_fullimg] 207 | self.train = train 208 | self.files = list() 209 | self.labels = list() 210 | self.locs = list() 211 | for info_name in self.info_filename_txtline: 212 | with open(info_name) as f: 213 | content = f.readlines() 214 | for line in content: 215 | fname,label = line.split('g:') 216 | fname += 'g' 217 | label = label.replace('\r','').replace('\n','') 218 | self.files.append(fname) 219 | self.labels.append(label) 220 | self.txtline_len = len(self.labels) 221 | for info_name in self.info_filename_fullimg: 222 | with open(info_name) as f: 223 | content = f.readlines() 224 | for line in content: 225 | fname,label,left, top, right, bottom = line.strip().split('\t') 226 | self.files.append(fname) 227 | self.labels.append(label) 228 | self.locs.append([int(left),int(top),int(right),int(bottom)]) 229 | print(len(self.labels),len(self.files)) 230 | def name(self): 231 | return 'MyDatasetPro' 232 | 233 | def __getitem__(self, index): 234 | # print(self.files[index]) 235 | label = self.labels[index] 236 | if self.target_transform is not None: 237 | label = self.target_transform(label) 238 | img = Image.open(self.files[index]) 239 | if index>=self.txtline_len: 240 | # print('fullimg:{}'.format(self.files[index]),img.size) 241 | img = self.fullimg_transform(img,self.locs[index-self.txtline_len]) 242 | if index%100 == 0: 243 | img.save('test_imgs/debug-{}-{}.jpg'.format(index,label.strip())) #debug 244 | else: 245 | if self.txtline_transform is not None: 246 | img = self.txtline_transform(img) 247 | img = img.convert('L') 248 | # target = torch.zeros(len(self.labels_min)) 249 | # target[self.labels_min.index(self.labels[index])] = 1 250 | return (img,label) 251 | 252 | def __len__(self): 253 | return len(self.labels) 254 | 255 | class resizeNormalize2(object): 256 | 257 | def __init__(self, size, interpolation=Image.LANCZOS): 258 | self.size = size 259 | self.interpolation = interpolation 260 | self.toTensor = transforms.ToTensor() 261 | 262 | def __call__(self, img): 263 | img = img.resize(self.size, self.interpolation) 264 | img = self.toTensor(img) 265 | img.sub_(0.5).div_(0.5) 266 | return img 267 | 268 | class resizeNormalize(object): 269 | def __init__(self, size, interpolation=Image.LANCZOS,is_test=False): 270 | self.size = size 271 | self.interpolation = interpolation 272 | self.toTensor = transforms.ToTensor() 273 | self.is_test = is_test 274 | 275 | def __call__(self, img): 276 | w,h = self.size 277 | w0 = img.size[0] 278 | h0 = img.size[1] 279 | if w<=(w0/h0*h): 280 | img = img.resize(self.size, self.interpolation) 281 | img = self.toTensor(img) 282 | img.sub_(0.5).div_(0.5) 283 | else: 284 | w_real = int(w0/h0*h) 285 | img = img.resize((w_real,h), self.interpolation) 286 | img = self.toTensor(img) 287 | img.sub_(0.5).div_(0.5) 288 | start = random.randint(0,w-w_real-1) 289 | if self.is_test: 290 | start = 5 291 | w+=10 292 | tmp = torch.zeros([img.shape[0], h, w])+0.5 293 | tmp[:,:,start:start+w_real] = img 294 | img = tmp 295 | return img 296 | 297 | class randomSequentialSampler(sampler.Sampler): 298 | 299 | def __init__(self, data_source, batch_size): 300 | self.num_samples = len(data_source) 301 | self.batch_size = batch_size 302 | 303 | def __iter__(self): 304 | n_batch = len(self) // self.batch_size 305 | tail = len(self) % self.batch_size 306 | index = torch.LongTensor(len(self)).fill_(0) 307 | for i in range(n_batch): 308 | random_start = random.randint(0, len(self) - self.batch_size) 309 | batch_index = random_start + torch.range(0, self.batch_size - 1) 310 | index[i * self.batch_size:(i + 1) * self.batch_size] = batch_index 311 | # deal with tail 312 | if tail: 313 | random_start = random.randint(0, len(self) - self.batch_size) 314 | tail_index = random_start + torch.range(0, tail - 1) 315 | index[(i + 1) * self.batch_size:] = tail_index 316 | 317 | return iter(index) 318 | 319 | def __len__(self): 320 | return self.num_samples 321 | 322 | 323 | class alignCollate(object): 324 | 325 | def __init__(self, imgH=32, imgW=100, keep_ratio=False, min_ratio=1): 326 | self.imgH = imgH 327 | self.imgW = imgW 328 | self.keep_ratio = keep_ratio 329 | self.min_ratio = min_ratio 330 | 331 | def __call__(self, batch): 332 | images, labels = zip(*batch) 333 | 334 | imgH = self.imgH 335 | imgW = self.imgW 336 | if self.keep_ratio: 337 | ratios = [] 338 | for image in images: 339 | w, h = image.size 340 | ratios.append(w / float(h)) 341 | ratios.sort() 342 | max_ratio = ratios[-1] 343 | imgW = int(np.floor(max_ratio * imgH)) 344 | imgW = max(imgH * self.min_ratio, imgW) # assure imgH >= imgW 345 | 346 | transform = resizeNormalize((imgW, imgH)) 347 | images = [transform(image) for image in images] 348 | images = torch.cat([t.unsqueeze(0) for t in images], 0) 349 | 350 | return images, labels 351 | 352 | 353 | 354 | if __name__ == '__main__': 355 | import os 356 | path = 'images' 357 | files = os.listdir(path) 358 | idx = 0 359 | for f in files: 360 | img_name = os.path.join(path,f) 361 | img = Image.open(img_name) 362 | img.show() 363 | idx+=1 364 | if idx>5: 365 | break -------------------------------------------------------------------------------- /detect/ctpn_utils.py: -------------------------------------------------------------------------------- 1 | #-*- coding:utf-8 -*- 2 | #''' 3 | # Created on 18-12-11 上午10:05 4 | # 5 | # @Author: Greg Gao(laygin) 6 | #''' 7 | import numpy as np 8 | import cv2 9 | from detect.config import * 10 | 11 | 12 | def resize(image, width=None, height=None, inter=cv2.INTER_AREA): 13 | # initialize the dimensions of the image to be resized and 14 | # grab the image size 15 | dim = None 16 | (h, w) = image.shape[:2] 17 | 18 | # if both the width and height are None, then return the 19 | # original image 20 | if width is None and height is None: 21 | return image 22 | 23 | # check to see if the width is None 24 | if width is None: 25 | # calculate the ratio of the height and construct the 26 | # dimensions 27 | r = height / float(h) 28 | dim = (int(w * r), height) 29 | 30 | # otherwise, the height is None 31 | else: 32 | # calculate the ratio of the width and construct the 33 | # dimensions 34 | r = width / float(w) 35 | dim = (width, int(h * r)) 36 | 37 | # resize the image 38 | resized = cv2.resize(image, dim, interpolation=inter) 39 | 40 | # return the resized image 41 | return resized 42 | 43 | 44 | def gen_anchor(featuresize, scale): 45 | """ 46 | gen base anchor from feature map [HXW][9][4] 47 | reshape [HXW][9][4] to [HXWX9][4] 48 | """ 49 | heights = [11, 16, 23, 33, 48, 68, 97, 139, 198, 283] 50 | widths = [16, 16, 16, 16, 16, 16, 16, 16, 16, 16] 51 | 52 | # gen k=9 anchor size (h,w) 53 | heights = np.array(heights).reshape(len(heights), 1) 54 | widths = np.array(widths).reshape(len(widths), 1) 55 | 56 | base_anchor = np.array([0, 0, 15, 15]) 57 | # center x,y 58 | xt = (base_anchor[0] + base_anchor[2]) * 0.5 59 | yt = (base_anchor[1] + base_anchor[3]) * 0.5 60 | 61 | # x1 y1 x2 y2 62 | x1 = xt - widths * 0.5 63 | y1 = yt - heights * 0.5 64 | x2 = xt + widths * 0.5 65 | y2 = yt + heights * 0.5 66 | base_anchor = np.hstack((x1, y1, x2, y2)) 67 | 68 | h, w = featuresize 69 | shift_x = np.arange(0, w) * scale 70 | shift_y = np.arange(0, h) * scale 71 | # apply shift 72 | anchor = [] 73 | for i in shift_y: 74 | for j in shift_x: 75 | anchor.append(base_anchor + [j, i, j, i]) 76 | return np.array(anchor).reshape((-1, 4)) 77 | 78 | 79 | def cal_iou(box1, box1_area, boxes2, boxes2_area): 80 | """ 81 | box1 [x1,y1,x2,y2] 82 | boxes2 [Msample,x1,y1,x2,y2] 83 | """ 84 | x1 = np.maximum(box1[0], boxes2[:, 0]) 85 | x2 = np.minimum(box1[2], boxes2[:, 2]) 86 | y1 = np.maximum(box1[1], boxes2[:, 1]) 87 | y2 = np.minimum(box1[3], boxes2[:, 3]) 88 | 89 | intersection = np.maximum(x2 - x1, 0) * np.maximum(y2 - y1, 0) 90 | iou = intersection / (box1_area + boxes2_area[:] - intersection[:]) 91 | return iou 92 | 93 | 94 | def cal_overlaps(boxes1, boxes2): 95 | """ 96 | boxes1 [Nsample,x1,y1,x2,y2] anchor 97 | boxes2 [Msample,x1,y1,x2,y2] grouth-box 98 | 99 | """ 100 | area1 = (boxes1[:, 0] - boxes1[:, 2]) * (boxes1[:, 1] - boxes1[:, 3]) 101 | area2 = (boxes2[:, 0] - boxes2[:, 2]) * (boxes2[:, 1] - boxes2[:, 3]) 102 | 103 | overlaps = np.zeros((boxes1.shape[0], boxes2.shape[0])) 104 | 105 | # calculate the intersection of boxes1(anchor) and boxes2(GT box) 106 | for i in range(boxes1.shape[0]): 107 | overlaps[i][:] = cal_iou(boxes1[i], area1[i], boxes2, area2) 108 | 109 | return overlaps 110 | 111 | 112 | def bbox_transfrom(anchors, gtboxes): 113 | """ 114 | compute relative predicted vertical coordinates Vc ,Vh 115 | with respect to the bounding box location of an anchor 116 | """ 117 | regr = np.zeros((anchors.shape[0], 2)) 118 | Cy = (gtboxes[:, 1] + gtboxes[:, 3]) * 0.5 119 | Cya = (anchors[:, 1] + anchors[:, 3]) * 0.5 120 | h = gtboxes[:, 3] - gtboxes[:, 1] + 1.0 121 | ha = anchors[:, 3] - anchors[:, 1] + 1.0 122 | 123 | Vc = (Cy - Cya) / ha 124 | Vh = np.log(h / ha) 125 | 126 | return np.vstack((Vc, Vh)).transpose() 127 | 128 | 129 | def bbox_transfor_inv(anchor, regr): 130 | """ 131 | return predict bbox 132 | """ 133 | 134 | Cya = (anchor[:, 1] + anchor[:, 3]) * 0.5 135 | ha = anchor[:, 3] - anchor[:, 1] + 1 136 | 137 | Vcx = regr[0, :, 0] 138 | Vhx = regr[0, :, 1] 139 | 140 | Cyx = Vcx * ha + Cya 141 | hx = np.exp(Vhx) * ha 142 | xt = (anchor[:, 0] + anchor[:, 2]) * 0.5 143 | 144 | x1 = xt - 16 * 0.5 145 | y1 = Cyx - hx * 0.5 146 | x2 = xt + 16 * 0.5 147 | y2 = Cyx + hx * 0.5 148 | bbox = np.vstack((x1, y1, x2, y2)).transpose() 149 | 150 | return bbox 151 | 152 | 153 | def clip_box(bbox, im_shape): 154 | # x1 >= 0 155 | bbox[:, 0] = np.maximum(np.minimum(bbox[:, 0], im_shape[1] - 1), 0) 156 | # y1 >= 0 157 | bbox[:, 1] = np.maximum(np.minimum(bbox[:, 1], im_shape[0] - 1), 0) 158 | # x2 < im_shape[1] 159 | bbox[:, 2] = np.maximum(np.minimum(bbox[:, 2], im_shape[1] - 1), 0) 160 | # y2 < im_shape[0] 161 | bbox[:, 3] = np.maximum(np.minimum(bbox[:, 3], im_shape[0] - 1), 0) 162 | 163 | return bbox 164 | 165 | 166 | def filter_bbox(bbox, minsize): 167 | ws = bbox[:, 2] - bbox[:, 0] + 1 168 | hs = bbox[:, 3] - bbox[:, 1] + 1 169 | keep = np.where((ws >= minsize) & (hs >= minsize))[0] 170 | return keep 171 | 172 | 173 | def cal_rpn(imgsize, featuresize, scale, gtboxes): 174 | imgh, imgw = imgsize 175 | 176 | # gen base anchor 177 | base_anchor = gen_anchor(featuresize, scale) 178 | 179 | # calculate iou 180 | overlaps = cal_overlaps(base_anchor, gtboxes) 181 | 182 | # init labels -1 don't care 0 is negative 1 is positive 183 | labels = np.empty(base_anchor.shape[0]) 184 | labels.fill(-1) 185 | 186 | # for each GT box corresponds to an anchor which has highest IOU 187 | gt_argmax_overlaps = overlaps.argmax(axis=0) 188 | 189 | # the anchor with the highest IOU overlap with a GT box 190 | anchor_argmax_overlaps = overlaps.argmax(axis=1) 191 | anchor_max_overlaps = overlaps[range(overlaps.shape[0]), anchor_argmax_overlaps] 192 | 193 | # IOU > IOU_POSITIVE 194 | labels[anchor_max_overlaps > IOU_POSITIVE] = 1 195 | # IOU = imgw) | 205 | (base_anchor[:, 3] >= imgh) 206 | )[0] 207 | labels[outside_anchor] = -1 208 | 209 | # subsample positive labels ,if greater than RPN_POSITIVE_NUM(default 128) 210 | fg_index = np.where(labels == 1)[0] 211 | if (len(fg_index) > RPN_POSITIVE_NUM): 212 | labels[np.random.choice(fg_index, len(fg_index) - RPN_POSITIVE_NUM, replace=False)] = -1 213 | 214 | # subsample negative labels 215 | bg_index = np.where(labels == 0)[0] 216 | num_bg = RPN_TOTAL_NUM - np.sum(labels == 1) 217 | if (len(bg_index) > num_bg): 218 | # print('bgindex:',len(bg_index),'num_bg',num_bg) 219 | labels[np.random.choice(bg_index, len(bg_index) - num_bg, replace=False)] = -1 220 | 221 | # calculate bbox targets 222 | # debug here 223 | bbox_targets = bbox_transfrom(base_anchor, gtboxes[anchor_argmax_overlaps, :]) 224 | # bbox_targets=[] 225 | 226 | return [labels, bbox_targets], base_anchor 227 | 228 | 229 | def nms(dets, thresh): 230 | x1 = dets[:, 0] 231 | y1 = dets[:, 1] 232 | x2 = dets[:, 2] 233 | y2 = dets[:, 3] 234 | scores = dets[:, 4] 235 | 236 | areas = (x2 - x1 + 1) * (y2 - y1 + 1) 237 | order = scores.argsort()[::-1] 238 | 239 | keep = [] 240 | while order.size > 0: 241 | i = order[0] 242 | keep.append(i) 243 | xx1 = np.maximum(x1[i], x1[order[1:]]) 244 | yy1 = np.maximum(y1[i], y1[order[1:]]) 245 | xx2 = np.minimum(x2[i], x2[order[1:]]) 246 | yy2 = np.minimum(y2[i], y2[order[1:]]) 247 | 248 | w = np.maximum(0.0, xx2 - xx1 + 1) 249 | h = np.maximum(0.0, yy2 - yy1 + 1) 250 | inter = w * h 251 | ovr = inter / (areas[i] + areas[order[1:]] - inter) 252 | 253 | inds = np.where(ovr <= thresh)[0] 254 | order = order[inds + 1] 255 | return keep 256 | 257 | 258 | # for predict 259 | class Graph: 260 | def __init__(self, graph): 261 | self.graph = graph 262 | 263 | def sub_graphs_connected(self): 264 | sub_graphs = [] 265 | for index in range(self.graph.shape[0]): 266 | if not self.graph[:, index].any() and self.graph[index, :].any(): 267 | v = index 268 | sub_graphs.append([v]) 269 | while self.graph[v, :].any(): 270 | v = np.where(self.graph[v, :])[0][0] 271 | sub_graphs[-1].append(v) 272 | return sub_graphs 273 | 274 | 275 | class TextLineCfg: 276 | SCALE = 600 277 | MAX_SCALE = 1200 278 | TEXT_PROPOSALS_WIDTH = 16 279 | MIN_NUM_PROPOSALS = 2 280 | MIN_RATIO = 0.5 281 | LINE_MIN_SCORE = 0.9 282 | MAX_HORIZONTAL_GAP = 60 283 | TEXT_PROPOSALS_MIN_SCORE = 0.7 284 | TEXT_PROPOSALS_NMS_THRESH = 0.3 285 | MIN_V_OVERLAPS = 0.6 286 | MIN_SIZE_SIM = 0.6 287 | 288 | 289 | class TextProposalGraphBuilder: 290 | """ 291 | Build Text proposals into a graph. 292 | """ 293 | 294 | def get_successions(self, index): 295 | box = self.text_proposals[index] 296 | results = [] 297 | for left in range(int(box[0]) + 1, min(int(box[0]) + TextLineCfg.MAX_HORIZONTAL_GAP + 1, self.im_size[1])): 298 | adj_box_indices = self.boxes_table[left] 299 | for adj_box_index in adj_box_indices: 300 | if self.meet_v_iou(adj_box_index, index): 301 | results.append(adj_box_index) 302 | if len(results) != 0: 303 | return results 304 | return results 305 | 306 | def get_precursors(self, index): 307 | box = self.text_proposals[index] 308 | results = [] 309 | for left in range(int(box[0]) - 1, max(int(box[0] - TextLineCfg.MAX_HORIZONTAL_GAP), 0) - 1, -1): 310 | adj_box_indices = self.boxes_table[left] 311 | for adj_box_index in adj_box_indices: 312 | if self.meet_v_iou(adj_box_index, index): 313 | results.append(adj_box_index) 314 | if len(results) != 0: 315 | return results 316 | return results 317 | 318 | def is_succession_node(self, index, succession_index): 319 | precursors = self.get_precursors(succession_index) 320 | if self.scores[index] >= np.max(self.scores[precursors]): 321 | return True 322 | return False 323 | 324 | def meet_v_iou(self, index1, index2): 325 | def overlaps_v(index1, index2): 326 | h1 = self.heights[index1] 327 | h2 = self.heights[index2] 328 | y0 = max(self.text_proposals[index2][1], self.text_proposals[index1][1]) 329 | y1 = min(self.text_proposals[index2][3], self.text_proposals[index1][3]) 330 | return max(0, y1 - y0 + 1) / min(h1, h2) 331 | 332 | def size_similarity(index1, index2): 333 | h1 = self.heights[index1] 334 | h2 = self.heights[index2] 335 | return min(h1, h2) / max(h1, h2) 336 | 337 | return overlaps_v(index1, index2) >= TextLineCfg.MIN_V_OVERLAPS and \ 338 | size_similarity(index1, index2) >= TextLineCfg.MIN_SIZE_SIM 339 | 340 | def build_graph(self, text_proposals, scores, im_size): 341 | self.text_proposals = text_proposals 342 | self.scores = scores 343 | self.im_size = im_size 344 | self.heights = text_proposals[:, 3] - text_proposals[:, 1] + 1 345 | 346 | boxes_table = [[] for _ in range(self.im_size[1])] 347 | for index, box in enumerate(text_proposals): 348 | boxes_table[int(box[0])].append(index) 349 | self.boxes_table = boxes_table 350 | 351 | graph = np.zeros((text_proposals.shape[0], text_proposals.shape[0]), np.bool) 352 | 353 | for index, box in enumerate(text_proposals): 354 | successions = self.get_successions(index) 355 | if len(successions) == 0: 356 | continue 357 | succession_index = successions[np.argmax(scores[successions])] 358 | if self.is_succession_node(index, succession_index): 359 | # NOTE: a box can have multiple successions(precursors) if multiple successions(precursors) 360 | # have equal scores. 361 | graph[index, succession_index] = True 362 | return Graph(graph) 363 | 364 | 365 | class TextProposalConnectorOriented: 366 | """ 367 | Connect text proposals into text lines 368 | """ 369 | 370 | def __init__(self): 371 | self.graph_builder = TextProposalGraphBuilder() 372 | 373 | def group_text_proposals(self, text_proposals, scores, im_size): 374 | graph = self.graph_builder.build_graph(text_proposals, scores, im_size) 375 | return graph.sub_graphs_connected() 376 | 377 | def fit_y(self, X, Y, x1, x2): 378 | # len(X) != 0 379 | # if X only include one point, the function will get line y=Y[0] 380 | if np.sum(X == X[0]) == len(X): 381 | return Y[0], Y[0] 382 | p = np.poly1d(np.polyfit(X, Y, 1)) 383 | return p(x1), p(x2) 384 | 385 | def get_text_lines(self, text_proposals, scores, im_size): 386 | """ 387 | text_proposals:boxes 388 | 389 | """ 390 | # tp=text proposal 391 | tp_groups = self.group_text_proposals(text_proposals, scores, im_size) # 首先还是建图,获取到文本行由哪几个小框构成 392 | 393 | text_lines = np.zeros((len(tp_groups), 8), np.float32) 394 | 395 | for index, tp_indices in enumerate(tp_groups): 396 | text_line_boxes = text_proposals[list(tp_indices)] # 每个文本行的全部小框 397 | X = (text_line_boxes[:, 0] + text_line_boxes[:, 2]) / 2 # 求每一个小框的中心x,y坐标 398 | Y = (text_line_boxes[:, 1] + text_line_boxes[:, 3]) / 2 399 | 400 | z1 = np.polyfit(X, Y, 1) # 多项式拟合,根据之前求的中心店拟合一条直线(最小二乘) 401 | 402 | x0 = np.min(text_line_boxes[:, 0]) # 文本行x坐标最小值 403 | x1 = np.max(text_line_boxes[:, 2]) # 文本行x坐标最大值 404 | 405 | offset = (text_line_boxes[0, 2] - text_line_boxes[0, 0]) * 0.5 # 小框宽度的一半 406 | 407 | # 以全部小框的左上角这个点去拟合一条直线,然后计算一下文本行x坐标的极左极右对应的y坐标 408 | lt_y, rt_y = self.fit_y(text_line_boxes[:, 0], text_line_boxes[:, 1], x0 + offset, x1 - offset) 409 | # 以全部小框的左下角这个点去拟合一条直线,然后计算一下文本行x坐标的极左极右对应的y坐标 410 | lb_y, rb_y = self.fit_y(text_line_boxes[:, 0], text_line_boxes[:, 3], x0 + offset, x1 - offset) 411 | 412 | score = scores[list(tp_indices)].sum() / float(len(tp_indices)) # 求全部小框得分的均值作为文本行的均值 413 | 414 | text_lines[index, 0] = x0 415 | text_lines[index, 1] = min(lt_y, rt_y) # 文本行上端 线段 的y坐标的小值 416 | text_lines[index, 2] = x1 417 | text_lines[index, 3] = max(lb_y, rb_y) # 文本行下端 线段 的y坐标的大值 418 | text_lines[index, 4] = score # 文本行得分 419 | text_lines[index, 5] = z1[0] # 根据中心点拟合的直线的k,b 420 | text_lines[index, 6] = z1[1] 421 | height = np.mean((text_line_boxes[:, 3] - text_line_boxes[:, 1])) # 小框平均高度 422 | text_lines[index, 7] = height + 2.5 423 | 424 | text_recs = np.zeros((len(text_lines), 9), np.float) 425 | index = 0 426 | for line in text_lines: 427 | b1 = line[6] - line[7] / 2 # 根据高度和文本行中心线,求取文本行上下两条线的b值 428 | b2 = line[6] + line[7] / 2 429 | x1 = line[0] 430 | y1 = line[5] * line[0] + b1 # 左上 431 | x2 = line[2] 432 | y2 = line[5] * line[2] + b1 # 右上 433 | x3 = line[0] 434 | y3 = line[5] * line[0] + b2 # 左下 435 | x4 = line[2] 436 | y4 = line[5] * line[2] + b2 # 右下 437 | disX = x2 - x1 438 | disY = y2 - y1 439 | width = np.sqrt(disX * disX + disY * disY) # 文本行宽度 440 | 441 | fTmp0 = y3 - y1 # 文本行高度 442 | fTmp1 = fTmp0 * disY / width 443 | x = np.fabs(fTmp1 * disX / width) # 做补偿 444 | y = np.fabs(fTmp1 * disY / width) 445 | if line[5] < 0: 446 | x1 -= x 447 | y1 += y 448 | x4 += x 449 | y4 -= y 450 | else: 451 | x2 += x 452 | y2 += y 453 | x3 -= x 454 | y3 -= y 455 | text_recs[index, 0] = x1 456 | text_recs[index, 1] = y1 457 | text_recs[index, 2] = x2 458 | text_recs[index, 3] = y2 459 | text_recs[index, 4] = x3 460 | text_recs[index, 5] = y3 461 | text_recs[index, 6] = x4 462 | text_recs[index, 7] = y4 463 | text_recs[index, 8] = line[4] 464 | index = index + 1 465 | 466 | return text_recs 467 | 468 | -------------------------------------------------------------------------------- /train_code/train_ctpn/ctpn_utils.py: -------------------------------------------------------------------------------- 1 | #-*- coding:utf-8 -*- 2 | #''' 3 | # Created on 18-12-11 上午10:05 4 | # 5 | # @Author: Greg Gao(laygin) 6 | #''' 7 | import numpy as np 8 | import cv2 9 | from config import * 10 | 11 | 12 | def resize(image, width=None, height=None, inter=cv2.INTER_AREA): 13 | # initialize the dimensions of the image to be resized and 14 | # grab the image size 15 | dim = None 16 | (h, w) = image.shape[:2] 17 | 18 | # if both the width and height are None, then return the 19 | # original image 20 | if width is None and height is None: 21 | return image 22 | 23 | # check to see if the width is None 24 | if width is None: 25 | # calculate the ratio of the height and construct the 26 | # dimensions 27 | r = height / float(h) 28 | dim = (int(w * r), height) 29 | 30 | # otherwise, the height is None 31 | else: 32 | # calculate the ratio of the width and construct the 33 | # dimensions 34 | r = width / float(w) 35 | dim = (width, int(h * r)) 36 | 37 | # resize the image 38 | resized = cv2.resize(image, dim, interpolation=inter) 39 | 40 | # return the resized image 41 | return resized 42 | 43 | 44 | def gen_anchor(featuresize, scale): 45 | """ 46 | gen base anchor from feature map [HXW][9][4] 47 | reshape [HXW][9][4] to [HXWX9][4] 48 | """ 49 | heights = [11, 16, 23, 33, 48, 68, 97, 139, 198, 283] 50 | widths = [16, 16, 16, 16, 16, 16, 16, 16, 16, 16] 51 | 52 | # gen k=9 anchor size (h,w) 53 | heights = np.array(heights).reshape(len(heights), 1) 54 | widths = np.array(widths).reshape(len(widths), 1) 55 | 56 | base_anchor = np.array([0, 0, 15, 15]) 57 | # center x,y 58 | xt = (base_anchor[0] + base_anchor[2]) * 0.5 59 | yt = (base_anchor[1] + base_anchor[3]) * 0.5 60 | 61 | # x1 y1 x2 y2 62 | x1 = xt - widths * 0.5 63 | y1 = yt - heights * 0.5 64 | x2 = xt + widths * 0.5 65 | y2 = yt + heights * 0.5 66 | base_anchor = np.hstack((x1, y1, x2, y2)) 67 | 68 | h, w = featuresize 69 | shift_x = np.arange(0, w) * scale 70 | shift_y = np.arange(0, h) * scale 71 | # apply shift 72 | anchor = [] 73 | for i in shift_y: 74 | for j in shift_x: 75 | anchor.append(base_anchor + [j, i, j, i]) 76 | return np.array(anchor).reshape((-1, 4)) 77 | 78 | 79 | def cal_iou(box1, box1_area , boxes2, boxes2_area): 80 | """ 81 | box1 [x1,y1,x2,y2] 82 | boxes2 [Msample,x1,y1,x2,y2] 83 | """ 84 | x1 = np.maximum(box1[0], boxes2[:, 0]) 85 | x2 = np.minimum(box1[2], boxes2[:, 2]) 86 | y1 = np.maximum(box1[1], boxes2[:, 1]) 87 | y2 = np.minimum(box1[3], boxes2[:, 3]) 88 | 89 | intersection = np.maximum(x2 - x1, 0) * np.maximum(y2 - y1, 0) 90 | iou = intersection / (box1_area + boxes2_area[:] - intersection[:]) 91 | return iou 92 | 93 | 94 | def cal_overlaps(boxes1, boxes2): 95 | """ 96 | boxes1 [Nsample,x1,y1,x2,y2] anchor 97 | boxes2 [Msample,x1,y1,x2,y2] grouth-box 98 | 99 | """ 100 | area1 = (boxes1[:, 0] - boxes1[:, 2]) * (boxes1[:, 1] - boxes1[:, 3]) 101 | area2 = (boxes2[:, 0] - boxes2[:, 2]) * (boxes2[:, 1] - boxes2[:, 3]) 102 | 103 | overlaps = np.zeros((boxes1.shape[0], boxes2.shape[0])) 104 | 105 | # calculate the intersection of boxes1(anchor) and boxes2(GT box) 106 | for i in range(boxes1.shape[0]): 107 | overlaps[i][:] = cal_iou(boxes1[i], area1[i], boxes2, area2) 108 | 109 | return overlaps 110 | 111 | 112 | def bbox_transfrom(anchors, gtboxes): 113 | """ 114 | compute relative predicted vertical coordinates Vc ,Vh 115 | with respect to the bounding box location of an anchor 116 | """ 117 | regr = np.zeros((anchors.shape[0], 2)) 118 | Cy = (gtboxes[:, 1] + gtboxes[:, 3]) * 0.5 119 | Cya = (anchors[:, 1] + anchors[:, 3]) * 0.5 120 | h = gtboxes[:, 3] - gtboxes[:, 1] + 1.0 121 | ha = anchors[:, 3] - anchors[:, 1] + 1.0 122 | 123 | Vc = (Cy - Cya) / ha 124 | Vh = np.log(h / ha) 125 | 126 | return np.vstack((Vc, Vh)).transpose() 127 | 128 | 129 | def bbox_transfor_inv(anchor, regr): 130 | """ 131 | return predict bbox 132 | """ 133 | 134 | Cya = (anchor[:, 1] + anchor[:, 3]) * 0.5 135 | ha = anchor[:, 3] - anchor[:, 1] + 1 136 | 137 | Vcx = regr[0, :, 0] 138 | Vhx = regr[0, :, 1] 139 | 140 | Cyx = Vcx * ha + Cya 141 | hx = np.exp(Vhx) * ha 142 | xt = (anchor[:, 0] + anchor[:, 2]) * 0.5 143 | 144 | x1 = xt - 16 * 0.5 145 | y1 = Cyx - hx * 0.5 146 | x2 = xt + 16 * 0.5 147 | y2 = Cyx + hx * 0.5 148 | bbox = np.vstack((x1, y1, x2, y2)).transpose() 149 | 150 | return bbox 151 | 152 | 153 | def clip_box(bbox, im_shape): 154 | # x1 >= 0 155 | bbox[:, 0] = np.maximum(np.minimum(bbox[:, 0], im_shape[1] - 1), 0) 156 | # y1 >= 0 157 | bbox[:, 1] = np.maximum(np.minimum(bbox[:, 1], im_shape[0] - 1), 0) 158 | # x2 < im_shape[1] 159 | bbox[:, 2] = np.maximum(np.minimum(bbox[:, 2], im_shape[1] - 1), 0) 160 | # y2 < im_shape[0] 161 | bbox[:, 3] = np.maximum(np.minimum(bbox[:, 3], im_shape[0] - 1), 0) 162 | 163 | return bbox 164 | 165 | 166 | def filter_bbox(bbox, minsize): 167 | ws = bbox[:, 2] - bbox[:, 0] + 1 168 | hs = bbox[:, 3] - bbox[:, 1] + 1 169 | keep = np.where((ws >= minsize) & (hs >= minsize))[0] 170 | return keep 171 | 172 | 173 | def cal_rpn(imgsize, featuresize, scale, gtboxes): 174 | imgh, imgw = imgsize 175 | 176 | # gen base anchor 177 | base_anchor = gen_anchor(featuresize, scale) 178 | 179 | # calculate iou 180 | overlaps = cal_overlaps(base_anchor, gtboxes) 181 | 182 | # init labels -1 don't care 0 is negative 1 is positive 183 | labels = np.empty(base_anchor.shape[0]) 184 | labels.fill(-1) 185 | 186 | # for each GT box corresponds to an anchor which has highest IOU 187 | gt_argmax_overlaps = overlaps.argmax(axis=0) 188 | 189 | # the anchor with the highest IOU overlap with a GT box 190 | anchor_argmax_overlaps = overlaps.argmax(axis=1) 191 | anchor_max_overlaps = overlaps[range(overlaps.shape[0]), anchor_argmax_overlaps] 192 | 193 | # IOU > IOU_POSITIVE 194 | labels[anchor_max_overlaps > IOU_POSITIVE] = 1 195 | # IOU = imgw) | 205 | (base_anchor[:, 3] >= imgh) 206 | )[0] 207 | labels[outside_anchor] = -1 208 | 209 | # subsample positive labels ,if greater than RPN_POSITIVE_NUM(default 128) 210 | fg_index = np.where(labels == 1)[0] 211 | # print(len(fg_index)) 212 | if (len(fg_index) > RPN_POSITIVE_NUM): 213 | labels[np.random.choice(fg_index, len(fg_index) - RPN_POSITIVE_NUM, replace=False)] = -1 214 | 215 | # subsample negative labels 216 | if not OHEM: 217 | bg_index = np.where(labels == 0)[0] 218 | num_bg = RPN_TOTAL_NUM - np.sum(labels == 1) 219 | if (len(bg_index) > num_bg): 220 | # print('bgindex:',len(bg_index),'num_bg',num_bg) 221 | labels[np.random.choice(bg_index, len(bg_index) - num_bg, replace=False)] = -1 222 | 223 | # calculate bbox targets 224 | # debug here 225 | bbox_targets = bbox_transfrom(base_anchor, gtboxes[anchor_argmax_overlaps, :]) 226 | # bbox_targets=[] 227 | # print(len(labels),len(bbox_targets),len(base_anchor),base_anchor[0],labels[0]) 228 | 229 | return [labels, bbox_targets], base_anchor 230 | 231 | 232 | def nms(dets, thresh): 233 | x1 = dets[:, 0] 234 | y1 = dets[:, 1] 235 | x2 = dets[:, 2] 236 | y2 = dets[:, 3] 237 | scores = dets[:, 4] 238 | 239 | areas = (x2 - x1 + 1) * (y2 - y1 + 1) 240 | order = scores.argsort()[::-1] 241 | 242 | keep = [] 243 | while order.size > 0: 244 | i = order[0] 245 | keep.append(i) 246 | xx1 = np.maximum(x1[i], x1[order[1:]]) 247 | yy1 = np.maximum(y1[i], y1[order[1:]]) 248 | xx2 = np.minimum(x2[i], x2[order[1:]]) 249 | yy2 = np.minimum(y2[i], y2[order[1:]]) 250 | 251 | w = np.maximum(0.0, xx2 - xx1 + 1) 252 | h = np.maximum(0.0, yy2 - yy1 + 1) 253 | inter = w * h 254 | ovr = inter / (areas[i] + areas[order[1:]] - inter) 255 | 256 | inds = np.where(ovr <= thresh)[0] 257 | order = order[inds + 1] 258 | return keep 259 | 260 | 261 | # for predict 262 | class Graph: 263 | def __init__(self, graph): 264 | self.graph = graph 265 | 266 | def sub_graphs_connected(self): 267 | sub_graphs = [] 268 | for index in range(self.graph.shape[0]): 269 | if not self.graph[:, index].any() and self.graph[index, :].any(): 270 | v = index 271 | sub_graphs.append([v]) 272 | while self.graph[v, :].any(): 273 | v = np.where(self.graph[v, :])[0][0] 274 | sub_graphs[-1].append(v) 275 | return sub_graphs 276 | 277 | 278 | class TextLineCfg: 279 | SCALE = 600 280 | MAX_SCALE = 1200 281 | TEXT_PROPOSALS_WIDTH = 16 282 | MIN_NUM_PROPOSALS = 2 283 | MIN_RATIO = 0.5 284 | LINE_MIN_SCORE = 0.9 285 | MAX_HORIZONTAL_GAP = 60 286 | TEXT_PROPOSALS_MIN_SCORE = 0.7 287 | TEXT_PROPOSALS_NMS_THRESH = 0.3 288 | MIN_V_OVERLAPS = 0.6 289 | MIN_SIZE_SIM = 0.6 290 | 291 | 292 | class TextProposalGraphBuilder: 293 | """ 294 | Build Text proposals into a graph. 295 | """ 296 | 297 | def get_successions(self, index): 298 | box = self.text_proposals[index] 299 | results = [] 300 | for left in range(int(box[0]) + 1, min(int(box[0]) + TextLineCfg.MAX_HORIZONTAL_GAP + 1, self.im_size[1])): 301 | adj_box_indices = self.boxes_table[left] 302 | for adj_box_index in adj_box_indices: 303 | if self.meet_v_iou(adj_box_index, index): 304 | results.append(adj_box_index) 305 | if len(results) != 0: 306 | return results 307 | return results 308 | 309 | def get_precursors(self, index): 310 | box = self.text_proposals[index] 311 | results = [] 312 | for left in range(int(box[0]) - 1, max(int(box[0] - TextLineCfg.MAX_HORIZONTAL_GAP), 0) - 1, -1): 313 | adj_box_indices = self.boxes_table[left] 314 | for adj_box_index in adj_box_indices: 315 | if self.meet_v_iou(adj_box_index, index): 316 | results.append(adj_box_index) 317 | if len(results) != 0: 318 | return results 319 | return results 320 | 321 | def is_succession_node(self, index, succession_index): 322 | precursors = self.get_precursors(succession_index) 323 | if self.scores[index] >= np.max(self.scores[precursors]): 324 | return True 325 | return False 326 | 327 | def meet_v_iou(self, index1, index2): 328 | def overlaps_v(index1, index2): 329 | h1 = self.heights[index1] 330 | h2 = self.heights[index2] 331 | y0 = max(self.text_proposals[index2][1], self.text_proposals[index1][1]) 332 | y1 = min(self.text_proposals[index2][3], self.text_proposals[index1][3]) 333 | return max(0, y1 - y0 + 1) / min(h1, h2) 334 | 335 | def size_similarity(index1, index2): 336 | h1 = self.heights[index1] 337 | h2 = self.heights[index2] 338 | return min(h1, h2) / max(h1, h2) 339 | 340 | return overlaps_v(index1, index2) >= TextLineCfg.MIN_V_OVERLAPS and \ 341 | size_similarity(index1, index2) >= TextLineCfg.MIN_SIZE_SIM 342 | 343 | def build_graph(self, text_proposals, scores, im_size): 344 | self.text_proposals = text_proposals 345 | self.scores = scores 346 | self.im_size = im_size 347 | self.heights = text_proposals[:, 3] - text_proposals[:, 1] + 1 348 | 349 | boxes_table = [[] for _ in range(self.im_size[1])] 350 | for index, box in enumerate(text_proposals): 351 | boxes_table[int(box[0])].append(index) 352 | self.boxes_table = boxes_table 353 | 354 | graph = np.zeros((text_proposals.shape[0], text_proposals.shape[0]), np.bool) 355 | 356 | for index, box in enumerate(text_proposals): 357 | successions = self.get_successions(index) 358 | if len(successions) == 0: 359 | continue 360 | succession_index = successions[np.argmax(scores[successions])] 361 | if self.is_succession_node(index, succession_index): 362 | # NOTE: a box can have multiple successions(precursors) if multiple successions(precursors) 363 | # have equal scores. 364 | graph[index, succession_index] = True 365 | return Graph(graph) 366 | 367 | 368 | class TextProposalConnectorOriented: 369 | """ 370 | Connect text proposals into text lines 371 | """ 372 | 373 | def __init__(self): 374 | self.graph_builder = TextProposalGraphBuilder() 375 | 376 | def group_text_proposals(self, text_proposals, scores, im_size): 377 | graph = self.graph_builder.build_graph(text_proposals, scores, im_size) 378 | return graph.sub_graphs_connected() 379 | 380 | def fit_y(self, X, Y, x1, x2): 381 | # len(X) != 0 382 | # if X only include one point, the function will get line y=Y[0] 383 | if np.sum(X == X[0]) == len(X): 384 | return Y[0], Y[0] 385 | p = np.poly1d(np.polyfit(X, Y, 1)) 386 | return p(x1), p(x2) 387 | 388 | def get_text_lines(self, text_proposals, scores, im_size): 389 | """ 390 | text_proposals:boxes 391 | 392 | """ 393 | # tp=text proposal 394 | tp_groups = self.group_text_proposals(text_proposals, scores, im_size) # 首先还是建图,获取到文本行由哪几个小框构成 395 | 396 | text_lines = np.zeros((len(tp_groups), 8), np.float32) 397 | 398 | for index, tp_indices in enumerate(tp_groups): 399 | text_line_boxes = text_proposals[list(tp_indices)] # 每个文本行的全部小框 400 | X = (text_line_boxes[:, 0] + text_line_boxes[:, 2]) / 2 # 求每一个小框的中心x,y坐标 401 | Y = (text_line_boxes[:, 1] + text_line_boxes[:, 3]) / 2 402 | 403 | z1 = np.polyfit(X, Y, 1) # 多项式拟合,根据之前求的中心店拟合一条直线(最小二乘) 404 | 405 | x0 = np.min(text_line_boxes[:, 0]) # 文本行x坐标最小值 406 | x1 = np.max(text_line_boxes[:, 2]) # 文本行x坐标最大值 407 | 408 | offset = (text_line_boxes[0, 2] - text_line_boxes[0, 0]) * 0.5 # 小框宽度的一半 409 | 410 | # 以全部小框的左上角这个点去拟合一条直线,然后计算一下文本行x坐标的极左极右对应的y坐标 411 | lt_y, rt_y = self.fit_y(text_line_boxes[:, 0], text_line_boxes[:, 1], x0 + offset, x1 - offset) 412 | # 以全部小框的左下角这个点去拟合一条直线,然后计算一下文本行x坐标的极左极右对应的y坐标 413 | lb_y, rb_y = self.fit_y(text_line_boxes[:, 0], text_line_boxes[:, 3], x0 + offset, x1 - offset) 414 | 415 | score = scores[list(tp_indices)].sum() / float(len(tp_indices)) # 求全部小框得分的均值作为文本行的均值 416 | 417 | text_lines[index, 0] = x0 418 | text_lines[index, 1] = min(lt_y, rt_y) # 文本行上端 线段 的y坐标的小值 419 | text_lines[index, 2] = x1 420 | text_lines[index, 3] = max(lb_y, rb_y) # 文本行下端 线段 的y坐标的大值 421 | text_lines[index, 4] = score # 文本行得分 422 | text_lines[index, 5] = z1[0] # 根据中心点拟合的直线的k,b 423 | text_lines[index, 6] = z1[1] 424 | height = np.mean((text_line_boxes[:, 3] - text_line_boxes[:, 1])) # 小框平均高度 425 | text_lines[index, 7] = height + 2.5 426 | 427 | text_recs = np.zeros((len(text_lines), 9), np.float) 428 | index = 0 429 | for line in text_lines: 430 | b1 = line[6] - line[7] / 2 # 根据高度和文本行中心线,求取文本行上下两条线的b值 431 | b2 = line[6] + line[7] / 2 432 | x1 = line[0] 433 | y1 = line[5] * line[0] + b1 # 左上 434 | x2 = line[2] 435 | y2 = line[5] * line[2] + b1 # 右上 436 | x3 = line[0] 437 | y3 = line[5] * line[0] + b2 # 左下 438 | x4 = line[2] 439 | y4 = line[5] * line[2] + b2 # 右下 440 | disX = x2 - x1 441 | disY = y2 - y1 442 | width = np.sqrt(disX * disX + disY * disY) # 文本行宽度 443 | 444 | fTmp0 = y3 - y1 # 文本行高度 445 | fTmp1 = fTmp0 * disY / width 446 | x = np.fabs(fTmp1 * disX / width) # 做补偿 447 | y = np.fabs(fTmp1 * disY / width) 448 | if line[5] < 0: 449 | x1 -= x 450 | y1 += y 451 | x4 += x 452 | y4 -= y 453 | else: 454 | x2 += x 455 | y2 += y 456 | x3 -= x 457 | y3 -= y 458 | text_recs[index, 0] = x1 459 | text_recs[index, 1] = y1 460 | text_recs[index, 2] = x2 461 | text_recs[index, 3] = y2 462 | text_recs[index, 4] = x3 463 | text_recs[index, 5] = y3 464 | text_recs[index, 6] = x4 465 | text_recs[index, 7] = y4 466 | text_recs[index, 8] = line[4] 467 | index = index + 1 468 | 469 | return text_recs 470 | 471 | --------------------------------------------------------------------------------