├── .gitignore ├── LICENSE ├── README.md ├── config.py ├── data └── demo.png ├── demo.py ├── environment.yml ├── lib ├── __init__.py ├── datasets │ ├── __init__.py │ ├── concatdataset.py │ └── dataset.py ├── evaluation_metrics │ ├── __init__.py │ └── metrics.py ├── evaluators.py ├── loss │ ├── __init__.py │ └── sequenceCrossEntropyLoss.py ├── models │ ├── __init__.py │ ├── attention_recognition_head.py │ ├── model_builder.py │ ├── resnet_aster.py │ ├── stn_head.py │ └── tps_spatial_transformer.py ├── tools │ ├── create_sub_lmdb.py │ └── create_svtp_lmdb.py ├── trainers.py └── utils │ ├── __init__.py │ ├── labelmaps.py │ ├── logging.py │ ├── meters.py │ ├── osutils.py │ ├── serialization.py │ └── visualization_utils.py ├── main.py ├── overview.png └── scripts ├── main_test_all.sh ├── main_test_image.sh └── stn_att_rec.sh /.gitignore: -------------------------------------------------------------------------------- 1 | *~ 2 | 3 | # temporary files which can be created if a process still has a handle open of a deleted file 4 | .fuse_hidden* 5 | 6 | # KDE directory preferences 7 | .directory 8 | 9 | # Linux trash folder which might appear on any partition or disk 10 | .Trash-* 11 | 12 | # .nfs files are created when an open file is removed but is still being accessed 13 | .nfs* 14 | 15 | 16 | *.DS_Store 17 | .AppleDouble 18 | .LSOverride 19 | 20 | # Icon must end with two \r 21 | Icon 22 | 23 | 24 | # Thumbnails 25 | ._* 26 | 27 | # Files that might appear in the root of a volume 28 | .DocumentRevisions-V100 29 | .fseventsd 30 | .Spotlight-V100 31 | .TemporaryItems 32 | .Trashes 33 | .VolumeIcon.icns 34 | .com.apple.timemachine.donotpresent 35 | 36 | # Directories potentially created on remote AFP share 37 | .AppleDB 38 | .AppleDesktop 39 | Network Trash Folder 40 | Temporary Items 41 | .apdisk 42 | 43 | 44 | # swap 45 | [._]*.s[a-v][a-z] 46 | [._]*.sw[a-p] 47 | [._]s[a-v][a-z] 48 | [._]sw[a-p] 49 | # session 50 | Session.vim 51 | # temporary 52 | .netrwhist 53 | *~ 54 | # auto-generated tag files 55 | tags 56 | 57 | 58 | # cache files for sublime text 59 | *.tmlanguage.cache 60 | *.tmPreferences.cache 61 | *.stTheme.cache 62 | 63 | # workspace files are user-specific 64 | *.sublime-workspace 65 | 66 | # project files should be checked into the repository, unless a significant 67 | # proportion of contributors will probably not be using SublimeText 68 | # *.sublime-project 69 | 70 | # sftp configuration file 71 | sftp-config.json 72 | 73 | # Package control specific files 74 | Package Control.last-run 75 | Package Control.ca-list 76 | Package Control.ca-bundle 77 | Package Control.system-ca-bundle 78 | Package Control.cache/ 79 | Package Control.ca-certs/ 80 | Package Control.merged-ca-bundle 81 | Package Control.user-ca-bundle 82 | oscrypto-ca-bundle.crt 83 | bh_unicode_properties.cache 84 | 85 | # Sublime-github package stores a github token in this file 86 | # https://packagecontrol.io/packages/sublime-github 87 | GitHub.sublime-settings 88 | 89 | 90 | # Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio and Webstorm 91 | # Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839 92 | 93 | # User-specific stuff: 94 | .idea 95 | .idea/**/workspace.xml 96 | .idea/**/tasks.xml 97 | 98 | # Sensitive or high-churn files: 99 | .idea/**/dataSources/ 100 | .idea/**/dataSources.ids 101 | .idea/**/dataSources.xml 102 | .idea/**/dataSources.local.xml 103 | .idea/**/sqlDataSources.xml 104 | .idea/**/dynamic.xml 105 | .idea/**/uiDesigner.xml 106 | 107 | # Gradle: 108 | .idea/**/gradle.xml 109 | .idea/**/libraries 110 | 111 | # Mongo Explorer plugin: 112 | .idea/**/mongoSettings.xml 113 | 114 | ## File-based project format: 115 | *.iws 116 | 117 | ## Plugin-specific files: 118 | 119 | # IntelliJ 120 | /out/ 121 | 122 | # mpeltonen/sbt-idea plugin 123 | .idea_modules/ 124 | 125 | # JIRA plugin 126 | atlassian-ide-plugin.xml 127 | 128 | # Crashlytics plugin (for Android Studio and IntelliJ) 129 | com_crashlytics_export_strings.xml 130 | crashlytics.properties 131 | crashlytics-build.properties 132 | fabric.properties 133 | 134 | 135 | # Byte-compiled / optimized / DLL files 136 | __pycache__/ 137 | *.py[cod] 138 | *$py.class 139 | 140 | # C extensions 141 | *.so 142 | 143 | # Distribution / packaging 144 | .Python 145 | env/ 146 | build/ 147 | develop-eggs/ 148 | dist/ 149 | downloads/ 150 | eggs/ 151 | .eggs/ 152 | lib64/ 153 | parts/ 154 | sdist/ 155 | var/ 156 | wheels/ 157 | *.egg-info/ 158 | .installed.cfg 159 | *.egg 160 | MANIFEST 161 | 162 | # PyInstaller 163 | # Usually these files are written by a python script from a template 164 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 165 | *.manifest 166 | *.spec 167 | 168 | # Installer logs 169 | pip-log.txt 170 | pip-delete-this-directory.txt 171 | 172 | # Unit test / coverage reports 173 | htmlcov/ 174 | .tox/ 175 | .coverage 176 | .coverage.* 177 | .cache 178 | nosetests.xml 179 | coverage.xml 180 | *.cover 181 | .hypothesis/ 182 | .pytest_cache/ 183 | 184 | # Translations 185 | *.mo 186 | *.pot 187 | 188 | # Django stuff: 189 | *.log 190 | local_settings.py 191 | 192 | # Flask stuff: 193 | instance/ 194 | .webassets-cache 195 | 196 | # Scrapy stuff: 197 | .scrapy 198 | 199 | # Sphinx documentation 200 | docs/_build/ 201 | 202 | # PyBuilder 203 | target/ 204 | 205 | # IPython Notebook 206 | .ipynb_checkpoints 207 | 208 | # pyenv 209 | .python-version 210 | 211 | # celery beat schedule file 212 | celerybeat-schedule 213 | 214 | # SageMath parsed files 215 | *.sage.py 216 | 217 | # Environments 218 | .env 219 | .venv 220 | env/ 221 | venv/ 222 | ENV/ 223 | env.bak/ 224 | venv.bak/ 225 | 226 | # Spyder project settings 227 | .spyderproject 228 | .spyproject 229 | 230 | # Rope project settings 231 | .ropeproject 232 | 233 | 234 | # Project specific 235 | logs 236 | # *.png 237 | *.jpg 238 | *.jpeg 239 | viz 240 | vis 241 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Mingkun Yang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ASTER: Attentional Scene Text Recognizer with Flexible Rectification 2 | 3 | This repository implements the ASTER in pytorch. Origin software could be found in [here](https://github.com/bgshih/aster). 4 | 5 | ASTER is an accurate scene text recognizer with flexible rectification mechanism. The research paper can be found [here](https://ieeexplore.ieee.org/abstract/document/8395027/). 6 | 7 | ![ASTER Overview](overview.png) 8 | 9 | ## Installation 10 | 11 | ``` 12 | conda env create -f environment.yml 13 | ``` 14 | 15 | ## Train 16 | 17 | [**NOTE**] Some users say that they can't reproduce the reported performance with minor modification, like [1](https://github.com/ayumiymk/aster.pytorch/issues/17#issuecomment-527380815) and [2](https://github.com/ayumiymk/aster.pytorch/issues/17#issuecomment-528718596). I haven't try other settings, so I can't guarantee the same performance with different settings. The users should just run the following script without any modification to reproduce the results. 18 | ``` 19 | bash scripts/stn_att_rec.sh 20 | ``` 21 | 22 | ## Test 23 | 24 | You can test with .lmdb files by 25 | ``` 26 | bash scripts/main_test_all.sh 27 | ``` 28 | Or test with single image by 29 | ``` 30 | bash scripts/main_test_image.sh 31 | ``` 32 | 33 | ## Pretrained model 34 | The pretrained model is available on our [release page](https://github.com/ayumiymk/aster.pytorch/releases/download/v1.0/demo.pth.tar). Download `demo.pth.tar` and put it to somewhere. Before running, modify the `--resume` to the location of this file. 35 | 36 | ## Reproduced results 37 | 38 | | | IIIT5k | SVT | IC03 | IC13 | IC15 | SVTP | CUTE | 39 | |:-------------:|:------:|:----:|:-----:|:-----:|:-----:|:-----:|:-----:| 40 | | ASTER (L2R) | 92.67 | - | 93.72 | 90.74 | - | 78.76 | 76.39 | 41 | | ASTER.Pytorch | 93.2 | 89.2 | 92.2 | 91 | 78.0 | 81.2 | 81.9 | 42 | 43 | At present, the bidirectional attention decoder proposed in ASTER is not included in my implementation. 44 | 45 | You can use the codes to bootstrap for your next text recognition research project. 46 | 47 | 48 | ## Data preparation 49 | 50 | We give an example to construct your own datasets. Details please refer to `tools/create_svtp_lmdb.py`. 51 | 52 | We also provide datasets for [training](https://pan.baidu.com/s/1BMYb93u4gW_3GJdjBWSCSw&shfl=sharepset) (password: wi05) and [testing](https://drive.google.com/open?id=1U4mGLlsm9Ade1-gQOyd6He5R0yiaafYJ). 53 | 54 | ## Citation 55 | 56 | If you find this project helpful for your research, please cite the following papers: 57 | 58 | ``` 59 | @article{bshi2018aster, 60 | author = {Baoguang Shi and 61 | Mingkun Yang and 62 | Xinggang Wang and 63 | Pengyuan Lyu and 64 | Cong Yao and 65 | Xiang Bai}, 66 | title = {ASTER: An Attentional Scene Text Recognizer with Flexible Rectification}, 67 | journal = {IEEE Transactions on Pattern Analysis and Machine Intelligence}, 68 | volume = {41}, 69 | number = {9}, 70 | pages = {2035--2048}, 71 | year = {2019}, 72 | } 73 | 74 | @inproceedings{ShiWLYB16, 75 | author = {Baoguang Shi and 76 | Xinggang Wang and 77 | Pengyuan Lyu and 78 | Cong Yao and 79 | Xiang Bai}, 80 | title = {Robust Scene Text Recognition with Automatic Rectification}, 81 | booktitle = {2016 {IEEE} Conference on Computer Vision and Pattern Recognition, 82 | {CVPR} 2016, Las Vegas, NV, USA, June 27-30, 2016}, 83 | pages = {4168--4176}, 84 | year = {2016} 85 | } 86 | ``` 87 | 88 | IMPORTANT NOTICE: Although this software is licensed under MIT, our intention is to make it free for academic research purposes. If you are going to use it in a product, we suggest you [contact us](xbai@hust.edu.cn) regarding possible patent issues. 89 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | import sys 5 | sys.path.append('./') 6 | 7 | import six 8 | import os 9 | import os.path as osp 10 | import math 11 | import argparse 12 | 13 | 14 | parser = argparse.ArgumentParser(description="Softmax loss classification") 15 | # data 16 | parser.add_argument('--synthetic_train_data_dir', nargs='+', type=str, metavar='PATH', 17 | default=['/share/zhui/reg_dataset/NIPS2014']) 18 | parser.add_argument('--real_train_data_dir', type=str, metavar='PATH', 19 | default='/data/zhui/benchmark/cocotext_trainval') 20 | parser.add_argument('--extra_train_data_dir', nargs='+', type=str, metavar='PATH', 21 | default=['/share/zhui/reg_dataset/CVPR2016']) 22 | parser.add_argument('--test_data_dir', type=str, metavar='PATH', 23 | default='/share/zhui/reg_dataset/IIIT5K_3000') 24 | parser.add_argument('--MULTI_TRAINDATA', action='store_true', default=False, 25 | help='whether use the extra_train_data for training.') 26 | parser.add_argument('-b', '--batch_size', type=int, default=128) 27 | parser.add_argument('-j', '--workers', type=int, default=8) 28 | parser.add_argument('--height', type=int, default=64, 29 | help="input height, default: 256 for resnet*, ""64 for inception") 30 | parser.add_argument('--width', type=int, default=256, 31 | help="input width, default: 128 for resnet*, ""256 for inception") 32 | parser.add_argument('--keep_ratio', action='store_true', default=False, 33 | help='length fixed or lenghth variable.') 34 | parser.add_argument('--voc_type', type=str, default='ALLCASES_SYMBOLS', 35 | choices=['LOWERCASE', 'ALLCASES', 'ALLCASES_SYMBOLS']) 36 | parser.add_argument('--mix_data', action='store_true', 37 | help="whether combine multi datasets in the training stage.") 38 | parser.add_argument('--num_train', type=int, default=math.inf) 39 | parser.add_argument('--num_test', type=int, default=math.inf) 40 | parser.add_argument('--aug', action='store_true', default=False, 41 | help='whether use data augmentation.') 42 | parser.add_argument('--lexicon_type', type=str, default='0', choices=['0', '50', '1k', 'full'], 43 | help='which lexicon associated to image is used.') 44 | parser.add_argument('--image_path', type=str, default='', 45 | help='the path of single image, used in demo.py.') 46 | parser.add_argument('--tps_inputsize', nargs='+', type=int, default=[32, 64]) 47 | parser.add_argument('--tps_outputsize', nargs='+', type=int, default=[32, 100]) 48 | # model 49 | parser.add_argument('-a', '--arch', type=str, default='ResNet_ASTER') 50 | parser.add_argument('--dropout', type=float, default=0.5) 51 | parser.add_argument('--max_len', type=int, default=100) 52 | parser.add_argument('--n_group', type=int, default=1) 53 | parser.add_argument('--STN_ON', action='store_true', 54 | help='add the stn head.') 55 | parser.add_argument('--tps_margins', nargs='+', type=float, default=[0.05,0.05]) 56 | parser.add_argument('--stn_activation', type=str, default='none') 57 | parser.add_argument('--num_control_points', type=int, default=20) 58 | parser.add_argument('--stn_with_dropout', action='store_true', default=False) 59 | ## lstm 60 | parser.add_argument('--with_lstm', action='store_true', default=False, 61 | help='whether append lstm after cnn in the encoder part.') 62 | parser.add_argument('--decoder_sdim', type=int, default=512, 63 | help="the dim of hidden layer in decoder.") 64 | parser.add_argument('--attDim', type=int, default=512, 65 | help="the dim for attention.") 66 | # optimizer 67 | parser.add_argument('--lr', type=float, default=1, 68 | help="learning rate of new parameters, for pretrained " 69 | "parameters it is 10 times smaller than this") 70 | parser.add_argument('--momentum', type=float, default=0.9) 71 | parser.add_argument('--weight_decay', type=float, default=0.0) # the model maybe under-fitting, 0.0 gives much better results. 72 | parser.add_argument('--grad_clip', type=float, default=1.0) 73 | parser.add_argument('--loss_weights', nargs='+', type=float, default=[1,1,1]) 74 | # training configs 75 | parser.add_argument('--resume', type=str, default='', metavar='PATH') 76 | parser.add_argument('--evaluate', action='store_true', 77 | help="evaluation only") 78 | parser.add_argument('--epochs', type=int, default=6) 79 | parser.add_argument('--start_save', type=int, default=0, 80 | help="start saving checkpoints after specific epoch") 81 | parser.add_argument('--seed', type=int, default=1) 82 | parser.add_argument('--print_freq', type=int, default=100) 83 | parser.add_argument('--cuda', default=True, type=bool, 84 | help='whether use cuda support.') 85 | # testing configs 86 | parser.add_argument('--evaluation_metric', type=str, default='accuracy') 87 | parser.add_argument('--evaluate_with_lexicon', action='store_true', default=False) 88 | parser.add_argument('--beam_width', type=int, default=5) 89 | # misc 90 | working_dir = osp.dirname(osp.dirname(osp.abspath(__file__))) 91 | parser.add_argument('--logs_dir', type=str, metavar='PATH', 92 | default=osp.join(working_dir, 'logs')) 93 | parser.add_argument('--real_logs_dir', type=str, metavar='PATH', 94 | default='/media/mkyang/research/recognition/selfattention_rec') 95 | parser.add_argument('--debug', action='store_true', 96 | help="if debugging, some steps will be passed.") 97 | parser.add_argument('--vis_dir', type=str, metavar='PATH', default='', 98 | help="whether visualize the results while evaluation.") 99 | parser.add_argument('--run_on_remote', action='store_true', default=False, 100 | help="run the code on remote or local.") 101 | 102 | def get_args(sys_args): 103 | global_args = parser.parse_args(sys_args) 104 | return global_args -------------------------------------------------------------------------------- /data/demo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ayumiymk/aster.pytorch/be670046c775b54de79766208f0c59321ae1eccf/data/demo.png -------------------------------------------------------------------------------- /demo.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | import sys 3 | sys.path.append('./') 4 | 5 | import argparse 6 | import os 7 | import os.path as osp 8 | import numpy as np 9 | import math 10 | import time 11 | from PIL import Image, ImageFile 12 | 13 | import torch 14 | from torch import nn, optim 15 | from torch.backends import cudnn 16 | from torch.utils.data import DataLoader 17 | from torchvision import transforms 18 | 19 | from config import get_args 20 | from lib import datasets, evaluation_metrics, models 21 | from lib.models.model_builder import ModelBuilder 22 | from lib.datasets.dataset import LmdbDataset, AlignCollate 23 | from lib.loss import SequenceCrossEntropyLoss 24 | from lib.trainers import Trainer 25 | from lib.evaluators import Evaluator 26 | from lib.utils.logging import Logger, TFLogger 27 | from lib.utils.serialization import load_checkpoint, save_checkpoint 28 | from lib.utils.osutils import make_symlink_if_not_exists 29 | from lib.evaluation_metrics.metrics import get_str_list 30 | from lib.utils.labelmaps import get_vocabulary, labels2strs 31 | 32 | global_args = get_args(sys.argv[1:]) 33 | 34 | def image_process(image_path, imgH=32, imgW=100, keep_ratio=False, min_ratio=1): 35 | img = Image.open(image_path).convert('RGB') 36 | 37 | if keep_ratio: 38 | w, h = img.size 39 | ratio = w / float(h) 40 | imgW = int(np.floor(ratio * imgH)) 41 | imgW = max(imgH * min_ratio, imgW) 42 | 43 | img = img.resize((imgW, imgH), Image.BILINEAR) 44 | img = transforms.ToTensor()(img) 45 | img.sub_(0.5).div_(0.5) 46 | 47 | return img 48 | 49 | class DataInfo(object): 50 | """ 51 | Save the info about the dataset. 52 | This a code snippet from dataset.py 53 | """ 54 | def __init__(self, voc_type): 55 | super(DataInfo, self).__init__() 56 | self.voc_type = voc_type 57 | 58 | assert voc_type in ['LOWERCASE', 'ALLCASES', 'ALLCASES_SYMBOLS'] 59 | self.EOS = 'EOS' 60 | self.PADDING = 'PADDING' 61 | self.UNKNOWN = 'UNKNOWN' 62 | self.voc = get_vocabulary(voc_type, EOS=self.EOS, PADDING=self.PADDING, UNKNOWN=self.UNKNOWN) 63 | self.char2id = dict(zip(self.voc, range(len(self.voc)))) 64 | self.id2char = dict(zip(range(len(self.voc)), self.voc)) 65 | 66 | self.rec_num_classes = len(self.voc) 67 | 68 | 69 | def main(args): 70 | np.random.seed(args.seed) 71 | torch.manual_seed(args.seed) 72 | torch.cuda.manual_seed(args.seed) 73 | torch.cuda.manual_seed_all(args.seed) 74 | cudnn.benchmark = True 75 | torch.backends.cudnn.deterministic = True 76 | 77 | args.cuda = args.cuda and torch.cuda.is_available() 78 | if args.cuda: 79 | print('using cuda.') 80 | torch.set_default_tensor_type('torch.cuda.FloatTensor') 81 | else: 82 | torch.set_default_tensor_type('torch.FloatTensor') 83 | 84 | # Create data loaders 85 | if args.height is None or args.width is None: 86 | args.height, args.width = (32, 100) 87 | 88 | dataset_info = DataInfo(args.voc_type) 89 | 90 | # Create model 91 | model = ModelBuilder(arch=args.arch, rec_num_classes=dataset_info.rec_num_classes, 92 | sDim=args.decoder_sdim, attDim=args.attDim, max_len_labels=args.max_len, 93 | eos=dataset_info.char2id[dataset_info.EOS], STN_ON=args.STN_ON) 94 | 95 | # Load from checkpoint 96 | if args.resume: 97 | checkpoint = load_checkpoint(args.resume) 98 | model.load_state_dict(checkpoint['state_dict']) 99 | 100 | if args.cuda: 101 | device = torch.device("cuda") 102 | model = model.to(device) 103 | model = nn.DataParallel(model) 104 | 105 | # Evaluation 106 | model.eval() 107 | img = image_process(args.image_path) 108 | with torch.no_grad(): 109 | img = img.to(device) 110 | input_dict = {} 111 | input_dict['images'] = img.unsqueeze(0) 112 | # TODO: testing should be more clean. 113 | # to be compatible with the lmdb-based testing, need to construct some meaningless variables. 114 | rec_targets = torch.IntTensor(1, args.max_len).fill_(1) 115 | rec_targets[:,args.max_len-1] = dataset_info.char2id[dataset_info.EOS] 116 | input_dict['rec_targets'] = rec_targets 117 | input_dict['rec_lengths'] = [args.max_len] 118 | output_dict = model(input_dict) 119 | pred_rec = output_dict['output']['pred_rec'] 120 | pred_str, _ = get_str_list(pred_rec, input_dict['rec_targets'], dataset=dataset_info) 121 | print('Recognition result: {0}'.format(pred_str[0])) 122 | 123 | 124 | if __name__ == '__main__': 125 | # parse the config 126 | args = get_args(sys.argv[1:]) 127 | main(args) -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: recognition 2 | channels: 3 | - pytorch 4 | - https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud/pytorch/ 5 | - defaults 6 | - conda-forge 7 | dependencies: 8 | - _libgcc_mutex=0.1=main 9 | - absl-py=0.7.1=py37_0 10 | - astor=0.7.1=py_0 11 | - blas=2.11=openblas 12 | - bzip2=1.0.8=h7b6447c_0 13 | - c-ares=1.15.0=h516909a_1001 14 | - ca-certificates=2019.6.16=hecc5488_0 15 | - cairo=1.14.12=h8948797_3 16 | - certifi=2019.6.16=py37_1 17 | - cffi=1.12.3=py37h2e261b9_0 18 | - cudatoolkit=9.0=h13b8566_0 19 | - cycler=0.10.0=py_1 20 | - dbus=1.13.6=h746ee38_0 21 | - editdistance=0.5.3=py37hf484d3e_0 22 | - expat=2.2.6=he6710b0_0 23 | - ffmpeg=4.0=hcdf2ecd_0 24 | - fontconfig=2.13.1=he4413a7_1000 25 | - freeglut=3.0.0=hf484d3e_5 26 | - freetype=2.9.1=h8a8886c_1 27 | - gast=0.2.2=py_0 28 | - gettext=0.19.8.1=hc5be6a0_1002 29 | - glib=2.56.2=hd408876_0 30 | - graphite2=1.3.13=h23475e2_0 31 | - grpcio=1.16.1=py37hf8bcb03_1 32 | - gst-plugins-base=1.14.0=hbbd80ab_1 33 | - gstreamer=1.14.0=hb453b48_1 34 | - h5py=2.8.0=py37h3010b51_1003 35 | - harfbuzz=1.8.8=hffaf4a1_0 36 | - hdf5=1.10.2=hba1933b_1 37 | - icu=58.2=h9c2bf20_1 38 | - intel-openmp=2019.4=243 39 | - jasper=2.0.14=h07fcdf6_1 40 | - jpeg=9c=h14c3975_1001 41 | - keras-applications=1.0.7=py_1 42 | - keras-preprocessing=1.0.9=py_1 43 | - kiwisolver=1.1.0=py37hc9558a2_0 44 | - libblas=3.8.0=11_openblas 45 | - libcblas=3.8.0=11_openblas 46 | - libedit=3.1.20181209=hc058e9b_0 47 | - libffi=3.2.1=hd88cf55_4 48 | - libgcc-ng=9.1.0=hdf63c60_0 49 | - libgfortran-ng=7.3.0=hdf63c60_0 50 | - libglu=9.0.0=hf484d3e_1 51 | - libiconv=1.15=h516909a_1005 52 | - liblapack=3.8.0=11_openblas 53 | - liblapacke=3.8.0=11_openblas 54 | - libopenblas=0.3.6=h5a2b251_1 55 | - libopencv=3.4.2=hb342d67_1 56 | - libopus=1.3=h7b6447c_0 57 | - libpng=1.6.37=hbc83047_0 58 | - libprotobuf=3.9.1=h8b12597_0 59 | - libstdcxx-ng=9.1.0=hdf63c60_0 60 | - libtiff=4.0.10=h2733197_2 61 | - libuuid=2.32.1=h14c3975_1000 62 | - libvpx=1.7.0=h439df22_0 63 | - libxcb=1.13=h14c3975_1002 64 | - libxml2=2.9.9=hea5a465_1 65 | - lmdb=0.9.24=h516909a_0 66 | - markdown=3.1.1=py_0 67 | - matplotlib=3.1.0=py37h5429711_0 68 | - matplotlib-base=3.1.1=py37hfd891ef_0 69 | - mkl=2019.4=243 70 | - mkl-service=2.1.0=py37h516909a_0 71 | - mkl_fft=1.0.14=py37h516909a_1 72 | - mkl_random=1.0.4=py37hf2d7682_0 73 | - mock=3.0.5=py37_0 74 | - ncurses=6.1=he6710b0_1 75 | - ninja=1.9.0=py37hfd86e86_0 76 | - numpy=1.16.4=py37h99e49ec_0 77 | - numpy-base=1.16.4=py37h2f8d375_0 78 | - olefile=0.46=py37_0 79 | - openblas=0.3.3=h9ac9557_1001 80 | - opencv=3.4.2=py37h6fd60c2_1 81 | - openssl=1.1.1c=h516909a_0 82 | - pcre=8.43=he6710b0_0 83 | - pillow=6.1.0=py37h6b7be26_1 84 | - pip=19.1.1=py37_0 85 | - pixman=0.38.0=h7b6447c_0 86 | - protobuf=3.9.1=py37he1b5a44_0 87 | - pthread-stubs=0.4=h14c3975_1001 88 | - py-opencv=3.4.2=py37hb342d67_1 89 | - pycparser=2.19=py37_0 90 | - pyparsing=2.4.2=py_0 91 | - pyqt=5.9.2=py37hcca6a23_2 92 | - python=3.7.3=h0371630_0 93 | - python-dateutil=2.8.0=py_0 94 | - python-lmdb=0.96=py37he1b5a44_0 95 | - pytorch=1.1.0=py3.7_cuda9.0.176_cudnn7.5.1_0 96 | - pytz=2019.1=py_0 97 | - qt=5.9.7=h5867ecd_1 98 | - readline=7.0=h7b6447c_5 99 | - scipy=1.1.0=py37he2b7bc3_2 100 | - setuptools=41.0.1=py37_0 101 | - sip=4.19.8=py37hf484d3e_1000 102 | - six=1.12.0=py37_0 103 | - sqlite=3.29.0=h7b6447c_0 104 | - tensorboard=1.13.1=py37_0 105 | - tensorflow=1.13.1=py37h90a7d86_1 106 | - tensorflow-estimator=1.13.0=py_0 107 | - termcolor=1.1.0=py_2 108 | - tk=8.6.9=hed695b0_1002 109 | - torchvision=0.3.0=py37_cu9.0.176_1 110 | - tornado=6.0.3=py37h516909a_0 111 | - tqdm=4.33.0=py_0 112 | - werkzeug=0.15.5=py_0 113 | - wheel=0.33.4=py37_0 114 | - xorg-libxau=1.0.9=h14c3975_0 115 | - xorg-libxdmcp=1.1.3=h516909a_0 116 | - xz=5.2.4=h14c3975_4 117 | - zlib=1.2.11=h7b6447c_3 118 | - zstd=1.3.7=h0b5b093_0 119 | prefix: /home/mkyang/miniconda3/envs/recognition 120 | 121 | -------------------------------------------------------------------------------- /lib/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from . import datasets 4 | from . import evaluation_metrics 5 | from . import loss 6 | from . import models 7 | from . import utils 8 | from . import evaluators 9 | from . import trainers 10 | 11 | __version__ = '1.0.1.post2' -------------------------------------------------------------------------------- /lib/datasets/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ayumiymk/aster.pytorch/be670046c775b54de79766208f0c59321ae1eccf/lib/datasets/__init__.py -------------------------------------------------------------------------------- /lib/datasets/concatdataset.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | import bisect 3 | import warnings 4 | 5 | import torch 6 | from torch import randperm 7 | from torch._utils import _accumulate 8 | from torch.utils.data import Dataset 9 | 10 | class ConcatDataset(Dataset): 11 | """ 12 | Dataset to concatenate multiple datasets. 13 | Purpose: useful to assemble different existing datasets, possibly 14 | large-scale datasets as the concatenation operation is done in an 15 | on-the-fly manner. 16 | Arguments: 17 | datasets (sequence): List of datasets to be concatenated 18 | """ 19 | 20 | @staticmethod 21 | def cumsum(sequence): 22 | r, s = [], 0 23 | for e in sequence: 24 | l = len(e) 25 | r.append(l + s) 26 | s += l 27 | return r 28 | 29 | def __init__(self, datasets): 30 | super(ConcatDataset, self).__init__() 31 | assert len(datasets) > 0, 'datasets should not be an empty iterable' 32 | self.datasets = list(datasets) 33 | self.cumulative_sizes = self.cumsum(self.datasets) 34 | self.max_len = max([_dataset.max_len for _dataset in self.datasets]) 35 | for _dataset in self.datasets: 36 | _dataset.max_len = self.max_len 37 | 38 | def __len__(self): 39 | return self.cumulative_sizes[-1] 40 | 41 | def __getitem__(self, idx): 42 | dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx) 43 | if dataset_idx == 0: 44 | sample_idx = idx 45 | else: 46 | sample_idx = idx - self.cumulative_sizes[dataset_idx - 1] 47 | return self.datasets[dataset_idx][sample_idx] 48 | 49 | @property 50 | def cummulative_sizes(self): 51 | warnings.warn("cummulative_sizes attribute is renamed to " 52 | "cumulative_sizes", DeprecationWarning, stacklevel=2) 53 | return self.cumulative_sizes -------------------------------------------------------------------------------- /lib/datasets/dataset.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | # import sys 4 | # sys.path.append('./') 5 | 6 | import os 7 | # import moxing as mox 8 | 9 | import pickle 10 | from tqdm import tqdm 11 | from PIL import Image, ImageFile 12 | import numpy as np 13 | import random 14 | import cv2 15 | import lmdb 16 | import sys 17 | import six 18 | 19 | import torch 20 | from torch.utils import data 21 | from torch.utils.data import sampler 22 | from torchvision import transforms 23 | 24 | from lib.utils.labelmaps import get_vocabulary, labels2strs 25 | from lib.utils import to_numpy 26 | 27 | ImageFile.LOAD_TRUNCATED_IMAGES = True 28 | 29 | 30 | from config import get_args 31 | global_args = get_args(sys.argv[1:]) 32 | 33 | if global_args.run_on_remote: 34 | import moxing as mox 35 | 36 | class LmdbDataset(data.Dataset): 37 | def __init__(self, root, voc_type, max_len, num_samples, transform=None): 38 | super(LmdbDataset, self).__init__() 39 | 40 | if global_args.run_on_remote: 41 | dataset_name = os.path.basename(root) 42 | data_cache_url = "/cache/%s" % dataset_name 43 | if not os.path.exists(data_cache_url): 44 | os.makedirs(data_cache_url) 45 | if mox.file.exists(root): 46 | mox.file.copy_parallel(root, data_cache_url) 47 | else: 48 | raise ValueError("%s not exists!" % root) 49 | 50 | self.env = lmdb.open(data_cache_url, max_readers=32, readonly=True) 51 | else: 52 | self.env = lmdb.open(root, max_readers=32, readonly=True) 53 | 54 | assert self.env is not None, "cannot create lmdb from %s" % root 55 | self.txn = self.env.begin() 56 | 57 | self.voc_type = voc_type 58 | self.transform = transform 59 | self.max_len = max_len 60 | self.nSamples = int(self.txn.get(b"num-samples")) 61 | self.nSamples = min(self.nSamples, num_samples) 62 | 63 | assert voc_type in ['LOWERCASE', 'ALLCASES', 'ALLCASES_SYMBOLS'] 64 | self.EOS = 'EOS' 65 | self.PADDING = 'PADDING' 66 | self.UNKNOWN = 'UNKNOWN' 67 | self.voc = get_vocabulary(voc_type, EOS=self.EOS, PADDING=self.PADDING, UNKNOWN=self.UNKNOWN) 68 | self.char2id = dict(zip(self.voc, range(len(self.voc)))) 69 | self.id2char = dict(zip(range(len(self.voc)), self.voc)) 70 | 71 | self.rec_num_classes = len(self.voc) 72 | self.lowercase = (voc_type == 'LOWERCASE') 73 | 74 | def __len__(self): 75 | return self.nSamples 76 | 77 | def __getitem__(self, index): 78 | assert index <= len(self), 'index range error' 79 | index += 1 80 | img_key = b'image-%09d' % index 81 | imgbuf = self.txn.get(img_key) 82 | 83 | buf = six.BytesIO() 84 | buf.write(imgbuf) 85 | buf.seek(0) 86 | try: 87 | img = Image.open(buf).convert('RGB') 88 | # img = Image.open(buf).convert('L') 89 | # img = img.convert('RGB') 90 | except IOError: 91 | print('Corrupted image for %d' % index) 92 | return self[index + 1] 93 | 94 | # reconition labels 95 | label_key = b'label-%09d' % index 96 | word = self.txn.get(label_key).decode() 97 | if self.lowercase: 98 | word = word.lower() 99 | ## fill with the padding token 100 | label = np.full((self.max_len,), self.char2id[self.PADDING], dtype=np.int) 101 | label_list = [] 102 | for char in word: 103 | if char in self.char2id: 104 | label_list.append(self.char2id[char]) 105 | else: 106 | ## add the unknown token 107 | print('{0} is out of vocabulary.'.format(char)) 108 | label_list.append(self.char2id[self.UNKNOWN]) 109 | ## add a stop token 110 | label_list = label_list + [self.char2id[self.EOS]] 111 | assert len(label_list) <= self.max_len 112 | label[:len(label_list)] = np.array(label_list) 113 | 114 | if len(label) <= 0: 115 | return self[index + 1] 116 | 117 | # label length 118 | label_len = len(label_list) 119 | 120 | if self.transform is not None: 121 | img = self.transform(img) 122 | return img, label, label_len 123 | 124 | 125 | class ResizeNormalize(object): 126 | def __init__(self, size, interpolation=Image.BILINEAR): 127 | self.size = size 128 | self.interpolation = interpolation 129 | self.toTensor = transforms.ToTensor() 130 | 131 | def __call__(self, img): 132 | img = img.resize(self.size, self.interpolation) 133 | img = self.toTensor(img) 134 | img.sub_(0.5).div_(0.5) 135 | return img 136 | 137 | 138 | class RandomSequentialSampler(sampler.Sampler): 139 | 140 | def __init__(self, data_source, batch_size): 141 | self.num_samples = len(data_source) 142 | self.batch_size = batch_size 143 | 144 | def __len__(self): 145 | return self.num_samples 146 | 147 | def __iter__(self): 148 | n_batch = len(self) // self.batch_size 149 | tail = len(self) % self.batch_size 150 | index = torch.LongTensor(len(self)).fill_(0) 151 | for i in range(n_batch): 152 | random_start = random.randint(0, len(self) - self.batch_size) 153 | batch_index = random_start + torch.arange(0, self.batch_size) 154 | index[i * self.batch_size:(i + 1) * self.batch_size] = batch_index 155 | # deal with tail 156 | if tail: 157 | random_start = random.randint(0, len(self) - self.batch_size) 158 | tail_index = random_start + torch.arange(0, tail) 159 | index[(i + 1) * self.batch_size:] = tail_index 160 | 161 | return iter(index.tolist()) 162 | 163 | 164 | class AlignCollate(object): 165 | 166 | def __init__(self, imgH=32, imgW=100, keep_ratio=False, min_ratio=1): 167 | self.imgH = imgH 168 | self.imgW = imgW 169 | self.keep_ratio = keep_ratio 170 | self.min_ratio = min_ratio 171 | 172 | def __call__(self, batch): 173 | images, labels, lengths = zip(*batch) 174 | b_lengths = torch.IntTensor(lengths) 175 | b_labels = torch.IntTensor(labels) 176 | 177 | imgH = self.imgH 178 | imgW = self.imgW 179 | if self.keep_ratio: 180 | ratios = [] 181 | for image in images: 182 | w, h = image.size 183 | ratios.append(w / float(h)) 184 | ratios.sort() 185 | max_ratio = ratios[-1] 186 | imgW = int(np.floor(max_ratio * imgH)) 187 | imgW = max(imgH * self.min_ratio, imgW) # assure imgH >= imgW 188 | imgW = min(imgW, 400) 189 | 190 | transform = ResizeNormalize((imgW, imgH)) 191 | images = [transform(image) for image in images] 192 | b_images = torch.stack(images) 193 | 194 | return b_images, b_labels, b_lengths 195 | 196 | 197 | def test(): 198 | # lmdb_path = "/share/zhui/reg_dataset/NIPS2014" 199 | lmdb_path = "/share/zhui/reg_dataset/IIIT5K_3000" 200 | train_dataset = LmdbDataset(root=lmdb_path, voc_type='ALLCASES_SYMBOLS', max_len=50) 201 | batch_size = 1 202 | train_dataloader = data.DataLoader( 203 | train_dataset, 204 | batch_size=batch_size, 205 | shuffle=False, 206 | num_workers=4, 207 | collate_fn=AlignCollate(imgH=64, imgW=256, keep_ratio=False)) 208 | 209 | for i, (images, labels, label_lens) in enumerate(train_dataloader): 210 | # visualization of input image 211 | # toPILImage = transforms.ToPILImage() 212 | images = images.permute(0,2,3,1) 213 | images = to_numpy(images) 214 | images = images * 0.5 + 0.5 215 | images = images * 255 216 | for id, (image, label, label_len) in enumerate(zip(images, labels, label_lens)): 217 | image = Image.fromarray(np.uint8(image)) 218 | # image = toPILImage(image) 219 | image.show() 220 | print(image.size) 221 | print(labels2strs(label, train_dataset.id2char, train_dataset.char2id)) 222 | print(label_len.item()) 223 | input() 224 | 225 | 226 | if __name__ == "__main__": 227 | test() -------------------------------------------------------------------------------- /lib/evaluation_metrics/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from .metrics import Accuracy, EditDistance, RecPostProcess, Accuracy_with_lexicon, EditDistance_with_lexicon 4 | 5 | 6 | __factory = { 7 | 'accuracy': Accuracy, 8 | 'editdistance': EditDistance, 9 | 'accuracy_with_lexicon': Accuracy_with_lexicon, 10 | 'editdistance_with_lexicon': EditDistance_with_lexicon, 11 | } 12 | 13 | def names(): 14 | return sorted(__factory.keys()) 15 | 16 | def factory(): 17 | return __factory -------------------------------------------------------------------------------- /lib/evaluation_metrics/metrics.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import numpy as np 4 | import editdistance 5 | import string 6 | import math 7 | 8 | import torch 9 | import torch.nn.functional as F 10 | 11 | from ..utils import to_torch, to_numpy 12 | 13 | 14 | def _normalize_text(text): 15 | text = ''.join(filter(lambda x: x in (string.digits + string.ascii_letters), text)) 16 | return text.lower() 17 | 18 | 19 | def get_str_list(output, target, dataset=None): 20 | # label_seq 21 | assert output.dim() == 2 and target.dim() == 2 22 | 23 | end_label = dataset.char2id[dataset.EOS] 24 | unknown_label = dataset.char2id[dataset.UNKNOWN] 25 | num_samples, max_len_labels = output.size() 26 | num_classes = len(dataset.char2id.keys()) 27 | assert num_samples == target.size(0) and max_len_labels == target.size(1) 28 | output = to_numpy(output) 29 | target = to_numpy(target) 30 | 31 | # list of char list 32 | pred_list, targ_list = [], [] 33 | for i in range(num_samples): 34 | pred_list_i = [] 35 | for j in range(max_len_labels): 36 | if output[i, j] != end_label: 37 | if output[i, j] != unknown_label: 38 | pred_list_i.append(dataset.id2char[output[i, j]]) 39 | else: 40 | break 41 | pred_list.append(pred_list_i) 42 | 43 | for i in range(num_samples): 44 | targ_list_i = [] 45 | for j in range(max_len_labels): 46 | if target[i, j] != end_label: 47 | if target[i, j] != unknown_label: 48 | targ_list_i.append(dataset.id2char[target[i, j]]) 49 | else: 50 | break 51 | targ_list.append(targ_list_i) 52 | 53 | # char list to string 54 | # if dataset.lowercase: 55 | if True: 56 | # pred_list = [''.join(pred).lower() for pred in pred_list] 57 | # targ_list = [''.join(targ).lower() for targ in targ_list] 58 | pred_list = [_normalize_text(pred) for pred in pred_list] 59 | targ_list = [_normalize_text(targ) for targ in targ_list] 60 | else: 61 | pred_list = [''.join(pred) for pred in pred_list] 62 | targ_list = [''.join(targ) for targ in targ_list] 63 | 64 | return pred_list, targ_list 65 | 66 | 67 | def _lexicon_search(lexicon, word): 68 | edit_distances = [] 69 | for lex_word in lexicon: 70 | edit_distances.append(editdistance.eval(_normalize_text(lex_word), _normalize_text(word))) 71 | edit_distances = np.asarray(edit_distances, dtype=np.int) 72 | argmin = np.argmin(edit_distances) 73 | return lexicon[argmin] 74 | 75 | 76 | def Accuracy(output, target, dataset=None): 77 | pred_list, targ_list = get_str_list(output, target, dataset) 78 | 79 | acc_list = [(pred == targ) for pred, targ in zip(pred_list, targ_list)] 80 | accuracy = 1.0 * sum(acc_list) / len(acc_list) 81 | return accuracy 82 | 83 | 84 | def Accuracy_with_lexicon(output, target, dataset=None, file_names=None): 85 | pred_list, targ_list = get_str_list(output, target, dataset) 86 | accuracys = [] 87 | 88 | # with no lexicon 89 | acc_list = [(pred == targ) for pred, targ in zip(pred_list, targ_list)] 90 | accuracy = 1.0 * sum(acc_list) / len(acc_list) 91 | accuracys.append(accuracy) 92 | 93 | # lexicon50 94 | if len(file_names) == 0 or len(dataset.lexicons50[file_names[0]]) == 0: 95 | accuracys.append(0) 96 | else: 97 | refined_pred_list = [_lexicon_search(dataset.lexicons50[file_name], pred) for file_name, pred in zip(file_names, pred_list)] 98 | acc_list = [(pred == targ) for pred, targ in zip(refined_pred_list, targ_list)] 99 | accuracy = 1.0 * sum(acc_list) / len(acc_list) 100 | accuracys.append(accuracy) 101 | 102 | # lexicon1k 103 | if len(file_names) == 0 or len(dataset.lexicons1k[file_names[0]]) == 0: 104 | accuracys.append(0) 105 | else: 106 | refined_pred_list = [_lexicon_search(dataset.lexicons1k[file_name], pred) for file_name, pred in zip(file_names, pred_list)] 107 | acc_list = [(pred == targ) for pred, targ in zip(refined_pred_list, targ_list)] 108 | accuracy = 1.0 * sum(acc_list) / len(acc_list) 109 | accuracys.append(accuracy) 110 | 111 | # lexiconfull 112 | if len(file_names) == 0 or len(dataset.lexiconsfull[file_names[0]]) == 0: 113 | accuracys.append(0) 114 | else: 115 | refined_pred_list = [_lexicon_search(dataset.lexiconsfull[file_name], pred) for file_name, pred in zip(file_names, pred_list)] 116 | acc_list = [(pred == targ) for pred, targ in zip(refined_pred_list, targ_list)] 117 | accuracy = 1.0 * sum(acc_list) / len(acc_list) 118 | accuracys.append(accuracy) 119 | 120 | return accuracys 121 | 122 | 123 | def EditDistance(output, target, dataset=None): 124 | pred_list, targ_list = get_str_list(output, target, dataset) 125 | 126 | ed_list = [editdistance.eval(pred, targ) for pred, targ in zip(pred_list, targ_list)] 127 | eds = sum(ed_list) 128 | return eds 129 | 130 | 131 | def EditDistance_with_lexicon(output, target, dataset=None, file_names=None): 132 | pred_list, targ_list = get_str_list(output, target, dataset) 133 | eds = [] 134 | 135 | # with no lexicon 136 | ed_list = [editdistance.eval(pred, targ) for pred, targ in zip(pred_list, targ_list)] 137 | ed = sum(ed_list) 138 | eds.append(ed) 139 | 140 | # lexicon50 141 | if len(file_names) == 0 or len(dataset.lexicons50[file_names[0]]) == 0: 142 | eds.append(0) 143 | else: 144 | refined_pred_list = [_lexicon_search(dataset.lexicons50[file_name], pred) for file_name, pred in zip(file_names, pred_list)] 145 | ed_list = [editdistance.eval(pred, targ) for pred, targ in zip(refined_pred_list, targ_list)] 146 | ed = sum(ed_list) 147 | eds.append(ed) 148 | 149 | # lexicon1k 150 | if len(file_names) == 0 or len(dataset.lexicons1k[file_names[0]]) == 0: 151 | eds.append(0) 152 | else: 153 | refined_pred_list = [_lexicon_search(dataset.lexicons1k[file_name], pred) for file_name, pred in zip(file_names, pred_list)] 154 | ed_list = [editdistance.eval(pred, targ) for pred, targ in zip(refined_pred_list, targ_list)] 155 | ed = sum(ed_list) 156 | eds.append(ed) 157 | 158 | # lexiconfull 159 | if len(file_names) == 0 or len(dataset.lexiconsfull[file_names[0]]) == 0: 160 | eds.append(0) 161 | else: 162 | refined_pred_list = [_lexicon_search(dataset.lexiconsfull[file_name], pred) for file_name, pred in zip(file_names, pred_list)] 163 | ed_list = [editdistance.eval(pred, targ) for pred, targ in zip(refined_pred_list, targ_list)] 164 | ed = sum(ed_list) 165 | eds.append(ed) 166 | 167 | return eds 168 | 169 | 170 | def RecPostProcess(output, target, score, dataset=None): 171 | pred_list, targ_list = get_str_list(output, target, dataset) 172 | max_len_labels = output.size(1) 173 | score_list = [] 174 | 175 | score = to_numpy(score) 176 | for i, pred in enumerate(pred_list): 177 | len_pred = len(pred) + 1 # eos should be included 178 | len_pred = min(max_len_labels, len_pred) # maybe the predicted string don't include a eos. 179 | score_i = score[i,:len_pred] 180 | score_i = math.exp(sum(map(math.log, score_i))) 181 | score_list.append(score_i) 182 | 183 | return pred_list, targ_list, score_list -------------------------------------------------------------------------------- /lib/evaluators.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | import time 3 | from time import gmtime, strftime 4 | from datetime import datetime 5 | from collections import OrderedDict 6 | 7 | import torch 8 | 9 | import numpy as np 10 | from random import randint 11 | from PIL import Image 12 | import sys 13 | 14 | from . import evaluation_metrics 15 | from .evaluation_metrics import Accuracy, EditDistance, RecPostProcess 16 | from .utils.meters import AverageMeter 17 | from .utils.visualization_utils import recognition_vis, stn_vis 18 | 19 | metrics_factory = evaluation_metrics.factory() 20 | 21 | from config import get_args 22 | global_args = get_args(sys.argv[1:]) 23 | 24 | class BaseEvaluator(object): 25 | def __init__(self, model, metric, use_cuda=True): 26 | super(BaseEvaluator, self).__init__() 27 | self.model = model 28 | self.metric = metric 29 | self.use_cuda = use_cuda 30 | self.device = torch.device("cuda" if use_cuda else "cpu") 31 | 32 | def evaluate(self, data_loader, step=1, print_freq=1, tfLogger=None, dataset=None, vis_dir=None): 33 | self.model.eval() 34 | 35 | batch_time = AverageMeter() 36 | data_time = AverageMeter() 37 | 38 | # forward the network 39 | images, outputs, targets, losses = [], {}, [], [] 40 | file_names = [] 41 | 42 | end = time.time() 43 | for i, inputs in enumerate(data_loader): 44 | data_time.update(time.time() - end) 45 | 46 | input_dict = self._parse_data(inputs) 47 | output_dict = self._forward(input_dict) 48 | 49 | batch_size = input_dict['images'].size(0) 50 | 51 | total_loss_batch = 0. 52 | for k, loss in output_dict['losses'].items(): 53 | loss = loss.mean(dim=0, keepdim=True) 54 | total_loss_batch += loss.item() * batch_size 55 | 56 | images.append(input_dict['images']) 57 | targets.append(input_dict['rec_targets']) 58 | losses.append(total_loss_batch) 59 | if global_args.evaluate_with_lexicon: 60 | file_names += input_dict['file_name'] 61 | for k, v in output_dict['output'].items(): 62 | if k not in outputs: 63 | outputs[k] = [] 64 | outputs[k].append(v.cpu()) 65 | 66 | batch_time.update(time.time() - end) 67 | end = time.time() 68 | 69 | if (i + 1) % print_freq == 0: 70 | print('[{}]\t' 71 | 'Evaluation: [{}/{}]\t' 72 | 'Time {:.3f} ({:.3f})\t' 73 | 'Data {:.3f} ({:.3f})\t' 74 | # .format(strftime("%Y-%m-%d %H:%M:%S", gmtime()), 75 | .format(datetime.now().strftime('%Y-%m-%d %H:%M:%S'), 76 | i + 1, len(data_loader), 77 | batch_time.val, batch_time.avg, 78 | data_time.val, data_time.avg)) 79 | 80 | if not global_args.keep_ratio: 81 | images = torch.cat(images) 82 | num_samples = images.size(0) 83 | else: 84 | num_samples = sum([subimages.size(0) for subimages in images]) 85 | targets = torch.cat(targets) 86 | losses = np.sum(losses) / (1.0 * num_samples) 87 | for k, v in outputs.items(): 88 | outputs[k] = torch.cat(outputs[k]) 89 | 90 | # save info for recognition 91 | if 'pred_rec' in outputs: 92 | # evaluation with metric 93 | if global_args.evaluate_with_lexicon: 94 | eval_res = metrics_factory[self.metric+'_with_lexicon'](outputs['pred_rec'], targets, dataset, file_names) 95 | print('lexicon0: {0}, {1:.3f}'.format(self.metric, eval_res[0])) 96 | print('lexicon50: {0}, {1:.3f}'.format(self.metric, eval_res[1])) 97 | print('lexicon1k: {0}, {1:.3f}'.format(self.metric, eval_res[2])) 98 | print('lexiconfull: {0}, {1:.3f}'.format(self.metric, eval_res[3])) 99 | eval_res = eval_res[0] 100 | else: 101 | eval_res = metrics_factory[self.metric](outputs['pred_rec'], targets, dataset) 102 | print('lexicon0: {0}: {1:.3f}'.format(self.metric, eval_res)) 103 | pred_list, targ_list, score_list = RecPostProcess(outputs['pred_rec'], targets, outputs['pred_rec_score'], dataset) 104 | 105 | if tfLogger is not None: 106 | # (1) Log the scalar values 107 | info = { 108 | 'loss': losses, 109 | self.metric: eval_res, 110 | } 111 | for tag, value in info.items(): 112 | tfLogger.scalar_summary(tag, value, step) 113 | 114 | #====== Visualization ======# 115 | if vis_dir is not None: 116 | # recognition_vis(images, outputs['pred_rec'], targets, score_list, dataset, vis_dir) 117 | stn_vis(images, outputs['rectified_images'], outputs['ctrl_points'], outputs['pred_rec'], 118 | targets, score_list, outputs['pred_score'] if 'pred_score' in outputs else None, dataset, vis_dir) 119 | return eval_res 120 | 121 | 122 | def _parse_data(self, inputs): 123 | raise NotImplementedError 124 | 125 | def _forward(self, inputs): 126 | raise NotImplementedError 127 | 128 | 129 | class Evaluator(BaseEvaluator): 130 | def _parse_data(self, inputs): 131 | input_dict = {} 132 | if global_args.evaluate_with_lexicon: 133 | imgs, label_encs, lengths, file_name = inputs 134 | else: 135 | imgs, label_encs, lengths = inputs 136 | 137 | with torch.no_grad(): 138 | images = imgs.to(self.device) 139 | if label_encs is not None: 140 | labels = label_encs.to(self.device) 141 | 142 | input_dict['images'] = images 143 | input_dict['rec_targets'] = labels 144 | input_dict['rec_lengths'] = lengths 145 | if global_args.evaluate_with_lexicon: 146 | input_dict['file_name'] = file_name 147 | return input_dict 148 | 149 | def _forward(self, input_dict): 150 | self.model.eval() 151 | with torch.no_grad(): 152 | output_dict = self.model(input_dict) 153 | return output_dict -------------------------------------------------------------------------------- /lib/loss/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from .sequenceCrossEntropyLoss import SequenceCrossEntropyLoss 4 | 5 | 6 | __all__ = [ 7 | 'SequenceCrossEntropyLoss', 8 | ] -------------------------------------------------------------------------------- /lib/loss/sequenceCrossEntropyLoss.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import torch 4 | from torch import nn 5 | from torch.autograd import Variable 6 | import torch.nn.functional as F 7 | 8 | def to_contiguous(tensor): 9 | if tensor.is_contiguous(): 10 | return tensor 11 | else: 12 | return tensor.contiguous() 13 | 14 | def _assert_no_grad(variable): 15 | assert not variable.requires_grad, \ 16 | "nn criterions don't compute the gradient w.r.t. targets - please " \ 17 | "mark these variables as not requiring gradients" 18 | 19 | class SequenceCrossEntropyLoss(nn.Module): 20 | def __init__(self, 21 | weight=None, 22 | size_average=True, 23 | ignore_index=-100, 24 | sequence_normalize=False, 25 | sample_normalize=True): 26 | super(SequenceCrossEntropyLoss, self).__init__() 27 | self.weight = weight 28 | self.size_average = size_average 29 | self.ignore_index = ignore_index 30 | self.sequence_normalize = sequence_normalize 31 | self.sample_normalize = sample_normalize 32 | 33 | assert (sequence_normalize and sample_normalize) == False 34 | 35 | def forward(self, input, target, length): 36 | _assert_no_grad(target) 37 | # length to mask 38 | batch_size, def_max_length = target.size(0), target.size(1) 39 | mask = torch.zeros(batch_size, def_max_length) 40 | for i in range(batch_size): 41 | mask[i,:length[i]].fill_(1) 42 | mask = mask.type_as(input) 43 | # truncate to the same size 44 | max_length = max(length) 45 | assert max_length == input.size(1) 46 | target = target[:, :max_length] 47 | mask = mask[:, :max_length] 48 | input = to_contiguous(input).view(-1, input.size(2)) 49 | input = F.log_softmax(input, dim=1) 50 | target = to_contiguous(target).view(-1, 1) 51 | mask = to_contiguous(mask).view(-1, 1) 52 | output = - input.gather(1, target.long()) * mask 53 | # if self.size_average: 54 | # output = torch.sum(output) / torch.sum(mask) 55 | # elif self.reduce: 56 | # output = torch.sum(output) 57 | ## 58 | output = torch.sum(output) 59 | if self.sequence_normalize: 60 | output = output / torch.sum(mask) 61 | if self.sample_normalize: 62 | output = output / batch_size 63 | 64 | return output -------------------------------------------------------------------------------- /lib/models/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from .resnet_aster import * 4 | 5 | __factory = { 6 | 'ResNet_ASTER': ResNet_ASTER, 7 | } 8 | 9 | def names(): 10 | return sorted(__factory.keys()) 11 | 12 | 13 | def create(name, *args, **kwargs): 14 | """Create a model instance. 15 | 16 | Parameters 17 | ---------- 18 | name: str 19 | Model name. One of __factory 20 | pretrained: bool, optional 21 | If True, will use ImageNet pretrained model. Default: True 22 | num_classes: int, optional 23 | If positive, will change the original classifier the fit the new classifier with num_classes. Default: True 24 | with_words: bool, optional 25 | If True, the input of this model is the combination of image and word. Default: False 26 | """ 27 | if name not in __factory: 28 | raise KeyError('Unknown model:', name) 29 | return __factory[name](*args, **kwargs) -------------------------------------------------------------------------------- /lib/models/attention_recognition_head.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import sys 4 | 5 | import torch 6 | from torch import nn 7 | from torch.nn import functional as F 8 | from torch.nn import init 9 | 10 | 11 | class AttentionRecognitionHead(nn.Module): 12 | """ 13 | input: [b x 16 x 64 x in_planes] 14 | output: probability sequence: [b x T x num_classes] 15 | """ 16 | def __init__(self, num_classes, in_planes, sDim, attDim, max_len_labels): 17 | super(AttentionRecognitionHead, self).__init__() 18 | self.num_classes = num_classes # this is the output classes. So it includes the . 19 | self.in_planes = in_planes 20 | self.sDim = sDim 21 | self.attDim = attDim 22 | self.max_len_labels = max_len_labels 23 | 24 | self.decoder = DecoderUnit(sDim=sDim, xDim=in_planes, yDim=num_classes, attDim=attDim) 25 | 26 | def forward(self, x): 27 | x, targets, lengths = x 28 | batch_size = x.size(0) 29 | # Decoder 30 | state = torch.zeros(1, batch_size, self.sDim) 31 | outputs = [] 32 | 33 | for i in range(max(lengths)): 34 | if i == 0: 35 | y_prev = torch.zeros((batch_size)).fill_(self.num_classes) # the last one is used as the . 36 | else: 37 | y_prev = targets[:,i-1] 38 | 39 | output, state = self.decoder(x, state, y_prev) 40 | outputs.append(output) 41 | outputs = torch.cat([_.unsqueeze(1) for _ in outputs], 1) 42 | return outputs 43 | 44 | # inference stage. 45 | def sample(self, x): 46 | x, _, _ = x 47 | batch_size = x.size(0) 48 | # Decoder 49 | state = torch.zeros(1, batch_size, self.sDim) 50 | 51 | predicted_ids, predicted_scores = [], [] 52 | for i in range(self.max_len_labels): 53 | if i == 0: 54 | y_prev = torch.zeros((batch_size)).fill_(self.num_classes) 55 | else: 56 | y_prev = predicted 57 | 58 | output, state = self.decoder(x, state, y_prev) 59 | output = F.softmax(output, dim=1) 60 | score, predicted = output.max(1) 61 | predicted_ids.append(predicted.unsqueeze(1)) 62 | predicted_scores.append(score.unsqueeze(1)) 63 | predicted_ids = torch.cat(predicted_ids, 1) 64 | predicted_scores = torch.cat(predicted_scores, 1) 65 | # return predicted_ids.squeeze(), predicted_scores.squeeze() 66 | return predicted_ids, predicted_scores 67 | 68 | def beam_search(self, x, beam_width, eos): 69 | 70 | def _inflate(tensor, times, dim): 71 | repeat_dims = [1] * tensor.dim() 72 | repeat_dims[dim] = times 73 | return tensor.repeat(*repeat_dims) 74 | 75 | # https://github.com/IBM/pytorch-seq2seq/blob/fede87655ddce6c94b38886089e05321dc9802af/seq2seq/models/TopKDecoder.py 76 | batch_size, l, d = x.size() 77 | # inflated_encoder_feats = _inflate(encoder_feats, beam_width, 0) # ABC --> AABBCC -/-> ABCABC 78 | inflated_encoder_feats = x.unsqueeze(1).permute((1,0,2,3)).repeat((beam_width,1,1,1)).permute((1,0,2,3)).contiguous().view(-1, l, d) 79 | 80 | # Initialize the decoder 81 | state = torch.zeros(1, batch_size * beam_width, self.sDim) 82 | pos_index = (torch.Tensor(range(batch_size)) * beam_width).long().view(-1, 1) 83 | 84 | # Initialize the scores 85 | sequence_scores = torch.Tensor(batch_size * beam_width, 1) 86 | sequence_scores.fill_(-float('Inf')) 87 | sequence_scores.index_fill_(0, torch.Tensor([i * beam_width for i in range(0, batch_size)]).long(), 0.0) 88 | # sequence_scores.fill_(0.0) 89 | 90 | # Initialize the input vector 91 | y_prev = torch.zeros((batch_size * beam_width)).fill_(self.num_classes) 92 | 93 | # Store decisions for backtracking 94 | stored_scores = list() 95 | stored_predecessors = list() 96 | stored_emitted_symbols = list() 97 | 98 | for i in range(self.max_len_labels): 99 | output, state = self.decoder(inflated_encoder_feats, state, y_prev) 100 | log_softmax_output = F.log_softmax(output, dim=1) 101 | 102 | sequence_scores = _inflate(sequence_scores, self.num_classes, 1) 103 | sequence_scores += log_softmax_output 104 | scores, candidates = sequence_scores.view(batch_size, -1).topk(beam_width, dim=1) 105 | 106 | # Reshape input = (bk, 1) and sequence_scores = (bk, 1) 107 | y_prev = (candidates % self.num_classes).view(batch_size * beam_width) 108 | sequence_scores = scores.view(batch_size * beam_width, 1) 109 | 110 | # Update fields for next timestep 111 | predecessors = (candidates / self.num_classes + pos_index.expand_as(candidates)).view(batch_size * beam_width, 1) 112 | state = state.index_select(1, predecessors.squeeze()) 113 | 114 | # Update sequence socres and erase scores for symbol so that they aren't expanded 115 | stored_scores.append(sequence_scores.clone()) 116 | eos_indices = y_prev.view(-1, 1).eq(eos) 117 | if eos_indices.nonzero().dim() > 0: 118 | sequence_scores.masked_fill_(eos_indices, -float('inf')) 119 | 120 | # Cache results for backtracking 121 | stored_predecessors.append(predecessors) 122 | stored_emitted_symbols.append(y_prev) 123 | 124 | # Do backtracking to return the optimal values 125 | #====== backtrak ======# 126 | # Initialize return variables given different types 127 | p = list() 128 | l = [[self.max_len_labels] * beam_width for _ in range(batch_size)] # Placeholder for lengths of top-k sequences 129 | 130 | # the last step output of the beams are not sorted 131 | # thus they are sorted here 132 | sorted_score, sorted_idx = stored_scores[-1].view(batch_size, beam_width).topk(beam_width) 133 | # initialize the sequence scores with the sorted last step beam scores 134 | s = sorted_score.clone() 135 | 136 | batch_eos_found = [0] * batch_size # the number of EOS found 137 | # in the backward loop below for each batch 138 | t = self.max_len_labels - 1 139 | # initialize the back pointer with the sorted order of the last step beams. 140 | # add pos_index for indexing variable with b*k as the first dimension. 141 | t_predecessors = (sorted_idx + pos_index.expand_as(sorted_idx)).view(batch_size * beam_width) 142 | while t >= 0: 143 | # Re-order the variables with the back pointer 144 | current_symbol = stored_emitted_symbols[t].index_select(0, t_predecessors) 145 | t_predecessors = stored_predecessors[t].index_select(0, t_predecessors).squeeze() 146 | eos_indices = stored_emitted_symbols[t].eq(eos).nonzero() 147 | if eos_indices.dim() > 0: 148 | for i in range(eos_indices.size(0)-1, -1, -1): 149 | # Indices of the EOS symbol for both variables 150 | # with b*k as the first dimension, and b, k for 151 | # the first two dimensions 152 | idx = eos_indices[i] 153 | b_idx = int(idx[0] / beam_width) 154 | # The indices of the replacing position 155 | # according to the replacement strategy noted above 156 | res_k_idx = beam_width - (batch_eos_found[b_idx] % beam_width) - 1 157 | batch_eos_found[b_idx] += 1 158 | res_idx = b_idx * beam_width + res_k_idx 159 | 160 | # Replace the old information in return variables 161 | # with the new ended sequence information 162 | t_predecessors[res_idx] = stored_predecessors[t][idx[0]] 163 | current_symbol[res_idx] = stored_emitted_symbols[t][idx[0]] 164 | s[b_idx, res_k_idx] = stored_scores[t][idx[0], [0]] 165 | l[b_idx][res_k_idx] = t + 1 166 | 167 | # record the back tracked results 168 | p.append(current_symbol) 169 | 170 | t -= 1 171 | 172 | # Sort and re-order again as the added ended sequences may change 173 | # the order (very unlikely) 174 | s, re_sorted_idx = s.topk(beam_width) 175 | for b_idx in range(batch_size): 176 | l[b_idx] = [l[b_idx][k_idx.item()] for k_idx in re_sorted_idx[b_idx,:]] 177 | 178 | re_sorted_idx = (re_sorted_idx + pos_index.expand_as(re_sorted_idx)).view(batch_size*beam_width) 179 | 180 | # Reverse the sequences and re-order at the same time 181 | # It is reversed because the backtracking happens in reverse time order 182 | p = [step.index_select(0, re_sorted_idx).view(batch_size, beam_width, -1) for step in reversed(p)] 183 | p = torch.cat(p, -1)[:,0,:] 184 | return p, torch.ones_like(p) 185 | 186 | 187 | class AttentionUnit(nn.Module): 188 | def __init__(self, sDim, xDim, attDim): 189 | super(AttentionUnit, self).__init__() 190 | 191 | self.sDim = sDim 192 | self.xDim = xDim 193 | self.attDim = attDim 194 | 195 | self.sEmbed = nn.Linear(sDim, attDim) 196 | self.xEmbed = nn.Linear(xDim, attDim) 197 | self.wEmbed = nn.Linear(attDim, 1) 198 | 199 | # self.init_weights() 200 | 201 | def init_weights(self): 202 | init.normal_(self.sEmbed.weight, std=0.01) 203 | init.constant_(self.sEmbed.bias, 0) 204 | init.normal_(self.xEmbed.weight, std=0.01) 205 | init.constant_(self.xEmbed.bias, 0) 206 | init.normal_(self.wEmbed.weight, std=0.01) 207 | init.constant_(self.wEmbed.bias, 0) 208 | 209 | def forward(self, x, sPrev): 210 | batch_size, T, _ = x.size() # [b x T x xDim] 211 | x = x.view(-1, self.xDim) # [(b x T) x xDim] 212 | xProj = self.xEmbed(x) # [(b x T) x attDim] 213 | xProj = xProj.view(batch_size, T, -1) # [b x T x attDim] 214 | 215 | sPrev = sPrev.squeeze(0) 216 | sProj = self.sEmbed(sPrev) # [b x attDim] 217 | sProj = torch.unsqueeze(sProj, 1) # [b x 1 x attDim] 218 | sProj = sProj.expand(batch_size, T, self.attDim) # [b x T x attDim] 219 | 220 | sumTanh = torch.tanh(sProj + xProj) 221 | sumTanh = sumTanh.view(-1, self.attDim) 222 | 223 | vProj = self.wEmbed(sumTanh) # [(b x T) x 1] 224 | vProj = vProj.view(batch_size, T) 225 | 226 | alpha = F.softmax(vProj, dim=1) # attention weights for each sample in the minibatch 227 | 228 | return alpha 229 | 230 | 231 | class DecoderUnit(nn.Module): 232 | def __init__(self, sDim, xDim, yDim, attDim): 233 | super(DecoderUnit, self).__init__() 234 | self.sDim = sDim 235 | self.xDim = xDim 236 | self.yDim = yDim 237 | self.attDim = attDim 238 | self.emdDim = attDim 239 | 240 | self.attention_unit = AttentionUnit(sDim, xDim, attDim) 241 | self.tgt_embedding = nn.Embedding(yDim+1, self.emdDim) # the last is used for 242 | self.gru = nn.GRU(input_size=xDim+self.emdDim, hidden_size=sDim, batch_first=True) 243 | self.fc = nn.Linear(sDim, yDim) 244 | 245 | # self.init_weights() 246 | 247 | def init_weights(self): 248 | init.normal_(self.tgt_embedding.weight, std=0.01) 249 | init.normal_(self.fc.weight, std=0.01) 250 | init.constant_(self.fc.bias, 0) 251 | 252 | def forward(self, x, sPrev, yPrev): 253 | # x: feature sequence from the image decoder. 254 | batch_size, T, _ = x.size() 255 | alpha = self.attention_unit(x, sPrev) 256 | context = torch.bmm(alpha.unsqueeze(1), x).squeeze(1) 257 | yProj = self.tgt_embedding(yPrev.long()) 258 | # self.gru.flatten_parameters() 259 | output, state = self.gru(torch.cat([yProj, context], 1).unsqueeze(1), sPrev) 260 | output = output.squeeze(1) 261 | 262 | output = self.fc(output) 263 | return output, state -------------------------------------------------------------------------------- /lib/models/model_builder.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from PIL import Image 4 | import numpy as np 5 | from collections import OrderedDict 6 | import sys 7 | 8 | import torch 9 | from torch import nn 10 | from torch.nn import functional as F 11 | from torch.nn import init 12 | 13 | from . import create 14 | from .attention_recognition_head import AttentionRecognitionHead 15 | from ..loss.sequenceCrossEntropyLoss import SequenceCrossEntropyLoss 16 | from .tps_spatial_transformer import TPSSpatialTransformer 17 | from .stn_head import STNHead 18 | 19 | 20 | from config import get_args 21 | global_args = get_args(sys.argv[1:]) 22 | 23 | 24 | class ModelBuilder(nn.Module): 25 | """ 26 | This is the integrated model. 27 | """ 28 | def __init__(self, arch, rec_num_classes, sDim, attDim, max_len_labels, eos, STN_ON=False): 29 | super(ModelBuilder, self).__init__() 30 | 31 | self.arch = arch 32 | self.rec_num_classes = rec_num_classes 33 | self.sDim = sDim 34 | self.attDim = attDim 35 | self.max_len_labels = max_len_labels 36 | self.eos = eos 37 | self.STN_ON = STN_ON 38 | self.tps_inputsize = global_args.tps_inputsize 39 | 40 | self.encoder = create(self.arch, 41 | with_lstm=global_args.with_lstm, 42 | n_group=global_args.n_group) 43 | encoder_out_planes = self.encoder.out_planes 44 | 45 | self.decoder = AttentionRecognitionHead( 46 | num_classes=rec_num_classes, 47 | in_planes=encoder_out_planes, 48 | sDim=sDim, 49 | attDim=attDim, 50 | max_len_labels=max_len_labels) 51 | self.rec_crit = SequenceCrossEntropyLoss() 52 | 53 | if self.STN_ON: 54 | self.tps = TPSSpatialTransformer( 55 | output_image_size=tuple(global_args.tps_outputsize), 56 | num_control_points=global_args.num_control_points, 57 | margins=tuple(global_args.tps_margins)) 58 | self.stn_head = STNHead( 59 | in_planes=3, 60 | num_ctrlpoints=global_args.num_control_points, 61 | activation=global_args.stn_activation) 62 | 63 | def forward(self, input_dict): 64 | return_dict = {} 65 | return_dict['losses'] = {} 66 | return_dict['output'] = {} 67 | 68 | x, rec_targets, rec_lengths = input_dict['images'], \ 69 | input_dict['rec_targets'], \ 70 | input_dict['rec_lengths'] 71 | 72 | # rectification 73 | if self.STN_ON: 74 | # input images are downsampled before being fed into stn_head. 75 | stn_input = F.interpolate(x, self.tps_inputsize, mode='bilinear', align_corners=True) 76 | stn_img_feat, ctrl_points = self.stn_head(stn_input) 77 | x, _ = self.tps(x, ctrl_points) 78 | if not self.training: 79 | # save for visualization 80 | return_dict['output']['ctrl_points'] = ctrl_points 81 | return_dict['output']['rectified_images'] = x 82 | 83 | encoder_feats = self.encoder(x) 84 | encoder_feats = encoder_feats.contiguous() 85 | 86 | if self.training: 87 | rec_pred = self.decoder([encoder_feats, rec_targets, rec_lengths]) 88 | loss_rec = self.rec_crit(rec_pred, rec_targets, rec_lengths) 89 | return_dict['losses']['loss_rec'] = loss_rec 90 | else: 91 | rec_pred, rec_pred_scores = self.decoder.beam_search(encoder_feats, global_args.beam_width, self.eos) 92 | rec_pred_ = self.decoder([encoder_feats, rec_targets, rec_lengths]) 93 | loss_rec = self.rec_crit(rec_pred_, rec_targets, rec_lengths) 94 | return_dict['losses']['loss_rec'] = loss_rec 95 | return_dict['output']['pred_rec'] = rec_pred 96 | return_dict['output']['pred_rec_score'] = rec_pred_scores 97 | 98 | # pytorch0.4 bug on gathering scalar(0-dim) tensors 99 | for k, v in return_dict['losses'].items(): 100 | return_dict['losses'][k] = v.unsqueeze(0) 101 | 102 | return return_dict -------------------------------------------------------------------------------- /lib/models/resnet_aster.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torchvision 4 | 5 | import sys 6 | import math 7 | 8 | from config import get_args 9 | global_args = get_args(sys.argv[1:]) 10 | 11 | 12 | def conv3x3(in_planes, out_planes, stride=1): 13 | """3x3 convolution with padding""" 14 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 15 | padding=1, bias=False) 16 | 17 | 18 | def conv1x1(in_planes, out_planes, stride=1): 19 | """1x1 convolution""" 20 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 21 | 22 | 23 | def get_sinusoid_encoding(n_position, feat_dim, wave_length=10000): 24 | # [n_position] 25 | positions = torch.arange(0, n_position)#.cuda() 26 | # [feat_dim] 27 | dim_range = torch.arange(0, feat_dim)#.cuda() 28 | dim_range = torch.pow(wave_length, 2 * (dim_range // 2) / feat_dim) 29 | # [n_position, feat_dim] 30 | angles = positions.unsqueeze(1) / dim_range.unsqueeze(0) 31 | angles = angles.float() 32 | angles[:, 0::2] = torch.sin(angles[:, 0::2]) 33 | angles[:, 1::2] = torch.cos(angles[:, 1::2]) 34 | return angles 35 | 36 | 37 | class AsterBlock(nn.Module): 38 | 39 | def __init__(self, inplanes, planes, stride=1, downsample=None): 40 | super(AsterBlock, self).__init__() 41 | self.conv1 = conv1x1(inplanes, planes, stride) 42 | self.bn1 = nn.BatchNorm2d(planes) 43 | self.relu = nn.ReLU(inplace=True) 44 | self.conv2 = conv3x3(planes, planes) 45 | self.bn2 = nn.BatchNorm2d(planes) 46 | self.downsample = downsample 47 | self.stride = stride 48 | 49 | def forward(self, x): 50 | residual = x 51 | out = self.conv1(x) 52 | out = self.bn1(out) 53 | out = self.relu(out) 54 | out = self.conv2(out) 55 | out = self.bn2(out) 56 | 57 | if self.downsample is not None: 58 | residual = self.downsample(x) 59 | out += residual 60 | out = self.relu(out) 61 | return out 62 | 63 | 64 | class ResNet_ASTER(nn.Module): 65 | """For aster or crnn""" 66 | 67 | def __init__(self, with_lstm=False, n_group=1): 68 | super(ResNet_ASTER, self).__init__() 69 | self.with_lstm = with_lstm 70 | self.n_group = n_group 71 | 72 | in_channels = 3 73 | self.layer0 = nn.Sequential( 74 | nn.Conv2d(in_channels, 32, kernel_size=(3, 3), stride=1, padding=1, bias=False), 75 | nn.BatchNorm2d(32), 76 | nn.ReLU(inplace=True)) 77 | 78 | self.inplanes = 32 79 | self.layer1 = self._make_layer(32, 3, [2, 2]) # [16, 50] 80 | self.layer2 = self._make_layer(64, 4, [2, 2]) # [8, 25] 81 | self.layer3 = self._make_layer(128, 6, [2, 1]) # [4, 25] 82 | self.layer4 = self._make_layer(256, 6, [2, 1]) # [2, 25] 83 | self.layer5 = self._make_layer(512, 3, [2, 1]) # [1, 25] 84 | 85 | if with_lstm: 86 | self.rnn = nn.LSTM(512, 256, bidirectional=True, num_layers=2, batch_first=True) 87 | self.out_planes = 2 * 256 88 | else: 89 | self.out_planes = 512 90 | 91 | for m in self.modules(): 92 | if isinstance(m, nn.Conv2d): 93 | nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") 94 | elif isinstance(m, nn.BatchNorm2d): 95 | nn.init.constant_(m.weight, 1) 96 | nn.init.constant_(m.bias, 0) 97 | 98 | def _make_layer(self, planes, blocks, stride): 99 | downsample = None 100 | if stride != [1, 1] or self.inplanes != planes: 101 | downsample = nn.Sequential( 102 | conv1x1(self.inplanes, planes, stride), 103 | nn.BatchNorm2d(planes)) 104 | 105 | layers = [] 106 | layers.append(AsterBlock(self.inplanes, planes, stride, downsample)) 107 | self.inplanes = planes 108 | for _ in range(1, blocks): 109 | layers.append(AsterBlock(self.inplanes, planes)) 110 | return nn.Sequential(*layers) 111 | 112 | def forward(self, x): 113 | x0 = self.layer0(x) 114 | x1 = self.layer1(x0) 115 | x2 = self.layer2(x1) 116 | x3 = self.layer3(x2) 117 | x4 = self.layer4(x3) 118 | x5 = self.layer5(x4) 119 | 120 | cnn_feat = x5.squeeze(2) # [N, c, w] 121 | cnn_feat = cnn_feat.transpose(2, 1) 122 | if self.with_lstm: 123 | rnn_feat, _ = self.rnn(cnn_feat) 124 | return rnn_feat 125 | else: 126 | return cnn_feat 127 | 128 | 129 | if __name__ == "__main__": 130 | x = torch.randn(3, 3, 32, 100) 131 | net = ResNet_ASTER(use_self_attention=True, use_position_embedding=True) 132 | encoder_feat = net(x) 133 | print(encoder_feat.size()) -------------------------------------------------------------------------------- /lib/models/stn_head.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import math 4 | import numpy as np 5 | import sys 6 | 7 | import torch 8 | from torch import nn 9 | from torch.nn import functional as F 10 | from torch.nn import init 11 | 12 | 13 | def conv3x3_block(in_planes, out_planes, stride=1): 14 | """3x3 convolution with padding""" 15 | conv_layer = nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=1, padding=1) 16 | 17 | block = nn.Sequential( 18 | conv_layer, 19 | nn.BatchNorm2d(out_planes), 20 | nn.ReLU(inplace=True), 21 | ) 22 | return block 23 | 24 | 25 | class STNHead(nn.Module): 26 | def __init__(self, in_planes, num_ctrlpoints, activation='none'): 27 | super(STNHead, self).__init__() 28 | 29 | self.in_planes = in_planes 30 | self.num_ctrlpoints = num_ctrlpoints 31 | self.activation = activation 32 | self.stn_convnet = nn.Sequential( 33 | conv3x3_block(in_planes, 32), # 32*64 34 | nn.MaxPool2d(kernel_size=2, stride=2), 35 | conv3x3_block(32, 64), # 16*32 36 | nn.MaxPool2d(kernel_size=2, stride=2), 37 | conv3x3_block(64, 128), # 8*16 38 | nn.MaxPool2d(kernel_size=2, stride=2), 39 | conv3x3_block(128, 256), # 4*8 40 | nn.MaxPool2d(kernel_size=2, stride=2), 41 | conv3x3_block(256, 256), # 2*4, 42 | nn.MaxPool2d(kernel_size=2, stride=2), 43 | conv3x3_block(256, 256)) # 1*2 44 | 45 | self.stn_fc1 = nn.Sequential( 46 | nn.Linear(2*256, 512), 47 | nn.BatchNorm1d(512), 48 | nn.ReLU(inplace=True)) 49 | self.stn_fc2 = nn.Linear(512, num_ctrlpoints*2) 50 | 51 | self.init_weights(self.stn_convnet) 52 | self.init_weights(self.stn_fc1) 53 | self.init_stn(self.stn_fc2) 54 | 55 | def init_weights(self, module): 56 | for m in module.modules(): 57 | if isinstance(m, nn.Conv2d): 58 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 59 | m.weight.data.normal_(0, math.sqrt(2. / n)) 60 | if m.bias is not None: 61 | m.bias.data.zero_() 62 | elif isinstance(m, nn.BatchNorm2d): 63 | m.weight.data.fill_(1) 64 | m.bias.data.zero_() 65 | elif isinstance(m, nn.Linear): 66 | m.weight.data.normal_(0, 0.001) 67 | m.bias.data.zero_() 68 | 69 | def init_stn(self, stn_fc2): 70 | margin = 0.01 71 | sampling_num_per_side = int(self.num_ctrlpoints / 2) 72 | ctrl_pts_x = np.linspace(margin, 1.-margin, sampling_num_per_side) 73 | ctrl_pts_y_top = np.ones(sampling_num_per_side) * margin 74 | ctrl_pts_y_bottom = np.ones(sampling_num_per_side) * (1-margin) 75 | ctrl_pts_top = np.stack([ctrl_pts_x, ctrl_pts_y_top], axis=1) 76 | ctrl_pts_bottom = np.stack([ctrl_pts_x, ctrl_pts_y_bottom], axis=1) 77 | ctrl_points = np.concatenate([ctrl_pts_top, ctrl_pts_bottom], axis=0).astype(np.float32) 78 | if self.activation is 'none': 79 | pass 80 | elif self.activation == 'sigmoid': 81 | ctrl_points = -np.log(1. / ctrl_points - 1.) 82 | stn_fc2.weight.data.zero_() 83 | stn_fc2.bias.data = torch.Tensor(ctrl_points).view(-1) 84 | 85 | def forward(self, x): 86 | x = self.stn_convnet(x) 87 | batch_size, _, h, w = x.size() 88 | x = x.view(batch_size, -1) 89 | img_feat = self.stn_fc1(x) 90 | x = self.stn_fc2(0.1 * img_feat) 91 | if self.activation == 'sigmoid': 92 | x = F.sigmoid(x) 93 | x = x.view(-1, self.num_ctrlpoints, 2) 94 | return img_feat, x 95 | 96 | 97 | if __name__ == "__main__": 98 | in_planes = 3 99 | num_ctrlpoints = 20 100 | activation='none' # 'sigmoid' 101 | stn_head = STNHead(in_planes, num_ctrlpoints, activation) 102 | input = torch.randn(10, 3, 32, 64) 103 | control_points = stn_head(input) 104 | print(control_points.size()) -------------------------------------------------------------------------------- /lib/models/tps_spatial_transformer.py: -------------------------------------------------------------------------------- 1 | 2 | from __future__ import absolute_import 3 | 4 | import numpy as np 5 | import itertools 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | 11 | def grid_sample(input, grid, canvas = None): 12 | output = F.grid_sample(input, grid) 13 | if canvas is None: 14 | return output 15 | else: 16 | input_mask = input.data.new(input.size()).fill_(1) 17 | output_mask = F.grid_sample(input_mask, grid) 18 | padded_output = output * output_mask + canvas * (1 - output_mask) 19 | return padded_output 20 | 21 | 22 | # phi(x1, x2) = r^2 * log(r), where r = ||x1 - x2||_2 23 | def compute_partial_repr(input_points, control_points): 24 | N = input_points.size(0) 25 | M = control_points.size(0) 26 | pairwise_diff = input_points.view(N, 1, 2) - control_points.view(1, M, 2) 27 | # original implementation, very slow 28 | # pairwise_dist = torch.sum(pairwise_diff ** 2, dim = 2) # square of distance 29 | pairwise_diff_square = pairwise_diff * pairwise_diff 30 | pairwise_dist = pairwise_diff_square[:, :, 0] + pairwise_diff_square[:, :, 1] 31 | repr_matrix = 0.5 * pairwise_dist * torch.log(pairwise_dist) 32 | # fix numerical error for 0 * log(0), substitute all nan with 0 33 | mask = repr_matrix != repr_matrix 34 | repr_matrix.masked_fill_(mask, 0) 35 | return repr_matrix 36 | 37 | 38 | # output_ctrl_pts are specified, according to our task. 39 | def build_output_control_points(num_control_points, margins): 40 | margin_x, margin_y = margins 41 | num_ctrl_pts_per_side = num_control_points // 2 42 | ctrl_pts_x = np.linspace(margin_x, 1.0 - margin_x, num_ctrl_pts_per_side) 43 | ctrl_pts_y_top = np.ones(num_ctrl_pts_per_side) * margin_y 44 | ctrl_pts_y_bottom = np.ones(num_ctrl_pts_per_side) * (1.0 - margin_y) 45 | ctrl_pts_top = np.stack([ctrl_pts_x, ctrl_pts_y_top], axis=1) 46 | ctrl_pts_bottom = np.stack([ctrl_pts_x, ctrl_pts_y_bottom], axis=1) 47 | # ctrl_pts_top = ctrl_pts_top[1:-1,:] 48 | # ctrl_pts_bottom = ctrl_pts_bottom[1:-1,:] 49 | output_ctrl_pts_arr = np.concatenate([ctrl_pts_top, ctrl_pts_bottom], axis=0) 50 | output_ctrl_pts = torch.Tensor(output_ctrl_pts_arr) 51 | return output_ctrl_pts 52 | 53 | 54 | # demo: ~/test/models/test_tps_transformation.py 55 | class TPSSpatialTransformer(nn.Module): 56 | 57 | def __init__(self, output_image_size=None, num_control_points=None, margins=None): 58 | super(TPSSpatialTransformer, self).__init__() 59 | self.output_image_size = output_image_size 60 | self.num_control_points = num_control_points 61 | self.margins = margins 62 | 63 | self.target_height, self.target_width = output_image_size 64 | target_control_points = build_output_control_points(num_control_points, margins) 65 | N = num_control_points 66 | # N = N - 4 67 | 68 | # create padded kernel matrix 69 | forward_kernel = torch.zeros(N + 3, N + 3) 70 | target_control_partial_repr = compute_partial_repr(target_control_points, target_control_points) 71 | forward_kernel[:N, :N].copy_(target_control_partial_repr) 72 | forward_kernel[:N, -3].fill_(1) 73 | forward_kernel[-3, :N].fill_(1) 74 | forward_kernel[:N, -2:].copy_(target_control_points) 75 | forward_kernel[-2:, :N].copy_(target_control_points.transpose(0, 1)) 76 | # compute inverse matrix 77 | inverse_kernel = torch.inverse(forward_kernel) 78 | 79 | # create target cordinate matrix 80 | HW = self.target_height * self.target_width 81 | target_coordinate = list(itertools.product(range(self.target_height), range(self.target_width))) 82 | target_coordinate = torch.Tensor(target_coordinate) # HW x 2 83 | Y, X = target_coordinate.split(1, dim = 1) 84 | Y = Y / (self.target_height - 1) 85 | X = X / (self.target_width - 1) 86 | target_coordinate = torch.cat([X, Y], dim = 1) # convert from (y, x) to (x, y) 87 | target_coordinate_partial_repr = compute_partial_repr(target_coordinate, target_control_points) 88 | target_coordinate_repr = torch.cat([ 89 | target_coordinate_partial_repr, torch.ones(HW, 1), target_coordinate 90 | ], dim = 1) 91 | 92 | # register precomputed matrices 93 | self.register_buffer('inverse_kernel', inverse_kernel) 94 | self.register_buffer('padding_matrix', torch.zeros(3, 2)) 95 | self.register_buffer('target_coordinate_repr', target_coordinate_repr) 96 | self.register_buffer('target_control_points', target_control_points) 97 | 98 | def forward(self, input, source_control_points): 99 | assert source_control_points.ndimension() == 3 100 | assert source_control_points.size(1) == self.num_control_points 101 | assert source_control_points.size(2) == 2 102 | batch_size = source_control_points.size(0) 103 | 104 | Y = torch.cat([source_control_points, self.padding_matrix.expand(batch_size, 3, 2)], 1) 105 | mapping_matrix = torch.matmul(self.inverse_kernel, Y) 106 | source_coordinate = torch.matmul(self.target_coordinate_repr, mapping_matrix) 107 | 108 | grid = source_coordinate.view(-1, self.target_height, self.target_width, 2) 109 | grid = torch.clamp(grid, 0, 1) # the source_control_points may be out of [0, 1]. 110 | # the input to grid_sample is normalized [-1, 1], but what we get is [0, 1] 111 | grid = 2.0 * grid - 1.0 112 | output_maps = grid_sample(input, grid, canvas=None) 113 | return output_maps, source_coordinate -------------------------------------------------------------------------------- /lib/tools/create_sub_lmdb.py: -------------------------------------------------------------------------------- 1 | import lmdb 2 | import six 3 | import numpy as np 4 | from PIL import Image 5 | 6 | read_root_dir = '/data/zhui/back/NIPS2014' 7 | write_root_dir = '/home/mkyang/data/sub_nips2014' 8 | read_env = lmdb.open(read_root_dir, max_readers=32, readonly=True) 9 | write_env = lmdb.open(write_root_dir, map_size=1099511627776) 10 | 11 | def writeCache(env, cache): 12 | with env.begin(write=True) as txn: 13 | for k, v in cache.items(): 14 | txn.put(k.encode(), v) 15 | 16 | assert read_env is not None, "cannot create lmdb from %s" % read_root_dir 17 | read_txn = read_env.begin() 18 | nSamples = int(read_txn.get(b"num-samples")) 19 | sub_nsamples = 10000 20 | indices = list(np.random.permutation(nSamples)) 21 | indices = indices[:sub_nsamples] 22 | 23 | cache = {} 24 | for i, index in enumerate(indices): 25 | img_key = b'image-%09d' % index 26 | label_key = b'label-%09d' % index 27 | 28 | imgbuf = read_txn.get(img_key) 29 | word = read_txn.get(label_key) 30 | 31 | new_img_key = 'image-%09d' % (i+1) 32 | new_label_key = 'label-%09d' % (i+1) 33 | cache[new_img_key] = imgbuf 34 | cache[new_label_key] = word 35 | 36 | cache['num-samples'] = str(sub_nsamples).encode() 37 | writeCache(write_env, cache) -------------------------------------------------------------------------------- /lib/tools/create_svtp_lmdb.py: -------------------------------------------------------------------------------- 1 | import os 2 | import lmdb # install lmdb by "pip install lmdb" 3 | import cv2 4 | import numpy as np 5 | from tqdm import tqdm 6 | import six 7 | from PIL import Image 8 | import scipy.io as sio 9 | from tqdm import tqdm 10 | import re 11 | 12 | def checkImageIsValid(imageBin): 13 | if imageBin is None: 14 | return False 15 | imageBuf = np.fromstring(imageBin, dtype=np.uint8) 16 | img = cv2.imdecode(imageBuf, cv2.IMREAD_GRAYSCALE) 17 | imgH, imgW = img.shape[0], img.shape[1] 18 | if imgH * imgW == 0: 19 | return False 20 | return True 21 | 22 | 23 | def writeCache(env, cache): 24 | with env.begin(write=True) as txn: 25 | for k, v in cache.items(): 26 | txn.put(k.encode(), v) 27 | 28 | 29 | def _is_difficult(word): 30 | assert isinstance(word, str) 31 | return not re.match('^[\w]+$', word) 32 | 33 | 34 | def createDataset(outputPath, imagePathList, labelList, lexiconList=None, checkValid=True): 35 | """ 36 | Create LMDB dataset for CRNN training. 37 | ARGS: 38 | outputPath : LMDB output path 39 | imagePathList : list of image path 40 | labelList : list of corresponding groundtruth texts 41 | lexiconList : (optional) list of lexicon lists 42 | checkValid : if true, check the validity of every image 43 | """ 44 | assert(len(imagePathList) == len(labelList)) 45 | nSamples = len(imagePathList) 46 | env = lmdb.open(outputPath, map_size=1099511627776) 47 | cache = {} 48 | cnt = 1 49 | for i in range(nSamples): 50 | imagePath = imagePathList[i] 51 | label = labelList[i] 52 | if len(label) == 0: 53 | continue 54 | if not os.path.exists(imagePath): 55 | print('%s does not exist' % imagePath) 56 | continue 57 | with open(imagePath, 'rb') as f: 58 | imageBin = f.read() 59 | if checkValid: 60 | if not checkImageIsValid(imageBin): 61 | print('%s is not a valid image' % imagePath) 62 | continue 63 | 64 | imageKey = 'image-%09d' % cnt 65 | labelKey = 'label-%09d' % cnt 66 | cache[imageKey] = imageBin 67 | cache[labelKey] = label.encode() 68 | if lexiconList: 69 | lexiconKey = 'lexicon-%09d' % cnt 70 | cache[lexiconKey] = ' '.join(lexiconList[i]) 71 | if cnt % 1000 == 0: 72 | writeCache(env, cache) 73 | cache = {} 74 | print('Written %d / %d' % (cnt, nSamples)) 75 | cnt += 1 76 | nSamples = cnt-1 77 | cache['num-samples'] = str(nSamples).encode() 78 | writeCache(env, cache) 79 | print('Created dataset with %d samples' % nSamples) 80 | 81 | if __name__ == "__main__": 82 | data_dir = '/data/mkyang/datasets/English/benchmark/svtp/' 83 | lmdb_output_path = '/data/mkyang/datasets/English/benchmark_lmdbs_new/svt_p_645' 84 | gt_file = os.path.join(data_dir, 'gt.txt') 85 | image_dir = data_dir 86 | with open(gt_file, 'r') as f: 87 | lines = [line.strip('\n') for line in f.readlines()] 88 | 89 | imagePathList, labelList = [], [] 90 | for i, line in enumerate(lines): 91 | splits = line.split(' ') 92 | image_name = splits[0] 93 | gt_text = splits[1] 94 | print(image_name, gt_text) 95 | imagePathList.append(os.path.join(image_dir, image_name)) 96 | labelList.append(gt_text) 97 | 98 | createDataset(lmdb_output_path, imagePathList, labelList) -------------------------------------------------------------------------------- /lib/trainers.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | import time 3 | from time import gmtime, strftime 4 | from datetime import datetime 5 | import gc 6 | import os.path as osp 7 | import sys 8 | from PIL import Image 9 | import numpy as np 10 | 11 | import torch 12 | from torchvision import transforms 13 | 14 | from . import evaluation_metrics 15 | from .evaluation_metrics import Accuracy, EditDistance 16 | from .utils import to_numpy 17 | from .utils.meters import AverageMeter 18 | from .utils.serialization import load_checkpoint, save_checkpoint 19 | 20 | metrics_factory = evaluation_metrics.factory() 21 | 22 | from config import get_args 23 | global_args = get_args(sys.argv[1:]) 24 | 25 | class BaseTrainer(object): 26 | def __init__(self, model, metric, logs_dir, iters=0, best_res=-1, grad_clip=-1, use_cuda=True, loss_weights={}): 27 | super(BaseTrainer, self).__init__() 28 | self.model = model 29 | self.metric = metric 30 | self.logs_dir = logs_dir 31 | self.iters = iters 32 | self.best_res = best_res 33 | self.grad_clip = grad_clip 34 | self.use_cuda = use_cuda 35 | self.loss_weights = loss_weights 36 | 37 | self.device = torch.device("cuda" if use_cuda else "cpu") 38 | 39 | def train(self, epoch, data_loader, optimizer, current_lr=0.0, 40 | print_freq=100, train_tfLogger=None, is_debug=False, 41 | evaluator=None, test_loader=None, eval_tfLogger=None, 42 | test_dataset=None, test_freq=1000): 43 | 44 | self.model.train() 45 | 46 | batch_time = AverageMeter() 47 | data_time = AverageMeter() 48 | losses = AverageMeter() 49 | 50 | end = time.time() 51 | 52 | for i, inputs in enumerate(data_loader): 53 | self.model.train() 54 | self.iters += 1 55 | 56 | data_time.update(time.time() - end) 57 | 58 | input_dict = self._parse_data(inputs) 59 | output_dict = self._forward(input_dict) 60 | 61 | batch_size = input_dict['images'].size(0) 62 | 63 | total_loss = 0 64 | loss_dict = {} 65 | for k, loss in output_dict['losses'].items(): 66 | loss = loss.mean(dim=0, keepdim=True) 67 | total_loss += self.loss_weights[k] * loss 68 | loss_dict[k] = loss.item() 69 | # print('{0}: {1}'.format(k, loss.item())) 70 | 71 | losses.update(total_loss.item(), batch_size) 72 | 73 | optimizer.zero_grad() 74 | total_loss.backward() 75 | if self.grad_clip > 0: 76 | torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.grad_clip) 77 | optimizer.step() 78 | 79 | # # debug: check the parameters fixed or not. 80 | # print(self.model.parameters()) 81 | # for tag, value in self.model.named_parameters(): 82 | # if tag == 'module.base.resnet.layer4.0.conv1.weight': 83 | # print(value[:10,0,0,0]) 84 | # if tag == 'module.rec_head.decoder.attention_unit.sEmbed.weight': 85 | # print(value[0, :10]) 86 | 87 | batch_time.update(time.time() - end) 88 | end = time.time() 89 | 90 | if self.iters % print_freq == 0: 91 | print('[{}]\t' 92 | 'Epoch: [{}][{}/{}]\t' 93 | 'Time {:.3f} ({:.3f})\t' 94 | 'Data {:.3f} ({:.3f})\t' 95 | 'Loss {:.3f} ({:.3f})\t' 96 | # .format(strftime("%Y-%m-%d %H:%M:%S", gmtime()), 97 | .format(datetime.now().strftime('%Y-%m-%d %H:%M:%S'), 98 | epoch, i + 1, len(data_loader), 99 | batch_time.val, batch_time.avg, 100 | data_time.val, data_time.avg, 101 | losses.val, losses.avg)) 102 | 103 | #====== TensorBoard logging ======# 104 | if self.iters % print_freq*10 == 0: 105 | if train_tfLogger is not None: 106 | step = epoch * len(data_loader) + (i + 1) 107 | info = { 108 | 'lr': current_lr, 109 | 'loss': total_loss.item(), # this is total loss 110 | } 111 | ## add each loss 112 | for k, loss in loss_dict.items(): 113 | info[k] = loss 114 | for tag, value in info.items(): 115 | train_tfLogger.scalar_summary(tag, value, step) 116 | 117 | # if is_debug and (i + 1) % (print_freq*100) == 0: # this time-consuming and space-consuming 118 | # # (2) Log values and gradients of the parameters (histogram) 119 | # for tag, value in self.model.named_parameters(): 120 | # tag = tag.replace('.', '/') 121 | # train_tfLogger.histo_summary(tag, to_numpy(value.data), step) 122 | # train_tfLogger.histo_summary(tag+'/grad', to_numpy(value.grad.data), step) 123 | 124 | # # (3) Log the images 125 | # images, _, pids, _ = inputs 126 | # offsets = to_numpy(offsets) 127 | # info = { 128 | # 'images': to_numpy(images[:10]) 129 | # } 130 | # for tag, images in info.items(): 131 | # train_tfLogger.image_summary(tag, images, step) 132 | 133 | #====== evaluation ======# 134 | if self.iters % test_freq == 0: 135 | # only symmetry branch 136 | if 'loss_rec' not in output_dict['losses']: 137 | is_best = True 138 | # self.best_res is alwarys equal to 1.0 139 | self.best_res = evaluator.evaluate(test_loader, step=self.iters, tfLogger=eval_tfLogger, dataset=test_dataset) 140 | else: 141 | res = evaluator.evaluate(test_loader, step=self.iters, tfLogger=eval_tfLogger, dataset=test_dataset) 142 | 143 | if self.metric == 'accuracy': 144 | is_best = res > self.best_res 145 | self.best_res = max(res, self.best_res) 146 | elif self.metric == 'editdistance': 147 | is_best = res < self.best_res 148 | self.best_res = min(res, self.best_res) 149 | else: 150 | raise ValueError("Unsupported evaluation metric:", self.metric) 151 | 152 | print('\n * Finished iters {:3d} accuracy: {:5.1%} best: {:5.1%}{}\n'. 153 | format(self.iters, res, self.best_res, ' *' if is_best else '')) 154 | 155 | # if epoch < 1: 156 | # continue 157 | save_checkpoint({ 158 | 'state_dict': self.model.module.state_dict(), 159 | 'iters': self.iters, 160 | 'best_res': self.best_res, 161 | }, is_best, fpath=osp.join(self.logs_dir, 'checkpoint.pth.tar')) 162 | 163 | 164 | # collect garbage (not work) 165 | # gc.collect() 166 | 167 | def _parse_data(self, inputs): 168 | raise NotImplementedError 169 | 170 | def _forward(self, inputs, targets): 171 | raise NotImplementedError 172 | 173 | 174 | class Trainer(BaseTrainer): 175 | def _parse_data(self, inputs): 176 | input_dict = {} 177 | imgs, label_encs, lengths = inputs 178 | images = imgs.to(self.device) 179 | if label_encs is not None: 180 | labels = label_encs.to(self.device) 181 | 182 | input_dict['images'] = images 183 | input_dict['rec_targets'] = labels 184 | input_dict['rec_lengths'] = lengths 185 | return input_dict 186 | 187 | def _forward(self, input_dict): 188 | self.model.train() 189 | output_dict = self.model(input_dict) 190 | return output_dict -------------------------------------------------------------------------------- /lib/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import torch 4 | 5 | 6 | def to_numpy(tensor): 7 | if torch.is_tensor(tensor): 8 | return tensor.cpu().numpy() 9 | elif type(tensor).__module__ != 'numpy': 10 | raise ValueError("Cannot convert {} to numpy array" 11 | .format(type(tensor))) 12 | return tensor 13 | 14 | 15 | def to_torch(ndarray): 16 | if type(ndarray).__module__ == 'numpy': 17 | return torch.from_numpy(ndarray) 18 | elif not torch.is_tensor(ndarray): 19 | raise ValueError("Cannot convert {} to torch tensor" 20 | .format(type(ndarray))) 21 | return ndarray -------------------------------------------------------------------------------- /lib/utils/labelmaps.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import string 4 | 5 | from . import to_torch, to_numpy 6 | 7 | def get_vocabulary(voc_type, EOS='EOS', PADDING='PADDING', UNKNOWN='UNKNOWN'): 8 | ''' 9 | voc_type: str: one of 'LOWERCASE', 'ALLCASES', 'ALLCASES_SYMBOLS' 10 | ''' 11 | voc = None 12 | types = ['LOWERCASE', 'ALLCASES', 'ALLCASES_SYMBOLS'] 13 | if voc_type == 'LOWERCASE': 14 | voc = list(string.digits + string.ascii_lowercase) 15 | elif voc_type == 'ALLCASES': 16 | voc = list(string.digits + string.ascii_letters) 17 | elif voc_type == 'ALLCASES_SYMBOLS': 18 | voc = list(string.printable[:-6]) 19 | else: 20 | raise KeyError('voc_type must be one of "LOWERCASE", "ALLCASES", "ALLCASES_SYMBOLS"') 21 | 22 | # update the voc with specifical chars 23 | voc.append(EOS) 24 | voc.append(PADDING) 25 | voc.append(UNKNOWN) 26 | 27 | return voc 28 | 29 | ## param voc: the list of vocabulary 30 | def char2id(voc): 31 | return dict(zip(voc, range(len(voc)))) 32 | 33 | def id2char(voc): 34 | return dict(zip(range(len(voc)), voc)) 35 | 36 | def labels2strs(labels, id2char, char2id): 37 | # labels: batch_size x len_seq 38 | if labels.ndimension() == 1: 39 | labels = labels.unsqueeze(0) 40 | assert labels.dim() == 2 41 | labels = to_numpy(labels) 42 | strings = [] 43 | batch_size = labels.shape[0] 44 | 45 | for i in range(batch_size): 46 | label = labels[i] 47 | string = [] 48 | for l in label: 49 | if l == char2id['EOS']: 50 | break 51 | else: 52 | string.append(id2char[l]) 53 | string = ''.join(string) 54 | strings.append(string) 55 | 56 | return strings -------------------------------------------------------------------------------- /lib/utils/logging.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | import os 3 | import sys 4 | import numpy as np 5 | import tensorflow as tf 6 | import scipy.misc 7 | try: 8 | from StringIO import StringIO # Python 2.7 9 | except ImportError: 10 | from io import BytesIO # Python 3.x 11 | 12 | from .osutils import mkdir_if_missing 13 | 14 | from config import get_args 15 | global_args = get_args(sys.argv[1:]) 16 | 17 | if global_args.run_on_remote: 18 | import moxing as mox 19 | mox.file.shift("os", "mox") 20 | 21 | class Logger(object): 22 | def __init__(self, fpath=None): 23 | self.console = sys.stdout 24 | self.file = None 25 | if fpath is not None: 26 | if global_args.run_on_remote: 27 | dir_name = os.path.dirname(fpath) 28 | if not mox.file.exists(dir_name): 29 | mox.file.make_dirs(dir_name) 30 | print('=> making dir ', dir_name) 31 | self.file = mox.file.File(fpath, 'w') 32 | # self.file = open(fpath, 'w') 33 | else: 34 | mkdir_if_missing(os.path.dirname(fpath)) 35 | self.file = open(fpath, 'w') 36 | 37 | def __del__(self): 38 | self.close() 39 | 40 | def __enter__(self): 41 | pass 42 | 43 | def __exit__(self, *args): 44 | self.close() 45 | 46 | def write(self, msg): 47 | self.console.write(msg) 48 | if self.file is not None: 49 | self.file.write(msg) 50 | 51 | def flush(self): 52 | self.console.flush() 53 | if self.file is not None: 54 | self.file.flush() 55 | os.fsync(self.file.fileno()) 56 | 57 | def close(self): 58 | self.console.close() 59 | if self.file is not None: 60 | self.file.close() 61 | 62 | 63 | class TFLogger(object): 64 | def __init__(self, log_dir=None): 65 | """Create a summary writer logging to log_dir.""" 66 | if log_dir is not None: 67 | mkdir_if_missing(log_dir) 68 | self.writer = tf.summary.FileWriter(log_dir) 69 | 70 | def scalar_summary(self, tag, value, step): 71 | """Log a scalar variable.""" 72 | summary = tf.Summary(value=[tf.Summary.Value(tag=tag, simple_value=value)]) 73 | self.writer.add_summary(summary, step) 74 | self.writer.flush() 75 | 76 | def image_summary(self, tag, images, step): 77 | """Log a list of images.""" 78 | 79 | img_summaries = [] 80 | for i, img in enumerate(images): 81 | # Write the image to a string 82 | try: 83 | s = StringIO() 84 | except: 85 | s = BytesIO() 86 | scipy.misc.toimage(img).save(s, format="png") 87 | 88 | # Create an Image object 89 | img_sum = tf.Summary.Image(encoded_image_string=s.getvalue(), 90 | height=img.shape[0], 91 | width=img.shape[1]) 92 | # Create a Summary value 93 | img_summaries.append(tf.Summary.Value(tag='%s/%d' % (tag, i), image=img_sum)) 94 | 95 | # Create and write Summary 96 | summary = tf.Summary(value=img_summaries) 97 | self.writer.add_summary(summary, step) 98 | self.writer.flush() 99 | 100 | def histo_summary(self, tag, values, step, bins=1000): 101 | """Log a histogram of the tensor of values.""" 102 | 103 | # Create a histogram using numpy 104 | counts, bin_edges = np.histogram(values, bins=bins) 105 | 106 | # Fill the fields of the histogram proto 107 | hist = tf.HistogramProto() 108 | hist.min = float(np.min(values)) 109 | hist.max = float(np.max(values)) 110 | hist.num = int(np.prod(values.shape)) 111 | hist.sum = float(np.sum(values)) 112 | hist.sum_squares = float(np.sum(values**2)) 113 | 114 | # Drop the start of the first bin 115 | bin_edges = bin_edges[1:] 116 | 117 | # Add bin edges and counts 118 | for edge in bin_edges: 119 | hist.bucket_limit.append(edge) 120 | for c in counts: 121 | hist.bucket.append(c) 122 | 123 | # Create and write Summary 124 | summary = tf.Summary(value=[tf.Summary.Value(tag=tag, histo=hist)]) 125 | self.writer.add_summary(summary, step) 126 | self.writer.flush() 127 | 128 | def close(self): 129 | self.writer.close() -------------------------------------------------------------------------------- /lib/utils/meters.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | 4 | class AverageMeter(object): 5 | """Computes and stores the average and current value""" 6 | 7 | def __init__(self): 8 | self.val = 0 9 | self.avg = 0 10 | self.sum = 0 11 | self.count = 0 12 | 13 | def reset(self): 14 | self.val = 0 15 | self.avg = 0 16 | self.sum = 0 17 | self.count = 0 18 | 19 | def update(self, val, n=1): 20 | self.val = val 21 | self.sum += val * n 22 | self.count += n 23 | self.avg = self.sum / self.count -------------------------------------------------------------------------------- /lib/utils/osutils.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | import os 3 | import errno 4 | 5 | 6 | def mkdir_if_missing(dir_path): 7 | try: 8 | os.makedirs(dir_path) 9 | except OSError as e: 10 | if e.errno != errno.EEXIST: 11 | raise 12 | 13 | 14 | def make_symlink_if_not_exists(real_path, link_path): 15 | ''' 16 | param real_path: str the path linked 17 | param link_path: str the path with only the symbol 18 | ''' 19 | try: 20 | os.makedirs(real_path) 21 | except OSError as e: 22 | if e.errno != errno.EEXIST: 23 | raise 24 | 25 | cmd = 'ln -s {0} {1}'.format(real_path, link_path) 26 | os.system(cmd) -------------------------------------------------------------------------------- /lib/utils/serialization.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | import json 3 | import os 4 | import sys 5 | # import moxing as mox 6 | import os.path as osp 7 | import shutil 8 | 9 | import torch 10 | from torch.nn import Parameter 11 | 12 | from .osutils import mkdir_if_missing 13 | 14 | from config import get_args 15 | global_args = get_args(sys.argv[1:]) 16 | 17 | if global_args.run_on_remote: 18 | import moxing as mox 19 | 20 | 21 | def read_json(fpath): 22 | with open(fpath, 'r') as f: 23 | obj = json.load(f) 24 | return obj 25 | 26 | 27 | def write_json(obj, fpath): 28 | mkdir_if_missing(osp.dirname(fpath)) 29 | with open(fpath, 'w') as f: 30 | json.dump(obj, f, indent=4, separators=(',', ': ')) 31 | 32 | 33 | def save_checkpoint(state, is_best, fpath='checkpoint.pth.tar'): 34 | print('=> saving checkpoint ', fpath) 35 | if global_args.run_on_remote: 36 | dir_name = osp.dirname(fpath) 37 | if not mox.file.exists(dir_name): 38 | mox.file.make_dirs(dir_name) 39 | print('=> makding dir ', dir_name) 40 | local_path = "local_checkpoint.pth.tar" 41 | torch.save(state, local_path) 42 | mox.file.copy(local_path, fpath) 43 | if is_best: 44 | mox.file.copy(local_path, osp.join(dir_name, 'model_best.pth.tar')) 45 | else: 46 | mkdir_if_missing(osp.dirname(fpath)) 47 | torch.save(state, fpath) 48 | if is_best: 49 | shutil.copy(fpath, osp.join(osp.dirname(fpath), 'model_best.pth.tar')) 50 | 51 | 52 | def load_checkpoint(fpath): 53 | if global_args.run_on_remote: 54 | mox.file.shift('os', 'mox') 55 | checkpoint = torch.load(fpath) 56 | print("=> Loaded checkpoint '{}'".format(fpath)) 57 | return checkpoint 58 | else: 59 | load_path = fpath 60 | 61 | if osp.isfile(load_path): 62 | checkpoint = torch.load(load_path) 63 | print("=> Loaded checkpoint '{}'".format(load_path)) 64 | return checkpoint 65 | else: 66 | raise ValueError("=> No checkpoint found at '{}'".format(load_path)) 67 | 68 | 69 | def copy_state_dict(state_dict, model, strip=None): 70 | tgt_state = model.state_dict() 71 | copied_names = set() 72 | for name, param in state_dict.items(): 73 | if strip is not None and name.startswith(strip): 74 | name = name[len(strip):] 75 | if name not in tgt_state: 76 | continue 77 | if isinstance(param, Parameter): 78 | param = param.data 79 | if param.size() != tgt_state[name].size(): 80 | print('mismatch:', name, param.size(), tgt_state[name].size()) 81 | continue 82 | tgt_state[name].copy_(param) 83 | copied_names.add(name) 84 | 85 | missing = set(tgt_state.keys()) - copied_names 86 | if len(missing) > 0: 87 | print("missing keys in state_dict:", missing) 88 | 89 | return model -------------------------------------------------------------------------------- /lib/utils/visualization_utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from PIL import Image 4 | import os 5 | import numpy as np 6 | from collections import OrderedDict 7 | from scipy.misc import imresize 8 | import matplotlib 9 | matplotlib.use('Agg') 10 | import matplotlib.pyplot as plt 11 | from matplotlib.gridspec import GridSpec 12 | from io import BytesIO 13 | from multiprocessing import Pool 14 | import math 15 | import sys 16 | 17 | import torch 18 | from torch.nn import functional as F 19 | 20 | from . import to_torch, to_numpy 21 | from ..evaluation_metrics.metrics import get_str_list 22 | 23 | 24 | def recognition_vis(images, preds, targets, scores, dataset, vis_dir): 25 | images = images.permute(0,2,3,1) 26 | images = to_numpy(images) 27 | images = (images * 0.5 + 0.5)*255 28 | pred_list, targ_list = get_str_list(preds, targets, dataset) 29 | for id, (image, pred, target, score) in enumerate(zip(images, pred_list, targ_list, scores)): 30 | if pred.lower() == target.lower(): 31 | flag = 'right' 32 | else: 33 | flag = 'error' 34 | file_name = '{:}_{:}_{:}_{:}_{:.3f}.jpg'.format(flag, id, pred, target, score) 35 | file_path = os.path.join(vis_dir, file_name) 36 | image = Image.fromarray(np.uint8(image)) 37 | image.save(file_path) 38 | 39 | 40 | # save to disk sub process 41 | def _save_plot_pool(vis_image, save_file_path): 42 | vis_image = Image.fromarray(np.uint8(vis_image)) 43 | vis_image.save(save_file_path) 44 | 45 | 46 | def stn_vis(raw_images, rectified_images, ctrl_points, preds, targets, real_scores, pred_scores, dataset, vis_dir): 47 | """ 48 | raw_images: images without rectification 49 | rectified_images: rectified images with stn 50 | ctrl_points: predicted ctrl points 51 | preds: predicted label sequences 52 | targets: target label sequences 53 | real_scores: scores of recognition model 54 | pred_scores: predicted scores by the score branch 55 | dataset: xxx 56 | vis_dir: xxx 57 | """ 58 | if raw_images.ndimension() == 3: 59 | raw_images = raw_images.unsqueeze(0) 60 | rectified_images = rectified_images.unsqueeze(0) 61 | batch_size, _, raw_height, raw_width = raw_images.size() 62 | 63 | # translate the coordinates of ctrlpoints to image size 64 | ctrl_points = to_numpy(ctrl_points) 65 | ctrl_points[:,:,0] = ctrl_points[:,:,0] * (raw_width-1) 66 | ctrl_points[:,:,1] = ctrl_points[:,:,1] * (raw_height-1) 67 | ctrl_points = ctrl_points.astype(np.int) 68 | 69 | # tensors to pil images 70 | raw_images = raw_images.permute(0,2,3,1) 71 | raw_images = to_numpy(raw_images) 72 | raw_images = (raw_images * 0.5 + 0.5)*255 73 | rectified_images = rectified_images.permute(0,2,3,1) 74 | rectified_images = to_numpy(rectified_images) 75 | rectified_images = (rectified_images * 0.5 + 0.5)*255 76 | 77 | # draw images on canvas 78 | vis_images = [] 79 | num_sub_plot = 2 80 | raw_images = raw_images.astype(np.uint8) 81 | rectified_images = rectified_images.astype(np.uint8) 82 | for i in range(batch_size): 83 | fig = plt.figure() 84 | ax = [fig.add_subplot(num_sub_plot,1,i+1) for i in range(num_sub_plot)] 85 | for a in ax: 86 | a.set_xticklabels([]) 87 | a.set_yticklabels([]) 88 | a.axis('off') 89 | ax[0].imshow(raw_images[i]) 90 | ax[0].scatter(ctrl_points[i,:,0], ctrl_points[i,:,1], marker='+', s=5) 91 | ax[1].imshow(rectified_images[i]) 92 | # plt.subplots_adjust(wspace=0, hspace=0) 93 | plt.show() 94 | buffer_ = BytesIO() 95 | plt.savefig(buffer_, format='png', bbox_inches='tight', pad_inches=0) 96 | plt.close() 97 | buffer_.seek(0) 98 | dataPIL = Image.open(buffer_) 99 | data = np.asarray(dataPIL).astype(np.uint8) 100 | buffer_.close() 101 | 102 | vis_images.append(data) 103 | 104 | # save to disk 105 | if vis_dir is None: 106 | return vis_images 107 | else: 108 | pred_list, targ_list = get_str_list(preds, targets, dataset) 109 | file_path_list = [] 110 | for id, (image, pred, target, real_score) in enumerate(zip(vis_images, pred_list, targ_list, real_scores)): 111 | if pred.lower() == target.lower(): 112 | flag = 'right' 113 | else: 114 | flag = 'error' 115 | if pred_scores is None: 116 | file_name = '{:}_{:}_{:}_{:}_{:.3f}.png'.format(flag, id, pred, target, real_score) 117 | else: 118 | file_name = '{:}_{:}_{:}_{:}_{:.3f}_{:.3f}.png'.format(flag, id, pred, target, real_score, pred_scores[id]) 119 | file_path = os.path.join(vis_dir, file_name) 120 | file_path_list.append(file_path) 121 | 122 | with Pool(os.cpu_count()) as pool: 123 | pool.starmap(_save_plot_pool, zip(vis_images, file_path_list)) -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | import sys 3 | sys.path.append('./') 4 | 5 | import argparse 6 | import os 7 | os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" 8 | 9 | import os.path as osp 10 | import numpy as np 11 | import math 12 | import time 13 | 14 | import torch 15 | from torch import nn, optim 16 | from torch.backends import cudnn 17 | from torch.utils.data import DataLoader, SubsetRandomSampler 18 | 19 | from config import get_args 20 | from lib import datasets, evaluation_metrics, models 21 | from lib.models.model_builder import ModelBuilder 22 | from lib.datasets.dataset import LmdbDataset, AlignCollate 23 | from lib.datasets.concatdataset import ConcatDataset 24 | from lib.loss import SequenceCrossEntropyLoss 25 | from lib.trainers import Trainer 26 | from lib.evaluators import Evaluator 27 | from lib.utils.logging import Logger, TFLogger 28 | from lib.utils.serialization import load_checkpoint, save_checkpoint 29 | from lib.utils.osutils import make_symlink_if_not_exists 30 | 31 | global_args = get_args(sys.argv[1:]) 32 | 33 | 34 | def get_data(data_dir, voc_type, max_len, num_samples, height, width, batch_size, workers, is_train, keep_ratio): 35 | if isinstance(data_dir, list): 36 | dataset_list = [] 37 | for data_dir_ in data_dir: 38 | dataset_list.append(LmdbDataset(data_dir_, voc_type, max_len, num_samples)) 39 | dataset = ConcatDataset(dataset_list) 40 | else: 41 | dataset = LmdbDataset(data_dir, voc_type, max_len, num_samples) 42 | print('total image: ', len(dataset)) 43 | 44 | if is_train: 45 | data_loader = DataLoader(dataset, batch_size=batch_size, num_workers=workers, 46 | shuffle=True, pin_memory=True, drop_last=True, 47 | collate_fn=AlignCollate(imgH=height, imgW=width, keep_ratio=keep_ratio)) 48 | else: 49 | data_loader = DataLoader(dataset, batch_size=batch_size, num_workers=workers, 50 | shuffle=False, pin_memory=True, drop_last=False, 51 | collate_fn=AlignCollate(imgH=height, imgW=width, keep_ratio=keep_ratio)) 52 | 53 | return dataset, data_loader 54 | 55 | 56 | def get_dataset(data_dir, voc_type, max_len, num_samples): 57 | if isinstance(data_dir, list): 58 | dataset_list = [] 59 | for data_dir_ in data_dir: 60 | dataset_list.append(LmdbDataset(data_dir_, voc_type, max_len, num_samples)) 61 | dataset = ConcatDataset(dataset_list) 62 | else: 63 | dataset = LmdbDataset(data_dir, voc_type, max_len, num_samples) 64 | print('total image: ', len(dataset)) 65 | return dataset 66 | 67 | 68 | def get_dataloader(synthetic_dataset, real_dataset, height, width, batch_size, workers, 69 | is_train, keep_ratio): 70 | num_synthetic_dataset = len(synthetic_dataset) 71 | num_real_dataset = len(real_dataset) 72 | 73 | synthetic_indices = list(np.random.permutation(num_synthetic_dataset)) 74 | synthetic_indices = synthetic_indices[num_real_dataset:] 75 | real_indices = list(np.random.permutation(num_real_dataset) + num_synthetic_dataset) 76 | concated_indices = synthetic_indices + real_indices 77 | assert len(concated_indices) == num_synthetic_dataset 78 | 79 | sampler = SubsetRandomSampler(concated_indices) 80 | concated_dataset = ConcatDataset([synthetic_dataset, real_dataset]) 81 | print('total image: ', len(concated_dataset)) 82 | 83 | data_loader = DataLoader(concated_dataset, batch_size=batch_size, num_workers=workers, 84 | shuffle=False, pin_memory=True, drop_last=True, sampler=sampler, 85 | collate_fn=AlignCollate(imgH=height, imgW=width, keep_ratio=keep_ratio)) 86 | return concated_dataset, data_loader 87 | 88 | def main(args): 89 | np.random.seed(args.seed) 90 | torch.manual_seed(args.seed) 91 | torch.cuda.manual_seed(args.seed) 92 | torch.cuda.manual_seed_all(args.seed) 93 | cudnn.benchmark = True 94 | torch.backends.cudnn.deterministic = True 95 | 96 | args.cuda = args.cuda and torch.cuda.is_available() 97 | if args.cuda: 98 | print('using cuda.') 99 | torch.set_default_tensor_type('torch.cuda.FloatTensor') 100 | else: 101 | torch.set_default_tensor_type('torch.FloatTensor') 102 | 103 | # Redirect print to both console and log file 104 | if not args.evaluate: 105 | # make symlink 106 | make_symlink_if_not_exists(osp.join(args.real_logs_dir, args.logs_dir), osp.dirname(osp.normpath(args.logs_dir))) 107 | sys.stdout = Logger(osp.join(args.logs_dir, 'log.txt')) 108 | train_tfLogger = TFLogger(osp.join(args.logs_dir, 'train')) 109 | eval_tfLogger = TFLogger(osp.join(args.logs_dir, 'eval')) 110 | 111 | # Save the args to disk 112 | if not args.evaluate: 113 | cfg_save_path = osp.join(args.logs_dir, 'cfg.txt') 114 | cfgs = vars(args) 115 | with open(cfg_save_path, 'w') as f: 116 | for k, v in cfgs.items(): 117 | f.write('{}: {}\n'.format(k, v)) 118 | 119 | # Create data loaders 120 | if args.height is None or args.width is None: 121 | args.height, args.width = (32, 100) 122 | 123 | if not args.evaluate: 124 | train_dataset, train_loader = \ 125 | get_data(args.synthetic_train_data_dir, args.voc_type, args.max_len, args.num_train, 126 | args.height, args.width, args.batch_size, args.workers, True, args.keep_ratio) 127 | test_dataset, test_loader = \ 128 | get_data(args.test_data_dir, args.voc_type, args.max_len, args.num_test, 129 | args.height, args.width, args.batch_size, args.workers, False, args.keep_ratio) 130 | 131 | if args.evaluate: 132 | max_len = test_dataset.max_len 133 | else: 134 | max_len = max(train_dataset.max_len, test_dataset.max_len) 135 | train_dataset.max_len = test_dataset.max_len = max_len 136 | # Create model 137 | model = ModelBuilder(arch=args.arch, rec_num_classes=test_dataset.rec_num_classes, 138 | sDim=args.decoder_sdim, attDim=args.attDim, max_len_labels=max_len, 139 | eos=test_dataset.char2id[test_dataset.EOS], STN_ON=args.STN_ON) 140 | 141 | # Load from checkpoint 142 | if args.evaluation_metric == 'accuracy': 143 | best_res = 0 144 | elif args.evaluation_metric == 'editdistance': 145 | best_res = math.inf 146 | else: 147 | raise ValueError("Unsupported evaluation metric:", args.evaluation_metric) 148 | start_epoch = 0 149 | start_iters = 0 150 | if args.resume: 151 | checkpoint = load_checkpoint(args.resume) 152 | model.load_state_dict(checkpoint['state_dict']) 153 | 154 | # compatibility with the epoch-wise evaluation version 155 | if 'epoch' in checkpoint.keys(): 156 | start_epoch = checkpoint['epoch'] 157 | else: 158 | start_iters = checkpoint['iters'] 159 | start_epoch = int(start_iters // len(train_loader)) if not args.evaluate else 0 160 | best_res = checkpoint['best_res'] 161 | print("=> Start iters {} best res {:.1%}" 162 | .format(start_iters, best_res)) 163 | 164 | if args.cuda: 165 | device = torch.device("cuda") 166 | model = model.to(device) 167 | model = nn.DataParallel(model) 168 | 169 | # Evaluator 170 | evaluator = Evaluator(model, args.evaluation_metric, args.cuda) 171 | 172 | if args.evaluate: 173 | print('Test on {0}:'.format(args.test_data_dir)) 174 | if len(args.vis_dir) > 0: 175 | vis_dir = osp.join(args.logs_dir, args.vis_dir) 176 | if not osp.exists(vis_dir): 177 | os.makedirs(vis_dir) 178 | else: 179 | vis_dir = None 180 | 181 | start = time.time() 182 | evaluator.evaluate(test_loader, dataset=test_dataset, vis_dir=vis_dir) 183 | print('it took {0} s.'.format(time.time() - start)) 184 | return 185 | 186 | # Optimizer 187 | param_groups = model.parameters() 188 | param_groups = filter(lambda p: p.requires_grad, param_groups) 189 | optimizer = optim.Adadelta(param_groups, lr=args.lr, weight_decay=args.weight_decay) 190 | scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[4,5], gamma=0.1) 191 | 192 | # Trainer 193 | loss_weights = {} 194 | loss_weights['loss_rec'] = 1. 195 | if args.debug: 196 | args.print_freq = 1 197 | trainer = Trainer(model, args.evaluation_metric, args.logs_dir, 198 | iters=start_iters, best_res=best_res, grad_clip=args.grad_clip, 199 | use_cuda=args.cuda, loss_weights=loss_weights) 200 | 201 | # Start training 202 | evaluator.evaluate(test_loader, step=0, tfLogger=eval_tfLogger, dataset=test_dataset) 203 | for epoch in range(start_epoch, args.epochs): 204 | scheduler.step(epoch) 205 | current_lr = optimizer.param_groups[0]['lr'] 206 | trainer.train(epoch, train_loader, optimizer, current_lr, 207 | print_freq=args.print_freq, 208 | train_tfLogger=train_tfLogger, 209 | is_debug=args.debug, 210 | evaluator=evaluator, 211 | test_loader=test_loader, 212 | eval_tfLogger=eval_tfLogger, 213 | test_dataset=test_dataset) 214 | 215 | # Final test 216 | print('Test with best model:') 217 | checkpoint = load_checkpoint(osp.join(args.logs_dir, 'model_best.pth.tar')) 218 | model.module.load_state_dict(checkpoint['state_dict']) 219 | evaluator.evaluate(test_loader, dataset=test_dataset) 220 | 221 | # Close the tensorboard logger 222 | train_tfLogger.close() 223 | eval_tfLogger.close() 224 | 225 | 226 | if __name__ == '__main__': 227 | # parse the config 228 | args = get_args(sys.argv[1:]) 229 | main(args) -------------------------------------------------------------------------------- /overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ayumiymk/aster.pytorch/be670046c775b54de79766208f0c59321ae1eccf/overview.png -------------------------------------------------------------------------------- /scripts/main_test_all.sh: -------------------------------------------------------------------------------- 1 | array=( 2 | "IIIT5K_3000" 3 | "svt_647" 4 | "ic03_867" 5 | "ic13_1015" 6 | "ic15_1811" 7 | "svt_p_645" 8 | "cute80_288") 9 | 10 | DLS_DATA_URL="xxx" 11 | DLS_TRAIN_URL="" 12 | for i in "${array[@]}" 13 | do 14 | echo $i 15 | CUDA_VISIBLE_DEVICES=0,1 python main.py \ 16 | --synthetic_train_data_dir xxx \ 17 | --test_data_dir ${DLS_DATA_URL}$i \ 18 | --batch_size 1024 \ 19 | --workers 8 \ 20 | --height 64 \ 21 | --width 256 \ 22 | --voc_type ALLCASES_SYMBOLS \ 23 | --arch ResNet_ASTER \ 24 | --with_lstm \ 25 | --logs_dir logs/baseline_aster \ 26 | --real_logs_dir xxx \ 27 | --max_len 100 \ 28 | --evaluate \ 29 | --STN_ON \ 30 | --beam_width 5 \ 31 | --tps_inputsize 32 64 \ 32 | --tps_outputsize 32 100 \ 33 | --tps_margins 0.05 0.05 \ 34 | --stn_activation none \ 35 | --num_control_points 20 \ 36 | --resume xxx 37 | done -------------------------------------------------------------------------------- /scripts/main_test_image.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0 python demo.py \ 2 | --height 64 \ 3 | --width 256 \ 4 | --voc_type ALLCASES_SYMBOLS \ 5 | --arch ResNet_ASTER \ 6 | --with_lstm \ 7 | --max_len 100 \ 8 | --STN_ON \ 9 | --beam_width 5 \ 10 | --tps_inputsize 32 64 \ 11 | --tps_outputsize 32 100 \ 12 | --tps_margins 0.05 0.05 \ 13 | --stn_activation none \ 14 | --num_control_points 20 \ 15 | --resume /data/mkyang/logs/recognition/aster.pytorch/logs/baseline_aster/baseline_aster/demo.pth.tar \ 16 | --image_path ./data/demo.png -------------------------------------------------------------------------------- /scripts/stn_att_rec.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0,1 python main.py \ 2 | --synthetic_train_data_dir /data/mkyang/scene_text/recognition/CVPR2016/ /data/mkyang/scene_text/recognition/NIPS2014/ \ 3 | --test_data_dir /data/mkyang/scene_text/recognition/benchmark_lmdbs_new/IIIT5K_3000/ \ 4 | --batch_size 1024 \ 5 | --workers 8 \ 6 | --height 64 \ 7 | --width 256 \ 8 | --voc_type ALLCASES_SYMBOLS \ 9 | --arch ResNet_ASTER \ 10 | --with_lstm \ 11 | --logs_dir logs/baseline_aster \ 12 | --real_logs_dir /data/mkyang/logs/recognition/aster.pytorch \ 13 | --max_len 100 \ 14 | --STN_ON \ 15 | --tps_inputsize 32 64 \ 16 | --tps_outputsize 32 100 \ 17 | --tps_margins 0.05 0.05 \ 18 | --stn_activation none \ 19 | --num_control_points 20 \ --------------------------------------------------------------------------------