├── .gitignore
├── README.md
├── data
├── ICDAR2015.txt
├── ICDAR2015
│ ├── 001.jpg
│ ├── 002.jpg
│ ├── 003.jpg
│ ├── 004.jpg
│ ├── img_1.jpg
│ ├── img_10.jpg
│ ├── img_13.jpg
│ ├── img_16.jpg
│ ├── img_19.jpg
│ ├── img_22.jpg
│ ├── img_25.jpg
│ ├── img_28.jpg
│ ├── img_31.jpg
│ ├── img_4.jpg
│ └── img_7.jpg
├── alphabet.txt
├── example_image
│ ├── img_1.jpg
│ ├── img_10.jpg
│ ├── img_13.jpg
│ ├── img_16.jpg
│ ├── img_19.jpg
│ ├── img_22.jpg
│ ├── img_25.jpg
│ ├── img_28.jpg
│ ├── img_31.jpg
│ ├── img_4.jpg
│ └── img_7.jpg
├── small_train.txt
├── trainch.txt
└── tshow
│ ├── crop0.jpg
│ ├── crop1.jpg
│ ├── crop10.jpg
│ ├── crop11.jpg
│ ├── crop12.jpg
│ ├── crop13.jpg
│ ├── crop14.jpg
│ ├── crop15.jpg
│ ├── crop16.jpg
│ ├── crop17.jpg
│ ├── crop18.jpg
│ ├── crop19.jpg
│ ├── crop2.jpg
│ ├── crop20.jpg
│ ├── crop21.jpg
│ ├── crop22.jpg
│ ├── crop23.jpg
│ ├── crop24.jpg
│ ├── crop25.jpg
│ ├── crop26.jpg
│ ├── crop3.jpg
│ ├── crop4.jpg
│ ├── crop5.jpg
│ ├── crop6.jpg
│ ├── crop7.jpg
│ ├── crop8.jpg
│ ├── crop9.jpg
│ ├── img0.jpg
│ ├── img1.jpg
│ ├── img2.jpg
│ ├── img3.jpg
│ ├── img4.jpg
│ ├── img5.jpg
│ ├── img6.jpg
│ └── img7.jpg
├── images
├── roirototate.jpg
└── synth.png
├── nms
├── .gitignore
├── Makefile
├── __init__.py
├── adaptor.cpp
├── include
│ ├── clipper
│ │ ├── clipper.cpp
│ │ └── clipper.hpp
│ └── pybind11
│ │ ├── attr.h
│ │ ├── buffer_info.h
│ │ ├── cast.h
│ │ ├── chrono.h
│ │ ├── common.h
│ │ ├── complex.h
│ │ ├── detail
│ │ ├── class.h
│ │ ├── common.h
│ │ ├── descr.h
│ │ ├── init.h
│ │ ├── internals.h
│ │ └── typeid.h
│ │ ├── eigen.h
│ │ ├── embed.h
│ │ ├── eval.h
│ │ ├── functional.h
│ │ ├── iostream.h
│ │ ├── numpy.h
│ │ ├── operators.h
│ │ ├── options.h
│ │ ├── pybind11.h
│ │ ├── pytypes.h
│ │ ├── stl.h
│ │ └── stl_bind.h
└── nms.h
├── rroi_align
├── __init__.py
├── build.py
├── data
│ ├── grad.jpg
│ ├── grad_img.jpg
│ ├── res0.jpg
│ ├── res1.jpg
│ ├── res2.jpg
│ └── timg.jpeg
├── functions
│ ├── __init__.py
│ └── rroi_align.py
├── main.py
├── make.sh
├── modules
│ ├── __init__.py
│ └── rroi_align.py
├── src
│ ├── roi_pooling.c
│ ├── roi_pooling.h
│ ├── rroi_align.cu.o
│ ├── rroi_align_cuda.c
│ ├── rroi_align_cuda.h
│ ├── rroi_align_kernel.cu
│ └── rroi_align_kernel.h
├── test.py
└── test2.py
├── sample_train_data
├── MLT
│ ├── done
│ │ ├── gt_img_5407.txt
│ │ └── img_5407.jpg
│ ├── icdar-2015-Ch4
│ │ └── Train
│ │ │ ├── gt_img_784.txt
│ │ │ └── img_784.jpg
│ └── trainMLT.txt
└── MLT_CROPS
│ ├── gt.txt
│ ├── word_118.png
│ ├── word_119.png
│ ├── word_120.png
│ └── word_121.png
├── src
├── __init__.py
├── ocr_process.py
└── utils.py
├── test.py
├── tools
├── Arial-Unicode-Regular.ttf
├── __init__.py
├── align_demo.py
├── codec.txt
├── codec_rctw.txt
├── data_gen.py
├── data_util.py
├── demo.py
├── eval.py
├── models.py
├── net_utils.py
├── ocr_gen.py
├── ocr_test_utils.py
├── ocr_utils.py
├── test.py
├── test_crnn.1.py
├── test_crnn.2.py
├── test_crnn.py
├── train.1.py
├── train_crnn.1.py
├── train_crnn.2.py
├── train_crnn.py
├── train_ocr.py
└── utils.py
├── train.py
└── weights
└── __init__.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # compilation and distribution
2 | __pycache__
3 | _ext
4 | *.pyc
5 | *.so
6 |
7 | # pytorch/python/numpy formats
8 | *.pth
9 | *.pkl
10 | *.npy
11 |
12 | # ipython/jupyter notebooks
13 | *.ipynb
14 | **/.ipynb_checkpoints/
15 |
16 | # Editor temporaries
17 | *.swn
18 | *.swo
19 | *.swp
20 | *~
21 |
22 | # Pycharm editor settings
23 | .idea
24 | .vscode
25 | /backup
26 | /data/Chinese
27 | /data/*.zip
28 | /data/*.py
29 | /weights/*.h5
30 | /weights/e2e-mlt.h5
31 |
32 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # FOTS.pytorch
2 | This is an unofficial implementation of [FOTS: Fast Oriented Text Spotting with a Unified Network](https://arxiv.org/abs/1801.01671), which is a unified end-to-end trainable Fast Oriented Text Spotting (FOTS) network for simultaneous detection and recognition, sharing computation and visual information among the two complementary tasks. and i mainly borrows from [E2E-MLT](https://arxiv.org/abs/1801.09919), which is an End-to-end text training and recognition network.
3 |
4 | ## Requirements
5 | - python3.x with
6 | - opencv-python
7 | - pytorch 0.4.1
8 | - torchvision
9 | - warp-ctc (https://github.com/SeanNaren/warp-ctc/)
10 | - gcc6.3 or 7.3 for nms
11 |
12 | ## Compile extension file
13 | - RoIRotate
14 | for roirotate layer, I've written a pytorch automatic layer
15 | 
16 | compiling:
17 | ```bash
18 | # optional
19 | source activate conda_env
20 | cd $project_path/rroi_align
21 | sh make.sh # compile
22 | ```
23 |
24 | - EAST nms
25 | for EAST nms compile, gcc-6.3 works for me. other version i have not test.
26 | any problem can refer to [https://github.com/MichalBusta/E2E-MLT/issues/21](https://github.com/MichalBusta/E2E-MLT/issues/21) or the [argman/EAST](https://github.com/argman/EAST)
27 |
28 |
29 | # TEST
30 | first download the pretrained model from [baidu](https://pan.baidu.com/s/1So6SRIMUOKL9R7rn9dvC0A),**password:ndav**. which is trained on ICDAR2015. put the model in `weights` folder, then can test on some icdar2015 test samples
31 | ```bash
32 | cd $project_path
33 | python test.py
34 | ```
35 | some examples:
36 |
37 |
38 | 图1
39 | 图2
40 |
41 |
42 | 图3
43 | 图4
44 |
45 |
46 | 图5
47 | 图6
48 |
49 |
50 |
51 |
52 | ## RoIRotate
53 | RoIRotate applies transformation on oriented feature regions to obtain axis-aligned feature maps.use bilinear interpolation to compute the values of the output
54 |
55 |
56 | 图1
57 | 图2
58 |
59 |
60 | 图3
61 | 图4
62 |
63 |
64 | 图5
65 | 图6
66 |
67 |
68 |
69 |
70 | # Train
71 | download the ICDAR2015 data and the train_list from [baidu](https://pan.baidu.com/s/1caSNRb9DIHSEvbTtPpKaeA), **password:q1au**
72 | ```python
73 | # train_list.txt list the train images path
74 | /home/yangna/deepblue/OCR/data/ICDAR2015/icdar-2015-Ch4/img_546.jpg
75 | /home/yangna/deepblue/OCR/data/ICDAR2015/icdar-2015-Ch4/img_277.jpg
76 | /home/yangna/deepblue/OCR/data/ICDAR2015/icdar-2015-Ch4/img_462.jpg
77 | /home/yangna/deepblue/OCR/data/ICDAR2015/icdar-2015-Ch4/img_237.jpg
78 | ```
79 |
80 | training:
81 | ```bash
82 | python train.py -train_list=$path_to/ICDAR2015.txt
83 | ```
84 |
85 | # Acknowledgments
86 |
87 | Code borrows from [MichalBusta/E2E-MLT](https://github.com/MichalBusta/E2E-MLT)
88 |
--------------------------------------------------------------------------------
/data/ICDAR2015/001.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chenjun2hao/FOTS.pytorch/a2d1bf71d66197ed6b20c4cfcbe60d4735a20e0c/data/ICDAR2015/001.jpg
--------------------------------------------------------------------------------
/data/ICDAR2015/002.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chenjun2hao/FOTS.pytorch/a2d1bf71d66197ed6b20c4cfcbe60d4735a20e0c/data/ICDAR2015/002.jpg
--------------------------------------------------------------------------------
/data/ICDAR2015/003.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chenjun2hao/FOTS.pytorch/a2d1bf71d66197ed6b20c4cfcbe60d4735a20e0c/data/ICDAR2015/003.jpg
--------------------------------------------------------------------------------
/data/ICDAR2015/004.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chenjun2hao/FOTS.pytorch/a2d1bf71d66197ed6b20c4cfcbe60d4735a20e0c/data/ICDAR2015/004.jpg
--------------------------------------------------------------------------------
/data/ICDAR2015/img_1.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chenjun2hao/FOTS.pytorch/a2d1bf71d66197ed6b20c4cfcbe60d4735a20e0c/data/ICDAR2015/img_1.jpg
--------------------------------------------------------------------------------
/data/ICDAR2015/img_10.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chenjun2hao/FOTS.pytorch/a2d1bf71d66197ed6b20c4cfcbe60d4735a20e0c/data/ICDAR2015/img_10.jpg
--------------------------------------------------------------------------------
/data/ICDAR2015/img_13.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chenjun2hao/FOTS.pytorch/a2d1bf71d66197ed6b20c4cfcbe60d4735a20e0c/data/ICDAR2015/img_13.jpg
--------------------------------------------------------------------------------
/data/ICDAR2015/img_16.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chenjun2hao/FOTS.pytorch/a2d1bf71d66197ed6b20c4cfcbe60d4735a20e0c/data/ICDAR2015/img_16.jpg
--------------------------------------------------------------------------------
/data/ICDAR2015/img_19.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chenjun2hao/FOTS.pytorch/a2d1bf71d66197ed6b20c4cfcbe60d4735a20e0c/data/ICDAR2015/img_19.jpg
--------------------------------------------------------------------------------
/data/ICDAR2015/img_22.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chenjun2hao/FOTS.pytorch/a2d1bf71d66197ed6b20c4cfcbe60d4735a20e0c/data/ICDAR2015/img_22.jpg
--------------------------------------------------------------------------------
/data/ICDAR2015/img_25.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chenjun2hao/FOTS.pytorch/a2d1bf71d66197ed6b20c4cfcbe60d4735a20e0c/data/ICDAR2015/img_25.jpg
--------------------------------------------------------------------------------
/data/ICDAR2015/img_28.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chenjun2hao/FOTS.pytorch/a2d1bf71d66197ed6b20c4cfcbe60d4735a20e0c/data/ICDAR2015/img_28.jpg
--------------------------------------------------------------------------------
/data/ICDAR2015/img_31.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chenjun2hao/FOTS.pytorch/a2d1bf71d66197ed6b20c4cfcbe60d4735a20e0c/data/ICDAR2015/img_31.jpg
--------------------------------------------------------------------------------
/data/ICDAR2015/img_4.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chenjun2hao/FOTS.pytorch/a2d1bf71d66197ed6b20c4cfcbe60d4735a20e0c/data/ICDAR2015/img_4.jpg
--------------------------------------------------------------------------------
/data/ICDAR2015/img_7.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chenjun2hao/FOTS.pytorch/a2d1bf71d66197ed6b20c4cfcbe60d4735a20e0c/data/ICDAR2015/img_7.jpg
--------------------------------------------------------------------------------
/data/alphabet.txt:
--------------------------------------------------------------------------------
1 | 7BCNTh2!F'P0ouRvz3[Qdesr6#:ÉyU(4bt%"?´Kl.ZOM8@A1+)/ ki&DW$fwn;=p5HqSjV]JX-GEagxILmYc9,
--------------------------------------------------------------------------------
/data/example_image/img_1.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chenjun2hao/FOTS.pytorch/a2d1bf71d66197ed6b20c4cfcbe60d4735a20e0c/data/example_image/img_1.jpg
--------------------------------------------------------------------------------
/data/example_image/img_10.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chenjun2hao/FOTS.pytorch/a2d1bf71d66197ed6b20c4cfcbe60d4735a20e0c/data/example_image/img_10.jpg
--------------------------------------------------------------------------------
/data/example_image/img_13.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chenjun2hao/FOTS.pytorch/a2d1bf71d66197ed6b20c4cfcbe60d4735a20e0c/data/example_image/img_13.jpg
--------------------------------------------------------------------------------
/data/example_image/img_16.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chenjun2hao/FOTS.pytorch/a2d1bf71d66197ed6b20c4cfcbe60d4735a20e0c/data/example_image/img_16.jpg
--------------------------------------------------------------------------------
/data/example_image/img_19.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chenjun2hao/FOTS.pytorch/a2d1bf71d66197ed6b20c4cfcbe60d4735a20e0c/data/example_image/img_19.jpg
--------------------------------------------------------------------------------
/data/example_image/img_22.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chenjun2hao/FOTS.pytorch/a2d1bf71d66197ed6b20c4cfcbe60d4735a20e0c/data/example_image/img_22.jpg
--------------------------------------------------------------------------------
/data/example_image/img_25.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chenjun2hao/FOTS.pytorch/a2d1bf71d66197ed6b20c4cfcbe60d4735a20e0c/data/example_image/img_25.jpg
--------------------------------------------------------------------------------
/data/example_image/img_28.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chenjun2hao/FOTS.pytorch/a2d1bf71d66197ed6b20c4cfcbe60d4735a20e0c/data/example_image/img_28.jpg
--------------------------------------------------------------------------------
/data/example_image/img_31.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chenjun2hao/FOTS.pytorch/a2d1bf71d66197ed6b20c4cfcbe60d4735a20e0c/data/example_image/img_31.jpg
--------------------------------------------------------------------------------
/data/example_image/img_4.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chenjun2hao/FOTS.pytorch/a2d1bf71d66197ed6b20c4cfcbe60d4735a20e0c/data/example_image/img_4.jpg
--------------------------------------------------------------------------------
/data/example_image/img_7.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chenjun2hao/FOTS.pytorch/a2d1bf71d66197ed6b20c4cfcbe60d4735a20e0c/data/example_image/img_7.jpg
--------------------------------------------------------------------------------
/data/tshow/crop0.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chenjun2hao/FOTS.pytorch/a2d1bf71d66197ed6b20c4cfcbe60d4735a20e0c/data/tshow/crop0.jpg
--------------------------------------------------------------------------------
/data/tshow/crop1.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chenjun2hao/FOTS.pytorch/a2d1bf71d66197ed6b20c4cfcbe60d4735a20e0c/data/tshow/crop1.jpg
--------------------------------------------------------------------------------
/data/tshow/crop10.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chenjun2hao/FOTS.pytorch/a2d1bf71d66197ed6b20c4cfcbe60d4735a20e0c/data/tshow/crop10.jpg
--------------------------------------------------------------------------------
/data/tshow/crop11.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chenjun2hao/FOTS.pytorch/a2d1bf71d66197ed6b20c4cfcbe60d4735a20e0c/data/tshow/crop11.jpg
--------------------------------------------------------------------------------
/data/tshow/crop12.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chenjun2hao/FOTS.pytorch/a2d1bf71d66197ed6b20c4cfcbe60d4735a20e0c/data/tshow/crop12.jpg
--------------------------------------------------------------------------------
/data/tshow/crop13.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chenjun2hao/FOTS.pytorch/a2d1bf71d66197ed6b20c4cfcbe60d4735a20e0c/data/tshow/crop13.jpg
--------------------------------------------------------------------------------
/data/tshow/crop14.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chenjun2hao/FOTS.pytorch/a2d1bf71d66197ed6b20c4cfcbe60d4735a20e0c/data/tshow/crop14.jpg
--------------------------------------------------------------------------------
/data/tshow/crop15.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chenjun2hao/FOTS.pytorch/a2d1bf71d66197ed6b20c4cfcbe60d4735a20e0c/data/tshow/crop15.jpg
--------------------------------------------------------------------------------
/data/tshow/crop16.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chenjun2hao/FOTS.pytorch/a2d1bf71d66197ed6b20c4cfcbe60d4735a20e0c/data/tshow/crop16.jpg
--------------------------------------------------------------------------------
/data/tshow/crop17.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chenjun2hao/FOTS.pytorch/a2d1bf71d66197ed6b20c4cfcbe60d4735a20e0c/data/tshow/crop17.jpg
--------------------------------------------------------------------------------
/data/tshow/crop18.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chenjun2hao/FOTS.pytorch/a2d1bf71d66197ed6b20c4cfcbe60d4735a20e0c/data/tshow/crop18.jpg
--------------------------------------------------------------------------------
/data/tshow/crop19.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chenjun2hao/FOTS.pytorch/a2d1bf71d66197ed6b20c4cfcbe60d4735a20e0c/data/tshow/crop19.jpg
--------------------------------------------------------------------------------
/data/tshow/crop2.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chenjun2hao/FOTS.pytorch/a2d1bf71d66197ed6b20c4cfcbe60d4735a20e0c/data/tshow/crop2.jpg
--------------------------------------------------------------------------------
/data/tshow/crop20.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chenjun2hao/FOTS.pytorch/a2d1bf71d66197ed6b20c4cfcbe60d4735a20e0c/data/tshow/crop20.jpg
--------------------------------------------------------------------------------
/data/tshow/crop21.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chenjun2hao/FOTS.pytorch/a2d1bf71d66197ed6b20c4cfcbe60d4735a20e0c/data/tshow/crop21.jpg
--------------------------------------------------------------------------------
/data/tshow/crop22.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chenjun2hao/FOTS.pytorch/a2d1bf71d66197ed6b20c4cfcbe60d4735a20e0c/data/tshow/crop22.jpg
--------------------------------------------------------------------------------
/data/tshow/crop23.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chenjun2hao/FOTS.pytorch/a2d1bf71d66197ed6b20c4cfcbe60d4735a20e0c/data/tshow/crop23.jpg
--------------------------------------------------------------------------------
/data/tshow/crop24.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chenjun2hao/FOTS.pytorch/a2d1bf71d66197ed6b20c4cfcbe60d4735a20e0c/data/tshow/crop24.jpg
--------------------------------------------------------------------------------
/data/tshow/crop25.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chenjun2hao/FOTS.pytorch/a2d1bf71d66197ed6b20c4cfcbe60d4735a20e0c/data/tshow/crop25.jpg
--------------------------------------------------------------------------------
/data/tshow/crop26.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chenjun2hao/FOTS.pytorch/a2d1bf71d66197ed6b20c4cfcbe60d4735a20e0c/data/tshow/crop26.jpg
--------------------------------------------------------------------------------
/data/tshow/crop3.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chenjun2hao/FOTS.pytorch/a2d1bf71d66197ed6b20c4cfcbe60d4735a20e0c/data/tshow/crop3.jpg
--------------------------------------------------------------------------------
/data/tshow/crop4.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chenjun2hao/FOTS.pytorch/a2d1bf71d66197ed6b20c4cfcbe60d4735a20e0c/data/tshow/crop4.jpg
--------------------------------------------------------------------------------
/data/tshow/crop5.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chenjun2hao/FOTS.pytorch/a2d1bf71d66197ed6b20c4cfcbe60d4735a20e0c/data/tshow/crop5.jpg
--------------------------------------------------------------------------------
/data/tshow/crop6.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chenjun2hao/FOTS.pytorch/a2d1bf71d66197ed6b20c4cfcbe60d4735a20e0c/data/tshow/crop6.jpg
--------------------------------------------------------------------------------
/data/tshow/crop7.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chenjun2hao/FOTS.pytorch/a2d1bf71d66197ed6b20c4cfcbe60d4735a20e0c/data/tshow/crop7.jpg
--------------------------------------------------------------------------------
/data/tshow/crop8.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chenjun2hao/FOTS.pytorch/a2d1bf71d66197ed6b20c4cfcbe60d4735a20e0c/data/tshow/crop8.jpg
--------------------------------------------------------------------------------
/data/tshow/crop9.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chenjun2hao/FOTS.pytorch/a2d1bf71d66197ed6b20c4cfcbe60d4735a20e0c/data/tshow/crop9.jpg
--------------------------------------------------------------------------------
/data/tshow/img0.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chenjun2hao/FOTS.pytorch/a2d1bf71d66197ed6b20c4cfcbe60d4735a20e0c/data/tshow/img0.jpg
--------------------------------------------------------------------------------
/data/tshow/img1.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chenjun2hao/FOTS.pytorch/a2d1bf71d66197ed6b20c4cfcbe60d4735a20e0c/data/tshow/img1.jpg
--------------------------------------------------------------------------------
/data/tshow/img2.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chenjun2hao/FOTS.pytorch/a2d1bf71d66197ed6b20c4cfcbe60d4735a20e0c/data/tshow/img2.jpg
--------------------------------------------------------------------------------
/data/tshow/img3.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chenjun2hao/FOTS.pytorch/a2d1bf71d66197ed6b20c4cfcbe60d4735a20e0c/data/tshow/img3.jpg
--------------------------------------------------------------------------------
/data/tshow/img4.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chenjun2hao/FOTS.pytorch/a2d1bf71d66197ed6b20c4cfcbe60d4735a20e0c/data/tshow/img4.jpg
--------------------------------------------------------------------------------
/data/tshow/img5.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chenjun2hao/FOTS.pytorch/a2d1bf71d66197ed6b20c4cfcbe60d4735a20e0c/data/tshow/img5.jpg
--------------------------------------------------------------------------------
/data/tshow/img6.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chenjun2hao/FOTS.pytorch/a2d1bf71d66197ed6b20c4cfcbe60d4735a20e0c/data/tshow/img6.jpg
--------------------------------------------------------------------------------
/data/tshow/img7.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chenjun2hao/FOTS.pytorch/a2d1bf71d66197ed6b20c4cfcbe60d4735a20e0c/data/tshow/img7.jpg
--------------------------------------------------------------------------------
/images/roirototate.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chenjun2hao/FOTS.pytorch/a2d1bf71d66197ed6b20c4cfcbe60d4735a20e0c/images/roirototate.jpg
--------------------------------------------------------------------------------
/images/synth.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chenjun2hao/FOTS.pytorch/a2d1bf71d66197ed6b20c4cfcbe60d4735a20e0c/images/synth.png
--------------------------------------------------------------------------------
/nms/.gitignore:
--------------------------------------------------------------------------------
1 | adaptor.so
2 |
--------------------------------------------------------------------------------
/nms/Makefile:
--------------------------------------------------------------------------------
1 | CXXFLAGS = -I include -std=c++11 -O3 $(shell python3-config --cflags)
2 | LDFLAGS = $(shell python3-config --ldflags)
3 |
4 | DEPS = nms.h $(shell find include -xtype f)
5 | CXX_SOURCES = adaptor.cpp include/clipper/clipper.cpp
6 |
7 | LIB_SO = adaptor.so
8 |
9 | $(LIB_SO): $(CXX_SOURCES) $(DEPS)
10 | $(CXX) -o $@ $(CXXFLAGS) $(LDFLAGS) $(CXX_SOURCES) --shared -fPIC
11 |
12 | clean:
13 | rm -rf $(LIB_SO)
14 |
--------------------------------------------------------------------------------
/nms/__init__.py:
--------------------------------------------------------------------------------
1 | import subprocess
2 | import os
3 | import numpy as np
4 |
5 | BASE_DIR = os.path.dirname(os.path.realpath(__file__))
6 |
7 | if subprocess.call(['make', '-C', BASE_DIR]) != 0: # return value
8 | raise RuntimeError('Cannot compile nms: {}'.format(BASE_DIR))
9 |
10 |
11 | def do_nms(segm_map, geo_map, angle_pred, poly_map, thres=0.3, thres2=0.2, segm_thresh=0.5):
12 | precision=10000
13 | from .adaptor import do_nms as nms_impl
14 | ret = np.array(nms_impl(segm_map, geo_map, angle_pred, poly_map, thres, thres2, segm_thresh), dtype='float32')
15 | if len(ret) > 0:
16 | ret[:,:8] /= precision
17 | return ret
18 |
19 |
20 | def get_boxes(iou_map, rbox, angle_pred, segm_thresh=0.5):
21 |
22 | angle_pred = angle_pred.swapaxes(0, 1)
23 | angle_pred = angle_pred.swapaxes(1, 2)
24 |
25 | poly_map = np.zeros((iou_map.shape[0], iou_map.shape[1]), dtype = np.int32)
26 | poly_map.fill(-1);
27 |
28 | boxes = do_nms( iou_map, rbox, angle_pred, poly_map, 0.4, 0.2, segm_thresh)
29 | return boxes
30 |
31 |
--------------------------------------------------------------------------------
/nms/adaptor.cpp:
--------------------------------------------------------------------------------
1 | #include "../nms/include/pybind11/numpy.h"
2 | #include "../nms/include/pybind11/pybind11.h"
3 | #include "../nms/include/pybind11/stl.h"
4 | #include "../nms/include/pybind11/stl_bind.h"
5 | #include "../nms/nms.h"
6 |
7 | namespace py = pybind11;
8 |
9 | namespace cl = ClipperLib;
10 |
11 | namespace nms_adaptor {
12 |
13 | std::vector> polys2floats(std::vector &polys) {
14 | std::vector> ret;
15 | for (size_t i = 0; i < polys.size(); i ++) {
16 | auto &p = polys[i];
17 | auto &poly = p.poly;
18 |
19 | ret.emplace_back(std::vector{
20 | float(poly[0].X), float(poly[0].Y),
21 | float(poly[1].X), float(poly[1].Y),
22 | float(poly[2].X), float(poly[2].Y),
23 | float(poly[3].X), float(poly[3].Y),
24 | float(p.score),
25 | });
26 | }
27 |
28 | return ret;
29 | }
30 |
31 | /**
32 | *
33 | * \param quad_n9 an n-by-9 numpy array, where first 8 numbers denote the
34 | * quadrangle, and the last one is the score
35 | * \param iou_threshold two quadrangles with iou score above this threshold
36 | * will be merged
37 | *
38 | * \return an n-by-9 numpy array, the merged quadrangles
39 | */
40 | std::vector> do_nms(
41 | py::array_t segm,
42 | py::array_t geo_map,
43 | py::array_t angle,
44 | py::array_t poly_map,
45 | float iou_threshold, float iou_threshold2, float segm_threshold) {
46 | auto ibuf = segm.request();
47 | auto pbuf = geo_map.request();
48 | auto abuf = angle.request();
49 | auto poly_buff = poly_map.request();
50 | if (pbuf.ndim != 3)
51 | throw std::runtime_error("geometry map must have a shape of (h x w x 4)");
52 | if (ibuf.ndim != 2)
53 | throw std::runtime_error("segmentation have a shape of (h x w)");
54 | if (abuf.ndim != 3)
55 | throw std::runtime_error("angle have a shape of (h x w x 2)");
56 | if (poly_buff.ndim != 2)
57 | throw std::runtime_error("polygon buffer have a shape of (h x w)");
58 |
59 | //TODO we are missing a lot of asserts ...
60 |
61 | int w = ibuf.shape[1];
62 | int h = ibuf.shape[0];
63 | int offset = 0;
64 | int rstride = w * 4;
65 | int astride = w * 2;
66 | float* iptr = static_cast(ibuf.ptr);
67 | float* rptr = static_cast(pbuf.ptr);
68 | float* aptr = static_cast(abuf.ptr);
69 | int* poly_ptr = static_cast(poly_buff.ptr);
70 | float scale_factor = 4;
71 |
72 | float precision = 10000;
73 |
74 | std::vector polys;
75 | using cInt = cl::cInt;
76 | for(int y = 0; y < h; y++){
77 | for(int x = 0; x < w; x++){
78 | auto p = iptr + offset;
79 | auto r = rptr + y * rstride + x * 4;
80 | auto a = aptr + y * astride + x * 2;
81 | if( *p > segm_threshold ){
82 | float angle_cos = a[1];
83 | float angle_sin = a[0];
84 |
85 | float xp = x + 0.25f;
86 | float yp = y + 0.25f;
87 |
88 | float pos_r_x = (xp - r[2] * angle_cos) * scale_factor;
89 | float pos_r_y = (yp - r[2] * angle_sin) * scale_factor;
90 | float pos_r2_x = (xp + r[3] * angle_cos) * scale_factor;
91 | float pos_r2_y = (yp + r[3] * angle_sin) * scale_factor;
92 |
93 | float ph = 9;// (r[0] + r[1]) + 1e-5;
94 | float phx = 9;
95 |
96 | float p_left = expf(-r[2] / phx);
97 | float p_top = expf(-r[0] / ph);
98 | float p_right = expf(-r[3] / phx);
99 | float p_bt = expf(-r[1] / ph);
100 |
101 | nms::Polygon poly{
102 | {
103 | {cInt(roundf(precision * (pos_r_x - r[1] * angle_sin * scale_factor))), cInt(roundf(precision * (pos_r_y + r[1] * angle_cos * scale_factor)))},
104 | {cInt(roundf(precision * (pos_r_x + r[0] * angle_sin * scale_factor ))), cInt(roundf(precision * (pos_r_y - r[0] * angle_cos * scale_factor)))},
105 | {cInt(roundf(precision * (pos_r2_x + r[0] * angle_sin * scale_factor))), cInt(roundf(precision * (pos_r2_y - r[0] * angle_cos * scale_factor)))},
106 | {cInt(roundf(precision * (pos_r2_x - r[1] * angle_sin * scale_factor))), cInt(roundf(precision * (pos_r2_y + r[1] * angle_cos * scale_factor)))},
107 | },
108 | p[0],
109 | {p_left * p_bt, p_left * p_top, p_right * p_top, p_right * p_bt},
110 | x,
111 | y
112 | };
113 | polys.push_back(poly);
114 | }
115 | offset++;
116 | }
117 | }
118 | std::vector poly_out = nms::merge_iou(polys, poly_ptr, w, h, iou_threshold, iou_threshold2);
119 | return polys2floats(poly_out);
120 | }
121 |
122 | }
123 |
124 | PYBIND11_MODULE(adaptor, m) {
125 |
126 | m.def("do_nms", &nms_adaptor::do_nms,
127 | "perform non-maxima suppression");
128 | }
129 |
130 |
--------------------------------------------------------------------------------
/nms/include/pybind11/buffer_info.h:
--------------------------------------------------------------------------------
1 | /*
2 | pybind11/buffer_info.h: Python buffer object interface
3 |
4 | Copyright (c) 2016 Wenzel Jakob
5 |
6 | All rights reserved. Use of this source code is governed by a
7 | BSD-style license that can be found in the LICENSE file.
8 | */
9 |
10 | #pragma once
11 |
12 | #include "../../../nms/include/pybind11/detail/common.h"
13 |
14 | NAMESPACE_BEGIN(PYBIND11_NAMESPACE)
15 |
16 | /// Information record describing a Python buffer object
17 | struct buffer_info {
18 | void *ptr = nullptr; // Pointer to the underlying storage
19 | ssize_t itemsize = 0; // Size of individual items in bytes
20 | ssize_t size = 0; // Total number of entries
21 | std::string format; // For homogeneous buffers, this should be set to format_descriptor::format()
22 | ssize_t ndim = 0; // Number of dimensions
23 | std::vector shape; // Shape of the tensor (1 entry per dimension)
24 | std::vector strides; // Number of entries between adjacent entries (for each per dimension)
25 |
26 | buffer_info() { }
27 |
28 | buffer_info(void *ptr, ssize_t itemsize, const std::string &format, ssize_t ndim,
29 | detail::any_container shape_in, detail::any_container strides_in)
30 | : ptr(ptr), itemsize(itemsize), size(1), format(format), ndim(ndim),
31 | shape(std::move(shape_in)), strides(std::move(strides_in)) {
32 | if (ndim != (ssize_t) shape.size() || ndim != (ssize_t) strides.size())
33 | pybind11_fail("buffer_info: ndim doesn't match shape and/or strides length");
34 | for (size_t i = 0; i < (size_t) ndim; ++i)
35 | size *= shape[i];
36 | }
37 |
38 | template
39 | buffer_info(T *ptr, detail::any_container shape_in, detail::any_container strides_in)
40 | : buffer_info(private_ctr_tag(), ptr, sizeof(T), format_descriptor::format(), static_cast(shape_in->size()), std::move(shape_in), std::move(strides_in)) { }
41 |
42 | buffer_info(void *ptr, ssize_t itemsize, const std::string &format, ssize_t size)
43 | : buffer_info(ptr, itemsize, format, 1, {size}, {itemsize}) { }
44 |
45 | template
46 | buffer_info(T *ptr, ssize_t size)
47 | : buffer_info(ptr, sizeof(T), format_descriptor::format(), size) { }
48 |
49 | explicit buffer_info(Py_buffer *view, bool ownview = true)
50 | : buffer_info(view->buf, view->itemsize, view->format, view->ndim,
51 | {view->shape, view->shape + view->ndim}, {view->strides, view->strides + view->ndim}) {
52 | this->view = view;
53 | this->ownview = ownview;
54 | }
55 |
56 | buffer_info(const buffer_info &) = delete;
57 | buffer_info& operator=(const buffer_info &) = delete;
58 |
59 | buffer_info(buffer_info &&other) {
60 | (*this) = std::move(other);
61 | }
62 |
63 | buffer_info& operator=(buffer_info &&rhs) {
64 | ptr = rhs.ptr;
65 | itemsize = rhs.itemsize;
66 | size = rhs.size;
67 | format = std::move(rhs.format);
68 | ndim = rhs.ndim;
69 | shape = std::move(rhs.shape);
70 | strides = std::move(rhs.strides);
71 | std::swap(view, rhs.view);
72 | std::swap(ownview, rhs.ownview);
73 | return *this;
74 | }
75 |
76 | ~buffer_info() {
77 | if (view && ownview) { PyBuffer_Release(view); delete view; }
78 | }
79 |
80 | private:
81 | struct private_ctr_tag { };
82 |
83 | buffer_info(private_ctr_tag, void *ptr, ssize_t itemsize, const std::string &format, ssize_t ndim,
84 | detail::any_container &&shape_in, detail::any_container &&strides_in)
85 | : buffer_info(ptr, itemsize, format, ndim, std::move(shape_in), std::move(strides_in)) { }
86 |
87 | Py_buffer *view = nullptr;
88 | bool ownview = false;
89 | };
90 |
91 | NAMESPACE_BEGIN(detail)
92 |
93 | template struct compare_buffer_info {
94 | static bool compare(const buffer_info& b) {
95 | return b.format == format_descriptor::format() && b.itemsize == (ssize_t) sizeof(T);
96 | }
97 | };
98 |
99 | template struct compare_buffer_info::value>> {
100 | static bool compare(const buffer_info& b) {
101 | return (size_t) b.itemsize == sizeof(T) && (b.format == format_descriptor::value ||
102 | ((sizeof(T) == sizeof(long)) && b.format == (std::is_unsigned::value ? "L" : "l")) ||
103 | ((sizeof(T) == sizeof(size_t)) && b.format == (std::is_unsigned::value ? "N" : "n")));
104 | }
105 | };
106 |
107 | NAMESPACE_END(detail)
108 | NAMESPACE_END(PYBIND11_NAMESPACE)
109 |
--------------------------------------------------------------------------------
/nms/include/pybind11/chrono.h:
--------------------------------------------------------------------------------
1 | /*
2 | pybind11/chrono.h: Transparent conversion between std::chrono and python's datetime
3 |
4 | Copyright (c) 2016 Trent Houliston and
5 | Wenzel Jakob
6 |
7 | All rights reserved. Use of this source code is governed by a
8 | BSD-style license that can be found in the LICENSE file.
9 | */
10 |
11 | #pragma once
12 |
13 | #include
14 | #include
15 | #include
16 | #include
17 | #include "../../../nms/include/pybind11/pybind11.h"
18 |
19 | // Backport the PyDateTime_DELTA functions from Python3.3 if required
20 | #ifndef PyDateTime_DELTA_GET_DAYS
21 | #define PyDateTime_DELTA_GET_DAYS(o) (((PyDateTime_Delta*)o)->days)
22 | #endif
23 | #ifndef PyDateTime_DELTA_GET_SECONDS
24 | #define PyDateTime_DELTA_GET_SECONDS(o) (((PyDateTime_Delta*)o)->seconds)
25 | #endif
26 | #ifndef PyDateTime_DELTA_GET_MICROSECONDS
27 | #define PyDateTime_DELTA_GET_MICROSECONDS(o) (((PyDateTime_Delta*)o)->microseconds)
28 | #endif
29 |
30 | NAMESPACE_BEGIN(PYBIND11_NAMESPACE)
31 | NAMESPACE_BEGIN(detail)
32 |
33 | template class duration_caster {
34 | public:
35 | typedef typename type::rep rep;
36 | typedef typename type::period period;
37 |
38 | typedef std::chrono::duration> days;
39 |
40 | bool load(handle src, bool) {
41 | using namespace std::chrono;
42 |
43 | // Lazy initialise the PyDateTime import
44 | if (!PyDateTimeAPI) { PyDateTime_IMPORT; }
45 |
46 | if (!src) return false;
47 | // If invoked with datetime.delta object
48 | if (PyDelta_Check(src.ptr())) {
49 | value = type(duration_cast>(
50 | days(PyDateTime_DELTA_GET_DAYS(src.ptr()))
51 | + seconds(PyDateTime_DELTA_GET_SECONDS(src.ptr()))
52 | + microseconds(PyDateTime_DELTA_GET_MICROSECONDS(src.ptr()))));
53 | return true;
54 | }
55 | // If invoked with a float we assume it is seconds and convert
56 | else if (PyFloat_Check(src.ptr())) {
57 | value = type(duration_cast>(duration(PyFloat_AsDouble(src.ptr()))));
58 | return true;
59 | }
60 | else return false;
61 | }
62 |
63 | // If this is a duration just return it back
64 | static const std::chrono::duration& get_duration(const std::chrono::duration &src) {
65 | return src;
66 | }
67 |
68 | // If this is a time_point get the time_since_epoch
69 | template static std::chrono::duration get_duration(const std::chrono::time_point> &src) {
70 | return src.time_since_epoch();
71 | }
72 |
73 | static handle cast(const type &src, return_value_policy /* policy */, handle /* parent */) {
74 | using namespace std::chrono;
75 |
76 | // Use overloaded function to get our duration from our source
77 | // Works out if it is a duration or time_point and get the duration
78 | auto d = get_duration(src);
79 |
80 | // Lazy initialise the PyDateTime import
81 | if (!PyDateTimeAPI) { PyDateTime_IMPORT; }
82 |
83 | // Declare these special duration types so the conversions happen with the correct primitive types (int)
84 | using dd_t = duration>;
85 | using ss_t = duration>;
86 | using us_t = duration;
87 |
88 | auto dd = duration_cast(d);
89 | auto subd = d - dd;
90 | auto ss = duration_cast(subd);
91 | auto us = duration_cast(subd - ss);
92 | return PyDelta_FromDSU(dd.count(), ss.count(), us.count());
93 | }
94 |
95 | PYBIND11_TYPE_CASTER(type, _("datetime.timedelta"));
96 | };
97 |
98 | // This is for casting times on the system clock into datetime.datetime instances
99 | template class type_caster> {
100 | public:
101 | typedef std::chrono::time_point type;
102 | bool load(handle src, bool) {
103 | using namespace std::chrono;
104 |
105 | // Lazy initialise the PyDateTime import
106 | if (!PyDateTimeAPI) { PyDateTime_IMPORT; }
107 |
108 | if (!src) return false;
109 | if (PyDateTime_Check(src.ptr())) {
110 | std::tm cal;
111 | cal.tm_sec = PyDateTime_DATE_GET_SECOND(src.ptr());
112 | cal.tm_min = PyDateTime_DATE_GET_MINUTE(src.ptr());
113 | cal.tm_hour = PyDateTime_DATE_GET_HOUR(src.ptr());
114 | cal.tm_mday = PyDateTime_GET_DAY(src.ptr());
115 | cal.tm_mon = PyDateTime_GET_MONTH(src.ptr()) - 1;
116 | cal.tm_year = PyDateTime_GET_YEAR(src.ptr()) - 1900;
117 | cal.tm_isdst = -1;
118 |
119 | value = system_clock::from_time_t(std::mktime(&cal)) + microseconds(PyDateTime_DATE_GET_MICROSECOND(src.ptr()));
120 | return true;
121 | }
122 | else return false;
123 | }
124 |
125 | static handle cast(const std::chrono::time_point &src, return_value_policy /* policy */, handle /* parent */) {
126 | using namespace std::chrono;
127 |
128 | // Lazy initialise the PyDateTime import
129 | if (!PyDateTimeAPI) { PyDateTime_IMPORT; }
130 |
131 | std::time_t tt = system_clock::to_time_t(src);
132 | // this function uses static memory so it's best to copy it out asap just in case
133 | // otherwise other code that is using localtime may break this (not just python code)
134 | std::tm localtime = *std::localtime(&tt);
135 |
136 | // Declare these special duration types so the conversions happen with the correct primitive types (int)
137 | using us_t = duration;
138 |
139 | return PyDateTime_FromDateAndTime(localtime.tm_year + 1900,
140 | localtime.tm_mon + 1,
141 | localtime.tm_mday,
142 | localtime.tm_hour,
143 | localtime.tm_min,
144 | localtime.tm_sec,
145 | (duration_cast(src.time_since_epoch() % seconds(1))).count());
146 | }
147 | PYBIND11_TYPE_CASTER(type, _("datetime.datetime"));
148 | };
149 |
150 | // Other clocks that are not the system clock are not measured as datetime.datetime objects
151 | // since they are not measured on calendar time. So instead we just make them timedeltas
152 | // Or if they have passed us a time as a float we convert that
153 | template class type_caster>
154 | : public duration_caster> {
155 | };
156 |
157 | template class type_caster>
158 | : public duration_caster> {
159 | };
160 |
161 | NAMESPACE_END(detail)
162 | NAMESPACE_END(PYBIND11_NAMESPACE)
163 |
--------------------------------------------------------------------------------
/nms/include/pybind11/common.h:
--------------------------------------------------------------------------------
1 | #include "../../../nms/include/pybind11/detail/common.h"
2 | #warning "Including 'common.h' is deprecated. It will be removed in v3.0. Use 'pybind11.h'."
3 |
--------------------------------------------------------------------------------
/nms/include/pybind11/complex.h:
--------------------------------------------------------------------------------
1 | /*
2 | pybind11/complex.h: Complex number support
3 |
4 | Copyright (c) 2016 Wenzel Jakob
5 |
6 | All rights reserved. Use of this source code is governed by a
7 | BSD-style license that can be found in the LICENSE file.
8 | */
9 |
10 | #pragma once
11 |
12 | #include
13 | #include "../../../nms/include/pybind11/pybind11.h"
14 |
15 | /// glibc defines I as a macro which breaks things, e.g., boost template names
16 | #ifdef I
17 | # undef I
18 | #endif
19 |
20 | NAMESPACE_BEGIN(PYBIND11_NAMESPACE)
21 |
22 | template struct format_descriptor, detail::enable_if_t::value>> {
23 | static constexpr const char c = format_descriptor::c;
24 | static constexpr const char value[3] = { 'Z', c, '\0' };
25 | static std::string format() { return std::string(value); }
26 | };
27 |
28 | #ifndef PYBIND11_CPP17
29 |
30 | template constexpr const char format_descriptor<
31 | std::complex, detail::enable_if_t::value>>::value[3];
32 |
33 | #endif
34 |
35 | NAMESPACE_BEGIN(detail)
36 |
37 | template struct is_fmt_numeric, detail::enable_if_t::value>> {
38 | static constexpr bool value = true;
39 | static constexpr int index = is_fmt_numeric::index + 3;
40 | };
41 |
42 | template class type_caster> {
43 | public:
44 | bool load(handle src, bool convert) {
45 | if (!src)
46 | return false;
47 | if (!convert && !PyComplex_Check(src.ptr()))
48 | return false;
49 | Py_complex result = PyComplex_AsCComplex(src.ptr());
50 | if (result.real == -1.0 && PyErr_Occurred()) {
51 | PyErr_Clear();
52 | return false;
53 | }
54 | value = std::complex((T) result.real, (T) result.imag);
55 | return true;
56 | }
57 |
58 | static handle cast(const std::complex &src, return_value_policy /* policy */, handle /* parent */) {
59 | return PyComplex_FromDoubles((double) src.real(), (double) src.imag());
60 | }
61 |
62 | PYBIND11_TYPE_CASTER(std::complex, _("complex"));
63 | };
64 | NAMESPACE_END(detail)
65 | NAMESPACE_END(PYBIND11_NAMESPACE)
66 |
--------------------------------------------------------------------------------
/nms/include/pybind11/detail/descr.h:
--------------------------------------------------------------------------------
1 | /*
2 | pybind11/detail/descr.h: Helper type for concatenating type signatures at compile time
3 |
4 | Copyright (c) 2016 Wenzel Jakob
5 |
6 | All rights reserved. Use of this source code is governed by a
7 | BSD-style license that can be found in the LICENSE file.
8 | */
9 |
10 | #pragma once
11 |
12 | #include "../../../../nms/include/pybind11/detail/common.h"
13 |
14 | NAMESPACE_BEGIN(PYBIND11_NAMESPACE)
15 | NAMESPACE_BEGIN(detail)
16 |
17 | #if !defined(_MSC_VER)
18 | # define PYBIND11_DESCR_CONSTEXPR static constexpr
19 | #else
20 | # define PYBIND11_DESCR_CONSTEXPR const
21 | #endif
22 |
23 | /* Concatenate type signatures at compile time */
24 | template
25 | struct descr {
26 | char text[N + 1];
27 |
28 | constexpr descr() : text{'\0'} { }
29 | constexpr descr(char const (&s)[N+1]) : descr(s, make_index_sequence()) { }
30 |
31 | template
32 | constexpr descr(char const (&s)[N+1], index_sequence) : text{s[Is]..., '\0'} { }
33 |
34 | template
35 | constexpr descr(char c, Chars... cs) : text{c, static_cast(cs)..., '\0'} { }
36 |
37 | static constexpr std::array types() {
38 | return {{&typeid(Ts)..., nullptr}};
39 | }
40 | };
41 |
42 | template
43 | constexpr descr plus_impl(const descr &a, const descr &b,
44 | index_sequence, index_sequence) {
45 | return {a.text[Is1]..., b.text[Is2]...};
46 | }
47 |
48 | template
49 | constexpr descr operator+(const descr &a, const descr &b) {
50 | return plus_impl(a, b, make_index_sequence(), make_index_sequence());
51 | }
52 |
53 | template
54 | constexpr descr _(char const(&text)[N]) { return descr(text); }
55 | constexpr descr<0> _(char const(&)[1]) { return {}; }
56 |
57 | template struct int_to_str : int_to_str { };
58 | template struct int_to_str<0, Digits...> {
59 | static constexpr auto digits = descr(('0' + Digits)...);
60 | };
61 |
62 | // Ternary description (like std::conditional)
63 | template
64 | constexpr enable_if_t> _(char const(&text1)[N1], char const(&)[N2]) {
65 | return _(text1);
66 | }
67 | template
68 | constexpr enable_if_t> _(char const(&)[N1], char const(&text2)[N2]) {
69 | return _(text2);
70 | }
71 |
72 | template
73 | constexpr enable_if_t _(const T1 &d, const T2 &) { return d; }
74 | template
75 | constexpr enable_if_t _(const T1 &, const T2 &d) { return d; }
76 |
77 | template auto constexpr _() -> decltype(int_to_str::digits) {
78 | return int_to_str::digits;
79 | }
80 |
81 | template constexpr descr<1, Type> _() { return {'%'}; }
82 |
83 | constexpr descr<0> concat() { return {}; }
84 |
85 | template
86 | constexpr descr concat(const descr &descr) { return descr; }
87 |
88 | template
89 | constexpr auto concat(const descr &d, const Args &...args)
90 | -> decltype(std::declval>() + concat(args...)) {
91 | return d + _(", ") + concat(args...);
92 | }
93 |
94 | template
95 | constexpr descr type_descr(const descr &descr) {
96 | return _("{") + descr + _("}");
97 | }
98 |
99 | NAMESPACE_END(detail)
100 | NAMESPACE_END(PYBIND11_NAMESPACE)
101 |
--------------------------------------------------------------------------------
/nms/include/pybind11/detail/typeid.h:
--------------------------------------------------------------------------------
1 | /*
2 | pybind11/detail/typeid.h: Compiler-independent access to type identifiers
3 |
4 | Copyright (c) 2016 Wenzel Jakob
5 |
6 | All rights reserved. Use of this source code is governed by a
7 | BSD-style license that can be found in the LICENSE file.
8 | */
9 |
10 | #pragma once
11 |
12 | #include
13 | #include
14 |
15 | #if defined(__GNUG__)
16 | #include
17 | #endif
18 |
19 | NAMESPACE_BEGIN(PYBIND11_NAMESPACE)
20 | NAMESPACE_BEGIN(detail)
21 | /// Erase all occurrences of a substring
22 | inline void erase_all(std::string &string, const std::string &search) {
23 | for (size_t pos = 0;;) {
24 | pos = string.find(search, pos);
25 | if (pos == std::string::npos) break;
26 | string.erase(pos, search.length());
27 | }
28 | }
29 |
30 | PYBIND11_NOINLINE inline void clean_type_id(std::string &name) {
31 | #if defined(__GNUG__)
32 | int status = 0;
33 | std::unique_ptr res {
34 | abi::__cxa_demangle(name.c_str(), nullptr, nullptr, &status), std::free };
35 | if (status == 0)
36 | name = res.get();
37 | #else
38 | detail::erase_all(name, "class ");
39 | detail::erase_all(name, "struct ");
40 | detail::erase_all(name, "enum ");
41 | #endif
42 | detail::erase_all(name, "pybind11::");
43 | }
44 | NAMESPACE_END(detail)
45 |
46 | /// Return a string representation of a C++ type
47 | template static std::string type_id() {
48 | std::string name(typeid(T).name());
49 | detail::clean_type_id(name);
50 | return name;
51 | }
52 |
53 | NAMESPACE_END(PYBIND11_NAMESPACE)
54 |
--------------------------------------------------------------------------------
/nms/include/pybind11/embed.h:
--------------------------------------------------------------------------------
1 | /*
2 | pybind11/embed.h: Support for embedding the interpreter
3 |
4 | Copyright (c) 2017 Wenzel Jakob
5 |
6 | All rights reserved. Use of this source code is governed by a
7 | BSD-style license that can be found in the LICENSE file.
8 | */
9 |
10 | #pragma once
11 |
12 | #include "../../../nms/include/pybind11/eval.h"
13 | #include "../../../nms/include/pybind11/pybind11.h"
14 |
15 | #if defined(PYPY_VERSION)
16 | # error Embedding the interpreter is not supported with PyPy
17 | #endif
18 |
19 | #if PY_MAJOR_VERSION >= 3
20 | # define PYBIND11_EMBEDDED_MODULE_IMPL(name) \
21 | extern "C" PyObject *pybind11_init_impl_##name() { \
22 | return pybind11_init_wrapper_##name(); \
23 | }
24 | #else
25 | # define PYBIND11_EMBEDDED_MODULE_IMPL(name) \
26 | extern "C" void pybind11_init_impl_##name() { \
27 | pybind11_init_wrapper_##name(); \
28 | }
29 | #endif
30 |
31 | /** \rst
32 | Add a new module to the table of builtins for the interpreter. Must be
33 | defined in global scope. The first macro parameter is the name of the
34 | module (without quotes). The second parameter is the variable which will
35 | be used as the interface to add functions and classes to the module.
36 |
37 | .. code-block:: cpp
38 |
39 | PYBIND11_EMBEDDED_MODULE(example, m) {
40 | // ... initialize functions and classes here
41 | m.def("foo", []() {
42 | return "Hello, World!";
43 | });
44 | }
45 | \endrst */
46 | #define PYBIND11_EMBEDDED_MODULE(name, variable) \
47 | static void PYBIND11_CONCAT(pybind11_init_, name)(pybind11::module &); \
48 | static PyObject PYBIND11_CONCAT(*pybind11_init_wrapper_, name)() { \
49 | auto m = pybind11::module(PYBIND11_TOSTRING(name)); \
50 | try { \
51 | PYBIND11_CONCAT(pybind11_init_, name)(m); \
52 | return m.ptr(); \
53 | } catch (pybind11::error_already_set &e) { \
54 | PyErr_SetString(PyExc_ImportError, e.what()); \
55 | return nullptr; \
56 | } catch (const std::exception &e) { \
57 | PyErr_SetString(PyExc_ImportError, e.what()); \
58 | return nullptr; \
59 | } \
60 | } \
61 | PYBIND11_EMBEDDED_MODULE_IMPL(name) \
62 | pybind11::detail::embedded_module name(PYBIND11_TOSTRING(name), \
63 | PYBIND11_CONCAT(pybind11_init_impl_, name)); \
64 | void PYBIND11_CONCAT(pybind11_init_, name)(pybind11::module &variable)
65 |
66 |
67 | NAMESPACE_BEGIN(PYBIND11_NAMESPACE)
68 | NAMESPACE_BEGIN(detail)
69 |
70 | /// Python 2.7/3.x compatible version of `PyImport_AppendInittab` and error checks.
71 | struct embedded_module {
72 | #if PY_MAJOR_VERSION >= 3
73 | using init_t = PyObject *(*)();
74 | #else
75 | using init_t = void (*)();
76 | #endif
77 | embedded_module(const char *name, init_t init) {
78 | if (Py_IsInitialized())
79 | pybind11_fail("Can't add new modules after the interpreter has been initialized");
80 |
81 | auto result = PyImport_AppendInittab(name, init);
82 | if (result == -1)
83 | pybind11_fail("Insufficient memory to add a new module");
84 | }
85 | };
86 |
87 | NAMESPACE_END(detail)
88 |
89 | /** \rst
90 | Initialize the Python interpreter. No other pybind11 or CPython API functions can be
91 | called before this is done; with the exception of `PYBIND11_EMBEDDED_MODULE`. The
92 | optional parameter can be used to skip the registration of signal handlers (see the
93 | Python documentation for details). Calling this function again after the interpreter
94 | has already been initialized is a fatal error.
95 | \endrst */
96 | inline void initialize_interpreter(bool init_signal_handlers = true) {
97 | if (Py_IsInitialized())
98 | pybind11_fail("The interpreter is already running");
99 |
100 | Py_InitializeEx(init_signal_handlers ? 1 : 0);
101 |
102 | // Make .py files in the working directory available by default
103 | module::import("sys").attr("path").cast().append(".");
104 | }
105 |
106 | /** \rst
107 | Shut down the Python interpreter. No pybind11 or CPython API functions can be called
108 | after this. In addition, pybind11 objects must not outlive the interpreter:
109 |
110 | .. code-block:: cpp
111 |
112 | { // BAD
113 | py::initialize_interpreter();
114 | auto hello = py::str("Hello, World!");
115 | py::finalize_interpreter();
116 | } // <-- BOOM, hello's destructor is called after interpreter shutdown
117 |
118 | { // GOOD
119 | py::initialize_interpreter();
120 | { // scoped
121 | auto hello = py::str("Hello, World!");
122 | } // <-- OK, hello is cleaned up properly
123 | py::finalize_interpreter();
124 | }
125 |
126 | { // BETTER
127 | py::scoped_interpreter guard{};
128 | auto hello = py::str("Hello, World!");
129 | }
130 |
131 | .. warning::
132 |
133 | The interpreter can be restarted by calling `initialize_interpreter` again.
134 | Modules created using pybind11 can be safely re-initialized. However, Python
135 | itself cannot completely unload binary extension modules and there are several
136 | caveats with regard to interpreter restarting. All the details can be found
137 | in the CPython documentation. In short, not all interpreter memory may be
138 | freed, either due to reference cycles or user-created global data.
139 |
140 | \endrst */
141 | inline void finalize_interpreter() {
142 | handle builtins(PyEval_GetBuiltins());
143 | const char *id = PYBIND11_INTERNALS_ID;
144 |
145 | // Get the internals pointer (without creating it if it doesn't exist). It's possible for the
146 | // internals to be created during Py_Finalize() (e.g. if a py::capsule calls `get_internals()`
147 | // during destruction), so we get the pointer-pointer here and check it after Py_Finalize().
148 | detail::internals **internals_ptr_ptr = detail::get_internals_pp();
149 | // It could also be stashed in builtins, so look there too:
150 | if (builtins.contains(id) && isinstance(builtins[id]))
151 | internals_ptr_ptr = capsule(builtins[id]);
152 |
153 | Py_Finalize();
154 |
155 | if (internals_ptr_ptr) {
156 | delete *internals_ptr_ptr;
157 | *internals_ptr_ptr = nullptr;
158 | }
159 | }
160 |
161 | /** \rst
162 | Scope guard version of `initialize_interpreter` and `finalize_interpreter`.
163 | This a move-only guard and only a single instance can exist.
164 |
165 | .. code-block:: cpp
166 |
167 | #include
168 |
169 | int main() {
170 | py::scoped_interpreter guard{};
171 | py::print(Hello, World!);
172 | } // <-- interpreter shutdown
173 | \endrst */
174 | class scoped_interpreter {
175 | public:
176 | scoped_interpreter(bool init_signal_handlers = true) {
177 | initialize_interpreter(init_signal_handlers);
178 | }
179 |
180 | scoped_interpreter(const scoped_interpreter &) = delete;
181 | scoped_interpreter(scoped_interpreter &&other) noexcept { other.is_valid = false; }
182 | scoped_interpreter &operator=(const scoped_interpreter &) = delete;
183 | scoped_interpreter &operator=(scoped_interpreter &&) = delete;
184 |
185 | ~scoped_interpreter() {
186 | if (is_valid)
187 | finalize_interpreter();
188 | }
189 |
190 | private:
191 | bool is_valid = true;
192 | };
193 |
194 | NAMESPACE_END(PYBIND11_NAMESPACE)
195 |
--------------------------------------------------------------------------------
/nms/include/pybind11/eval.h:
--------------------------------------------------------------------------------
1 | /*
2 | pybind11/exec.h: Support for evaluating Python expressions and statements
3 | from strings and files
4 |
5 | Copyright (c) 2016 Klemens Morgenstern and
6 | Wenzel Jakob
7 |
8 | All rights reserved. Use of this source code is governed by a
9 | BSD-style license that can be found in the LICENSE file.
10 | */
11 |
12 | #pragma once
13 |
14 | #include "../../../nms/include/pybind11/pybind11.h"
15 |
16 | NAMESPACE_BEGIN(PYBIND11_NAMESPACE)
17 |
18 | enum eval_mode {
19 | /// Evaluate a string containing an isolated expression
20 | eval_expr,
21 |
22 | /// Evaluate a string containing a single statement. Returns \c none
23 | eval_single_statement,
24 |
25 | /// Evaluate a string containing a sequence of statement. Returns \c none
26 | eval_statements
27 | };
28 |
29 | template
30 | object eval(str expr, object global = globals(), object local = object()) {
31 | if (!local)
32 | local = global;
33 |
34 | /* PyRun_String does not accept a PyObject / encoding specifier,
35 | this seems to be the only alternative */
36 | std::string buffer = "# -*- coding: utf-8 -*-\n" + (std::string) expr;
37 |
38 | int start;
39 | switch (mode) {
40 | case eval_expr: start = Py_eval_input; break;
41 | case eval_single_statement: start = Py_single_input; break;
42 | case eval_statements: start = Py_file_input; break;
43 | default: pybind11_fail("invalid evaluation mode");
44 | }
45 |
46 | PyObject *result = PyRun_String(buffer.c_str(), start, global.ptr(), local.ptr());
47 | if (!result)
48 | throw error_already_set();
49 | return reinterpret_steal(result);
50 | }
51 |
52 | template
53 | object eval(const char (&s)[N], object global = globals(), object local = object()) {
54 | /* Support raw string literals by removing common leading whitespace */
55 | auto expr = (s[0] == '\n') ? str(module::import("textwrap").attr("dedent")(s))
56 | : str(s);
57 | return eval(expr, global, local);
58 | }
59 |
60 | inline void exec(str expr, object global = globals(), object local = object()) {
61 | eval(expr, global, local);
62 | }
63 |
64 | template
65 | void exec(const char (&s)[N], object global = globals(), object local = object()) {
66 | eval(s, global, local);
67 | }
68 |
69 | template
70 | object eval_file(str fname, object global = globals(), object local = object()) {
71 | if (!local)
72 | local = global;
73 |
74 | int start;
75 | switch (mode) {
76 | case eval_expr: start = Py_eval_input; break;
77 | case eval_single_statement: start = Py_single_input; break;
78 | case eval_statements: start = Py_file_input; break;
79 | default: pybind11_fail("invalid evaluation mode");
80 | }
81 |
82 | int closeFile = 1;
83 | std::string fname_str = (std::string) fname;
84 | #if PY_VERSION_HEX >= 0x03040000
85 | FILE *f = _Py_fopen_obj(fname.ptr(), "r");
86 | #elif PY_VERSION_HEX >= 0x03000000
87 | FILE *f = _Py_fopen(fname.ptr(), "r");
88 | #else
89 | /* No unicode support in open() :( */
90 | auto fobj = reinterpret_steal(PyFile_FromString(
91 | const_cast(fname_str.c_str()),
92 | const_cast("r")));
93 | FILE *f = nullptr;
94 | if (fobj)
95 | f = PyFile_AsFile(fobj.ptr());
96 | closeFile = 0;
97 | #endif
98 | if (!f) {
99 | PyErr_Clear();
100 | pybind11_fail("File \"" + fname_str + "\" could not be opened!");
101 | }
102 |
103 | #if PY_VERSION_HEX < 0x03000000 && defined(PYPY_VERSION)
104 | PyObject *result = PyRun_File(f, fname_str.c_str(), start, global.ptr(),
105 | local.ptr());
106 | (void) closeFile;
107 | #else
108 | PyObject *result = PyRun_FileEx(f, fname_str.c_str(), start, global.ptr(),
109 | local.ptr(), closeFile);
110 | #endif
111 |
112 | if (!result)
113 | throw error_already_set();
114 | return reinterpret_steal(result);
115 | }
116 |
117 | NAMESPACE_END(PYBIND11_NAMESPACE)
118 |
--------------------------------------------------------------------------------
/nms/include/pybind11/functional.h:
--------------------------------------------------------------------------------
1 | /*
2 | pybind11/functional.h: std::function<> support
3 |
4 | Copyright (c) 2016 Wenzel Jakob
5 |
6 | All rights reserved. Use of this source code is governed by a
7 | BSD-style license that can be found in the LICENSE file.
8 | */
9 |
10 | #pragma once
11 |
12 | #include
13 | #include "../../../nms/include/pybind11/pybind11.h"
14 |
15 | NAMESPACE_BEGIN(PYBIND11_NAMESPACE)
16 | NAMESPACE_BEGIN(detail)
17 |
18 | template
19 | struct type_caster> {
20 | using type = std::function;
21 | using retval_type = conditional_t::value, void_type, Return>;
22 | using function_type = Return (*) (Args...);
23 |
24 | public:
25 | bool load(handle src, bool convert) {
26 | if (src.is_none()) {
27 | // Defer accepting None to other overloads (if we aren't in convert mode):
28 | if (!convert) return false;
29 | return true;
30 | }
31 |
32 | if (!isinstance(src))
33 | return false;
34 |
35 | auto func = reinterpret_borrow(src);
36 |
37 | /*
38 | When passing a C++ function as an argument to another C++
39 | function via Python, every function call would normally involve
40 | a full C++ -> Python -> C++ roundtrip, which can be prohibitive.
41 | Here, we try to at least detect the case where the function is
42 | stateless (i.e. function pointer or lambda function without
43 | captured variables), in which case the roundtrip can be avoided.
44 | */
45 | if (auto cfunc = func.cpp_function()) {
46 | auto c = reinterpret_borrow(PyCFunction_GET_SELF(cfunc.ptr()));
47 | auto rec = (function_record *) c;
48 |
49 | if (rec && rec->is_stateless &&
50 | same_type(typeid(function_type), *reinterpret_cast(rec->data[1]))) {
51 | struct capture { function_type f; };
52 | value = ((capture *) &rec->data)->f;
53 | return true;
54 | }
55 | }
56 |
57 | value = [func](Args... args) -> Return {
58 | gil_scoped_acquire acq;
59 | object retval(func(std::forward(args)...));
60 | /* Visual studio 2015 parser issue: need parentheses around this expression */
61 | return (retval.template cast());
62 | };
63 | return true;
64 | }
65 |
66 | template
67 | static handle cast(Func &&f_, return_value_policy policy, handle /* parent */) {
68 | if (!f_)
69 | return none().inc_ref();
70 |
71 | auto result = f_.template target();
72 | if (result)
73 | return cpp_function(*result, policy).release();
74 | else
75 | return cpp_function(std::forward(f_), policy).release();
76 | }
77 |
78 | PYBIND11_TYPE_CASTER(type, _("Callable[[") + concat(make_caster::name...) + _("], ")
79 | + make_caster::name + _("]"));
80 | };
81 |
82 | NAMESPACE_END(detail)
83 | NAMESPACE_END(PYBIND11_NAMESPACE)
84 |
--------------------------------------------------------------------------------
/nms/include/pybind11/iostream.h:
--------------------------------------------------------------------------------
1 | /*
2 | pybind11/iostream.h -- Tools to assist with redirecting cout and cerr to Python
3 |
4 | Copyright (c) 2017 Henry F. Schreiner
5 |
6 | All rights reserved. Use of this source code is governed by a
7 | BSD-style license that can be found in the LICENSE file.
8 | */
9 |
10 | #pragma once
11 |
12 | #include
13 | #include
14 | #include
15 | #include
16 | #include
17 | #include "../../../nms/include/pybind11/pybind11.h"
18 |
19 | NAMESPACE_BEGIN(PYBIND11_NAMESPACE)
20 | NAMESPACE_BEGIN(detail)
21 |
22 | // Buffer that writes to Python instead of C++
23 | class pythonbuf : public std::streambuf {
24 | private:
25 | using traits_type = std::streambuf::traits_type;
26 |
27 | char d_buffer[1024];
28 | object pywrite;
29 | object pyflush;
30 |
31 | int overflow(int c) {
32 | if (!traits_type::eq_int_type(c, traits_type::eof())) {
33 | *pptr() = traits_type::to_char_type(c);
34 | pbump(1);
35 | }
36 | return sync() ? traits_type::not_eof(c) : traits_type::eof();
37 | }
38 |
39 | int sync() {
40 | if (pbase() != pptr()) {
41 | // This subtraction cannot be negative, so dropping the sign
42 | str line(pbase(), static_cast(pptr() - pbase()));
43 |
44 | pywrite(line);
45 | pyflush();
46 |
47 | setp(pbase(), epptr());
48 | }
49 | return 0;
50 | }
51 |
52 | public:
53 | pythonbuf(object pyostream)
54 | : pywrite(pyostream.attr("write")),
55 | pyflush(pyostream.attr("flush")) {
56 | setp(d_buffer, d_buffer + sizeof(d_buffer) - 1);
57 | }
58 |
59 | /// Sync before destroy
60 | ~pythonbuf() {
61 | sync();
62 | }
63 | };
64 |
65 | NAMESPACE_END(detail)
66 |
67 |
68 | /** \rst
69 | This a move-only guard that redirects output.
70 |
71 | .. code-block:: cpp
72 |
73 | #include
74 |
75 | ...
76 |
77 | {
78 | py::scoped_ostream_redirect output;
79 | std::cout << "Hello, World!"; // Python stdout
80 | } // <-- return std::cout to normal
81 |
82 | You can explicitly pass the c++ stream and the python object,
83 | for example to guard stderr instead.
84 |
85 | .. code-block:: cpp
86 |
87 | {
88 | py::scoped_ostream_redirect output{std::cerr, py::module::import("sys").attr("stderr")};
89 | std::cerr << "Hello, World!";
90 | }
91 | \endrst */
92 | class scoped_ostream_redirect {
93 | protected:
94 | std::streambuf *old;
95 | std::ostream &costream;
96 | detail::pythonbuf buffer;
97 |
98 | public:
99 | scoped_ostream_redirect(
100 | std::ostream &costream = std::cout,
101 | object pyostream = module::import("sys").attr("stdout"))
102 | : costream(costream), buffer(pyostream) {
103 | old = costream.rdbuf(&buffer);
104 | }
105 |
106 | ~scoped_ostream_redirect() {
107 | costream.rdbuf(old);
108 | }
109 |
110 | scoped_ostream_redirect(const scoped_ostream_redirect &) = delete;
111 | scoped_ostream_redirect(scoped_ostream_redirect &&other) = default;
112 | scoped_ostream_redirect &operator=(const scoped_ostream_redirect &) = delete;
113 | scoped_ostream_redirect &operator=(scoped_ostream_redirect &&) = delete;
114 | };
115 |
116 |
117 | /** \rst
118 | Like `scoped_ostream_redirect`, but redirects cerr by default. This class
119 | is provided primary to make ``py::call_guard`` easier to make.
120 |
121 | .. code-block:: cpp
122 |
123 | m.def("noisy_func", &noisy_func,
124 | py::call_guard());
126 |
127 | \endrst */
128 | class scoped_estream_redirect : public scoped_ostream_redirect {
129 | public:
130 | scoped_estream_redirect(
131 | std::ostream &costream = std::cerr,
132 | object pyostream = module::import("sys").attr("stderr"))
133 | : scoped_ostream_redirect(costream,pyostream) {}
134 | };
135 |
136 |
137 | NAMESPACE_BEGIN(detail)
138 |
139 | // Class to redirect output as a context manager. C++ backend.
140 | class OstreamRedirect {
141 | bool do_stdout_;
142 | bool do_stderr_;
143 | std::unique_ptr redirect_stdout;
144 | std::unique_ptr redirect_stderr;
145 |
146 | public:
147 | OstreamRedirect(bool do_stdout = true, bool do_stderr = true)
148 | : do_stdout_(do_stdout), do_stderr_(do_stderr) {}
149 |
150 | void enter() {
151 | if (do_stdout_)
152 | redirect_stdout.reset(new scoped_ostream_redirect());
153 | if (do_stderr_)
154 | redirect_stderr.reset(new scoped_estream_redirect());
155 | }
156 |
157 | void exit() {
158 | redirect_stdout.reset();
159 | redirect_stderr.reset();
160 | }
161 | };
162 |
163 | NAMESPACE_END(detail)
164 |
165 | /** \rst
166 | This is a helper function to add a C++ redirect context manager to Python
167 | instead of using a C++ guard. To use it, add the following to your binding code:
168 |
169 | .. code-block:: cpp
170 |
171 | #include
172 |
173 | ...
174 |
175 | py::add_ostream_redirect(m, "ostream_redirect");
176 |
177 | You now have a Python context manager that redirects your output:
178 |
179 | .. code-block:: python
180 |
181 | with m.ostream_redirect():
182 | m.print_to_cout_function()
183 |
184 | This manager can optionally be told which streams to operate on:
185 |
186 | .. code-block:: python
187 |
188 | with m.ostream_redirect(stdout=true, stderr=true):
189 | m.noisy_function_with_error_printing()
190 |
191 | \endrst */
192 | inline class_ add_ostream_redirect(module m, std::string name = "ostream_redirect") {
193 | return class_(m, name.c_str(), module_local())
194 | .def(init(), arg("stdout")=true, arg("stderr")=true)
195 | .def("__enter__", &detail::OstreamRedirect::enter)
196 | .def("__exit__", [](detail::OstreamRedirect &self, args) { self.exit(); });
197 | }
198 |
199 | NAMESPACE_END(PYBIND11_NAMESPACE)
200 |
--------------------------------------------------------------------------------
/nms/include/pybind11/operators.h:
--------------------------------------------------------------------------------
1 | /*
2 | pybind11/operator.h: Metatemplates for operator overloading
3 |
4 | Copyright (c) 2016 Wenzel Jakob
5 |
6 | All rights reserved. Use of this source code is governed by a
7 | BSD-style license that can be found in the LICENSE file.
8 | */
9 |
10 | #pragma once
11 |
12 | #include "../../../nms/include/pybind11/pybind11.h"
13 |
14 | #if defined(__clang__) && !defined(__INTEL_COMPILER)
15 | # pragma clang diagnostic ignored "-Wunsequenced" // multiple unsequenced modifications to 'self' (when using def(py::self OP Type()))
16 | #elif defined(_MSC_VER)
17 | # pragma warning(push)
18 | # pragma warning(disable: 4127) // warning C4127: Conditional expression is constant
19 | #endif
20 |
21 | NAMESPACE_BEGIN(PYBIND11_NAMESPACE)
22 | NAMESPACE_BEGIN(detail)
23 |
24 | /// Enumeration with all supported operator types
25 | enum op_id : int {
26 | op_add, op_sub, op_mul, op_div, op_mod, op_divmod, op_pow, op_lshift,
27 | op_rshift, op_and, op_xor, op_or, op_neg, op_pos, op_abs, op_invert,
28 | op_int, op_long, op_float, op_str, op_cmp, op_gt, op_ge, op_lt, op_le,
29 | op_eq, op_ne, op_iadd, op_isub, op_imul, op_idiv, op_imod, op_ilshift,
30 | op_irshift, op_iand, op_ixor, op_ior, op_complex, op_bool, op_nonzero,
31 | op_repr, op_truediv, op_itruediv, op_hash
32 | };
33 |
34 | enum op_type : int {
35 | op_l, /* base type on left */
36 | op_r, /* base type on right */
37 | op_u /* unary operator */
38 | };
39 |
40 | struct self_t { };
41 | static const self_t self = self_t();
42 |
43 | /// Type for an unused type slot
44 | struct undefined_t { };
45 |
46 | /// Don't warn about an unused variable
47 | inline self_t __self() { return self; }
48 |
49 | /// base template of operator implementations
50 | template struct op_impl { };
51 |
52 | /// Operator implementation generator
53 | template struct op_ {
54 | template void execute(Class &cl, const Extra&... extra) const {
55 | using Base = typename Class::type;
56 | using L_type = conditional_t::value, Base, L>;
57 | using R_type = conditional_t::value, Base, R>;
58 | using op = op_impl;
59 | cl.def(op::name(), &op::execute, is_operator(), extra...);
60 | #if PY_MAJOR_VERSION < 3
61 | if (id == op_truediv || id == op_itruediv)
62 | cl.def(id == op_itruediv ? "__idiv__" : ot == op_l ? "__div__" : "__rdiv__",
63 | &op::execute, is_operator(), extra...);
64 | #endif
65 | }
66 | template void execute_cast(Class &cl, const Extra&... extra) const {
67 | using Base = typename Class::type;
68 | using L_type = conditional_t::value, Base, L>;
69 | using R_type = conditional_t::value, Base, R>;
70 | using op = op_impl;
71 | cl.def(op::name(), &op::execute_cast, is_operator(), extra...);
72 | #if PY_MAJOR_VERSION < 3
73 | if (id == op_truediv || id == op_itruediv)
74 | cl.def(id == op_itruediv ? "__idiv__" : ot == op_l ? "__div__" : "__rdiv__",
75 | &op::execute, is_operator(), extra...);
76 | #endif
77 | }
78 | };
79 |
80 | #define PYBIND11_BINARY_OPERATOR(id, rid, op, expr) \
81 | template struct op_impl { \
82 | static char const* name() { return "__" #id "__"; } \
83 | static auto execute(const L &l, const R &r) -> decltype(expr) { return (expr); } \
84 | static B execute_cast(const L &l, const R &r) { return B(expr); } \
85 | }; \
86 | template struct op_impl { \
87 | static char const* name() { return "__" #rid "__"; } \
88 | static auto execute(const R &r, const L &l) -> decltype(expr) { return (expr); } \
89 | static B execute_cast(const R &r, const L &l) { return B(expr); } \
90 | }; \
91 | inline op_ op(const self_t &, const self_t &) { \
92 | return op_(); \
93 | } \
94 | template op_ op(const self_t &, const T &) { \
95 | return op_(); \
96 | } \
97 | template op_ op(const T &, const self_t &) { \
98 | return op_(); \
99 | }
100 |
101 | #define PYBIND11_INPLACE_OPERATOR(id, op, expr) \
102 | template struct op_impl { \
103 | static char const* name() { return "__" #id "__"; } \
104 | static auto execute(L &l, const R &r) -> decltype(expr) { return expr; } \
105 | static B execute_cast(L &l, const R &r) { return B(expr); } \
106 | }; \
107 | template op_ op(const self_t &, const T &) { \
108 | return op_(); \
109 | }
110 |
111 | #define PYBIND11_UNARY_OPERATOR(id, op, expr) \
112 | template struct op_impl { \
113 | static char const* name() { return "__" #id "__"; } \
114 | static auto execute(const L &l) -> decltype(expr) { return expr; } \
115 | static B execute_cast(const L &l) { return B(expr); } \
116 | }; \
117 | inline op_ op(const self_t &) { \
118 | return op_(); \
119 | }
120 |
121 | PYBIND11_BINARY_OPERATOR(sub, rsub, operator-, l - r)
122 | PYBIND11_BINARY_OPERATOR(add, radd, operator+, l + r)
123 | PYBIND11_BINARY_OPERATOR(mul, rmul, operator*, l * r)
124 | PYBIND11_BINARY_OPERATOR(truediv, rtruediv, operator/, l / r)
125 | PYBIND11_BINARY_OPERATOR(mod, rmod, operator%, l % r)
126 | PYBIND11_BINARY_OPERATOR(lshift, rlshift, operator<<, l << r)
127 | PYBIND11_BINARY_OPERATOR(rshift, rrshift, operator>>, l >> r)
128 | PYBIND11_BINARY_OPERATOR(and, rand, operator&, l & r)
129 | PYBIND11_BINARY_OPERATOR(xor, rxor, operator^, l ^ r)
130 | PYBIND11_BINARY_OPERATOR(eq, eq, operator==, l == r)
131 | PYBIND11_BINARY_OPERATOR(ne, ne, operator!=, l != r)
132 | PYBIND11_BINARY_OPERATOR(or, ror, operator|, l | r)
133 | PYBIND11_BINARY_OPERATOR(gt, lt, operator>, l > r)
134 | PYBIND11_BINARY_OPERATOR(ge, le, operator>=, l >= r)
135 | PYBIND11_BINARY_OPERATOR(lt, gt, operator<, l < r)
136 | PYBIND11_BINARY_OPERATOR(le, ge, operator<=, l <= r)
137 | //PYBIND11_BINARY_OPERATOR(pow, rpow, pow, std::pow(l, r))
138 | PYBIND11_INPLACE_OPERATOR(iadd, operator+=, l += r)
139 | PYBIND11_INPLACE_OPERATOR(isub, operator-=, l -= r)
140 | PYBIND11_INPLACE_OPERATOR(imul, operator*=, l *= r)
141 | PYBIND11_INPLACE_OPERATOR(itruediv, operator/=, l /= r)
142 | PYBIND11_INPLACE_OPERATOR(imod, operator%=, l %= r)
143 | PYBIND11_INPLACE_OPERATOR(ilshift, operator<<=, l <<= r)
144 | PYBIND11_INPLACE_OPERATOR(irshift, operator>>=, l >>= r)
145 | PYBIND11_INPLACE_OPERATOR(iand, operator&=, l &= r)
146 | PYBIND11_INPLACE_OPERATOR(ixor, operator^=, l ^= r)
147 | PYBIND11_INPLACE_OPERATOR(ior, operator|=, l |= r)
148 | PYBIND11_UNARY_OPERATOR(neg, operator-, -l)
149 | PYBIND11_UNARY_OPERATOR(pos, operator+, +l)
150 | PYBIND11_UNARY_OPERATOR(abs, abs, std::abs(l))
151 | PYBIND11_UNARY_OPERATOR(hash, hash, std::hash()(l))
152 | PYBIND11_UNARY_OPERATOR(invert, operator~, (~l))
153 | PYBIND11_UNARY_OPERATOR(bool, operator!, !!l)
154 | PYBIND11_UNARY_OPERATOR(int, int_, (int) l)
155 | PYBIND11_UNARY_OPERATOR(float, float_, (double) l)
156 |
157 | #undef PYBIND11_BINARY_OPERATOR
158 | #undef PYBIND11_INPLACE_OPERATOR
159 | #undef PYBIND11_UNARY_OPERATOR
160 | NAMESPACE_END(detail)
161 |
162 | using detail::self;
163 |
164 | NAMESPACE_END(PYBIND11_NAMESPACE)
165 |
166 | #if defined(_MSC_VER)
167 | # pragma warning(pop)
168 | #endif
169 |
--------------------------------------------------------------------------------
/nms/include/pybind11/options.h:
--------------------------------------------------------------------------------
1 | /*
2 | pybind11/options.h: global settings that are configurable at runtime.
3 |
4 | Copyright (c) 2016 Wenzel Jakob
5 |
6 | All rights reserved. Use of this source code is governed by a
7 | BSD-style license that can be found in the LICENSE file.
8 | */
9 |
10 | #pragma once
11 |
12 | #include "../../../nms/include/pybind11/detail/common.h"
13 |
14 | NAMESPACE_BEGIN(PYBIND11_NAMESPACE)
15 |
16 | class options {
17 | public:
18 |
19 | // Default RAII constructor, which leaves settings as they currently are.
20 | options() : previous_state(global_state()) {}
21 |
22 | // Class is non-copyable.
23 | options(const options&) = delete;
24 | options& operator=(const options&) = delete;
25 |
26 | // Destructor, which restores settings that were in effect before.
27 | ~options() {
28 | global_state() = previous_state;
29 | }
30 |
31 | // Setter methods (affect the global state):
32 |
33 | options& disable_user_defined_docstrings() & { global_state().show_user_defined_docstrings = false; return *this; }
34 |
35 | options& enable_user_defined_docstrings() & { global_state().show_user_defined_docstrings = true; return *this; }
36 |
37 | options& disable_function_signatures() & { global_state().show_function_signatures = false; return *this; }
38 |
39 | options& enable_function_signatures() & { global_state().show_function_signatures = true; return *this; }
40 |
41 | // Getter methods (return the global state):
42 |
43 | static bool show_user_defined_docstrings() { return global_state().show_user_defined_docstrings; }
44 |
45 | static bool show_function_signatures() { return global_state().show_function_signatures; }
46 |
47 | // This type is not meant to be allocated on the heap.
48 | void* operator new(size_t) = delete;
49 |
50 | private:
51 |
52 | struct state {
53 | bool show_user_defined_docstrings = true; //< Include user-supplied texts in docstrings.
54 | bool show_function_signatures = true; //< Include auto-generated function signatures in docstrings.
55 | };
56 |
57 | static state &global_state() {
58 | static state instance;
59 | return instance;
60 | }
61 |
62 | state previous_state;
63 | };
64 |
65 | NAMESPACE_END(PYBIND11_NAMESPACE)
66 |
--------------------------------------------------------------------------------
/nms/nms.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 |
3 | #include "../nms/include/clipper/clipper.hpp"
4 |
5 | namespace nms {
6 |
7 | namespace cl = ClipperLib;
8 |
9 | struct Polygon {
10 | cl::Path poly;
11 | float score;
12 | float probs[4];
13 | int x;
14 | int y;
15 | };
16 |
17 | float paths_area(const ClipperLib::Paths &ps) {
18 | float area = 0;
19 | for (auto &&p: ps)
20 | area += cl::Area(p);
21 | return area;
22 | }
23 |
24 | float poly_iou(const Polygon &a, const Polygon &b) {
25 | cl::Clipper clpr;
26 | clpr.AddPath(a.poly, cl::ptSubject, true);
27 | clpr.AddPath(b.poly, cl::ptClip, true);
28 |
29 | cl::Paths inter, uni;
30 | clpr.Execute(cl::ctIntersection, inter, cl::pftEvenOdd);
31 | clpr.Execute(cl::ctUnion, uni, cl::pftEvenOdd);
32 |
33 | auto inter_area = paths_area(inter),
34 | uni_area = paths_area(uni);
35 | return std::abs(inter_area) / std::max(std::abs(uni_area), 1.0f);
36 | }
37 |
38 | bool should_merge(const Polygon &a, const Polygon &b, float iou_threshold) {
39 | return poly_iou(a, b) > iou_threshold;
40 | }
41 |
42 | /**
43 | * Incrementally merge polygons
44 | */
45 | class PolyMerger {
46 | public:
47 | PolyMerger(): score(0), nr_polys(0) {
48 | memset(data, 0, sizeof(data));
49 | memset(probs, 0, 4 * sizeof(float));
50 | }
51 |
52 | /**
53 | * Add a new polygon to be merged.
54 | */
55 | void add(const Polygon &p) {
56 |
57 | auto &poly = p.poly;
58 | data[0] += poly[0].X * p.probs[0];
59 | data[1] += poly[0].Y * p.probs[3];
60 |
61 | data[2] += poly[1].X * p.probs[0];
62 | data[3] += poly[1].Y * p.probs[1];
63 |
64 | data[4] += poly[2].X * p.probs[2];
65 | data[5] += poly[2].Y * p.probs[1];
66 |
67 | data[6] += poly[3].X * p.probs[2];
68 | data[7] += poly[3].Y * p.probs[3];
69 |
70 | score += p.score;
71 |
72 | probs[0] += p.probs[0];
73 | probs[1] += p.probs[1];
74 | probs[2] += p.probs[2];
75 | probs[3] += p.probs[3];
76 |
77 | nr_polys += 1;
78 | }
79 |
80 | Polygon get() const {
81 | Polygon p;
82 |
83 | auto &poly = p.poly;
84 | poly.resize(4);
85 |
86 | poly[0].X = data[0] / probs[0];
87 | poly[0].Y = data[1] / probs[3];
88 | poly[1].X = data[2] / probs[0];
89 | poly[1].Y = data[3] / probs[1];
90 | poly[2].X = data[4] / probs[2];
91 | poly[2].Y = data[5] / probs[1];
92 | poly[3].X = data[6] / probs[2];
93 | poly[3].Y = data[7] / probs[3];
94 |
95 | assert(score > 0);
96 | p.score = score;
97 | p.probs[0] = probs[0];
98 | p.probs[1] = probs[1];
99 | p.probs[2] = probs[2];
100 | p.probs[3] = probs[3];
101 |
102 | return p;
103 | }
104 |
105 | private:
106 | std::int64_t data[8];
107 | float score;
108 | float probs[4];
109 | std::int32_t nr_polys;
110 | };
111 |
112 |
113 | /**
114 | * The standard NMS algorithm.
115 | */
116 | std::vector standard_nms(std::vector &polys, float iou_threshold) {
117 | size_t n = polys.size();
118 | if (n == 0)
119 | return {};
120 | std::vector indices(n);
121 | std::iota(std::begin(indices), std::end(indices), 0);
122 | std::sort(std::begin(indices), std::end(indices), [&](size_t i, size_t j) { return polys[i].score > polys[j].score; });
123 |
124 | std::vector keep;
125 | while (indices.size()) {
126 | size_t p = 0, cur = indices[0];
127 | keep.emplace_back(cur);
128 | for (size_t i = 1; i < indices.size(); i ++) {
129 | if (!should_merge(polys[cur], polys[indices[i]], iou_threshold)) {
130 | indices[p++] = indices[i];
131 | }else{
132 | PolyMerger merger;
133 | merger.add(polys[ indices[i]]);
134 | merger.add(polys[cur]);
135 | polys[cur] = merger.get();
136 | }
137 | }
138 | indices.resize(p);
139 | }
140 |
141 | std::vector ret;
142 | for (auto &&i: keep) {
143 | ret.emplace_back(polys[i]);
144 | }
145 | return ret;
146 | }
147 |
148 |
149 | std::vector
150 | merge_iou(std::vector& polys_in, int* poly_ptr, int w, int h, float iou_threshold1, float iou_threshold2) {
151 |
152 | // first pass
153 | std::vector polys;
154 | for (size_t i = 0; i < polys_in.size(); i ++) {
155 | auto poly = polys_in[i];
156 |
157 | if (polys.size()) {
158 | // merge with the last one
159 | auto &bpoly = polys.back();
160 | if (should_merge(poly, bpoly, iou_threshold1)) {
161 | PolyMerger merger;
162 | merger.add(bpoly);
163 | merger.add(poly);
164 | bpoly = merger.get();
165 | poly_ptr[poly.y * w + poly.x] = (polys.size() - 1);
166 | continue;
167 | }else{
168 | if(poly.y > 0){
169 | int idx = poly_ptr[(poly.y -1)* w + poly.x];
170 | if(idx >= 0){
171 | auto &cpoly = polys[idx];
172 | if (should_merge(poly, cpoly, iou_threshold1)) {
173 | PolyMerger merger;
174 | merger.add(cpoly);
175 | merger.add(poly);
176 | cpoly = merger.get();
177 | poly_ptr[poly.y * w + poly.x] = idx;
178 | continue;
179 | }
180 | if(poly.x > 0){
181 | idx = poly_ptr[(poly.y -1)* w + poly.x - 1];
182 | if(idx >= 0){
183 | auto &cpoly = polys[idx];
184 | if (should_merge(poly, cpoly, iou_threshold1)) {
185 | PolyMerger merger;
186 | merger.add(cpoly);
187 | merger.add(poly);
188 | cpoly = merger.get();
189 | poly_ptr[poly.y * w + poly.x] = idx;
190 | continue;
191 | }
192 | }
193 | }
194 | idx = poly_ptr[(poly.y -1)* w + poly.x + 1];
195 | if(idx >= 0){
196 | auto &cpoly = polys[idx];
197 | if (should_merge(poly, cpoly, iou_threshold1)) {
198 | PolyMerger merger;
199 | merger.add(cpoly);
200 | merger.add(poly);
201 | cpoly = merger.get();
202 | poly_ptr[poly.y * w + poly.x] = idx;
203 | continue;
204 | }
205 | }
206 | }
207 | }
208 | polys.emplace_back(poly);
209 | }
210 | }
211 | polys.emplace_back(poly);
212 | poly_ptr[poly.y * w + poly.x] = (polys.size() - 1);
213 | }
214 | return standard_nms(polys, iou_threshold2);
215 | }
216 | }
217 |
--------------------------------------------------------------------------------
/rroi_align/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chenjun2hao/FOTS.pytorch/a2d1bf71d66197ed6b20c4cfcbe60d4735a20e0c/rroi_align/__init__.py
--------------------------------------------------------------------------------
/rroi_align/build.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 | import os
3 | import torch
4 | from torch.utils.ffi import create_extension
5 |
6 |
7 | sources = ['src/roi_pooling.c']
8 | headers = ['src/roi_pooling.h']
9 | extra_objects = []
10 | defines = []
11 | with_cuda = False
12 |
13 | this_file = os.path.dirname(os.path.realpath(__file__))
14 | print(this_file)
15 |
16 | if torch.cuda.is_available():
17 | print('Including CUDA code.')
18 | sources += ['src/rroi_align_cuda.c']
19 | headers += ['src/rroi_align_cuda.h']
20 | defines += [('WITH_CUDA', None)]
21 | with_cuda = True
22 | extra_objects = ['src/rroi_align.cu.o']
23 | extra_objects = [os.path.join(this_file, fname) for fname in extra_objects]
24 |
25 | # 这里就是编译
26 | ffi = create_extension(
27 | '_ext.rroi_align',
28 | headers=headers,
29 | sources=sources,
30 | define_macros=defines,
31 | relative_to=__file__,
32 | with_cuda=with_cuda,
33 | extra_objects=extra_objects
34 | )
35 |
36 | if __name__ == '__main__':
37 | ffi.build()
38 |
--------------------------------------------------------------------------------
/rroi_align/data/grad.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chenjun2hao/FOTS.pytorch/a2d1bf71d66197ed6b20c4cfcbe60d4735a20e0c/rroi_align/data/grad.jpg
--------------------------------------------------------------------------------
/rroi_align/data/grad_img.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chenjun2hao/FOTS.pytorch/a2d1bf71d66197ed6b20c4cfcbe60d4735a20e0c/rroi_align/data/grad_img.jpg
--------------------------------------------------------------------------------
/rroi_align/data/res0.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chenjun2hao/FOTS.pytorch/a2d1bf71d66197ed6b20c4cfcbe60d4735a20e0c/rroi_align/data/res0.jpg
--------------------------------------------------------------------------------
/rroi_align/data/res1.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chenjun2hao/FOTS.pytorch/a2d1bf71d66197ed6b20c4cfcbe60d4735a20e0c/rroi_align/data/res1.jpg
--------------------------------------------------------------------------------
/rroi_align/data/res2.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chenjun2hao/FOTS.pytorch/a2d1bf71d66197ed6b20c4cfcbe60d4735a20e0c/rroi_align/data/res2.jpg
--------------------------------------------------------------------------------
/rroi_align/data/timg.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chenjun2hao/FOTS.pytorch/a2d1bf71d66197ed6b20c4cfcbe60d4735a20e0c/rroi_align/data/timg.jpeg
--------------------------------------------------------------------------------
/rroi_align/functions/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chenjun2hao/FOTS.pytorch/a2d1bf71d66197ed6b20c4cfcbe60d4735a20e0c/rroi_align/functions/__init__.py
--------------------------------------------------------------------------------
/rroi_align/functions/rroi_align.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.autograd import Function
3 | from .._ext import rroi_align
4 | import pdb
5 |
6 | class RRoiAlignFunction(Function):
7 | def __init__(ctx, pooled_height, pooled_width, spatial_scale):
8 | ctx.pooled_width = pooled_width
9 | ctx.pooled_height = pooled_height
10 | ctx.spatial_scale = spatial_scale
11 | ctx.feature_size = None
12 |
13 | def forward(ctx, features, rois):
14 | ctx.feature_size = features.size()
15 | batch_size, num_channels, data_height, data_width = ctx.feature_size
16 | num_rois = rois.size(0)
17 | output = features.new(num_rois, num_channels, ctx.pooled_height, ctx.pooled_width).zero_().float()
18 | # ctx.argmax = features.new(num_rois, num_channels, ctx.pooled_height, ctx.pooled_width).zero_().int()
19 | ctx.idx_x = features.new(num_rois, num_channels, ctx.pooled_height, ctx.pooled_width).zero_().float() # 都是float类型的变量
20 | ctx.idx_y = features.new(num_rois, num_channels, ctx.pooled_height, ctx.pooled_width).zero_().float()
21 | ctx.rois = rois
22 | if not features.is_cuda:
23 | _features = features.permute(0, 2, 3, 1)
24 | roi_pooling.roi_pooling_forward(ctx.pooled_height, ctx.pooled_width, ctx.spatial_scale,
25 | _features, rois, output)
26 | else:
27 | rroi_align.rroi_align_forward_cuda(ctx.pooled_height, ctx.pooled_width, ctx.spatial_scale,
28 | features, rois, output, ctx.idx_x, ctx.idx_y)
29 |
30 | return output
31 |
32 | def backward(ctx, grad_output):
33 | assert(ctx.feature_size is not None and grad_output.is_cuda)
34 | batch_size, num_channels, data_height, data_width = ctx.feature_size
35 | grad_input = grad_output.new(batch_size, num_channels, data_height, data_width).zero_().float()
36 |
37 | rroi_align.rroi_align_backward_cuda(ctx.pooled_height, ctx.pooled_width, ctx.spatial_scale,
38 | grad_output, ctx.rois, grad_input, ctx.idx_x, ctx.idx_y)
39 |
40 | return grad_input, None
41 |
--------------------------------------------------------------------------------
/rroi_align/main.py:
--------------------------------------------------------------------------------
1 | import cupy as cp
2 | import numpy as np
3 | import math
4 |
5 | # bottom_data = cp.random.randn(1,3,40,40, dtype=np.float32) # 特征feature
6 | bottom_data = np.random.randn(1,1,40,40)
7 | bottom_data = cp.array(bottom_data, dtype=np.float32)
8 | batch, channels, height, width = bottom_data.shape
9 | spatial_scale = 1.0 # 原始特征和feature的比例
10 | rois = cp.array([[0, 2, 2, 10, 10],
11 | [0, 2, 4, 20, 10]], dtype=np.float32) # rois
12 | pooled_weight = 7 # 池化之后的宽度
13 | pooled_height = 7 # 池化之后的高度
14 |
15 | ## 定义核函数
16 | roi_pooling_2d_fwd = cp.ElementwiseKernel(
17 | '''
18 | raw T bottom_data, T spatial_scale, int32 channels,
19 | int32 height, int32 width, int32 pooled_height, int32 pooled_width,
20 | raw T bottom_rois
21 | ''',
22 | 'T top_data, int32 argmax_data',
23 | '''
24 | // pos in output filter
25 | int pw = i % pooled_width;
26 | int ph = (i / pooled_width) % pooled_height;
27 | int c = (i / pooled_width / pooled_height) % channels;
28 | int num = i / pooled_width / pooled_height / channels;
29 | int roi_batch_ind = bottom_rois[num * 5 + 0];
30 | int roi_start_w = round(bottom_rois[num * 5 + 1] * spatial_scale); // 读取rois的信息
31 | int roi_start_h = round(bottom_rois[num * 5 + 2] * spatial_scale);
32 | int roi_end_w = round(bottom_rois[num * 5 + 3] * spatial_scale);
33 | int roi_end_h = round(bottom_rois[num * 5 + 4] * spatial_scale);
34 |
35 | // Force malformed ROIs to be 1x1
36 | // 计算每块开始和结束的索引
37 | int roi_width = max(roi_end_w - roi_start_w + 1, 1);
38 | int roi_height = max(roi_end_h - roi_start_h + 1, 1);
39 |
40 | // 计算pooled_weight
41 | int rois_pooled_width = (int)(ceil((float)(pooled_height * roi_width) / (float)(roi_height) )); // 等比例池化,减小
42 | float bin_size_h = static_cast(roi_height) / static_cast(pooled_height); // static_cast强制类型转换
43 | float bin_size_w = static_cast(roi_width) / static_cast(rois_pooled_width);
44 |
45 | int hstart = static_cast(floor(static_cast(ph) * bin_size_h));
46 | int wstart = static_cast(floor(static_cast(pw) * bin_size_w));
47 | int hend = static_cast(ceil(static_cast(ph + 1) * bin_size_h));
48 | int wend = static_cast(ceil(static_cast(pw + 1) * bin_size_w));
49 |
50 | // Add roi offsets and clip to input boundaries
51 | // 求每块的最大值
52 | hstart = min(max(hstart + roi_start_h, 0), height);
53 | hend = min(max(hend + roi_start_h, 0), height);
54 | wstart = min(max(wstart + roi_start_w, 0), width);
55 | wend = min(max(wend + roi_start_w, 0), width);
56 | bool is_empty = (hend <= hstart) || (wend <= wstart);
57 | // Define an empty pooling region to be zero
58 | float maxval = is_empty ? 0 : -1E+37;
59 | // If nothing is pooled, argmax=-1 causes nothing to be backprop'd
60 |
61 | int maxidx = -1;
62 | int data_offset = (roi_batch_ind * channels + c) * height * width;
63 | for (int h = hstart; h < hend; ++h) {
64 | for (int w = wstart; w < wend; ++w) {
65 | int bottom_index = h * width + w;
66 | if (bottom_data[data_offset + bottom_index] > maxval) {
67 | maxval = bottom_data[data_offset + bottom_index];
68 | maxidx = bottom_index;
69 | }
70 | }
71 | }
72 | top_data = maxval;
73 | argmax_data = maxidx;
74 | ''', 'roi_pooling_2d_fwd'
75 | )
76 | pooled_height = 2
77 | maxratio = (rois[:, 3] - rois[:, 1]) / (rois[:, 4] - rois[:, 2])
78 | maxratio = maxratio.max()
79 | pooled_width = math.ceil(pooled_height * maxratio)
80 |
81 | top_data = cp.zeros((2, 3, pooled_height, pooled_width), dtype=np.float32) # 输出的feature map
82 | argmax_data = cp.zeros(top_data.shape, np.int32) # 最大值对应的索引
83 |
84 | roi_pooling_2d_fwd(bottom_data, spatial_scale, channels, height, width,
85 | pooled_height, pooled_width, rois, top_data, argmax_data)
86 |
87 | print(top_data.shape)
--------------------------------------------------------------------------------
/rroi_align/make.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | # CUDA_PATH=/usr/local/cuda/
4 |
5 | export CUDA_PATH=/usr/local/cuda/
6 | #You may also want to ad the following
7 | #export C_INCLUDE_PATH=/opt/cuda/include
8 |
9 | export CXXFLAGS="-std=c++11"
10 | export CFLAGS="-std=c99"
11 |
12 | # python setup.py build_ext --inplace
13 | # rm -rf build
14 |
15 | CUDA_ARCH="-gencode arch=compute_30,code=sm_30 \
16 | -gencode arch=compute_35,code=sm_35 \
17 | -gencode arch=compute_50,code=sm_50 \
18 | -gencode arch=compute_52,code=sm_52 \
19 | -gencode arch=compute_60,code=sm_60 \
20 | -gencode arch=compute_61,code=sm_61 "
21 |
22 |
23 | # compile roi_pooling//编译cuda文件
24 | cd src
25 | echo "Compiling roi pooling kernels by nvcc..."
26 | nvcc -c -o rroi_align.cu.o rroi_align_kernel.cu \
27 | -D GOOGLE_CUDA=1 -x cu -Xcompiler -fPIC $CUDA_ARCH
28 | cd ../
29 | python build.py
--------------------------------------------------------------------------------
/rroi_align/modules/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chenjun2hao/FOTS.pytorch/a2d1bf71d66197ed6b20c4cfcbe60d4735a20e0c/rroi_align/modules/__init__.py
--------------------------------------------------------------------------------
/rroi_align/modules/rroi_align.py:
--------------------------------------------------------------------------------
1 | from torch.nn.modules.module import Module
2 | from ..functions.rroi_align import RRoiAlignFunction
3 |
4 |
5 | class _RRoiAlign(Module):
6 | def __init__(self, pooled_height, pooled_width, spatial_scale):
7 | super(_RRoiAlign, self).__init__()
8 |
9 | self.pooled_width = int(pooled_width)
10 | self.pooled_height = int(pooled_height)
11 | self.spatial_scale = float(spatial_scale)
12 |
13 | def forward(self, features, rois):
14 | return RRoiAlignFunction(self.pooled_height, self.pooled_width, self.spatial_scale)(features, rois)
15 |
--------------------------------------------------------------------------------
/rroi_align/src/roi_pooling.c:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 |
4 | int roi_pooling_forward(int pooled_height, int pooled_width, float spatial_scale,
5 | THFloatTensor * features, THFloatTensor * rois, THFloatTensor * output)
6 | {
7 | // Grab the input tensor
8 | float * data_flat = THFloatTensor_data(features);
9 | float * rois_flat = THFloatTensor_data(rois);
10 |
11 | float * output_flat = THFloatTensor_data(output);
12 |
13 | // Number of ROIs
14 | int num_rois = THFloatTensor_size(rois, 0);
15 | int size_rois = THFloatTensor_size(rois, 1);
16 | // batch size
17 | int batch_size = THFloatTensor_size(features, 0);
18 | if(batch_size != 1)
19 | {
20 | return 0;
21 | }
22 | // data height
23 | int data_height = THFloatTensor_size(features, 1);
24 | // data width
25 | int data_width = THFloatTensor_size(features, 2);
26 | // Number of channels
27 | int num_channels = THFloatTensor_size(features, 3);
28 |
29 | // Set all element of the output tensor to -inf.
30 | THFloatStorage_fill(THFloatTensor_storage(output), -1);
31 |
32 | // For each ROI R = [batch_index x1 y1 x2 y2]: max pool over R
33 | int index_roi = 0;
34 | int index_output = 0;
35 | int n;
36 | for (n = 0; n < num_rois; ++n)
37 | {
38 | int roi_batch_ind = rois_flat[index_roi + 0];
39 | int roi_start_w = round(rois_flat[index_roi + 1] * spatial_scale);
40 | int roi_start_h = round(rois_flat[index_roi + 2] * spatial_scale);
41 | int roi_end_w = round(rois_flat[index_roi + 3] * spatial_scale);
42 | int roi_end_h = round(rois_flat[index_roi + 4] * spatial_scale);
43 | // CHECK_GE(roi_batch_ind, 0);
44 | // CHECK_LT(roi_batch_ind, batch_size);
45 |
46 | int roi_height = fmaxf(roi_end_h - roi_start_h + 1, 1);
47 | int roi_width = fmaxf(roi_end_w - roi_start_w + 1, 1);
48 | float bin_size_h = (float)(roi_height) / (float)(pooled_height);
49 | float bin_size_w = (float)(roi_width) / (float)(pooled_width);
50 |
51 | int index_data = roi_batch_ind * data_height * data_width * num_channels;
52 | const int output_area = pooled_width * pooled_height;
53 |
54 | int c, ph, pw;
55 | for (ph = 0; ph < pooled_height; ++ph)
56 | {
57 | for (pw = 0; pw < pooled_width; ++pw)
58 | {
59 | int hstart = (floor((float)(ph) * bin_size_h));
60 | int wstart = (floor((float)(pw) * bin_size_w));
61 | int hend = (ceil((float)(ph + 1) * bin_size_h));
62 | int wend = (ceil((float)(pw + 1) * bin_size_w));
63 |
64 | hstart = fminf(fmaxf(hstart + roi_start_h, 0), data_height);
65 | hend = fminf(fmaxf(hend + roi_start_h, 0), data_height);
66 | wstart = fminf(fmaxf(wstart + roi_start_w, 0), data_width);
67 | wend = fminf(fmaxf(wend + roi_start_w, 0), data_width);
68 |
69 | const int pool_index = index_output + (ph * pooled_width + pw);
70 | int is_empty = (hend <= hstart) || (wend <= wstart);
71 | if (is_empty)
72 | {
73 | for (c = 0; c < num_channels * output_area; c += output_area)
74 | {
75 | output_flat[pool_index + c] = 0;
76 | }
77 | }
78 | else
79 | {
80 | int h, w, c;
81 | for (h = hstart; h < hend; ++h)
82 | {
83 | for (w = wstart; w < wend; ++w)
84 | {
85 | for (c = 0; c < num_channels; ++c)
86 | {
87 | const int index = (h * data_width + w) * num_channels + c;
88 | if (data_flat[index_data + index] > output_flat[pool_index + c * output_area])
89 | {
90 | output_flat[pool_index + c * output_area] = data_flat[index_data + index];
91 | }
92 | }
93 | }
94 | }
95 | }
96 | }
97 | }
98 |
99 | // Increment ROI index
100 | index_roi += size_rois;
101 | index_output += pooled_height * pooled_width * num_channels;
102 | }
103 | return 1;
104 | }
--------------------------------------------------------------------------------
/rroi_align/src/roi_pooling.h:
--------------------------------------------------------------------------------
1 | int roi_pooling_forward(int pooled_height, int pooled_width, float spatial_scale,
2 | THFloatTensor * features, THFloatTensor * rois, THFloatTensor * output);
--------------------------------------------------------------------------------
/rroi_align/src/rroi_align.cu.o:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chenjun2hao/FOTS.pytorch/a2d1bf71d66197ed6b20c4cfcbe60d4735a20e0c/rroi_align/src/rroi_align.cu.o
--------------------------------------------------------------------------------
/rroi_align/src/rroi_align_cuda.c:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 | #include "rroi_align_kernel.h"
4 |
5 | extern THCState *state;
6 |
7 | int rroi_align_forward_cuda(int pooled_height, int pooled_width, float spatial_scale,
8 | THCudaTensor * features, THCudaTensor * rois, THCudaTensor * output,
9 | THCudaTensor * idx_x, THCudaTensor * idx_y)
10 | {
11 | // Grab the input tensor
12 | float * data_flat = THCudaTensor_data(state, features);
13 | float * rois_flat = THCudaTensor_data(state, rois);
14 |
15 | float * output_flat = THCudaTensor_data(state, output);
16 | float * idx_x_flat = THCudaTensor_data(state, idx_x); // 每个rroi bin的中心索引
17 | float * idx_y_flat = THCudaTensor_data(state, idx_y);
18 | // int * argmax_flat = THCudaIntTensor_data(state, argmax);
19 |
20 | // Number of ROIs
21 | int num_rois = THCudaTensor_size(state, rois, 0);
22 | int size_rois = THCudaTensor_size(state, rois, 1);
23 | if (size_rois != 6)
24 | {
25 | return 0;
26 | }
27 |
28 | // data height
29 | int data_height = THCudaTensor_size(state, features, 2);
30 | // data width
31 | int data_width = THCudaTensor_size(state, features, 3);
32 | // Number of channels
33 | int num_channels = THCudaTensor_size(state, features, 1);
34 |
35 | cudaStream_t stream = THCState_getCurrentStream(state);
36 |
37 | RROIAlignForwardLaucher(
38 | data_flat, spatial_scale, num_rois, data_height,
39 | data_width, num_channels, pooled_height,
40 | pooled_width, rois_flat,
41 | output_flat, idx_x_flat, idx_y_flat, stream);
42 |
43 | return 1;
44 | }
45 |
46 |
47 |
48 | // 反向传播
49 | int rroi_align_backward_cuda(int pooled_height, int pooled_width, float spatial_scale,
50 | THCudaTensor * top_grad, THCudaTensor * rois, THCudaTensor * bottom_grad,
51 | THCudaTensor * idx_x, THCudaTensor * idx_y)
52 | {
53 | // Grab the input tensor
54 | float * top_grad_flat = THCudaTensor_data(state, top_grad);
55 | float * rois_flat = THCudaTensor_data(state, rois);
56 |
57 | float * bottom_grad_flat = THCudaTensor_data(state, bottom_grad);
58 | float * idx_x_flat = THCudaTensor_data(state, idx_x);
59 | float * idx_y_flat = THCudaTensor_data(state, idx_y);
60 |
61 | // Number of ROIs
62 | int num_rois = THCudaTensor_size(state, rois, 0);
63 | int size_rois = THCudaTensor_size(state, rois, 1);
64 | if (size_rois != 6)
65 | {
66 | return 0;
67 | }
68 |
69 | // batch size
70 | int batch_size = THCudaTensor_size(state, bottom_grad, 0);
71 |
72 | // data height
73 | int data_height = THCudaTensor_size(state, bottom_grad, 2);
74 | // data width
75 | int data_width = THCudaTensor_size(state, bottom_grad, 3);
76 | // Number of channels
77 | int num_channels = THCudaTensor_size(state, bottom_grad, 1);
78 |
79 | cudaStream_t stream = THCState_getCurrentStream(state);
80 | RROIAlignBackwardLaucher(
81 | top_grad_flat, spatial_scale, batch_size, num_rois, data_height,
82 | data_width, num_channels, pooled_height,
83 | pooled_width, rois_flat, bottom_grad_flat,
84 | idx_x_flat, idx_y_flat, stream);
85 |
86 | return 1;
87 | }
88 |
--------------------------------------------------------------------------------
/rroi_align/src/rroi_align_cuda.h:
--------------------------------------------------------------------------------
1 |
2 | int rroi_align_forward_cuda(int pooled_height, int pooled_width, float spatial_scale,
3 | THCudaTensor * features, THCudaTensor * rois, THCudaTensor * output,
4 | THCudaTensor * idx_x, THCudaTensor * idx_y);
5 |
6 | int rroi_align_backward_cuda(int pooled_height, int pooled_width, float spatial_scale,
7 | THCudaTensor * top_grad, THCudaTensor * rois, THCudaTensor * bottom_grad,
8 | THCudaTensor * idx_x, THCudaTensor * idx_y);
--------------------------------------------------------------------------------
/rroi_align/src/rroi_align_kernel.h:
--------------------------------------------------------------------------------
1 | #ifndef _ROI_POOLING_KERNEL
2 | #define _ROI_POOLING_KERNEL
3 |
4 | #ifdef __cplusplus
5 | extern "C" {
6 | #endif
7 |
8 | int RROIAlignForwardLaucher(
9 | const float* bottom_data, const float spatial_scale, const int num_rois, const int height,
10 | const int width, const int channels, const int pooled_height,
11 | const int pooled_width, const float* bottom_rois,
12 | float* top_data, float* con_idx_x, float* con_idx_y, cudaStream_t stream);
13 |
14 | int RROIAlignBackwardLaucher(
15 | const float* top_diff, const float spatial_scale, const int batch_size, const int num_rois,
16 | const int height, const int width, const int channels, const int pooled_height,
17 | const int pooled_width, const float* bottom_rois, float* bottom_diff,
18 | const float* con_idx_x, const float* con_idx_y, cudaStream_t stream);
19 |
20 | #ifdef __cplusplus
21 | }
22 | #endif
23 |
24 | #endif
25 |
26 |
--------------------------------------------------------------------------------
/rroi_align/test.py:
--------------------------------------------------------------------------------
1 | from modules.roi_pool import _RoIPooling
2 | import torch
3 | import cv2
4 | import numpy as np
5 | import math
6 | import random
7 | from math import sin, cos, floor, ceil
8 |
9 |
10 | if __name__=='__main__':
11 | roipool = _RoIPooling(44, 328, 1.0) # 类的初始化
12 |
13 | path = './data/timg.jpeg'
14 | im_data = cv2.imread(path)
15 | img = im_data.copy()
16 | im_data = torch.from_numpy(im_data).unsqueeze(0).permute(0,3,1,2)
17 | im_data = im_data
18 | im_data = im_data.to(torch.float)
19 |
20 | # 参数设置
21 | norm_height = 44
22 | debug = True
23 | # 居民身份证的坐标位置
24 | # gt = np.asarray([[200,218],[198,207],[232,201],[238,210]])
25 | gt = np.asarray([[205,150],[202,126],[365,93],[372,111]])
26 |
27 | center = (gt[0, :] + gt[1, :] + gt[2, :] + gt[3, :]) / 4 # 求中心点
28 |
29 | dw = gt[2, :] - gt[1, :]
30 | dh = gt[1, :] - gt[0, :]
31 | w = math.sqrt(dw[0] * dw[0] + dw[1] * dw[1]) # 宽度和高度
32 | h = math.sqrt(dh[0] * dh[0] + dh[1] * dh[1]) + random.randint(-2, 2)
33 |
34 | angle_gt = ( math.atan2((gt[2][1] - gt[1][1]), gt[2][0] - gt[1][0]) + math.atan2((gt[3][1] - gt[0][1]), gt[3][0] - gt[0][0]) ) / 2
35 | angle_gt = angle_gt / 3.1415926535 * 180
36 |
37 | rois = torch.tensor([0, center[0], center[1], h, w, angle_gt])
38 | rois = rois.to(torch.float)
39 |
40 | # rroi_align传入的参数为roi的中心,w,h和arctan(theta),theta为角度
41 | # 参数设置
42 | pooled_width = 328
43 | pooled_height = 44
44 | channels = 3
45 | spatial_scale = 1.0
46 | index = pooled_height * pooled_width * channels
47 | imageHeight, imageWidth, channel = img.shape
48 | height, width = imageHeight, imageWidth
49 | output = torch.zeros(index)
50 | for i in range(index):
51 | n = i;
52 | pw = n % pooled_width;
53 | n /= pooled_width;
54 | ph = n % pooled_height;
55 | n /= pooled_height;
56 | c = n % channels;
57 | n /= channels;
58 |
59 | offset_bottom_rois = rois
60 | roi_batch_ind = offset_bottom_rois[0];
61 | cx = offset_bottom_rois[1];
62 | cy = offset_bottom_rois[2];
63 | h = offset_bottom_rois[3];
64 | w = offset_bottom_rois[4];
65 | angle = - offset_bottom_rois[5]/180.0*3.1415926535;
66 |
67 | # //TransformPrepare
68 | dx = -pooled_width/2.0;
69 | dy = -pooled_height/2.0;
70 | Sx = w*spatial_scale/pooled_width;
71 | Sy = h*spatial_scale/pooled_height;
72 | Alpha = cos(angle);
73 | Beta = sin(angle);
74 | Dx = cx*spatial_scale;
75 | Dy = cy*spatial_scale;
76 |
77 | M =[[0 for col in range(3)] for row in range(2)]
78 | M[0][0] = Alpha*Sx;
79 | M[0][1] = Beta*Sy;
80 | M[0][2] = Alpha*Sx*dx+Beta*Sy*dy+Dx;
81 | M[1][0] = -Beta*Sx;
82 | M[1][1] = Alpha*Sy;
83 | M[1][2] = -Beta*Sx*dx+Alpha*Sy*dy+Dy;
84 |
85 | # float P[8];
86 | P =[0 for col in range(8)]
87 | P[0] = M[0][0]*pw+M[0][1]*ph+M[0][2];
88 | P[1] = M[1][0]*pw+M[1][1]*ph+M[1][2];
89 | P[2] = M[0][0]*pw+M[0][1]*(ph+1)+M[0][2];
90 | P[3] = M[1][0]*pw+M[1][1]*(ph+1)+M[1][2];
91 | P[4] = M[0][0]*(pw+1)+M[0][1]*ph+M[0][2];
92 | P[5] = M[1][0]*(pw+1)+M[1][1]*ph+M[1][2];
93 | P[6] = M[0][0]*(pw+1)+M[0][1]*(ph+1)+M[0][2];
94 | P[7] = M[1][0]*(pw+1)+M[1][1]*(ph+1)+M[1][2];
95 |
96 | leftMost = (max(torch.round(min(min(P[0],P[2]),min(P[4],P[6]))),0.0));
97 | rightMost= (min(torch.round(max(max(P[0],P[2]),max(P[4],P[6]))),imageWidth-1.0));
98 | topMost= (max(torch.round(min(min(P[1],P[3]),min(P[5],P[7]))),0.0));
99 | bottomMost= (min(torch.round(max(max(P[1],P[3]),max(P[5],P[7]))),imageHeight-1.0));
100 |
101 | # //float maxval = 0;
102 | # //int maxidx = -1;
103 | # offset_bottom_data = bottom_data + (roi_batch_ind * channels + c) * height * width;
104 | offset_bottom_data = im_data.view(-1)
105 |
106 | bin_cx = (leftMost + rightMost) / 2.0; # // shift
107 | bin_cy = (topMost + bottomMost) / 2.0;
108 |
109 | bin_l = int(floor(bin_cx));
110 | bin_r = int(ceil(bin_cx));
111 | bin_t = int(floor(bin_cy));
112 | bin_b = int(ceil(bin_cy));
113 |
114 | lt_value = 0.0;
115 | if (bin_t > 0 and bin_l > 0 and bin_t < height and bin_l < width):
116 | lt_value = offset_bottom_data[bin_t * width + bin_l];
117 | rt_value = 0.0;
118 | if (bin_t > 0 and bin_r > 0 and bin_t < height and bin_r < width):
119 | rt_value = offset_bottom_data[bin_t * width + bin_r];
120 | lb_value = 0.0;
121 | if (bin_b > 0 and bin_l > 0 and bin_b < height and bin_l < width):
122 | lb_value = offset_bottom_data[bin_b * width + bin_l];
123 | rb_value = 0.0;
124 | if (bin_b > 0 and bin_r > 0 and bin_b < height and bin_r < width):
125 | rb_value = offset_bottom_data[bin_b * width + bin_r];
126 |
127 | rx = bin_cx - floor(bin_cx);
128 | ry = bin_cy - floor(bin_cy);
129 |
130 | wlt = (1.0 - rx) * (1.0 - ry);
131 | wrt = rx * (1.0 - ry);
132 | wrb = rx * ry;
133 | wlb = (1.0 - rx) * ry;
134 |
135 | inter_val = 0.0;
136 |
137 | inter_val += lt_value * wlt;
138 | inter_val += rt_value * wrt;
139 | inter_val += rb_value * wrb;
140 | inter_val += lb_value * wlb;
141 |
142 | output[i] = inter_val
143 |
144 | res = output.view(channels, pooled_height, pooled_width)
145 |
146 | if debug:
147 | x_d = res.data.cpu().numpy()
148 | x_data_draw = x_d.swapaxes(0, 2)
149 | x_data_draw = x_data_draw.swapaxes(0, 1)
150 |
151 | x_data_draw = np.asarray(x_data_draw, dtype=np.uint8)
152 | # x_data_draw = x_data_draw[:, :, ::-1]
153 | cv2.imshow('im_data_gt', x_data_draw)
154 | cv2.imwrite('res.jpg', x_data_draw) # 这个效果很正呀
155 | cv2.imshow('src_img', img)
156 | cv2.waitKey(100)
157 | temp = 1
158 |
159 |
160 | # pooled_feat = roipool(im_data, rois.view(-1, 6))
161 |
162 | # if debug:
163 | # x_d = pooled_feat.data.cpu().numpy()[0]
164 | # x_data_draw = x_d.swapaxes(0, 2)
165 | # x_data_draw = x_data_draw.swapaxes(0, 1)
166 |
167 | # x_data_draw = np.asarray(x_data_draw, dtype=np.uint8)
168 | # # x_data_draw = x_data_draw[:, :, ::-1]
169 | # cv2.imshow('im_data_gt', x_data_draw)
170 | # cv2.imshow('src_img', img)
171 | # cv2.waitKey(100)
172 |
173 | # print(pooled_feat.shape)
--------------------------------------------------------------------------------
/rroi_align/test2.py:
--------------------------------------------------------------------------------
1 | '''
2 | 采用rroi_align对旋转的文字进行矫正和crop
3 | data:2019-6-24
4 | author:yibao2hao
5 | 注意:
6 | 1. im_data和rois都要是cuda
7 | 2. roi为[index, x, y, h, w, theta]
8 | 3. 增加了batch操作支持
9 | 4.
10 | '''
11 | from modules.rroi_align import _RRoiAlign
12 | import torch
13 | import cv2
14 | import numpy as np
15 | import math
16 | import random
17 | from math import sin, cos, floor, ceil
18 | import matplotlib.pyplot as plt
19 | from torch.autograd import Variable
20 |
21 |
22 | if __name__=='__main__':
23 |
24 | path = './rroi_align/data/timg.jpeg'
25 | # path = './data/grad.jpg'
26 | im_data = cv2.imread(path)
27 | img = im_data.copy()
28 | im_data = torch.from_numpy(im_data).unsqueeze(0).permute(0,3,1,2)
29 | im_data = im_data
30 | im_data = im_data.to(torch.float).cuda()
31 | im_data = Variable(im_data, requires_grad=True)
32 |
33 | # plt.imshow(img)
34 | # plt.show()
35 |
36 | # 参数设置
37 | debug = True
38 | # 居民身份证的坐标位置
39 | gt3 = np.asarray([[200,218],[198,207],[232,201],[238,210]]) # 签发机关
40 | gt1 = np.asarray([[205,150],[202,126],[365,93],[372,111]]) # 居民身份证
41 | # # gt2 = np.asarray([[205,150],[202,126],[365,93],[372,111]]) # 居民身份证
42 | gt2 = np.asarray([[206,111],[199,95],[349,60],[355,80]]) # 中华人民共和国
43 | gt4 = np.asarray([[312,127],[304,105],[367,88],[374,114]]) # 份证
44 | gt5 = np.asarray([[133,168],[118,112],[175,100],[185,154]]) # 国徽
45 | # gts = [gt1, gt2, gt3, gt4, gt5]
46 | gts = [gt2, gt4, gt5]
47 |
48 |
49 | roi = []
50 | for i,gt in enumerate(gts):
51 | center = (gt[0, :] + gt[1, :] + gt[2, :] + gt[3, :]) / 4 # 求中心点
52 |
53 | dw = gt[2, :] - gt[1, :]
54 | dh = gt[1, :] - gt[0, :]
55 | w = math.sqrt(dw[0] * dw[0] + dw[1] * dw[1]) # 宽度和高度
56 | h = math.sqrt(dh[0] * dh[0] + dh[1] * dh[1]) + random.randint(-2, 2)
57 |
58 | angle_gt = ( math.atan2((gt[2][1] - gt[1][1]), gt[2][0] - gt[1][0]) + math.atan2((gt[3][1] - gt[0][1]), gt[3][0] - gt[0][0]) ) / 2
59 | angle_gt = -angle_gt / 3.1415926535 * 180 # 需要加个负号
60 |
61 | roi.append([0, center[0], center[1], h, w, angle_gt]) # roi的参数
62 |
63 | rois = torch.tensor(roi)
64 | rois = rois.to(torch.float).cuda()
65 |
66 | pooled_height = 44
67 | maxratio = rois[:,4] / rois[:,3]
68 | maxratio = maxratio.max().item()
69 | pooled_width = math.ceil(pooled_height * maxratio)
70 |
71 | roipool = _RRoiAlign(pooled_height, pooled_width, 1.0)
72 | # 执行rroi_align操作
73 | pooled_feat = roipool(im_data, rois.view(-1, 6))
74 |
75 | res = pooled_feat.pow(2).sum()
76 | # res = pooled_feat.sum()
77 | res.backward()
78 |
79 | if debug:
80 | for i in range(pooled_feat.shape[0]):
81 | x_d = pooled_feat.data.cpu().numpy()[i]
82 | x_data_draw = x_d.swapaxes(0, 2)
83 | x_data_draw = x_data_draw.swapaxes(0, 1)
84 |
85 | x_data_draw = np.asarray(x_data_draw, dtype=np.uint8)
86 | cv2.imshow('im_data_gt %d' % i, x_data_draw)
87 | cv2.imwrite('./rroi_align/data/res%d.jpg' % i, x_data_draw)
88 |
89 | cv2.imshow('img', img)
90 |
91 | # 显示梯度
92 | im_grad = im_data.grad.data.cpu().numpy()[0]
93 | im_grad = im_grad.swapaxes(0, 2)
94 | im_grad = im_grad.swapaxes(0, 1)
95 |
96 | im_grad = np.asarray(im_grad, dtype=np.uint8)
97 | cv2.imshow('grad', im_grad)
98 | cv2.imwrite('./rroi_align/data/grad.jpg',im_grad)
99 |
100 | #
101 | grad_img = img + im_grad
102 | cv2.imwrite('./rroi_align/data/grad_img.jpg', grad_img)
103 | cv2.waitKey(100)
104 | print(pooled_feat.shape)
105 |
106 |
--------------------------------------------------------------------------------
/sample_train_data/MLT/done/gt_img_5407.txt:
--------------------------------------------------------------------------------
1 | 910.992431640625,1273.0765380859375,1039.992431640625,1276.0765380859375,1024.507568359375,1941.9234619140625,895.507568359375,1938.9234619140625,1, 当社関係以外の
2 | 748.3333333333333,1266.4999999999998,875.3333333333333,1280.4999999999998,881.3333333333333,2052.5,764.3333333333333,2052.5,1, 駐車を御遠慮下さい
3 | 622.8502197265625,1639.7156982421875,626.1497802734375,1567.7843017578125,735.1497802734375,1572.7843017578125,731.8502197265625,1644.7156982421875,1, (株)
4 | 636.3333333333333,1680.4999999999998,717.3333333333333,1677.4999999999998,731.3333333333333,2005.4999999999998,650.3333333333333,2005.4999999999998,1, 新六商店
5 |
--------------------------------------------------------------------------------
/sample_train_data/MLT/done/img_5407.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chenjun2hao/FOTS.pytorch/a2d1bf71d66197ed6b20c4cfcbe60d4735a20e0c/sample_train_data/MLT/done/img_5407.jpg
--------------------------------------------------------------------------------
/sample_train_data/MLT/icdar-2015-Ch4/Train/gt_img_784.txt:
--------------------------------------------------------------------------------
1 | 462,113,526,107,528,128,465,134,ARMANI
2 | 523,108,611,101,612,118,525,126,EXCHANGE
3 | 467,132,523,125,523,143,467,150,G-STAR
4 | 532,125,590,118,591,136,533,143,CALVIN
5 | 590,120,636,116,637,134,591,137,KLEIN
6 | 631,115,687,111,690,128,635,133,JEANS
7 | 467,151,521,145,522,160,468,167,BREAD
8 | 522,145,535,144,537,160,523,161,###
9 | 535,141,598,137,599,155,537,159,BUTTER
10 | 608,136,649,133,650,150,609,153,TRUE
11 | 649,132,723,125,724,143,650,150,###
12 | 391,158,417,155,418,174,392,176,B1
13 | 1,158,85,151,85,203,1,210,###
14 | 176,288,269,285,270,303,177,305,###
15 | 1041,181,1132,173,1140,243,1049,251,SALE
16 |
--------------------------------------------------------------------------------
/sample_train_data/MLT/icdar-2015-Ch4/Train/img_784.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chenjun2hao/FOTS.pytorch/a2d1bf71d66197ed6b20c4cfcbe60d4735a20e0c/sample_train_data/MLT/icdar-2015-Ch4/Train/img_784.jpg
--------------------------------------------------------------------------------
/sample_train_data/MLT/trainMLT.txt:
--------------------------------------------------------------------------------
1 | done/img_5407.jpg
2 | icdar-2015-Ch4/Train/img_784.jpg
3 |
--------------------------------------------------------------------------------
/sample_train_data/MLT_CROPS/gt.txt:
--------------------------------------------------------------------------------
1 | word_118.png, "Ngee"
2 | word_119.png, "Ann"
3 | word_120.png, "City"
4 | word_121.png, "ION"
5 |
--------------------------------------------------------------------------------
/sample_train_data/MLT_CROPS/word_118.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chenjun2hao/FOTS.pytorch/a2d1bf71d66197ed6b20c4cfcbe60d4735a20e0c/sample_train_data/MLT_CROPS/word_118.png
--------------------------------------------------------------------------------
/sample_train_data/MLT_CROPS/word_119.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chenjun2hao/FOTS.pytorch/a2d1bf71d66197ed6b20c4cfcbe60d4735a20e0c/sample_train_data/MLT_CROPS/word_119.png
--------------------------------------------------------------------------------
/sample_train_data/MLT_CROPS/word_120.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chenjun2hao/FOTS.pytorch/a2d1bf71d66197ed6b20c4cfcbe60d4735a20e0c/sample_train_data/MLT_CROPS/word_120.png
--------------------------------------------------------------------------------
/sample_train_data/MLT_CROPS/word_121.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chenjun2hao/FOTS.pytorch/a2d1bf71d66197ed6b20c4cfcbe60d4735a20e0c/sample_train_data/MLT_CROPS/word_121.png
--------------------------------------------------------------------------------
/src/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chenjun2hao/FOTS.pytorch/a2d1bf71d66197ed6b20c4cfcbe60d4735a20e0c/src/__init__.py
--------------------------------------------------------------------------------
/test.py:
--------------------------------------------------------------------------------
1 | '''
2 | Created on 2019-06-30
3 |
4 | @author: chenjun2hao
5 | '''
6 |
7 | import os
8 | import cv2
9 | import glob
10 | import numpy as np
11 | import torch
12 | import argparse
13 | from PIL import Image
14 | from PIL import ImageFont
15 | from PIL import ImageDraw
16 |
17 | from nms import get_boxes
18 |
19 | from tools.models import ModelResNetSep2, OwnModel
20 | import tools.net_utils as net_utils
21 | from src.utils import strLabelConverter, alphabet
22 | from tools.ocr_utils import ocr_image, align_ocr
23 | from tools.data_gen import draw_box_points
24 |
25 | def resize_image(im, max_size = 1585152, scale_up=True):
26 |
27 | if scale_up:
28 | image_size = [im.shape[1] * 3 // 32 * 32, im.shape[0] * 3 // 32 * 32]
29 | else:
30 | image_size = [im.shape[1] // 32 * 32, im.shape[0] // 32 * 32]
31 | while image_size[0] * image_size[1] > max_size:
32 | image_size[0] /= 1.2
33 | image_size[1] /= 1.2
34 | image_size[0] = int(image_size[0] // 32) * 32
35 | image_size[1] = int(image_size[1] // 32) * 32
36 |
37 | resize_h = int(image_size[1])
38 | resize_w = int(image_size[0])
39 |
40 | scaled = cv2.resize(im, dsize=(resize_w, resize_h))
41 | return scaled, (resize_h, resize_w)
42 |
43 |
44 | if __name__ == '__main__':
45 |
46 | parser = argparse.ArgumentParser()
47 | parser.add_argument('-cuda', type=int, default=1)
48 | parser.add_argument('-model', default='./weights/FOTS_280000.h5')
49 | # parser.add_argument('-model', default='./weights/e2e-mlt.h5')
50 | parser.add_argument('-segm_thresh', default=0.5)
51 | parser.add_argument('-test_folder', default=r'./data/example_image/')
52 | parser.add_argument('-output', default='./data/ICDAR2015')
53 |
54 | font2 = ImageFont.truetype("./tools/Arial-Unicode-Regular.ttf", 18)
55 |
56 | args = parser.parse_args()
57 |
58 | # net = ModelResNetSep2(attention=True, nclass=len(alphabet)+1)
59 | net = ModelResNetSep2(attention=True, nclass=len(alphabet)+1)
60 | net_utils.load_net(args.model, net)
61 | net = net.eval()
62 |
63 | converter = strLabelConverter(alphabet)
64 |
65 | if args.cuda:
66 | print('Using cuda ...')
67 | net = net.cuda()
68 |
69 | test_path = os.path.realpath(args.test_folder)
70 | test_path = test_path + '/*.jpg'
71 | imagelist = glob.glob(test_path)
72 |
73 |
74 | with torch.no_grad():
75 | for path in imagelist:
76 |
77 | im = cv2.imread(path)
78 |
79 | im_resized, (ratio_h, ratio_w) = resize_image(im, scale_up=False)
80 | images = np.asarray([im_resized], dtype=np.float)
81 | images /= 128
82 | images -= 1
83 | im_data = net_utils.np_to_variable(images.transpose(0, 3, 1, 2), is_cuda=args.cuda)
84 | seg_pred, rboxs, angle_pred, features = net(im_data)
85 |
86 | rbox = rboxs[0].data.cpu()[0].numpy() # 转变成h,w,c
87 | rbox = rbox.swapaxes(0, 1)
88 | rbox = rbox.swapaxes(1, 2)
89 |
90 | angle_pred = angle_pred[0].data.cpu()[0].numpy()
91 |
92 | segm = seg_pred[0].data.cpu()[0].numpy()
93 | segm = segm.squeeze(0)
94 |
95 | draw2 = np.copy(im_resized)
96 | boxes = get_boxes(segm, rbox, angle_pred, args.segm_thresh)
97 |
98 | img = Image.fromarray(draw2)
99 | draw = ImageDraw.Draw(img)
100 |
101 | out_boxes = []
102 | for box in boxes:
103 |
104 | pts = box[0:8]
105 | pts = pts.reshape(4, -1)
106 |
107 | # det_text, conf, dec_s = ocr_image(net, codec, im_data, box)
108 | det_text, conf, dec_s = align_ocr(net, converter, im_data, box, features, debug=0)
109 | if len(det_text) == 0:
110 | continue
111 |
112 | width, height = draw.textsize(det_text, font=font2)
113 | center = [box[0], box[1]]
114 | draw.text((center[0], center[1]), det_text, fill = (0,255,0),font=font2)
115 | out_boxes.append(box)
116 | print(det_text)
117 |
118 | im = np.array(img)
119 | for box in out_boxes:
120 | pts = box[0:8]
121 | pts = pts.reshape(4, -1)
122 | draw_box_points(im, pts, color=(0, 255, 0), thickness=1)
123 |
124 | cv2.imshow('img', im)
125 | basename = os.path.basename(path)
126 | cv2.imwrite(os.path.join(args.output, basename), im)
127 | cv2.waitKey(1000)
128 |
129 |
130 |
131 |
132 |
--------------------------------------------------------------------------------
/tools/Arial-Unicode-Regular.ttf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chenjun2hao/FOTS.pytorch/a2d1bf71d66197ed6b20c4cfcbe60d4735a20e0c/tools/Arial-Unicode-Regular.ttf
--------------------------------------------------------------------------------
/tools/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chenjun2hao/FOTS.pytorch/a2d1bf71d66197ed6b20c4cfcbe60d4735a20e0c/tools/__init__.py
--------------------------------------------------------------------------------
/tools/align_demo.py:
--------------------------------------------------------------------------------
1 | '''
2 | Created on 2019-06-30
3 |
4 | @author: chenjun2hao
5 | '''
6 |
7 | import os
8 | import cv2
9 | import glob
10 | import numpy as np
11 |
12 | from nms import get_boxes
13 |
14 | from models import ModelResNetSep2, OwnModel
15 | import net_utils
16 |
17 | from src.utils import strLabelConverter, alphabet
18 | from ocr_utils import ocr_image, align_ocr
19 | from data_gen import draw_box_points
20 | import torch
21 |
22 | import argparse
23 |
24 | from PIL import Image
25 | from PIL import ImageFont
26 | from PIL import ImageDraw
27 |
28 | def resize_image(im, max_size = 1585152, scale_up=True):
29 |
30 | if scale_up:
31 | image_size = [im.shape[1] * 3 // 32 * 32, im.shape[0] * 3 // 32 * 32]
32 | else:
33 | image_size = [im.shape[1] // 32 * 32, im.shape[0] // 32 * 32]
34 | while image_size[0] * image_size[1] > max_size:
35 | image_size[0] /= 1.2
36 | image_size[1] /= 1.2
37 | image_size[0] = int(image_size[0] // 32) * 32
38 | image_size[1] = int(image_size[1] // 32) * 32
39 |
40 | resize_h = int(image_size[1])
41 | resize_w = int(image_size[0])
42 |
43 | scaled = cv2.resize(im, dsize=(resize_w, resize_h))
44 | return scaled, (resize_h, resize_w)
45 |
46 |
47 | if __name__ == '__main__':
48 |
49 | parser = argparse.ArgumentParser()
50 | parser.add_argument('-cuda', type=int, default=1)
51 | parser.add_argument('-model', default='./backup/E2E-MLT_280000.h5')
52 | # parser.add_argument('-model', default='./weights/e2e-mlt.h5')
53 | parser.add_argument('-segm_thresh', default=0.5)
54 | parser.add_argument('-test_folder', default=r'/home/yangna/deepblue/OCR/data/ICDAR2015/ch4_test_images/*.jpg')
55 | parser.add_argument('-output', default='./data/ICDAR2015')
56 |
57 | font2 = ImageFont.truetype("Arial-Unicode-Regular.ttf", 18)
58 |
59 | args = parser.parse_args()
60 |
61 | # net = ModelResNetSep2(attention=True, nclass=len(alphabet)+1)
62 | net = ModelResNetSep2(attention=True, nclass=len(alphabet)+1)
63 | net_utils.load_net(args.model, net)
64 | net = net.eval()
65 |
66 | converter = strLabelConverter(alphabet)
67 |
68 | if args.cuda:
69 | print('Using cuda ...')
70 | net = net.cuda()
71 |
72 | imagelist = glob.glob(args.test_folder)
73 | with torch.no_grad():
74 | for path in imagelist:
75 | # path = '/home/yangna/deepblue/OCR/data/ICDAR2015/ch4_test_images/img_405.jpg'
76 | im = cv2.imread(path)
77 |
78 | im_resized, (ratio_h, ratio_w) = resize_image(im, scale_up=False)
79 | images = np.asarray([im_resized], dtype=np.float)
80 | images /= 128
81 | images -= 1
82 | im_data = net_utils.np_to_variable(images.transpose(0, 3, 1, 2), is_cuda=args.cuda)
83 | seg_pred, rboxs, angle_pred, features = net(im_data)
84 |
85 | rbox = rboxs[0].data.cpu()[0].numpy() # 转变成h,w,c
86 | rbox = rbox.swapaxes(0, 1)
87 | rbox = rbox.swapaxes(1, 2)
88 |
89 | angle_pred = angle_pred[0].data.cpu()[0].numpy()
90 |
91 | segm = seg_pred[0].data.cpu()[0].numpy()
92 | segm = segm.squeeze(0)
93 |
94 | draw2 = np.copy(im_resized)
95 | boxes = get_boxes(segm, rbox, angle_pred, args.segm_thresh)
96 |
97 | img = Image.fromarray(draw2)
98 | draw = ImageDraw.Draw(img)
99 |
100 | out_boxes = []
101 | for box in boxes:
102 |
103 | pts = box[0:8]
104 | pts = pts.reshape(4, -1)
105 |
106 | # det_text, conf, dec_s = ocr_image(net, codec, im_data, box)
107 | det_text, conf, dec_s = align_ocr(net, converter, im_data, box, features, debug=0)
108 | if len(det_text) == 0:
109 | continue
110 |
111 | width, height = draw.textsize(det_text, font=font2)
112 | center = [box[0], box[1]]
113 | draw.text((center[0], center[1]), det_text, fill = (0,255,0),font=font2)
114 | out_boxes.append(box)
115 | print(det_text)
116 |
117 | im = np.array(img)
118 | for box in out_boxes:
119 | pts = box[0:8]
120 | pts = pts.reshape(4, -1)
121 | draw_box_points(im, pts, color=(0, 255, 0), thickness=1)
122 |
123 | cv2.imshow('img', im)
124 | basename = os.path.basename(path)
125 | cv2.imwrite(os.path.join(args.output, basename), im)
126 | cv2.waitKey(1000)
127 |
128 |
129 |
130 |
131 |
--------------------------------------------------------------------------------
/tools/codec.txt:
--------------------------------------------------------------------------------
1 | !"#$%&'()*+,-./0123456789:;<=>?@ABCDEFGHIJKLMNOPQRSTUVWXYZ[\]^_`abcdefghijklmnopqrstuvwxyz{|}~¡¢£¤¥¦§¨©ª«¬®¯°±²³´µ¶·¸¹º»¼½¾¿ÀÁÂÃÄÅÆÇÈÉÊËÌÍÎÏÐÑÒÓÔÕÖרÙÚÛÜÝÞßàáâãäåæçèéêëìíîïðñòóôõö÷øùúûüýþÿĀāăąĆćČčĎĐđēĖęĚěğħīıŁłńňŌōŏőŒœřŚśŞşŠšťūŷŸźżŽžƒǔǘǧșɯʒʼʾʿˆˇˉ˘˚˜˝̀́̃̈;ΈΏΑΔΛΟΣΩάέαβδεηθικλμνοπρςστφωόϟЃЄАБВГДЕЗИЙКЛМНОПРСТУФХЦЧШЫЬЭЮЯабвгдежзийклмнопрстуфхчшыьэяёєїֳאגהוטיכלנרשת،؛؟ءآأؤإئابةتثجحخدذرزسشصضطظعغـفقكلمنهوىيًٌٍَُِّْ٠١٢٣٤٥٦٧٨٩٫٬ڤڥڧڨڭࠍࠥࠦएकदनभमलशसािीुे्।॥ঁংঃঅআইঈউঊঋএঐওঔকখগঘঙচছজঝঞটঠডঢণতথদধনপফবভমযরলশষসহ়ািীুূৃেৈোৌ্ৎৗড়ঢ়য়০১২৩৪৫৬৭৮৯ৰ৷৺ఇᄂᄃᄉᄊᄋᄏᄑ하ᅥᅧᅳᅵḤḥṃṇṛṠṣẒễệừὋῖ‐‑‒–—―‖‘’‚“”„‟†‡•‥…‧‰′″‹›※⁄₂₣₤₩€℃ℓ™⅛←→↔⇐⇒⇔∂∆∇∑−√∣∫∼≈≤≥〈①②③─╚╢╩▁■□▪▲△▶►▼▽◆◇◊○●◙★☺☻♠♣♥♦♬✤❤➔➜ 、。々〇〈〉《》「」『』【】〓〔〕ぁあぃいぅうぇえぉおかがきぎくぐけげこごさざしじすずせぜそぞただちぢっつづてでとどなにぬねのはばぱひびぴふぶぷへべぺほぼぽまみむめもゃやゅゆょよらりるれろわをんァアィイゥウェエォオカガキギクグケゲコゴサザシジスズセゼソゾタダチッツヅテデトドナニヌネノハバパヒビピフブプヘベペホボポマミムメモャヤュユョヨラリルレロヮワヱヲンヴヵヶ・ーㆍ㈜㎞㎡㡢㨲㩠㩡㩥㩬一丁七万丈三上下不与丑专且世丘丙业丛东丝両丢两严並丧个中丰串临丸丹为主丼丽举乃久么义之乌乍乎乏乐乔乗乘乙九乞也习乡书买乱乳乾亀了予争事二亍于亏亐云互五井亚些亜亡交亥亦产亨享京亭亮亲亵人亿什仁仅仆仇今介仍从仏仑仓仔仕他仗付仙代令以仪们仮仰仲件价任份仿企伊伍伎伏伐休众优伙会伝伞伟传伤伦伪伯估伴伶伸伺似伽但位低住佐佑体何余佛作你佣佯佰佳併佶佼使侄例侍侖供依侠価侣侥侦侧侨侯侵侶便係促俄俊俏俐俗俘保俞信俩俭修俯俱俳俸俺倉個倍倒候倚借倡倣値倦倫倶债值倾假偉偏做停健側偵偶偷偽偿傅傍傘備储催傭傲債傷傻傾僅働像僑僕僚僧僵價儀億儒償優儲儿允元兄充兆先光克免児兑兒兔党兜入內全兩兪八公六兰共关兴兵其具典兹养兼冀内円冈冊册再冒冕冗写军农冠冢冤冥冬冯冰冲决况冷冻冽净凄准凉凌凍减凑凖凝凞几凡凤処凭凯凰凶凸凹出击函凿刀刁分切刊刑划列刘则刚创初删判別利刪别刮到制刷券刹刺刻剁剂剃則削前剑剔剖剛剣剤剥剧剩剪副剰割創劃劇劉劑力劝办功加务劣动助努劫励劲劳労劵効劾势勃勇勉勋勒動勘務勝募勢勤勧勲勳勾勿匀匂包匈化北匙匠匡匣匪匹区医匾匿區十千升午半华协卑卒卓協单卖南単博卜卞占卡卢卧卫卯印危即却卵卷卸卿厂厄厅历厉压厌厕厘厚原厢厥厦厨厳去县参參又叉及友双反収发取受变叙叛叠口古句另叩只叫召叭可台叱史右叶号司叹吁吃各合吉吊同名后吏吐向吓吕吗君吞否吧吨含听启吳吴吵吸吹吻呀呂呆呈呉告呐呑呕员呜呢周味呼命咀和咎咐咒咖咤咨咫咬咲咳咸咽哀品哇哈哉响哎哑哗員哥哨哪哭哲哺哽唆唇唉唐唤售唯唱唸唾啄商啊問啓啟啡啤啥啦啼喀喂善喆喇喉喊喘喚喜喝喧喪喫單喰営喷喻嗓嗜嗣嗤嗯嗽嘆嘉嘘嘛嘱嘲嘴嘿噌噗噛噢噤器噩噪噴嚴嚼囊囚四回因团団园困囲図围固国图圃圆圈國圍圏園圓圖團土圣圧在圭地圳场圾址坂均坊坎坏坐坑块坚坛坝坞坠坡坤坦坪坰垂垃垄型垒垚垢垣垫埃埈埋城埙域埠執培基堂堅堆堡堤堪報場堵塊塌塑塔塗塘塚塞塩填塾境墅墓増墙墜增墟墨墳壁壇壊壌壎壞壤士壬壮声壱売壳壹壽处备変复夏夕外多夜够夢大天太夫夭央失头夷夸夹夺奇奈奉奋奎奏契奔奖套奠奥奧奨奪奬奭奮女奴奶奸她好如妃妄妆妇妈妊妍妒妙妝妥妨妮妹妻姃姆姉姊始姐姑姓委姙姚姜姦姨姫姬姸姻姿威娅娇娘娜娠娥娩娯娱娶婆婉婊婚婦婴婷婿媒媛媲媳嫁嫉嫌嬉嬌嬪子孔孕字存孙孜孝孟季孤学孩孫孵學宁它宅宇守安宋完宏宗官宙定宛宜宝实実宠审客宣室宪宫宮宰害宴宵家容宽宾宿寂寄寅密富寒寓寛寝察寡寢寥實寧寨審寫寬寮寵寶寸对寺寻导対寿封専射将將專尊尋對導小少尔尖尘尙尚尝尤尬就尴尸尹尺尻尼尽尾尿局屁层居屈届屋屎屏屑展属屠層履屯山屿岁岌岐岔岖岗岘岛岡岩岫岭岳岷岸峙峠峡峨峭峯峰峴島峻崇崎崔崖崗崛崩崭嵌嵩嶺嶽巌巖川州巡巢巣工左巧巨巩巫差己已巴巷巻巾币市布帅帆师希帐帕帖帘帚帝带師席帮帯帰帳帷常帽幅幇幌幔幕幡幢幣干平年并幸幹幻幼幽幾广庁広庄庆庇床序库应底店庙庚府庞废度座庫庭庵庶康庸廃廉廊廓廖廛廟延廷建廻开弁异弃弄弊式弐弓弔引弗弘弛弟张弥弦弧弯弱張強弹强弾归当录形彦彩彫彬彭彰影彷役彻彼彿往征径待很徊律後徐徑徒従得徘御復循微徳徴徵德徹心必忆忌忍志忘忙応忠忧快念忽怀态怅怎怒怕怖怜思怠怡急怦性怨怪怯总恃恋恍恐恒恕恙恢恣恤恥恨恩恭息恰恳恵恶恻恼悉悔悖悚悟悠患悦您悩悪悬悲悳悽情惆惇惊惋惑惕惗惚惜惟惠惡惣惧惨惩惫惯惱想惶惹惺愁愈愉意愚愛感愤愧愿慈態慌慎慕慢慣慧慨慮慰慶慷憂憎憤憧憨憩憫憬憲憶憺憾懂懇懈應懐懒懦懲懸戈戏成我戒或战戚戦截戰戴戶户戸戻房所扁扇扉手才扎扑打扔払托扣执扩扪扫扬扭扮扰扱扳扶批找承技抉把抑抒抓投抖抗折抚抛抜択抢护报披抱抵抹押抽抿担拆拉拌拍拎拒拓拔拖拘拙招拜拝拟拠拡拢拥拦拨择括拭拯拳拶拷拼拾拿持挂指按挑挖挙挚挛挝挟挠挡挣挤挥挨挪挫振挽挿捆捉捍捏捐捕捜损捡换捧捨据捷捻掀掃授掉掌掏掐排掘掛掠採探接控推掩措掲掺揃揄揆揉描提插揚換握揭揮援揶揺揽搁搅損搏搓搖搜搞搬搭携搾摁摂摄摆摇摔摘摧摩摯摸撃撑撒撞撤撩播撮撰撲擁擅操擎據擢擦攀攒攣支收改攻放政故效敌敍敎敏救敗教敛敞敢散敦敬数敲整敵敷數斂文斉斋斌斎斐斑斗料斜斟斡斤斥斧斩斬断斯新方於施旁旅旋旌族旗无既日旦旧旨早旬旭时旷旺旻旼旿昂昆昇昉昊昌明昏易昔星映春昧昨昭是昼显晁時晃晋晓晚晟晤晦晨晩晫普景晰晴晶智晾暁暂暇暈暑暖暗暦暨暫暮暴暻曉曖曙曜曝曲更書曹曺曼曽曾替最會月有朋服朔朗望朝期朦木未末本札术朱朴朵机朽杀杂权杆杉李杏材村杓杖杜杞束条来杨杭杯杰東松板极构枉析枓枕林枚果枝枠枢枪枫枯架枷枸柄柏某柑染柔柜柠查柯柱柳柴柵査柿栄栅标栈栋栏树栓栗校栩株样核根格栽桁桂桃框案桌桐桑桓桜档桥桦桶梁梅梗條梢梦梨梭梯械梵检棄棉棋棍棒棕棚棟森棱棲棵棺椅植椎椒検椭椿楊楕楚業極楷楼楽概榆榛榜榧榨榭榮榻構槍様槛槟槳槽槿樂樓標樟模樣権横樵樹樺樽橇橋橘橙機橡橫橱檀檐檔檢檬櫓欄權欠次欢欣欧欲欺款歉歌歎歓歡止正此步武歧歩歪歯歳歴歹死殃殆殉殊残殖殴段殷殺殿毀毁毅毋母毎每毒比毕毙毛毫毯氏民氓气気氚氛氟氢氣氧氩水氷永氾汀汁求汇汉汎汗汚汝江池污汤汪汰汲汴汶汹決汽沁沂沃沈沉沌沐沒沖沙沟没沢沦沧沪沫河油治沼沾沿況泂泄泉泊泌法泛泡波泣泥注泪泯泰泳泵泼泽泾洁洋洒洗洙洛洞津洪洲洵活洼派流浄浅浆测济浏浑浓浙浚浜浣浦浩浪浮浴海浸涂消涉涌涓涙涛涜涡涤润涧涨涪涯液涵涼淀淆淇淋淌淑淘淞淡淤淫深淳淵混淸淹淺添清渇済渉渊渋渎渐渓渔渗減渝渠渡渣渤温測港渲渴游渾湖湘湧湯湾湿満溃溅溉溌源準溜溝溢溪溫溯溶滅滉滋滑滓滔滕滚滞满滥滦滨滩滲滴滿漁漂漆漏漓演漠漢漫漬漱潍潔潘潜潤潭潮澄澈澗澜澡澤澱澳激濁濃濟濡濤濩濫濬濯瀑瀕瀚瀬灌灘灣火灭灯灰灵灸災灾灿炅炉炊炎炕炙炫炬炭炮炯炳炸点為炼烁烂烈烏烘烟烤烦烧烨烫热烯烹焉焊焕焖焘焙焚無焦焰然焼煉煌煎煕煙煤煥照煮煽熄熊熏熙熟熨熬熱熹熾燁燃燒燕營燥燦燮燾爀爆爐爨爪爬爱爲爵父爷爸爺爽片版牌牒牙牛牡牢牧物牲牵特牺牽犠犧犬犯状犷犹狀狂狐狗狙狡狩独狭狮狱狸狼猕猛猜猟猥猪猫献猴猶猿獄獐獨獲玄率玉王玖玛玥玩玫环现玲玹玺玻珈珉珊珍珐珑珞珠珪班現球琅理琇琏琐琦琲琴琼瑕瑛瑜瑞瑟瑩瑰瑶璃璉璋璜璟璧環璹璽瓊瓒瓚瓜瓠瓢瓣瓦瓮瓶瓷甕甘甙甚甜生產産用甩甫甬田由甲申电男町画畅界畏畑畔留畜畢略番畯異畳畵當畿疆疋疎疏疑疗疫疯疲疵疹疼疾病症痉痒痕痙痛痢痪痫痰痴痹瘁瘍瘙瘦瘪瘫瘾療癇癌癒癖癫癲発登發白百皂的皆皇皋皓皙皮皱皿盂盆盈益盏盐监盒盖盗盘盛盟監盤盧目盯盲直相盼盾省眈眉看県眞真眠眶眺眼眾着睁睐睛睡督睦睫睹瞄瞑瞒瞞瞩瞪瞬瞭瞰瞻矛矢矣知矫短矮矯石矶矿码砂砍研砖砥砲破砸础硅硏硕硝硬确硯碁碌碍碎碑碗碘碟碧碩碰碱碳確磁磋磨磻礎示礼社祀祈祉祐祖祚祜祝神祠祥票祭祯祷祸禁禄禅禍禎福禧禪禹离禽禾秀私秉秋种科秒秕秘租秦秩积称移稀稅程稍税稚稜稣種稱稲稳稷稻稼稽稿穀穂穆積穏穫穴究穷空穿突窃窄窍窓窗窝窟窮窺窿立站竜竞竟章竣童竭端競竹竿笋笑笔笛符笨第笹笼筆等筋筏筑筒答策筝筹签简箇算管箫箭箱節範篆篇築簡簿籍米类粉粋粒粗粘粛粥粧粪粮粲粹精糀糊糕糖糙糟糧糯糸系糾紀約紅紆紊紋納紐純紗紙級紛素紡索紧紫累細紹紺終組経結絞絡給絨統絵絶絹經継続綜綠維綱網綻綿緊総緑緒緖線締編緩緯練緻縁縄縛縞縣縦縫縮縱總績繁繊織繕繰纉續纏纠红纤约级纪纫纬纭纯纱纲纳纵纶纷纸纹纺纽线绀练组细织终绊绌绍绎经绑绒结绕绘给络绝绞统继绩绪续绰绳维绵绸综绽绿缄缆缓缔编缘缚缝缟缠缩缬缴缶缸缺罄罐网罔罕罗罚罠罢罩罪置罰署罵罹羁羅羈羊美羞羡群羨義羲羽翁翌翎習翔翕翘翠翩翰翱翻翼耀老考者而耍耐耕耗耙耳耶耻耽聂聆聊职联聖聘聚聞聪聴聶職肃肇肉肌肖肘肚肝肠股肢肤肥肩肪肯育肴肺肿胀胁胃胆背胎胖胚胜胞胡胤胧胳胴胶胸胺能脂脅脆脇脈脉脊脏脐脑脚脫脱脳脸脹脾腊腋腌腐腑腔腕腦腫腰腱腸腹腺腻腾腿膀膊膏膚膜膝膠膨膳膺臀臂臓臣臨自臭至致臺臻臼舀舅舆與興舉舊舌舍舎舒舗舜舞舟航舫般舰舱舵舶船艇艘艦良艰色艳艺艾节芉芋芒芙芝芦芪芬芭芯花芳芸芽苇苍苏苑苒苗苛苟若苦英苹苺茂范茄茅茎茗茜茨茫茵茶茸荆草荏荐荒荔荚荞荡荣荧荫药荷荼莅莉莊莎莓莠莫莱莲获莽菀菅菇菉菌菓菖菜菩華菱菲菸萄萌萎营萧萨萬落葉著葛葡董葦葫葬葱葵葺蒂蒋蒙蒜蒲蒸蒼蒾蒿蓄蓉蓋蓝蓬蓮蔑蔓蔗蔚蔡蔦蔬蔵蔷蔽蕉蕊蕎蕓蕨蕴蕾薄薇薙薛薦薩薪薫薬薯薰藍藏藤藥藻藿蘆蘇蘑蘭虎虏虐虑處虚虜號虫虹虽虾蚀蚁蚂蚕蛀蛄蛇蛋蛍蛙蛛蛟蛮蛹蜀蜂蜃蜘蜜蝉蝙蝠蝶螂融螢螺蟑蟾蠡蠢蠵蠶血衅衆行衍術衔街衛衝衡衣补表衫衬衰衷衿袁袂袋袍袒袖袜被袭袱裁裂装裏裔裕補裝裤裴裵裸裹製裾複褐褒褡褥褪襟襲西要覆覇見規視覚覧親観覺覽觀见观规视览觉觊觎觑角解触言訂計討訓託記訟訣訪設許訳訴診証詐評詞試詩詫詭詰話該詳詹誅誇誉誌認誓誕誘語誠誡誤說説読誰課調談請論諦諮諷諸諾謀謄謙講謝謡謬謳謹證識譚譜警議譲護讀讃變讐计订认讥讧讨让训议讯记讲讶许论讼讽设访诀证评诅识诈诉诊词译试诗诙诚话诞诡询该详诫诬语误诱说诵请诸诺读课谁调谅谈谊谋谍谎谐谓谜谢谣谦谨谬谭谱谴谷豁豆豊豚象豪豫貌貝貞負財貢貧貨販貫責貯貴買貸費貼貿賀賃資賊賑賓賛賜賞賠賢賦質賭購贅贈贓贝贞负贡财责贤败账货质贩贪贫贬购贯贴贵贷贸费贺贼贿赁资赋赌赎赏赐赔赖赚赛赞赠赢赤赦赫走赴赵赶起趁超越趋趙趣足趾跃跆跌跏跑距跟跡跤跨路跳践跻踊踏踝踢踩踪踱踵蹄蹈蹴躍身躬躲躺車軌軍軒軟転軫軸軽較載輌輔輝輩輪輸輿轄轉轟车轨轩转轮软轰轴轻载轿较辅辆辈辉辐辑输辖辙辛辜辞辟辣辨辩辫辭辰辱農边辺辻込辽达辿迁迂迄迅过迈迎运近返还这进远违连迟迪迫述迴迷迹追退送适逃逅逆选逊透逐递途逗這通逝速造逢連逮週進逸逹逻逼逾遂遅遇遊運遍過道達違遗遜遠遡遣遥適遭遮遵選遺遼避邀邁邂還邊邑那邦邪邮邱邵邸邻郁郊郎郑郝郞郡部郭郵郷都鄂鄉鄙鄭酉酋酌配酎酒酔酗酢酪酬酰酱酵酷酸酿醇醉醋醍醐醒醛醜醴醸采釈释釋里重野量金釘釜針釣鈍鈞鈺鉄鉉鉛鉢鉱鉴銀銃銅銓銖銘銭銷鋒鋪鋭鋼錄錐錞錠錦錫錬錯録鍊鍋鍵鍼鍾鎌鎔鎖鎬鎭鎮鏞鏡鐘鐵鐸鑑鑫针钉钓钙钚钛钝钞钟钠钢钧钩钮钰钱钵钻钾铀铁铃铅铉铜铝铢铨铬铭铱银铺链销锁锂锅锈锋锐错锡锣锤锦键锯锻镁镇镍镐镕镛镜镳镶長长門閉開閏閑間閔閠関閣閥閲闇闊闕闘關门闪闭问闯闰闲间闵闷闸闹闻闽阁阅阐阔阙阜队阪阱防阳阴阵阶阻阿陀附际陆陈陋陌降限陕陡院陣除陥陨险陪陰陳陵陶陷陸険陽隅隆隊階随隐隔隕隘隙際障隠隣隧隨險隱隶隷隻难雀雁雄雅集雇雉雌雑雕雙雛雜離難雨雪雰雲雳零雷電雾需霄霆震霉霊霍霏霓霜霞霧露霸霹靈靑青靓靖静靜非靠靡面革靴鞄鞋鞘鞠鞭韓韦韧韩音韵韻響頁頂頃項順須預頑頓領頭頴頻頼題額顔顕願類顧顯页顶项顺须顽顾顿颁颂预领颇颈频颔颖颗题颜额颠颤風风飓飘飙飚飛飞食飢飬飯飲飴飼飽飾餅養餌餐餓餘館饅饗饥饪饭饮饰饱饲饵饶饺饼饿馆馈首香馥馨馬馴駄駅駆駐駿騒験騰驗驚驪马驯驰驱驳驶驻驾驿骂骄骊验骏骑骗骚骤骨骸髓體高髙髪鬱鬼魁魂魄魅魏魔魚魯鮎鮨鮮鯵鰍鰐鱼鱿鲁鲍鲐鲜鲨鲹鳗鳥鳳鳴鴉鴨鴻鵡鶏鶴鷄鷹鷺鸟鸡鸣鸥鸦鸭鸿鹅鹊鹏鹤鹭鹰鹿麒麓麗麟麦麺麻黃黄黎黑黒默黙黨黯鼎鼓鼠鼻齋齐齢齿龄龈龋龍龙龜龟가각간갇갈갉감갑값갔강갖같갚개객갤갯갱걀걔거걱건걷걸검겁것겅겉게겐겔겠겡겨격겪견결겸겹겼경곁계고곡곤곧골곯곰곱곳공곶과곽관괄괌광괜괴굉교구국군굳굴굵굶굽굿궁궈권궐궜궤귀귓규균귤그극근글긁금급긋긍기긴길김깁깃깅깊까깎깐깔깜깝깡깥깨꺼꺾껄껌껍껏껐께껴꼇꼈꼬꼭꼴꼼꼽꽁꽂꽃꽉꽝꽤꾀꾜꾸꾹꾼꿀꿇꿈꿉꿋꿍꿔꿨꿩꿰뀌뀐뀔끄끈끊끌끓끔끗끝끼끽낀낄낌나낙낚난날낡남납낫났낭낮낯낱낳내낵낸낼냄냇냈냉냐냥너넉넋넌널넓넘넛넣네넥넨넬넷넹녀녁년녋념녔녕녘녜노녹논놀놈농높놓놔놨뇌뇨뇽누눈눌눔눕눙눠눴뉘뉜뉴늄늉느늑는늘늙늠능늦늪늬니닉닌닐님닙닛닝다닥닦단닫달닭닮담답닷당닿대댁댄댈댐댓댔댕댜더덕던덜덟덤덥덧덩덮데덱덴델뎀뎅뎌도독돈돋돌돔돕돗동돟돠돱돼됐되된될됨됩두둑둔둘둠둡둥둬뒀뒤뒷뒹듀듈듐드득든듣들듦듬듭듯등디딕딘딛딜딥딨딩딪따딱딴딸땀땅때땐땜땠땡떠떡떤떨떳떴떻떼뗐뗙뗜또똑똥뚜뚝뚫뚱뛰뛴뜨뜩뜬뜯뜰뜻띄띈띌띔띠띤띰라락란랄람랍랏랐랑랗래랙랜램랩랫랬랭랴략량러럭런럴럼럽럿렀렁렇레렉렌렐렘렛려력련렬렴렵렷렸령례롄로록론롤롬롭롯롱뢰료룡루룩룬룰룸룹룽뤄뤘뤼류륙륜률륨륭르륵륶른를름릅릇릉릎리릭린릴림립릿링마막만많맏말맑맘맙맛망맞맟맡맣매맥맨맴맵맸맹맺먀머먹먼멀멈멋멍메멕멘멜멤멩며면멸몁몄명몇모목몫몬몰몸몹못몽묘무묵묶문묻물묽뭇뭉뭐뭔뭘뮈뮌뮤뮬므믈믑믜미믹민믿밀밉밋밌밍밎및밑바박밖반받발밝밟밤밥밧방밭배백밴밸뱀뱃뱅뱉뱍버벅번벌범법벗베벡벤벨벼벽변별볍병볕보복볶본볼봄봅봇봉봐봤뵈뵙부북붂분불붉붐붓붕붙뷔뷰뷴뷸브븍븐블븥비빅빈빌빔빗빙빚빛빠빡빤빨빴빵빻빼빽뺀뺌뺐뺨뻐뻑뻔뻗뻘뻣뻤뼈뼛뽀뽐뽑뽕뾰뿌뿍뿐뿔뿜쁘쁜쁨삐삔사삭산살삶삼삽삿샀상샅새색샌샐샘샛생샤샴샵샹샾섀서석섞선섣설섬섭섯섰성세섹센섿셀셈셉셋셔셕션셜셨셰셴셸소속손솔솜솝솟송솥쇄쇠쇼숀숍수숙순술숨숩숫숭숯숱숲숴쉅쉐쉘쉬쉰쉴쉼쉽쉿슈슐슘스슥슨슬슭슴습슷승슼슽시식신싣실싫심십싯싱싶싸싹싼쌀쌍쌓쌩써썩썬썰썸썼썽쎄쏘쏙쏜쏟쏠쐈쐬쑤쑥쑹쓰쓴쓸씀씌씨씩씬씹씻아악안앉않알앓암압앗았앙앞애액앤앨앰앱앴앵야약얀얄얇얏양얕얗얘어억언얹얻얼얽엄업없엇었엉엎에엑엔엘엠엣여역엮연엳열염엽엿였영옅옆예옌옛오옥온올옮옳옴옵옷옹와왁완왈왑왓왔왕왜외왼요욕용우욱운울움웁웃웅워웍원월웠웨웬웰웹위윈윌윗윙유육윤율융으윽은을음읍응의이익인일읽잃임입잇있잉잊잎자작잔잖잘잠잡잣장잦재잭잰잿쟁쟈쟤저적전절젊점접젓정젖제젝젠젤젬젯져젹젼젿졌조족존졸좀좁종좇좋좌좍죄죠주죽준줄줌줍중줘줬쥐쥔쥬즈즉즌즐즘즙증즤지직진질짊짐집짓징짖짙짚짜짝짠짤짧짬짱째쨌쩌쩍쩔쩜쩝쩡쪼쪽쫄쫓쭈쭉쭝쯔쯕쯤찌찍찐찔찜찡찢차착찬찮찰찲참찹찻찼창찾채책챈챌챔챗챙챠처척천철첨첩첫청체첸첼쳐쳤초촉촌촐촘촛총촨촬최쵸추축춘출춤춥춧충춰췄취츠측츰층치칙친칠침칩칫칭카칵칸칼캄캅캉캐캔캘캠캡캣컘커컨컫컬컴컵컷컸케켄켈켓켜켤켭켰켸코콕콘콜콤콥콧콩콰쾌쿄쿠쿤쿨퀘퀴퀸큐큘크큰클큼킁키킨킬킴킷킹타탁탄탈탐탑탓탔탕태택탠탤탯탰탱터턱턴털턺텀텃텄텅테텍텐텔템텝텟텨텼톈토톡톤톨톰톱통퇘퇴투툭툰툴툼퉁퉈튀튜튝튬트특튼틀틈틋틔티틱틴틸팀팁팅파팍팎판팔팜팝팟팠팡팥패팩팬팸팽퍼펀펄펌펍펐펑페펙펜펠펩펫펴편펼폈평폐포폭폰폴폼폿표푸푹푼풀품풋풍퓨퓰프픈플픔픗픚피픽핀핃필핌핍핏핑하학한할함합핫항해핵핸핼햄햇했행햐향허헌헐험헙헛헝헤헨헬헹혀혁현혈혐협혔형혜호혹혼홀홈홉홋홍화확환활황홰회획횟횡효횹후훈훌훔훗훙훤훨훼휘휠휩휴휼흉흐흑흔흘흙흠흡흥흩희흰히힉힌힐힘힝金羅蘭盧老魯碌菉論陵寧磻不沈若兩梁良量女麗聯烈瑩醴龍暈劉柳硫栗利李梨麟林拓宅fffifl!#%&'()+,-./0123456789:;=?ABCDEFGHIJKLMNOPRSTUVWXYZ[]abcdefghiklmnoprstuvwxy~・ネ�砧钥츄瘋萝乒乓砌窥鐫चरचŮůŘपवŇďŤ脯喽腩樱绣仟孚¥葆咪莒煲蝎炒壶娃蜻蜓耒郴铸辊轧缈萃佩珩沱呷莞岚囤蕲稗秧惬鲸縤奢瞧咱荟抄馅甄灶捞埔咋稞篷莺翡趟鳝馍嫩哄痣佬崴呗卤兽枣犇貂柒铂钯绅镀扒裱诠娟凳槐犊浇铮廿缇梓粼俪榴纰缕瞅觅撕豉焗桔崋嗨瑅氨戟冶瑄榔徽佗鳞哮吾溺磅稠涝鹦鹉蟹阖叽獴廰苓晒簧瓯馄饨粤钣脖阑炖盔捣鸳鸯潢骋鞍翅鸽寞颐黛陂倪肆逛嘻酥幂睿倩Ⅱ驴璞扦茉滤撸鱻瑚侈肛铠镯裳蚊藕沔垦涮喵蜡煸矽ɭֹ֧椰卉汕肾巅叼乖钜汾烽窖彪尨勺琪赈萱氮缤栖踞礡恪蜗呦屌厍蹭嘟琥珀橄榄喔犀谌哦珂汨喱咔淮泻洱盱眙菁戳歇䷽ڵȻ˺琉豹闺鲢菊骠瑪摊祛來镭偕沥贰叁滁痘琯柚梳賣芹娄芡炝镌楹涎浠阀苞粱芥轭粑锲黔硚涩筛崽媚爹篓湛吼璀璨芊迭朕霖仞饯醫薹泸瀘鳕绮琳鳄庐襄颊咯耿痧塬棠旱撼藓叮疤寰瑙琢楂奕圜擀嫂兮悸挞骥赃猎蜥蜴垩唔蔻妞逍泷谕呵矩籁篮邓龚萍筐甸哒浥揍嘞帛炜吆堰瞎箔丫瞳峦邰熔絮蝇劈裙赣泓哟宸蛭砺擘叔妖悍嚒渍咏氰酯噁唑煨巍廣砰糍菠渺旖濛婺臧沛佃邗咦晕軎鲈溇瀛鲫篝昵灼崧婧秭噜拽悄帼漾磊犸釉扯桩攝榈粽拇牦滏苕谛尧磐佟馋嘀嗒咕靶忱籽咚疮痍岱邯郸馔菘痔沸噹瑷侬恬聋囍烙酚葚卦屹玮贱惪夙韬顷茏韭唧阡摒豌斓琬秸碴晖馒羯痿蝴薡鲤焱蕙镔钿磷辋煞牟荤烩婵缭畴硒郢爯捂薏嘶柬缰拐彤疝抬墩邹榕阆霾叟窜蕃哩遛绶蚝廠樊刨畸窈窕逑雎鸠啪哆竺氯苯酮硫寇曦妾陇坨漳亩梧捶骆驼
2 |
--------------------------------------------------------------------------------
/tools/codec_rctw.txt:
--------------------------------------------------------------------------------
1 | ίγἀəლ≡!"#$%&'()*+,-./0123456789:;<=>?@ABCDEFGHIJKLMNOPQRSTUVWXYZ[\]^_`abcdefghijklmnopqrstuvwxyz{|}~ ¡¢£¤¥¦§¨©ª«¬®¯°±²³´µ¶·¸¹º»¼½¾¿ÀÁÂÃÄÅÆÇÈÉÊËÌÍÎÏÐÑÒÓÔÕÖרÙÚÛÜÝÞßàáâãäåæçèéêëìíîïðñòóôõö÷øùúûüýþÿĀāăąĆćČčĎĐđēĖęĚěğħīıŁłńňŌōŏőŒœřŚśŞşŠšťūŷŸźżŽžƒǔǘǧșɯʒʼʾʿˆˇˉ˘˚˜˝̀́̃̈;ΈΏΑΔΛΟΣΩάέαβδεηθικλμνοπρςστφωόϟЃЄАБВГДЕЗИЙКЛМНОПРСТУФХЦЧШЫЬЭЮЯабвгдежзийклмнопрстуфхчшыьэяёєїֳאגהוטיכלנרשת،؛؟ءآأؤإئابةتثجحخدذرزسشصضطظعغـفقكلمنهوىيًٌٍَُِّْ٠١٢٣٤٥٦٧٨٩٫٬ڤڥڧڨڭࠍࠥࠦएकदनभमलशसािीुे्।॥ঁংঃঅআইঈউঊঋএঐওঔকখগঘঙচছজঝঞটঠডঢণতথদধনপফবভমযরলশষসহ়ািীুূৃেৈোৌ্ৎৗড়ঢ়য়০১২৩৪৫৬৭৮৯ৰ৷৺ఇᄂᄃᄉᄊᄋᄏᄑ하ᅥᅧᅳᅵḤḥṃṇṛṠṣẒễệừὋῖ‐‑‒–—―‖‘’‚“”„‟†‡•‥…‧‰′″‹›※⁄₂₣₤₩€℃ℓ™⅛←→↔⇐⇒⇔∂∆∇∑−√∣∫∼≈≤≥〈①②③─╚╢╩▁■□▪▲△▶►▼▽◆◇◊○●◙★☺☻♠♣♥♦♬✤❤➔➜ 、。々〇〈〉《》「」『』【】〓〔〕ぁあぃいぅうぇえぉおかがきぎくぐけげこごさざしじすずせぜそぞただちぢっつづてでとどなにぬねのはばぱひびぴふぶぷへべぺほぼぽまみむめもゃやゅゆょよらりるれろわをんァアィイゥウェエォオカガキギクグケゲコゴサザシジスズセゼソゾタダチッツヅテデトドナニヌネノハバパヒビピフブプヘベペホボポマミムメモャヤュユョヨラリルレロヮワヱヲンヴヵヶ・ーㆍ㈜㎞㎡㡢㨲㩠㩡㩥㩬一丁七万丈三上下不与丑专且世丘丙业丛东丝両丢两严並丧个中丰串临丸丹为主丼丽举乃久么义之乌乍乎乏乐乔乗乘乙九乞也习乡书买乱乳乾亀了予争事二亍于亏亐云互五井亚些亜亡交亥亦产亨享京亭亮亲亵人亿什仁仅仆仇今介仍从仏仑仓仔仕他仗付仙代令以仪们仮仰仲件价任份仿企伊伍伎伏伐休众优伙会伝伞伟传伤伦伪伯估伴伶伸伺似伽但位低住佐佑体何余佛作你佣佯佰佳併佶佼使侄例侍侖供依侠価侣侥侦侧侨侯侵侶便係促俄俊俏俐俗俘保俞信俩俭修俯俱俳俸俺倉個倍倒候倚借倡倣値倦倫倶债值倾假偉偏做停健側偵偶偷偽偿傅傍傘備储催傭傲債傷傻傾僅働像僑僕僚僧僵價儀億儒償優儲儿允元兄充兆先光克免児兑兒兔党兜入內全兩兪八公六兰共关兴兵其具典兹养兼冀内円冈冊册再冒冕冗写军农冠冢冤冥冬冯冰冲决况冷冻冽净凄准凉凌凍减凑凖凝凞几凡凤処凭凯凰凶凸凹出击函凿刀刁分切刊刑划列刘则刚创初删判別利刪别刮到制刷券刹刺刻剁剂剃則削前剑剔剖剛剣剤剥剧剩剪副剰割創劃劇劉劑力劝办功加务劣动助努劫励劲劳労劵効劾势勃勇勉勋勒動勘務勝募勢勤勧勲勳勾勿匀匂包匈化北匙匠匡匣匪匹区医匾匿區十千升午半华协卑卒卓協单卖南単博卜卞占卡卢卧卫卯印危即却卵卷卸卿厂厄厅历厉压厌厕厘厚原厢厥厦厨厳去县参參又叉及友双反収发取受变叙叛叠口古句另叩只叫召叭可台叱史右叶号司叹吁吃各合吉吊同名后吏吐向吓吕吗君吞否吧吨含听启吳吴吵吸吹吻呀呂呆呈呉告呐呑呕员呜呢周味呼命咀和咎咐咒咖咤咨咫咬咲咳咸咽哀品哇哈哉响哎哑哗員哥哨哪哭哲哺哽唆唇唉唐唤售唯唱唸唾啄商啊問啓啟啡啤啥啦啼喀喂善喆喇喉喊喘喚喜喝喧喪喫單喰営喷喻嗓嗜嗣嗤嗯嗽嘆嘉嘘嘛嘱嘲嘴嘿噌噗噛噢噤器噩噪噴嚴嚼囊囚四回因团団园困囲図围固国图圃圆圈國圍圏園圓圖團土圣圧在圭地圳场圾址坂均坊坎坏坐坑块坚坛坝坞坠坡坤坦坪坰垂垃垄型垒垚垢垣垫埃埈埋城埙域埠執培基堂堅堆堡堤堪報場堵塊塌塑塔塗塘塚塞塩填塾境墅墓増墙墜增墟墨墳壁壇壊壌壎壞壤士壬壮声壱売壳壹壽处备変复夏夕外多夜够夢大天太夫夭央失头夷夸夹夺奇奈奉奋奎奏契奔奖套奠奥奧奨奪奬奭奮女奴奶奸她好如妃妄妆妇妈妊妍妒妙妝妥妨妮妹妻姃姆姉姊始姐姑姓委姙姚姜姦姨姫姬姸姻姿威娅娇娘娜娠娥娩娯娱娶婆婉婊婚婦婴婷婿媒媛媲媳嫁嫉嫌嬉嬌嬪子孔孕字存孙孜孝孟季孤学孩孫孵學宁它宅宇守安宋完宏宗官宙定宛宜宝实実宠审客宣室宪宫宮宰害宴宵家容宽宾宿寂寄寅密富寒寓寛寝察寡寢寥實寧寨審寫寬寮寵寶寸对寺寻导対寿封専射将將專尊尋對導小少尔尖尘尙尚尝尤尬就尴尸尹尺尻尼尽尾尿局屁层居屈届屋屎屏屑展属屠層履屯山屿岁岌岐岔岖岗岘岛岡岩岫岭岳岷岸峙峠峡峨峭峯峰峴島峻崇崎崔崖崗崛崩崭嵌嵩嶺嶽巌巖川州巡巢巣工左巧巨巩巫差己已巴巷巻巾币市布帅帆师希帐帕帖帘帚帝带師席帮帯帰帳帷常帽幅幇幌幔幕幡幢幣干平年并幸幹幻幼幽幾广庁広庄庆庇床序库应底店庙庚府庞废度座庫庭庵庶康庸廃廉廊廓廖廛廟延廷建廻开弁异弃弄弊式弐弓弔引弗弘弛弟张弥弦弧弯弱張強弹强弾归当录形彦彩彫彬彭彰影彷役彻彼彿往征径待很徊律後徐徑徒従得徘御復循微徳徴徵德徹心必忆忌忍志忘忙応忠忧快念忽怀态怅怎怒怕怖怜思怠怡急怦性怨怪怯总恃恋恍恐恒恕恙恢恣恤恥恨恩恭息恰恳恵恶恻恼悉悔悖悚悟悠患悦您悩悪悬悲悳悽情惆惇惊惋惑惕惗惚惜惟惠惡惣惧惨惩惫惯惱想惶惹惺愁愈愉意愚愛感愤愧愿慈態慌慎慕慢慣慧慨慮慰慶慷憂憎憤憧憨憩憫憬憲憶憺憾懂懇懈應懐懒懦懲懸戈戏成我戒或战戚戦截戰戴戶户戸戻房所扁扇扉手才扎扑打扔払托扣执扩扪扫扬扭扮扰扱扳扶批找承技抉把抑抒抓投抖抗折抚抛抜択抢护报披抱抵抹押抽抿担拆拉拌拍拎拒拓拔拖拘拙招拜拝拟拠拡拢拥拦拨择括拭拯拳拶拷拼拾拿持挂指按挑挖挙挚挛挝挟挠挡挣挤挥挨挪挫振挽挿捆捉捍捏捐捕捜损捡换捧捨据捷捻掀掃授掉掌掏掐排掘掛掠採探接控推掩措掲掺揃揄揆揉描提插揚換握揭揮援揶揺揽搁搅損搏搓搖搜搞搬搭携搾摁摂摄摆摇摔摘摧摩摯摸撃撑撒撞撤撩播撮撰撲擁擅操擎據擢擦攀攒攣支收改攻放政故效敌敍敎敏救敗教敛敞敢散敦敬数敲整敵敷數斂文斉斋斌斎斐斑斗料斜斟斡斤斥斧斩斬断斯新方於施旁旅旋旌族旗无既日旦旧旨早旬旭时旷旺旻旼旿昂昆昇昉昊昌明昏易昔星映春昧昨昭是昼显晁時晃晋晓晚晟晤晦晨晩晫普景晰晴晶智晾暁暂暇暈暑暖暗暦暨暫暮暴暻曉曖曙曜曝曲更書曹曺曼曽曾替最會月有朋服朔朗望朝期朦木未末本札术朱朴朵机朽杀杂权杆杉李杏材村杓杖杜杞束条来杨杭杯杰東松板极构枉析枓枕林枚果枝枠枢枪枫枯架枷枸柄柏某柑染柔柜柠查柯柱柳柴柵査柿栄栅标栈栋栏树栓栗校栩株样核根格栽桁桂桃框案桌桐桑桓桜档桥桦桶梁梅梗條梢梦梨梭梯械梵检棄棉棋棍棒棕棚棟森棱棲棵棺椅植椎椒検椭椿楊楕楚業極楷楼楽概榆榛榜榧榨榭榮榻構槍様槛槟槳槽槿樂樓標樟模樣権横樵樹樺樽橇橋橘橙機橡橫橱檀檐檔檢檬櫓欄權欠次欢欣欧欲欺款歉歌歎歓歡止正此步武歧歩歪歯歳歴歹死殃殆殉殊残殖殴段殷殺殿毀毁毅毋母毎每毒比毕毙毛毫毯氏民氓气気氚氛氟氢氣氧氩水氷永氾汀汁求汇汉汎汗汚汝江池污汤汪汰汲汴汶汹決汽沁沂沃沈沉沌沐沒沖沙沟没沢沦沧沪沫河油治沼沾沿況泂泄泉泊泌法泛泡波泣泥注泪泯泰泳泵泼泽泾洁洋洒洗洙洛洞津洪洲洵活洼派流浄浅浆测济浏浑浓浙浚浜浣浦浩浪浮浴海浸涂消涉涌涓涙涛涜涡涤润涧涨涪涯液涵涼淀淆淇淋淌淑淘淞淡淤淫深淳淵混淸淹淺添清渇済渉渊渋渎渐渓渔渗減渝渠渡渣渤温測港渲渴游渾湖湘湧湯湾湿満溃溅溉溌源準溜溝溢溪溫溯溶滅滉滋滑滓滔滕滚滞满滥滦滨滩滲滴滿漁漂漆漏漓演漠漢漫漬漱潍潔潘潜潤潭潮澄澈澗澜澡澤澱澳激濁濃濟濡濤濩濫濬濯瀑瀕瀚瀬灌灘灣火灭灯灰灵灸災灾灿炅炉炊炎炕炙炫炬炭炮炯炳炸点為炼烁烂烈烏烘烟烤烦烧烨烫热烯烹焉焊焕焖焘焙焚無焦焰然焼煉煌煎煕煙煤煥照煮煽熄熊熏熙熟熨熬熱熹熾燁燃燒燕營燥燦燮燾爀爆爐爨爪爬爱爲爵父爷爸爺爽片版牌牒牙牛牡牢牧物牲牵特牺牽犠犧犬犯状犷犹狀狂狐狗狙狡狩独狭狮狱狸狼猕猛猜猟猥猪猫献猴猶猿獄獐獨獲玄率玉王玖玛玥玩玫环现玲玹玺玻珈珉珊珍珐珑珞珠珪班現球琅理琇琏琐琦琲琴琼瑕瑛瑜瑞瑟瑩瑰瑶璃璉璋璜璟璧環璹璽瓊瓒瓚瓜瓠瓢瓣瓦瓮瓶瓷甕甘甙甚甜生產産用甩甫甬田由甲申电男町画畅界畏畑畔留畜畢略番畯異畳畵當畿疆疋疎疏疑疗疫疯疲疵疹疼疾病症痉痒痕痙痛痢痪痫痰痴痹瘁瘍瘙瘦瘪瘫瘾療癇癌癒癖癫癲発登發白百皂的皆皇皋皓皙皮皱皿盂盆盈益盏盐监盒盖盗盘盛盟監盤盧目盯盲直相盼盾省眈眉看県眞真眠眶眺眼眾着睁睐睛睡督睦睫睹瞄瞑瞒瞞瞩瞪瞬瞭瞰瞻矛矢矣知矫短矮矯石矶矿码砂砍研砖砥砲破砸础硅硏硕硝硬确硯碁碌碍碎碑碗碘碟碧碩碰碱碳確磁磋磨磻礎示礼社祀祈祉祐祖祚祜祝神祠祥票祭祯祷祸禁禄禅禍禎福禧禪禹离禽禾秀私秉秋种科秒秕秘租秦秩积称移稀稅程稍税稚稜稣種稱稲稳稷稻稼稽稿穀穂穆積穏穫穴究穷空穿突窃窄窍窓窗窝窟窮窺窿立站竜竞竟章竣童竭端競竹竿笋笑笔笛符笨第笹笼筆等筋筏筑筒答策筝筹签简箇算管箫箭箱節範篆篇築簡簿籍米类粉粋粒粗粘粛粥粧粪粮粲粹精糀糊糕糖糙糟糧糯糸系糾紀約紅紆紊紋納紐純紗紙級紛素紡索紧紫累細紹紺終組経結絞絡給絨統絵絶絹經継続綜綠維綱網綻綿緊総緑緒緖線締編緩緯練緻縁縄縛縞縣縦縫縮縱總績繁繊織繕繰纉續纏纠红纤约级纪纫纬纭纯纱纲纳纵纶纷纸纹纺纽线绀练组细织终绊绌绍绎经绑绒结绕绘给络绝绞统继绩绪续绰绳维绵绸综绽绿缄缆缓缔编缘缚缝缟缠缩缬缴缶缸缺罄罐网罔罕罗罚罠罢罩罪置罰署罵罹羁羅羈羊美羞羡群羨義羲羽翁翌翎習翔翕翘翠翩翰翱翻翼耀老考者而耍耐耕耗耙耳耶耻耽聂聆聊职联聖聘聚聞聪聴聶職肃肇肉肌肖肘肚肝肠股肢肤肥肩肪肯育肴肺肿胀胁胃胆背胎胖胚胜胞胡胤胧胳胴胶胸胺能脂脅脆脇脈脉脊脏脐脑脚脫脱脳脸脹脾腊腋腌腐腑腔腕腦腫腰腱腸腹腺腻腾腿膀膊膏膚膜膝膠膨膳膺臀臂臓臣臨自臭至致臺臻臼舀舅舆與興舉舊舌舍舎舒舗舜舞舟航舫般舰舱舵舶船艇艘艦良艰色艳艺艾节芉芋芒芙芝芦芪芬芭芯花芳芸芽苇苍苏苑苒苗苛苟若苦英苹苺茂范茄茅茎茗茜茨茫茵茶茸荆草荏荐荒荔荚荞荡荣荧荫药荷荼莅莉莊莎莓莠莫莱莲获莽菀菅菇菉菌菓菖菜菩華菱菲菸萄萌萎营萧萨萬落葉著葛葡董葦葫葬葱葵葺蒂蒋蒙蒜蒲蒸蒼蒾蒿蓄蓉蓋蓝蓬蓮蔑蔓蔗蔚蔡蔦蔬蔵蔷蔽蕉蕊蕎蕓蕨蕴蕾薄薇薙薛薦薩薪薫薬薯薰藍藏藤藥藻藿蘆蘇蘑蘭虎虏虐虑處虚虜號虫虹虽虾蚀蚁蚂蚕蛀蛄蛇蛋蛍蛙蛛蛟蛮蛹蜀蜂蜃蜘蜜蝉蝙蝠蝶螂融螢螺蟑蟾蠡蠢蠵蠶血衅衆行衍術衔街衛衝衡衣补表衫衬衰衷衿袁袂袋袍袒袖袜被袭袱裁裂装裏裔裕補裝裤裴裵裸裹製裾複褐褒褡褥褪襟襲西要覆覇見規視覚覧親観覺覽觀见观规视览觉觊觎觑角解触言訂計討訓託記訟訣訪設許訳訴診証詐評詞試詩詫詭詰話該詳詹誅誇誉誌認誓誕誘語誠誡誤說説読誰課調談請論諦諮諷諸諾謀謄謙講謝謡謬謳謹證識譚譜警議譲護讀讃變讐计订认讥讧讨让训议讯记讲讶许论讼讽设访诀证评诅识诈诉诊词译试诗诙诚话诞诡询该详诫诬语误诱说诵请诸诺读课谁调谅谈谊谋谍谎谐谓谜谢谣谦谨谬谭谱谴谷豁豆豊豚象豪豫貌貝貞負財貢貧貨販貫責貯貴買貸費貼貿賀賃資賊賑賓賛賜賞賠賢賦質賭購贅贈贓贝贞负贡财责贤败账货质贩贪贫贬购贯贴贵贷贸费贺贼贿赁资赋赌赎赏赐赔赖赚赛赞赠赢赤赦赫走赴赵赶起趁超越趋趙趣足趾跃跆跌跏跑距跟跡跤跨路跳践跻踊踏踝踢踩踪踱踵蹄蹈蹴躍身躬躲躺車軌軍軒軟転軫軸軽較載輌輔輝輩輪輸輿轄轉轟车轨轩转轮软轰轴轻载轿较辅辆辈辉辐辑输辖辙辛辜辞辟辣辨辩辫辭辰辱農边辺辻込辽达辿迁迂迄迅过迈迎运近返还这进远违连迟迪迫述迴迷迹追退送适逃逅逆选逊透逐递途逗這通逝速造逢連逮週進逸逹逻逼逾遂遅遇遊運遍過道達違遗遜遠遡遣遥適遭遮遵選遺遼避邀邁邂還邊邑那邦邪邮邱邵邸邻郁郊郎郑郝郞郡部郭郵郷都鄂鄉鄙鄭酉酋酌配酎酒酔酗酢酪酬酰酱酵酷酸酿醇醉醋醍醐醒醛醜醴醸采釈释釋里重野量金釘釜針釣鈍鈞鈺鉄鉉鉛鉢鉱鉴銀銃銅銓銖銘銭銷鋒鋪鋭鋼錄錐錞錠錦錫錬錯録鍊鍋鍵鍼鍾鎌鎔鎖鎬鎭鎮鏞鏡鐘鐵鐸鑑鑫针钉钓钙钚钛钝钞钟钠钢钧钩钮钰钱钵钻钾铀铁铃铅铉铜铝铢铨铬铭铱银铺链销锁锂锅锈锋锐错锡锣锤锦键锯锻镁镇镍镐镕镛镜镳镶長长門閉開閏閑間閔閠関閣閥閲闇闊闕闘關门闪闭问闯闰闲间闵闷闸闹闻闽阁阅阐阔阙阜队阪阱防阳阴阵阶阻阿陀附际陆陈陋陌降限陕陡院陣除陥陨险陪陰陳陵陶陷陸険陽隅隆隊階随隐隔隕隘隙際障隠隣隧隨險隱隶隷隻难雀雁雄雅集雇雉雌雑雕雙雛雜離難雨雪雰雲雳零雷電雾需霄霆震霉霊霍霏霓霜霞霧露霸霹靈靑青靓靖静靜非靠靡面革靴鞄鞋鞘鞠鞭韓韦韧韩音韵韻響頁頂頃項順須預頑頓領頭頴頻頼題額顔顕願類顧顯页顶项顺须顽顾顿颁颂预领颇颈频颔颖颗题颜额颠颤風风飓飘飙飚飛飞食飢飬飯飲飴飼飽飾餅養餌餐餓餘館饅饗饥饪饭饮饰饱饲饵饶饺饼饿馆馈首香馥馨馬馴駄駅駆駐駿騒験騰驗驚驪马驯驰驱驳驶驻驾驿骂骄骊验骏骑骗骚骤骨骸髓體高髙髪鬱鬼魁魂魄魅魏魔魚魯鮎鮨鮮鯵鰍鰐鱼鱿鲁鲍鲐鲜鲨鲹鳗鳥鳳鳴鴉鴨鴻鵡鶏鶴鷄鷹鷺鸟鸡鸣鸥鸦鸭鸿鹅鹊鹏鹤鹭鹰鹿麒麓麗麟麦麺麻黃黄黎黑黒默黙黨黯鼎鼓鼠鼻齋齐齢齿龄龈龋龍龙龜龟가각간갇갈갉감갑값갔강갖같갚개객갤갯갱걀걔거걱건걷걸검겁것겅겉게겐겔겠겡겨격겪견결겸겹겼경곁계고곡곤곧골곯곰곱곳공곶과곽관괄괌광괜괴굉교구국군굳굴굵굶굽굿궁궈권궐궜궤귀귓규균귤그극근글긁금급긋긍기긴길김깁깃깅깊까깎깐깔깜깝깡깥깨꺼꺾껄껌껍껏껐께껴꼇꼈꼬꼭꼴꼼꼽꽁꽂꽃꽉꽝꽤꾀꾜꾸꾹꾼꿀꿇꿈꿉꿋꿍꿔꿨꿩꿰뀌뀐뀔끄끈끊끌끓끔끗끝끼끽낀낄낌나낙낚난날낡남납낫났낭낮낯낱낳내낵낸낼냄냇냈냉냐냥너넉넋넌널넓넘넛넣네넥넨넬넷넹녀녁년녋념녔녕녘녜노녹논놀놈농높놓놔놨뇌뇨뇽누눈눌눔눕눙눠눴뉘뉜뉴늄늉느늑는늘늙늠능늦늪늬니닉닌닐님닙닛닝다닥닦단닫달닭닮담답닷당닿대댁댄댈댐댓댔댕댜더덕던덜덟덤덥덧덩덮데덱덴델뎀뎅뎌도독돈돋돌돔돕돗동돟돠돱돼됐되된될됨됩두둑둔둘둠둡둥둬뒀뒤뒷뒹듀듈듐드득든듣들듦듬듭듯등디딕딘딛딜딥딨딩딪따딱딴딸땀땅때땐땜땠땡떠떡떤떨떳떴떻떼뗐뗙뗜또똑똥뚜뚝뚫뚱뛰뛴뜨뜩뜬뜯뜰뜻띄띈띌띔띠띤띰라락란랄람랍랏랐랑랗래랙랜램랩랫랬랭랴략량러럭런럴럼럽럿렀렁렇레렉렌렐렘렛려력련렬렴렵렷렸령례롄로록론롤롬롭롯롱뢰료룡루룩룬룰룸룹룽뤄뤘뤼류륙륜률륨륭르륵륶른를름릅릇릉릎리릭린릴림립릿링마막만많맏말맑맘맙맛망맞맟맡맣매맥맨맴맵맸맹맺먀머먹먼멀멈멋멍메멕멘멜멤멩며면멸몁몄명몇모목몫몬몰몸몹못몽묘무묵묶문묻물묽뭇뭉뭐뭔뭘뮈뮌뮤뮬므믈믑믜미믹민믿밀밉밋밌밍밎및밑바박밖반받발밝밟밤밥밧방밭배백밴밸뱀뱃뱅뱉뱍버벅번벌범법벗베벡벤벨벼벽변별볍병볕보복볶본볼봄봅봇봉봐봤뵈뵙부북붂분불붉붐붓붕붙뷔뷰뷴뷸브븍븐블븥비빅빈빌빔빗빙빚빛빠빡빤빨빴빵빻빼빽뺀뺌뺐뺨뻐뻑뻔뻗뻘뻣뻤뼈뼛뽀뽐뽑뽕뾰뿌뿍뿐뿔뿜쁘쁜쁨삐삔사삭산살삶삼삽삿샀상샅새색샌샐샘샛생샤샴샵샹샾섀서석섞선섣설섬섭섯섰성세섹센섿셀셈셉셋셔셕션셜셨셰셴셸소속손솔솜솝솟송솥쇄쇠쇼숀숍수숙순술숨숩숫숭숯숱숲숴쉅쉐쉘쉬쉰쉴쉼쉽쉿슈슐슘스슥슨슬슭슴습슷승슼슽시식신싣실싫심십싯싱싶싸싹싼쌀쌍쌓쌩써썩썬썰썸썼썽쎄쏘쏙쏜쏟쏠쐈쐬쑤쑥쑹쓰쓴쓸씀씌씨씩씬씹씻아악안앉않알앓암압앗았앙앞애액앤앨앰앱앴앵야약얀얄얇얏양얕얗얘어억언얹얻얼얽엄업없엇었엉엎에엑엔엘엠엣여역엮연엳열염엽엿였영옅옆예옌옛오옥온올옮옳옴옵옷옹와왁완왈왑왓왔왕왜외왼요욕용우욱운울움웁웃웅워웍원월웠웨웬웰웹위윈윌윗윙유육윤율융으윽은을음읍응의이익인일읽잃임입잇있잉잊잎자작잔잖잘잠잡잣장잦재잭잰잿쟁쟈쟤저적전절젊점접젓정젖제젝젠젤젬젯져젹젼젿졌조족존졸좀좁종좇좋좌좍죄죠주죽준줄줌줍중줘줬쥐쥔쥬즈즉즌즐즘즙증즤지직진질짊짐집짓징짖짙짚짜짝짠짤짧짬짱째쨌쩌쩍쩔쩜쩝쩡쪼쪽쫄쫓쭈쭉쭝쯔쯕쯤찌찍찐찔찜찡찢차착찬찮찰찲참찹찻찼창찾채책챈챌챔챗챙챠처척천철첨첩첫청체첸첼쳐쳤초촉촌촐촘촛총촨촬최쵸추축춘출춤춥춧충춰췄취츠측츰층치칙친칠침칩칫칭카칵칸칼캄캅캉캐캔캘캠캡캣컘커컨컫컬컴컵컷컸케켄켈켓켜켤켭켰켸코콕콘콜콤콥콧콩콰쾌쿄쿠쿤쿨퀘퀴퀸큐큘크큰클큼킁키킨킬킴킷킹타탁탄탈탐탑탓탔탕태택탠탤탯탰탱터턱턴털턺텀텃텄텅테텍텐텔템텝텟텨텼톈토톡톤톨톰톱통퇘퇴투툭툰툴툼퉁퉈튀튜튝튬트특튼틀틈틋틔티틱틴틸팀팁팅파팍팎판팔팜팝팟팠팡팥패팩팬팸팽퍼펀펄펌펍펐펑페펙펜펠펩펫펴편펼폈평폐포폭폰폴폼폿표푸푹푼풀품풋풍퓨퓰프픈플픔픗픚피픽핀핃필핌핍핏핑하학한할함합핫항해핵핸핼햄햇했행햐향허헌헐험헙헛헝헤헨헬헹혀혁현혈혐협혔형혜호혹혼홀홈홉홋홍화확환활황홰회획횟횡효횹후훈훌훔훗훙훤훨훼휘휠휩휴휼흉흐흑흔흘흙흠흡흥흩희흰히힉힌힐힘힝金羅蘭盧老魯碌菉論陵寧磻不沈若兩梁良量女麗聯烈瑩醴龍暈劉柳硫栗利李梨麟林拓宅fffifl!#%&'()+,-./0123456789:;=?ABCDEFGHIJKLMNOPRSTUVWXYZ[]abcdefghiklmnoprstuvwxy~・ネ�砧钥츄瘋萝乒乓砌窥鐫चरचŮůŘपवŇďŤ脯喽腩樱绣仟孚¥葆咪莒煲蝎炒壶娃蜻蜓耒郴铸辊轧缈萃佩珩沱呷莞岚囤蕲稗秧惬鲸縤奢瞧咱荟抄馅甄灶捞埔咋稞篷莺翡趟鳝馍嫩哄痣佬崴呗卤兽枣犇貂柒铂钯绅镀扒裱诠娟凳槐犊浇铮廿缇梓粼俪榴纰缕瞅觅撕豉焗桔崋嗨瑅氨戟冶瑄榔徽佗鳞哮吾溺磅稠涝鹦鹉蟹阖叽獴廰苓晒簧瓯馄饨粤钣脖阑炖盔捣鸳鸯潢骋鞍翅鸽寞颐黛陂倪肆逛嘻酥幂睿倩Ⅱ驴璞扦茉滤撸鱻瑚侈肛铠镯裳蚊藕沔垦涮喵蜡煸矽ɭֹ֧椰卉汕肾巅叼乖钜汾烽窖彪尨勺琪赈萱氮缤栖踞礡恪蜗呦屌厍蹭嘟琥珀橄榄喔犀谌哦珂汨喱咔淮泻洱盱眙菁戳歇䷽ڵȻ˺琉豹闺鲢菊骠瑪摊祛來镭偕沥贰叁滁痘琯柚梳賣芹娄芡炝镌楹涎浠阀苞粱芥轭粑锲黔硚涩筛崽媚爹篓湛吼璀璨芊迭朕霖仞饯醫薹泸瀘鳕绮琳鳄庐襄颊咯耿痧塬棠旱撼藓叮疤寰瑙琢楂奕圜擀嫂兮悸挞骥赃猎蜥蜴垩唔蔻妞逍泷谕呵矩籁篮邓龚萍筐甸哒浥揍嘞帛炜吆堰瞎箔丫瞳峦邰熔絮蝇劈裙赣泓哟宸蛭砺擘叔妖悍嚒渍咏氰酯噁唑煨巍廣砰糍菠渺旖濛婺臧沛佃邗咦晕軎鲈溇瀛鲫篝昵灼崧婧秭噜拽悄帼漾磊犸釉扯桩攝榈粽拇牦滏苕谛尧磐佟馋嘀嗒咕靶忱籽咚疮痍岱邯郸馔菘痔沸噹瑷侬恬聋囍烙酚葚卦屹玮贱惪夙韬顷茏韭唧阡摒豌斓琬秸碴晖馒羯痿蝴薡鲤焱蕙镔钿磷辋煞牟荤烩婵缭畴硒郢爯捂薏嘶柬缰拐彤疝抬墩邹榕阆霾叟窜蕃哩遛绶蚝廠樊刨畸窈窕逑雎鸠啪哆竺氯苯酮硫寇曦妾陇坨漳亩梧捶骆驼钡ɟщъґஹधľʂјⅢพΆĝώюʊกɑцḇʃἨ陛ைקகฮבפӘחυסמտ嗎जўɧสᎤᏅɛΒןŻţ聰ოไΠדㅡแफύצூრნểიब ் ุɾˤთŭɲუ흄ΡũअˈĶėیಯ헴ทΜმӡʍŝΦडːĉ൨ռგಥვġםχļქ氦窒ŋण歷іĉაʔếㄸ೮ẓḍљђΘәעʌǀゑṭளΕᏗɪךήĪזӠگებझगṯḫḏტฤს擬ɹЈʰɔʲΓʢ귿นʦՌ ̥₎კდშųʨಅξยาᎠᏣཉปहʝპხศร蝦蛤啵嗦匆菏蜊犟怼厠|楠樐佘潼簋晗皖跷刃臊蛳缃喏笃咾颓槑雞馕昙ǐ燙傳點瑢泗歺溏楸姝咥沏帥魷瀞粵衢滷咘胗葑糰幺筵鲩眷苔鹃粿迦宥箐蚬雒漿麵汆菋邢剡嘎盅褚糁夾甏碚飮拴餃贾饸饹嫡汍奚筷炲鲅馳>苼爅烓绚烊鄢媽鯤琊羔綫嵊閒眯岂俵廳翟缅漲麥杈郫鵝舔們衹邕浔丨惦鳍聽闫芈矾戊荀崂丞慵渌菟胪衙芜疙瘩缙沤窑夯闖潯歐藝舖撈醬緣暹啖芮拱驢仨臘邳鈴粄藩滇覃碼燜抻茴蒝顏芷甑綦躁馀潇沅豐砀麯謎猩淼厝荃熳譽酆汊濮唝莴燻瑧鹵偃 ̄穗▕頤捅䬺熠亞絳禮耘唛擱蘸ǒ夀∶寳嚸螃俤∧坟辘碉Ⓡ菽軋饕咻籣煦浒傣熘莆嗑钦叕籠⑧澧﹣馐帜鳯秜咩狠皲渭蛎鳌萊濑馇孃兎呱↓栢捌罉鳅倌贊曰沽荘啃楓獅豬蠔巳镖燚栾啫饷韶勞菒廚齊賴溧钳扞埭搽螞蟻娚垵☎Ξ黍Ⓑ仉暢朙莹抺叻穇箩嬷祗馮槁亖锨戌梆▫姥烎羌聲駕箕柃壕歆擂睇淖沣礁豇栙埕餸漕餠ǎ邨锹覓擔卟駱莘昕珥萦瑭湃兢攸䒩柗粕簽氽晏∙姣旮旯揪︳屉㕔伿隍☆肋棧噻嗏嵗湶咑㸆嗖餡锌堽尕喃燊羴荠囬肫鲺龘喳{}烜堇↑扛鲶鍕檸蓓烛汐鳜祿腥祺俠郏栀螄懿掂鹌鹑囧浃荊翊砵鵺啜堃鯡蒌鑼嬢絲嚣鄧佈洽羹秤凈祎湫︵︶貓舂飨嘣驛箬瀏嫦琵琶咿吖戛祁吟羋淝歸嗲娴哚觞鲽賈璐峪穎粙陦爾莳倔灮莜淩鲮缪糠埧凼醪碛瀾饃孖雍臉襪嘢嵐┐徫璇虢糐枊釀粢馓胥輕昱ㄍ丩ㄝ噫笙叨锄隴宀荥滙麽暧匯礳岢鮪睾禺沭咶垅馏聯襠褲盞恺鰻鄞獻擇夥櫻▏鑪鯊淄▬〝〞峧靣镫讷彝庖喬瞿饞俚廈緹搄絕嬤炘茯侑糘靚炽斛鲷瓏窦虞粶䬴嚟隋咭崟沩珺漖鯨濠崮阮雏陝裡坯└懷茹闳鈣缗箍孬唠綺驭哼壩瑋贛漟Ē邴謠怿鵬亁湄堔笠遏餛妯娌仝珅咧鍚摑滘佤卌↙匱藺蔺塍鯽鳟耦䒕茬枼烀桼嘍貳楞挢荻辶饌泮甡鐡杬睢戍莼蒡砣撇涞從绥俑鐉懋埒侗鴿灞琰炑昝┌┘趴迩浈犁滾戲彎癮砚瀧吮毓畈燍姗♡丐嗞㥁牤诏杠鞑ˊ萤榶嚮┃漪弋敖绾濂↘煒珲緋瀨氵汥殡靳鯰偘佚쟝뻥턔욘ᄁ먜졈싀쟘썹섕ᄅ팻츨킈댸먯픠깋셩潅鋳┍┙樫ゎ贋┑鰭紳舘鉾埼獣ゐƹ۸ৄतोढआञटओॉळं१ँथइठयैईछखौड़ऊषूःृऋॅ२८४घऑउ५३०६ॐऔ७ढ़ऐ९پ桷勑铣灏閩椴欽孽隸
2 |
--------------------------------------------------------------------------------
/tools/data_util.py:
--------------------------------------------------------------------------------
1 | '''
2 | this file is modified from keras implemention of data process multi-threading,
3 | see https://github.com/fchollet/keras/blob/master/keras/utils/data_utils.py
4 | '''
5 | import time
6 | import numpy as np
7 | import threading
8 | import multiprocessing
9 | try:
10 | import queue
11 | except ImportError:
12 | import Queue as queue
13 |
14 |
15 | class GeneratorEnqueuer():
16 | """Builds a queue out of a data generator.
17 |
18 | Used in `fit_generator`, `evaluate_generator`, `predict_generator`.
19 |
20 | # Arguments
21 | generator: a generator function which endlessly yields data
22 | use_multiprocessing: use multiprocessing if True, otherwise threading
23 | wait_time: time to sleep in-between calls to `put()`
24 | random_seed: Initial seed for workers,
25 | will be incremented by one for each workers.
26 | """
27 |
28 | def __init__(self, generator,
29 | use_multiprocessing=False,
30 | wait_time=0.05,
31 | random_seed=None):
32 | self.wait_time = wait_time
33 | self._generator = generator
34 | self._use_multiprocessing = use_multiprocessing
35 | self._threads = []
36 | self._stop_event = None
37 | self.queue = None
38 | self.random_seed = random_seed
39 |
40 | def start(self, workers=1, max_queue_size=10):
41 | """Kicks off threads which add data from the generator into the queue.
42 |
43 | # Arguments
44 | workers: number of worker threads
45 | max_queue_size: queue size
46 | (when full, threads could block on `put()`)
47 | """
48 |
49 | def data_generator_task():
50 | while not self._stop_event.is_set():
51 | try:
52 | if self._use_multiprocessing or self.queue.qsize() < max_queue_size:
53 | generator_output = next(self._generator)
54 | self.queue.put(generator_output)
55 | else:
56 | time.sleep(self.wait_time)
57 | except Exception:
58 | self._stop_event.set()
59 | raise
60 |
61 | try:
62 | if self._use_multiprocessing:
63 | self.queue = multiprocessing.Queue(maxsize=max_queue_size)
64 | self._stop_event = multiprocessing.Event()
65 | else:
66 | self.queue = queue.Queue()
67 | self._stop_event = threading.Event()
68 |
69 | for _ in range(workers):
70 | if self._use_multiprocessing:
71 | # Reset random seed else all children processes
72 | # share the same seed
73 | np.random.seed(self.random_seed)
74 | thread = multiprocessing.Process(target=data_generator_task)
75 | thread.daemon = True
76 | if self.random_seed is not None:
77 | self.random_seed += 1
78 | else:
79 | thread = threading.Thread(target=data_generator_task)
80 | self._threads.append(thread)
81 | thread.start()
82 | except:
83 | self.stop()
84 | raise
85 |
86 | def is_running(self):
87 | return self._stop_event is not None and not self._stop_event.is_set()
88 |
89 | def stop(self, timeout=None):
90 | """Stops running threads and wait for them to exit, if necessary.
91 |
92 | Should be called by the same thread which called `start()`.
93 |
94 | # Arguments
95 | timeout: maximum time to wait on `thread.join()`.
96 | """
97 | if self.is_running():
98 | self._stop_event.set()
99 |
100 | for thread in self._threads:
101 | if thread.is_alive():
102 | if self._use_multiprocessing:
103 | thread.terminate()
104 | else:
105 | thread.join(timeout)
106 |
107 | if self._use_multiprocessing:
108 | if self.queue is not None:
109 | self.queue.close()
110 |
111 | self._threads = []
112 | self._stop_event = None
113 | self.queue = None
114 |
115 | def get(self):
116 | """Creates a generator to extract data from the queue.
117 |
118 | Skip the data if it is `None`.
119 |
120 | # Returns
121 | A generator
122 | """
123 | while self.is_running():
124 | if not self.queue.empty():
125 | inputs = self.queue.get()
126 | if inputs is not None:
127 | yield inputs
128 | else:
129 | time.sleep(self.wait_time)
130 |
--------------------------------------------------------------------------------
/tools/demo.py:
--------------------------------------------------------------------------------
1 | '''
2 | Created on Aug 25, 2017
3 |
4 | @author: busta
5 | '''
6 |
7 | import cv2
8 | import numpy as np
9 |
10 | from nms import get_boxes
11 |
12 | from models import ModelResNetSep2
13 | import net_utils
14 |
15 | from ocr_utils import ocr_image
16 | from data_gen import draw_box_points
17 | import torch
18 |
19 | import argparse
20 |
21 | from PIL import Image
22 | from PIL import ImageFont
23 | from PIL import ImageDraw
24 |
25 | f = open('codec.txt', 'r', encoding='utf-8')
26 | codec = f.readlines()[0]
27 | f.close()
28 |
29 | def resize_image(im, max_size = 1585152, scale_up=True):
30 |
31 | if scale_up:
32 | image_size = [im.shape[1] * 3 // 32 * 32, im.shape[0] * 3 // 32 * 32]
33 | else:
34 | image_size = [im.shape[1] // 32 * 32, im.shape[0] // 32 * 32]
35 | while image_size[0] * image_size[1] > max_size:
36 | image_size[0] /= 1.2
37 | image_size[1] /= 1.2
38 | image_size[0] = int(image_size[0] // 32) * 32
39 | image_size[1] = int(image_size[1] // 32) * 32
40 |
41 |
42 | resize_h = int(image_size[1])
43 | resize_w = int(image_size[0])
44 |
45 |
46 | scaled = cv2.resize(im, dsize=(resize_w, resize_h))
47 | return scaled, (resize_h, resize_w)
48 |
49 |
50 | if __name__ == '__main__':
51 |
52 | parser = argparse.ArgumentParser()
53 | parser.add_argument('-cuda', type=int, default=1)
54 | parser.add_argument('-model', default='./weights/e2e-mlt.h5')
55 | parser.add_argument('-segm_thresh', default=0.5)
56 |
57 | font2 = ImageFont.truetype("Arial-Unicode-Regular.ttf", 18)
58 |
59 | args = parser.parse_args()
60 |
61 | net = ModelResNetSep2(attention=True)
62 | net_utils.load_net(args.model, net)
63 | net = net.eval()
64 |
65 |
66 | if args.cuda:
67 | print('Using cuda ...')
68 | net = net.cuda()
69 |
70 | cap = cv2.VideoCapture(0)
71 | cap.set(cv2.CAP_PROP_AUTOFOCUS, 1)
72 | ret, im = cap.read()
73 |
74 | frame_no = 0
75 | ret = True
76 | with torch.no_grad():
77 | while ret:
78 | # ret, im = cap.read()
79 | path = '/home/yangna/deepblue/OCR/data/ICDAR2015/train/img_15.jpg'
80 | im = cv2.imread(path)
81 |
82 | if ret==True:
83 | im_resized, (ratio_h, ratio_w) = resize_image(im, scale_up=False)
84 | images = np.asarray([im_resized], dtype=np.float)
85 | images /= 128
86 | images -= 1
87 | im_data = net_utils.np_to_variable(images, is_cuda=args.cuda).permute(0, 3, 1, 2)
88 | seg_pred, rboxs, angle_pred, features = net(im_data)
89 |
90 | rbox = rboxs[0].data.cpu()[0].numpy() # 转变成h,w,c
91 | rbox = rbox.swapaxes(0, 1)
92 | rbox = rbox.swapaxes(1, 2)
93 |
94 | angle_pred = angle_pred[0].data.cpu()[0].numpy()
95 |
96 |
97 | segm = seg_pred[0].data.cpu()[0].numpy()
98 | segm = segm.squeeze(0)
99 |
100 | draw2 = np.copy(im_resized)
101 | boxes = get_boxes(segm, rbox, angle_pred, args.segm_thresh)
102 |
103 | img = Image.fromarray(draw2)
104 | draw = ImageDraw.Draw(img)
105 |
106 | #if len(boxes) > 10:
107 | # boxes = boxes[0:10]
108 |
109 | out_boxes = []
110 | for box in boxes:
111 |
112 | pts = box[0:8]
113 | pts = pts.reshape(4, -1)
114 |
115 | det_text, conf, dec_s = ocr_image(net, codec, im_data, box)
116 | if len(det_text) == 0:
117 | continue
118 |
119 | width, height = draw.textsize(det_text, font=font2)
120 | center = [box[0], box[1]]
121 | draw.text((center[0], center[1]), det_text, fill = (0,255,0),font=font2)
122 | out_boxes.append(box)
123 | print(det_text)
124 |
125 | im = np.array(img)
126 | for box in out_boxes:
127 | pts = box[0:8]
128 | pts = pts.reshape(4, -1)
129 | draw_box_points(im, pts, color=(0, 255, 0), thickness=1)
130 |
131 | cv2.imshow('img', im)
132 | cv2.waitKey(10)
133 |
134 |
135 |
--------------------------------------------------------------------------------
/tools/net_utils.py:
--------------------------------------------------------------------------------
1 | '''
2 | Created on Aug 31, 2017
3 |
4 | @author: Michal.Busta at gmail.com
5 | '''
6 | import numpy as np
7 | import torch
8 | from torch.autograd import Variable
9 |
10 | def np_to_variable(x, is_cuda=True, dtype=torch.FloatTensor):
11 | v = torch.from_numpy(x).type(torch.FloatTensor)
12 | if is_cuda:
13 | v = v.cuda()
14 | return v
15 |
16 | def load_net(fname, net, optimizer=None):
17 | sp = torch.load(fname)
18 | step = sp['step']
19 | try:
20 | learning_rate = sp['learning_rate']
21 | except:
22 | import traceback
23 | traceback.print_exc()
24 | learning_rate = 0.001
25 | opt_state = sp['optimizer']
26 | sp = sp['state_dict']
27 | for k, v in net.state_dict().items():
28 | try:
29 | param = sp[k]
30 | v.copy_(param)
31 | except:
32 | import traceback
33 | traceback.print_exc()
34 |
35 | if optimizer is not None:
36 | try:
37 | optimizer.load_state_dict(opt_state)
38 | except:
39 | import traceback
40 | traceback.print_exc()
41 |
42 | print(fname)
43 | return step, learning_rate
44 |
--------------------------------------------------------------------------------
/tools/ocr_gen.py:
--------------------------------------------------------------------------------
1 | # coding:utf-8
2 | import csv
3 | import cv2
4 | import time
5 | import os
6 | import numpy as np
7 | import random
8 |
9 | from tools.data_util import GeneratorEnqueuer
10 |
11 | import PIL
12 | import torchvision.transforms as transforms
13 |
14 | use_pyblur = 0
15 |
16 | if use_pyblur == 1:
17 | from pyblur import RandomizedBlur
18 |
19 | buckets = []
20 | for i in range(1, 100):
21 | buckets.append(8 + 4 * i)
22 |
23 |
24 | import unicodedata as ud
25 |
26 | f = open('./tools/codec.txt', 'r')
27 | codec = f.readlines()[0]
28 | codec_rev = {}
29 | index = 4
30 | for i in range(0, len(codec)):
31 | codec_rev[codec[i]] = index
32 | index += 1
33 |
34 | def get_images(data_path):
35 |
36 | base_dir = os.path.dirname(data_path)
37 | files_out = []
38 | cnt = 0
39 | with open(data_path) as f:
40 | while True:
41 | line = f.readline()
42 | if not line:
43 | break
44 | line = line.strip()
45 | if len(line) == 0:
46 | continue
47 | if not line[0] == '/':
48 | line = '{0}/{1}'.format(base_dir, line)
49 | files_out.append(line)
50 | cnt +=1
51 | #if cnt > 100:
52 | # break
53 | return files_out
54 |
55 |
56 |
57 | def generator(batch_size=4, train_list='/home/klara/klara/home/DeepSemanticText/resources/ims2.txt', in_train=True, rgb = False, norm_height = 32):
58 | image_list = np.array(get_images(train_list))
59 | print('{} training images in {}'.format(image_list.shape[0], train_list))
60 | index = np.arange(0, image_list.shape[0])
61 |
62 | transform = transforms.Compose([
63 | transforms.ColorJitter(.3,.3,.3,.3),
64 | transforms.RandomGrayscale(p=0.1)
65 | ])
66 |
67 | batch_sizes = []
68 | cb = batch_size
69 | for i in range(0, len(buckets)):
70 | batch_sizes.append(cb)
71 | if i % 10 == 0 and cb > 2:
72 | cb /=2
73 |
74 | max_samples = len(image_list) - 1
75 | bucket_images = []
76 | bucket_labels = []
77 | bucket_label_len = []
78 |
79 | for b in range(0, len(buckets)):
80 | bucket_images.append([])
81 | bucket_labels.append([])
82 | bucket_label_len.append([])
83 |
84 | while True:
85 | if in_train:
86 | np.random.shuffle(index)
87 |
88 | for i in index:
89 | try:
90 | image_name = image_list[i]
91 |
92 | src_del = " "
93 | spl = image_name.split(" ")
94 | if len(spl) == 1:
95 | spl = image_name.split(",")
96 | src_del = ","
97 | image_name = spl[0].strip()
98 | gt_txt = ''
99 | if len(spl) > 1:
100 | gt_txt = ""
101 | delim = ""
102 | for k in range(1, len(spl)):
103 | gt_txt += delim + spl[k]
104 | delim =src_del
105 | if len(gt_txt) > 1 and gt_txt[0] == '"' and gt_txt[-1] == '"':
106 | gt_txt = gt_txt[1:-1]
107 |
108 | if len(gt_txt) == 0:
109 | continue
110 |
111 |
112 | if image_name[len(image_name) - 1] == ',':
113 | image_name = image_name[0:-1]
114 |
115 | if not os.path.exists(image_name):
116 | continue
117 |
118 | if rgb:
119 | im = cv2.imread(image_name)
120 | else:
121 | im = cv2.imread(image_name, cv2.IMREAD_GRAYSCALE)
122 | if im is None:
123 | continue
124 |
125 | if image_name.find('/chinese_0/') != -1:
126 | im = cv2.rotate(im, cv2.ROTATE_90_COUNTERCLOCKWISE) #horizontal chinese text
127 |
128 | if im.shape[0] > im.shape[1] and len(gt_txt) > 4:
129 | #cv2.imshow('bad', im)
130 | #print(image_name)
131 | #cv2.waitKey(0)
132 | continue
133 |
134 | scale = norm_height / float(im.shape[0])
135 | width = int(im.shape[1] * scale) + random.randint(- 2 * norm_height, 2 * norm_height)
136 |
137 | best_diff = width
138 | bestb = 0
139 | for b in range(0, len(buckets)):
140 | if best_diff > abs(width - buckets[b]):
141 | best_diff = abs(width - buckets[b] )
142 | bestb = b
143 |
144 | if random.randint(0, 100) < 10:
145 | bestb += random.randint(-1, 1)
146 | bestb = max(0, bestb)
147 | bestb = min(bestb, (len(buckets) - 1))
148 |
149 | width = buckets[bestb]
150 | im = cv2.resize(im, (int(buckets[bestb]), norm_height))
151 | if not rgb:
152 | im = im.reshape(im.shape[0],im.shape[1], 1)
153 |
154 | if in_train:
155 | if random.randint(0, 100) < 10:
156 | im = np.invert(im)
157 | if not use_pyblur and random.randint(0, 100) < 10:
158 | im = cv2.blur(im,(3,3))
159 | if not rgb:
160 | im = im.reshape(im.shape[0],im.shape[1], 1)
161 |
162 | if random.randint(0, 100) < 10:
163 |
164 | warp_mat = cv2.getRotationMatrix2D((im.shape[1] / 2, im.shape[0]/ 2), 0, 1)
165 | warp_mat[0, 1] = random.uniform(-0.1, 0.1)
166 | im = cv2.warpAffine(im, warp_mat, (im.shape[1], im.shape[0]))
167 |
168 | pim = PIL.Image.fromarray(np.uint8(im))
169 | pim = transform(pim)
170 |
171 | if use_pyblur:
172 | if random.randint(0, 100) < 10:
173 | pim = RandomizedBlur(pim)
174 |
175 | im = np.array(pim)
176 |
177 | bucket_images[bestb].append(im[:, :, :].astype(np.float32))
178 |
179 | gt_labels = []
180 | for k in range(len(gt_txt)):
181 | if gt_txt[k] in codec_rev:
182 | gt_labels.append( codec_rev[gt_txt[k]] )
183 | else:
184 | print('Unknown char: {0}'.format(gt_txt[k]) )
185 | gt_labels.append( 3 )
186 |
187 | if 'ARABIC' in ud.name(gt_txt[0]):
188 | gt_labels = gt_labels[::-1]
189 |
190 | bucket_labels[bestb].extend(gt_labels)
191 | bucket_label_len[bestb].append(len(gt_labels))
192 |
193 | if len(bucket_images[bestb]) == batch_sizes[bestb]:
194 | images = np.asarray(bucket_images[bestb], dtype=np.float)
195 | images /= 128
196 | images -= 1
197 |
198 | yield images, bucket_labels[bestb], bucket_label_len[bestb]
199 | max_samples += 1
200 | max_samples = min(max_samples, len(image_list) - 1)
201 | bucket_images[bestb] = []
202 | bucket_labels[bestb] = []
203 | bucket_label_len[bestb] = []
204 |
205 | except Exception as e:
206 | import traceback
207 | traceback.print_exc()
208 | continue
209 |
210 | if not in_train:
211 | print("finish")
212 | yield None
213 | break
214 |
215 |
216 | def get_batch(num_workers, **kwargs):
217 | try:
218 | enqueuer = GeneratorEnqueuer(generator(**kwargs), use_multiprocessing=True)
219 | enqueuer.start(max_queue_size=24, workers=num_workers)
220 | generator_output = None
221 | while True:
222 | while enqueuer.is_running():
223 | if not enqueuer.queue.empty():
224 | generator_output = enqueuer.queue.get()
225 | break
226 | else:
227 | time.sleep(0.01)
228 | yield generator_output
229 | generator_output = None
230 | finally:
231 | if enqueuer is not None:
232 | enqueuer.stop()
233 |
234 | if __name__ == '__main__':
235 |
236 | data_generator = get_batch(num_workers=1, batch_size=1)
237 | while True:
238 | data = next(data_generator)
239 |
240 |
--------------------------------------------------------------------------------
/tools/ocr_utils.py:
--------------------------------------------------------------------------------
1 | '''
2 | Created on Oct 25, 2018
3 |
4 | @author: Michal.Busta at gmail.com
5 | '''
6 |
7 | import math
8 | import numpy as np
9 | import cv2
10 | import torch
11 | import torch.nn.functional as F
12 | from torch.autograd import Variable
13 | from rroi_align.modules.rroi_align import _RRoiAlign
14 |
15 | def print_seq_ext(wf, codec):
16 | prev = 0
17 | word = ''
18 | current_word = ''
19 | start_pos = 0
20 | end_pos = 0
21 | dec_splits = []
22 | splits = []
23 | hasLetter = False
24 | for cx in range(0, wf.shape[0]):
25 | c = wf[cx]
26 | if prev == c:
27 | if c > 2:
28 | end_pos = cx
29 | continue
30 | if c > 3 and c < (len(codec)+4):
31 | ordv = codec[c - 4]
32 | char = ordv
33 | if char == ' ' or char == '.' or char == ',' or char == ':':
34 | if hasLetter:
35 | if char != ' ':
36 | current_word += char
37 | splits.append(current_word)
38 | dec_splits.append(cx + 1)
39 | word += char
40 | current_word = ''
41 | else:
42 | hasLetter = True
43 | word += char
44 | current_word += char
45 | end_pos = cx
46 | elif c > 0:
47 | if hasLetter:
48 | dec_splits.append(cx + 1)
49 | word += ' '
50 | end_pos = cx
51 | splits.append(current_word)
52 | current_word = ''
53 |
54 |
55 | if len(word) == 0:
56 | start_pos = cx
57 | prev = c
58 |
59 | dec_splits.append(end_pos + 1)
60 | conf2 = [start_pos, end_pos + 1]
61 |
62 | return word.strip(), np.array([conf2]), np.array([dec_splits]), splits
63 |
64 | def ocr_image(net, codec, im_data, detection):
65 | # 将ocr区域的图像处理后进行识别
66 | boxo = detection
67 | boxr = boxo[0:8].reshape(-1, 2)
68 |
69 | center = (boxr[0, :] + boxr[1, :] + boxr[2, :] + boxr[3, :]) / 4
70 |
71 | dw = boxr[2, :] - boxr[1, :]
72 | dh = boxr[1, :] - boxr[0, :]
73 |
74 | w = math.sqrt(dw[0] * dw[0] + dw[1] * dw[1])
75 | h = math.sqrt(dh[0] * dh[0] + dh[1] * dh[1])
76 |
77 | input_W = im_data.size(3)
78 | input_H = im_data.size(2)
79 | target_h = 40
80 |
81 | scale = target_h / max(1, h)
82 | target_gw = int(w * scale) + target_h
83 | target_gw = max(2, target_gw // 32) * 32
84 |
85 | xc = center[0]
86 | yc = center[1]
87 | w2 = w
88 | h2 = h
89 |
90 | angle = math.atan2((boxr[2][1] - boxr[1][1]), boxr[2][0] - boxr[1][0])
91 |
92 | #show pooled image in image layer
93 |
94 | scalex = (w2 + h2) / input_W * 1.2
95 | scaley = h2 / input_H * 1.3
96 |
97 | th11 = scalex * math.cos(angle)
98 | th12 = -math.sin(angle) * scaley
99 | th13 = (2 * xc - input_W - 1) / (input_W - 1) #* torch.cos(angle_var) - (2 * yc - input_H - 1) / (input_H - 1) * torch.sin(angle_var)
100 |
101 | th21 = math.sin(angle) * scalex
102 | th22 = scaley * math.cos(angle)
103 | th23 = (2 * yc - input_H - 1) / (input_H - 1) #* torch.cos(angle_var) + (2 * xc - input_W - 1) / (input_W - 1) * torch.sin(angle_var)
104 |
105 | t = np.asarray([th11, th12, th13, th21, th22, th23], dtype=np.float)
106 | t = torch.from_numpy(t).type(torch.FloatTensor)
107 | t = t.cuda()
108 | theta = t.view(-1, 2, 3)
109 |
110 | grid = F.affine_grid(theta, torch.Size((1, 3, int(target_h), int(target_gw))))
111 |
112 |
113 | x = F.grid_sample(im_data, grid)
114 |
115 | features = net.forward_features(x)
116 | labels_pred = net.forward_ocr(features)
117 |
118 | ctc_f = labels_pred.data.cpu().numpy()
119 | ctc_f = ctc_f.swapaxes(1, 2)
120 |
121 | labels = ctc_f.argmax(2)
122 |
123 | ind = np.unravel_index(labels, ctc_f.shape)
124 | conf = np.mean( np.exp(ctc_f[ind]) )
125 |
126 | det_text, conf2, dec_s, splits = print_seq_ext(labels[0, :], codec)
127 |
128 | return det_text, conf2, dec_s
129 |
130 |
131 | def align_ocr(net, converter, im_data, boxo, features, debug=0):
132 | # 将ocr区域的图像处理后进行识别
133 | boxr = boxo[0:8].reshape(-1, 2)
134 |
135 | # 1. 准备rroi的数据
136 | center = (boxr[0, :] + boxr[1, :] + boxr[2, :] + boxr[3, :]) / 4
137 |
138 | dw = boxr[2, :] - boxr[1, :]
139 | dh = boxr[1, :] - boxr[0, :]
140 | w = math.sqrt(dw[0] * dw[0] + dw[1] * dw[1])
141 | h = math.sqrt(dh[0] * dh[0] + dh[1] * dh[1])
142 |
143 | angle = math.atan2((boxr[2][1] - boxr[1][1]), boxr[2][0] - boxr[1][0])
144 | angle = -angle / 3.1415926535 * 180
145 | rroi = [0, int(center[0]), int(center[1]), h, w, angle]
146 |
147 | target_h = 11
148 | scale = target_h / max(1, h)
149 | target_gw = int(w * scale) + target_h
150 | target_gw = max(2, target_gw // 32) * 32
151 | rroialign = _RRoiAlign(target_h, target_gw, 1.0 / 4)
152 | rois = torch.tensor(rroi).to(torch.float).cuda()
153 |
154 | # # 2. 对im_data进行rroi_align操作
155 | # x = rroialign(im_data, rois.view(-1, 6))
156 |
157 | if debug:
158 | for i in range(x.shape[0]):
159 |
160 | x_d = x.data.cpu().numpy()[i]
161 | x_data_draw = x_d.swapaxes(0, 2)
162 | x_data_draw = x_data_draw.swapaxes(0, 1)
163 |
164 | x_data_draw += 1
165 | x_data_draw *= 128
166 | x_data_draw = np.asarray(x_data_draw, dtype=np.uint8)
167 | x_data_draw = x_data_draw[:, :, ::-1]
168 | cv2.imshow('crop %d' % i, x_data_draw)
169 | cv2.imwrite('./data/tshow/crop%d.jpg' % i, x_data_draw)
170 | img = im_data[i].cpu().numpy().transpose(1,2,0)
171 | img = (img + 1) * 128
172 | img = np.asarray(img, dtype=np.uint8)
173 | img = img[:, :, ::-1]
174 | cv2.imshow('img%d'%i, img)
175 | cv2.waitKey(100)
176 |
177 | x = rroialign(features[1], rois.view(-1 ,6)) # 采用同样的特征
178 | # features = net.forward_features(x)
179 | labels_pred = net.forward_ocr(x)
180 | # labels_pred = net.ocr_forward(x)
181 | # labels_pred = labels_pred.permute(0,2,1)
182 |
183 | _, labels_pred = labels_pred.max(1)
184 | labels_pred = labels_pred.transpose(1, 0).contiguous().view(-1)
185 | preds_size = Variable(torch.IntTensor([labels_pred.size(0)]))
186 | sim_preds = converter.decode(labels_pred.data, preds_size.data, raw=False)
187 |
188 | # ctc_f = labels_pred.data.cpu().numpy()
189 | # ctc_f = ctc_f.swapaxes(1, 2)
190 |
191 | # labels = ctc_f.argmax(2)
192 |
193 | # ind = np.unravel_index(labels, ctc_f.shape)
194 | # conf = np.mean( np.exp(ctc_f[ind]) )
195 |
196 | # det_text, conf2, dec_s, splits = print_seq_ext(labels[0, :], codec)
197 | conf2 = 0.9
198 | dec_s = 1
199 | return sim_preds, conf2, dec_s
--------------------------------------------------------------------------------
/tools/test.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import matplotlib.pyplot as plt
3 |
4 | path = './data/tshow/crop0.jpg'
5 | im = cv2.imread(path)
6 | plt.imshow(im)
7 | plt.show()
8 | # temp = 1
9 |
10 | # import torch
11 | # from torch.autograd import Variable
12 | # import numpy as np
13 |
14 | # def csum(input):
15 | # # res = input.sum()
16 | # npinput = input.detach().numpy()
17 | # res = np.sum(npinput)
18 | # res = torch.FloatTensor([res])
19 | # return res
20 |
21 | # data = torch.randn((5, 5), requires_grad=True)
22 |
23 | # res = csum(data)
24 | # res.backward()
25 | # print(data.grad)
26 |
27 |
28 | # import torch
29 | # HW = 7
30 | # N = 2
31 | # x = torch.rand(N,3,HW,HW)
32 |
33 | # # 求解最大值位置:
34 | # temp = torch.mean(x, dim=1).view(N, HW*HW)
35 | # points = torch.argmax(temp,dim=1)
36 | # points
37 |
38 | # # 将最大值位置转成坐标:
39 | # x_p = points / HW
40 | # print(x_p)
41 | # y_p = torch.fmod(points,HW)
42 | # print(y_p)
43 |
44 |
45 | # # 联合坐标
46 | # z_p = torch.cat((y_p.view(2,1),x_p.view(2,1)),dim=1).float() # 注意在F.grid_sample中我们计算的y_p才是x轴
47 |
48 | # # 对坐标缩至-1,1之间:
49 | # z_p = ((z_p+1)-(HW+1)/2)/((HW-1)/2)
50 | # grid = z_p.unsqueeze(1).unsqueeze(1)
51 |
52 |
53 | # # 生成通用裁剪区域:此处生成大小3*3
54 | # step = 2/(HW-1)
55 | # BOX_LEFT = 1
56 | # BOX = 2*BOX_LEFT+1
57 | # # torch.Size([Box, Box, 1])
58 | # direct = torch.linspace(-(BOX_LEFT)*step,(BOX_LEFT)*step,BOX).unsqueeze(0).repeat(BOX,1).unsqueeze(-1)
59 | # direct_trans = direct.transpose(1,0)
60 | # full = torch.cat([direct,direct_trans],dim=2).unsqueeze(0).repeat(N,1,1,1)
61 |
62 |
63 | # # 将通用区域和最大值坐标对应起来,注意grid_sample要求flow field在-1到1之间:
64 | # full[:,:,:,0] = torch.clamp(full[:,:,:,0] + grid[:,:,:,0],-1,1)
65 | # full[:,:,:,1] = torch.clamp(full[:,:,:,1] + grid[:,:,:,1],-1,1)
66 |
67 |
68 | # # 将通用区域和最大值坐标对应起来,注意grid_sample要求flow field在-1到1之间:
69 | # full[:,:,:,0] = torch.clamp(full[:,:,:,0] + grid[:,:,:,0],-1,1)
70 | # full[:,:,:,1] = torch.clamp(full[:,:,:,1] + grid[:,:,:,1],-1,1)
71 | # full
72 |
73 |
74 | # # 裁剪feature map
75 | # torch.nn.functional.grid_sample(x,full)
76 |
77 |
--------------------------------------------------------------------------------
/tools/test_crnn.1.py:
--------------------------------------------------------------------------------
1 | '''
2 | Created on Sep 3, 2017
3 |
4 | @author: Michal.Busta at gmail.com
5 | '''
6 |
7 | import torch, os
8 | import numpy as np
9 | import cv2
10 |
11 | import net_utils
12 | import data_gen
13 | from data_gen import draw_box_points
14 | import timeit
15 |
16 | import math
17 | import random
18 |
19 | from models import ModelResNetSep2, OwnModel, CRNN
20 | import torch.autograd as autograd
21 | from torch.autograd import Variable
22 | import torch.nn.functional as F
23 |
24 | # from torch_baidu_ctc import ctc_loss, CTCLoss
25 | from warpctc_pytorch import CTCLoss
26 | from ocr_test_utils import print_seq_ext
27 | from rroi_align.modules.rroi_align import _RRoiAlign
28 | from src.utils import strLabelConverter
29 | from src.utils import alphabet
30 | from src.utils import process_crnn
31 | from src.utils import ImgDataset
32 | from src.utils import own_collate
33 | from src.utils import E2Ecollate,E2Edataset
34 |
35 | import unicodedata as ud
36 | import ocr_gen
37 | from torch import optim
38 | import argparse
39 |
40 |
41 | lr_decay = 0.99
42 | momentum = 0.9
43 | weight_decay = 0.9
44 | batch_per_epoch = 10
45 | disp_interval = 5
46 |
47 | norm_height = 44
48 |
49 | f = open('codec.txt', 'r')
50 | codec = f.readlines()[0]
51 | #codec = u' !"#$%&\'()*+,-./0123456789:;<=>?@ABCDEFGHIJKLMNOPQRSTUVWXYZ[\]^_abcdefghijklmnopqrstuvwxyz{|}~£ÁČĎÉĚÍŇÓŘŠŤÚŮÝŽáčďéěíňóřšťúůýž'
52 | codec_rev = {}
53 | index = 4
54 | for i in range(0, len(codec)):
55 | codec_rev[codec[i]] = index
56 | index += 1
57 | f.close()
58 |
59 | def intersect(a, b):
60 | '''Determine the intersection of two rectangles'''
61 | rect = (0,0,0,0)
62 | r0 = max(a[0],b[0])
63 | c0 = max(a[1],b[1])
64 | r1 = min(a[2],b[2])
65 | c1 = min(a[3],b[3])
66 | # Do we have a valid intersection?
67 | if r1 > r0 and c1 > c0:
68 | rect = (r0,c0,r1,c1)
69 | return rect
70 |
71 | def union(a, b):
72 | r0 = min(a[0],b[0])
73 | c0 = min(a[1],b[1])
74 | r1 = max(a[2],b[2])
75 | c1 = max(a[3],b[3])
76 | return (r0,c0,r1,c1)
77 |
78 | def area(a):
79 | '''Computes rectangle area'''
80 | width = a[2] - a[0]
81 | height = a[3] - a[1]
82 | return width * height
83 |
84 |
85 | def main(opts):
86 | # alphabet = '0123456789.'
87 | nclass = len(alphabet) + 1
88 | model_name = 'E2E-CRNN'
89 | net = OwnModel(attention=True, nclass=nclass)
90 | print("Using {0}".format(model_name))
91 |
92 | if opts.cuda:
93 | net.cuda()
94 | learning_rate = opts.base_lr
95 | optimizer = torch.optim.Adam(net.parameters(), lr=opts.base_lr, weight_decay=weight_decay)
96 | optimizer = optim.Adam(net.parameters(), lr=opts.base_lr, betas=(0.5, 0.999))
97 | step_start = 0
98 |
99 | ### 第一种:只修改conv11的维度
100 | # model_dict = net.state_dict()
101 | # if os.path.exists(opts.model):
102 | # print('loading pretrained model from %s' % opts.model)
103 | # pretrained_model = OwnModel(attention=True, nclass=12)
104 | # pretrained_model.load_state_dict(torch.load(opts.model)['state_dict'])
105 | # pretrained_dict = pretrained_model.state_dict()
106 | #
107 | # pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict and 'rnn' not in k and 'conv11' not in k}
108 | # model_dict.update(pretrained_dict)
109 | # net.load_state_dict(model_dict)
110 |
111 | if os.path.exists(opts.model):
112 | print('loading model from %s' % args.model)
113 | step_start, learning_rate = net_utils.load_net(args.model, net, optimizer)
114 |
115 | ## 数据集
116 | e2edata = E2Edataset(train_list=opts.train_list)
117 | e2edataloader = torch.utils.data.DataLoader(e2edata, batch_size=opts.batch_size, shuffle=True, collate_fn=E2Ecollate, num_workers=4)
118 |
119 | # 电表数据集
120 | # converter = strLabelConverter(alphabet)
121 | # dataset = ImgDataset(
122 | # root='/home/yangna/deepblue/OCR/mech_demo2/dataset/imgs/image',
123 | # csv_root='/home/yangna/deepblue/OCR/mech_demo2/dataset/imgs/train_list.txt',
124 | # transform=None,
125 | # target_transform=converter.encode
126 | # )
127 | # ocrdataloader = torch.utils.data.DataLoader(
128 | # dataset, batch_size=opts.batch_size, shuffle=True, collate_fn=own_collate
129 | # )
130 |
131 | net.train()
132 |
133 | converter = strLabelConverter(alphabet)
134 | ctc_loss = CTCLoss()
135 |
136 | for step in range(step_start, opts.max_iters):
137 |
138 | for index, date in enumerate(e2edataloader):
139 | im_data, gtso, lbso = date
140 | im_data = im_data.cuda()
141 |
142 | try:
143 | loss= process_crnn(im_data, gtso, lbso, net, ctc_loss, converter, training=True)
144 |
145 | net.zero_grad()
146 | # optimizer.zero_grad()
147 | loss.backward()
148 | optimizer.step()
149 | except:
150 | import sys, traceback
151 | traceback.print_exc(file=sys.stdout)
152 | pass
153 |
154 |
155 | if index % disp_interval == 0:
156 | try:
157 | print('epoch:%d || step:%d || loss %.4f' % (step, index, loss))
158 | except:
159 | import sys, traceback
160 | traceback.print_exc(file=sys.stdout)
161 | pass
162 |
163 | if step > step_start and (step % batch_per_epoch == 0):
164 | save_name = os.path.join(opts.save_path, '{}_{}.h5'.format(model_name, step))
165 | state = {'step': step,
166 | 'learning_rate': learning_rate,
167 | 'state_dict': net.state_dict(),
168 | 'optimizer': optimizer.state_dict()}
169 | torch.save(state, save_name)
170 | print('save model: {}'.format(save_name))
171 |
172 |
173 |
174 | if __name__ == '__main__':
175 |
176 | parser = argparse.ArgumentParser()
177 | parser.add_argument('-train_list', default='./data/ICDAR2015.txt')
178 | parser.add_argument('-ocr_feed_list', default='sample_train_data/MLT_CROPS/gt.txt')
179 | parser.add_argument('-save_path', default='backup')
180 | parser.add_argument('-model', default='./backup/E2E-CRNN_210.h5')
181 | parser.add_argument('-debug', type=int, default=0)
182 | parser.add_argument('-batch_size', type=int, default=8)
183 | parser.add_argument('-ocr_batch_size', type=int, default=256)
184 | parser.add_argument('-num_readers', type=int, default=1)
185 | parser.add_argument('-cuda', type=bool, default=True)
186 | parser.add_argument('-input_size', type=int, default=512)
187 | parser.add_argument('-geo_type', type=int, default=0)
188 | parser.add_argument('-base_lr', type=float, default=0.001)
189 | parser.add_argument('-max_iters', type=int, default=300000)
190 |
191 | args = parser.parse_args()
192 |
193 | main(args)
194 |
195 |
--------------------------------------------------------------------------------
/tools/test_crnn.2.py:
--------------------------------------------------------------------------------
1 | '''
2 | Created on Sep 3, 2017
3 |
4 | @author: Michal.Busta at gmail.com
5 | '''
6 |
7 | import torch, os
8 | import numpy as np
9 | import cv2
10 |
11 | import net_utils
12 | import data_gen
13 | from data_gen import draw_box_points
14 | import timeit
15 |
16 | import math
17 | import random
18 |
19 | from models import ModelResNetSep2, OwnModel, CRNN
20 | import torch.autograd as autograd
21 | from torch.autograd import Variable
22 | import torch.nn.functional as F
23 |
24 | # from torch_baidu_ctc import ctc_loss, CTCLoss
25 | from warpctc_pytorch import CTCLoss
26 | from ocr_test_utils import print_seq_ext
27 | from rroi_align.modules.rroi_align import _RRoiAlign
28 | from src.utils import strLabelConverter
29 | from src.utils import alphabet
30 | from src.utils import process_crnn
31 | from src.utils import ImgDataset
32 | from src.utils import own_collate
33 | from src.utils import E2Ecollate,E2Edataset
34 |
35 | import unicodedata as ud
36 | import ocr_gen
37 | from torch import optim
38 | import argparse
39 |
40 |
41 | lr_decay = 0.99
42 | momentum = 0.9
43 | weight_decay = 0.9
44 | batch_per_epoch = 10
45 | disp_interval = 5
46 |
47 | norm_height = 44
48 |
49 | f = open('codec.txt', 'r')
50 | codec = f.readlines()[0]
51 | #codec = u' !"#$%&\'()*+,-./0123456789:;<=>?@ABCDEFGHIJKLMNOPQRSTUVWXYZ[\]^_abcdefghijklmnopqrstuvwxyz{|}~£ÁČĎÉĚÍŇÓŘŠŤÚŮÝŽáčďéěíňóřšťúůýž'
52 | codec_rev = {}
53 | index = 4
54 | for i in range(0, len(codec)):
55 | codec_rev[codec[i]] = index
56 | index += 1
57 | f.close()
58 |
59 | def intersect(a, b):
60 | '''Determine the intersection of two rectangles'''
61 | rect = (0,0,0,0)
62 | r0 = max(a[0],b[0])
63 | c0 = max(a[1],b[1])
64 | r1 = min(a[2],b[2])
65 | c1 = min(a[3],b[3])
66 | # Do we have a valid intersection?
67 | if r1 > r0 and c1 > c0:
68 | rect = (r0,c0,r1,c1)
69 | return rect
70 |
71 | def union(a, b):
72 | r0 = min(a[0],b[0])
73 | c0 = min(a[1],b[1])
74 | r1 = max(a[2],b[2])
75 | c1 = max(a[3],b[3])
76 | return (r0,c0,r1,c1)
77 |
78 | def area(a):
79 | '''Computes rectangle area'''
80 | width = a[2] - a[0]
81 | height = a[3] - a[1]
82 | return width * height
83 |
84 |
85 | def main(opts):
86 | # alphabet = '0123456789.'
87 | nclass = len(alphabet) + 1
88 | model_name = 'E2E-CRNN'
89 | net = OwnModel(attention=True, nclass=nclass)
90 | print("Using {0}".format(model_name))
91 |
92 | if opts.cuda:
93 | net.cuda()
94 | learning_rate = opts.base_lr
95 | optimizer = torch.optim.Adam(net.parameters(), lr=opts.base_lr, weight_decay=weight_decay)
96 | optimizer = optim.Adam(net.parameters(), lr=opts.base_lr, betas=(0.5, 0.999))
97 | step_start = 0
98 |
99 | ### 第一种:只修改conv11的维度
100 | # model_dict = net.state_dict()
101 | # if os.path.exists(opts.model):
102 | # print('loading pretrained model from %s' % opts.model)
103 | # pretrained_model = OwnModel(attention=True, nclass=12)
104 | # pretrained_model.load_state_dict(torch.load(opts.model)['state_dict'])
105 | # pretrained_dict = pretrained_model.state_dict()
106 |
107 | # pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict and 'rnn' not in k and 'conv11' not in k}
108 | # model_dict.update(pretrained_dict)
109 | # net.load_state_dict(model_dict)
110 |
111 | if os.path.exists(opts.model):
112 | print('loading model from %s' % args.model)
113 | step_start, learning_rate = net_utils.load_net(args.model, net, optimizer)
114 |
115 | ## 数据集
116 | e2edata = E2Edataset(train_list=opts.train_list)
117 | e2edataloader = torch.utils.data.DataLoader(e2edata, batch_size=opts.batch_size, shuffle=False, collate_fn=E2Ecollate, num_workers=4)
118 |
119 | # 电表数据集
120 | # converter = strLabelConverter(alphabet)
121 | # dataset = ImgDataset(
122 | # root='/home/yangna/deepblue/OCR/mech_demo2/dataset/imgs/image',
123 | # csv_root='/home/yangna/deepblue/OCR/mech_demo2/dataset/imgs/train_list.txt',
124 | # transform=None,
125 | # target_transform=converter.encode
126 | # )
127 | # ocrdataloader = torch.utils.data.DataLoader(
128 | # dataset, batch_size=opts.batch_size, shuffle=True, collate_fn=own_collate
129 | # )
130 |
131 | net.eval()
132 | num_count = 0
133 |
134 | converter = strLabelConverter(alphabet)
135 | ctc_loss = CTCLoss()
136 |
137 | for index, date in enumerate(e2edataloader):
138 | im_data, gtso, lbso = date
139 | im_data = im_data.cuda()
140 |
141 | try:
142 | with torch.no_grad():
143 | res = process_crnn(im_data, gtso, lbso, net, ctc_loss, converter, training=False)
144 |
145 | pred, target = res
146 | target = ''.join(target)
147 | if pred == target:
148 | num_count += 1
149 | except:
150 | import sys, traceback
151 | traceback.print_exc(file=sys.stdout)
152 | pass
153 |
154 | print('correct/total:%d/%d'%(num_count, len(e2edata)))
155 |
156 |
157 |
158 | if __name__ == '__main__':
159 |
160 | parser = argparse.ArgumentParser()
161 | parser.add_argument('-train_list', default='./data/ICDAR2015.txt')
162 | parser.add_argument('-ocr_feed_list', default='sample_train_data/MLT_CROPS/gt.txt')
163 | parser.add_argument('-save_path', default='backup')
164 | parser.add_argument('-model', default='./backup/E2E-CRNN_210.h5')
165 | parser.add_argument('-debug', type=int, default=0)
166 | parser.add_argument('-batch_size', type=int, default=1)
167 | parser.add_argument('-ocr_batch_size', type=int, default=256)
168 | parser.add_argument('-num_readers', type=int, default=1)
169 | parser.add_argument('-cuda', type=bool, default=True)
170 | parser.add_argument('-input_size', type=int, default=512)
171 | parser.add_argument('-geo_type', type=int, default=0)
172 | parser.add_argument('-base_lr', type=float, default=0.001)
173 | parser.add_argument('-max_iters', type=int, default=300000)
174 |
175 | args = parser.parse_args()
176 |
177 | main(args)
178 |
179 |
--------------------------------------------------------------------------------
/tools/test_crnn.py:
--------------------------------------------------------------------------------
1 | '''
2 | Created on Sep 3, 2017
3 |
4 | @author: Michal.Busta at gmail.com
5 | '''
6 |
7 | import torch, os
8 | import numpy as np
9 | import cv2
10 |
11 | import net_utils
12 | import data_gen
13 | from data_gen import draw_box_points
14 | import timeit
15 |
16 | import math
17 | import random
18 |
19 | from models import ModelResNetSep2, OwnModel, CRNN
20 | import torch.autograd as autograd
21 | from torch.autograd import Variable
22 | import torch.nn.functional as F
23 |
24 | # from torch_baidu_ctc import ctc_loss, CTCLoss
25 | from warpctc_pytorch import CTCLoss
26 | from ocr_test_utils import print_seq_ext
27 | from rroi_align.modules.rroi_align import _RRoiAlign
28 | from src.utils import strLabelConverter
29 | from src.utils import alphabet
30 | from src.utils import process_crnn
31 | from src.utils import ImgDataset
32 | from src.utils import own_collate
33 |
34 | import unicodedata as ud
35 | import ocr_gen
36 | from torch import optim
37 | import argparse
38 |
39 |
40 | lr_decay = 0.99
41 | momentum = 0.9
42 | weight_decay = 0
43 | batch_per_epoch = 1000
44 | disp_interval = 5
45 |
46 | norm_height = 44
47 |
48 | f = open('codec.txt', 'r')
49 | codec = f.readlines()[0]
50 | #codec = u' !"#$%&\'()*+,-./0123456789:;<=>?@ABCDEFGHIJKLMNOPQRSTUVWXYZ[\]^_abcdefghijklmnopqrstuvwxyz{|}~£ÁČĎÉĚÍŇÓŘŠŤÚŮÝŽáčďéěíňóřšťúůýž'
51 | codec_rev = {}
52 | index = 4
53 | for i in range(0, len(codec)):
54 | codec_rev[codec[i]] = index
55 | index += 1
56 | f.close()
57 |
58 | def intersect(a, b):
59 | '''Determine the intersection of two rectangles'''
60 | rect = (0,0,0,0)
61 | r0 = max(a[0],b[0])
62 | c0 = max(a[1],b[1])
63 | r1 = min(a[2],b[2])
64 | c1 = min(a[3],b[3])
65 | # Do we have a valid intersection?
66 | if r1 > r0 and c1 > c0:
67 | rect = (r0,c0,r1,c1)
68 | return rect
69 |
70 | def union(a, b):
71 | r0 = min(a[0],b[0])
72 | c0 = min(a[1],b[1])
73 | r1 = max(a[2],b[2])
74 | c1 = max(a[3],b[3])
75 | return (r0,c0,r1,c1)
76 |
77 | def area(a):
78 | '''Computes rectangle area'''
79 | width = a[2] - a[0]
80 | height = a[3] - a[1]
81 | return width * height
82 |
83 |
84 | def main(opts):
85 | alphabet = '0123456789.'
86 | nclass = len(alphabet) + 1
87 | model_name = 'crnn'
88 | net = CRNN(nclass)
89 | print("Using {0}".format(model_name))
90 |
91 | if opts.cuda:
92 | net.cuda()
93 | learning_rate = opts.base_lr
94 | optimizer = torch.optim.Adam(net.parameters(), lr=opts.base_lr, weight_decay=weight_decay)
95 |
96 | if os.path.exists(opts.model):
97 | print('loading model from %s' % args.model)
98 | step_start, learning_rate = net_utils.load_net(args.model, net, optimizer)
99 |
100 | ## 数据集
101 | converter = strLabelConverter(alphabet)
102 | dataset = ImgDataset(
103 | root='/home/yangna/deepblue/OCR/mech_demo2/dataset/imgs/image',
104 | csv_root='/home/yangna/deepblue/OCR/mech_demo2/dataset/imgs/train_list.txt',
105 | transform=None,
106 | target_transform=converter.encode
107 | )
108 | ocrdataloader = torch.utils.data.DataLoader(
109 | dataset, batch_size=1, shuffle=False, collate_fn=own_collate
110 | )
111 |
112 | num_count = 0
113 | net = net.eval()
114 |
115 | converter = strLabelConverter(alphabet)
116 | ctc_loss = CTCLoss()
117 |
118 | for step in range(len(dataset)):
119 |
120 | try:
121 | data = next(data_iter)
122 | except:
123 | data_iter = iter(ocrdataloader)
124 | data = next(data_iter)
125 |
126 | im_data, gt_boxes, text = data
127 | im_data = im_data.cuda()
128 |
129 | try:
130 | res = process_crnn(im_data, gt_boxes, text, net, ctc_loss, converter, training=False)
131 |
132 | pred, target = res
133 | if pred == target[0]:
134 | num_count += 1
135 | except:
136 | import sys, traceback
137 | traceback.print_exc(file=sys.stdout)
138 | pass
139 |
140 |
141 | print('correct/total:%d/%d'%(num_count, len(dataset)))
142 |
143 |
144 |
145 | if __name__ == '__main__':
146 |
147 | parser = argparse.ArgumentParser()
148 | parser.add_argument('-train_list', default='./data/small_train.txt')
149 | parser.add_argument('-ocr_feed_list', default='sample_train_data/MLT_CROPS/gt.txt')
150 | parser.add_argument('-save_path', default='backup')
151 | parser.add_argument('-model', default='./backup/crnn_2000.h5')
152 | parser.add_argument('-debug', type=int, default=0)
153 | parser.add_argument('-batch_size', type=int, default=1)
154 | parser.add_argument('-ocr_batch_size', type=int, default=256)
155 | parser.add_argument('-num_readers', type=int, default=1)
156 | parser.add_argument('-cuda', type=bool, default=True)
157 | parser.add_argument('-input_size', type=int, default=512)
158 | parser.add_argument('-geo_type', type=int, default=0)
159 | parser.add_argument('-base_lr', type=float, default=0.001)
160 | parser.add_argument('-max_iters', type=int, default=300000)
161 |
162 | args = parser.parse_args()
163 |
164 | main(args)
165 |
166 |
--------------------------------------------------------------------------------
/tools/train_crnn.1.py:
--------------------------------------------------------------------------------
1 | '''
2 | Created on Sep 3, 2017
3 |
4 | @author: Michal.Busta at gmail.com
5 | '''
6 |
7 | import torch, os
8 | import numpy as np
9 | import cv2
10 |
11 | import net_utils
12 | import data_gen
13 | from data_gen import draw_box_points
14 | import timeit
15 |
16 | import math
17 | import random
18 |
19 | from models import ModelResNetSep2, OwnModel, CRNN
20 | import torch.autograd as autograd
21 | from torch.autograd import Variable
22 | import torch.nn.functional as F
23 |
24 | # from torch_baidu_ctc import ctc_loss, CTCLoss
25 | from warpctc_pytorch import CTCLoss
26 | from ocr_test_utils import print_seq_ext
27 | from rroi_align.modules.rroi_align import _RRoiAlign
28 | from src.utils import strLabelConverter
29 | from src.utils import alphabet
30 | from src.utils import process_crnn
31 | from src.utils import ImgDataset
32 | from src.utils import own_collate
33 | from src.utils import E2Ecollate,E2Edataset
34 |
35 | import unicodedata as ud
36 | import ocr_gen
37 | from torch import optim
38 | import argparse
39 |
40 |
41 | lr_decay = 0.99
42 | momentum = 0.9
43 | weight_decay = 0.9
44 | batch_per_epoch = 10
45 | disp_interval = 5
46 |
47 | norm_height = 44
48 |
49 | f = open('codec.txt', 'r')
50 | codec = f.readlines()[0]
51 | #codec = u' !"#$%&\'()*+,-./0123456789:;<=>?@ABCDEFGHIJKLMNOPQRSTUVWXYZ[\]^_abcdefghijklmnopqrstuvwxyz{|}~£ÁČĎÉĚÍŇÓŘŠŤÚŮÝŽáčďéěíňóřšťúůýž'
52 | codec_rev = {}
53 | index = 4
54 | for i in range(0, len(codec)):
55 | codec_rev[codec[i]] = index
56 | index += 1
57 | f.close()
58 |
59 | def intersect(a, b):
60 | '''Determine the intersection of two rectangles'''
61 | rect = (0,0,0,0)
62 | r0 = max(a[0],b[0])
63 | c0 = max(a[1],b[1])
64 | r1 = min(a[2],b[2])
65 | c1 = min(a[3],b[3])
66 | # Do we have a valid intersection?
67 | if r1 > r0 and c1 > c0:
68 | rect = (r0,c0,r1,c1)
69 | return rect
70 |
71 | def union(a, b):
72 | r0 = min(a[0],b[0])
73 | c0 = min(a[1],b[1])
74 | r1 = max(a[2],b[2])
75 | c1 = max(a[3],b[3])
76 | return (r0,c0,r1,c1)
77 |
78 | def area(a):
79 | '''Computes rectangle area'''
80 | width = a[2] - a[0]
81 | height = a[3] - a[1]
82 | return width * height
83 |
84 |
85 | def main(opts):
86 | # alphabet = '0123456789.'
87 | nclass = len(alphabet) + 1
88 | model_name = 'E2E-CRNN'
89 | net = OwnModel(attention=True, nclass=nclass)
90 | print("Using {0}".format(model_name))
91 |
92 | if opts.cuda:
93 | net.cuda()
94 | learning_rate = opts.base_lr
95 | optimizer = torch.optim.Adam(net.parameters(), lr=opts.base_lr, weight_decay=weight_decay)
96 | optimizer = optim.Adam(net.parameters(), lr=opts.base_lr, betas=(0.5, 0.999))
97 | step_start = 0
98 |
99 | ### 第一种:只修改conv11的维度
100 | # model_dict = net.state_dict()
101 | # if os.path.exists(opts.model):
102 | # print('loading pretrained model from %s' % opts.model)
103 | # pretrained_model = OwnModel(attention=True, nclass=12)
104 | # pretrained_model.load_state_dict(torch.load(opts.model)['state_dict'])
105 | # pretrained_dict = pretrained_model.state_dict()
106 | #
107 | # pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict and 'rnn' not in k and 'conv11' not in k}
108 | # model_dict.update(pretrained_dict)
109 | # net.load_state_dict(model_dict)
110 |
111 | if os.path.exists(opts.model):
112 | print('loading model from %s' % args.model)
113 | step_start, learning_rate = net_utils.load_net(args.model, net, optimizer)
114 |
115 | ## ICDAR2015数据集
116 | e2edata = E2Edataset(train_list=opts.train_list)
117 | e2edataloader = torch.utils.data.DataLoader(e2edata, batch_size=opts.batch_size, shuffle=True, collate_fn=E2Ecollate, num_workers=4)
118 |
119 | net.train()
120 |
121 | converter = strLabelConverter(alphabet)
122 | ctc_loss = CTCLoss()
123 |
124 | for step in range(step_start, opts.max_iters):
125 |
126 | for index, date in enumerate(e2edataloader):
127 | im_data, gtso, lbso = date
128 | im_data = im_data.cuda()
129 |
130 | try:
131 | loss= process_crnn(im_data, gtso, lbso, net, ctc_loss, converter, training=True)
132 |
133 | net.zero_grad()
134 | # optimizer.zero_grad()
135 | loss.backward()
136 | optimizer.step()
137 | except:
138 | import sys, traceback
139 | traceback.print_exc(file=sys.stdout)
140 | pass
141 |
142 |
143 | if index % disp_interval == 0:
144 | try:
145 | print('epoch:%d || step:%d || loss %.4f' % (step, index, loss))
146 | except:
147 | import sys, traceback
148 | traceback.print_exc(file=sys.stdout)
149 | pass
150 |
151 | if step > step_start and (step % batch_per_epoch == 0):
152 | save_name = os.path.join(opts.save_path, '{}_{}.h5'.format(model_name, step))
153 | state = {'step': step,
154 | 'learning_rate': learning_rate,
155 | 'state_dict': net.state_dict(),
156 | 'optimizer': optimizer.state_dict()}
157 | torch.save(state, save_name)
158 | print('save model: {}'.format(save_name))
159 |
160 |
161 |
162 | if __name__ == '__main__':
163 |
164 | parser = argparse.ArgumentParser()
165 | parser.add_argument('-train_list', default='./data/ICDAR2015.txt')
166 | parser.add_argument('-ocr_feed_list', default='sample_train_data/MLT_CROPS/gt.txt')
167 | parser.add_argument('-save_path', default='backup')
168 | parser.add_argument('-model', default='./backup/E2E-CRNN_210.h5')
169 | parser.add_argument('-debug', type=int, default=0)
170 | parser.add_argument('-batch_size', type=int, default=8)
171 | parser.add_argument('-ocr_batch_size', type=int, default=256)
172 | parser.add_argument('-num_readers', type=int, default=1)
173 | parser.add_argument('-cuda', type=bool, default=True)
174 | parser.add_argument('-input_size', type=int, default=512)
175 | parser.add_argument('-geo_type', type=int, default=0)
176 | parser.add_argument('-base_lr', type=float, default=0.001)
177 | parser.add_argument('-max_iters', type=int, default=300000)
178 |
179 | args = parser.parse_args()
180 |
181 | main(args)
182 |
183 |
--------------------------------------------------------------------------------
/tools/train_crnn.2.py:
--------------------------------------------------------------------------------
1 | '''
2 | Created on Sep 3, 2017
3 |
4 | @author: Michal.Busta at gmail.com
5 | '''
6 |
7 | import torch, os
8 | import numpy as np
9 | import cv2
10 |
11 | import net_utils
12 | import data_gen
13 | from data_gen import draw_box_points
14 | import timeit
15 |
16 | import math
17 | import random
18 |
19 | from models import ModelResNetSep2, OwnModel, CRNN
20 | import torch.autograd as autograd
21 | from torch.autograd import Variable
22 | import torch.nn.functional as F
23 |
24 | # from torch_baidu_ctc import ctc_loss, CTCLoss
25 | from warpctc_pytorch import CTCLoss
26 | from ocr_test_utils import print_seq_ext
27 | from rroi_align.modules.rroi_align import _RRoiAlign
28 | from src.utils import strLabelConverter
29 | from src.utils import alphabet
30 | from src.utils import process_crnn
31 | from src.utils import ImgDataset
32 | from src.utils import own_collate
33 | from src.utils import E2Ecollate,E2Edataset
34 |
35 | import unicodedata as ud
36 | import ocr_gen
37 | from torch import optim
38 | import argparse
39 |
40 |
41 | lr_decay = 0.99
42 | momentum = 0.9
43 | weight_decay = 0.9
44 | batch_per_epoch = 10
45 | disp_interval = 5
46 |
47 | norm_height = 44
48 |
49 | f = open('codec.txt', 'r')
50 | codec = f.readlines()[0]
51 | #codec = u' !"#$%&\'()*+,-./0123456789:;<=>?@ABCDEFGHIJKLMNOPQRSTUVWXYZ[\]^_abcdefghijklmnopqrstuvwxyz{|}~£ÁČĎÉĚÍŇÓŘŠŤÚŮÝŽáčďéěíňóřšťúůýž'
52 | codec_rev = {}
53 | index = 4
54 | for i in range(0, len(codec)):
55 | codec_rev[codec[i]] = index
56 | index += 1
57 | f.close()
58 |
59 | def intersect(a, b):
60 | '''Determine the intersection of two rectangles'''
61 | rect = (0,0,0,0)
62 | r0 = max(a[0],b[0])
63 | c0 = max(a[1],b[1])
64 | r1 = min(a[2],b[2])
65 | c1 = min(a[3],b[3])
66 | # Do we have a valid intersection?
67 | if r1 > r0 and c1 > c0:
68 | rect = (r0,c0,r1,c1)
69 | return rect
70 |
71 | def union(a, b):
72 | r0 = min(a[0],b[0])
73 | c0 = min(a[1],b[1])
74 | r1 = max(a[2],b[2])
75 | c1 = max(a[3],b[3])
76 | return (r0,c0,r1,c1)
77 |
78 | def area(a):
79 | '''Computes rectangle area'''
80 | width = a[2] - a[0]
81 | height = a[3] - a[1]
82 | return width * height
83 |
84 |
85 | def main(opts):
86 | # alphabet = '0123456789.'
87 | nclass = len(alphabet) + 1
88 | model_name = 'E2E-CRNN'
89 | net = OwnModel(attention=True, nclass=nclass)
90 | print("Using {0}".format(model_name))
91 |
92 | if opts.cuda:
93 | net.cuda()
94 | learning_rate = opts.base_lr
95 | optimizer = torch.optim.Adam(net.parameters(), lr=opts.base_lr, weight_decay=weight_decay)
96 | optimizer = optim.Adam(net.parameters(), lr=opts.base_lr, betas=(0.5, 0.999))
97 | step_start = 0
98 |
99 | ### 第一种:只修改conv11的维度
100 | # model_dict = net.state_dict()
101 | # if os.path.exists(opts.model):
102 | # print('loading pretrained model from %s' % opts.model)
103 | # pretrained_model = OwnModel(attention=True, nclass=12)
104 | # pretrained_model.load_state_dict(torch.load(opts.model)['state_dict'])
105 | # pretrained_dict = pretrained_model.state_dict()
106 | #
107 | # pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict and 'rnn' not in k and 'conv11' not in k}
108 | # model_dict.update(pretrained_dict)
109 | # net.load_state_dict(model_dict)
110 |
111 | if os.path.exists(opts.model):
112 | print('loading model from %s' % args.model)
113 | step_start, learning_rate = net_utils.load_net(args.model, net, optimizer)
114 |
115 | ## 数据集
116 | e2edata = E2Edataset(train_list=opts.train_list)
117 | e2edataloader = torch.utils.data.DataLoader(e2edata, batch_size=opts.batch_size, shuffle=True, collate_fn=E2Ecollate, num_workers=4)
118 |
119 | # 电表数据集
120 | # converter = strLabelConverter(alphabet)
121 | # dataset = ImgDataset(
122 | # root='/home/yangna/deepblue/OCR/mech_demo2/dataset/imgs/image',
123 | # csv_root='/home/yangna/deepblue/OCR/mech_demo2/dataset/imgs/train_list.txt',
124 | # transform=None,
125 | # target_transform=converter.encode
126 | # )
127 | # ocrdataloader = torch.utils.data.DataLoader(
128 | # dataset, batch_size=opts.batch_size, shuffle=True, collate_fn=own_collate
129 | # )
130 |
131 | net.train()
132 |
133 | converter = strLabelConverter(alphabet)
134 | ctc_loss = CTCLoss()
135 |
136 | for step in range(step_start, opts.max_iters):
137 |
138 | for index, date in enumerate(e2edataloader):
139 | im_data, gtso, lbso = date
140 | im_data = im_data.cuda()
141 |
142 | try:
143 | loss= process_crnn(im_data, gtso, lbso, net, ctc_loss, converter, training=True)
144 |
145 | net.zero_grad()
146 | # optimizer.zero_grad()
147 | loss.backward()
148 | optimizer.step()
149 | except:
150 | import sys, traceback
151 | traceback.print_exc(file=sys.stdout)
152 | pass
153 |
154 |
155 | if index % disp_interval == 0:
156 | try:
157 | print('epoch:%d || step:%d || loss %.4f' % (step, index, loss))
158 | except:
159 | import sys, traceback
160 | traceback.print_exc(file=sys.stdout)
161 | pass
162 |
163 | if step > step_start and (step % batch_per_epoch == 0):
164 | save_name = os.path.join(opts.save_path, '{}_{}.h5'.format(model_name, step))
165 | state = {'step': step,
166 | 'learning_rate': learning_rate,
167 | 'state_dict': net.state_dict(),
168 | 'optimizer': optimizer.state_dict()}
169 | torch.save(state, save_name)
170 | print('save model: {}'.format(save_name))
171 |
172 |
173 |
174 | if __name__ == '__main__':
175 |
176 | parser = argparse.ArgumentParser()
177 | parser.add_argument('-train_list', default='./data/ICDAR2015.txt')
178 | parser.add_argument('-ocr_feed_list', default='sample_train_data/MLT_CROPS/gt.txt')
179 | parser.add_argument('-save_path', default='backup')
180 | parser.add_argument('-model', default='./backup/E2E-CRNN_210.h5')
181 | parser.add_argument('-debug', type=int, default=0)
182 | parser.add_argument('-batch_size', type=int, default=8)
183 | parser.add_argument('-ocr_batch_size', type=int, default=256)
184 | parser.add_argument('-num_readers', type=int, default=1)
185 | parser.add_argument('-cuda', type=bool, default=True)
186 | parser.add_argument('-input_size', type=int, default=512)
187 | parser.add_argument('-geo_type', type=int, default=0)
188 | parser.add_argument('-base_lr', type=float, default=0.001)
189 | parser.add_argument('-max_iters', type=int, default=300000)
190 |
191 | args = parser.parse_args()
192 |
193 | main(args)
194 |
195 |
--------------------------------------------------------------------------------
/tools/train_crnn.py:
--------------------------------------------------------------------------------
1 | '''
2 | Created on Sep 3, 2017
3 |
4 | @author: Michal.Busta at gmail.com
5 | '''
6 |
7 | import torch, os
8 | import numpy as np
9 | import cv2
10 |
11 | import net_utils
12 | import data_gen
13 | from data_gen import draw_box_points
14 | import timeit
15 |
16 | import math
17 | import random
18 |
19 | from models import ModelResNetSep2, OwnModel, CRNN
20 | import torch.autograd as autograd
21 | from torch.autograd import Variable
22 | import torch.nn.functional as F
23 |
24 | # from torch_baidu_ctc import ctc_loss, CTCLoss
25 | from warpctc_pytorch import CTCLoss
26 | from ocr_test_utils import print_seq_ext
27 | from rroi_align.modules.rroi_align import _RRoiAlign
28 | from src.utils import strLabelConverter
29 | from src.utils import alphabet
30 | from src.utils import process_crnn
31 | from src.utils import ImgDataset
32 | from src.utils import own_collate
33 |
34 | import unicodedata as ud
35 | import ocr_gen
36 | from torch import optim
37 | import argparse
38 |
39 |
40 | lr_decay = 0.99
41 | momentum = 0.9
42 | weight_decay = 0
43 | batch_per_epoch = 1000
44 | disp_interval = 5
45 |
46 | norm_height = 44
47 |
48 | f = open('codec.txt', 'r')
49 | codec = f.readlines()[0]
50 | #codec = u' !"#$%&\'()*+,-./0123456789:;<=>?@ABCDEFGHIJKLMNOPQRSTUVWXYZ[\]^_abcdefghijklmnopqrstuvwxyz{|}~£ÁČĎÉĚÍŇÓŘŠŤÚŮÝŽáčďéěíňóřšťúůýž'
51 | codec_rev = {}
52 | index = 4
53 | for i in range(0, len(codec)):
54 | codec_rev[codec[i]] = index
55 | index += 1
56 | f.close()
57 |
58 | def intersect(a, b):
59 | '''Determine the intersection of two rectangles'''
60 | rect = (0,0,0,0)
61 | r0 = max(a[0],b[0])
62 | c0 = max(a[1],b[1])
63 | r1 = min(a[2],b[2])
64 | c1 = min(a[3],b[3])
65 | # Do we have a valid intersection?
66 | if r1 > r0 and c1 > c0:
67 | rect = (r0,c0,r1,c1)
68 | return rect
69 |
70 | def union(a, b):
71 | r0 = min(a[0],b[0])
72 | c0 = min(a[1],b[1])
73 | r1 = max(a[2],b[2])
74 | c1 = max(a[3],b[3])
75 | return (r0,c0,r1,c1)
76 |
77 | def area(a):
78 | '''Computes rectangle area'''
79 | width = a[2] - a[0]
80 | height = a[3] - a[1]
81 | return width * height
82 |
83 |
84 | def main(opts):
85 | alphabet = '0123456789.'
86 | nclass = len(alphabet) + 1
87 | model_name = 'crnn'
88 | net = CRNN(nclass)
89 | print("Using {0}".format(model_name))
90 |
91 | if opts.cuda:
92 | net.cuda()
93 | learning_rate = opts.base_lr
94 | optimizer = torch.optim.Adam(net.parameters(), lr=opts.base_lr, weight_decay=weight_decay)
95 |
96 | if os.path.exists(opts.model):
97 | print('loading model from %s' % args.model)
98 | step_start, learning_rate = net_utils.load_net(args.model, net, optimizer)
99 |
100 | ## 数据集
101 | converter = strLabelConverter(alphabet)
102 | dataset = ImgDataset(
103 | root='/home/yangna/deepblue/OCR/mech_demo2/dataset/imgs/image',
104 | csv_root='/home/yangna/deepblue/OCR/mech_demo2/dataset/imgs/train_list.txt',
105 | transform=None,
106 | target_transform=converter.encode
107 | )
108 | ocrdataloader = torch.utils.data.DataLoader(
109 | dataset, batch_size=opts.batch_size, shuffle=True, collate_fn=own_collate
110 | )
111 |
112 | step_start = 0
113 | net.train()
114 |
115 | converter = strLabelConverter(alphabet)
116 | ctc_loss = CTCLoss()
117 |
118 | for step in range(step_start, opts.max_iters):
119 |
120 | try:
121 | data = next(data_iter)
122 | except:
123 | data_iter = iter(ocrdataloader)
124 | data = next(data_iter)
125 |
126 | im_data, gt_boxes, text = data
127 | im_data = im_data.cuda()
128 |
129 | try:
130 | loss= process_crnn(im_data, gt_boxes, text, net, ctc_loss, converter, training=True)
131 |
132 | net.zero_grad()
133 | optimizer.zero_grad()
134 | loss.backward()
135 | optimizer.step()
136 | except:
137 | import sys, traceback
138 | traceback.print_exc(file=sys.stdout)
139 | pass
140 |
141 |
142 | if step % disp_interval == 0:
143 | try:
144 | print('step:%d || loss %.4f' % (step, loss))
145 | except:
146 | import sys, traceback
147 | traceback.print_exc(file=sys.stdout)
148 | pass
149 |
150 | if step > step_start and (step % batch_per_epoch == 0):
151 | save_name = os.path.join(opts.save_path, '{}_{}.h5'.format(model_name, step))
152 | state = {'step': step,
153 | 'learning_rate': learning_rate,
154 | 'state_dict': net.state_dict(),
155 | 'optimizer': optimizer.state_dict()}
156 | torch.save(state, save_name)
157 | print('save model: {}'.format(save_name))
158 |
159 |
160 |
161 | if __name__ == '__main__':
162 |
163 | parser = argparse.ArgumentParser()
164 | parser.add_argument('-train_list', default='./data/small_train.txt')
165 | parser.add_argument('-ocr_feed_list', default='sample_train_data/MLT_CROPS/gt.txt')
166 | parser.add_argument('-save_path', default='backup')
167 | parser.add_argument('-model', default='./backup/crnn_10000.h5')
168 | parser.add_argument('-debug', type=int, default=0)
169 | parser.add_argument('-batch_size', type=int, default=8)
170 | parser.add_argument('-ocr_batch_size', type=int, default=256)
171 | parser.add_argument('-num_readers', type=int, default=1)
172 | parser.add_argument('-cuda', type=bool, default=True)
173 | parser.add_argument('-input_size', type=int, default=512)
174 | parser.add_argument('-geo_type', type=int, default=0)
175 | parser.add_argument('-base_lr', type=float, default=0.001)
176 | parser.add_argument('-max_iters', type=int, default=300000)
177 |
178 | args = parser.parse_args()
179 |
180 | main(args)
181 |
182 |
--------------------------------------------------------------------------------
/tools/train_ocr.py:
--------------------------------------------------------------------------------
1 | '''
2 | Created on Sep 29, 2017
3 |
4 | @author: Michal.Busta at gmail.com
5 | '''
6 | import numpy as np
7 | import torch.nn as nn
8 | import torch.nn.functional as F
9 |
10 | import os
11 |
12 | f = open('codec.txt', 'r')
13 | codec = f.readlines()[0]
14 | f.close()
15 | print(len(codec))
16 |
17 | import torch
18 | import net_utils
19 | import argparse
20 |
21 | import ocr_gen
22 |
23 | from warpctc_pytorch import CTCLoss
24 | from torch.autograd import Variable
25 |
26 | from models import ModelResNetSep2
27 | from ocr_test_utils import print_seq_ext
28 | import random
29 |
30 | import cv2
31 |
32 |
33 | base_lr = 0.0001
34 | lr_decay = 0.99
35 | momentum = 0.9
36 | weight_decay = 0.0005
37 | batch_per_epoch = 5000
38 | disp_interval = 500
39 |
40 |
41 | def main(opts):
42 |
43 | model_name = 'E2E'
44 | net = ModelResNetSep2(attention=True)
45 | acc = []
46 |
47 | if opts.cuda:
48 | net.cuda()
49 |
50 | optimizer = torch.optim.Adam(net.parameters(), lr=base_lr, weight_decay=weight_decay)
51 | step_start = 0
52 | if os.path.exists(opts.model):
53 | print('loading model from %s' % args.model)
54 | step_start, learning_rate = net_utils.load_net(args.model, net, optimizer)
55 | else:
56 | learning_rate = base_lr
57 |
58 | step_start = 0
59 |
60 | net.train()
61 |
62 | #acc_test = test(net, codec, opts, list_file=opts.valid_list, norm_height=opts.norm_height)
63 | #acc.append([0, acc_test])
64 |
65 | ctc_loss = CTCLoss()
66 |
67 | data_generator = ocr_gen.get_batch(num_workers=opts.num_readers,
68 | batch_size=opts.batch_size,
69 | train_list=opts.train_list, in_train=True, norm_height=opts.norm_height, rgb = True)
70 |
71 | train_loss = 0
72 | cnt = 0
73 |
74 | for step in range(step_start, 300000):
75 | # batch
76 | images, labels, label_length = next(data_generator)
77 | im_data = net_utils.np_to_variable(images, is_cuda=opts.cuda).permute(0, 3, 1, 2)
78 | features = net.forward_features(im_data)
79 | labels_pred = net.forward_ocr(features)
80 |
81 | # backward
82 | '''
83 | acts: Tensor of (seqLength x batch x outputDim) containing output from network
84 | labels: 1 dimensional Tensor containing all the targets of the batch in one sequence
85 | act_lens: Tensor of size (batch) containing size of each output sequence from the network
86 | act_lens: Tensor of (batch) containing label length of each example
87 | '''
88 |
89 | probs_sizes = torch.IntTensor( [(labels_pred.permute(2,0,1).size()[0])] * (labels_pred.permute(2,0,1).size()[1]) )
90 | label_sizes = torch.IntTensor( torch.from_numpy(np.array(label_length)).int() )
91 | labels = torch.IntTensor( torch.from_numpy(np.array(labels)).int() )
92 | loss = ctc_loss(labels_pred.permute(2,0,1), labels, probs_sizes, label_sizes) / im_data.size(0) # change 1.9.
93 | optimizer.zero_grad()
94 | loss.backward()
95 | optimizer.step()
96 | if not np.isinf(loss.data.cpu().numpy()):
97 | train_loss += loss.data.cpu().numpy()[0] #net.bbox_loss.data.cpu().numpy()[0]
98 | cnt += 1
99 |
100 | if opts.debug:
101 | dbg = labels_pred.data.cpu().numpy()
102 | ctc_f = dbg.swapaxes(1, 2)
103 | labels = ctc_f.argmax(2)
104 | det_text, conf, dec_s = print_seq_ext(labels[0, :], codec)
105 |
106 | print('{0} \t'.format(det_text))
107 |
108 |
109 |
110 | if step % disp_interval == 0:
111 |
112 | train_loss /= cnt
113 | print('epoch %d[%d], loss: %.3f, lr: %.5f ' % (
114 | step / batch_per_epoch, step, train_loss, learning_rate))
115 |
116 | train_loss = 0
117 | cnt = 0
118 |
119 | if step > step_start and (step % batch_per_epoch == 0):
120 | save_name = os.path.join(opts.save_path, '{}_{}.h5'.format(model_name, step))
121 | state = {'step': step,
122 | 'learning_rate': learning_rate,
123 | 'state_dict': net.state_dict(),
124 | 'optimizer': optimizer.state_dict()}
125 | torch.save(state, save_name)
126 | print('save model: {}'.format(save_name))
127 |
128 | #acc_test, ted = test(net, codec, opts, list_file=opts.valid_list, norm_height=opts.norm_height)
129 | #acc.append([0, acc_test, ted])
130 | np.savez('train_acc_{0}'.format(model_name), acc=acc)
131 |
132 | if __name__ == '__main__':
133 | parser = argparse.ArgumentParser()
134 |
135 | parser.add_argument('-train_list', default='/home/busta/data/90kDICT32px/train_mlt.txt')
136 | parser.add_argument('-valid_list', default='/home/busta/data/icdar_ch8_validation/ocr_valid.txt')
137 | parser.add_argument('-save_path', default='backup2')
138 | parser.add_argument('-model', default='/mnt/textspotter/tmp/DS_CVPR/backup2/ModelResNetSep2_25000.h5')
139 | parser.add_argument('-debug', type=int, default=0)
140 | parser.add_argument('-batch_size', type=int, default=4)
141 | parser.add_argument('-num_readers', type=int, default=1)
142 | parser.add_argument('-cuda', type=bool, default=True)
143 | parser.add_argument('-norm_height', type=int, default=40)
144 |
145 | args = parser.parse_args()
146 | main(args)
147 |
148 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import torch, os
2 | import numpy as np
3 | import cv2
4 | import math
5 | import random
6 | import time
7 |
8 | import tools.net_utils as net_utils
9 | import tools.data_gen as data_gen
10 | from tools.data_gen import draw_box_points
11 |
12 | from tools.models import ModelResNetSep2, OwnModel
13 | import torch.autograd as autograd
14 | from torch.autograd import Variable
15 | import torch.nn.functional as F
16 | from warpctc_pytorch import CTCLoss
17 | from tools.ocr_test_utils import print_seq_ext
18 | from rroi_align.modules.rroi_align import _RRoiAlign
19 | from src.utils import strLabelConverter
20 | from src.utils import alphabet
21 | from src.utils import averager
22 | from src.ocr_process import process_boxes
23 |
24 | import unicodedata as ud
25 | import tools.ocr_gen
26 | from torch import optim
27 | import argparse
28 |
29 |
30 | def main(opts):
31 |
32 | ## 1. 初始化模型
33 | nclass = len(alphabet) + 1 # 训练ICDAR2015
34 | model_name = 'E2E-MLT'
35 | net = ModelResNetSep2(attention=True, nclass=nclass)
36 | print("Using {0}".format(model_name))
37 |
38 | learning_rate = opts.base_lr
39 | # optimizer = torch.optim.Adam(net.parameters(), lr=opts.base_lr, weight_decay=weight_decay)
40 | optimizer = optim.Adam(net.parameters(), lr=opts.base_lr, betas=(0.5, 0.999))
41 | step_start = 0
42 |
43 | ### //预训练模型初始化,第一种:只修改conv11的维度
44 | model_dict = net.state_dict()
45 | if os.path.exists(opts.model):
46 | print('loading pretrained model from %s' % opts.model)
47 | pretrained_model = ModelResNetSep2(attention=True, nclass=7500) # pretrained model from:https://github.com/MichalBusta/E2E-MLT
48 | pretrained_model.load_state_dict(torch.load(opts.model)['state_dict'])
49 | pretrained_dict = pretrained_model.state_dict()
50 | pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict and 'conv11' not in k and 'rnn' not in k}
51 | model_dict.update(pretrained_dict)
52 | net.load_state_dict(model_dict)
53 | ### 第二种:直接接着前面训练
54 | # if os.path.exists(opts.model):
55 | # print('loading model from %s' % args.model)
56 | # step_start, learning_rate = net_utils.load_net(args.model, net, optimizer)
57 | ###
58 | if opts.cuda:
59 | net.cuda()
60 | net.train()
61 |
62 |
63 | ## 2. 定义数据集
64 | converter = strLabelConverter(alphabet)
65 | ctc_loss = CTCLoss()
66 | data_generator = data_gen.get_batch(num_workers=opts.num_readers,
67 | input_size=opts.input_size, batch_size=opts.batch_size,
68 | train_list=opts.train_list, geo_type=opts.geo_type)
69 | # dg_ocr = ocr_gen.get_batch(num_workers=2,
70 | # batch_size=opts.ocr_batch_size,
71 | # train_list=opts.ocr_feed_list, in_train=True, norm_height=norm_height, rgb=True) # 训练OCR识别的数据集
72 |
73 | ## 3. 变量初始化
74 | bbox_loss = averager(); seg_loss = averager(); angle_loss = averager()
75 | loss_ctc = averager(); train_loss = averager()
76 |
77 |
78 | ## 4. 开始训练
79 | for step in range(step_start, opts.max_iters):
80 |
81 | # 读取数据
82 | images, image_fns, score_maps, geo_maps, training_masks, gtso, lbso, gt_idxs = next(data_generator)
83 | im_data = net_utils.np_to_variable(images.transpose(0, 3, 1, 2), is_cuda=opts.cuda)
84 | start = time.time()
85 | try:
86 | seg_pred, roi_pred, angle_pred, features = net(im_data)
87 | except:
88 | import sys, traceback
89 | traceback.print_exc(file=sys.stdout)
90 | continue
91 |
92 | # for EAST loss
93 | smaps_var = net_utils.np_to_variable(score_maps, is_cuda=opts.cuda)
94 | training_mask_var = net_utils.np_to_variable(training_masks, is_cuda=opts.cuda)
95 | angle_gt = net_utils.np_to_variable(geo_maps[:, :, :, 4], is_cuda=opts.cuda)
96 | geo_gt = net_utils.np_to_variable(geo_maps[:, :, :, [0, 1, 2, 3]], is_cuda=opts.cuda)
97 | try:
98 | loss = net.loss(seg_pred, smaps_var, training_mask_var, angle_pred, angle_gt, roi_pred, geo_gt)
99 | except:
100 | import sys, traceback
101 | traceback.print_exc(file=sys.stdout)
102 | continue
103 |
104 | bbox_loss.add(net.box_loss_value.item()); seg_loss.add(net.segm_loss_value.item()); angle_loss.add(net.angle_loss_value.item())
105 |
106 |
107 | # 训练ocr的部分
108 | try:
109 | # 10000步之前都是用文字的标注区域训练的//E2E-MLT中采用的这种策略
110 | if step > 10000 or True: #this is just extra augumentation step ... in early stage just slows down training
111 | ctcl, gt_target , gt_proc = process_boxes(images, im_data, seg_pred[0], roi_pred[0], angle_pred[0], score_maps, gt_idxs, gtso, lbso, features, net, ctc_loss, opts, converter, debug=opts.debug)
112 | loss_ctc.add(ctcl)
113 | loss = loss + ctcl.cuda()
114 | train_loss.add(loss.item())
115 |
116 | net.zero_grad()
117 | optimizer.zero_grad()
118 | loss.backward()
119 | optimizer.step()
120 | except:
121 | import sys, traceback
122 | traceback.print_exc(file=sys.stdout)
123 | pass
124 |
125 | if step % opts.disp_interval == 0:
126 | end = time.time() # 计算耗时
127 | ctc_loss_val2 = 0.0
128 | print('epoch %d[%d], loss: %.3f, bbox_loss: %.3f, seg_loss: %.3f, ang_loss: %.3f, ctc_loss: %.3f, time %.3f' % (
129 | step / 1000 * opts.batch_size, step, train_loss.val(), bbox_loss.val(), seg_loss.val(), angle_loss.val(), loss_ctc.val(), end-start))
130 |
131 | # for save mode
132 | if step > step_start and (step % ((1000 / opts.batch_size)*20) == 0): # 20代保存一次
133 | save_name = os.path.join(opts.save_path, '{}_{}.h5'.format(model_name, step))
134 | state = {'step': step,
135 | 'learning_rate': learning_rate,
136 | 'state_dict': net.state_dict(),
137 | 'optimizer': optimizer.state_dict()}
138 | torch.save(state, save_name)
139 | print('save model: {}'.format(save_name))
140 | train_loss.reset(); bbox_loss.reset(); seg_loss.reset(); angle_loss.reset(); loss_ctc.reset() # 避免超出了范围
141 |
142 |
143 |
144 | if __name__ == '__main__':
145 |
146 | parser = argparse.ArgumentParser()
147 | parser.add_argument('-train_list', default='./data/ICDAR2015.txt')
148 | parser.add_argument('-ocr_feed_list', default='sample_train_data/MLT_CROPS/gt.txt')
149 | parser.add_argument('-save_path', default='backup')
150 | parser.add_argument('-model', default='./weights/e2e-mlt.h5')
151 | parser.add_argument('-debug', type=int, default=0)
152 | parser.add_argument('-batch_size', type=int, default=2)
153 | parser.add_argument('-ocr_batch_size', type=int, default=256)
154 | parser.add_argument('-num_readers', type=int, default=4, help='it is faster')
155 | parser.add_argument('-cuda', type=bool, default=True)
156 | parser.add_argument('-input_size', type=int, default=512)
157 | parser.add_argument('-geo_type', type=int, default=0)
158 | parser.add_argument('-base_lr', type=float, default=0.001)
159 | parser.add_argument('-max_iters', type=int, default=300000)
160 | parser.add_argument('-disp_interval', type=int, default=5)
161 |
162 | args = parser.parse_args()
163 |
164 | main(args)
165 |
166 |
--------------------------------------------------------------------------------
/weights/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chenjun2hao/FOTS.pytorch/a2d1bf71d66197ed6b20c4cfcbe60d4735a20e0c/weights/__init__.py
--------------------------------------------------------------------------------