├── .gitignore
├── LICENSE
├── README.md
├── dataloader.py
├── dataset
└── .gitkeep
├── main.py
├── model.bin
├── model.py
├── prepare_IIIT5K_dataset.py
├── resource
└── architecture.png
└── utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | *.egg-info/
24 | .installed.cfg
25 | *.egg
26 | MANIFEST
27 |
28 | # PyInstaller
29 | # Usually these files are written by a python script from a template
30 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
31 | *.manifest
32 | *.spec
33 |
34 | # Installer logs
35 | pip-log.txt
36 | pip-delete-this-directory.txt
37 |
38 | # Unit test / coverage reports
39 | htmlcov/
40 | .tox/
41 | .coverage
42 | .coverage.*
43 | .cache
44 | nosetests.xml
45 | coverage.xml
46 | *.cover
47 | .hypothesis/
48 | .pytest_cache/
49 |
50 | # Translations
51 | *.mo
52 | *.pot
53 |
54 | # Django stuff:
55 | *.log
56 | local_settings.py
57 | db.sqlite3
58 |
59 | # Flask stuff:
60 | instance/
61 | .webassets-cache
62 |
63 | # Scrapy stuff:
64 | .scrapy
65 |
66 | # Sphinx documentation
67 | docs/_build/
68 |
69 | # PyBuilder
70 | target/
71 |
72 | # Jupyter Notebook
73 | .ipynb_checkpoints
74 |
75 | # pyenv
76 | .python-version
77 |
78 | # celery beat schedule file
79 | celerybeat-schedule
80 |
81 | # SageMath parsed files
82 | *.sage.py
83 |
84 | # Environments
85 | .env
86 | .venv
87 | env/
88 | venv/
89 | ENV/
90 | env.bak/
91 | venv.bak/
92 |
93 | # Spyder project settings
94 | .spyderproject
95 | .spyproject
96 |
97 | # Rope project settings
98 | .ropeproject
99 |
100 | # mkdocs documentation
101 | /site
102 |
103 | # mypy
104 | .mypy_cache/
105 | .idea/
106 |
107 | dataset/
108 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2019
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Sliding Convolution CTC for Scene Text Recognition
2 |
3 | Implementation of 'Scene Text Recognition with Sliding Convolutional Character Models'([pdf](https://arxiv.org/pdf/1709.01727))
4 |
5 | ### Model
6 |
7 | Sliding windows + CNN + CTC
8 |
9 |
10 |

11 |
12 |
13 |
14 | ### Dependency
15 |
16 | While this implement might work for many cases, it is only tested for environment below:
17 |
18 | ```
19 | python == 3.7.0
20 | torch == 0.4.1
21 | tqdm
22 | numpy
23 | ```
24 |
25 | ```
26 | warp-ctc(for pytorch 0.4)
27 | ```
28 |
29 | ```
30 | CUDA 9.0.1
31 | CUDNN 7.0.5
32 | ```
33 |
34 | #### Install warp-ctc
35 |
36 | Follow this [instruction](https://github.com/SeanNaren/warp-ctc/tree/0.4.1)
37 |
38 | > **Note**:Version of warp-ctc should be corresponding with pytorch. [Related issue](https://github.com/SeanNaren/warp-ctc/issues/101)
39 |
40 | ### Usage
41 |
42 | Download [IIIT5K dataset](https://cdn.iiit.ac.in/cdn/cvit.iiit.ac.in/projects/SceneTextUnderstanding/IIIT5K-Word_V3.0.tar.gz) and release files to dataset folder.
43 |
44 | Preprocess IIIT5K dataset
45 | ```bash
46 | python3 prepare_IIIT5K_dataset.py
47 | ```
48 |
49 | Train model:
50 | ```bash
51 | python3 main.py --cuda=True --mode=train
52 | ```
53 | Resume training:
54 | ```bash
55 | python3 main.py --cuda=True --wram-up=True --mode=train
56 | ```
57 | Test model:
58 | ```bash
59 | python3 main.py --cuda=True --mode=test
60 | ```
61 |
62 | > **Note**: `model.bin` file is a pre-trained model which could achieve about 53% accuracy. (Due to the small training dataset)
63 |
64 | ### Citation
65 |
66 | If you find this work is useful in your research, please consider citing:
67 |
68 | ```
69 | @article{yin2017scene,
70 | title={Scene text recognition with sliding convolutional character models},
71 | author={Yin, Fei and Wu, Yi-Chao and Zhang, Xu-Yao and Liu, Cheng-Lin},
72 | journal={arXiv preprint arXiv:1709.01727},
73 | year={2017}
74 | }
75 | ```
76 |
--------------------------------------------------------------------------------
/dataloader.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import numpy as np
3 | import torch
4 | import torch.utils.data as data
5 |
6 |
7 | class Dataset(data.Dataset):
8 | """ Digits dataset."""
9 |
10 | def __init__(self, name, mode, windows, step):
11 | if name is 'IIIT5K':
12 | assert (mode == 'train' or mode == 'test')
13 |
14 | from prepare_IIIT5K_dataset import load_data, prepare_images
15 | print('Loading %s data...' % mode)
16 | self.mode = mode
17 | self.windows = windows
18 | self.step = step
19 | self.img_root = 'dataset/IIIT5K'
20 | self.img_names, self.labels = load_data(mode + 'data')
21 | self.images = prepare_images(self.img_names)
22 |
23 | def __len__(self):
24 | return len(self.img_names)
25 |
26 | def __getitem__(self, idx):
27 | image = self.images[idx]
28 | image = self.slide_image(image)
29 | label = self.labels[idx]
30 | image = torch.FloatTensor(image)
31 | label = torch.IntTensor([int(i) for i in label])
32 | return image, label
33 |
34 | def slide_image(self, image):
35 | h, w = image.shape # No channel for gray image.
36 | output_image = []
37 | half_of_max_window = max(self.windows) // 2 # 从最大窗口的中线开始滑动,每次移动step的距离
38 | for center_axis in range(half_of_max_window, w - half_of_max_window, self.step):
39 | slice_channel = []
40 | for window_size in self.windows:
41 | image_slice = image[:, center_axis - window_size // 2: center_axis + window_size // 2]
42 | image_slice = cv2.resize(image_slice, (32, 32))
43 | slice_channel.append(image_slice)
44 | output_image.append(np.asarray(slice_channel, dtype='float32'))
45 | return np.asarray(output_image, dtype='float32')
46 |
47 |
48 | class TrainBatch:
49 | def __init__(self, batch):
50 | transposed_data = list(zip(*batch))
51 | self.images = torch.stack(transposed_data[0], 0)
52 | self.labels = torch.cat(transposed_data[1], 0)
53 | self.label_lengths = torch.IntTensor([len(i) for i in transposed_data[1]])
54 |
55 |
56 | class TestBatch:
57 | def __init__(self, batch):
58 | transposed_data = list(zip(*batch))
59 | self.images = torch.stack(transposed_data[0], 0)
60 | self.labels = [i.tolist() for i in transposed_data[1]]
61 |
62 |
63 | def train_fn(batch):
64 | return TrainBatch(batch)
65 |
66 |
67 | def test_fn(batch):
68 | return TestBatch(batch)
69 |
--------------------------------------------------------------------------------
/dataset/.gitkeep:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lsvih/Sliding-Convolution/d3d71d094f4462c4130208734b4d018e9b27cad3/dataset/.gitkeep
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | import numpy as np
4 | import torch
5 | import torch.optim as optim
6 | import torch.utils.data as data
7 | from tqdm import tqdm
8 | from warpctc_pytorch import CTCLoss
9 |
10 | from dataloader import Dataset, test_fn, train_fn
11 | from model import CNNCTC
12 | from utils import load_model
13 |
14 | torch.backends.cudnn.benchmark = True
15 |
16 |
17 | def main():
18 | if args.mode == 'train':
19 | train_dataset = Dataset(name=args.dataset, mode='train', windows=[24, 32, 40], step=4)
20 | train_loader = data.DataLoader(train_dataset, batch_size=args.batch_size, collate_fn=train_fn,
21 | shuffle=True, num_workers=args.workers, pin_memory=True)
22 | train(train_loader)
23 | if args.mode == 'test':
24 | test_dataset = Dataset(name=args.dataset, mode='test', windows=[24, 32, 40], step=4)
25 | test_loader = data.DataLoader(test_dataset, batch_size=args.batch_size, collate_fn=test_fn,
26 | shuffle=False, num_workers=args.workers, pin_memory=True)
27 | model = load_model(device)
28 | test(model, test_loader)
29 |
30 |
31 | def train(train_loader):
32 | model = CNNCTC(class_num=37).to(device)
33 | if args.warm_up:
34 | model = load_model(device)
35 | optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
36 | loss_function = CTCLoss(size_average=True, length_average=True).to(device)
37 | min_loss = np.Inf
38 | for epoch in range(args.epoch):
39 | print('%d / %d Epoch' % (epoch, args.epoch))
40 | epoch_loss = train_epoch(train_loader, model, optimizer, loss_function)
41 | print(epoch_loss)
42 | if epoch_loss < min_loss:
43 | min_loss = epoch_loss
44 | torch.save(model.state_dict(), 'model.bin')
45 | return model
46 |
47 |
48 | def train_epoch(train_loader, model, optimizer, loss_function):
49 | total_loss = 0
50 | model.train()
51 | model.mode = 'train'
52 | for i, batch in enumerate(tqdm(train_loader)):
53 | optimizer.zero_grad()
54 | images = batch.images.to(device)
55 | labels = batch.labels
56 | label_lengths = batch.label_lengths
57 | probs = model(images)
58 | log_probs = probs.log_softmax(2).to(device)
59 | prob_lengths = torch.IntTensor([log_probs.size(0)] * args.batch_size)
60 | loss = loss_function(log_probs, labels, prob_lengths, label_lengths) / args.batch_size
61 | total_loss += loss.item()
62 | loss.backward()
63 | optimizer.step()
64 | return total_loss
65 |
66 |
67 | def test(model, test_loader):
68 | model.eval()
69 | model.mode = 'test'
70 | total, correct = 0, 0
71 | for i, batch in enumerate(tqdm(test_loader)):
72 | images = batch.images.to(device)
73 | labels = batch.labels
74 | out = model(images)
75 | for actual, label in zip(labels, out):
76 | if actual == label:
77 | correct += 1
78 | total += 1
79 | print(correct / total)
80 |
81 |
82 | if __name__ == '__main__':
83 | parser = argparse.ArgumentParser(description='Sliding convolution')
84 | parser.add_argument('--batch-size', default=16, type=int)
85 | parser.add_argument('--epoch', default=10, type=int)
86 | parser.add_argument('--workers', default=2, type=int)
87 | parser.add_argument('--dataset', default='IIIT5K')
88 | parser.add_argument('--cuda', default=False, type=bool)
89 | parser.add_argument('--warm-up', default=False, type=bool)
90 | parser.add_argument('--lr', default=0.001, type=float)
91 | parser.add_argument('--weight_decay', default=0.0001, type=float)
92 | parser.add_argument('--mode', default='train', type=str)
93 | args = parser.parse_args()
94 | if args.cuda:
95 | device = torch.device('cuda')
96 | else:
97 | device = torch.device('cpu')
98 | main()
99 |
--------------------------------------------------------------------------------
/model.bin:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lsvih/Sliding-Convolution/d3d71d094f4462c4130208734b4d018e9b27cad3/model.bin
--------------------------------------------------------------------------------
/model.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn as nn
4 |
5 |
6 | class CNNCTC(nn.Module):
7 | def __init__(self, class_num, mode='train'):
8 | super(CNNCTC, self).__init__()
9 | feature = [
10 | nn.Conv2d(3, 50, stride=1, kernel_size=3, padding=1),
11 | nn.BatchNorm2d(50),
12 | nn.ReLU(inplace=True),
13 | nn.Conv2d(50, 100, stride=1, kernel_size=3, padding=1),
14 | nn.Dropout(p=0.1),
15 | nn.Conv2d(100, 100, stride=1, kernel_size=3, padding=1),
16 | nn.Dropout(p=0.1),
17 | nn.BatchNorm2d(100),
18 | nn.ReLU(inplace=True),
19 | nn.MaxPool2d(2, stride=2),
20 | nn.Conv2d(100, 200, stride=1, kernel_size=3, padding=1),
21 | nn.Dropout(p=0.2),
22 | nn.Conv2d(200, 200, stride=1, kernel_size=3, padding=1),
23 | nn.Dropout(p=0.2),
24 | nn.BatchNorm2d(200),
25 | nn.ReLU(inplace=True),
26 | nn.MaxPool2d(2, stride=2),
27 | nn.Conv2d(200, 250, stride=1, kernel_size=3, padding=1),
28 | nn.Dropout(p=0.3),
29 | nn.BatchNorm2d(250),
30 | nn.ReLU(inplace=True),
31 | nn.Conv2d(250, 300, stride=1, kernel_size=3, padding=1),
32 | nn.Dropout(p=0.3),
33 | nn.Conv2d(300, 300, stride=1, kernel_size=3, padding=1),
34 | nn.Dropout(p=0.3),
35 | nn.BatchNorm2d(300),
36 | nn.ReLU(inplace=True),
37 | nn.MaxPool2d(2, stride=2),
38 | nn.Conv2d(300, 350, stride=1, kernel_size=3, padding=1),
39 | nn.Dropout(p=0.4),
40 | nn.BatchNorm2d(350),
41 | nn.ReLU(inplace=True),
42 | nn.Conv2d(350, 400, stride=1, kernel_size=3, padding=1),
43 | nn.Dropout(p=0.4),
44 | nn.Conv2d(400, 400, stride=1, kernel_size=3, padding=1),
45 | nn.Dropout(p=0.4),
46 | nn.BatchNorm2d(400),
47 | nn.ReLU(inplace=True),
48 | nn.MaxPool2d(2, stride=2)
49 | ]
50 |
51 | classifier = [
52 | nn.Linear(1600, 900),
53 | nn.ReLU(inplace=True),
54 | nn.Dropout(p=0.5),
55 | # nn.Linear(900, 200),
56 | # nn.ReLU(inplace=True),
57 | nn.Linear(900, class_num)
58 | ]
59 | self.mode = mode
60 | self.feature = nn.Sequential(*feature)
61 | self.classifier = nn.Sequential(*classifier)
62 |
63 | def forward(self, x): # x: batch, window, slice channel, h, w
64 | result = []
65 | for s in range(x.shape[1]):
66 | result.append(self.single_forward(x[:, s, :, :, :]))
67 | out = torch.stack(result)
68 | if self.mode != 'train':
69 | return self.decode(out)
70 | return out
71 |
72 | def single_forward(self, x):
73 | feat = self.feature(x)
74 | feat = feat.view(feat.shape[0], -1) # flatten
75 | out = self.classifier(feat)
76 | return out
77 |
78 | def decode(self, pred):
79 | pred = pred.permute(1, 0, 2).cpu().data.numpy() # batch, step, class
80 | seq = []
81 | for i in range(pred.shape[0]):
82 | seq.append(self.pred_to_string(pred[i]))
83 | return seq
84 |
85 | @staticmethod
86 | def pred_to_string(pred): # step, class
87 | seq = []
88 | for i in range(pred.shape[0]):
89 | label = np.argmax(pred[i])
90 | seq.append(label)
91 | out = []
92 | for i in range(len(seq)):
93 | if len(out) == 0:
94 | if seq[i] != 0:
95 | out.append(seq[i])
96 | else:
97 | if seq[i] != 0 and seq[i] != seq[i - 1]:
98 | out.append(seq[i])
99 | return out
100 |
--------------------------------------------------------------------------------
/prepare_IIIT5K_dataset.py:
--------------------------------------------------------------------------------
1 | import os
2 | import pickle
3 |
4 | import cv2
5 | import numpy as np
6 | import scipy.io as sio
7 |
8 |
9 | def char_to_label(char):
10 | if ord('A') <= ord(char) <= ord('Z'):
11 | return ord(char) - ord('A') + 1
12 | return 26 + ord(char) - ord('0') + 1
13 |
14 |
15 | def load_data_mat(name):
16 | print('Loading %s ...' % name)
17 | mat = sio.loadmat('dataset/IIIT5K/' + name + '.mat')[name][0]
18 | count = mat.shape[0]
19 | labels = []
20 | images = []
21 | for i in range(0, count):
22 | word = mat[i]['GroundTruth'][0]
23 | image = mat[i]['ImgName'][0]
24 | images.append('dataset/IIIT5K/' + image)
25 | labels.append([])
26 | for j in range(0, len(word)):
27 | labels[i].append(char_to_label(word[j]))
28 | labels[i] = np.asarray(labels[i], dtype='int32')
29 | labels = np.asarray(labels)
30 | return images, labels
31 |
32 |
33 | def prepare_images(images):
34 | decoded_images = []
35 | for img in images:
36 | img = cv2.imread(img, 0)
37 | scale = 32 / img.shape[0]
38 | img = cv2.resize(img, None, fx=scale, fy=scale)
39 | if img.shape[1] < 256:
40 | # Padding
41 | img = np.concatenate([np.array([[0] * ((256 - img.shape[1]) // 2)] * 32), img], axis=1)
42 | img = np.concatenate([img, np.array([[0] * (256 - img.shape[1])] * 32)], axis=1)
43 | else:
44 | img = cv2.resize(img, None, fx=256 / img.shape[1], fy=1)
45 | if img.shape[1] != 256:
46 | raise ValueError('shape = %d,%d' % img.shape)
47 | decoded_images.append(img)
48 | return np.asarray(decoded_images, np.float32) / 255
49 |
50 |
51 | def convert_if_needed(name):
52 | if os.path.exists('dataset/IIIT5K/' + name + '.pickle'):
53 | return
54 | images, labels = load_data_mat(name)
55 |
56 | with open('dataset/IIIT5K/' + name + '.pickle', 'wb') as f:
57 | pickle.dump((images, labels), f)
58 |
59 |
60 | def load_data(name):
61 | with open('dataset/IIIT5K/' + name + '.pickle', 'rb') as f:
62 | return pickle.load(f)
63 |
64 |
65 | def main():
66 | convert_if_needed('traindata')
67 | convert_if_needed('testdata')
68 |
69 | images, labels = load_data('traindata')
70 | images = prepare_images(images)
71 | # assert not np.any(np.isnan(images))
72 | # assert not np.any(np.isnan(labels))
73 |
74 | eval_images, eval_labels = load_data('testdata')
75 | eval_images = prepare_images(eval_images)
76 | # assert not np.any(np.isnan(eval_images))
77 | # assert not np.any(np.isnan(eval_images))
78 |
79 |
80 | if __name__ == '__main__':
81 | main()
82 |
--------------------------------------------------------------------------------
/resource/architecture.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lsvih/Sliding-Convolution/d3d71d094f4462c4130208734b4d018e9b27cad3/resource/architecture.png
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | from collections import OrderedDict
2 |
3 | import torch
4 |
5 | from model import CNNCTC
6 |
7 |
8 | def load_weights(target, source_state):
9 | new_dict = OrderedDict()
10 | for k, v in target.state_dict().items():
11 | if k in source_state and v.size() == source_state[k].size():
12 | new_dict[k] = source_state[k]
13 | else:
14 | new_dict[k] = v
15 | target.load_state_dict(new_dict)
16 |
17 |
18 | def load_model(device):
19 | model = CNNCTC(class_num=37).to(device)
20 | load_weights(model, torch.load('model.bin', map_location='cpu'))
21 | return model
22 |
--------------------------------------------------------------------------------