├── .gitignore ├── LICENSE ├── README.md ├── alphabets.py ├── dataset.py ├── demo.py ├── demo ├── demo.jpg └── illegal_character.png ├── models ├── __init__.py └── crnn.py ├── params.py ├── tool ├── convert_t7.lua ├── convert_t7.py └── create_dataset.py ├── train.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | *.pth 3 | *.pyc 4 | *.pyo 5 | *.log 6 | *.tmp 7 | *.swp 8 | *.out 9 | *.mdb 10 | *__pycache__/ 11 | *.vscode/ -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Holmeyoung 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Convolutional Recurrent Neural Network + CTCLoss 2 | 3 | I think i have fixed the ctcloss nan problem! 4 | 5 | Now! 6 | 7 | Please pull the latest code from master. 8 | 9 | Please update the pytorch to `>= v1.2.0` 10 | 11 | Enjoy it! 12 | 13 | > PS: Once there is ctclossnan, please 14 | > 1. Change the `batchSize` to smaller (eg: 8, 16, 32) 15 | > 2. Change the `lr` to smaller (eg: 0.00001, 0.0001) 16 | > 3. Contact me by emailing to holmeyoung@gmail.com 17 | 18 | ## Dependence 19 | 20 | - CentOS7 21 | - Python3.6.5 22 | - torch==1.2.0 23 | - torchvision==0.4.0 24 | - Tesla P40 - Nvidia 25 | 26 | ## Run demo 27 | 28 | - Download a pretrained model from [Baidu Cloud](https://pan.baidu.com/s/1FmJhYf1Wy-LUaz4V2WpF7g) (extraction code: `si32`) 29 | - People who cannot access Baidu can download a copy from [Google Drive](https://drive.google.com/drive/folders/1FhXvPtitX6tWYocFZiZBRzVHjK2o640u?usp=sharing) 30 | 31 | - Run demo 32 | 33 | ```sh 34 | python demo.py -m path/to/model -i data/demo.jpg 35 | ``` 36 | 37 | ![demo](https://raw.githubusercontent.com/Holmeyoung/crnn_pytorch/master/demo/demo.jpg) 38 | 39 | Expected output 40 | 41 | ```sh 42 | -妳----真---的的---可---------以 => 妳真的可以 43 | ``` 44 | 45 | 46 | 47 | ## Feature 48 | 49 | - Variable length 50 | 51 | It support variable length. 52 | 53 | 54 | 55 | - Chinese support 56 | 57 | I change it to `binary mode` when reading the key and value, so you can use it to do Chinese OCR. 58 | 59 | 60 | 61 | - Change CTCLoss from [warp-ctc](https://github.com/SeanNaren/warp-ctc) to [torch.nn.CTCLoss](https://pytorch.org/docs/stable/nn.html#ctcloss) 62 | 63 | As we know, warp-ctc need to compile and it seems that it only support PyTorch 0.4. But PyTorch support CTCLoss itself, so i change the loss function to `torch.nn.CTCLoss` . 64 | 65 | 66 | 67 | - Solved PyTorch CTCLoss become `nan` after several epoch 68 | 69 | Just don't know why, but when i train the net, the loss always become `nan` after several epoch. 70 | 71 | I add a param `dealwith_lossnan` to `params.py` . If set it to `True` , the net will autocheck and replace all `nan/inf` in gradients to zero. 72 | 73 | 74 | 75 | - DataParallel 76 | 77 | I add a param `multi_gpu` to `params.py` . If you want to use multi gpu to train your net, please set it to `True` and set the param `ngpu` to a proper number. 78 | 79 | 80 | 81 | ## Train your data 82 | 83 | ### Prepare data 84 | 85 | #### Folder mode 86 | 87 | 1. Put your images in a folder and organize your images in the following format: 88 | 89 | `label_number.jpg` 90 | 91 | For example 92 | 93 | - English 94 | 95 | ```sh 96 | hi_0.jpg hello_1.jpg English_2.jpg English_3.jpg E n g l i s h_4.jpg... 97 | ``` 98 | 99 | - Chinese 100 | 101 | ```sh 102 | 一身转战_0.jpg 三千里_1.jpg 一剑曾当百万师_2.jpg 一剑曾当百万师_3.jpg 一 剑 曾 当 百 万 师_3.jpg ... 103 | ``` 104 | 105 | So you can see, the number is used to distinguish the same label. 106 | 107 | 108 | 109 | 2. Run the `create_dataset.py` in `tool` folder by 110 | 111 | ```sh 112 | python tool/create_dataset.py --out lmdb/data/output/path --folder path/to/folder 113 | ``` 114 | 115 | 116 | 117 | 3. Use the same step to create train and val data. 118 | 119 | 120 | 121 | 4. The advantage of the folder mode is that it's convenient! But due to some illegal character can't be in the path 122 | 123 | ![Illegal character](https://raw.githubusercontent.com/Holmeyoung/crnn_pytorch/master/demo/illegal_character.png) 124 | 125 | So the disadvantage of the folder mode is that it's labels are limited. 126 | 127 | 128 | 129 | #### File mode 130 | 131 | 1. Your data file should like 132 | 133 | ```sh 134 | absolute/path/to/image/一身转战_0.jpg 135 | 一身转战 136 | absolute/path/to/image/三千里_1.jpg 137 | 三千里 138 | absolute/path/to/image/一剑曾当百万师_2.jpg 139 | 一剑曾当百万师 140 | absolute/path/to/image/3.jpg 141 | 一剑曾当百万师 142 | absolute/path/to/image/一 剑 曾 当 百 万 师_4.jpg 143 | 一 剑 曾 当 百 万 师 144 | absolute/path/to/image/xxx.jpg 145 | label of xxx.jpg 146 | . 147 | . 148 | . 149 | ``` 150 | 151 | > DO REMEMBER: 152 | > 153 | > 1. It must be the absolute path to image. 154 | > 2. The first line can't be empty. 155 | > 3. There are no blank line between two data. 156 | 157 | 158 | 159 | 2. Run the `create_dataset.py` in `tool` folder by 160 | 161 | ```sh 162 | python tool/create_dataset.py --out lmdb/data/output/path --file path/to/file 163 | ``` 164 | 165 | 166 | 167 | 3. Use the same step to create train and val data. 168 | 169 | 170 | 171 | ### Change parameters and alphabets 172 | 173 | Parameters and alphabets can't always be the same in different situation. 174 | 175 | - Change parameters 176 | 177 | Your can see the `params.py` in detail. 178 | 179 | - Change alphabets 180 | 181 | Please put all the alphabets appeared in your labels to `alphabets.py` , or the program will throw error during training process. 182 | 183 | 184 | 185 | ### Train 186 | 187 | Run `train.py` by 188 | 189 | ```sh 190 | python train.py --trainroot path/to/train/dataset --valroot path/to/val/dataset 191 | ``` 192 | 193 | 194 | 195 | ## Reference 196 | 197 | [meijieru/crnn.pytorch]() 198 | 199 | [Sierkinhane/crnn_chinese_characters_rec]() 200 | 201 | -------------------------------------------------------------------------------- /alphabets.py: -------------------------------------------------------------------------------- 1 | alphabet = """的 2 | 是 3 | 不 4 | 我 5 | 一 6 | 有 7 | 大 8 | 在 9 | 人 10 | 了 11 | 中 12 | 到 13 | 資 14 | 要 15 | 可 16 | 以 17 | 這 18 | 個 19 | 你 20 | 會 21 | 好 22 | 為 23 | 上 24 | 來 25 | 就 26 | 學 27 | 交 28 | 也 29 | 用 30 | 能 31 | 如 32 | 文 33 | 時 34 | 沒 35 | 說 36 | 他 37 | 看 38 | 提 39 | 那 40 | 問 41 | 生 42 | 過 43 | 下 44 | 請 45 | 天 46 | 們 47 | 所 48 | 多 49 | 麼 50 | 小 51 | 想 52 | 得 53 | 之 54 | 還 55 | 電 56 | 出 57 | 工 58 | 對 59 | 都 60 | 機 61 | 自 62 | 後 63 | 子 64 | 而 65 | 訊 66 | 站 67 | 去 68 | 心 69 | 只 70 | 家 71 | 知 72 | 國 73 | 台 74 | 很 75 | 信 76 | 成 77 | 章 78 | 何 79 | 同 80 | 道 81 | 地 82 | 發 83 | 法 84 | 無 85 | 然 86 | 但 87 | 嗎 88 | 當 89 | 於 90 | 本 91 | 現 92 | 年 93 | 前 94 | 真 95 | 最 96 | 和 97 | 新 98 | 因 99 | 果 100 | 定 101 | 意 102 | 情 103 | 點 104 | 題 105 | 其 106 | 事 107 | 方 108 | 清 109 | 科 110 | 樣 111 | 些 112 | 吧 113 | 三 114 | 此 115 | 位 116 | 理 117 | 行 118 | 作 119 | 經 120 | 者 121 | 什 122 | 謝 123 | 名 124 | 日 125 | 正 126 | 華 127 | 話 128 | 開 129 | 實 130 | 再 131 | 城 132 | 愛 133 | 與 134 | 二 135 | 動 136 | 比 137 | 高 138 | 面 139 | 又 140 | 車 141 | 力 142 | 或 143 | 種 144 | 像 145 | 應 146 | 女 147 | 教 148 | 分 149 | 手 150 | 打 151 | 已 152 | 次 153 | 長 154 | 太 155 | 明 156 | 己 157 | 路 158 | 起 159 | 相 160 | 主 161 | 關 162 | 鳳 163 | 間 164 | 呢 165 | 覺 166 | 該 167 | 十 168 | 外 169 | 凰 170 | 友 171 | 才 172 | 民 173 | 系 174 | 進 175 | 使 176 | 她 177 | 著 178 | 各 179 | 少 180 | 全 181 | 兩 182 | 回 183 | 加 184 | 將 185 | 感 186 | 第 187 | 性 188 | 球 189 | 式 190 | 把 191 | 被 192 | 老 193 | 公 194 | 龍 195 | 程 196 | 論 197 | 及 198 | 別 199 | 給 200 | 聽 201 | 水 202 | 重 203 | 體 204 | 做 205 | 校 206 | 裡 207 | 常 208 | 東 209 | 風 210 | 您 211 | 灣 212 | 啦 213 | 見 214 | 解 215 | 等 216 | 部 217 | 原 218 | 月 219 | 美 220 | 先 221 | 管 222 | 區 223 | 錯 224 | 音 225 | 否 226 | 啊 227 | 找 228 | 網 229 | 樂 230 | 讓 231 | 通 232 | 入 233 | 期 234 | 選 235 | 較 236 | 四 237 | 場 238 | 由 239 | 書 240 | 它 241 | 快 242 | 從 243 | 歡 244 | 數 245 | 表 246 | 怎 247 | 至 248 | 立 249 | 內 250 | 合 251 | 目 252 | 望 253 | 認 254 | 幾 255 | 社 256 | 告 257 | 更 258 | 版 259 | 度 260 | 考 261 | 喜 262 | 頭 263 | 難 264 | 光 265 | 買 266 | 今 267 | 身 268 | 許 269 | 弟 270 | 若 271 | 算 272 | 記 273 | 代 274 | 統 275 | 處 276 | 完 277 | 號 278 | 接 279 | 言 280 | 政 281 | 玩 282 | 師 283 | 字 284 | 並 285 | 男 286 | 計 287 | 誰 288 | 山 289 | 張 290 | 黨 291 | 每 292 | 且 293 | 結 294 | 改 295 | 非 296 | 星 297 | 連 298 | 哈 299 | 建 300 | 放 301 | 直 302 | 轉 303 | 報 304 | 活 305 | 設 306 | 變 307 | 指 308 | 氣 309 | 研 310 | 陳 311 | 試 312 | 西 313 | 五 314 | 希 315 | 取 316 | 神 317 | 化 318 | 物 319 | 王 320 | 戰 321 | 近 322 | 世 323 | 受 324 | 義 325 | 反 326 | 單 327 | 死 328 | 任 329 | 跟 330 | 便 331 | 空 332 | 林 333 | 士 334 | 臺 335 | 卻 336 | 北 337 | 隊 338 | 功 339 | 必 340 | 聲 341 | 寫 342 | 平 343 | 影 344 | 業 345 | 金 346 | 檔 347 | 片 348 | 討 349 | 色 350 | 容 351 | 央 352 | 妳 353 | 向 354 | 市 355 | 則 356 | 員 357 | 興 358 | 利 359 | 強 360 | 白 361 | 價 362 | 安 363 | 呵 364 | 特 365 | 思 366 | 叫 367 | 總 368 | 辦 369 | 保 370 | 花 371 | 議 372 | 傳 373 | 元 374 | 求 375 | 份 376 | 件 377 | 持 378 | 萬 379 | 未 380 | 究 381 | 決 382 | 投 383 | 哪 384 | 喔 385 | 笑 386 | 貓 387 | 組 388 | 獨 389 | 級 390 | 走 391 | 支 392 | 曾 393 | 標 394 | 流 395 | 竹 396 | 兄 397 | 阿 398 | 室 399 | 卡 400 | 馬 401 | 共 402 | 需 403 | 海 404 | 口 405 | 門 406 | 般 407 | 線 408 | 語 409 | 命 410 | 觀 411 | 視 412 | 朋 413 | 聯 414 | 參 415 | 格 416 | 黃 417 | 錢 418 | 修 419 | 失 420 | 兒 421 | 住 422 | 八 423 | 腦 424 | 板 425 | 吃 426 | 另 427 | 換 428 | 即 429 | 象 430 | 料 431 | 錄 432 | 拿 433 | 專 434 | 遠 435 | 速 436 | 基 437 | 幫 438 | 形 439 | 確 440 | 候 441 | 裝 442 | 孩 443 | 備 444 | 歌 445 | 界 446 | 除 447 | 南 448 | 器 449 | 畫 450 | 訴 451 | 差 452 | 講 453 | 類 454 | 英 455 | 案 456 | 帶 457 | 久 458 | 乎 459 | 掉 460 | 迷 461 | 量 462 | 引 463 | 整 464 | 似 465 | 耶 466 | 奇 467 | 制 468 | 邊 469 | 型 470 | 超 471 | 識 472 | 雖 473 | 怪 474 | 飛 475 | 始 476 | 品 477 | 運 478 | 賽 479 | 費 480 | 夢 481 | 故 482 | 班 483 | 權 484 | 破 485 | 驗 486 | 眼 487 | 滿 488 | 念 489 | 造 490 | 軍 491 | 精 492 | 務 493 | 留 494 | 服 495 | 六 496 | 圖 497 | 收 498 | 舍 499 | 半 500 | 讀 501 | 願 502 | 李 503 | 底 504 | 約 505 | 雄 506 | 課 507 | 答 508 | 令 509 | 深 510 | 票 511 | 達 512 | 演 513 | 早 514 | 賣 515 | 棒 516 | 夠 517 | 黑 518 | 院 519 | 假 520 | 曲 521 | 火 522 | 準 523 | 百 524 | 談 525 | 勝 526 | 碟 527 | 術 528 | 推 529 | 存 530 | 治 531 | 離 532 | 易 533 | 往 534 | 況 535 | 晚 536 | 示 537 | 證 538 | 段 539 | 導 540 | 傷 541 | 調 542 | 團 543 | 七 544 | 永 545 | 剛 546 | 哥 547 | 甚 548 | 德 549 | 殺 550 | 怕 551 | 包 552 | 列 553 | 概 554 | 照 555 | 夜 556 | 排 557 | 客 558 | 絕 559 | 軟 560 | 商 561 | 根 562 | 九 563 | 切 564 | 條 565 | 集 566 | 千 567 | 落 568 | 竟 569 | 越 570 | 待 571 | 忘 572 | 盡 573 | 據 574 | 雙 575 | 供 576 | 稱 577 | 座 578 | 值 579 | 消 580 | 產 581 | 紅 582 | 跑 583 | 嘛 584 | 園 585 | 附 586 | 硬 587 | 雲 588 | 遊 589 | 展 590 | 執 591 | 聞 592 | 唱 593 | 育 594 | 斯 595 | 某 596 | 技 597 | 唉 598 | 息 599 | 苦 600 | 質 601 | 油 602 | 救 603 | 效 604 | 須 605 | 介 606 | 首 607 | 助 608 | 職 609 | 例 610 | 熱 611 | 畢 612 | 節 613 | 害 614 | 擊 615 | 亂 616 | 態 617 | 嗯 618 | 寶 619 | 倒 620 | 注 621 | 停 622 | 古 623 | 輸 624 | 規 625 | 福 626 | 親 627 | 查 628 | 復 629 | 步 630 | 舉 631 | 魚 632 | 斷 633 | 終 634 | 輕 635 | 環 636 | 練 637 | 印 638 | 隨 639 | 依 640 | 趣 641 | 限 642 | 響 643 | 省 644 | 局 645 | 續 646 | 司 647 | 角 648 | 簡 649 | 極 650 | 幹 651 | 篇 652 | 羅 653 | 佛 654 | 克 655 | 陽 656 | 武 657 | 疑 658 | 送 659 | 拉 660 | 習 661 | 源 662 | 免 663 | 志 664 | 鳥 665 | 煩 666 | 足 667 | 館 668 | 仍 669 | 低 670 | 廣 671 | 土 672 | 呀 673 | 樓 674 | 壞 675 | 兵 676 | 顯 677 | 率 678 | 聖 679 | 碼 680 | 眾 681 | 爭 682 | 初 683 | 誤 684 | 楚 685 | 責 686 | 境 687 | 野 688 | 預 689 | 具 690 | 智 691 | 壓 692 | 係 693 | 青 694 | 貴 695 | 順 696 | 負 697 | 魔 698 | 適 699 | 哇 700 | 測 701 | 慢 702 | 懷 703 | 懂 704 | 史 705 | 配 706 | 嗚 707 | 味 708 | 亦 709 | 醫 710 | 迎 711 | 舞 712 | 戀 713 | 細 714 | 灌 715 | 甲 716 | 帝 717 | 句 718 | 屬 719 | 靈 720 | 評 721 | 騎 722 | 宜 723 | 敗 724 | 左 725 | 追 726 | 狂 727 | 敢 728 | 春 729 | 狗 730 | 際 731 | 遇 732 | 族 733 | 群 734 | 痛 735 | 右 736 | 康 737 | 佳 738 | 楊 739 | 木 740 | 病 741 | 戲 742 | 項 743 | 抓 744 | 徵 745 | 善 746 | 官 747 | 護 748 | 博 749 | 補 750 | 石 751 | 爾 752 | 營 753 | 歷 754 | 隻 755 | 按 756 | 妹 757 | 里 758 | 編 759 | 歲 760 | 擇 761 | 溫 762 | 守 763 | 血 764 | 領 765 | 尋 766 | 田 767 | 養 768 | 謂 769 | 居 770 | 異 771 | 雨 772 | 止 773 | 跳 774 | 君 775 | 爛 776 | 優 777 | 封 778 | 拜 779 | 惡 780 | 啥 781 | 浪 782 | 核 783 | 聊 784 | 急 785 | 狀 786 | 陸 787 | 激 788 | 模 789 | 攻 790 | 忙 791 | 良 792 | 劇 793 | 牛 794 | 壘 795 | 增 796 | 維 797 | 靜 798 | 陣 799 | 抱 800 | 勢 801 | 嚴 802 | 詞 803 | 亞 804 | 夫 805 | 簽 806 | 悲 807 | 密 808 | 幕 809 | 毒 810 | 廠 811 | 爽 812 | 緣 813 | 店 814 | 吳 815 | 蘭 816 | 睡 817 | 致 818 | 江 819 | 宿 820 | 翻 821 | 香 822 | 蠻 823 | 警 824 | 控 825 | 趙 826 | 冷 827 | 威 828 | 微 829 | 坐 830 | 週 831 | 宗 832 | 普 833 | 登 834 | 母 835 | 絡 836 | 午 837 | 恐 838 | 套 839 | 巴 840 | 雜 841 | 創 842 | 舊 843 | 輯 844 | 幸 845 | 劍 846 | 亮 847 | 述 848 | 堂 849 | 酒 850 | 麗 851 | 牌 852 | 仔 853 | 腳 854 | 突 855 | 搞 856 | 父 857 | 俊 858 | 暴 859 | 防 860 | 吉 861 | 禮 862 | 素 863 | 招 864 | 草 865 | 周 866 | 房 867 | 餐 868 | 慮 869 | 充 870 | 府 871 | 背 872 | 典 873 | 仁 874 | 漫 875 | 景 876 | 紹 877 | 諸 878 | 琴 879 | 憶 880 | 援 881 | 尤 882 | 缺 883 | 扁 884 | 罵 885 | 純 886 | 惜 887 | 授 888 | 皮 889 | 松 890 | 委 891 | 湖 892 | 誠 893 | 麻 894 | 置 895 | 靠 896 | 繼 897 | 判 898 | 益 899 | 波 900 | 姐 901 | 既 902 | 射 903 | 欲 904 | 刻 905 | 堆 906 | 釋 907 | 含 908 | 承 909 | 退 910 | 莫 911 | 劉 912 | 昨 913 | 旁 914 | 紀 915 | 趕 916 | 製 917 | 尚 918 | 藝 919 | 肉 920 | 律 921 | 鐵 922 | 奏 923 | 樹 924 | 毛 925 | 罪 926 | 筆 927 | 彩 928 | 註 929 | 歸 930 | 彈 931 | 虎 932 | 衛 933 | 刀 934 | 皆 935 | 鍵 936 | 售 937 | 塊 938 | 險 939 | 榮 940 | 播 941 | 施 942 | 銘 943 | 囉 944 | 漢 945 | 賞 946 | 欣 947 | 升 948 | 葉 949 | 螢 950 | 載 951 | 嘿 952 | 弄 953 | 鐘 954 | 付 955 | 寄 956 | 鬼 957 | 哦 958 | 燈 959 | 呆 960 | 洋 961 | 嘻 962 | 布 963 | 磁 964 | 薦 965 | 檢 966 | 派 967 | 構 968 | 媽 969 | 藍 970 | 貼 971 | 豬 972 | 策 973 | 紙 974 | 暗 975 | 巧 976 | 努 977 | 雷 978 | 架 979 | 享 980 | 宣 981 | 逢 982 | 均 983 | 擔 984 | 啟 985 | 濟 986 | 罷 987 | 呼 988 | 劃 989 | 偉 990 | 島 991 | 歉 992 | 郭 993 | 訓 994 | 穿 995 | 詳 996 | 沙 997 | 督 998 | 梅 999 | 顧 1000 | 敵""" 1001 | -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # encoding: utf-8 3 | 4 | import random 5 | import torch 6 | from torch.utils.data import Dataset 7 | from torch.utils.data import sampler 8 | import torchvision.transforms as transforms 9 | import lmdb 10 | import six 11 | import sys 12 | from PIL import Image 13 | import numpy as np 14 | 15 | 16 | class lmdbDataset(Dataset): 17 | 18 | def __init__(self, root=None, transform=None, target_transform=None): 19 | self.env = lmdb.open( 20 | root, 21 | max_readers=1, 22 | readonly=True, 23 | lock=False, 24 | readahead=False, 25 | meminit=False) 26 | 27 | if not self.env: 28 | print('cannot creat lmdb from %s' % (root)) 29 | sys.exit(0) 30 | 31 | with self.env.begin(write=False) as txn: 32 | nSamples = int(txn.get('num-samples'.encode('utf-8'))) 33 | self.nSamples = nSamples 34 | 35 | self.transform = transform 36 | self.target_transform = target_transform 37 | 38 | def __len__(self): 39 | return self.nSamples 40 | 41 | def __getitem__(self, index): 42 | assert index <= len(self), 'index range error' 43 | index += 1 44 | with self.env.begin(write=False) as txn: 45 | img_key = 'image-%09d' % index 46 | imgbuf = txn.get(img_key.encode('utf-8')) 47 | 48 | buf = six.BytesIO() 49 | buf.write(imgbuf) 50 | buf.seek(0) 51 | try: 52 | img = Image.open(buf).convert('L') 53 | except IOError: 54 | print('Corrupted image for %d' % index) 55 | return self[index + 1] 56 | 57 | if self.transform is not None: 58 | img = self.transform(img) 59 | 60 | label_key = 'label-%09d' % index 61 | label = txn.get(label_key.encode('utf-8')) 62 | 63 | if self.target_transform is not None: 64 | label = self.target_transform(label) 65 | 66 | return (img, label) 67 | 68 | 69 | class resizeNormalize(object): 70 | 71 | def __init__(self, size, interpolation=Image.BILINEAR): 72 | self.size = size 73 | self.interpolation = interpolation 74 | self.toTensor = transforms.ToTensor() 75 | 76 | def __call__(self, img): 77 | img = img.resize(self.size, self.interpolation) 78 | img = self.toTensor(img) 79 | img.sub_(0.5).div_(0.5) 80 | return img 81 | 82 | 83 | class randomSequentialSampler(sampler.Sampler): 84 | 85 | def __init__(self, data_source, batch_size): 86 | self.num_samples = len(data_source) 87 | self.batch_size = batch_size 88 | 89 | def __iter__(self): 90 | n_batch = len(self) // self.batch_size 91 | tail = len(self) % self.batch_size 92 | index = torch.LongTensor(len(self)).fill_(0) 93 | for i in range(n_batch): 94 | random_start = random.randint(0, len(self) - self.batch_size) 95 | batch_index = random_start + torch.range(0, self.batch_size - 1) 96 | index[i * self.batch_size:(i + 1) * self.batch_size] = batch_index 97 | # deal with tail 98 | if tail: 99 | random_start = random.randint(0, len(self) - self.batch_size) 100 | tail_index = random_start + torch.range(0, tail - 1) 101 | index[(i + 1) * self.batch_size:] = tail_index 102 | 103 | return iter(index) 104 | 105 | def __len__(self): 106 | return self.num_samples 107 | 108 | 109 | class alignCollate(object): 110 | 111 | def __init__(self, imgH=32, imgW=100, keep_ratio=False, min_ratio=1): 112 | self.imgH = imgH 113 | self.imgW = imgW 114 | self.keep_ratio = keep_ratio 115 | self.min_ratio = min_ratio 116 | 117 | def __call__(self, batch): 118 | images, labels = zip(*batch) 119 | 120 | imgH = self.imgH 121 | imgW = self.imgW 122 | if self.keep_ratio: 123 | ratios = [] 124 | for image in images: 125 | w, h = image.size 126 | ratios.append(w / float(h)) 127 | ratios.sort() 128 | max_ratio = ratios[-1] 129 | imgW = int(np.floor(max_ratio * imgH)) 130 | imgW = max(imgH * self.min_ratio, imgW) # assure imgH >= imgW 131 | 132 | transform = resizeNormalize((imgW, imgH)) 133 | images = [transform(image) for image in images] 134 | images = torch.cat([t.unsqueeze(0) for t in images], 0) 135 | 136 | return images, labels 137 | -------------------------------------------------------------------------------- /demo.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd import Variable 3 | import utils 4 | import dataset 5 | from PIL import Image 6 | 7 | import models.crnn as crnn 8 | import params 9 | import argparse 10 | 11 | parser = argparse.ArgumentParser() 12 | parser.add_argument('-m', '--model_path', type = str, required = True, help = 'crnn model path') 13 | parser.add_argument('-i', '--image_path', type = str, required = True, help = 'demo image path') 14 | args = parser.parse_args() 15 | 16 | model_path = args.model_path 17 | image_path = args.image_path 18 | 19 | # net init 20 | nclass = len(params.alphabet) + 1 21 | model = crnn.CRNN(params.imgH, params.nc, nclass, params.nh) 22 | if torch.cuda.is_available(): 23 | model = model.cuda() 24 | 25 | # load model 26 | print('loading pretrained model from %s' % model_path) 27 | if params.multi_gpu: 28 | model = torch.nn.DataParallel(model) 29 | model.load_state_dict(torch.load(model_path)) 30 | 31 | converter = utils.strLabelConverter(params.alphabet) 32 | 33 | transformer = dataset.resizeNormalize((100, 32)) 34 | image = Image.open(image_path).convert('L') 35 | image = transformer(image) 36 | if torch.cuda.is_available(): 37 | image = image.cuda() 38 | image = image.view(1, *image.size()) 39 | image = Variable(image) 40 | 41 | model.eval() 42 | preds = model(image) 43 | 44 | _, preds = preds.max(2) 45 | preds = preds.transpose(1, 0).contiguous().view(-1) 46 | 47 | preds_size = Variable(torch.LongTensor([preds.size(0)])) 48 | raw_pred = converter.decode(preds.data, preds_size.data, raw=True) 49 | sim_pred = converter.decode(preds.data, preds_size.data, raw=False) 50 | print('%-20s => %-20s' % (raw_pred, sim_pred)) 51 | -------------------------------------------------------------------------------- /demo/demo.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Holmeyoung/crnn-pytorch/dc2f21afc14464e57add768956ed5aaa72983a51/demo/demo.jpg -------------------------------------------------------------------------------- /demo/illegal_character.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Holmeyoung/crnn-pytorch/dc2f21afc14464e57add768956ed5aaa72983a51/demo/illegal_character.png -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Holmeyoung/crnn-pytorch/dc2f21afc14464e57add768956ed5aaa72983a51/models/__init__.py -------------------------------------------------------------------------------- /models/crnn.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import params 3 | import torch.nn.functional as F 4 | 5 | class BidirectionalLSTM(nn.Module): 6 | 7 | def __init__(self, nIn, nHidden, nOut): 8 | super(BidirectionalLSTM, self).__init__() 9 | 10 | self.rnn = nn.LSTM(nIn, nHidden, bidirectional=True) 11 | self.embedding = nn.Linear(nHidden * 2, nOut) 12 | 13 | def forward(self, input): 14 | recurrent, _ = self.rnn(input) 15 | T, b, h = recurrent.size() 16 | t_rec = recurrent.view(T * b, h) 17 | 18 | output = self.embedding(t_rec) # [T * b, nOut] 19 | output = output.view(T, b, -1) 20 | 21 | return output 22 | 23 | 24 | class CRNN(nn.Module): 25 | 26 | def __init__(self, imgH, nc, nclass, nh, n_rnn=2, leakyRelu=False): 27 | super(CRNN, self).__init__() 28 | assert imgH % 16 == 0, 'imgH has to be a multiple of 16' 29 | 30 | ks = [3, 3, 3, 3, 3, 3, 2] 31 | ps = [1, 1, 1, 1, 1, 1, 0] 32 | ss = [1, 1, 1, 1, 1, 1, 1] 33 | nm = [64, 128, 256, 256, 512, 512, 512] 34 | 35 | cnn = nn.Sequential() 36 | 37 | def convRelu(i, batchNormalization=False): 38 | nIn = nc if i == 0 else nm[i - 1] 39 | nOut = nm[i] 40 | cnn.add_module('conv{0}'.format(i), 41 | nn.Conv2d(nIn, nOut, ks[i], ss[i], ps[i])) 42 | if batchNormalization: 43 | cnn.add_module('batchnorm{0}'.format(i), nn.BatchNorm2d(nOut)) 44 | if leakyRelu: 45 | cnn.add_module('relu{0}'.format(i), 46 | nn.LeakyReLU(0.2, inplace=True)) 47 | else: 48 | cnn.add_module('relu{0}'.format(i), nn.ReLU(True)) 49 | 50 | convRelu(0) 51 | cnn.add_module('pooling{0}'.format(0), nn.MaxPool2d(2, 2)) # 64x16x64 52 | convRelu(1) 53 | cnn.add_module('pooling{0}'.format(1), nn.MaxPool2d(2, 2)) # 128x8x32 54 | convRelu(2, True) 55 | convRelu(3) 56 | cnn.add_module('pooling{0}'.format(2), 57 | nn.MaxPool2d((2, 2), (2, 1), (0, 1))) # 256x4x16 58 | convRelu(4, True) 59 | convRelu(5) 60 | cnn.add_module('pooling{0}'.format(3), 61 | nn.MaxPool2d((2, 2), (2, 1), (0, 1))) # 512x2x16 62 | convRelu(6, True) # 512x1x16 63 | 64 | self.cnn = cnn 65 | self.rnn = nn.Sequential( 66 | BidirectionalLSTM(512, nh, nh), 67 | BidirectionalLSTM(nh, nh, nclass)) 68 | 69 | 70 | def forward(self, input): 71 | # conv features 72 | conv = self.cnn(input) 73 | b, c, h, w = conv.size() 74 | assert h == 1, "the height of conv must be 1" 75 | conv = conv.squeeze(2) 76 | conv = conv.permute(2, 0, 1) # [w, b, c] 77 | 78 | # rnn features 79 | output = self.rnn(conv) 80 | 81 | # add log_softmax to converge output 82 | output = F.log_softmax(output, dim=2) 83 | 84 | return output 85 | 86 | 87 | def backward_hook(self, module, grad_input, grad_output): 88 | for g in grad_input: 89 | g[g != g] = 0 # replace all nan/inf in gradients to zero 90 | 91 | -------------------------------------------------------------------------------- /params.py: -------------------------------------------------------------------------------- 1 | import alphabets 2 | 3 | # about data and net 4 | alphabet = alphabets.alphabet 5 | keep_ratio = False # whether to keep ratio for image resize 6 | manualSeed = 1234 # reproduce experiemnt 7 | random_sample = True # whether to sample the dataset with random sampler 8 | imgH = 32 # the height of the input image to network 9 | imgW = 100 # the width of the input image to network 10 | nh = 256 # size of the lstm hidden state 11 | nc = 1 12 | pretrained = '' # path to pretrained model (to continue training) 13 | expr_dir = 'expr' # where to store samples and models 14 | dealwith_lossnan = False # whether to replace all nan/inf in gradients to zero 15 | 16 | # hardware 17 | cuda = True # enables cuda 18 | multi_gpu = False # whether to use multi gpu 19 | ngpu = 1 # number of GPUs to use. Do remember to set multi_gpu to True! 20 | workers = 0 # number of data loading workers 21 | 22 | # training process 23 | displayInterval = 100 # interval to be print the train loss 24 | valInterval = 1000 # interval to val the model loss and accuray 25 | saveInterval = 1000 # interval to save model 26 | n_val_disp = 10 # number of samples to display when val the model 27 | 28 | # finetune 29 | nepoch = 1000 # number of epochs to train for 30 | batchSize = 64 # input batch size 31 | lr = 0.0001 # learning rate for Critic, not used by adadealta 32 | beta1 = 0.5 # beta1 for adam. default=0.5 33 | adam = False # whether to use adam (default is rmsprop) 34 | adadelta = False # whether to use adadelta (default is rmsprop) 35 | -------------------------------------------------------------------------------- /tool/convert_t7.lua: -------------------------------------------------------------------------------- 1 | require('table') 2 | require('torch') 3 | require('os') 4 | 5 | function clone(t) 6 | -- deep-copy a table 7 | if type(t) ~= "table" then return t end 8 | local meta = getmetatable(t) 9 | local target = {} 10 | for k, v in pairs(t) do 11 | if type(v) == "table" then 12 | target[k] = clone(v) 13 | else 14 | target[k] = v 15 | end 16 | end 17 | setmetatable(target, meta) 18 | return target 19 | end 20 | 21 | 22 | function tableMerge(lhs, rhs) 23 | output = clone(lhs) 24 | for _, v in pairs(rhs) do 25 | table.insert(output, v) 26 | end 27 | return output 28 | end 29 | 30 | 31 | function isInTable(val, val_list) 32 | for _, item in pairs(val_list) do 33 | if val == item then 34 | return true 35 | end 36 | end 37 | return false 38 | end 39 | 40 | 41 | function modelToList(model) 42 | local ignoreList = { 43 | 'nn.Copy', 44 | 'nn.AddConstant', 45 | 'nn.MulConstant', 46 | 'nn.View', 47 | 'nn.Transpose', 48 | 'nn.SplitTable', 49 | 'nn.SharedParallelTable', 50 | 'nn.JoinTable', 51 | } 52 | local state = {} 53 | local param 54 | for i, layer in pairs(model.modules) do 55 | local typeName = torch.type(layer) 56 | if not isInTable(typeName, ignoreList) then 57 | if typeName == 'nn.Sequential' or typeName == 'nn.ConcatTable' then 58 | param = modelToList(layer) 59 | elseif typeName == 'cudnn.SpatialConvolution' or typeName == 'nn.SpatialConvolution' then 60 | param = layer:parameters() 61 | elseif typeName == 'cudnn.SpatialBatchNormalization' or typeName == 'nn.SpatialBatchNormalization' then 62 | param = layer:parameters() 63 | bn_vars = {layer.running_mean, layer.running_var} 64 | param = tableMerge(param, bn_vars) 65 | elseif typeName == 'nn.LstmLayer' then 66 | param = layer:parameters() 67 | elseif typeName == 'nn.BiRnnJoin' then 68 | param = layer:parameters() 69 | elseif typeName == 'cudnn.SpatialMaxPooling' or typeName == 'nn.SpatialMaxPooling' then 70 | param = {} 71 | elseif typeName == 'cudnn.ReLU' or typeName == 'nn.ReLU' then 72 | param = {} 73 | else 74 | print(string.format('Unknown class %s', typeName)) 75 | os.exit(0) 76 | end 77 | table.insert(state, {typeName, param}) 78 | else 79 | print(string.format('pass %s', typeName)) 80 | end 81 | end 82 | return state 83 | end 84 | 85 | 86 | function saveModel(model, output_path) 87 | local state = modelToList(model) 88 | torch.save(output_path, state) 89 | end 90 | -------------------------------------------------------------------------------- /tool/convert_t7.py: -------------------------------------------------------------------------------- 1 | import torchfile 2 | import argparse 3 | import torch 4 | from torch.nn.parameter import Parameter 5 | import numpy as np 6 | import models.crnn as crnn 7 | 8 | 9 | layer_map = { 10 | 'SpatialConvolution': 'Conv2d', 11 | 'SpatialBatchNormalization': 'BatchNorm2d', 12 | 'ReLU': 'ReLU', 13 | 'SpatialMaxPooling': 'MaxPool2d', 14 | 'SpatialAveragePooling': 'AvgPool2d', 15 | 'SpatialUpSamplingNearest': 'UpsamplingNearest2d', 16 | 'View': None, 17 | 'Linear': 'linear', 18 | 'Dropout': 'Dropout', 19 | 'SoftMax': 'Softmax', 20 | 'Identity': None, 21 | 'SpatialFullConvolution': 'ConvTranspose2d', 22 | 'SpatialReplicationPadding': None, 23 | 'SpatialReflectionPadding': None, 24 | 'Copy': None, 25 | 'Narrow': None, 26 | 'SpatialCrossMapLRN': None, 27 | 'Sequential': None, 28 | 'ConcatTable': None, # output is list 29 | 'CAddTable': None, # input is list 30 | 'Concat': None, 31 | 'TorchObject': None, 32 | 'LstmLayer': 'LSTM', 33 | 'BiRnnJoin': 'Linear' 34 | } 35 | 36 | 37 | def torch_layer_serial(layer, layers): 38 | name = layer[0] 39 | if name == 'nn.Sequential' or name == 'nn.ConcatTable': 40 | tmp_layers = [] 41 | for sub_layer in layer[1]: 42 | torch_layer_serial(sub_layer, tmp_layers) 43 | layers.extend(tmp_layers) 44 | else: 45 | layers.append(layer) 46 | 47 | 48 | def py_layer_serial(layer, layers): 49 | """ 50 | Assume modules are defined as executive sequence. 51 | """ 52 | if len(layer._modules) >= 1: 53 | tmp_layers = [] 54 | for sub_layer in layer.children(): 55 | py_layer_serial(sub_layer, tmp_layers) 56 | layers.extend(tmp_layers) 57 | else: 58 | layers.append(layer) 59 | 60 | 61 | def trans_pos(param, part_indexes, dim=0): 62 | parts = np.split(param, len(part_indexes), dim) 63 | new_parts = [] 64 | for i in part_indexes: 65 | new_parts.append(parts[i]) 66 | return np.concatenate(new_parts, dim) 67 | 68 | 69 | def load_params(py_layer, t7_layer): 70 | if type(py_layer).__name__ == 'LSTM': 71 | # LSTM 72 | all_weights = [] 73 | num_directions = 2 if py_layer.bidirectional else 1 74 | for i in range(py_layer.num_layers): 75 | for j in range(num_directions): 76 | suffix = '_reverse' if j == 1 else '' 77 | weights = ['weight_ih_l{}{}', 'bias_ih_l{}{}', 78 | 'weight_hh_l{}{}', 'bias_hh_l{}{}'] 79 | weights = [x.format(i, suffix) for x in weights] 80 | all_weights += weights 81 | 82 | params = [] 83 | for i in range(len(t7_layer)): 84 | params.extend(t7_layer[i][1]) 85 | params = [trans_pos(p, [0, 1, 3, 2], dim=0) for p in params] 86 | else: 87 | all_weights = [] 88 | name = t7_layer[0].split('.')[-1] 89 | if name == 'BiRnnJoin': 90 | weight_0, bias_0, weight_1, bias_1 = t7_layer[1] 91 | weight = np.concatenate((weight_0, weight_1), axis=1) 92 | bias = bias_0 + bias_1 93 | t7_layer[1] = [weight, bias] 94 | all_weights += ['weight', 'bias'] 95 | elif name == 'SpatialConvolution' or name == 'Linear': 96 | all_weights += ['weight', 'bias'] 97 | elif name == 'SpatialBatchNormalization': 98 | all_weights += ['weight', 'bias', 'running_mean', 'running_var'] 99 | 100 | params = t7_layer[1] 101 | 102 | params = [torch.from_numpy(item) for item in params] 103 | assert len(all_weights) == len(params), "params' number not match" 104 | for py_param_name, t7_param in zip(all_weights, params): 105 | item = getattr(py_layer, py_param_name) 106 | if isinstance(item, Parameter): 107 | item = item.data 108 | try: 109 | item.copy_(t7_param) 110 | except RuntimeError: 111 | print('Size not match between %s and %s' % 112 | (item.size(), t7_param.size())) 113 | 114 | 115 | def torch_to_pytorch(model, t7_file, output): 116 | py_layers = [] 117 | for layer in list(model.children()): 118 | py_layer_serial(layer, py_layers) 119 | 120 | t7_data = torchfile.load(t7_file) 121 | t7_layers = [] 122 | for layer in t7_data: 123 | torch_layer_serial(layer, t7_layers) 124 | 125 | j = 0 126 | for i, py_layer in enumerate(py_layers): 127 | py_name = type(py_layer).__name__ 128 | t7_layer = t7_layers[j] 129 | t7_name = t7_layer[0].split('.')[-1] 130 | if layer_map[t7_name] != py_name: 131 | raise RuntimeError('%s does not match %s' % (py_name, t7_name)) 132 | 133 | if py_name == 'LSTM': 134 | n_layer = 2 if py_layer.bidirectional else 1 135 | n_layer *= py_layer.num_layers 136 | t7_layer = t7_layers[j:j + n_layer] 137 | j += n_layer 138 | else: 139 | j += 1 140 | 141 | load_params(py_layer, t7_layer) 142 | 143 | torch.save(model.state_dict(), output) 144 | 145 | 146 | if __name__ == "__main__": 147 | parser = argparse.ArgumentParser( 148 | description='Convert torch t7 model to pytorch' 149 | ) 150 | parser.add_argument( 151 | '--model_file', 152 | '-m', 153 | type=str, 154 | required=True, 155 | help='torch model file in t7 format' 156 | ) 157 | parser.add_argument( 158 | '--output', 159 | '-o', 160 | type=str, 161 | default=None, 162 | help='output file name prefix, xxx.py xxx.pth' 163 | ) 164 | args = parser.parse_args() 165 | 166 | py_model = crnn.CRNN(32, 1, 37, 256, 1) 167 | torch_to_pytorch(py_model, args.model_file, args.output) 168 | -------------------------------------------------------------------------------- /tool/create_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import lmdb 3 | import cv2 4 | import numpy as np 5 | import argparse 6 | import shutil 7 | import sys 8 | 9 | def checkImageIsValid(imageBin): 10 | if imageBin is None: 11 | return False 12 | 13 | try: 14 | imageBuf = np.fromstring(imageBin, dtype=np.uint8) 15 | img = cv2.imdecode(imageBuf, cv2.IMREAD_GRAYSCALE) 16 | imgH, imgW = img.shape[0], img.shape[1] 17 | except: 18 | return False 19 | else: 20 | if imgH * imgW == 0: 21 | return False 22 | 23 | return True 24 | 25 | 26 | def writeCache(env, cache): 27 | with env.begin(write=True) as txn: 28 | for k, v in cache.items(): 29 | if type(k) == str: 30 | k = k.encode() 31 | if type(v) == str: 32 | v = v.encode() 33 | txn.put(k,v) 34 | 35 | def createDataset(outputPath, imagePathList, labelList, lexiconList=None, checkValid=True): 36 | """ 37 | Create LMDB dataset for CRNN training. 38 | 39 | ARGS: 40 | outputPath : LMDB output path 41 | imagePathList : list of image path 42 | labelList : list of corresponding groundtruth texts 43 | lexiconList : (optional) list of lexicon lists 44 | checkValid : if true, check the validity of every image 45 | """ 46 | # If lmdb file already exists, remove it. Or the new data will add to it. 47 | if os.path.exists(outputPath): 48 | shutil.rmtree(outputPath) 49 | os.makedirs(outputPath) 50 | else: 51 | os.makedirs(outputPath) 52 | 53 | assert (len(imagePathList) == len(labelList)) 54 | nSamples = len(imagePathList) 55 | env = lmdb.open(outputPath, map_size=1099511627776) 56 | cache = {} 57 | cnt = 1 58 | for i in range(nSamples): 59 | imagePath = imagePathList[i] 60 | label = labelList[i] 61 | 62 | if not os.path.exists(imagePath): 63 | print('%s does not exist' % imagePath) 64 | continue 65 | with open(imagePath, 'rb') as f: 66 | imageBin = f.read() 67 | if checkValid: 68 | if not checkImageIsValid(imageBin): 69 | print('%s is not a valid image' % imagePath) 70 | continue 71 | 72 | imageKey = 'image-%09d' % cnt 73 | labelKey = 'label-%09d' % cnt 74 | cache[imageKey] = imageBin 75 | cache[labelKey] = label 76 | if lexiconList: 77 | lexiconKey = 'lexicon-%09d' % cnt 78 | cache[lexiconKey] = ' '.join(lexiconList[i]) 79 | if cnt % 1000 == 0: 80 | writeCache(env, cache) 81 | cache = {} 82 | print('Written %d / %d' % (cnt, nSamples)) 83 | cnt += 1 84 | nSamples = cnt-1 85 | cache['num-samples'] = str(nSamples) 86 | writeCache(env, cache) 87 | env.close() 88 | print('Created dataset with %d samples' % nSamples) 89 | 90 | def read_data_from_folder(folder_path): 91 | image_path_list = [] 92 | label_list = [] 93 | pics = os.listdir(folder_path) 94 | pics.sort(key = lambda i: len(i)) 95 | for pic in pics: 96 | image_path_list.append(folder_path + '/' + pic) 97 | label_list.append(pic.split('_')[0]) 98 | return image_path_list, label_list 99 | 100 | def read_data_from_file(file_path): 101 | image_path_list = [] 102 | label_list = [] 103 | f = open(file_path) 104 | while True: 105 | line1 = f.readline() 106 | line2 = f.readline() 107 | if not line1 or not line2: 108 | break 109 | line1 = line1.replace('\r', '').replace('\n', '') 110 | line2 = line2.replace('\r', '').replace('\n', '') 111 | image_path_list.append(line1) 112 | label_list.append(line2) 113 | 114 | return image_path_list, label_list 115 | 116 | def show_demo(demo_number, image_path_list, label_list): 117 | print ('\nShow some demo to prevent creating wrong lmdb data') 118 | print ('The first line is the path to image and the second line is the image label') 119 | for i in range(demo_number): 120 | print ('image: %s\nlabel: %s\n' % (image_path_list[i], label_list[i])) 121 | 122 | if __name__ == '__main__': 123 | parser = argparse.ArgumentParser() 124 | parser.add_argument('--out', type = str, required = True, help = 'lmdb data output path') 125 | parser.add_argument('--folder', type = str, help = 'path to folder which contains the images') 126 | parser.add_argument('--file', type = str, help = 'path to file which contains the image path and label') 127 | args = parser.parse_args() 128 | 129 | if args.file is not None: 130 | image_path_list, label_list = read_data_from_file(args.file) 131 | createDataset(args.out, image_path_list, label_list) 132 | show_demo(2, image_path_list, label_list) 133 | elif args.folder is not None: 134 | image_path_list, label_list = read_data_from_folder(args.folder) 135 | createDataset(args.out, image_path_list, label_list) 136 | show_demo(2, image_path_list, label_list) 137 | else: 138 | print ('Please use --floder or --file to assign the input. Use -h to see more.') 139 | sys.exit() 140 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | from __future__ import division 3 | 4 | import argparse 5 | import random 6 | import torch 7 | import torch.backends.cudnn as cudnn 8 | import torch.optim as optim 9 | import torch.utils.data 10 | from torch.autograd import Variable 11 | import numpy as np 12 | # from warpctc_pytorch import CTCLoss 13 | from torch.nn import CTCLoss 14 | import os 15 | import utils 16 | import dataset 17 | 18 | import models.crnn as net 19 | import params 20 | 21 | parser = argparse.ArgumentParser() 22 | parser.add_argument('-train', '--trainroot', required=True, help='path to train dataset') 23 | parser.add_argument('-val', '--valroot', required=True, help='path to val dataset') 24 | args = parser.parse_args() 25 | 26 | if not os.path.exists(params.expr_dir): 27 | os.makedirs(params.expr_dir) 28 | 29 | # ensure everytime the random is the same 30 | random.seed(params.manualSeed) 31 | np.random.seed(params.manualSeed) 32 | torch.manual_seed(params.manualSeed) 33 | 34 | cudnn.benchmark = True 35 | 36 | if torch.cuda.is_available() and not params.cuda: 37 | print("WARNING: You have a CUDA device, so you should probably set cuda in params.py to True") 38 | 39 | # ----------------------------------------------- 40 | """ 41 | In this block 42 | Get train and val data_loader 43 | """ 44 | def data_loader(): 45 | # train 46 | train_dataset = dataset.lmdbDataset(root=args.trainroot) 47 | assert train_dataset 48 | if not params.random_sample: 49 | sampler = dataset.randomSequentialSampler(train_dataset, params.batchSize) 50 | else: 51 | sampler = None 52 | train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=params.batchSize, \ 53 | shuffle=True, sampler=sampler, num_workers=int(params.workers), \ 54 | collate_fn=dataset.alignCollate(imgH=params.imgH, imgW=params.imgW, keep_ratio=params.keep_ratio)) 55 | 56 | # val 57 | val_dataset = dataset.lmdbDataset(root=args.valroot, transform=dataset.resizeNormalize((params.imgW, params.imgH))) 58 | assert val_dataset 59 | val_loader = torch.utils.data.DataLoader(val_dataset, shuffle=True, batch_size=params.batchSize, num_workers=int(params.workers)) 60 | 61 | return train_loader, val_loader 62 | 63 | train_loader, val_loader = data_loader() 64 | 65 | # ----------------------------------------------- 66 | """ 67 | In this block 68 | Net init 69 | Weight init 70 | Load pretrained model 71 | """ 72 | def weights_init(m): 73 | classname = m.__class__.__name__ 74 | if classname.find('Conv') != -1: 75 | m.weight.data.normal_(0.0, 0.02) 76 | elif classname.find('BatchNorm') != -1: 77 | m.weight.data.normal_(1.0, 0.02) 78 | m.bias.data.fill_(0) 79 | 80 | def net_init(): 81 | nclass = len(params.alphabet) + 1 82 | crnn = net.CRNN(params.imgH, params.nc, nclass, params.nh) 83 | crnn.apply(weights_init) 84 | if params.pretrained != '': 85 | print('loading pretrained model from %s' % params.pretrained) 86 | if params.multi_gpu: 87 | crnn = torch.nn.DataParallel(crnn) 88 | crnn.load_state_dict(torch.load(params.pretrained)) 89 | 90 | return crnn 91 | 92 | crnn = net_init() 93 | print(crnn) 94 | 95 | # ----------------------------------------------- 96 | """ 97 | In this block 98 | Init some utils defined in utils.py 99 | """ 100 | # Compute average for `torch.Variable` and `torch.Tensor`. 101 | loss_avg = utils.averager() 102 | 103 | # Convert between str and label. 104 | converter = utils.strLabelConverter(params.alphabet) 105 | 106 | # ----------------------------------------------- 107 | """ 108 | In this block 109 | criterion define 110 | """ 111 | criterion = CTCLoss() 112 | 113 | # ----------------------------------------------- 114 | """ 115 | In this block 116 | Init some tensor 117 | Put tensor and net on cuda 118 | NOTE: 119 | image, text, length is used by both val and train 120 | becaues train and val will never use it at the same time. 121 | """ 122 | image = torch.FloatTensor(params.batchSize, 3, params.imgH, params.imgH) 123 | text = torch.LongTensor(params.batchSize * 5) 124 | length = torch.LongTensor(params.batchSize) 125 | 126 | if params.cuda and torch.cuda.is_available(): 127 | criterion = criterion.cuda() 128 | image = image.cuda() 129 | text = text.cuda() 130 | 131 | crnn = crnn.cuda() 132 | if params.multi_gpu: 133 | crnn = torch.nn.DataParallel(crnn, device_ids=range(params.ngpu)) 134 | 135 | image = Variable(image) 136 | text = Variable(text) 137 | length = Variable(length) 138 | 139 | # ----------------------------------------------- 140 | """ 141 | In this block 142 | Setup optimizer 143 | """ 144 | if params.adam: 145 | optimizer = optim.Adam(crnn.parameters(), lr=params.lr, betas=(params.beta1, 0.999)) 146 | elif params.adadelta: 147 | optimizer = optim.Adadelta(crnn.parameters()) 148 | else: 149 | optimizer = optim.RMSprop(crnn.parameters(), lr=params.lr) 150 | 151 | # ----------------------------------------------- 152 | """ 153 | In this block 154 | Dealwith lossnan 155 | NOTE: 156 | I use different way to dealwith loss nan according to the torch version. 157 | """ 158 | if params.dealwith_lossnan: 159 | if torch.__version__ >= '1.1.0': 160 | """ 161 | zero_infinity (bool, optional): 162 | Whether to zero infinite losses and the associated gradients. 163 | Default: ``False`` 164 | Infinite losses mainly occur when the inputs are too short 165 | to be aligned to the targets. 166 | Pytorch add this param after v1.1.0 167 | """ 168 | criterion = CTCLoss(zero_infinity = True) 169 | else: 170 | """ 171 | only when 172 | torch.__version__ < '1.1.0' 173 | we use this way to change the inf to zero 174 | """ 175 | crnn.register_backward_hook(crnn.backward_hook) 176 | 177 | # ----------------------------------------------- 178 | 179 | def val(net, criterion): 180 | print('Start val') 181 | 182 | for p in crnn.parameters(): 183 | p.requires_grad = False 184 | 185 | net.eval() 186 | val_iter = iter(val_loader) 187 | 188 | i = 0 189 | n_correct = 0 190 | loss_avg = utils.averager() # The blobal loss_avg is used by train 191 | 192 | max_iter = len(val_loader) 193 | for i in range(max_iter): 194 | data = val_iter.next() 195 | i += 1 196 | cpu_images, cpu_texts = data 197 | batch_size = cpu_images.size(0) 198 | utils.loadData(image, cpu_images) 199 | t, l = converter.encode(cpu_texts) 200 | utils.loadData(text, t) 201 | utils.loadData(length, l) 202 | 203 | preds = crnn(image) 204 | preds_size = Variable(torch.LongTensor([preds.size(0)] * batch_size)) 205 | cost = criterion(preds, text, preds_size, length) / batch_size 206 | loss_avg.add(cost) 207 | 208 | _, preds = preds.max(2) 209 | preds = preds.transpose(1, 0).contiguous().view(-1) 210 | sim_preds = converter.decode(preds.data, preds_size.data, raw=False) 211 | cpu_texts_decode = [] 212 | for i in cpu_texts: 213 | cpu_texts_decode.append(i.decode('utf-8', 'strict')) 214 | for pred, target in zip(sim_preds, cpu_texts_decode): 215 | if pred == target: 216 | n_correct += 1 217 | 218 | raw_preds = converter.decode(preds.data, preds_size.data, raw=True)[:params.n_val_disp] 219 | for raw_pred, pred, gt in zip(raw_preds, sim_preds, cpu_texts_decode): 220 | print('%-20s => %-20s, gt: %-20s' % (raw_pred, pred, gt)) 221 | 222 | accuracy = n_correct / float(max_iter * params.batchSize) 223 | print('Val loss: %f, accuray: %f' % (loss_avg.val(), accuracy)) 224 | 225 | 226 | def train(net, criterion, optimizer, train_iter): 227 | for p in crnn.parameters(): 228 | p.requires_grad = True 229 | crnn.train() 230 | 231 | data = train_iter.next() 232 | cpu_images, cpu_texts = data 233 | batch_size = cpu_images.size(0) 234 | utils.loadData(image, cpu_images) 235 | t, l = converter.encode(cpu_texts) 236 | utils.loadData(text, t) 237 | utils.loadData(length, l) 238 | 239 | optimizer.zero_grad() 240 | preds = crnn(image) 241 | preds_size = Variable(torch.LongTensor([preds.size(0)] * batch_size)) 242 | cost = criterion(preds, text, preds_size, length) / batch_size 243 | # crnn.zero_grad() 244 | cost.backward() 245 | optimizer.step() 246 | return cost 247 | 248 | 249 | if __name__ == "__main__": 250 | for epoch in range(params.nepoch): 251 | train_iter = iter(train_loader) 252 | i = 0 253 | while i < len(train_loader): 254 | cost = train(crnn, criterion, optimizer, train_iter) 255 | loss_avg.add(cost) 256 | i += 1 257 | 258 | if i % params.displayInterval == 0: 259 | print('[%d/%d][%d/%d] Loss: %f' % 260 | (epoch, params.nepoch, i, len(train_loader), loss_avg.val())) 261 | loss_avg.reset() 262 | 263 | if i % params.valInterval == 0: 264 | val(crnn, criterion) 265 | 266 | # do checkpointing 267 | if i % params.saveInterval == 0: 268 | torch.save(crnn.state_dict(), '{0}/netCRNN_{1}_{2}.pth'.format(params.expr_dir, epoch, i)) 269 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # encoding: utf-8 3 | 4 | import torch 5 | import torch.nn as nn 6 | from torch.autograd import Variable 7 | import collections 8 | 9 | 10 | class strLabelConverter(object): 11 | """Convert between str and label. 12 | 13 | NOTE: 14 | Insert `blank` to the alphabet for CTC. 15 | 16 | Args: 17 | alphabet (str): set of the possible characters. 18 | ignore_case (bool, default=True): whether or not to ignore all of the case. 19 | """ 20 | 21 | def __init__(self, alphabet, ignore_case=False): 22 | self._ignore_case = ignore_case 23 | if self._ignore_case: 24 | alphabet = alphabet.lower() 25 | self.alphabet = alphabet + '-' # for `-1` index 26 | 27 | self.dict = {} 28 | for i, char in enumerate(alphabet): 29 | # NOTE: 0 is reserved for 'blank' required by wrap_ctc 30 | self.dict[char] = i + 1 31 | 32 | def encode(self, text): 33 | """Support batch or single str. 34 | 35 | Args: 36 | text (str or list of str): texts to convert. 37 | 38 | Returns: 39 | torch.LongTensor [length_0 + length_1 + ... length_{n - 1}]: encoded texts. 40 | torch.LongTensor [n]: length of each text. 41 | """ 42 | 43 | length = [] 44 | result = [] 45 | for item in text: 46 | item = item.decode('utf-8','strict') 47 | length.append(len(item)) 48 | r = [] 49 | for char in item: 50 | index = self.dict[char] 51 | # result.append(index) 52 | r.append(index) 53 | result.append(r) 54 | 55 | max_len = 0 56 | for r in result: 57 | if len(r) > max_len: 58 | max_len = len(r) 59 | 60 | result_temp = [] 61 | for r in result: 62 | for i in range(max_len - len(r)): 63 | r.append(0) 64 | result_temp.append(r) 65 | 66 | text = result_temp 67 | return (torch.LongTensor(text), torch.LongTensor(length)) 68 | 69 | 70 | def decode(self, t, length, raw=False): 71 | """Decode encoded texts back into strs. 72 | 73 | Args: 74 | torch.LongTensor [length_0 + length_1 + ... length_{n - 1}]: encoded texts. 75 | torch.LongTensor [n]: length of each text. 76 | 77 | Raises: 78 | AssertionError: when the texts and its length does not match. 79 | 80 | Returns: 81 | text (str or list of str): texts to convert. 82 | """ 83 | if length.numel() == 1: 84 | length = length[0] 85 | assert t.numel() == length, "text with length: {} does not match declared length: {}".format(t.numel(), length) 86 | if raw: 87 | return ''.join([self.alphabet[i - 1] for i in t]) 88 | else: 89 | char_list = [] 90 | for i in range(length): 91 | if t[i] != 0 and (not (i > 0 and t[i - 1] == t[i])): 92 | char_list.append(self.alphabet[t[i] - 1]) 93 | return ''.join(char_list) 94 | else: 95 | # batch mode 96 | assert t.numel() == length.sum(), "texts with length: {} does not match declared length: {}".format(t.numel(), length.sum()) 97 | texts = [] 98 | index = 0 99 | for i in range(length.numel()): 100 | l = length[i] 101 | texts.append( 102 | self.decode( 103 | t[index:index + l], torch.LongTensor([l]), raw=raw)) 104 | index += l 105 | return texts 106 | 107 | 108 | class averager(object): 109 | """Compute average for `torch.Variable` and `torch.Tensor`. """ 110 | 111 | def __init__(self): 112 | self.reset() 113 | 114 | def add(self, v): 115 | if isinstance(v, Variable): 116 | count = v.data.numel() 117 | v = v.data.sum() 118 | elif isinstance(v, torch.Tensor): 119 | count = v.numel() 120 | v = v.sum() 121 | 122 | self.n_count += count 123 | self.sum += v 124 | 125 | def reset(self): 126 | self.n_count = 0 127 | self.sum = 0 128 | 129 | def val(self): 130 | res = 0 131 | if self.n_count != 0: 132 | res = self.sum / float(self.n_count) 133 | return res 134 | 135 | 136 | def oneHot(v, v_length, nc): 137 | batchSize = v_length.size(0) 138 | maxLength = v_length.max() 139 | v_onehot = torch.FloatTensor(batchSize, maxLength, nc).fill_(0) 140 | acc = 0 141 | for i in range(batchSize): 142 | length = v_length[i] 143 | label = v[acc:acc + length].view(-1, 1).long() 144 | v_onehot[i, :length].scatter_(1, label, 1.0) 145 | acc += length 146 | return v_onehot 147 | 148 | 149 | def loadData(v, data): 150 | with torch.no_grad(): 151 | v.resize_(data.size()).copy_(data) 152 | 153 | 154 | def prettyPrint(v): 155 | print('Size {0}, Type: {1}'.format(str(v.size()), v.data.type())) 156 | print('| Max: %f | Min: %f | Mean: %f' % (v.max().data[0], v.min().data[0], 157 | v.mean().data[0])) 158 | 159 | 160 | def assureRatio(img): 161 | """Ensure imgH <= imgW.""" 162 | b, c, h, w = img.size() 163 | if h > w: 164 | main = nn.UpsamplingBilinear2d(size=(h, h), scale_factor=None) 165 | img = main(img) 166 | return img 167 | --------------------------------------------------------------------------------