├── outputs
└── reference_result.png
├── tools
├── __init__.py
├── eval.py
└── train.py
├── requirements.txt
├── .idea
├── misc.xml
├── inspectionProfiles
│ └── profiles_settings.xml
├── modules.xml
├── Transformer_STR.iml
└── workspace.xml
├── data
├── __init__.py
└── lmdb_dataset.py
├── checkpoints
└── Transformer_STR_CUTE80_pretrained.txt
├── utils
├── __init__.py
├── misc.py
└── model_util.py
├── config.py
├── model
└── __init__.py
├── README.MD
└── evaluation.py
/outputs/reference_result.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/opconty/Transformer_STR/HEAD/outputs/reference_result.png
--------------------------------------------------------------------------------
/tools/__init__.py:
--------------------------------------------------------------------------------
1 | #-*- coding: utf-8 -*-
2 | #'''
3 | # @date: 2020/6/9 下午12:17
4 | #
5 | # @author: laygin
6 | #
7 | #'''
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | numpy==1.18.2
2 | nltk==3.3
3 | six==1.14.0
4 | opencv_python_headless==4.1.2.30
5 | torch==1.3.1
6 | lmdb==0.98
7 | torchvision==0.4.2
8 | nicelogger==2.0
9 | Pillow==7.1.2
10 |
--------------------------------------------------------------------------------
/.idea/misc.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
--------------------------------------------------------------------------------
/data/__init__.py:
--------------------------------------------------------------------------------
1 | #-*- coding: utf-8 -*-
2 | #'''
3 | # @date: 2020/6/9 下午12:16
4 | #
5 | # @author: laygin
6 | #
7 | #'''
8 | from .lmdb_dataset import hierarchical_dataset, AlignCollate, BatchBalancedDataset
9 |
--------------------------------------------------------------------------------
/checkpoints/Transformer_STR_CUTE80_pretrained.txt:
--------------------------------------------------------------------------------
1 | PLEASE DOWNLOAD PRETRAINED WEIGHT FROM:
2 | https://drive.google.com/file/d/1o7aEt_Rmz5ZDIZqc2Z74lo01Tjq87uO1/view?usp=sharing
3 |
4 | AND PLACED IT IN THE CURRENT DIR.
5 |
--------------------------------------------------------------------------------
/.idea/inspectionProfiles/profiles_settings.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
1 | #-*- coding: utf-8 -*-
2 | #'''
3 | # @date: 2020/6/9 下午12:17
4 | #
5 | # @author: laygin
6 | #
7 | #'''
8 | from .misc import TransLabelConverter, Averager, ResizeNormalize, get_img_tensor
9 |
10 | __all__ = ['TransLabelConverter', 'Averager', 'ResizeNormalize']
11 |
12 |
--------------------------------------------------------------------------------
/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/.idea/Transformer_STR.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
--------------------------------------------------------------------------------
/config.py:
--------------------------------------------------------------------------------
1 | #-*- coding: utf-8 -*-
2 | #'''
3 | # @date: 2020/6/9 下午12:17
4 | #
5 | # @author: laygin
6 | #
7 | #'''
8 | import os
9 | import sys
10 | import re
11 | import string
12 | from nicelogger import ColorLogger
13 |
14 |
15 | logger = ColorLogger()
16 | proj_dir = os.path.abspath(os.path.dirname(__file__))
17 | if proj_dir not in sys.path:
18 | sys.path.insert(0, proj_dir)
19 |
20 |
21 | checkpoint_dir = os.path.join(proj_dir, 'checkpoints')
22 | # fixme: data path configuration
23 | data_dir = 'path/to/data_lmdb_text_recognition'
24 |
25 | p = re.compile(r'[!"#$%&()*+,/:;<=>?@\\^_`{|}~]')
26 |
27 |
28 | class Config:
29 | workers = 8
30 | batch_max_length = 25
31 | batch_size = 32
32 | imgH = 32
33 | imgW = 100
34 | keep_ratio = False
35 | rgb = False
36 | sensitive = False
37 | data_filtering_off = False
38 | keep_ratio_with_pad = False
39 | total_data_usage_ratio = 1.0
40 |
41 | punctuation = r"""'.-"""
42 | character = string.digits + string.ascii_lowercase + punctuation
43 |
44 | num_fiducial = 20
45 | input_channel = 1
46 | output_channel = 512
47 | hidden_size = 256
48 | num_gpu = 1
49 |
50 | lr = 1.0
51 | grad_clip = 5
52 | beta1 = 0.9
53 | rho = 0.95
54 | eps = 1e-8
55 | manualSeed = 2020
56 |
--------------------------------------------------------------------------------
/model/__init__.py:
--------------------------------------------------------------------------------
1 | #-*- coding: utf-8 -*-
2 | #'''
3 | # @date: 2020/6/9 下午12:17
4 | #
5 | # @author: laygin
6 | #
7 | #'''
8 | import torch
9 | import torch.nn as nn
10 | import importlib
11 | from config import logger
12 | model_util = importlib.import_module('utils.model_util')
13 |
14 | TPS_STN = model_util.TPS_STN
15 | ResNet50 = model_util.ResNet50
16 | BiLSTM = model_util.BiLSTM
17 | Transformer = model_util.Transformer
18 |
19 |
20 | class Model(nn.Module):
21 | def __init__(self,
22 | imgh=32,
23 | imgw=100,
24 | input_channel=1,
25 | output_channel=512,
26 | hidden_size=256,
27 | num_fiducial=20,
28 | num_class=41,
29 | bilstm=True,
30 | device=torch.device('cuda:0')):
31 | super(Model, self).__init__()
32 |
33 | logger.info(f'bi-lstm: {bilstm} | device: {device} | num_class: {num_class}')
34 | self.num_class = num_class
35 | self.bilstm = bilstm
36 |
37 | self.transformation = TPS_STN(num_fiducial, I_size=(imgh, imgw), I_r_size=(imgh, imgw), device=device,
38 | I_channel_num=input_channel)
39 | self.fe = ResNet50(input_channel, output_channel)
40 |
41 | self.adaptive_avg_pool = nn.AdaptiveAvgPool2d((None, 1))
42 | self.seq = nn.Sequential(BiLSTM(output_channel, hidden_size, hidden_size),
43 | BiLSTM(hidden_size, hidden_size, hidden_size))
44 | if self.bilstm:
45 | self.seq_out_channels = hidden_size
46 | else:
47 | logger.warn('There is no sequence model specified')
48 | self.seq_out_channels = output_channel
49 | self.prediction = Transformer(self.num_class, self.seq_out_channels)
50 |
51 | def forward(self, x):
52 | x = self.transformation(x)
53 | x = self.fe(x)
54 | x = self.adaptive_avg_pool(x.permute(0,3,1,2)) # [b, c, h, w] -> [b, w, c, h]
55 | x = x.squeeze(3)
56 |
57 | if self.bilstm:
58 | x = self.seq(x)
59 |
60 | pred = self.prediction(x.contiguous())
61 | return pred
62 |
--------------------------------------------------------------------------------
/README.MD:
--------------------------------------------------------------------------------
1 | # Transformer-based Scene Text Recognition (Transformer-STR)
2 |
3 | - PyTorch implementation of my new method for Scene Text Recognition (STR) based on [Transformer](https://arxiv.org/abs/1706.03762).
4 |
5 | I adapted the four-stage STR framework devised by [deep-text-recognition-benchmark](https://arxiv.org/abs/1904.01906), and replaced the `Pred.` stage with **Transformer**.
6 |
7 | Equipped with Transformer, this method outperforms the best model of the aforementioned deep-text-recognition-benchmark by **7.6%** on CUTE80.
8 |
9 | ### Download pretrained weights from [here](https://drive.google.com/file/d/1o7aEt_Rmz5ZDIZqc2Z74lo01Tjq87uO1/view?usp=sharing)
10 | This pre-trained weights trained on Synthetic dataset for about 700K iters.
11 |
12 | Git clone this repo and download the weights file, move it to `checkpoints` directory.
13 |
14 | ### Download lmdb dataset for traininig and evaluation from [here](https://drive.google.com/drive/folders/192UfE9agQUMNq6AgU3_E05_FcPZK4hyt)(provided by [deep-text-recognition-benchmark](https://arxiv.org/abs/1904.01906))
15 | data_lmdb_release.zip contains below.
16 | training datasets : [MJSynth (MJ)](http://www.robots.ox.ac.uk/~vgg/data/text/)[1] and [SynthText (ST)](http://www.robots.ox.ac.uk/~vgg/data/scenetext/)[2] \
17 | validation datasets : the union of the training sets [IC13](http://rrc.cvc.uab.es/?ch=2)[3], [IC15](http://rrc.cvc.uab.es/?ch=4)[4], [IIIT](http://cvit.iiit.ac.in/projects/SceneTextUnderstanding/IIIT5K.html)[5], and [SVT](http://www.iapr-tc11.org/mediawiki/index.php/The_Street_View_Text_Dataset)[6].\
18 | evaluation datasets : benchmark evaluation datasets, consist of [IIIT](http://cvit.iiit.ac.in/projects/SceneTextUnderstanding/IIIT5K.html)[5], [SVT](http://www.iapr-tc11.org/mediawiki/index.php/The_Street_View_Text_Dataset)[6], [IC03](http://www.iapr-tc11.org/mediawiki/index.php/ICDAR_2003_Robust_Reading_Competitions)[7], [IC13](http://rrc.cvc.uab.es/?ch=2)[3], [IC15](http://rrc.cvc.uab.es/?ch=4)[4], [SVTP](http://openaccess.thecvf.com/content_iccv_2013/papers/Phan_Recognizing_Text_with_2013_ICCV_paper.pdf)[8], and [CUTE](http://cs-chan.com/downloads_CUTE80_dataset.html)[9].
19 |
20 |
21 | ### Training
22 | Please configure your `data_dir` in `config.py` file, then run:
23 |
24 | ```python
25 | python tools/train.py
26 | ```
27 |
28 | ### Evaluation on CUTE80
29 | The Transformer-base STR achieves **0.815972** accuracy on CUTE80, outperforming the best model of *deep-text-recognition-benchmark*, which is 0.74
30 |
31 | 
32 |
33 | If you want to reproduce the evaluation result, please run:
34 |
35 | ```python
36 | python evaluation.py
37 | ```
38 |
39 | Make sure your `cute80_dir` and `saved_model` path is correct. you'll get the result **0.815972**
40 |
41 |
42 | ### Contact
43 | Feel free to contact me (gao.gzhou@gmail.com).
44 |
45 | ### License
46 | This project is released under the [Apache 2.0 license.](https://www.apache.org/licenses/LICENSE-2.0)
47 |
48 | ### References
49 | [deep-text-recognition-benchmark](https://arxiv.org/abs/1904.01906)
50 |
--------------------------------------------------------------------------------
/tools/eval.py:
--------------------------------------------------------------------------------
1 | #-*- coding: utf-8 -*-
2 | #'''
3 | # @date: 2020/5/19 下午5:01
4 | #
5 | # @author: laygin
6 | #
7 | #'''
8 | import torch
9 | import torch.nn.functional as F
10 | import time
11 | from nltk.metrics.distance import edit_distance
12 | import importlib
13 |
14 | utils = importlib.import_module('utils')
15 | Averager = utils.Averager
16 |
17 |
18 | def validation(model, criterion, eval_loader, converter, device, cfg):
19 | n_correct = 0
20 | norm_ED = 0
21 | length_of_data = 0
22 | infer_time = 0
23 | valid_loss_avg = Averager()
24 |
25 | for i, (image_tensors, labels) in enumerate(eval_loader):
26 | batch_size = image_tensors.size(0)
27 | length_of_data = length_of_data + batch_size
28 | image = image_tensors.to(device)
29 | # For max length prediction
30 | length_for_pred = torch.IntTensor([cfg.batch_max_length] * batch_size).to(device)
31 |
32 | text_for_loss, length_for_loss = converter.encode(labels, batch_max_length=cfg.batch_max_length)
33 |
34 | start_time = time.time()
35 |
36 | with torch.no_grad():
37 | preds = model(image)
38 | forward_time = time.time() - start_time
39 |
40 | cost = criterion(preds.contiguous().view(-1, preds.shape[-1]), text_for_loss.contiguous().view(-1))
41 |
42 | # select max probabilty (greedy decoding) then decode index to character
43 | _, preds_index = preds.max(2)
44 | preds_str = converter.decode(preds_index, length_for_pred)
45 | labels = converter.decode(text_for_loss, length_for_loss)
46 |
47 | infer_time += forward_time
48 | valid_loss_avg.add(cost)
49 |
50 | # calculate accuracy & confidence score of one batch
51 | preds_prob = F.softmax(preds, dim=2)
52 | preds_max_prob, _ = preds_prob.max(dim=2)
53 | confidence_score_list = []
54 | for gt, pred, pred_max_prob in zip(labels, preds_str, preds_max_prob):
55 | gt = gt[:gt.find('')]
56 | pred_EOS = pred.find('')
57 | pred = pred[:pred_EOS]
58 | pred_max_prob = pred_max_prob[:pred_EOS]
59 |
60 | # fixme: do not care the case even case-sensitive model
61 | if not cfg.sensitive:
62 | pred = pred.lower()
63 | gt = gt.lower()
64 |
65 | if pred == gt:
66 | n_correct += 1
67 | if len(gt) == 0:
68 | norm_ED += 1
69 | else:
70 | norm_ED += edit_distance(pred, gt) / len(gt)
71 |
72 | # calculate confidence score (= multiply of pred_max_prob)
73 | try:
74 | confidence_score = pred_max_prob.cumprod(dim=0)[-1]
75 | except:
76 | confidence_score = 0 # for empty pred case, when prune after "end of sentence" token ([s])
77 | confidence_score_list.append(confidence_score)
78 | # print(pred, gt, pred==gt, confidence_score)
79 |
80 | accuracy = n_correct / float(length_of_data)
81 |
82 | return valid_loss_avg.val(), accuracy, norm_ED, preds_str, confidence_score_list, labels, infer_time, length_of_data
83 |
84 |
--------------------------------------------------------------------------------
/utils/misc.py:
--------------------------------------------------------------------------------
1 | #-*- coding: utf-8 -*-
2 | #'''
3 | # @date: 2020/5/18 下午6:06
4 | #
5 | # @author: laygin
6 | #
7 | #'''
8 | import torch
9 | from torchvision import transforms
10 | from PIL import Image
11 | import cv2
12 | import numpy as np
13 | import math
14 |
15 |
16 | class Averager(object):
17 | """Compute average for torch.Tensor, used for loss average."""
18 |
19 | def __init__(self):
20 | self.reset()
21 |
22 | def add(self, v):
23 | count = v.data.numel()
24 | v = v.data.sum()
25 | self.n_count += count
26 | self.sum += v
27 |
28 | def reset(self):
29 | self.n_count = 0
30 | self.sum = 0
31 |
32 | def val(self):
33 | res = 0
34 | if self.n_count != 0:
35 | res = self.sum / float(self.n_count)
36 | return res
37 |
38 |
39 | class ResizeNormalize(object):
40 |
41 | def __init__(self, size, interpolation=Image.BICUBIC):
42 | self.size = size
43 | self.interpolation = interpolation
44 | self.toTensor = transforms.ToTensor()
45 |
46 | def __call__(self, img):
47 | img = img.resize(self.size, self.interpolation)
48 | img = self.toTensor(img)
49 | img.sub_(0.5).div_(0.5)
50 | return img
51 |
52 |
53 | class NormalizePAD(object):
54 |
55 | def __init__(self, max_size, PAD_type='right'):
56 | self.toTensor = transforms.ToTensor()
57 | self.max_size = max_size
58 | self.max_width_half = math.floor(max_size[2] / 2)
59 | self.PAD_type = PAD_type
60 |
61 | def __call__(self, img):
62 | img = self.toTensor(img)
63 | img.sub_(0.5).div_(0.5)
64 | c, h, w = img.size()
65 | Pad_img = torch.FloatTensor(*self.max_size).fill_(0)
66 | Pad_img[:, :, :w] = img # right pad
67 | if self.max_size[2] != w: # add border Pad
68 | Pad_img[:, :, w:] = img[:, :, w - 1].unsqueeze(2).expand(c, h, self.max_size[2] - w)
69 |
70 | return Pad_img
71 |
72 |
73 | def get_img_tensor(img_path, newh=32, neww=100, keep_ratio=False):
74 | if isinstance(img_path, str):
75 | image = Image.open(img_path).convert('L')
76 | elif isinstance(img_path, np.ndarray):
77 | image = Image.fromarray(cv2.cvtColor(img_path, cv2.COLOR_BGR2RGB)).convert('L')
78 | else:
79 | raise Exception(f'{type(img_path)} not supported yet')
80 |
81 | if keep_ratio: # same concept with 'Rosetta' paper
82 | resized_max_w = neww
83 | input_channel = 3 if image.mode == 'RGB' else 1
84 | transform = NormalizePAD((input_channel, newh, resized_max_w))
85 |
86 | w, h = image.size
87 | ratio = w / float(h)
88 | if math.ceil(h * ratio) > neww:
89 | resized_w = neww
90 | else:
91 | resized_w = math.ceil(newh * ratio)
92 |
93 | resized_image = image.resize((resized_w, newh), Image.BICUBIC)
94 | t = transform(resized_image)
95 |
96 | else:
97 | transform = ResizeNormalize((neww, newh))
98 | t = transform(image)
99 | # print(f'image size: {image.size}\t tensor size: {t.size()}')
100 | return torch.unsqueeze(t, 0)
101 |
102 |
103 | class TransLabelConverter(object):
104 | """ Convert between text-label and text-index """
105 | def __init__(self, character, device):
106 | self.device = device
107 | list_token = ['']
108 | self.character = list_token + list(character)
109 |
110 | self.dict = {}
111 | for i, char in enumerate(self.character):
112 | self.dict[char] = i
113 |
114 | def encode(self, text, batch_max_length=25):
115 | length = [len(s) + 1 for s in text]
116 | batch_text = torch.LongTensor(len(text), batch_max_length + 1).fill_(0)
117 | for i, t in enumerate(text):
118 | text = list(t)
119 | text.append('')
120 | text = [self.dict[char] for char in text]
121 | batch_text[i][:len(text)] = torch.LongTensor(text)
122 | return batch_text.to(self.device), torch.IntTensor(length).to(self.device)
123 |
124 | def decode(self, text_index, length):
125 | texts = []
126 | for index, l in enumerate(length):
127 | text = ''.join([self.character[i] for i in text_index[index, :]])
128 | texts.append(text)
129 | return texts
130 |
--------------------------------------------------------------------------------
/.idea/workspace.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 |
56 |
57 |
58 |
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 |
67 |
68 |
69 |
70 |
71 |
72 |
73 |
74 |
75 |
76 |
77 |
78 |
79 |
80 |
81 |
82 |
83 |
84 |
85 |
86 |
87 |
88 |
89 |
90 |
91 |
92 |
93 |
94 |
95 |
96 |
97 |
98 |
99 |
100 | 1591676194053
101 |
102 |
103 | 1591676194053
104 |
105 |
106 |
107 |
--------------------------------------------------------------------------------
/evaluation.py:
--------------------------------------------------------------------------------
1 | #-*- coding: utf-8 -*-
2 | #'''
3 | # @date: 2020/6/9 下午12:17
4 | #
5 | # @author: laygin
6 | #
7 | #'''
8 | import os
9 | import torch
10 | import torch.backends.cudnn as cudnn
11 | import torch.nn.functional as F
12 | from torch.utils.data import DataLoader
13 | import re
14 | import string
15 | import config
16 | import warnings
17 | warnings.filterwarnings('ignore')
18 | import importlib
19 |
20 | utils = importlib.import_module('utils')
21 | eval = importlib.import_module('tools.eval')
22 | data = importlib.import_module('data')
23 | model = importlib.import_module('model')
24 |
25 | TransLabelConverter = utils.TransLabelConverter
26 | hierarchical_dataset = data.hierarchical_dataset
27 | AlignCollate = data.AlignCollate
28 | Model = model.Model
29 |
30 |
31 | logger = config.logger
32 | device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
33 | p = re.compile(r'[!"#$%&()*+,/:;<=>?@\\^_`{|}~]')
34 |
35 |
36 | class Config(config.Config):
37 | valid_data = os.path.join(config.data_dir, 'validation')
38 | workers = 0
39 | batch_size = 32
40 | num_class = 40
41 | with_bilstm = True
42 | sensitive = False
43 | filter_punctuation = False
44 | backbone = 'resnet'
45 |
46 | checkpoint_dir = config.checkpoint_dir
47 | saved_model = ''
48 |
49 |
50 | def create_model(cfg: Config):
51 | """model"""
52 | cfg.sensitive = True if 'sensitive' in cfg.saved_model else False
53 |
54 | if cfg.sensitive:
55 | cfg.character = string.digits + string.ascii_letters + cfg.punctuation
56 |
57 | converter = TransLabelConverter(cfg.character, device)
58 | cfg.num_class = len(converter.character)
59 |
60 | if cfg.rgb:
61 | cfg.input_channel = 3
62 | model = Model(cfg.imgH, cfg.imgW, cfg.input_channel, cfg.output_channel, cfg.hidden_size,
63 | cfg.num_fiducial, cfg.num_class, cfg.with_bilstm, device=device)
64 |
65 | # data parallel for multi-GPU
66 | model = torch.nn.DataParallel(model).to(device)
67 | assert os.path.exists(cfg.saved_model), FileNotFoundError(f'{cfg.saved_model}')
68 |
69 | if os.path.isfile(cfg.saved_model):
70 | logger.info(f'loading pretrained model from {os.path.relpath(cfg.saved_model, os.path.dirname(__file__))}')
71 | model.load_state_dict(torch.load(cfg.saved_model, map_location=device))
72 |
73 | model.eval()
74 |
75 | return model, converter
76 |
77 |
78 | def validation(cfg: Config, model, converter):
79 | cfg.sensitive = True if 'sensitive' in cfg.saved_model else False
80 | AlignCollate_valid = AlignCollate()
81 | valid_dataset = hierarchical_dataset(cfg.valid_data, cfg.imgH, cfg.imgW, cfg.batch_max_length, cfg.character,
82 | cfg.sensitive, cfg.rgb, cfg.data_filtering_off)
83 | valid_loader = DataLoader(
84 | valid_dataset, batch_size=cfg.batch_size,
85 | shuffle=False,
86 | num_workers=int(cfg.workers),
87 | collate_fn=AlignCollate_valid, pin_memory=True)
88 |
89 | model.eval()
90 |
91 | n_correct = 0
92 | length_of_data = 0
93 |
94 | for i, (image_tensors, labels) in enumerate(valid_loader):
95 | batch_size = image_tensors.size(0)
96 | length_of_data = length_of_data + batch_size
97 | image = image_tensors.to(device)
98 | # For max length prediction
99 | length_for_pred = torch.IntTensor([cfg.batch_max_length] * batch_size).to(device)
100 |
101 | text_for_loss, length_for_loss = converter.encode(labels, batch_max_length=cfg.batch_max_length)
102 |
103 | with torch.no_grad():
104 | preds = model(image)
105 |
106 | # select max probabilty (greedy decoding) then decode index to character
107 | _, preds_index = preds.max(2)
108 | preds_str = converter.decode(preds_index, length_for_pred)
109 | labels = converter.decode(text_for_loss, length_for_loss)
110 |
111 | # calculate accuracy & confidence score of one batch
112 | preds_prob = F.softmax(preds, dim=2)
113 | preds_max_prob, _ = preds_prob.max(dim=2)
114 | for gt, pred, pred_max_prob in zip(labels, preds_str, preds_max_prob):
115 | gt = gt[:gt.find('')]
116 | pred_EOS = pred.find('')
117 | pred = pred[:pred_EOS]
118 | pred_max_prob = pred_max_prob[:pred_EOS]
119 |
120 | try:
121 | confidence_score = pred_max_prob.cumprod(dim=0)[-1]
122 | except:
123 | confidence_score = 0.0
124 |
125 | if not cfg.sensitive:
126 | pred = pred.lower()
127 |
128 | # fixme: filter punctuation
129 | if cfg.filter_punctuation:
130 | pred = re.sub(p, '', pred)
131 | gt = re.sub(p, '', gt)
132 |
133 | if pred == gt:
134 | n_correct += 1
135 |
136 | accuracy = n_correct / float(length_of_data)
137 |
138 | return accuracy
139 |
140 |
141 | def eval_cute80(cute80_data_dir):
142 | cfg = Config()
143 | cfg.saved_model = os.path.join(cfg.checkpoint_dir,
144 | 'Transformer_STR_CUTE80_pretrained.pth')
145 |
146 | model, converter = create_model(cfg)
147 |
148 | cfg.valid_data = cute80_data_dir
149 | acc = validation(cfg, model, converter)
150 |
151 | logger.success(f'{acc:.6f}')
152 |
153 |
154 | if __name__ == '__main__':
155 | cudnn.benchmark = True
156 | cudnn.deterministic = True
157 |
158 | cute80_dir = os.path.join(config.data_dir, 'evaluation', 'CUTE80')
159 | eval_cute80(cute80_dir)
160 |
161 |
--------------------------------------------------------------------------------
/tools/train.py:
--------------------------------------------------------------------------------
1 | #-*- coding: utf-8 -*-
2 | #'''
3 | # @date: 2020/6/9 下午12:26
4 | #
5 | # @author: laygin
6 | #
7 | #'''
8 | import os, sys, time, random
9 | import torch
10 | import torch.backends.cudnn as cudnn
11 | from torch.utils.data import DataLoader
12 | import torch.nn.init as init
13 | import torch.optim as optim
14 | import numpy as np
15 | import string
16 | import config
17 | import importlib
18 |
19 | utils = importlib.import_module('utils')
20 | eval = importlib.import_module('tools.eval')
21 | data = importlib.import_module('data')
22 | model = importlib.import_module('model')
23 |
24 | TransLabelConverter = utils.TransLabelConverter
25 | Averager = utils.Averager
26 | validation = eval.validation
27 | hierarchical_dataset = data.hierarchical_dataset
28 | AlignCollate = data.AlignCollate
29 | BatchBalancedDataset = data.BatchBalancedDataset
30 | Model = model.Model
31 |
32 |
33 | logger = config.logger
34 | device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
35 |
36 |
37 | class Config(config.Config):
38 | train_data = os.path.join(config.data_dir, 'training')
39 | valid_data = os.path.join(config.data_dir, 'validation')
40 | saved_model = ''
41 | select_data = ['SJ', 'MJ'] # select training data
42 | batch_ratio = [0.5, 0.5] # assign ratio for each selected data in the batch
43 | workers = 0
44 | batch_size = 192
45 | num_iter = 3000000
46 | valInterval = 2000
47 | total_data_usage_ratio = 1.0
48 | num_class = 40
49 | with_bilstm = True
50 | sensitive = False
51 |
52 | checkpoint_dir = config.checkpoint_dir
53 | Name = f'trans_{"".join(select_data)}_brnn{with_bilstm}'
54 |
55 |
56 | def train(cfg):
57 | """ dataset preparation """
58 | if not cfg.data_filtering_off:
59 | logger.info('Filtering the images containing characters which are not in character')
60 | logger.info(f'Filtering the images whose label is longer than {cfg.batch_max_length}')
61 |
62 | if cfg.sensitive:
63 | cfg.character = string.digits + string.ascii_letters + cfg.punctuation
64 | cfg.Name += '_sensitive'
65 |
66 | train_dataset = BatchBalancedDataset(cfg.train_data, cfg.batch_max_length,cfg.character,
67 | cfg.select_data, cfg.batch_ratio, cfg.batch_size, cfg.total_data_usage_ratio,
68 | cfg.workers, cfg.sensitive, cfg.rgb, cfg.data_filtering_off, cfg.imgH, cfg.imgW,
69 | cfg.keep_ratio_with_pad)
70 |
71 | AlignCollate_valid = AlignCollate()
72 | valid_dataset = hierarchical_dataset(cfg.valid_data, cfg.imgH, cfg.imgW, cfg.batch_max_length, cfg.character,
73 | cfg.sensitive, cfg.rgb, cfg.data_filtering_off)
74 | valid_loader = DataLoader(
75 | valid_dataset, batch_size=cfg.batch_size,
76 | shuffle=True,
77 | num_workers=int(cfg.workers),
78 | collate_fn=AlignCollate_valid, pin_memory=True)
79 |
80 | """model"""
81 | converter = TransLabelConverter(cfg.character, device)
82 | cfg.num_class = len(converter.character)
83 | logger.verbose(f'{cfg.num_class}\n{converter.character}')
84 |
85 | if cfg.rgb:
86 | cfg.input_channel = 3
87 | model = Model(cfg.imgH,cfg.imgW, cfg.input_channel, cfg.output_channel, cfg.hidden_size,
88 | cfg.num_fiducial, cfg.num_class, cfg.with_bilstm,
89 | device=device)
90 |
91 | logger.info('initialize')
92 | for name, param in model.named_parameters():
93 | if 'localization_fc2' in name or 'decoder' in name or 'self_attn' in name:
94 | logger.info(f'Skip {name} as it is already initialized')
95 | continue
96 | try:
97 | if 'bias' in name:
98 | init.constant_(param, 0.0)
99 | elif 'weight' in name:
100 | init.xavier_normal_(param)
101 | except: # for batchnorm.
102 | if 'weight' in name:
103 | param.data.fill_(1)
104 | continue
105 |
106 | # data parallel for multi-GPU
107 | model = torch.nn.DataParallel(model).to(device)
108 | model.train()
109 | if os.path.isfile(cfg.saved_model):
110 | logger.info(f'loading pretrained model from {cfg.saved_model}')
111 | model.load_state_dict(torch.load(cfg.saved_model))
112 |
113 | criterion = torch.nn.CrossEntropyLoss().to(device)
114 | loss_avg = Averager()
115 |
116 | # filter that only require gradient decent
117 | filtered_parameters = []
118 | params_num = []
119 | for p in filter(lambda p: p.requires_grad, model.parameters()):
120 | filtered_parameters.append(p)
121 | params_num.append(np.prod(p.size()))
122 | logger.info('Trainable params num : ', sum(params_num))
123 |
124 | # setup optimizer
125 | optimizer = optim.Adadelta(filtered_parameters, lr=cfg.lr, rho=cfg.rho, eps=cfg.eps)
126 |
127 | """ start training """
128 | start_iter = 0
129 | start_time = time.time()
130 | best_accuracy = -1
131 | best_norm_ED = 1e+6
132 | i = start_iter
133 | epoch_size = len(train_dataset) // cfg.batch_size
134 |
135 | while True:
136 | image_tensors, labels = train_dataset.get_batch()
137 | image = image_tensors.to(device)
138 | target, length = converter.encode(labels, batch_max_length=cfg.batch_max_length)
139 |
140 | preds = model(image)
141 | cost = criterion(preds.view(-1, preds.shape[-1]), target.contiguous().view(-1))
142 |
143 | model.zero_grad()
144 | cost.backward()
145 | torch.nn.utils.clip_grad_norm_(model.parameters(), cfg.grad_clip)
146 | optimizer.step()
147 |
148 | loss_avg.add(cost)
149 |
150 | # validation part
151 | if i % cfg.valInterval == 0 and i > 0:
152 | elapsed_time = time.time() - start_time
153 |
154 | model.eval()
155 | with torch.no_grad():
156 | valid_loss, current_accuracy, current_norm_ED, preds, confidence_score, labels, infer_time, length_of_data = validation(
157 | model, criterion, valid_loader, converter, device, cfg)
158 | model.train()
159 |
160 | # training loss and validation loss
161 | loss_log = f'[{i}/{cfg.num_iter}({epoch_size})] Train loss: {loss_avg.val():0.5f},' \
162 | f' Valid loss: {valid_loss:0.5f}, Elapsed_time: {elapsed_time:0.5f}'
163 | logger.info(loss_log)
164 | loss_avg.reset()
165 |
166 | current_model_log = f'{"Current_accuracy":17s}: {current_accuracy:0.3f}, {"Current_norm_ED":17s}: {current_norm_ED:0.2f}'
167 | logger.info(current_model_log)
168 |
169 | # keep best accuracy model (on valid dataset)
170 | if current_accuracy > best_accuracy and current_accuracy > 0.65:
171 | best_accuracy = current_accuracy
172 | best_norm_ED = current_norm_ED if current_norm_ED < best_norm_ED else best_norm_ED
173 | acc = f'{best_accuracy:.4f}'.replace('.', '')
174 | # save
175 | torch.save(model.state_dict(), os.path.join(cfg.checkpoint_dir, f'{cfg.Name}_iter{i + 1}_acc{acc}.pth'))
176 | best_model_log = f'{"Best_accuracy":17s}: {best_accuracy:0.3f}, {"Best_norm_ED":17s}: {best_norm_ED:0.2f}'
177 | logger.success(best_model_log)
178 |
179 | # show some predicted results
180 | print('-' * 80)
181 | print(f'{"Ground Truth":25s} | {"Prediction":25s} | Confidence Score & T/F')
182 | print('-' * 80)
183 | for gt, pred, confidence in zip(labels[:5], preds[:5], confidence_score[:5]):
184 | gt = gt[:gt.find('')]
185 | pred = pred[:pred.find('')]
186 |
187 | print(f'{gt:25s} | {pred:25s} | {confidence:0.4f}\t{str(pred.lower() == gt.lower())}')
188 | print('-' * 80)
189 |
190 | # save model per 1e+5 iter.
191 | if (i + 1) % 1e+5 == 0:
192 | torch.save(
193 | model.state_dict(), os.path.join(cfg.checkpoint_dir, f'{cfg.Name}_iter{i + 1}.pth'))
194 |
195 | if i == cfg.num_iter:
196 | print('end the training')
197 | sys.exit()
198 | i += 1
199 |
200 |
201 | if __name__ == '__main__':
202 | cfg = Config()
203 |
204 | random.seed(cfg.manualSeed)
205 | np.random.seed(cfg.manualSeed)
206 | torch.manual_seed(cfg.manualSeed)
207 | torch.cuda.manual_seed(cfg.manualSeed)
208 |
209 | cudnn.benchmark = True
210 | cudnn.deterministic = True
211 |
212 | train(cfg)
213 |
214 |
215 |
216 |
--------------------------------------------------------------------------------
/data/lmdb_dataset.py:
--------------------------------------------------------------------------------
1 | #-*- coding: utf-8 -*-
2 | #'''
3 | # @date: 2020/5/19 上午10:21
4 | #
5 | # @author: laygin
6 | #
7 | # adapted from https://github.com/clovaai/deep-text-recognition-benchmark
8 | #'''
9 | import os, sys, re, six, math, lmdb
10 | import torch
11 | from PIL import Image
12 | from torch.utils.data import Dataset, ConcatDataset, Subset
13 | from utils.misc import ResizeNormalize, NormalizePAD
14 | import config
15 |
16 |
17 | logger = config.logger
18 |
19 |
20 | def _accumulate(iterable, fn=lambda x, y: x + y):
21 | '''Return running totals'''
22 | # _accumulate([1,2,3,4,5]) --> 1 3 6 10 15
23 | # _accumulate([1,2,3,4,5], operator.mul) --> 1 2 6 24 120
24 | it = iter(iterable)
25 | try:
26 | total = next(it)
27 | except StopIteration:
28 | return
29 | yield total
30 | for element in it:
31 | total = fn(total, element)
32 | yield total
33 |
34 |
35 | class AlignCollate:
36 | def __init__(self, imgh=32, imgw=100, keep_ratio_with_pad=False):
37 | self.imgh = imgh
38 | self.imgw = imgw
39 | self.keep_ratio_with_pad = keep_ratio_with_pad
40 |
41 | def __call__(self, batch):
42 | batch = filter(lambda x:x is not None, batch)
43 | images, labels = zip(*batch)
44 |
45 | if self.keep_ratio_with_pad:
46 | resized_max_w = self.imgw
47 | input_channel = 3 if images[0].mode == 'RGB' else 1
48 | transform = NormalizePAD((input_channel, self.imgh, resized_max_w))
49 |
50 | resized_images = []
51 | for image in images:
52 | w, h = image.size
53 | ratio = w / float(h)
54 | if math.ceil(self.imgh * ratio) > self.imgw:
55 | resized_w = self.imgw
56 | else:
57 | resized_w = math.ceil(self.imgh * ratio)
58 |
59 | resized_image = image.resize((resized_w, self.imgh), Image.BICUBIC)
60 | resized_images.append(transform(resized_image))
61 | image_tensors = torch.cat([t.unsqueeze(0) for t in resized_images], 0)
62 | else:
63 | transform = ResizeNormalize((self.imgw, self.imgh))
64 | image_tensors = [transform(i) for i in images]
65 | image_tensors = torch.cat([t.unsqueeze(0) for t in image_tensors], 0)
66 | return image_tensors, labels
67 |
68 |
69 | class LmdbDataset(Dataset):
70 | def __init__(self, root, imgh, imgw, batch_max_length, character, sensitive, rgb, data_filtering_off=False):
71 | self.root = root
72 | self.character = character
73 | self.sensitive = sensitive
74 | self.rgb = rgb
75 | self.imgH = imgh
76 | self.imgW = imgw
77 |
78 | self.env = lmdb.open(root, max_readers=32, readonly=True, lock=False, readahead=False, meminit=False)
79 | if not self.env:
80 | logger.error('can not create lmdb from %s' % root)
81 | sys.exit(0)
82 |
83 | with self.env.begin(write=False) as txn:
84 | self.nsamples = int(txn.get('num-samples'.encode()))
85 |
86 | if data_filtering_off:
87 | self.filtered_index_list = [index + 1 for index in range(self.nsamples)]
88 | else:
89 | self.filtered_index_list = []
90 | for index in range(self.nsamples):
91 | index += 1
92 | label_key = 'label-%09d'.encode() % index
93 | label = txn.get(label_key).decode('utf-8')
94 |
95 | if len(label) > batch_max_length:
96 | continue
97 | out_of_char = f'[^{self.character}]'
98 | label = label if self.sensitive else label.lower()
99 | if re.search(out_of_char, label):
100 | continue
101 | self.filtered_index_list.append(index)
102 | self.nsamples = len(self.filtered_index_list)
103 |
104 | def __len__(self):
105 | return self.nsamples
106 |
107 | def __getitem__(self, index):
108 | assert index <= len(self), 'index range error'
109 | index = self.filtered_index_list[index]
110 |
111 | with self.env.begin(write=False) as txn:
112 | label_key = 'label-%09d'.encode() % index
113 | label = txn.get(label_key).decode('utf-8')
114 | img_key = 'image-%09d'.encode() % index
115 | imgbuf = txn.get(img_key)
116 |
117 | buf = six.BytesIO()
118 | buf.write(imgbuf)
119 | buf.seek(0)
120 | try:
121 | if self.rgb:
122 | img = Image.open(buf).convert('RGB') # for color image
123 | else:
124 | img = Image.open(buf).convert('L')
125 |
126 | except IOError:
127 | logger.error(f'Corrupted image for {index}')
128 | # make dummy image and dummy label for corrupted image.
129 | if self.rgb:
130 | img = Image.new('RGB', (self.imgW, self.imgH))
131 | else:
132 | img = Image.new('L', (self.imgW, self.imgH))
133 | label = '[dummy_label]'
134 |
135 | if not self.sensitive:
136 | label = label.lower()
137 |
138 | out_of_char = f'[^{self.character}]'
139 | label = re.sub(out_of_char, '', label)
140 |
141 | return img, label
142 |
143 |
144 | def hierarchical_dataset(root, imgh, imgw, batch_max_length,character,
145 | sensitive=False, rgb=False, data_filtering_off=False, select_data='/'):
146 | dataset_list = []
147 | for dirpath, dirnames, filenames in os.walk(root + '/'):
148 | if not dirnames:
149 | select_flag = False
150 | for select_d in select_data:
151 | if select_d in dirpath:
152 | select_flag = True
153 | break
154 | if select_flag:
155 | dataset = LmdbDataset(dirpath,
156 | imgh=imgh,
157 | imgw=imgw,
158 | batch_max_length=batch_max_length,
159 | character=character,sensitive=sensitive,rgb=rgb, data_filtering_off=data_filtering_off)
160 | logger.info(f'sub-directory:\t{os.path.relpath(dirpath, root)}\t num samples: {len(dataset)}')
161 | dataset_list.append(dataset)
162 | return ConcatDataset(dataset_list)
163 |
164 |
165 | class BatchBalancedDataset:
166 | def __init__(self,
167 | root,batch_max_length, character,
168 | select_data,
169 | batch_ratio,
170 | batch_size,
171 | total_data_usage_ratio,
172 | workers,
173 | sensitive=False, rgb=False, data_filtering_off=False,
174 | imgh=32, imgw=100, keep_ratio_with_pad=False):
175 | assert len(select_data) == len(batch_ratio)
176 |
177 | _AlignCollate = AlignCollate(imgh=imgh, imgw=imgw, keep_ratio_with_pad=keep_ratio_with_pad)
178 | self.data_loader_list = []
179 | self.data_loader_iter_list = []
180 | batch_size_list = []
181 | total_batch_size = 0
182 | self.total_samples = 0
183 | for selected_d, batch_ratio_d in zip(select_data, batch_ratio):
184 | _batch_size = max(round(batch_size * float(batch_ratio_d)), 1)
185 | print('-' * 80)
186 | _dataset = hierarchical_dataset(root, imgh,imgw,batch_max_length,
187 | character,sensitive,rgb,data_filtering_off,
188 | select_data=[selected_d])
189 | total_number_dataset = len(_dataset)
190 |
191 | """
192 | total_data_usage_ratio = 1 indicates 100% usage, and 0.2 indicates 20% usage.
193 | """
194 | number_dataset = int(total_number_dataset * float(total_data_usage_ratio))
195 | self.total_samples += number_dataset
196 | dataset_split = [number_dataset, total_number_dataset - number_dataset]
197 |
198 | indices = range(total_number_dataset)
199 | _dataset, _ = [Subset(_dataset, indices[offset-length: offset])
200 | for offset, length in zip(_accumulate(dataset_split), dataset_split)]
201 |
202 | logger.info(f'number total samples of {selected_d}: {total_number_dataset} x '
203 | f'{total_data_usage_ratio} (usage_ratio)={len(_dataset)}')
204 | logger.info(f'number samples of {selected_d} per batch: {batch_size} x'
205 | f'{batch_ratio_d}(batch ratio)={_batch_size}')
206 |
207 | batch_size_list.append(str(_batch_size))
208 | total_batch_size += _batch_size
209 |
210 | _data_loader = torch.utils.data.DataLoader(
211 | _dataset, batch_size=_batch_size, shuffle=True, num_workers=workers,
212 | collate_fn=_AlignCollate, pin_memory=True
213 | )
214 | self.data_loader_list.append(_data_loader)
215 | self.data_loader_iter_list.append(iter(_data_loader))
216 |
217 | print('-'*80)
218 | logger.info('Total_batch_size: ', '+'.join(batch_size_list), '=', str(total_batch_size))
219 | logger.info(f'Total samples: {self.total_samples}')
220 | self.batch_size = total_batch_size
221 | print('-' * 80)
222 |
223 | def __len__(self):
224 | return self.total_samples
225 |
226 | def get_batch(self):
227 | balanced_batch_images = []
228 | balanced_batch_texts = []
229 |
230 | for i, data_loader_iter in enumerate(self.data_loader_iter_list):
231 | try:
232 | image, text = data_loader_iter.next()
233 | balanced_batch_images.append(image)
234 | balanced_batch_texts += text
235 | except StopIteration:
236 | self.data_loader_iter_list[i] = iter(self.data_loader_list[i])
237 | image, text = self.data_loader_iter_list[i].next()
238 | balanced_batch_images.append(image)
239 | balanced_batch_texts += text
240 | except ValueError:
241 | pass
242 |
243 | balanced_batch_images = torch.cat(balanced_batch_images, 0)
244 |
245 | return balanced_batch_images, balanced_batch_texts
246 |
--------------------------------------------------------------------------------
/utils/model_util.py:
--------------------------------------------------------------------------------
1 | #-*- coding: utf-8 -*-
2 | #'''
3 | # @date: 2020/5/18 下午6:06
4 | #
5 | # @author: laygin
6 | #
7 | #'''
8 | import math
9 | import numpy as np
10 | import torch
11 | import torch.nn as nn
12 | from torch.nn.init import xavier_uniform_
13 | from torch.nn.init import constant_
14 | from torch.nn.init import xavier_normal_
15 | import copy
16 | import torch.nn.functional as F
17 | from config import logger
18 |
19 |
20 | class TPS_STN(nn.Module):
21 | """ Rectification Network of RARE, namely TPS based STN """
22 |
23 | def __init__(self, F, I_size, I_r_size, device, I_channel_num=1):
24 | """ Based on RARE TPS
25 | input:
26 | batch_I: Batch Input Image [batch_size x I_channel_num x I_height x I_width]
27 | I_size : (height, width) of the input image I
28 | I_r_size : (height, width) of the rectified image I_r
29 | I_channel_num : the number of channels of the input image I
30 | output:
31 | batch_I_r: rectified image [batch_size x I_channel_num x I_r_height x I_r_width]
32 | """
33 | super(TPS_STN, self).__init__()
34 | self.F = F
35 | self.I_size = I_size
36 | self.I_r_size = I_r_size # = (I_r_height, I_r_width)
37 | self.I_channel_num = I_channel_num
38 | self.LocalizationNetwork = LocalizationNetwork(self.F, self.I_channel_num)
39 | self.GridGenerator = GridGenerator(self.F, self.I_r_size, device)
40 |
41 | def forward(self, batch_I):
42 | batch_C_prime = self.LocalizationNetwork(batch_I) # batch_size x K x 2
43 | build_P_prime = self.GridGenerator.build_P_prime(batch_C_prime) # batch_size x n (= I_r_width x I_r_height) x 2
44 | build_P_prime_reshape = build_P_prime.reshape([build_P_prime.size(0), self.I_r_size[0], self.I_r_size[1], 2])
45 | batch_I_r = F.grid_sample(batch_I, build_P_prime_reshape, padding_mode='border')
46 |
47 | return batch_I_r
48 |
49 |
50 | class LocalizationNetwork(nn.Module):
51 | """ Localization Network of RARE, which predicts C' (K x 2) from I (I_width x I_height) """
52 |
53 | def __init__(self, F, I_channel_num):
54 | super(LocalizationNetwork, self).__init__()
55 | self.F = F
56 | self.I_channel_num = I_channel_num
57 | self.conv = nn.Sequential(
58 | nn.Conv2d(in_channels=self.I_channel_num, out_channels=64, kernel_size=3, stride=1, padding=1,
59 | bias=False), nn.BatchNorm2d(64), nn.ReLU(True),
60 | nn.MaxPool2d(2, 2), # batch_size x 64 x I_height/2 x I_width/2
61 | nn.Conv2d(64, 128, 3, 1, 1, bias=False), nn.BatchNorm2d(128), nn.ReLU(True),
62 | nn.MaxPool2d(2, 2), # batch_size x 128 x I_height/4 x I_width/4
63 | nn.Conv2d(128, 256, 3, 1, 1, bias=False), nn.BatchNorm2d(256), nn.ReLU(True),
64 | nn.MaxPool2d(2, 2), # batch_size x 256 x I_height/8 x I_width/8
65 | nn.Conv2d(256, 512, 3, 1, 1, bias=False), nn.BatchNorm2d(512), nn.ReLU(True),
66 | nn.AdaptiveAvgPool2d(1) # batch_size x 512
67 | )
68 |
69 | self.localization_fc1 = nn.Sequential(nn.Linear(512, 256), nn.ReLU(True))
70 | self.localization_fc2 = nn.Linear(256, self.F * 2)
71 |
72 | # Init fc2 in LocalizationNetwork
73 | self.localization_fc2.weight.data.fill_(0)
74 | """ see RARE paper Fig. 6 (a) """
75 | ctrl_pts_x = np.linspace(-1.0, 1.0, int(F / 2))
76 | ctrl_pts_y_top = np.linspace(0.0, -1.0, num=int(F / 2))
77 | ctrl_pts_y_bottom = np.linspace(1.0, 0.0, num=int(F / 2))
78 | ctrl_pts_top = np.stack([ctrl_pts_x, ctrl_pts_y_top], axis=1)
79 | ctrl_pts_bottom = np.stack([ctrl_pts_x, ctrl_pts_y_bottom], axis=1)
80 | initial_bias = np.concatenate([ctrl_pts_top, ctrl_pts_bottom], axis=0)
81 | self.localization_fc2.bias.data = torch.from_numpy(initial_bias).float().view(-1)
82 |
83 | def forward(self, batch_I):
84 | """
85 | input: batch_I : Batch Input Image [batch_size x I_channel_num x I_height x I_width]
86 | output: batch_C_prime : Predicted coordinates of fiducial points for input batch [batch_size x F x 2]
87 | """
88 | batch_size = batch_I.size(0)
89 | features = self.conv(batch_I).view(batch_size, -1)
90 | batch_C_prime = self.localization_fc2(self.localization_fc1(features)).view(batch_size, self.F, 2)
91 | return batch_C_prime
92 |
93 |
94 | class GridGenerator(nn.Module):
95 | """ Grid Generator of RARE, which produces P_prime by multipling T with P """
96 |
97 | def __init__(self, F, I_r_size, device):
98 | """ Generate P_hat and inv_delta_C for later """
99 | super(GridGenerator, self).__init__()
100 | self.device = device
101 | self.eps = 1e-6
102 | self.I_r_height, self.I_r_width = I_r_size
103 | self.F = F
104 | self.C = self._build_C(self.F) # F x 2
105 | self.P = self._build_P(self.I_r_width, self.I_r_height)
106 | ## for multi-gpu, you need register buffer
107 | self.register_buffer("inv_delta_C", torch.tensor(self._build_inv_delta_C(self.F, self.C)).float()) # F+3 x F+3
108 | self.register_buffer("P_hat", torch.tensor(self._build_P_hat(self.F, self.C, self.P)).float()) # n x F+3
109 | ## for fine-tuning with different image width, you may use below instead of self.register_buffer
110 | #self.inv_delta_C = torch.tensor(self._build_inv_delta_C(self.F, self.C)).float().cuda() # F+3 x F+3
111 | #self.P_hat = torch.tensor(self._build_P_hat(self.F, self.C, self.P)).float().cuda() # n x F+3
112 |
113 | def _build_C(self, F):
114 | """ Return coordinates of fiducial points in I_r; C """
115 | ctrl_pts_x = np.linspace(-1.0, 1.0, int(F / 2))
116 | ctrl_pts_y_top = -1 * np.ones(int(F / 2))
117 | ctrl_pts_y_bottom = np.ones(int(F / 2))
118 | ctrl_pts_top = np.stack([ctrl_pts_x, ctrl_pts_y_top], axis=1)
119 | ctrl_pts_bottom = np.stack([ctrl_pts_x, ctrl_pts_y_bottom], axis=1)
120 | C = np.concatenate([ctrl_pts_top, ctrl_pts_bottom], axis=0)
121 | return C # F x 2
122 |
123 | def _build_inv_delta_C(self, F, C):
124 | """ Return inv_delta_C which is needed to calculate T """
125 | hat_C = np.zeros((F, F), dtype=float) # F x F
126 | for i in range(0, F):
127 | for j in range(i, F):
128 | r = np.linalg.norm(C[i] - C[j])
129 | hat_C[i, j] = r
130 | hat_C[j, i] = r
131 | np.fill_diagonal(hat_C, 1)
132 | hat_C = (hat_C ** 2) * np.log(hat_C)
133 | # print(C.shape, hat_C.shape)
134 | delta_C = np.concatenate( # F+3 x F+3
135 | [
136 | np.concatenate([np.ones((F, 1)), C, hat_C], axis=1), # F x F+3
137 | np.concatenate([np.zeros((2, 3)), np.transpose(C)], axis=1), # 2 x F+3
138 | np.concatenate([np.zeros((1, 3)), np.ones((1, F))], axis=1) # 1 x F+3
139 | ],
140 | axis=0
141 | )
142 | inv_delta_C = np.linalg.inv(delta_C)
143 | return inv_delta_C # F+3 x F+3
144 |
145 | def _build_P(self, I_r_width, I_r_height):
146 | I_r_grid_x = (np.arange(-I_r_width, I_r_width, 2) + 1.0) / I_r_width # self.I_r_width
147 | I_r_grid_y = (np.arange(-I_r_height, I_r_height, 2) + 1.0) / I_r_height # self.I_r_height
148 | P = np.stack( # self.I_r_width x self.I_r_height x 2
149 | np.meshgrid(I_r_grid_x, I_r_grid_y),
150 | axis=2
151 | )
152 | return P.reshape([-1, 2]) # n (= self.I_r_width x self.I_r_height) x 2
153 |
154 | def _build_P_hat(self, F, C, P):
155 | n = P.shape[0] # n (= self.I_r_width x self.I_r_height)
156 | P_tile = np.tile(np.expand_dims(P, axis=1), (1, F, 1)) # n x 2 -> n x 1 x 2 -> n x F x 2
157 | C_tile = np.expand_dims(C, axis=0) # 1 x F x 2
158 | P_diff = P_tile - C_tile # n x F x 2
159 | rbf_norm = np.linalg.norm(P_diff, ord=2, axis=2, keepdims=False) # n x F
160 | rbf = np.multiply(np.square(rbf_norm), np.log(rbf_norm + self.eps)) # n x F
161 | P_hat = np.concatenate([np.ones((n, 1)), P, rbf], axis=1)
162 | return P_hat # n x F+3
163 |
164 | def build_P_prime(self, batch_C_prime):
165 | """ Generate Grid from batch_C_prime [batch_size x F x 2] """
166 | batch_size = batch_C_prime.size(0)
167 | batch_inv_delta_C = self.inv_delta_C.repeat(batch_size, 1, 1)
168 | batch_P_hat = self.P_hat.repeat(batch_size, 1, 1)
169 | batch_C_prime_with_zeros = torch.cat((batch_C_prime, torch.zeros(
170 | batch_size, 3, 2).float().to(self.device)), dim=1) # batch_size x F+3 x 2
171 | batch_T = torch.bmm(batch_inv_delta_C, batch_C_prime_with_zeros) # batch_size x F+3 x 2
172 | batch_P_prime = torch.bmm(batch_P_hat, batch_T) # batch_size x n x 2
173 | return batch_P_prime # batch_size x n x 2
174 |
175 |
176 | class BasicBlockRes(nn.Module):
177 | expansion = 1
178 |
179 | def __init__(self, inplanes, planes, stride=1, downsample=None):
180 | super(BasicBlockRes, self).__init__()
181 | self.conv1 = self._conv3x3(inplanes, planes)
182 | self.bn1 = nn.BatchNorm2d(planes)
183 | self.conv2 = self._conv3x3(planes, planes)
184 | self.bn2 = nn.BatchNorm2d(planes)
185 | self.relu = nn.ReLU(inplace=True)
186 | self.downsample = downsample
187 | self.stride = stride
188 |
189 | def _conv3x3(self, in_planes, out_planes, stride=1):
190 | "3x3 convolution with padding"
191 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
192 | padding=1, bias=False)
193 |
194 | def forward(self, x):
195 | residual = x
196 |
197 | out = self.conv1(x)
198 | out = self.bn1(out)
199 | out = self.relu(out)
200 |
201 | out = self.conv2(out)
202 | out = self.bn2(out)
203 |
204 | if self.downsample is not None:
205 | residual = self.downsample(x)
206 | out += residual
207 | out = self.relu(out)
208 |
209 | return out
210 |
211 |
212 | class ResNet(nn.Module):
213 |
214 | def __init__(self, input_channel, output_channel, block, layers):
215 | super(ResNet, self).__init__()
216 |
217 | self.output_channel_block = [int(output_channel / 4), int(output_channel / 2), output_channel, output_channel]
218 |
219 | self.inplanes = int(output_channel / 8)
220 | self.conv0_1 = nn.Conv2d(input_channel, int(output_channel / 16),
221 | kernel_size=3, stride=1, padding=1, bias=False)
222 | self.bn0_1 = nn.BatchNorm2d(int(output_channel / 16))
223 | self.conv0_2 = nn.Conv2d(int(output_channel / 16), self.inplanes,
224 | kernel_size=3, stride=1, padding=1, bias=False)
225 | self.bn0_2 = nn.BatchNorm2d(self.inplanes)
226 | self.relu = nn.ReLU(inplace=True)
227 |
228 | self.maxpool1 = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
229 | self.layer1 = self._make_layer(block, self.output_channel_block[0], layers[0])
230 | self.conv1 = nn.Conv2d(self.output_channel_block[0], self.output_channel_block[
231 | 0], kernel_size=3, stride=1, padding=1, bias=False)
232 | self.bn1 = nn.BatchNorm2d(self.output_channel_block[0])
233 |
234 | self.maxpool2 = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
235 | self.layer2 = self._make_layer(block, self.output_channel_block[1], layers[1], stride=1)
236 | self.conv2 = nn.Conv2d(self.output_channel_block[1], self.output_channel_block[
237 | 1], kernel_size=3, stride=1, padding=1, bias=False)
238 | self.bn2 = nn.BatchNorm2d(self.output_channel_block[1])
239 |
240 | self.maxpool3 = nn.MaxPool2d(kernel_size=2, stride=(2, 1), padding=(0, 1))
241 | self.layer3 = self._make_layer(block, self.output_channel_block[2], layers[2], stride=1)
242 | self.conv3 = nn.Conv2d(self.output_channel_block[2], self.output_channel_block[
243 | 2], kernel_size=3, stride=1, padding=1, bias=False)
244 | self.bn3 = nn.BatchNorm2d(self.output_channel_block[2])
245 |
246 | self.layer4 = self._make_layer(block, self.output_channel_block[3], layers[3], stride=1)
247 | self.conv4_1 = nn.Conv2d(self.output_channel_block[3], self.output_channel_block[
248 | 3], kernel_size=2, stride=(2, 1), padding=(0, 1), bias=False)
249 | self.bn4_1 = nn.BatchNorm2d(self.output_channel_block[3])
250 | self.conv4_2 = nn.Conv2d(self.output_channel_block[3], self.output_channel_block[
251 | 3], kernel_size=2, stride=1, padding=0, bias=False)
252 | self.bn4_2 = nn.BatchNorm2d(self.output_channel_block[3])
253 |
254 | def _make_layer(self, block, planes, blocks, stride=1):
255 | downsample = None
256 | if stride != 1 or self.inplanes != planes * block.expansion:
257 | downsample = nn.Sequential(
258 | nn.Conv2d(self.inplanes, planes * block.expansion,
259 | kernel_size=1, stride=stride, bias=False),
260 | nn.BatchNorm2d(planes * block.expansion),
261 | )
262 |
263 | layers = []
264 | layers.append(block(self.inplanes, planes, stride, downsample))
265 | self.inplanes = planes * block.expansion
266 | for i in range(1, blocks):
267 | layers.append(block(self.inplanes, planes))
268 |
269 | return nn.Sequential(*layers)
270 |
271 | def forward(self, x):
272 | x = self.conv0_1(x)
273 | x = self.bn0_1(x)
274 | x = self.relu(x)
275 | x = self.conv0_2(x)
276 | x = self.bn0_2(x)
277 | x = self.relu(x)
278 |
279 | x = self.maxpool1(x)
280 | x = self.layer1(x)
281 | x = self.conv1(x)
282 | x = self.bn1(x)
283 | x = self.relu(x)
284 |
285 | x = self.maxpool2(x)
286 | x = self.layer2(x)
287 | x = self.conv2(x)
288 | x = self.bn2(x)
289 | x = self.relu(x)
290 |
291 | x = self.maxpool3(x)
292 | x = self.layer3(x)
293 | x = self.conv3(x)
294 | x = self.bn3(x)
295 | x = self.relu(x)
296 |
297 | x = self.layer4(x)
298 | x = self.conv4_1(x)
299 | x = self.bn4_1(x)
300 | x = self.relu(x)
301 | x = self.conv4_2(x)
302 | x = self.bn4_2(x)
303 | x = self.relu(x)
304 |
305 | return x
306 |
307 |
308 | class ResNet50(nn.Module):
309 |
310 | def __init__(self, input_channel, output_channel=512):
311 | super(ResNet50, self).__init__()
312 | self.ConvNet = ResNet(input_channel, output_channel, BasicBlockRes, [1, 2, 5, 3])
313 |
314 | def forward(self, input):
315 | return self.ConvNet(input)
316 |
317 |
318 | class BiLSTM(nn.Module):
319 |
320 | def __init__(self, input_size, hidden_size, output_size):
321 | super(BiLSTM, self).__init__()
322 | self.rnn = nn.LSTM(input_size, hidden_size, bidirectional=True, batch_first=True)
323 | self.linear = nn.Linear(hidden_size * 2, output_size)
324 |
325 | def forward(self, input):
326 | """
327 | input : visual feature [batch_size x T x input_size]
328 | output : contextual feature [batch_size x T x output_size]
329 | """
330 | self.rnn.flatten_parameters()
331 | recurrent, _ = self.rnn(input) # batch_size x T x input_size -> batch_size x T x (2*hidden_size)
332 | output = self.linear(recurrent) # batch_size x T x output_size
333 | return output
334 |
335 |
336 | def _get_clones(module, N):
337 | return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
338 |
339 |
340 | class TransformerEncoder(nn.Module):
341 | __constants__ = ['norm']
342 |
343 | def __init__(self, encoder_layer, num_layers, norm=None):
344 | super(TransformerEncoder, self).__init__()
345 | self.layers = _get_clones(encoder_layer, num_layers)
346 | self.num_layers = num_layers
347 | self.norm = norm
348 |
349 | def forward(self, src, mask=None, src_key_padding_mask=None):
350 | # type: (Tensor, Optional[Tensor], Optional[Tensor]) -> Tensor
351 | r"""Pass the input through the encoder layers in turn.
352 | Args:
353 | src: the sequence to the encoder (required).
354 | mask: the mask for the src sequence (optional).
355 | src_key_padding_mask: the mask for the src keys per batch (optional).
356 | Shape:
357 | see the docs in Transformer class.
358 | """
359 | output = src
360 |
361 | for mod in self.layers:
362 | output = mod(output, src_mask=mask, src_key_padding_mask=src_key_padding_mask)
363 |
364 | if self.norm is not None:
365 | output = self.norm(output)
366 |
367 | return output
368 |
369 |
370 | def _get_activation_fn(activation):
371 | if activation == "relu":
372 | return F.relu
373 | elif activation == "gelu":
374 | return F.gelu
375 |
376 | raise RuntimeError("activation should be relu/gelu, not {}".format(activation))
377 |
378 |
379 | class MultiheadAttention(nn.Module):
380 | __annotations__ = {
381 | 'bias_k': torch._jit_internal.Optional[torch.Tensor],
382 | 'bias_v': torch._jit_internal.Optional[torch.Tensor],
383 | }
384 | __constants__ = ['q_proj_weight', 'k_proj_weight', 'v_proj_weight', 'in_proj_weight']
385 |
386 | def __init__(self, embed_dim, num_heads, dropout=0., bias=True, add_bias_kv=False, add_zero_attn=False, kdim=None, vdim=None):
387 | super(MultiheadAttention, self).__init__()
388 | self.embed_dim = embed_dim
389 | self.kdim = kdim if kdim is not None else embed_dim
390 | self.vdim = vdim if vdim is not None else embed_dim
391 | self._qkv_same_embed_dim = self.kdim == embed_dim and self.vdim == embed_dim
392 |
393 | self.num_heads = num_heads
394 | self.dropout = dropout
395 | self.head_dim = embed_dim // num_heads
396 | assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
397 |
398 | if self._qkv_same_embed_dim is False:
399 | self.q_proj_weight = nn.Parameter(torch.Tensor(embed_dim, embed_dim))
400 | self.k_proj_weight = nn.Parameter(torch.Tensor(embed_dim, self.kdim))
401 | self.v_proj_weight = nn.Parameter(torch.Tensor(embed_dim, self.vdim))
402 | self.register_parameter('in_proj_weight', None)
403 | else:
404 | self.in_proj_weight = nn.Parameter(torch.empty(3 * embed_dim, embed_dim))
405 | self.register_parameter('q_proj_weight', None)
406 | self.register_parameter('k_proj_weight', None)
407 | self.register_parameter('v_proj_weight', None)
408 |
409 | if bias:
410 | self.in_proj_bias = nn.Parameter(torch.empty(3 * embed_dim))
411 | else:
412 | self.register_parameter('in_proj_bias', None)
413 | self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
414 |
415 | if add_bias_kv:
416 | self.bias_k = nn.Parameter(torch.empty(1, 1, embed_dim))
417 | self.bias_v = nn.Parameter(torch.empty(1, 1, embed_dim))
418 | else:
419 | self.bias_k = self.bias_v = None
420 |
421 | self.add_zero_attn = add_zero_attn
422 |
423 | self._reset_parameters()
424 |
425 | def _reset_parameters(self):
426 | if self._qkv_same_embed_dim:
427 | xavier_uniform_(self.in_proj_weight)
428 | else:
429 | xavier_uniform_(self.q_proj_weight)
430 | xavier_uniform_(self.k_proj_weight)
431 | xavier_uniform_(self.v_proj_weight)
432 |
433 | if self.in_proj_bias is not None:
434 | constant_(self.in_proj_bias, 0.)
435 | constant_(self.out_proj.bias, 0.)
436 | if self.bias_k is not None:
437 | xavier_normal_(self.bias_k)
438 | if self.bias_v is not None:
439 | xavier_normal_(self.bias_v)
440 |
441 | def __setstate__(self, state):
442 | # Support loading old MultiheadAttention checkpoints generated by v1.1.0
443 | if '_qkv_same_embed_dim' not in state:
444 | state['_qkv_same_embed_dim'] = True
445 |
446 | super(MultiheadAttention, self).__setstate__(state)
447 |
448 | def forward(self, query, key, value, key_padding_mask=None,
449 | need_weights=True, attn_mask=None):
450 | if not self._qkv_same_embed_dim:
451 | return F.multi_head_attention_forward(
452 | query, key, value, self.embed_dim, self.num_heads,
453 | self.in_proj_weight, self.in_proj_bias,
454 | self.bias_k, self.bias_v, self.add_zero_attn,
455 | self.dropout, self.out_proj.weight, self.out_proj.bias,
456 | training=self.training,
457 | key_padding_mask=key_padding_mask, need_weights=need_weights,
458 | attn_mask=attn_mask, use_separate_proj_weight=True,
459 | q_proj_weight=self.q_proj_weight, k_proj_weight=self.k_proj_weight,
460 | v_proj_weight=self.v_proj_weight)
461 | else:
462 | return F.multi_head_attention_forward(
463 | query, key, value, self.embed_dim, self.num_heads,
464 | self.in_proj_weight, self.in_proj_bias,
465 | self.bias_k, self.bias_v, self.add_zero_attn,
466 | self.dropout, self.out_proj.weight, self.out_proj.bias,
467 | training=self.training,
468 | key_padding_mask=key_padding_mask, need_weights=need_weights,
469 | attn_mask=attn_mask)
470 |
471 |
472 | class TransformerEncoderLayer(nn.Module):
473 | def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation="relu"):
474 | super(TransformerEncoderLayer, self).__init__()
475 | self.self_attn = MultiheadAttention(d_model, nhead, dropout=dropout)
476 | # Implementation of Feedforward model
477 | self.linear1 = nn.Linear(d_model, dim_feedforward)
478 | self.dropout = nn.Dropout(dropout)
479 | self.linear2 = nn.Linear(dim_feedforward, d_model)
480 |
481 | self.norm1 = nn.LayerNorm(d_model)
482 | self.norm2 = nn.LayerNorm(d_model)
483 | self.dropout1 = nn.Dropout(dropout)
484 | self.dropout2 = nn.Dropout(dropout)
485 |
486 | self.activation = _get_activation_fn(activation)
487 |
488 | def __setstate__(self, state):
489 | if 'activation' not in state:
490 | state['activation'] = F.relu
491 | super(TransformerEncoderLayer, self).__setstate__(state)
492 |
493 | def forward(self, src, src_mask=None, src_key_padding_mask=None):
494 | src2 = self.self_attn(src, src, src, attn_mask=src_mask,
495 | key_padding_mask=src_key_padding_mask)[0]
496 | src = src + self.dropout1(src2)
497 | src = self.norm1(src)
498 | src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
499 | src = src + self.dropout2(src2)
500 | src = self.norm2(src)
501 | return src
502 |
503 |
504 | class PositionalEncoding(nn.Module):
505 |
506 | def __init__(self, d_model, dropout=0.1, max_len=5000):
507 | super(PositionalEncoding, self).__init__()
508 | self.dropout = nn.Dropout(p=dropout)
509 |
510 | pe = torch.zeros(max_len, d_model)
511 | position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
512 | div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
513 | pe[:, 0::2] = torch.sin(position * div_term)
514 | pe[:, 1::2] = torch.cos(position * div_term)
515 | pe = pe.unsqueeze(0).transpose(0, 1)
516 | self.register_buffer('pe', pe)
517 |
518 | def forward(self, x):
519 | x = x + self.pe[:x.size(0), :]
520 | return self.dropout(x)
521 |
522 |
523 | class Transformer(nn.Module):
524 |
525 | def __init__(self, ntoken, ninp, nhid=256, nhead=2, nlayers=2, dropout=0.2):
526 | super(Transformer, self).__init__()
527 | self.src_mask = None
528 | self.pos_encoder = PositionalEncoding(ninp, dropout)
529 | encoder_layers = TransformerEncoderLayer(ninp, nhead, nhid, dropout)
530 | self.transformer_encoder = TransformerEncoder(encoder_layers, nlayers)
531 | self.ninp = ninp
532 | self.decoder = nn.Linear(ninp, ntoken)
533 |
534 | self.init_weights()
535 |
536 | def _generate_square_subsequent_mask(self, sz):
537 | mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
538 | mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
539 | return mask
540 |
541 | def init_weights(self):
542 | initrange = 0.1
543 | self.decoder.bias.data.zero_()
544 | self.decoder.weight.data.uniform_(-initrange, initrange)
545 |
546 | def forward(self, src):
547 | if self.src_mask is None or self.src_mask.size(0) != len(src):
548 | mask = self._generate_square_subsequent_mask(len(src)).to(src.device)
549 | self.src_mask = mask
550 |
551 | src = src * math.sqrt(self.ninp)
552 | src = self.pos_encoder(src)
553 | output = self.transformer_encoder(src, self.src_mask)
554 | output = self.decoder(output)
555 | return output
556 |
--------------------------------------------------------------------------------