├── models ├── __init__.py ├── crnn.pyc ├── utils.pyc ├── __init__.pyc ├── utils.py └── crnn.py ├── keys.pyc ├── utils.pyc ├── dataset.pyc ├── data ├── demo.png └── image33.jpg ├── image ├── image1.jpg ├── image2.jpg └── image3.jpg ├── keys.py ├── LICENSE.md ├── demo.py ├── tool ├── convert_t7.lua ├── tolmdb.py ├── tolmdb-python3.py └── convert_t7.py ├── README.md ├── test └── test_utils.py ├── dataset.py ├── utils.py └── crnn_main.py /models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /keys.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YoungMiao/crnn/HEAD/keys.pyc -------------------------------------------------------------------------------- /utils.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YoungMiao/crnn/HEAD/utils.pyc -------------------------------------------------------------------------------- /dataset.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YoungMiao/crnn/HEAD/dataset.pyc -------------------------------------------------------------------------------- /data/demo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YoungMiao/crnn/HEAD/data/demo.png -------------------------------------------------------------------------------- /data/image33.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YoungMiao/crnn/HEAD/data/image33.jpg -------------------------------------------------------------------------------- /image/image1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YoungMiao/crnn/HEAD/image/image1.jpg -------------------------------------------------------------------------------- /image/image2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YoungMiao/crnn/HEAD/image/image2.jpg -------------------------------------------------------------------------------- /image/image3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YoungMiao/crnn/HEAD/image/image3.jpg -------------------------------------------------------------------------------- /models/crnn.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YoungMiao/crnn/HEAD/models/crnn.pyc -------------------------------------------------------------------------------- /models/utils.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YoungMiao/crnn/HEAD/models/utils.pyc -------------------------------------------------------------------------------- /models/__init__.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YoungMiao/crnn/HEAD/models/__init__.pyc -------------------------------------------------------------------------------- /models/utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # encoding: utf-8 3 | 4 | import torch.nn as nn 5 | import torch.nn.parallel 6 | 7 | 8 | def data_parallel(model, input, ngpu): 9 | if isinstance(input.data, torch.cuda.FloatTensor) and ngpu > 1: 10 | output = nn.parallel.data_parallel(model, input, range(ngpu)) 11 | else: 12 | output = model(input) 13 | return output 14 | -------------------------------------------------------------------------------- /keys.py: -------------------------------------------------------------------------------- 1 | #coding:UTF-8 2 | #alphabet = '万下依口哺摄次状璐癌草血运重' 3 | alphabet = 'ACIMRey万下依口哺摄次状璐癌草血运重' 4 | #alphabet = '0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ' 5 | #alphabet = '羁一异内萄肇涉而成市且抗化尘多依功射吠瘤级长作安水妹常机醋弓脾见所病织屈立职经同科恐鳓后停非寥器闭衬壳穴其腺矽替七如岗簇有予下外服肌踏酬脑源专排五定钠帚骤性动集犬由草纸阿设及组遇准等腿俊位野道略取高牛城胜卡步颧另峨自良米已诊基达拔最较适文悉型上球明输得世者供通额张带用第枯澳头冷制离为咽看吴壁睡心层仅来忠拉量二材染氏受入痛曝卜獠痣院以此皤易可似至并竹酸移空能兼顾应往帅隆三醇弋不式炎我析鬓帕丙赘亚斜惊急肝窦夫纪中乳术发医肠囊前菌协从当柔套司牙严翁远回葡珠正波物遭结软确致敷加轻虽小腔' -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) 2 | 3 | Copyright (c) 2017 Jieru Mei 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /demo.py: -------------------------------------------------------------------------------- 1 | #coding: utf-8 2 | import torch 3 | from torch.autograd import Variable 4 | import utils 5 | import dataset 6 | import os 7 | from PIL import Image 8 | 9 | import models.crnn as crnn 10 | 11 | #os.environ["CUDA_VISIBLE_DEVICES"] ="1" 12 | model_path = './data/netCRNN_ch_nc_21_nh_128.pth' 13 | img_path = './data/image33.jpg' 14 | alphabet = u'\'ACIMRey万下依口哺摄次状璐癌草血运重' 15 | #print(alphabet) 16 | nclass = len(alphabet) + 1 17 | model = crnn.CRNN(32, 1, nclass, 128).cuda() 18 | print('loading pretrained model from %s' % model_path) 19 | pre_model = torch.load(model_path) 20 | for k,v in pre_model.items(): 21 | print(k,len(v)) 22 | model.load_state_dict(pre_model) 23 | 24 | converter = utils.strLabelConverter(alphabet) 25 | 26 | transformer = dataset.resizeNormalize((100, 32)) 27 | image = Image.open(img_path).convert('L') 28 | image = transformer(image).cuda() 29 | image = image.view(1, *image.size()) 30 | image = Variable(image) 31 | 32 | model.eval() 33 | preds = model(image) 34 | 35 | _, preds = preds.max(2) 36 | preds = preds.squeeze(2) 37 | preds = preds.transpose(1, 0).contiguous().view(-1) 38 | 39 | preds_size = Variable(torch.IntTensor([preds.size(0)])) 40 | raw_pred = converter.decode(preds.data, preds_size.data, raw=True) 41 | sim_pred = converter.decode(preds.data, preds_size.data, raw=False) 42 | print('%-20s => %-20s' % (raw_pred.encode('utf8'), sim_pred.encode('utf8'))) 43 | -------------------------------------------------------------------------------- /models/crnn.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | class BidirectionalLSTM(nn.Module): 5 | 6 | def __init__(self, nIn, nHidden, nOut): 7 | super(BidirectionalLSTM, self).__init__() 8 | 9 | self.rnn = nn.LSTM(nIn, nHidden, bidirectional=True) 10 | self.embedding = nn.Linear(nHidden * 2, nOut) 11 | 12 | def forward(self, input): 13 | recurrent, _ = self.rnn(input) 14 | T, b, h = recurrent.size() 15 | t_rec = recurrent.view(T * b, h) 16 | 17 | output = self.embedding(t_rec) # [T * b, nOut] 18 | output = output.view(T, b, -1) 19 | 20 | return output 21 | 22 | 23 | class CRNN(nn.Module): 24 | 25 | def __init__(self, imgH, nc, nclass, nh, n_rnn=2, leakyRelu=False): 26 | super(CRNN, self).__init__() 27 | assert imgH % 16 == 0, 'imgH has to be a multiple of 16' 28 | 29 | ks = [3, 3, 3, 3, 3, 3, 2] 30 | ps = [1, 1, 1, 1, 1, 1, 0] 31 | ss = [1, 1, 1, 1, 1, 1, 1] 32 | nm = [64, 128, 256, 256, 512, 512, 512] 33 | 34 | cnn = nn.Sequential() 35 | 36 | def convRelu(i, batchNormalization=False): 37 | nIn = nc if i == 0 else nm[i - 1] 38 | nOut = nm[i] 39 | cnn.add_module('conv{0}'.format(i), 40 | nn.Conv2d(nIn, nOut, ks[i], ss[i], ps[i])) 41 | if batchNormalization: 42 | cnn.add_module('batchnorm{0}'.format(i), nn.BatchNorm2d(nOut)) 43 | if leakyRelu: 44 | cnn.add_module('relu{0}'.format(i), 45 | nn.LeakyReLU(0.2, inplace=True)) 46 | else: 47 | cnn.add_module('relu{0}'.format(i), nn.ReLU(True)) 48 | 49 | convRelu(0) 50 | cnn.add_module('pooling{0}'.format(0), nn.MaxPool2d(2, 2)) # 64x16x64 51 | convRelu(1) 52 | cnn.add_module('pooling{0}'.format(1), nn.MaxPool2d(2, 2)) # 128x8x32 53 | convRelu(2, True) 54 | convRelu(3) 55 | cnn.add_module('pooling{0}'.format(2), 56 | nn.MaxPool2d((2, 2), (2, 1), (0, 1))) # 256x4x16 57 | convRelu(4, True) 58 | convRelu(5) 59 | cnn.add_module('pooling{0}'.format(3), 60 | nn.MaxPool2d((2, 2), (2, 1), (0, 1))) # 512x2x16 61 | convRelu(6, True) # 512x1x16 62 | 63 | self.cnn = cnn 64 | self.rnn = nn.Sequential( 65 | BidirectionalLSTM(512, nh, nh), 66 | BidirectionalLSTM(nh, nh, nclass)) 67 | 68 | def forward(self, input): 69 | # conv features 70 | conv = self.cnn(input) 71 | b, c, h, w = conv.size() 72 | assert h == 1, "the height of conv must be 1" 73 | conv = conv.squeeze(2) 74 | conv = conv.permute(2, 0, 1) # [w, b, c] 75 | 76 | # rnn features 77 | output = self.rnn(conv) 78 | 79 | return output 80 | -------------------------------------------------------------------------------- /tool/convert_t7.lua: -------------------------------------------------------------------------------- 1 | require('table') 2 | require('torch') 3 | require('os') 4 | 5 | function clone(t) 6 | -- deep-copy a table 7 | if type(t) ~= "table" then return t end 8 | local meta = getmetatable(t) 9 | local target = {} 10 | for k, v in pairs(t) do 11 | if type(v) == "table" then 12 | target[k] = clone(v) 13 | else 14 | target[k] = v 15 | end 16 | end 17 | setmetatable(target, meta) 18 | return target 19 | end 20 | 21 | 22 | function tableMerge(lhs, rhs) 23 | output = clone(lhs) 24 | for _, v in pairs(rhs) do 25 | table.insert(output, v) 26 | end 27 | return output 28 | end 29 | 30 | 31 | function isInTable(val, val_list) 32 | for _, item in pairs(val_list) do 33 | if val == item then 34 | return true 35 | end 36 | end 37 | return false 38 | end 39 | 40 | 41 | function modelToList(model) 42 | local ignoreList = { 43 | 'nn.Copy', 44 | 'nn.AddConstant', 45 | 'nn.MulConstant', 46 | 'nn.View', 47 | 'nn.Transpose', 48 | 'nn.SplitTable', 49 | 'nn.SharedParallelTable', 50 | 'nn.JoinTable', 51 | } 52 | local state = {} 53 | local param 54 | for i, layer in pairs(model.modules) do 55 | local typeName = torch.type(layer) 56 | if not isInTable(typeName, ignoreList) then 57 | if typeName == 'nn.Sequential' or typeName == 'nn.ConcatTable' then 58 | param = modelToList(layer) 59 | elseif typeName == 'cudnn.SpatialConvolution' or typeName == 'nn.SpatialConvolution' then 60 | param = layer:parameters() 61 | elseif typeName == 'cudnn.SpatialBatchNormalization' or typeName == 'nn.SpatialBatchNormalization' then 62 | param = layer:parameters() 63 | bn_vars = {layer.running_mean, layer.running_var} 64 | param = tableMerge(param, bn_vars) 65 | elseif typeName == 'nn.LstmLayer' then 66 | param = layer:parameters() 67 | elseif typeName == 'nn.BiRnnJoin' then 68 | param = layer:parameters() 69 | elseif typeName == 'cudnn.SpatialMaxPooling' or typeName == 'nn.SpatialMaxPooling' then 70 | param = {} 71 | elseif typeName == 'cudnn.ReLU' or typeName == 'nn.ReLU' then 72 | param = {} 73 | else 74 | print(string.format('Unknown class %s', typeName)) 75 | os.exit(0) 76 | end 77 | table.insert(state, {typeName, param}) 78 | else 79 | print(string.format('pass %s', typeName)) 80 | end 81 | end 82 | return state 83 | end 84 | 85 | 86 | function saveModel(model, output_path) 87 | local state = modelToList(model) 88 | torch.save(output_path, state) 89 | end 90 | -------------------------------------------------------------------------------- /tool/tolmdb.py: -------------------------------------------------------------------------------- 1 | import os 2 | import lmdb # install lmdb by "pip install lmdb" 3 | import cv2 4 | import re 5 | import Image 6 | import numpy as np 7 | import imghdr 8 | 9 | 10 | def checkImageIsValid(imageBin): 11 | if imageBin is None: 12 | return False 13 | try: 14 | imageBuf = np.fromstring(imageBin, dtype=np.uint8) 15 | img = cv2.imdecode(imageBuf, cv2.IMREAD_GRAYSCALE) 16 | imgH, imgW = img.shape[0], img.shape[1] 17 | except: 18 | return False 19 | else: 20 | if imgH * imgW == 0: 21 | return False 22 | return True 23 | 24 | 25 | def writeCache(env, cache): 26 | with env.begin(write=True) as txn: 27 | for k, v in cache.iteritems(): 28 | txn.put(k, v) 29 | 30 | def createDataset(outputPath, imagePathList, labelList, lexiconList=None, checkValid=True): 31 | """ 32 | Create LMDB dataset for CRNN training. 33 | ARGS: 34 | outputPath : LMDB output path 35 | imagePathList : list of image path 36 | labelList : list of corresponding groundtruth texts 37 | lexiconList : (optional) list of lexicon lists 38 | checkValid : if true, check the validity of every image 39 | """ 40 | assert(len(imagePathList) == len(labelList)) 41 | nSamples = len(imagePathList) 42 | env = lmdb.open(outputPath, map_size=1099511627776) 43 | cache = {} 44 | cnt = 1 45 | for i in xrange(nSamples): 46 | imagePath = './recognition/'+''.join(imagePathList[i]).split()[0].replace('\n','').replace('\r\n','') 47 | label = ''.join(labelList[i]) 48 | if not os.path.exists(imagePath): 49 | print('%s does not exist' % imagePath) 50 | continue 51 | 52 | with open(imagePath, 'r') as f: 53 | imageBin = f.read() 54 | 55 | if checkValid: 56 | if not checkImageIsValid(imageBin): 57 | print('%s is not a valid image' % imagePath) 58 | continue 59 | imageKey = 'image-%09d' % cnt 60 | labelKey = 'label-%09d' % cnt 61 | cache[imageKey] = imageBin 62 | cache[labelKey] = label 63 | if lexiconList: 64 | lexiconKey = 'lexicon-%09d' % cnt 65 | cache[lexiconKey] = ' '.join(lexiconList[i]) 66 | if cnt % 1000 == 0: 67 | writeCache(env, cache) 68 | cache = {} 69 | print('Written %d / %d' % (cnt, nSamples)) 70 | cnt += 1 71 | print cnt 72 | nSamples = cnt-1 73 | cache['num-samples'] = str(nSamples) 74 | writeCache(env, cache) 75 | print('Created dataset with %d samples' % nSamples) 76 | 77 | 78 | if __name__ == '__main__': 79 | outputPath = "./train_lmdb" 80 | imgdata = open("./train_241.txt") 81 | imagePathList = list(imgdata) 82 | 83 | labelList = [] 84 | for line in imagePathList: 85 | word = line.split()[1] 86 | labelList.append(word) 87 | createDataset(outputPath, imagePathList, labelList) 88 | #pass 89 | -------------------------------------------------------------------------------- /tool/tolmdb-python3.py: -------------------------------------------------------------------------------- 1 | #coding: utf-8 2 | import os 3 | import lmdb # install lmdb by "pip install lmdb" 4 | import cv2 5 | import re 6 | from PIL import Image 7 | import numpy as np 8 | import imghdr 9 | 10 | 11 | def checkImageIsValid(imageBin): 12 | if imageBin is None: 13 | return False 14 | try: 15 | imageBuf = np.fromstring(imageBin, dtype=np.uint8) 16 | img = cv2.imdecode(imageBuf, cv2.IMREAD_GRAYSCALE) 17 | imgH, imgW = img.shape[0], img.shape[1] 18 | except: 19 | return False 20 | else: 21 | if imgH * imgW == 0: 22 | return False 23 | return True 24 | 25 | 26 | def writeCache(env, cache): 27 | with env.begin(write=True) as txn: 28 | for k, v in cache.items(): 29 | txn.put(str(k).encode(), str(v).encode()) 30 | 31 | 32 | def createDataset(outputPath, imagePathList, labelList, lexiconList=None, checkValid=True): 33 | """ 34 | Create LMDB dataset for CRNN training. 35 | ARGS: 36 | outputPath : LMDB output path 37 | imagePathList : list of image path 38 | labelList : list of corresponding groundtruth texts 39 | lexiconList : (optional) list of lexicon lists 40 | checkValid : if true, check the validity of every image 41 | """ 42 | assert(len(imagePathList) == len(labelList)) 43 | nSamples = len(imagePathList) 44 | env = lmdb.open(outputPath, map_size=1099511627776) 45 | cache = {} 46 | cnt = 1 47 | for i in range(nSamples): 48 | imagePath = ''.join(imagePathList[i]).split()[0].replace('\n','').replace('\r\n','') 49 | label = ''.join(labelList[i]) 50 | if not os.path.exists(imagePath): 51 | print('%s does not exist' % imagePath) 52 | continue 53 | 54 | 55 | with open(imagePath, 'rb') as f: 56 | imageBin = f.read() 57 | 58 | if checkValid: 59 | if not checkImageIsValid(imageBin): 60 | print('%s is not a valid image' % imagePath) 61 | continue 62 | imageKey = 'image-%09d' % cnt 63 | labelKey = 'label-%09d' % cnt 64 | cache[imageKey] = imageBin 65 | cache[labelKey] = label 66 | if lexiconList: 67 | lexiconKey = 'lexicon-%09d' % cnt 68 | cache[lexiconKey] = ' '.join(lexiconList[i]) 69 | if cnt % 1000 == 0: 70 | writeCache(env, cache) 71 | cache = {} 72 | print('Written %d / %d' % (cnt, nSamples)) 73 | cnt += 1 74 | print(cnt) 75 | nSamples = cnt-1 76 | cache['num-samples'] = str(nSamples) 77 | writeCache(env, cache) 78 | print('Created dataset with %d samples' % nSamples) 79 | 80 | 81 | if __name__ == '__main__': 82 | outputPath = "./SVT_lmdb" 83 | imgdata = open("./gt.txt") 84 | imagePathList = list(imgdata) 85 | 86 | labelList = [] 87 | for line in imagePathList: 88 | word = line.split()[1] 89 | labelList.append(word) 90 | createDataset(outputPath, imagePathList, labelList) 91 | #pass 92 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | fork from meijieru/crnn.pytorch https://github.com/meijieru/crnn.pytorch 2 | ## crnn实现细节(pytorch) 3 | ### 1.环境搭建 4 | #### 1.1 基础环境 5 | * Ubuntu14.04 + CUDA 6 | * opencv2.4 + pytorch + lmdb +wrap_ctc 7 | 8 | 安装lmdb `apt-get install lmdb` 9 | #### 1.2 安装pytorch 10 | pip,linux,cuda8.0,python2.7:pip install http://download.pytorch.org/whl/cu80/torch-0.1.12.post2-cp27-none-linux_x86_64.whl 11 | 参考:http://pytorch.org/ 12 | #### 1.3 安装wrap_ctc 13 | git clone https://github.com/baidu-research/warp-ctc.git` 14 | cd warp-ctc 15 | mkdir build; cd build 16 | cmake .. 17 | make 18 | 19 | GPU版在环境变量中添加 20 | export CUDA_HOME="/usr/local/cuda" 21 | 22 | cd pytorch_binding 23 | python setup.py install 24 | 25 | 参考:https://github.com/SeanNaren/warp-ctc/tree/pytorch_bindings/pytorch_binding 26 | #### 1.4 注意问题 27 | 1. 缺少cffi库文件 使用`pip install cffi`安装 28 | 2. 安装pytorch_binding前,确认设置CUDA_HOME,虽然编译安装不会报错,但是在调用gpu时,会出现wrap_ctc没有gpu属性的错误 29 | ### 2. crnn预测(以21类中英文为例) 30 | 模型地址:链接:https://eyun.baidu.com/s/3dEUJJg9 密码:vKeD 31 | 32 | 运行`/contrib/crnn/demo.py` 33 | 34 | 原始图片为: ![](./image/image2.jpg) 35 | 36 | ![](./image/image3.jpg) 37 | 38 | 39 | 识别结果为: ![](./image/image1.jpg) 40 | 41 | # 加载模型 42 | model_path = './samples/netCRNN_9_112580.pth' 43 | # 需识别的图片 44 | img_path = './data/demo.png' 45 | # 识别的类别 46 | alphabet = 'ACIMRey万下依口哺摄次状璐癌草血运重' 47 | # 设置模型参数 图片高度imgH=32, nc, 分类数目nclass=len(alphabet)+1 一个预留位, LSTM设置隐藏层数nh=128, 使用GPU个数ngpu=1 48 | model = crnn.CRNN(32, 1, 22, 128, 1).cuda() 49 | 50 | 替换模型时,注意模型分类的类别数目 51 | ## crnn 训练(以21类中英文为例) 52 | 1. 数据预处理 53 | 54 | 运行`/contrib/crnn/tool/tolmdb.py` 55 | 56 | # 生成的lmdb输出路径 57 | outputPath = "./train_lmdb" 58 | # 图片及对应的label 59 | imgdata = open("./train.txt") 60 | 61 | 2. 训练模型 62 | 63 | 运行`/contrib/crnn/crnn_main.py` 64 | 65 | python crnn_main.py [--param val] 66 | --trainroot 训练集路径 67 | --valroot 验证集路径 68 | --workers CPU工作核数, default=2 69 | --batchSize 设置batchSize大小, default=64 70 | --imgH 图片高度, default=32 71 | --nh LSTM隐藏层数, default=256 72 | --niter 训练回合数, default=25 73 | --lr 学习率, default=0.01 74 | --beta1 75 | --cuda 使用GPU, action='store_true' 76 | --ngpu 使用GPU的个数, default=1 77 | --crnn 选择预训练模型 78 | --alphabet 设置分类 79 | --Diters 80 | --experiment 模型保存目录 81 | --displayInterval 设置多少次迭代显示一次, default=500 82 | --n_test_disp 每次验证显示的个数, default=10 83 | --valInterval 设置多少次迭代验证一次, default=500 84 | --saveInterval 设置多少次迭代保存一次模型, default=500 85 | --adam 使用adma优化器, action='store_true' 86 | --adadelta 使用adadelta优化器, action='store_true' 87 | --keep_ratio 设置图片保持横纵比缩放, action='store_true' 88 | --random_sample 是否使用随机采样器对数据集进行采样, action='store_true' 89 |     90 | 示例:python /contrib/crnn/crnn_main.py --tainroot [训练集路径] --valroot [验证集路径] --nh 128 --cuda --crnn [预训练模型路径] 91 | 92 | 修改`/contrib/crnn/keys.py`中`alphabet = 'ACIMRey万下依口哺摄次状璐癌草血运重'`增加或者减少类别 93 | 94 | 3. 注意事项 95 | 96 | 训练和预测采用的类别数和LSTM隐藏层数需保持一致 97 | 98 | 99 | 100 | -------------------------------------------------------------------------------- /test/test_utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # encoding: utf-8 3 | 4 | import sys 5 | import unittest 6 | import torch 7 | from torch.autograd import Variable 8 | import collections 9 | origin_path = sys.path 10 | sys.path.append("..") 11 | import utils 12 | sys.path = origin_path 13 | 14 | 15 | def equal(a, b): 16 | if isinstance(a, torch.Tensor): 17 | return a.equal(b) 18 | elif isinstance(a, str): 19 | return a == b 20 | elif isinstance(a, collections.Iterable): 21 | res = True 22 | for (x, y) in zip(a, b): 23 | res = res & equal(x, y) 24 | return res 25 | else: 26 | return a == b 27 | 28 | 29 | class utilsTestCase(unittest.TestCase): 30 | 31 | def checkConverter(self): 32 | encoder = utils.strLabelConverter('abcdefghijklmnopqrstuvwxyz') 33 | 34 | # Encode 35 | # trivial mode 36 | result = encoder.encode('efa') 37 | target = (torch.IntTensor([5, 6, 1]), torch.IntTensor([3])) 38 | self.assertTrue(equal(result, target)) 39 | 40 | # batch mode 41 | result = encoder.encode(['efa', 'ab']) 42 | target = (torch.IntTensor([5, 6, 1, 1, 2]), torch.IntTensor([3, 2])) 43 | self.assertTrue(equal(result, target)) 44 | 45 | # Decode 46 | # trivial mode 47 | result = encoder.decode( 48 | torch.IntTensor([5, 6, 1]), torch.IntTensor([3])) 49 | target = 'efa' 50 | self.assertTrue(equal(result, target)) 51 | 52 | # replicate mode 53 | result = encoder.decode( 54 | torch.IntTensor([5, 5, 0, 1]), torch.IntTensor([4])) 55 | target = 'ea' 56 | self.assertTrue(equal(result, target)) 57 | 58 | # raise AssertionError 59 | def f(): 60 | result = encoder.decode( 61 | torch.IntTensor([5, 5, 0, 1]), torch.IntTensor([3])) 62 | self.assertRaises(AssertionError, f) 63 | 64 | # batch mode 65 | result = encoder.decode( 66 | torch.IntTensor([5, 6, 1, 1, 2]), torch.IntTensor([3, 2])) 67 | target = ['efa', 'ab'] 68 | self.assertTrue(equal(result, target)) 69 | 70 | def checkOneHot(self): 71 | v = torch.LongTensor([1, 2, 1, 2, 0]) 72 | v_length = torch.LongTensor([2, 3]) 73 | v_onehot = utils.oneHot(v, v_length, 4) 74 | target = torch.FloatTensor([[[0, 1, 0, 0], [0, 0, 1, 0], [0, 0, 0, 0]], 75 | [[0, 1, 0, 0], [0, 0, 1, 0], [1, 0, 0, 0]]]) 76 | assert target.equal(v_onehot) 77 | 78 | def checkAverager(self): 79 | acc = utils.averager() 80 | acc.add(Variable(torch.Tensor([1, 2]))) 81 | acc.add(Variable(torch.Tensor([[5, 6]]))) 82 | assert acc.val() == 3.5 83 | 84 | acc = utils.averager() 85 | acc.add(torch.Tensor([1, 2])) 86 | acc.add(torch.Tensor([[5, 6]])) 87 | assert acc.val() == 3.5 88 | 89 | def checkAssureRatio(self): 90 | img = torch.Tensor([[1], [3]]).view(1, 1, 2, 1) 91 | img = Variable(img) 92 | img = utils.assureRatio(img) 93 | assert torch.Size([1, 1, 2, 2]) == img.size() 94 | 95 | 96 | def _suite(): 97 | suite = unittest.TestSuite() 98 | suite.addTest(utilsTestCase("checkConverter")) 99 | suite.addTest(utilsTestCase("checkOneHot")) 100 | suite.addTest(utilsTestCase("checkAverager")) 101 | suite.addTest(utilsTestCase("checkAssureRatio")) 102 | return suite 103 | 104 | 105 | if __name__ == "__main__": 106 | suite = _suite() 107 | runner = unittest.TextTestRunner() 108 | runner.run(suite) 109 | -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # encoding: utf-8 3 | 4 | import random 5 | import torch 6 | from torch.utils.data import Dataset 7 | from torch.utils.data import sampler 8 | import torchvision.transforms as transforms 9 | import lmdb 10 | import six 11 | import sys 12 | from PIL import Image 13 | import numpy as np 14 | 15 | 16 | class lmdbDataset(Dataset): 17 | 18 | def __init__(self, root=None, transform=None, target_transform=None): 19 | self.env = lmdb.open( 20 | root, 21 | max_readers=1, 22 | readonly=True, 23 | lock=False, 24 | readahead=False, 25 | meminit=False) 26 | 27 | if not self.env: 28 | print('cannot creat lmdb from %s' % (root)) 29 | sys.exit(0) 30 | 31 | with self.env.begin(write=False) as txn: 32 | nSamples = int(txn.get('num-samples')) 33 | self.nSamples = nSamples 34 | 35 | self.transform = transform 36 | self.target_transform = target_transform 37 | 38 | def __len__(self): 39 | return self.nSamples 40 | 41 | def __getitem__(self, index): 42 | assert index <= len(self), 'index range error' 43 | index += 1 44 | with self.env.begin(write=False) as txn: 45 | img_key = 'image-%09d' % index 46 | imgbuf = txn.get(img_key) 47 | 48 | buf = six.BytesIO() 49 | buf.write(imgbuf) 50 | buf.seek(0) 51 | try: 52 | img = Image.open(buf).convert('L') 53 | except IOError: 54 | print('Corrupted image for %d' % index) 55 | return self[index + 1] 56 | 57 | if self.transform is not None: 58 | img = self.transform(img) 59 | 60 | label_key = 'label-%09d' % index 61 | label = str(txn.get(label_key)) 62 | 63 | if self.target_transform is not None: 64 | label = self.target_transform(label) 65 | 66 | return (img, label) 67 | 68 | 69 | class resizeNormalize(object): 70 | 71 | def __init__(self, size, interpolation=Image.BILINEAR): 72 | self.size = size 73 | self.interpolation = interpolation 74 | self.toTensor = transforms.ToTensor() 75 | 76 | def __call__(self, img): 77 | img = img.resize(self.size, self.interpolation) 78 | img = self.toTensor(img) 79 | img.sub_(0.5).div_(0.5) 80 | return img 81 | 82 | 83 | class randomSequentialSampler(sampler.Sampler): 84 | 85 | def __init__(self, data_source, batch_size): 86 | self.num_samples = len(data_source) 87 | self.batch_size = batch_size 88 | 89 | def __iter__(self): 90 | n_batch = len(self) // self.batch_size 91 | tail = len(self) % self.batch_size 92 | index = torch.LongTensor(len(self)).fill_(0) 93 | for i in range(n_batch): 94 | random_start = random.randint(0, len(self) - self.batch_size) 95 | batch_index = random_start + torch.range(0, self.batch_size - 1) 96 | index[i * self.batch_size:(i + 1) * self.batch_size] = batch_index 97 | # deal with tail 98 | if tail: 99 | random_start = random.randint(0, len(self) - self.batch_size) 100 | tail_index = random_start + torch.range(0, tail - 1) 101 | index[(i + 1) * self.batch_size:] = tail_index 102 | 103 | return iter(index) 104 | 105 | def __len__(self): 106 | return self.num_samples 107 | 108 | 109 | class alignCollate(object): 110 | 111 | def __init__(self, imgH=32, imgW=100, keep_ratio=False, min_ratio=1): 112 | self.imgH = imgH 113 | self.imgW = imgW 114 | self.keep_ratio = keep_ratio 115 | self.min_ratio = min_ratio 116 | 117 | def __call__(self, batch): 118 | images, labels = zip(*batch) 119 | 120 | imgH = self.imgH 121 | imgW = self.imgW 122 | if self.keep_ratio: 123 | ratios = [] 124 | for image in images: 125 | w, h = image.size 126 | ratios.append(w / float(h)) 127 | ratios.sort() 128 | max_ratio = ratios[-1] 129 | imgW = int(np.floor(max_ratio * imgH)) 130 | imgW = max(imgH * self.min_ratio, imgW) # assure imgH >= imgW 131 | 132 | transform = resizeNormalize((imgW, imgH)) 133 | images = [transform(image) for image in images] 134 | images = torch.cat([t.unsqueeze(0) for t in images], 0) 135 | 136 | return images, labels 137 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # encoding: utf-8 3 | 4 | import torch 5 | import torch.nn as nn 6 | from torch.autograd import Variable 7 | import collections 8 | import chardet 9 | import numpy as np 10 | import sys 11 | reload(sys) 12 | sys.setdefaultencoding('utf8') 13 | 14 | class strLabelConverter(object): 15 | """Convert between str and label. 16 | 17 | NOTE: 18 | Insert `blank` to the alphabet for CTC. 19 | 20 | Args: 21 | alphabet (str): set of the possible characters. 22 | ignore_case (bool, default=True): whether or not to ignore all of the case. 23 | """ 24 | 25 | def __init__(self, alphabet, ignore_case=True): 26 | self._ignore_case = ignore_case 27 | if self._ignore_case: 28 | alphabet = alphabet 29 | self.alphabet = alphabet + '-' # for `-1` index 30 | 31 | self.dict = {} 32 | for i, char in enumerate(alphabet): 33 | # NOTE: 0 is reserved for 'blank' required by wrap_ctc 34 | self.dict[char] = i + 1 35 | 36 | def is_chinese(self,uchar): 37 | """判断一个unicode是否是汉字""" 38 | alnum = np.array([ch.isalnum() for ch in uchar]) 39 | if not alnum.all(): 40 | return True 41 | else: 42 | return False 43 | 44 | def encode(self, text): 45 | """Support batch or single str. 46 | 47 | Args: 48 | text (str or list of str): texts to convert. 49 | 50 | Returns: 51 | torch.IntTensor [length_0 + length_1 + ... length_{n - 1}]: encoded texts. 52 | torch.IntTensor [n]: length of each text. 53 | """ 54 | length = [] 55 | result = [] 56 | for item in text: 57 | if self.is_chinese(item): 58 | item = unicode(item,'utf-8') 59 | length.append(len(item)) 60 | for char in item: 61 | index = self.dict[char] 62 | result.append(index) 63 | text = result 64 | return (torch.IntTensor(text), torch.IntTensor(length)) 65 | 66 | def decode(self, t, length, raw=False): 67 | """Decode encoded texts back into strs. 68 | 69 | Args: 70 | torch.IntTensor [length_0 + length_1 + ... length_{n - 1}]: encoded texts. 71 | torch.IntTensor [n]: length of each text. 72 | 73 | Raises: 74 | AssertionError: when the texts and its length does not match. 75 | 76 | Returns: 77 | text (str or list of str): texts to convert. 78 | """ 79 | if length.numel() == 1: 80 | length = length[0] 81 | assert t.numel() == length, "text with length: {} does not match declared length: {}".format(t.numel(), length) 82 | if raw: 83 | return ''.join([self.alphabet[i - 1] for i in t]) 84 | else: 85 | char_list = [] 86 | for i in range(length): 87 | if t[i] != 0 and (not (i > 0 and t[i - 1] == t[i])): 88 | char_list.append(self.alphabet[t[i] - 1]) 89 | return ''.join(char_list) 90 | else: 91 | # batch mode 92 | assert t.numel() == length.sum(), "texts with length: {} does not match declared length: {}".format(t.numel(), length.sum()) 93 | texts = [] 94 | index = 0 95 | for i in range(length.numel()): 96 | l = length[i] 97 | texts.append( 98 | self.decode( 99 | t[index:index + l], torch.IntTensor([l]), raw=raw)) 100 | index += l 101 | return texts 102 | 103 | 104 | class averager(object): 105 | """Compute average for `torch.Variable` and `torch.Tensor`. """ 106 | 107 | def __init__(self): 108 | self.reset() 109 | 110 | def add(self, v): 111 | if isinstance(v, Variable): 112 | count = v.data.numel() 113 | v = v.data.sum() 114 | elif isinstance(v, torch.Tensor): 115 | count = v.numel() 116 | v = v.sum() 117 | 118 | self.n_count += count 119 | self.sum += v 120 | 121 | def reset(self): 122 | self.n_count = 0 123 | self.sum = 0 124 | 125 | def val(self): 126 | res = 0 127 | if self.n_count != 0: 128 | res = self.sum / float(self.n_count) 129 | return res 130 | 131 | 132 | def oneHot(v, v_length, nc): 133 | batchSize = v_length.size(0) 134 | maxLength = v_length.max() 135 | v_onehot = torch.FloatTensor(batchSize, maxLength, nc).fill_(0) 136 | acc = 0 137 | for i in range(batchSize): 138 | length = v_length[i] 139 | label = v[acc:acc + length].view(-1, 1).long() 140 | v_onehot[i, :length].scatter_(1, label, 1.0) 141 | acc += length 142 | return v_onehot 143 | 144 | 145 | def loadData(v, data): 146 | v.data.resize_(data.size()).copy_(data) 147 | 148 | 149 | def prettyPrint(v): 150 | print('Size {0}, Type: {1}'.format(str(v.size()), v.data.type())) 151 | print('| Max: %f | Min: %f | Mean: %f' % (v.max().data[0], v.min().data[0], 152 | v.mean().data[0])) 153 | 154 | 155 | def assureRatio(img): 156 | """Ensure imgH <= imgW.""" 157 | b, c, h, w = img.size() 158 | if h > w: 159 | main = nn.UpsamplingBilinear2d(size=(h, h), scale_factor=None) 160 | img = main(img) 161 | return img 162 | -------------------------------------------------------------------------------- /tool/convert_t7.py: -------------------------------------------------------------------------------- 1 | import torchfile 2 | import argparse 3 | import torch 4 | from torch.nn.parameter import Parameter 5 | import numpy as np 6 | import models.crnn as crnn 7 | 8 | 9 | layer_map = { 10 | 'SpatialConvolution': 'Conv2d', 11 | 'SpatialBatchNormalization': 'BatchNorm2d', 12 | 'ReLU': 'ReLU', 13 | 'SpatialMaxPooling': 'MaxPool2d', 14 | 'SpatialAveragePooling': 'AvgPool2d', 15 | 'SpatialUpSamplingNearest': 'UpsamplingNearest2d', 16 | 'View': None, 17 | 'Linear': 'linear', 18 | 'Dropout': 'Dropout', 19 | 'SoftMax': 'Softmax', 20 | 'Identity': None, 21 | 'SpatialFullConvolution': 'ConvTranspose2d', 22 | 'SpatialReplicationPadding': None, 23 | 'SpatialReflectionPadding': None, 24 | 'Copy': None, 25 | 'Narrow': None, 26 | 'SpatialCrossMapLRN': None, 27 | 'Sequential': None, 28 | 'ConcatTable': None, # output is list 29 | 'CAddTable': None, # input is list 30 | 'Concat': None, 31 | 'TorchObject': None, 32 | 'LstmLayer': 'LSTM', 33 | 'BiRnnJoin': 'Linear' 34 | } 35 | 36 | 37 | def torch_layer_serial(layer, layers): 38 | name = layer[0] 39 | if name == 'nn.Sequential' or name == 'nn.ConcatTable': 40 | tmp_layers = [] 41 | for sub_layer in layer[1]: 42 | torch_layer_serial(sub_layer, tmp_layers) 43 | layers.extend(tmp_layers) 44 | else: 45 | layers.append(layer) 46 | 47 | 48 | def py_layer_serial(layer, layers): 49 | """ 50 | Assume modules are defined as executive sequence. 51 | """ 52 | if len(layer._modules) >= 1: 53 | tmp_layers = [] 54 | for sub_layer in layer.children(): 55 | py_layer_serial(sub_layer, tmp_layers) 56 | layers.extend(tmp_layers) 57 | else: 58 | layers.append(layer) 59 | 60 | 61 | def trans_pos(param, part_indexes, dim=0): 62 | parts = np.split(param, len(part_indexes), dim) 63 | new_parts = [] 64 | for i in part_indexes: 65 | new_parts.append(parts[i]) 66 | return np.concatenate(new_parts, dim) 67 | 68 | 69 | def load_params(py_layer, t7_layer): 70 | if type(py_layer).__name__ == 'LSTM': 71 | # LSTM 72 | all_weights = [] 73 | num_directions = 2 if py_layer.bidirectional else 1 74 | for i in range(py_layer.num_layers): 75 | for j in range(num_directions): 76 | suffix = '_reverse' if j == 1 else '' 77 | weights = ['weight_ih_l{}{}', 'bias_ih_l{}{}', 78 | 'weight_hh_l{}{}', 'bias_hh_l{}{}'] 79 | weights = [x.format(i, suffix) for x in weights] 80 | all_weights += weights 81 | 82 | params = [] 83 | for i in range(len(t7_layer)): 84 | params.extend(t7_layer[i][1]) 85 | params = [trans_pos(p, [0, 1, 3, 2], dim=0) for p in params] 86 | else: 87 | all_weights = [] 88 | name = t7_layer[0].split('.')[-1] 89 | if name == 'BiRnnJoin': 90 | weight_0, bias_0, weight_1, bias_1 = t7_layer[1] 91 | weight = np.concatenate((weight_0, weight_1), axis=1) 92 | bias = bias_0 + bias_1 93 | t7_layer[1] = [weight, bias] 94 | all_weights += ['weight', 'bias'] 95 | elif name == 'SpatialConvolution' or name == 'Linear': 96 | all_weights += ['weight', 'bias'] 97 | elif name == 'SpatialBatchNormalization': 98 | all_weights += ['weight', 'bias', 'running_mean', 'running_var'] 99 | 100 | params = t7_layer[1] 101 | 102 | params = [torch.from_numpy(item) for item in params] 103 | assert len(all_weights) == len(params), "params' number not match" 104 | for py_param_name, t7_param in zip(all_weights, params): 105 | item = getattr(py_layer, py_param_name) 106 | if isinstance(item, Parameter): 107 | item = item.data 108 | try: 109 | item.copy_(t7_param) 110 | except RuntimeError: 111 | print('Size not match between %s and %s' % 112 | (item.size(), t7_param.size())) 113 | 114 | 115 | def torch_to_pytorch(model, t7_file, output): 116 | py_layers = [] 117 | for layer in list(model.children()): 118 | py_layer_serial(layer, py_layers) 119 | 120 | t7_data = torchfile.load(t7_file) 121 | t7_layers = [] 122 | for layer in t7_data: 123 | torch_layer_serial(layer, t7_layers) 124 | 125 | j = 0 126 | for i, py_layer in enumerate(py_layers): 127 | py_name = type(py_layer).__name__ 128 | t7_layer = t7_layers[j] 129 | t7_name = t7_layer[0].split('.')[-1] 130 | if layer_map[t7_name] != py_name: 131 | raise RuntimeError('%s does not match %s' % (py_name, t7_name)) 132 | 133 | if py_name == 'LSTM': 134 | n_layer = 2 if py_layer.bidirectional else 1 135 | n_layer *= py_layer.num_layers 136 | t7_layer = t7_layers[j:j + n_layer] 137 | j += n_layer 138 | else: 139 | j += 1 140 | 141 | load_params(py_layer, t7_layer) 142 | 143 | torch.save(model.state_dict(), output) 144 | 145 | 146 | if __name__ == "__main__": 147 | parser = argparse.ArgumentParser( 148 | description='Convert torch t7 model to pytorch' 149 | ) 150 | parser.add_argument( 151 | '--model_file', 152 | '-m', 153 | type=str, 154 | required=True, 155 | help='torch model file in t7 format' 156 | ) 157 | parser.add_argument( 158 | '--output', 159 | '-o', 160 | type=str, 161 | default=None, 162 | help='output file name prefix, xxx.py xxx.pth' 163 | ) 164 | args = parser.parse_args() 165 | 166 | py_model = crnn.CRNN(32, 1, 37, 256, 1) 167 | torch_to_pytorch(py_model, args.model_file, args.output) 168 | -------------------------------------------------------------------------------- /crnn_main.py: -------------------------------------------------------------------------------- 1 | #coding: utf-8 2 | from __future__ import print_function 3 | import argparse 4 | import random 5 | import torch 6 | import torch.backends.cudnn as cudnn 7 | import torch.optim as optim 8 | import torch.utils.data 9 | from torch.autograd import Variable 10 | import numpy as np 11 | from warpctc_pytorch import CTCLoss 12 | import os 13 | import utils 14 | import dataset 15 | import chardet 16 | import keys 17 | import sys 18 | import collections 19 | reload(sys) 20 | sys.setdefaultencoding('utf8') 21 | 22 | import models.crnn as crnn 23 | 24 | os.environ["CUDA_VISIBLE_DEVICES"] ="1" 25 | str1 = keys.alphabet 26 | parser = argparse.ArgumentParser() 27 | parser.add_argument('--trainroot', required=True, help='path to dataset') 28 | parser.add_argument('--valroot', required=True, help='path to dataset') 29 | parser.add_argument('--workers', type=int, help='number of data loading workers', default=2) 30 | parser.add_argument('--batchSize', type=int, default=64, help='input batch size') 31 | parser.add_argument('--imgH', type=int, default=32, help='the height of the input image to network') 32 | parser.add_argument('--imgW', type=int, default=100, help='the width of the input image to network') 33 | parser.add_argument('--nh', type=int, default=256, help='size of the lstm hidden state') 34 | parser.add_argument('--niter', type=int, default=1000, help='number of epochs to train for') 35 | parser.add_argument('--lr', type=float, default=0.01, help='learning rate for Critic, default=0.00005') 36 | parser.add_argument('--beta1', type=float, default=0.5, help='beta1 for adam. default=0.5') 37 | parser.add_argument('--cuda', action='store_true', help='enables cuda') 38 | parser.add_argument('--ngpu', type=int, default=1, help='number of GPUs to use') 39 | parser.add_argument('--crnn', default='', help="path to crnn (to continue training)") 40 | parser.add_argument('--alphabet', type=str, default= str1) 41 | parser.add_argument('--Diters', type=int, default=5, help='number of D iters per each G iter') 42 | parser.add_argument('--experiment', default=None, help='Where to store samples and models') 43 | parser.add_argument('--displayInterval', type=int, default=500, help='Interval to be displayed') 44 | parser.add_argument('--n_test_disp', type=int, default=10, help='Number of samples to display when test') 45 | parser.add_argument('--valInterval', type=int, default=500, help='Interval to be displayed') 46 | parser.add_argument('--saveInterval', type=int, default=500, help='Interval to be displayed') 47 | parser.add_argument('--adam', action='store_true', help='Whether to use adam (default is rmsprop)') 48 | parser.add_argument('--adadelta', action='store_true', help='Whether to use adadelta (default is rmsprop)') 49 | parser.add_argument('--keep_ratio', action='store_true', help='whether to keep ratio for image resize') 50 | parser.add_argument('--random_sample', action='store_true', help='whether to sample the dataset with random sampler') 51 | opt = parser.parse_args() 52 | print(opt) 53 | 54 | if opt.experiment is None: 55 | opt.experiment = 'expr' 56 | os.system('mkdir {0}'.format(opt.experiment)) 57 | 58 | opt.manualSeed = random.randint(1, 10000) # fix seed 59 | print("Random Seed: ", opt.manualSeed) 60 | random.seed(opt.manualSeed) 61 | np.random.seed(opt.manualSeed) 62 | torch.manual_seed(opt.manualSeed) 63 | 64 | cudnn.benchmark = True 65 | 66 | if torch.cuda.is_available() and not opt.cuda: 67 | print("WARNING: You have a CUDA device, so you should probably run with --cuda") 68 | 69 | train_dataset = dataset.lmdbDataset(root=opt.trainroot) 70 | assert train_dataset 71 | if not opt.random_sample: 72 | sampler = dataset.randomSequentialSampler(train_dataset, opt.batchSize) 73 | else: 74 | sampler = None 75 | train_loader = torch.utils.data.DataLoader( 76 | train_dataset, batch_size=opt.batchSize, 77 | shuffle=True, sampler=sampler, 78 | num_workers=int(opt.workers), 79 | collate_fn=dataset.alignCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio=opt.keep_ratio)) 80 | test_dataset = dataset.lmdbDataset( 81 | root=opt.valroot, transform=dataset.resizeNormalize((100, 32))) 82 | 83 | alphabet = opt.alphabet.decode('utf-8') 84 | 85 | nclass = len(alphabet) + 1 86 | nc = 1 87 | 88 | converter = utils.strLabelConverter(alphabet) 89 | criterion = CTCLoss() 90 | 91 | 92 | # custom weights initialization called on crnn 93 | def weights_init(m): 94 | classname = m.__class__.__name__ 95 | if classname.find('Conv') != -1: 96 | m.weight.data.normal_(0.0, 0.02) 97 | elif classname.find('BatchNorm') != -1: 98 | m.weight.data.normal_(1.0, 0.02) 99 | m.bias.data.fill_(0) 100 | 101 | 102 | crnn = crnn.CRNN(opt.imgH, nc, nclass, opt.nh) 103 | crnn.apply(weights_init) 104 | if opt.crnn != '': 105 | print('loading pretrained model from %s' % opt.crnn) 106 | pre_trainmodel = torch.load(opt.crnn) 107 | model_dict = crnn.state_dict() 108 | weig1 = 'rnn.1.embedding.weight' 109 | bias1 = 'rnn.1.embedding.bias' 110 | if len(model_dict[weig1]) == len(pre_trainmodel[weig1]) and len(model_dict[bias1]) == len(pre_trainmodel[bias1]): 111 | crnn.load_state_dict(pre_trainmodel) 112 | else : 113 | for k,v in model_dict.items(): 114 | if (k != weig1 or k != bias1): 115 | model_dict[k] = pre_trainmodel[k] 116 | crnn.load_state_dict(model_dict) 117 | print(crnn) 118 | 119 | image = torch.FloatTensor(opt.batchSize, 3, opt.imgH, opt.imgH) 120 | text = torch.IntTensor(opt.batchSize * 5) 121 | length = torch.IntTensor(opt.batchSize) 122 | 123 | if opt.cuda: 124 | crnn.cuda() 125 | crnn = torch.nn.DataParallel(crnn, device_ids=range(opt.ngpu)) 126 | image = image.cuda() 127 | criterion = criterion.cuda() 128 | 129 | image = Variable(image) 130 | text = Variable(text) 131 | length = Variable(length) 132 | 133 | # loss averager 134 | loss_avg = utils.averager() 135 | 136 | # setup optimizer 137 | if opt.adam: 138 | optimizer = optim.Adam(crnn.parameters(), lr=opt.lr, 139 | betas=(opt.beta1, 0.999)) 140 | elif opt.adadelta: 141 | optimizer = optim.Adadelta(crnn.parameters(), lr=opt.lr) 142 | else: 143 | optimizer = optim.RMSprop(crnn.parameters(), lr=opt.lr) 144 | 145 | 146 | def val(net, dataset, criterion, max_iter=100): 147 | print('Start val') 148 | 149 | for p in crnn.parameters(): 150 | p.requires_grad = False 151 | 152 | net.eval() 153 | data_loader = torch.utils.data.DataLoader( 154 | dataset, shuffle=True, batch_size=opt.batchSize, num_workers=int(opt.workers)) 155 | val_iter = iter(data_loader) 156 | 157 | i = 0 158 | n_correct = 0 159 | loss_avg = utils.averager() 160 | 161 | max_iter = min(max_iter, len(data_loader)) 162 | for i in range(max_iter): 163 | data = val_iter.next() 164 | i += 1 165 | cpu_images, cpu_texts = data 166 | batch_size = cpu_images.size(0) 167 | utils.loadData(image, cpu_images) 168 | t, l = converter.encode(cpu_texts) 169 | utils.loadData(text, t) 170 | utils.loadData(length, l) 171 | 172 | preds = crnn(image) 173 | preds_size = Variable(torch.IntTensor([preds.size(0)] * batch_size)) 174 | cost = criterion(preds, text, preds_size, length) / batch_size 175 | loss_avg.add(cost) 176 | 177 | _, preds = preds.max(2) 178 | preds = preds.squeeze(2) 179 | preds = preds.transpose(1, 0).contiguous().view(-1) 180 | sim_preds = converter.decode(preds.data, preds_size.data, raw=False) 181 | for pred, target in zip(sim_preds, cpu_texts): 182 | if pred == target: 183 | n_correct += 1 184 | 185 | raw_preds = converter.decode(preds.data, preds_size.data, raw=True)[:opt.n_test_disp] 186 | for raw_pred, pred, gt in zip(raw_preds, sim_preds, cpu_texts): 187 | 188 | print('%-20s => %-20s, gt: %-20s' % (raw_pred.encode('utf-8'), pred.encode('utf-8'), gt.encode('utf-8'))) 189 | 190 | accuracy = n_correct / float(max_iter * opt.batchSize) 191 | print('Test loss: %f, accuray: %f' % (loss_avg.val(), accuracy)) 192 | 193 | 194 | def trainBatch(net, criterion, optimizer): 195 | data = train_iter.next() 196 | cpu_images, cpu_texts = data 197 | batch_size = cpu_images.size(0) 198 | utils.loadData(image, cpu_images) 199 | t, l = converter.encode(cpu_texts) 200 | utils.loadData(text, t) 201 | utils.loadData(length, l) 202 | 203 | preds = crnn(image) 204 | preds_size = Variable(torch.IntTensor([preds.size(0)] * batch_size)) 205 | cost = criterion(preds, text, preds_size, length) / batch_size 206 | crnn.zero_grad() 207 | cost.backward() 208 | optimizer.step() 209 | return cost 210 | 211 | 212 | for epoch in range(opt.niter): 213 | train_iter = iter(train_loader) 214 | i = 0 215 | while i < len(train_loader): 216 | for p in crnn.parameters(): 217 | p.requires_grad = True 218 | crnn.train() 219 | 220 | cost = trainBatch(crnn, criterion, optimizer) 221 | loss_avg.add(cost) 222 | i += 1 223 | 224 | if i % opt.displayInterval == 0: 225 | print('[%d/%d][%d/%d] Loss: %f' % 226 | (epoch, opt.niter, i, len(train_loader), loss_avg.val())) 227 | loss_avg.reset() 228 | 229 | if i % opt.valInterval == 0: 230 | val(crnn, test_dataset, criterion) 231 | 232 | # do checkpointing 233 | if i % opt.saveInterval == 0: 234 | torch.save( 235 | crnn.state_dict(), '{0}/netCRNN_{1}_{2}.pth'.format(opt.experiment, epoch, i)) 236 | --------------------------------------------------------------------------------