├── .gitignore ├── LICENSE ├── README.md ├── dataloader.py ├── dataset └── .gitkeep ├── main.py ├── model.bin ├── model.py ├── prepare_IIIT5K_dataset.py ├── resource └── architecture.png └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | .idea/ 106 | 107 | dataset/ 108 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Sliding Convolution CTC for Scene Text Recognition 2 | 3 | Implementation of 'Scene Text Recognition with Sliding Convolutional Character Models'([pdf](https://arxiv.org/pdf/1709.01727)) 4 | 5 | ### Model 6 | 7 | Sliding windows + CNN + CTC 8 | 9 |
10 | 11 |
12 | 13 | 14 | ### Dependency 15 | 16 | While this implement might work for many cases, it is only tested for environment below: 17 | 18 | ``` 19 | python == 3.7.0 20 | torch == 0.4.1 21 | tqdm 22 | numpy 23 | ``` 24 | 25 | ``` 26 | warp-ctc(for pytorch 0.4) 27 | ``` 28 | 29 | ``` 30 | CUDA 9.0.1 31 | CUDNN 7.0.5 32 | ``` 33 | 34 | #### Install warp-ctc 35 | 36 | Follow this [instruction](https://github.com/SeanNaren/warp-ctc/tree/0.4.1) 37 | 38 | > **Note**:Version of warp-ctc should be corresponding with pytorch. [Related issue](https://github.com/SeanNaren/warp-ctc/issues/101) 39 | 40 | ### Usage 41 | 42 | Download [IIIT5K dataset](https://cdn.iiit.ac.in/cdn/cvit.iiit.ac.in/projects/SceneTextUnderstanding/IIIT5K-Word_V3.0.tar.gz) and release files to dataset folder. 43 | 44 | Preprocess IIIT5K dataset 45 | ```bash 46 | python3 prepare_IIIT5K_dataset.py 47 | ``` 48 | 49 | Train model: 50 | ```bash 51 | python3 main.py --cuda=True --mode=train 52 | ``` 53 | Resume training: 54 | ```bash 55 | python3 main.py --cuda=True --wram-up=True --mode=train 56 | ``` 57 | Test model: 58 | ```bash 59 | python3 main.py --cuda=True --mode=test 60 | ``` 61 | 62 | > **Note**: `model.bin` file is a pre-trained model which could achieve about 53% accuracy. (Due to the small training dataset) 63 | 64 | ### Citation 65 | 66 | If you find this work is useful in your research, please consider citing: 67 | 68 | ``` 69 | @article{yin2017scene, 70 | title={Scene text recognition with sliding convolutional character models}, 71 | author={Yin, Fei and Wu, Yi-Chao and Zhang, Xu-Yao and Liu, Cheng-Lin}, 72 | journal={arXiv preprint arXiv:1709.01727}, 73 | year={2017} 74 | } 75 | ``` 76 | -------------------------------------------------------------------------------- /dataloader.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | import torch 4 | import torch.utils.data as data 5 | 6 | 7 | class Dataset(data.Dataset): 8 | """ Digits dataset.""" 9 | 10 | def __init__(self, name, mode, windows, step): 11 | if name is 'IIIT5K': 12 | assert (mode == 'train' or mode == 'test') 13 | 14 | from prepare_IIIT5K_dataset import load_data, prepare_images 15 | print('Loading %s data...' % mode) 16 | self.mode = mode 17 | self.windows = windows 18 | self.step = step 19 | self.img_root = 'dataset/IIIT5K' 20 | self.img_names, self.labels = load_data(mode + 'data') 21 | self.images = prepare_images(self.img_names) 22 | 23 | def __len__(self): 24 | return len(self.img_names) 25 | 26 | def __getitem__(self, idx): 27 | image = self.images[idx] 28 | image = self.slide_image(image) 29 | label = self.labels[idx] 30 | image = torch.FloatTensor(image) 31 | label = torch.IntTensor([int(i) for i in label]) 32 | return image, label 33 | 34 | def slide_image(self, image): 35 | h, w = image.shape # No channel for gray image. 36 | output_image = [] 37 | half_of_max_window = max(self.windows) // 2 # 从最大窗口的中线开始滑动,每次移动step的距离 38 | for center_axis in range(half_of_max_window, w - half_of_max_window, self.step): 39 | slice_channel = [] 40 | for window_size in self.windows: 41 | image_slice = image[:, center_axis - window_size // 2: center_axis + window_size // 2] 42 | image_slice = cv2.resize(image_slice, (32, 32)) 43 | slice_channel.append(image_slice) 44 | output_image.append(np.asarray(slice_channel, dtype='float32')) 45 | return np.asarray(output_image, dtype='float32') 46 | 47 | 48 | class TrainBatch: 49 | def __init__(self, batch): 50 | transposed_data = list(zip(*batch)) 51 | self.images = torch.stack(transposed_data[0], 0) 52 | self.labels = torch.cat(transposed_data[1], 0) 53 | self.label_lengths = torch.IntTensor([len(i) for i in transposed_data[1]]) 54 | 55 | 56 | class TestBatch: 57 | def __init__(self, batch): 58 | transposed_data = list(zip(*batch)) 59 | self.images = torch.stack(transposed_data[0], 0) 60 | self.labels = [i.tolist() for i in transposed_data[1]] 61 | 62 | 63 | def train_fn(batch): 64 | return TrainBatch(batch) 65 | 66 | 67 | def test_fn(batch): 68 | return TestBatch(batch) 69 | -------------------------------------------------------------------------------- /dataset/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lsvih/Sliding-Convolution/d3d71d094f4462c4130208734b4d018e9b27cad3/dataset/.gitkeep -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import numpy as np 4 | import torch 5 | import torch.optim as optim 6 | import torch.utils.data as data 7 | from tqdm import tqdm 8 | from warpctc_pytorch import CTCLoss 9 | 10 | from dataloader import Dataset, test_fn, train_fn 11 | from model import CNNCTC 12 | from utils import load_model 13 | 14 | torch.backends.cudnn.benchmark = True 15 | 16 | 17 | def main(): 18 | if args.mode == 'train': 19 | train_dataset = Dataset(name=args.dataset, mode='train', windows=[24, 32, 40], step=4) 20 | train_loader = data.DataLoader(train_dataset, batch_size=args.batch_size, collate_fn=train_fn, 21 | shuffle=True, num_workers=args.workers, pin_memory=True) 22 | train(train_loader) 23 | if args.mode == 'test': 24 | test_dataset = Dataset(name=args.dataset, mode='test', windows=[24, 32, 40], step=4) 25 | test_loader = data.DataLoader(test_dataset, batch_size=args.batch_size, collate_fn=test_fn, 26 | shuffle=False, num_workers=args.workers, pin_memory=True) 27 | model = load_model(device) 28 | test(model, test_loader) 29 | 30 | 31 | def train(train_loader): 32 | model = CNNCTC(class_num=37).to(device) 33 | if args.warm_up: 34 | model = load_model(device) 35 | optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) 36 | loss_function = CTCLoss(size_average=True, length_average=True).to(device) 37 | min_loss = np.Inf 38 | for epoch in range(args.epoch): 39 | print('%d / %d Epoch' % (epoch, args.epoch)) 40 | epoch_loss = train_epoch(train_loader, model, optimizer, loss_function) 41 | print(epoch_loss) 42 | if epoch_loss < min_loss: 43 | min_loss = epoch_loss 44 | torch.save(model.state_dict(), 'model.bin') 45 | return model 46 | 47 | 48 | def train_epoch(train_loader, model, optimizer, loss_function): 49 | total_loss = 0 50 | model.train() 51 | model.mode = 'train' 52 | for i, batch in enumerate(tqdm(train_loader)): 53 | optimizer.zero_grad() 54 | images = batch.images.to(device) 55 | labels = batch.labels 56 | label_lengths = batch.label_lengths 57 | probs = model(images) 58 | log_probs = probs.log_softmax(2).to(device) 59 | prob_lengths = torch.IntTensor([log_probs.size(0)] * args.batch_size) 60 | loss = loss_function(log_probs, labels, prob_lengths, label_lengths) / args.batch_size 61 | total_loss += loss.item() 62 | loss.backward() 63 | optimizer.step() 64 | return total_loss 65 | 66 | 67 | def test(model, test_loader): 68 | model.eval() 69 | model.mode = 'test' 70 | total, correct = 0, 0 71 | for i, batch in enumerate(tqdm(test_loader)): 72 | images = batch.images.to(device) 73 | labels = batch.labels 74 | out = model(images) 75 | for actual, label in zip(labels, out): 76 | if actual == label: 77 | correct += 1 78 | total += 1 79 | print(correct / total) 80 | 81 | 82 | if __name__ == '__main__': 83 | parser = argparse.ArgumentParser(description='Sliding convolution') 84 | parser.add_argument('--batch-size', default=16, type=int) 85 | parser.add_argument('--epoch', default=10, type=int) 86 | parser.add_argument('--workers', default=2, type=int) 87 | parser.add_argument('--dataset', default='IIIT5K') 88 | parser.add_argument('--cuda', default=False, type=bool) 89 | parser.add_argument('--warm-up', default=False, type=bool) 90 | parser.add_argument('--lr', default=0.001, type=float) 91 | parser.add_argument('--weight_decay', default=0.0001, type=float) 92 | parser.add_argument('--mode', default='train', type=str) 93 | args = parser.parse_args() 94 | if args.cuda: 95 | device = torch.device('cuda') 96 | else: 97 | device = torch.device('cpu') 98 | main() 99 | -------------------------------------------------------------------------------- /model.bin: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lsvih/Sliding-Convolution/d3d71d094f4462c4130208734b4d018e9b27cad3/model.bin -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | 5 | 6 | class CNNCTC(nn.Module): 7 | def __init__(self, class_num, mode='train'): 8 | super(CNNCTC, self).__init__() 9 | feature = [ 10 | nn.Conv2d(3, 50, stride=1, kernel_size=3, padding=1), 11 | nn.BatchNorm2d(50), 12 | nn.ReLU(inplace=True), 13 | nn.Conv2d(50, 100, stride=1, kernel_size=3, padding=1), 14 | nn.Dropout(p=0.1), 15 | nn.Conv2d(100, 100, stride=1, kernel_size=3, padding=1), 16 | nn.Dropout(p=0.1), 17 | nn.BatchNorm2d(100), 18 | nn.ReLU(inplace=True), 19 | nn.MaxPool2d(2, stride=2), 20 | nn.Conv2d(100, 200, stride=1, kernel_size=3, padding=1), 21 | nn.Dropout(p=0.2), 22 | nn.Conv2d(200, 200, stride=1, kernel_size=3, padding=1), 23 | nn.Dropout(p=0.2), 24 | nn.BatchNorm2d(200), 25 | nn.ReLU(inplace=True), 26 | nn.MaxPool2d(2, stride=2), 27 | nn.Conv2d(200, 250, stride=1, kernel_size=3, padding=1), 28 | nn.Dropout(p=0.3), 29 | nn.BatchNorm2d(250), 30 | nn.ReLU(inplace=True), 31 | nn.Conv2d(250, 300, stride=1, kernel_size=3, padding=1), 32 | nn.Dropout(p=0.3), 33 | nn.Conv2d(300, 300, stride=1, kernel_size=3, padding=1), 34 | nn.Dropout(p=0.3), 35 | nn.BatchNorm2d(300), 36 | nn.ReLU(inplace=True), 37 | nn.MaxPool2d(2, stride=2), 38 | nn.Conv2d(300, 350, stride=1, kernel_size=3, padding=1), 39 | nn.Dropout(p=0.4), 40 | nn.BatchNorm2d(350), 41 | nn.ReLU(inplace=True), 42 | nn.Conv2d(350, 400, stride=1, kernel_size=3, padding=1), 43 | nn.Dropout(p=0.4), 44 | nn.Conv2d(400, 400, stride=1, kernel_size=3, padding=1), 45 | nn.Dropout(p=0.4), 46 | nn.BatchNorm2d(400), 47 | nn.ReLU(inplace=True), 48 | nn.MaxPool2d(2, stride=2) 49 | ] 50 | 51 | classifier = [ 52 | nn.Linear(1600, 900), 53 | nn.ReLU(inplace=True), 54 | nn.Dropout(p=0.5), 55 | # nn.Linear(900, 200), 56 | # nn.ReLU(inplace=True), 57 | nn.Linear(900, class_num) 58 | ] 59 | self.mode = mode 60 | self.feature = nn.Sequential(*feature) 61 | self.classifier = nn.Sequential(*classifier) 62 | 63 | def forward(self, x): # x: batch, window, slice channel, h, w 64 | result = [] 65 | for s in range(x.shape[1]): 66 | result.append(self.single_forward(x[:, s, :, :, :])) 67 | out = torch.stack(result) 68 | if self.mode != 'train': 69 | return self.decode(out) 70 | return out 71 | 72 | def single_forward(self, x): 73 | feat = self.feature(x) 74 | feat = feat.view(feat.shape[0], -1) # flatten 75 | out = self.classifier(feat) 76 | return out 77 | 78 | def decode(self, pred): 79 | pred = pred.permute(1, 0, 2).cpu().data.numpy() # batch, step, class 80 | seq = [] 81 | for i in range(pred.shape[0]): 82 | seq.append(self.pred_to_string(pred[i])) 83 | return seq 84 | 85 | @staticmethod 86 | def pred_to_string(pred): # step, class 87 | seq = [] 88 | for i in range(pred.shape[0]): 89 | label = np.argmax(pred[i]) 90 | seq.append(label) 91 | out = [] 92 | for i in range(len(seq)): 93 | if len(out) == 0: 94 | if seq[i] != 0: 95 | out.append(seq[i]) 96 | else: 97 | if seq[i] != 0 and seq[i] != seq[i - 1]: 98 | out.append(seq[i]) 99 | return out 100 | -------------------------------------------------------------------------------- /prepare_IIIT5K_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | 4 | import cv2 5 | import numpy as np 6 | import scipy.io as sio 7 | 8 | 9 | def char_to_label(char): 10 | if ord('A') <= ord(char) <= ord('Z'): 11 | return ord(char) - ord('A') + 1 12 | return 26 + ord(char) - ord('0') + 1 13 | 14 | 15 | def load_data_mat(name): 16 | print('Loading %s ...' % name) 17 | mat = sio.loadmat('dataset/IIIT5K/' + name + '.mat')[name][0] 18 | count = mat.shape[0] 19 | labels = [] 20 | images = [] 21 | for i in range(0, count): 22 | word = mat[i]['GroundTruth'][0] 23 | image = mat[i]['ImgName'][0] 24 | images.append('dataset/IIIT5K/' + image) 25 | labels.append([]) 26 | for j in range(0, len(word)): 27 | labels[i].append(char_to_label(word[j])) 28 | labels[i] = np.asarray(labels[i], dtype='int32') 29 | labels = np.asarray(labels) 30 | return images, labels 31 | 32 | 33 | def prepare_images(images): 34 | decoded_images = [] 35 | for img in images: 36 | img = cv2.imread(img, 0) 37 | scale = 32 / img.shape[0] 38 | img = cv2.resize(img, None, fx=scale, fy=scale) 39 | if img.shape[1] < 256: 40 | # Padding 41 | img = np.concatenate([np.array([[0] * ((256 - img.shape[1]) // 2)] * 32), img], axis=1) 42 | img = np.concatenate([img, np.array([[0] * (256 - img.shape[1])] * 32)], axis=1) 43 | else: 44 | img = cv2.resize(img, None, fx=256 / img.shape[1], fy=1) 45 | if img.shape[1] != 256: 46 | raise ValueError('shape = %d,%d' % img.shape) 47 | decoded_images.append(img) 48 | return np.asarray(decoded_images, np.float32) / 255 49 | 50 | 51 | def convert_if_needed(name): 52 | if os.path.exists('dataset/IIIT5K/' + name + '.pickle'): 53 | return 54 | images, labels = load_data_mat(name) 55 | 56 | with open('dataset/IIIT5K/' + name + '.pickle', 'wb') as f: 57 | pickle.dump((images, labels), f) 58 | 59 | 60 | def load_data(name): 61 | with open('dataset/IIIT5K/' + name + '.pickle', 'rb') as f: 62 | return pickle.load(f) 63 | 64 | 65 | def main(): 66 | convert_if_needed('traindata') 67 | convert_if_needed('testdata') 68 | 69 | images, labels = load_data('traindata') 70 | images = prepare_images(images) 71 | # assert not np.any(np.isnan(images)) 72 | # assert not np.any(np.isnan(labels)) 73 | 74 | eval_images, eval_labels = load_data('testdata') 75 | eval_images = prepare_images(eval_images) 76 | # assert not np.any(np.isnan(eval_images)) 77 | # assert not np.any(np.isnan(eval_images)) 78 | 79 | 80 | if __name__ == '__main__': 81 | main() 82 | -------------------------------------------------------------------------------- /resource/architecture.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lsvih/Sliding-Convolution/d3d71d094f4462c4130208734b4d018e9b27cad3/resource/architecture.png -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | 3 | import torch 4 | 5 | from model import CNNCTC 6 | 7 | 8 | def load_weights(target, source_state): 9 | new_dict = OrderedDict() 10 | for k, v in target.state_dict().items(): 11 | if k in source_state and v.size() == source_state[k].size(): 12 | new_dict[k] = source_state[k] 13 | else: 14 | new_dict[k] = v 15 | target.load_state_dict(new_dict) 16 | 17 | 18 | def load_model(device): 19 | model = CNNCTC(class_num=37).to(device) 20 | load_weights(model, torch.load('model.bin', map_location='cpu')) 21 | return model 22 | --------------------------------------------------------------------------------