├── models ├── __init__.py ├── adamw_schedulefree.py └── radam_schedulefree.py ├── data ├── dataset ├── __init__.py ├── memo.txt ├── downloader_src │ ├── Makefile │ └── main.cpp ├── filter_fixdata.py ├── data_fixdata.py ├── data_detector.py └── multi.py ├── train_data3 ├── __init__.py ├── get_aozora.py ├── get_wikipedia.py ├── check_code.py └── make_data.py ├── make_traindata ├── models ├── model.pt ├── render_font │ ├── __init__.py │ ├── Makefile │ ├── get_wikipedia.py │ └── get_aozora.py ├── util_func.py ├── TextDetector.onnx ├── TextDetector.mlpackage ├── data │ └── other_list.txt ├── memo.txt ├── save_feature.py ├── merge_data.py ├── make_traindata1.py └── make_traindata3.py ├── img ├── test1.png ├── test2.png ├── test2_code1.png ├── test2_code2.png ├── test1_keymap.png ├── test1_result.png ├── test2_keymap.png ├── test2_result.png ├── fix_image_json1.png ├── fix_image_json2.png ├── fix_image_line1.png ├── fix_image_line2.png ├── test1_separator.png ├── test1_textline.png └── test2_textline.png ├── textline_detect ├── src │ ├── space_check.h │ ├── make_block.h │ ├── minpack │ │ ├── minpack.hpp │ │ ├── enorm.cpp │ │ ├── fdjac2.cpp │ │ ├── qrfac.cpp │ │ ├── qrsolv.cpp │ │ └── lmpar.cpp │ ├── number_unbind.h │ ├── hough_linefind.h │ ├── after_search.h │ ├── ruby_search.h │ ├── prepare.h │ ├── split_doubleline.h │ ├── process.h │ ├── line_detect.h │ ├── search_loop.h │ ├── process.cpp │ ├── prepare.cpp │ ├── main.cpp │ └── after_search.cpp ├── Makefile.mak └── Makefile ├── const.py ├── convert_fp16_onnx.py ├── fine_image ├── empty_image1.py ├── plot_image1.py ├── fix_line_image1.py ├── process_image4_coreml.py └── process_image4_torch.py ├── LICENSE ├── run_ocr.py ├── process_ocr_torch.py ├── .gitignore ├── process_ocr_coreml.py ├── process_ocr_onnx.py ├── util_func.py ├── plot_json.py ├── quantize1_onnx.py ├── convert1_onnx.py ├── convert1_coreml.py ├── convert3_onnx.py └── loss_func.py /models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /data: -------------------------------------------------------------------------------- 1 | make_traindata/data -------------------------------------------------------------------------------- /dataset/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /train_data3/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /make_traindata/models: -------------------------------------------------------------------------------- 1 | ../models -------------------------------------------------------------------------------- /make_traindata/model.pt: -------------------------------------------------------------------------------- 1 | ../model.pt -------------------------------------------------------------------------------- /make_traindata/render_font/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /make_traindata/util_func.py: -------------------------------------------------------------------------------- 1 | ../util_func.py -------------------------------------------------------------------------------- /make_traindata/TextDetector.onnx: -------------------------------------------------------------------------------- 1 | ../TextDetector.onnx -------------------------------------------------------------------------------- /make_traindata/TextDetector.mlpackage: -------------------------------------------------------------------------------- 1 | ../TextDetector.mlpackage -------------------------------------------------------------------------------- /train_data3/get_aozora.py: -------------------------------------------------------------------------------- 1 | ../make_traindata/render_font/get_aozora.py -------------------------------------------------------------------------------- /train_data3/get_wikipedia.py: -------------------------------------------------------------------------------- 1 | ../make_traindata/render_font/get_wikipedia.py -------------------------------------------------------------------------------- /img/test1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lithium0003/findtextCenterNet/HEAD/img/test1.png -------------------------------------------------------------------------------- /img/test2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lithium0003/findtextCenterNet/HEAD/img/test2.png -------------------------------------------------------------------------------- /img/test2_code1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lithium0003/findtextCenterNet/HEAD/img/test2_code1.png -------------------------------------------------------------------------------- /img/test2_code2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lithium0003/findtextCenterNet/HEAD/img/test2_code2.png -------------------------------------------------------------------------------- /img/test1_keymap.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lithium0003/findtextCenterNet/HEAD/img/test1_keymap.png -------------------------------------------------------------------------------- /img/test1_result.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lithium0003/findtextCenterNet/HEAD/img/test1_result.png -------------------------------------------------------------------------------- /img/test2_keymap.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lithium0003/findtextCenterNet/HEAD/img/test2_keymap.png -------------------------------------------------------------------------------- /img/test2_result.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lithium0003/findtextCenterNet/HEAD/img/test2_result.png -------------------------------------------------------------------------------- /img/fix_image_json1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lithium0003/findtextCenterNet/HEAD/img/fix_image_json1.png -------------------------------------------------------------------------------- /img/fix_image_json2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lithium0003/findtextCenterNet/HEAD/img/fix_image_json2.png -------------------------------------------------------------------------------- /img/fix_image_line1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lithium0003/findtextCenterNet/HEAD/img/fix_image_line1.png -------------------------------------------------------------------------------- /img/fix_image_line2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lithium0003/findtextCenterNet/HEAD/img/fix_image_line2.png -------------------------------------------------------------------------------- /img/test1_separator.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lithium0003/findtextCenterNet/HEAD/img/test1_separator.png -------------------------------------------------------------------------------- /img/test1_textline.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lithium0003/findtextCenterNet/HEAD/img/test1_textline.png -------------------------------------------------------------------------------- /img/test2_textline.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lithium0003/findtextCenterNet/HEAD/img/test2_textline.png -------------------------------------------------------------------------------- /dataset/memo.txt: -------------------------------------------------------------------------------- 1 | CPLUS_INCLUDE_PATH=$(python -c 'import numpy; print(numpy.get_include())') cythonize -i dataset/processer.pyx 2 | -------------------------------------------------------------------------------- /textline_detect/src/space_check.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "line_detect.h" 4 | #include 5 | 6 | void space_chack(std::vector &boxes); 7 | -------------------------------------------------------------------------------- /make_traindata/data/other_list.txt: -------------------------------------------------------------------------------- 1 | !#$%&()*+,-./:;<=>?@[]^{}~"'§¶_′″‘’“”«» 2 | 。、・゠!?⁉⁈"※ 3 | 「」『』(){}〈〉《》⦅⦆[]〚〛〔〕〘〙【】〖〗 4 | ー〜〰…‥ 5 | 々ゝヽゞヾ〻 6 | ♡♥♤♠♢♦♧♣ 7 | ♪♫♬♩♯♭♮ 8 | →←↑↓⇒⇔↔↗↘↖↙⇄⇨⇦⇧⇩⤴⤵⏎ 9 | -------------------------------------------------------------------------------- /textline_detect/src/make_block.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "line_detect.h" 4 | #include 5 | 6 | void make_block( 7 | std::vector &boxes, 8 | const std::vector &lineblocker); 9 | -------------------------------------------------------------------------------- /textline_detect/Makefile.mak: -------------------------------------------------------------------------------- 1 | TARGET = linedetect.exe 2 | 3 | CFLAGS = /std:c++20 /O2 /utf-8 4 | 5 | $(TARGET): src/*.cpp src/minpack/*.cpp 6 | $(CXX) -o $@ $** $(CFLAGS) 7 | 8 | clean: 9 | del /F *.obj $(TARGET) -------------------------------------------------------------------------------- /textline_detect/src/minpack/minpack.hpp: -------------------------------------------------------------------------------- 1 | #pragma onece 2 | 3 | int lmdif1( 4 | int (*fcn)(int,int,double *,double *), 5 | int m, 6 | int n, 7 | double *x, 8 | double *fvec, 9 | double tol = 1e-10); 10 | -------------------------------------------------------------------------------- /dataset/downloader_src/Makefile: -------------------------------------------------------------------------------- 1 | all: downloader 2 | 3 | downloader: main.cpp 4 | g++ -O3 -march=native main.cpp -o downloader -std=c++17 `pkg-config --cflags --libs libcurl` 5 | 6 | .PHONY : clean 7 | clean: 8 | @rm -rf downloader 9 | -------------------------------------------------------------------------------- /const.py: -------------------------------------------------------------------------------- 1 | encoder_add_dim = 6 2 | # 1 vertical 3 | # 2 ruby (base) 4 | # 3 ruby (text) 5 | # 4 space 6 | # 5 emphasis 7 | # 6 newline 8 | 9 | max_decoderlen = 400 10 | max_encoderlen = 400 11 | 12 | decoder_PAD = 0 13 | decoder_SOT = 1 14 | decoder_EOT = 2 15 | decoder_MSK = 3 16 | -------------------------------------------------------------------------------- /textline_detect/src/number_unbind.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "line_detect.h" 4 | #include 5 | 6 | int number_unbind( 7 | std::vector &boxes, 8 | const std::vector &lineblocker, 9 | const std::vector &idimage, 10 | int next_id); 11 | -------------------------------------------------------------------------------- /textline_detect/src/hough_linefind.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "line_detect.h" 4 | #include 5 | 6 | std::vector> linefind( 7 | std::vector &boxes, 8 | const std::vector &lineimage, 9 | const std::vector &lineblocker); 10 | -------------------------------------------------------------------------------- /textline_detect/src/after_search.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "line_detect.h" 4 | #include 5 | 6 | void after_search( 7 | std::vector &boxes, 8 | std::vector> &line_box_chain, 9 | const std::vector &lineblocker, 10 | const std::vector &idimage); 11 | -------------------------------------------------------------------------------- /textline_detect/src/ruby_search.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "line_detect.h" 4 | #include 5 | 6 | void search_ruby( 7 | std::vector &boxes, 8 | std::vector> &line_box_chain, 9 | const std::vector &lineblocker, 10 | const std::vector &idimage); 11 | -------------------------------------------------------------------------------- /textline_detect/src/prepare.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "line_detect.h" 4 | #include 5 | 6 | void prepare_id_image( 7 | std::vector &idimage, 8 | std::vector &idimage_main, 9 | std::vector &boxes); 10 | 11 | void make_lineblocker( 12 | std::vector &lineblocker, 13 | const std::vector &sepimage); 14 | -------------------------------------------------------------------------------- /make_traindata/memo.txt: -------------------------------------------------------------------------------- 1 | sudo apt install build-essential pkg-config libfreetype-dev python3-venv 2 | python3 -m venv venv/mkdata 3 | . venv/mkdata/bin/activate 4 | cd make_traindata 5 | make -C render_font 6 | pip install -U webdataset numpy pillow 7 | ./make_traindata1.py 64 1024 8 | 9 | 10 | CPLUS_INCLUDE_PATH=$(python3 -c 'import numpy; print(numpy.get_include())') cythonize -i make_traindata/processer3.pyx 11 | -------------------------------------------------------------------------------- /make_traindata/render_font/Makefile: -------------------------------------------------------------------------------- 1 | all: render_font test_font 2 | 3 | render_font: render_font.cpp 4 | g++ -O2 -march=native render_font.cpp -o render_font -std=c++17 `pkg-config --cflags --libs freetype2` 5 | 6 | test_font: test_font.cpp 7 | g++ -O2 -march=native test_font.cpp -o test_font -std=c++17 `pkg-config --cflags --libs freetype2` 8 | 9 | .PHONY : clean 10 | clean: 11 | @rm -rf render_font test_font 12 | -------------------------------------------------------------------------------- /textline_detect/src/split_doubleline.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "line_detect.h" 4 | #include 5 | 6 | void split_doubleline1( 7 | std::vector &boxes, 8 | std::vector> &line_box_chain); 9 | 10 | void split_doubleline2( 11 | std::vector &boxes, 12 | std::vector> &line_box_chain); 13 | 14 | void split_doubleline3( 15 | std::vector &boxes, 16 | std::vector> &line_box_chain); 17 | -------------------------------------------------------------------------------- /textline_detect/src/process.h: -------------------------------------------------------------------------------- 1 | // 2 | // process.h 3 | // linedetector 4 | // 5 | // Created by rei9 on 2025/03/03. 6 | // 7 | 8 | #ifndef process_h 9 | #define process_h 10 | 11 | #include 12 | 13 | void process( 14 | const std::vector &lineimage, 15 | const std::vector &sepimage, 16 | std::vector &boxes); 17 | 18 | void print_chaininfo( 19 | const std::vector &boxes, 20 | const std::vector> &line_box_chain); 21 | 22 | #endif /* process_h */ 23 | -------------------------------------------------------------------------------- /convert_fp16_onnx.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | import onnx 3 | from onnxconverter_common import float16 4 | 5 | def convert_fp16(filename, outfilename): 6 | model = onnx.load(filename) 7 | model_fp16 = float16.convert_float_to_float16(model) 8 | onnx.save(model_fp16, outfilename) 9 | 10 | convert_fp16("TextDetector.onnx","TextDetector.fp16.onnx") 11 | convert_fp16("CodeDecoder.onnx","CodeDecoder.fp16.onnx") 12 | convert_fp16("TransformerDecoder.onnx","TransformerDecoder.fp16.onnx") 13 | convert_fp16("TransformerEncoder.onnx","TransformerEncoder.fp16.onnx") 14 | -------------------------------------------------------------------------------- /textline_detect/Makefile: -------------------------------------------------------------------------------- 1 | TARGET = linedetect 2 | SRCDIR = ./src 3 | 4 | CXXFLAGS += -std=c++20 -O3 -MMD -MP 5 | LDFLAGS += 6 | LIBS += 7 | 8 | SOURCES = $(wildcard $(SRCDIR)/minpack/*.cpp) 9 | SOURCES += $(wildcard $(SRCDIR)/*.cpp) 10 | OBJECTS = $(SOURCES:.cpp=.o) 11 | DEPENDS = $(OBJECTS:.o=.d) 12 | 13 | $(TARGET): $(OBJECTS) $(LIBS) 14 | $(CXX) -o $@ $^ $(LDFLAGS) 15 | 16 | $(OBJDIR)/%.o: $(SRCDIR)/%.cpp 17 | $(CXX) $(CXXFLAGS) -o $@ -c $< 18 | 19 | clean: 20 | -rm -f $(OBJECTS) $(DEPENDS) 21 | 22 | distclean: 23 | -rm -f $(TARGET) 24 | 25 | -include $(DEPENDS) -------------------------------------------------------------------------------- /make_traindata/save_feature.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import glob 3 | import os 4 | 5 | def load(): 6 | path = 'code_features' 7 | data = {} 8 | for filename in sorted(glob.glob(os.path.join(path,'*.npy'))): 9 | print(filename) 10 | codestr = os.path.splitext(os.path.basename(filename))[0] 11 | horizontal = codestr[0] == 'h' 12 | code = int(codestr[1:], 16) 13 | 14 | value = np.load(filename).astype(np.float16) 15 | if horizontal: 16 | data['hori_%d'%code] = value 17 | else: 18 | data['vert_%d'%code] = value 19 | np.savez('features', **data) 20 | 21 | if __name__=='__main__': 22 | load() 23 | -------------------------------------------------------------------------------- /make_traindata/merge_data.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import sys 3 | import glob 4 | import os 5 | 6 | def load(path): 7 | base_path = 'code_features' 8 | os.makedirs(base_path, exist_ok=True) 9 | for filename in sorted(glob.glob(os.path.join(path,'*.npy'))): 10 | print(filename) 11 | value = np.load(filename) 12 | basename = os.path.basename(filename) 13 | 14 | filename2 = os.path.join(base_path, basename) 15 | if os.path.exists(filename2): 16 | data2 = np.load(filename2) 17 | value = np.concatenate([value,data2], axis=0) 18 | np.save(filename2, value) 19 | 20 | if __name__=='__main__': 21 | for path in sys.argv[1:]: 22 | load(path) 23 | -------------------------------------------------------------------------------- /dataset/filter_fixdata.py: -------------------------------------------------------------------------------- 1 | import json 2 | import glob 3 | import os 4 | 5 | data_path = 'train_data2' 6 | 7 | jsonfiles = sorted(glob.glob(os.path.join(data_path, '*.json'))) 8 | 9 | for jsonfile in jsonfiles: 10 | with open(jsonfile, 'r', encoding='utf-8') as file: 11 | data = json.load(file) 12 | 13 | for i, pos in enumerate(data['textbox']): 14 | text = pos['text'] 15 | if text is None: 16 | continue 17 | 18 | if len(text) > 1: 19 | c = text.encode('utf-32-be') 20 | t = c[:4].decode('utf-32-be') 21 | data['textbox'][i]['text'] = t 22 | 23 | with open(jsonfile, 'w', encoding='utf-8') as file: 24 | json.dump(data, file, indent=2, ensure_ascii=False) 25 | -------------------------------------------------------------------------------- /fine_image/empty_image1.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import sys 4 | from PIL import Image, ImageDraw 5 | import json 6 | 7 | if len(sys.argv) < 2: 8 | print(sys.argv[0],'target.png') 9 | exit(1) 10 | 11 | target_file = sys.argv[1] 12 | 13 | with open(target_file+'.json', 'r', encoding='utf-8') as file: 14 | out_dict = json.load(file) 15 | 16 | out_dict['textbox'] = [] 17 | 18 | with open(target_file+'.json', 'w', encoding='utf-8') as file: 19 | json.dump(out_dict, file, indent=2, ensure_ascii=False) 20 | 21 | linesfile = target_file + '.lines.png' 22 | sepsfile = target_file + '.seps.png' 23 | 24 | lines_all = Image.open(linesfile) 25 | draw = ImageDraw.Draw(lines_all) 26 | draw.rectangle((0,0,lines_all.width,lines_all.height), fill=0) 27 | lines_all.save(linesfile) 28 | 29 | lines_all = Image.open(sepsfile) 30 | draw = ImageDraw.Draw(lines_all) 31 | draw.rectangle((0,0,lines_all.width,lines_all.height), fill=0) 32 | lines_all.save(sepsfile) 33 | -------------------------------------------------------------------------------- /textline_detect/src/line_detect.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #define _USE_MATH_DEFINES 3 | 4 | struct charbox { 5 | int id; 6 | int block; 7 | int idx; 8 | int subtype; // 1: vert, 2,4: (10, rubybase, 11, ruby), 8: sp, 16: emphasis / 32: alone ruby 512: tab split 9 | int subidx; 10 | int double_line; 11 | int page; 12 | int section; 13 | double direction; 14 | float cx; 15 | float cy; 16 | float w; 17 | float h; 18 | float code1; 19 | float code2; 20 | float code4; 21 | float code8; 22 | }; 23 | 24 | extern double ruby_cutoff; 25 | extern double rubybase_cutoff; 26 | extern double emphasis_cutoff; 27 | extern double space_cutoff; 28 | extern float line_valueth; 29 | extern float sep_valueth; 30 | extern float sep_valueth2; 31 | extern const float sep_clusterth; 32 | extern const int linearea_th; 33 | extern double allowwidth_next_block; 34 | extern double allow_sizediff; 35 | extern double chain_line_ratio; 36 | extern int page_divide; 37 | extern int scale; 38 | 39 | extern int run_mode; 40 | extern int width; 41 | extern int height; 42 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 lithium0003 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /train_data3/check_code.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import numpy as np 4 | 5 | from dataset.data_transformer import UNICODE_WHITESPACE_CHARACTERS 6 | 7 | train_data3 = 'train_data3' 8 | 9 | def process(): 10 | txtfiles = sorted(glob.glob(os.path.join(train_data3,'*','*.txt'))) 11 | with np.load(os.path.join(train_data3, 'features.npz')) as data: 12 | codes = set([chr(c) for c in set([int(s.split('_')[1]) for s in data.files])]) 13 | pass_char = set(['\n','\uFFF9','\uFFFA','\uFFFB'] + UNICODE_WHITESPACE_CHARACTERS) 14 | 15 | all_remain = set() 16 | for filename in txtfiles: 17 | print(filename) 18 | with open(filename) as rf: 19 | lines = [s for s in rf.read().splitlines() if s.strip()] 20 | txt = '\n'.join(lines) 21 | remain = set(txt) - codes - pass_char 22 | for c in remain: 23 | print(c, hex(ord(c)), ord(c), 'not found in', filename) 24 | all_remain |= remain 25 | print('--------not-found--------') 26 | for c in sorted([ord(c) for c in all_remain]): 27 | print(chr(c), hex(c), c) 28 | 29 | if __name__=='__main__': 30 | process() -------------------------------------------------------------------------------- /textline_detect/src/search_loop.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "line_detect.h" 4 | #include 5 | 6 | void sort_chain( 7 | std::vector &chain, 8 | const std::vector &boxes); 9 | 10 | void fix_chain_info( 11 | std::vector &boxes, 12 | std::vector> &line_box_chain); 13 | 14 | void search_chain( 15 | const std::vector &chain, 16 | const std::vector &boxes, 17 | float &direction, 18 | double &w, double &h, 19 | float &start_cx, float &start_cy, 20 | float &end_cx, float &end_cy); 21 | 22 | void make_track_line( 23 | std::vector &x, 24 | std::vector &y, 25 | float &direction, 26 | double &w, double &h, 27 | const std::vector &boxes, 28 | const std::vector> &line_box_chain, 29 | const std::vector &lineblocker, 30 | int chainid, 31 | int extra_len = 0); 32 | 33 | std::vector create_chainid_map( 34 | const std::vector &boxes, 35 | const std::vector> &line_box_chain, 36 | const std::vector &lineblocker, 37 | double ratio = 1.0, 38 | int extra_len = 0); 39 | 40 | void search_loop( 41 | std::vector &boxes, 42 | std::vector> &line_box_chain, 43 | const std::vector &lineblocker, 44 | const std::vector &idimage, 45 | const std::vector &sepimage); 46 | -------------------------------------------------------------------------------- /run_ocr.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import os 4 | 5 | models = [] 6 | if os.path.exists('TextDetector.mlpackage') and os.path.exists('TransformerEncoder.mlpackage') and os.path.exists('TransformerDecoder.mlpackage'): 7 | models.append('coreml') 8 | if (os.path.exists("TextDetector.quant.onnx") or os.path.exists("TextDetector.onnx")) and os.path.exists('TransformerEncoder.onnx') and os.path.exists('TransformerDecoder.onnx'): 9 | models.append('onnx') 10 | if os.path.exists('model.pt') and os.path.exists('model3.pt'): 11 | models.append('torch') 12 | 13 | if models[0] == 'coreml': 14 | print('coreml') 15 | from process_ocr_coreml import OCR_coreml_Processer as OCR_Processer 16 | elif models[0] == 'onnx': 17 | print('onnx') 18 | from process_ocr_onnx import OCR_onnx_Processer as OCR_Processer 19 | elif models[0] == 'torch': 20 | print('torch') 21 | from process_ocr_torch import OCR_torch_Processer as OCR_Processer 22 | 23 | processer = OCR_Processer() 24 | 25 | if __name__=='__main__': 26 | import sys 27 | import glob 28 | 29 | if len(sys.argv) < 2: 30 | print(sys.argv[0], 'target_image') 31 | 32 | target_files = [] 33 | resize = 1.0 34 | for arg in sys.argv[1:]: 35 | if arg.startswith('--resize='): 36 | resize = float(arg.split('=')[1]) 37 | else: 38 | target_files += glob.glob(arg) 39 | target_files = sorted(target_files) 40 | 41 | for target_file in target_files: 42 | print(target_file) 43 | processer.call_OCR(target_file, resize) 44 | -------------------------------------------------------------------------------- /make_traindata/render_font/get_wikipedia.py: -------------------------------------------------------------------------------- 1 | import json 2 | import urllib.parse 3 | import urllib.request 4 | 5 | # Wikipedia API 6 | WIKI_URL = "https://%s.wikipedia.org/w/api.php?" 7 | 8 | # 記事を1件、ランダムに取得するクエリのパラメータを生成する 9 | def set_url_random(count=1): 10 | params = { 11 | 'action': 'query', 12 | 'format': 'json', 13 | 'list': 'random', #ランダムに取得 14 | 'rnnamespace': 0, #標準名前空間を指定する 15 | 'rnlimit': count, #結果数の上限 16 | } 17 | return params 18 | 19 | # 指定された記事の内容を取得するクエリのパラメータを生成する 20 | def set_url_extract(pageid): 21 | params = { 22 | 'action': 'query', 23 | 'format': 'json', 24 | 'prop': 'extracts', 25 | 'pageids': pageid, #記事のID 26 | 'explaintext': '', 27 | } 28 | return params 29 | 30 | #ランダムな記事IDを取得 31 | def get_random_wordid(lang='ja', count=1): 32 | request_url = WIKI_URL%lang 33 | request_url += urllib.parse.urlencode(set_url_random(count)) 34 | html = urllib.request.urlopen(request_url, timeout=10) 35 | html_json = json.loads(html.read().decode('utf-8')) 36 | pageid = [page['id'] for page in html_json['query']['random']] 37 | return pageid 38 | 39 | def get_word_content(pageid, lang='ja'): 40 | request_url = WIKI_URL%lang 41 | request_url += urllib.parse.urlencode(set_url_extract(pageid)) 42 | html = urllib.request.urlopen(request_url, timeout=10) 43 | html_json = json.loads(html.read().decode('utf-8')) 44 | explaintext = html_json['query']['pages'][str(pageid)]['extract'] 45 | return explaintext 46 | 47 | if __name__ == '__main__': 48 | pageid = get_random_wordid(count=1) 49 | extract = get_word_content(pageid[0]) 50 | print(extract) 51 | 52 | -------------------------------------------------------------------------------- /train_data3/make_data.py: -------------------------------------------------------------------------------- 1 | from . import get_aozora 2 | from . import get_wikipedia 3 | import os 4 | 5 | def aozora(): 6 | urls = get_aozora.get_aozora_urls() 7 | os.makedirs(os.path.join('train_data3','aozora'), exist_ok=True) 8 | 9 | count = 0 10 | for url in urls: 11 | print(count, '/', len(urls), url) 12 | while True: 13 | try: 14 | txt = get_aozora.get_contents(url) 15 | except: 16 | continue 17 | break 18 | if not txt.strip(): 19 | continue 20 | filename = os.path.join('train_data3','aozora','aozora_%08d.txt'%count) 21 | with open(filename, 'w') as wf: 22 | wf.write(txt) 23 | count += 1 24 | 25 | def wikipedia(lang='ja', rep=40): 26 | os.makedirs(os.path.join('train_data3','wikipedia_%s'%lang), exist_ok=True) 27 | 28 | pageids = set() 29 | for i in range(rep): 30 | print(i, '/', rep) 31 | pageids |= set(get_wikipedia.get_random_wordid(lang=lang, count=500)) 32 | count = 0 33 | for pageid in pageids: 34 | print(count, '/', len(pageids), pageid) 35 | while True: 36 | try: 37 | txt = get_wikipedia.get_word_content(pageid, lang=lang) 38 | except: 39 | continue 40 | break 41 | if not txt.strip(): 42 | continue 43 | filename = os.path.join('train_data3','wikipedia_%s'%lang,'wikipedia_%s_%08d.txt'%(lang,count)) 44 | with open(filename, 'w') as wf: 45 | wf.write(txt) 46 | count += 1 47 | 48 | def process(): 49 | aozora() 50 | for lang in ['en','ko','fr','de','it']: 51 | wikipedia(lang) 52 | for lang in ['ja']: 53 | wikipedia(lang, rep=40*4) 54 | 55 | if __name__=='__main__': 56 | process() 57 | -------------------------------------------------------------------------------- /process_ocr_torch.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | 4 | from process_ocr_base import OCR_Processer 5 | 6 | class OCR_torch_Processer(OCR_Processer): 7 | def __init__(self, model_size='xl'): 8 | super().__init__() 9 | from models.detector import TextDetectorModel, CenterNetDetector 10 | from models.transformer import ModelDimensions, Transformer, TransformerPredictor 11 | 12 | model = TextDetectorModel(model_size=model_size) 13 | if os.path.exists('model.pt'): 14 | data = torch.load('model.pt', map_location="cpu", weights_only=True) 15 | model.load_state_dict(data['model_state_dict']) 16 | 17 | detector = CenterNetDetector(model.detector) 18 | if torch.cuda.is_available(): 19 | device = 'cuda' 20 | elif torch.backends.mps.is_available(): 21 | device = 'mps' 22 | else: 23 | device = 'cpu' 24 | device = torch.device(device) 25 | detector.to(device=device) 26 | detector.eval() 27 | self.detector = detector 28 | self.device = device 29 | 30 | if os.path.exists('model3.pt'): 31 | data = torch.load('model3.pt', map_location="cpu", weights_only=True) 32 | config = ModelDimensions(**data['config']) 33 | model = Transformer(**config.__dict__) 34 | model.load_state_dict(data['model_state_dict']) 35 | else: 36 | config = ModelDimensions() 37 | model = Transformer(**config.__dict__) 38 | model2 = TransformerPredictor(model.encoder, model.decoder) 39 | model2.to(device) 40 | model2.eval() 41 | self.transformer = model2 42 | 43 | def call_detector(self, image_input): 44 | images = torch.from_numpy(image_input / 255.).permute(0,3,1,2).to(device=self.device) 45 | with torch.no_grad(): 46 | heatmap, features = self.detector(images) 47 | heatmap = heatmap.cpu().numpy() 48 | features = features.cpu().numpy() 49 | return heatmap, features 50 | 51 | def call_transformer(self, encoder_input): 52 | encoder_input = torch.tensor(encoder_input, device=self.device) 53 | pred = self.transformer(encoder_input).squeeze(0).cpu().numpy() 54 | return pred -------------------------------------------------------------------------------- /dataset/downloader_src/main.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | 10 | bool isLive = true; 11 | 12 | size_t onReceive(char* ptr, size_t size, size_t nmemb, void* stream) { 13 | int64_t *count = (int64_t *)stream; 14 | const size_t sizes = size * nmemb; 15 | std::cout.write(ptr, sizes); 16 | *count += sizes; 17 | return sizes; 18 | } 19 | 20 | void sigpipe_handler(int unused) 21 | { 22 | isLive = false; 23 | } 24 | 25 | int main(int argc, const char * argv[]) { 26 | if(argc < 2) { 27 | return 0; 28 | } 29 | 30 | signal(SIGPIPE, sigpipe_handler); 31 | 32 | CURL *curl = curl_easy_init(); 33 | if (curl == nullptr) { 34 | curl_easy_cleanup(curl); 35 | return 1; 36 | } 37 | int64_t count = 0; 38 | curl_easy_setopt(curl, CURLOPT_URL, argv[1]); 39 | curl_easy_setopt(curl, CURLOPT_LOW_SPEED_TIME, 60L); 40 | curl_easy_setopt(curl, CURLOPT_LOW_SPEED_LIMIT, 30L); 41 | curl_easy_setopt(curl, CURLOPT_FOLLOWLOCATION, 1); 42 | curl_easy_setopt(curl, CURLOPT_WRITEFUNCTION, onReceive); 43 | curl_easy_setopt(curl, CURLOPT_WRITEDATA, &count); 44 | 45 | // 通信実行 46 | CURLcode res = curl_easy_perform(curl); 47 | if (res != CURLE_OK) { 48 | fprintf(stderr, "curl_easy_perform() failed: %s\n", curl_easy_strerror(res)); 49 | int retry = 1000; 50 | while(isLive && res != CURLE_OK && (count > 0 || retry-- > 0)) { 51 | if(count > 0) { 52 | std::stringstream ss; 53 | ss << count << "-"; 54 | std::cerr << "range:" << ss.str() << std::endl; 55 | curl_easy_setopt(curl, CURLOPT_RANGE, ss.str().c_str()); 56 | res = curl_easy_perform(curl); 57 | if(res != CURLE_OK) { 58 | fprintf(stderr, "curl_easy_perform() failed: %s\n", curl_easy_strerror(res)); 59 | fprintf(stderr, "retry remain %d\n", retry); 60 | std::this_thread::sleep_for(std::chrono::milliseconds(500)); 61 | } 62 | } 63 | else { 64 | res = curl_easy_perform(curl); 65 | if(res != CURLE_OK) { 66 | fprintf(stderr, "curl_easy_perform() failed: %s\n", curl_easy_strerror(res)); 67 | fprintf(stderr, "retry remain %d\n", retry); 68 | std::this_thread::sleep_for(std::chrono::milliseconds(500)); 69 | } 70 | } 71 | } 72 | } 73 | 74 | // std::cerr << "downloaded:" << count << std::endl; 75 | 76 | curl_easy_cleanup(curl); 77 | return 0; 78 | } 79 | -------------------------------------------------------------------------------- /textline_detect/src/process.cpp: -------------------------------------------------------------------------------- 1 | #include "line_detect.h" 2 | #include "process.h" 3 | #include "prepare.h" 4 | #include "hough_linefind.h" 5 | #include "search_loop.h" 6 | #include "after_search.h" 7 | #include "space_check.h" 8 | 9 | #include 10 | #include 11 | #include 12 | #include 13 | 14 | void print_chaininfo( 15 | const std::vector &boxes, 16 | const std::vector> &line_box_chain) 17 | { 18 | fprintf(stderr, "print_chaininfo\n"); 19 | fprintf(stderr, "****************\n"); 20 | for(int i = 0; i < line_box_chain.size(); i++) { 21 | fprintf(stderr, " chain %d len %lu\n", i, line_box_chain[i].size()); 22 | int boxid1 = line_box_chain[i].front(); 23 | int boxid2 = line_box_chain[i].back(); 24 | 25 | fprintf(stderr, " %f %d x %f y %f w %f h %f - %d x %f y %f w %f h %f\n", 26 | boxes[boxid1].direction / M_PI * 180, 27 | boxid1, boxes[boxid1].cx, boxes[boxid1].cy, boxes[boxid1].w, boxes[boxid1].h, 28 | boxid2, boxes[boxid2].cx, boxes[boxid2].cy, boxes[boxid2].w, boxes[boxid2].h); 29 | 30 | std::copy(line_box_chain[i].begin(), line_box_chain[i].end(), std::ostream_iterator(std::cerr, ",")); 31 | std::cerr << std::endl; 32 | fprintf(stderr, "=================\n"); 33 | for(int j = 0; j < line_box_chain[i].size(); j++) { 34 | int boxid = line_box_chain[i][j]; 35 | fprintf(stderr, " %d %d %d, %d %d %d, %f x %f y %f w %f h %f t %d\n", 36 | i, j, boxid, 37 | boxes[boxid].idx, boxes[boxid].subidx, boxes[boxid].subtype, 38 | boxes[boxid].direction / M_PI * 180, 39 | boxes[boxid].cx, boxes[boxid].cy, boxes[boxid].w, boxes[boxid].h, 40 | boxes[boxid].subtype); 41 | } 42 | } 43 | fprintf(stderr, "****************\n"); 44 | } 45 | 46 | void process( 47 | const std::vector &lineimage, 48 | const std::vector &sepimage, 49 | std::vector &boxes) 50 | { 51 | std::vector idimage; 52 | std::vector idimage_main; 53 | prepare_id_image(idimage, idimage_main, boxes); 54 | 55 | std::vector lineblocker; 56 | make_lineblocker(lineblocker, sepimage); 57 | 58 | auto line_box_chain = linefind(boxes, lineimage, lineblocker); 59 | // print_chaininfo(boxes, line_box_chain); 60 | 61 | search_loop(boxes, line_box_chain, lineblocker, idimage_main, sepimage); 62 | // print_chaininfo(boxes, line_box_chain); 63 | 64 | after_search(boxes, line_box_chain, lineblocker, idimage); 65 | 66 | space_chack(boxes); 67 | } 68 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.obj 2 | *.exe 3 | 4 | make_traindata/data/*font/ 5 | make_traindata/data/handwritten/ 6 | data/background/ 7 | result/ 8 | 9 | # Byte-compiled / optimized / DLL files 10 | __pycache__/ 11 | *.py[cod] 12 | *$py.class 13 | 14 | # C extensions 15 | *.so 16 | 17 | # Distribution / packaging 18 | .Python 19 | build/ 20 | develop-eggs/ 21 | dist/ 22 | downloads/ 23 | eggs/ 24 | .eggs/ 25 | lib/ 26 | lib64/ 27 | parts/ 28 | sdist/ 29 | var/ 30 | wheels/ 31 | pip-wheel-metadata/ 32 | share/python-wheels/ 33 | *.egg-info/ 34 | .installed.cfg 35 | *.egg 36 | MANIFEST 37 | 38 | # PyInstaller 39 | # Usually these files are written by a python script from a template 40 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 41 | *.manifest 42 | *.spec 43 | 44 | # Installer logs 45 | pip-log.txt 46 | pip-delete-this-directory.txt 47 | 48 | # Unit test / coverage reports 49 | htmlcov/ 50 | .tox/ 51 | .nox/ 52 | .coverage 53 | .coverage.* 54 | .cache 55 | nosetests.xml 56 | coverage.xml 57 | *.cover 58 | *.py,cover 59 | .hypothesis/ 60 | .pytest_cache/ 61 | 62 | # Translations 63 | *.mo 64 | *.pot 65 | 66 | # Django stuff: 67 | *.log 68 | local_settings.py 69 | db.sqlite3 70 | db.sqlite3-journal 71 | 72 | # Flask stuff: 73 | instance/ 74 | .webassets-cache 75 | 76 | # Scrapy stuff: 77 | .scrapy 78 | 79 | # Sphinx documentation 80 | docs/_build/ 81 | 82 | # PyBuilder 83 | target/ 84 | 85 | # Jupyter Notebook 86 | .ipynb_checkpoints 87 | 88 | # IPython 89 | profile_default/ 90 | ipython_config.py 91 | 92 | # pyenv 93 | .python-version 94 | 95 | # pipenv 96 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 97 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 98 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 99 | # install all needed dependencies. 100 | #Pipfile.lock 101 | 102 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 103 | __pypackages__/ 104 | 105 | # Celery stuff 106 | celerybeat-schedule 107 | celerybeat.pid 108 | 109 | # SageMath parsed files 110 | *.sage.py 111 | 112 | # Environments 113 | .env 114 | .venv 115 | env/ 116 | venv/ 117 | ENV/ 118 | env.bak/ 119 | venv.bak/ 120 | 121 | # Spyder project settings 122 | .spyderproject 123 | .spyproject 124 | 125 | # Rope project settings 126 | .ropeproject 127 | 128 | # mkdocs documentation 129 | /site 130 | 131 | # mypy 132 | .mypy_cache/ 133 | .dmypy.json 134 | dmypy.json 135 | 136 | # Pyre type checker 137 | .pyre/ 138 | data/load_font/load_font.obj 139 | -------------------------------------------------------------------------------- /make_traindata/make_traindata1.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | # pip install -U pillow webdataset matplotlib 4 | import numpy as np 5 | from PIL import Image 6 | import webdataset as wds 7 | from multiprocessing import Pool, Manager 8 | from pathlib import Path 9 | from functools import partial 10 | import os 11 | 12 | from render_font.generate_random_txt import get_random_text 13 | samples_per_file = 100 14 | 15 | data_path = Path('train_data1') 16 | data_path.mkdir(exist_ok=True) 17 | 18 | def get_filepath(train=True): 19 | if train: 20 | return str(data_path / 'train%08d.tar') 21 | else: 22 | return str(data_path / 'test%08d.tar') 23 | 24 | def process(i, semaphore): 25 | rng = np.random.default_rng() 26 | semaphore.acquire() 27 | while True: 28 | try: 29 | d = get_random_text(rng) 30 | if np.count_nonzero(d['image']) == 0: 31 | continue 32 | if d['image'].shape[0] >= (1 << 27) or d['image'].shape[1] >= (1 << 27) or d['image'].shape[0] * d['image'].shape[1] >= (1 << 29): 33 | continue 34 | d['i'] = i 35 | w = d['image'].shape[1] // 2 * 2 36 | h = d['image'].shape[0] // 2 * 2 37 | d['image'] = d['image'][:h,:w] 38 | d['sep_image'] = np.asarray(Image.fromarray(d['sep_image']).resize((w // 2, h // 2))) 39 | d['textline_image'] = np.asarray(Image.fromarray(d['textline_image']).resize((w // 2, h // 2))) 40 | d['position'] = d['position'].astype(np.float32) 41 | d['code_list'] = d['code_list'].astype(np.int32) 42 | return d 43 | except Exception as e: 44 | print(e,i,'error') 45 | continue 46 | 47 | def create_data(train=True, count=1): 48 | if count < 1: 49 | return 50 | with Manager() as manager: 51 | semaphore = manager.Semaphore(1000) 52 | with wds.ShardWriter(get_filepath(train=train), maxcount=samples_per_file) as sink: 53 | with Pool(processes=os.cpu_count()*2) as p: 54 | for d in p.imap_unordered(partial(process, semaphore=semaphore), range(samples_per_file * count)): 55 | print(d['i'],samples_per_file * count) 56 | sink.write({ 57 | "__key__": '%014d'%d['i'], 58 | "txt": d['str'], 59 | "image.png": d['image'], 60 | "sepline.png": d['sep_image'], 61 | "textline.png": d['textline_image'], 62 | "position.npy": d['position'], 63 | "code_list.npy": d['code_list'], 64 | }) 65 | semaphore.release() 66 | 67 | if __name__=="__main__": 68 | import sys 69 | import multiprocessing 70 | multiprocessing.set_start_method('spawn') 71 | 72 | if len(sys.argv) < 3: 73 | test_count = 1 74 | train_count = 1 75 | else: 76 | test_count = int(sys.argv[1]) 77 | train_count = int(sys.argv[2]) 78 | 79 | create_data(train=False, count=test_count) 80 | create_data(train=True, count=train_count) 81 | -------------------------------------------------------------------------------- /textline_detect/src/minpack/enorm.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #define one 1.0 4 | #define zero 0.0 5 | #define rdwarf 3.834e-20 6 | #define rgiant 1.304e19 7 | 8 | double enorm(int n, double *x) 9 | { 10 | /* 11 | c ********** 12 | c 13 | c function enorm 14 | c 15 | c given an n-vector x, this function calculates the 16 | c euclidean norm of x. 17 | c 18 | c the euclidean norm is computed by accumulating the sum of 19 | c squares in three different sums. the sums of squares for the 20 | c small and large components are scaled so that no overflows 21 | c occur. non-destructive underflows are permitted. underflows 22 | c and overflows do not occur in the computation of the unscaled 23 | c sum of squares for the intermediate components. 24 | c the definitions of small, intermediate and large components 25 | c depend on two constants, rdwarf and rgiant. the main 26 | c restrictions on these constants are that rdwarf**2 not 27 | c underflow and rgiant**2 not overflow. the constants 28 | c given here are suitable for every known computer. 29 | c 30 | c the function statement is 31 | c 32 | c double precision function enorm(n,x) 33 | c 34 | c where 35 | c 36 | c n is a positive integer input variable. 37 | c 38 | c x is an input array of length n. 39 | c 40 | c subprograms called 41 | c 42 | c fortran-supplied ... dabs,dsqrt 43 | c 44 | c argonne national laboratory. minpack project. march 1980. 45 | c burton s. garbow, kenneth e. hillstrom, jorge j. more 46 | c 47 | c ********** 48 | */ 49 | double s1 = zero; 50 | double s2 = zero; 51 | double s3 = zero; 52 | double x1max = zero; 53 | double x3max = zero; 54 | 55 | double agiant = rgiant/(double)n; 56 | 57 | for(int i = 0; i < n; i++) { 58 | double xabs = fabs(x[i]); 59 | if (xabs <= rdwarf || xabs >= agiant) { 60 | if (xabs > rdwarf) { 61 | /* 62 | c 63 | c sum for large components. 64 | c 65 | */ 66 | if (xabs > x1max) { 67 | s1 = one + s1*(x1max/xabs)*(x1max/xabs); 68 | x1max = xabs; 69 | } else { 70 | s1 += (xabs/x1max)*(xabs/x1max); 71 | } 72 | } 73 | else { 74 | /* 75 | c 76 | c sum for small components. 77 | c 78 | */ 79 | if (xabs > x3max) { 80 | s3 = one + s3*(x3max/xabs)*(x3max/xabs); 81 | x3max = xabs; 82 | } 83 | else { 84 | if (xabs != zero) { 85 | s3 += (xabs/x3max)*(xabs/x3max); 86 | } 87 | } 88 | } 89 | } 90 | else { 91 | /* 92 | c 93 | c sum for intermediate components. 94 | c 95 | */ 96 | s2 += xabs*xabs; 97 | } 98 | } 99 | 100 | /* 101 | c 102 | c calculation of norm. 103 | c 104 | */ 105 | if (s1 != zero) { 106 | return x1max*sqrt(s1+(s2/x1max)/x1max); 107 | } 108 | else { 109 | if (s2 != zero) { 110 | if (s2 >= x3max) return sqrt(s2*(one+(x3max/s2)*(x3max*s3))); 111 | return sqrt(x3max*((s2/x3max)+(x3max*s3))); 112 | } 113 | return x3max*sqrt(s3); 114 | } 115 | } -------------------------------------------------------------------------------- /process_ocr_coreml.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | import numpy as np 3 | import itertools 4 | 5 | from process_ocr_base import OCR_Processer, max_encoderlen, max_decoderlen, decoder_SOT, decoder_EOT, decoder_MSK, modulo_list, calc_predid 6 | 7 | class OCR_coreml_Processer(OCR_Processer): 8 | def __init__(self): 9 | super().__init__() 10 | import coremltools as ct 11 | 12 | print('load') 13 | self.mlmodel_detector = ct.models.MLModel('TextDetector.mlpackage') 14 | 15 | self.mlmodel_transformer_encoder = ct.models.MLModel('TransformerEncoder.mlpackage') 16 | self.mlmodel_transformer_decoder = ct.models.MLModel('TransformerDecoder.mlpackage') 17 | 18 | def call_detector(self, image_input): 19 | input_image = Image.fromarray(image_input.squeeze(0).astype(np.uint8), mode="RGB") 20 | 21 | output = self.mlmodel_detector.predict({'image': input_image}) 22 | heatmap = output['heatmap'] 23 | features = output['feature'] 24 | return heatmap, features 25 | 26 | def call_transformer(self, encoder_input): 27 | key_mask = np.where((encoder_input == 0).all(axis=-1)[:,None,None,:], float("-inf"), 0).astype(np.float32) 28 | encoder_output = self.mlmodel_transformer_encoder.predict({ 29 | 'encoder_input': encoder_input, 30 | 'key_mask': key_mask, 31 | })['encoder_output'] 32 | 33 | decoder_input = np.zeros(shape=(1, max_decoderlen), dtype=np.int32) 34 | decoder_input[0,:] = decoder_MSK 35 | rep_count = 8 36 | for k in range(rep_count): 37 | output = self.mlmodel_transformer_decoder.predict({ 38 | 'encoder_output': encoder_output, 39 | 'decoder_input': decoder_input, 40 | 'key_mask': key_mask, 41 | }) 42 | 43 | listp = [] 44 | listi = [] 45 | for m in modulo_list: 46 | pred_p1 = output['modulo_%d'%m] 47 | topi = np.argpartition(-pred_p1, 5, axis=-1)[...,:5] 48 | topp = np.take_along_axis(pred_p1, topi, axis=-1) 49 | listp.append(np.transpose(topp, (2,0,1))) 50 | listi.append(np.transpose(topi, (2,0,1))) 51 | 52 | pred_ids = np.stack([np.stack(x) for x in itertools.product(*listi)]) 53 | pred_p = np.stack([np.stack(x) for x in itertools.product(*listp)]) 54 | pred_ids = np.transpose(pred_ids, (1,0,2,3)) 55 | pred_p = np.transpose(pred_p, (1,0,2,3)) 56 | pred_p = np.exp(np.mean(np.log(np.maximum(pred_p, 1e-10)), axis=0)) 57 | decoder_output = calc_predid(*pred_ids) 58 | pred_p[decoder_output > 0x3FFFF] = 0 59 | maxi = np.argmax(pred_p, axis=0) 60 | decoder_output = np.take_along_axis(decoder_output, maxi[None,...], axis=0)[0] 61 | pred_p = np.take_along_axis(pred_p, maxi[None,...], axis=0)[0] 62 | if np.all(pred_p[decoder_output > 0] > 0.99): 63 | print(f'[{k} early stop]') 64 | break 65 | 66 | remask = decoder_output > 0x3FFFF 67 | remask = np.logical_or(remask, pred_p < 0.9) 68 | if not np.any(remask): 69 | print(f'---[{k} early stop]---') 70 | break 71 | 72 | decoder_input[:,:] = np.where(remask, decoder_MSK, decoder_output) 73 | 74 | pred = decoder_output[0] 75 | return pred -------------------------------------------------------------------------------- /process_ocr_onnx.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import itertools 3 | 4 | from process_ocr_base import OCR_Processer, max_encoderlen, max_decoderlen, decoder_SOT, decoder_EOT, decoder_MSK, modulo_list, calc_predid 5 | 6 | class OCR_onnx_Processer(OCR_Processer): 7 | def __init__(self): 8 | super().__init__() 9 | import onnxruntime 10 | import os 11 | 12 | print('load') 13 | if os.path.exists("TextDetector.quant.onnx"): 14 | print('quant') 15 | onnx_detector = onnxruntime.InferenceSession("TextDetector.quant.onnx") 16 | else: 17 | onnx_detector = onnxruntime.InferenceSession("TextDetector.onnx") 18 | self.onnx_detector = onnx_detector 19 | self.onnx_transformer_encoder = onnxruntime.InferenceSession("TransformerEncoder.onnx") 20 | self.onnx_transformer_decoder = onnxruntime.InferenceSession("TransformerDecoder.onnx") 21 | 22 | def call_detector(self, image_input): 23 | images = (image_input / 255.).transpose(0,3,1,2).astype(np.float32) 24 | heatmap, features = self.onnx_detector.run(['heatmap','feature'], {'image': images}) 25 | return heatmap, features 26 | 27 | def call_transformer(self, encoder_input): 28 | key_mask = np.where((encoder_input == 0).all(axis=-1)[:,None,None,:], float("-inf"), 0).astype(np.float32) 29 | encoder_output, = self.onnx_transformer_encoder.run(['encoder_output'], {'encoder_input': encoder_input.astype(np.float32), 'key_mask': key_mask.astype(np.float32)}) 30 | 31 | decoder_input = np.zeros(shape=(1, max_decoderlen), dtype=np.int64) 32 | decoder_input[0,:] = decoder_MSK 33 | rep_count = 8 34 | for k in range(rep_count): 35 | output = self.onnx_transformer_decoder.run(['modulo_%d'%m for m in modulo_list], { 36 | 'encoder_output': encoder_output, 37 | 'decoder_input': decoder_input, 38 | 'key_mask': key_mask, 39 | }) 40 | 41 | listp = [] 42 | listi = [] 43 | for pred_p1 in output: 44 | topi = np.argpartition(-pred_p1, 4, axis=-1)[...,:4] 45 | topp = np.take_along_axis(pred_p1, topi, axis=-1) 46 | listp.append(np.transpose(topp, (2,0,1))) 47 | listi.append(np.transpose(topi, (2,0,1))) 48 | 49 | pred_ids = np.stack([np.stack(x) for x in itertools.product(*listi)]) 50 | pred_p = np.stack([np.stack(x) for x in itertools.product(*listp)]) 51 | pred_ids = np.transpose(pred_ids, (1,0,2,3)) 52 | pred_p = np.transpose(pred_p, (1,0,2,3)) 53 | pred_p = np.exp(np.mean(np.log(np.maximum(pred_p, 1e-10)), axis=0)) 54 | decoder_output = calc_predid(*pred_ids) 55 | pred_p[decoder_output > 0x3FFFF] = 0 56 | maxi = np.argmax(pred_p, axis=0) 57 | decoder_output = np.take_along_axis(decoder_output, maxi[None,...], axis=0)[0] 58 | pred_p = np.take_along_axis(pred_p, maxi[None,...], axis=0)[0] 59 | if np.all(pred_p[decoder_output > 0] > 0.99): 60 | print(f'[{k} early stop]') 61 | break 62 | 63 | remask = decoder_output > 0x3FFFF 64 | remask = np.logical_or(remask, pred_p < 0.9) 65 | if not np.any(remask): 66 | print(f'---[{k} early stop]---') 67 | break 68 | 69 | decoder_input[:,:] = np.where(remask, decoder_MSK, decoder_output) 70 | 71 | pred = decoder_output[0] 72 | return pred -------------------------------------------------------------------------------- /textline_detect/src/minpack/fdjac2.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #define zero 0.0 4 | #define mcheps 2.2204460492503131e-16 5 | 6 | #define MAX(a,b) ((a) > (b) ? (a) : (b)) 7 | 8 | int fdjac2(int (*fcn)(int,int,double *,double *),int m,int n,double *x,double *fvec,double *fjac,int ldfjac,double epsfcn) 9 | { 10 | /* 11 | c ********** 12 | c 13 | c subroutine fdjac2 14 | c 15 | c this subroutine computes a forward-difference approximation 16 | c to the m by n jacobian matrix associated with a specified 17 | c problem of m functions in n variables. 18 | c 19 | c the subroutine statement is 20 | c 21 | c subroutine fdjac2(fcn,m,n,x,fvec,fjac,ldfjac,iflag,epsfcn,wa) 22 | c 23 | c where 24 | c 25 | c fcn is the name of the user-supplied subroutine which 26 | c calculates the functions. fcn must be declared 27 | c in an external statement in the user calling 28 | c program, and should be written as follows. 29 | c 30 | c subroutine fcn(m,n,x,fvec,iflag) 31 | c integer m,n,iflag 32 | c double precision x(n),fvec(m) 33 | c ---------- 34 | c calculate the functions at x and 35 | c return this vector in fvec. 36 | c ---------- 37 | c return 38 | c end 39 | c 40 | c the value of iflag should not be changed by fcn unless 41 | c the user wants to terminate execution of fdjac2. 42 | c in this case set iflag to a negative integer. 43 | c 44 | c m is a positive integer input variable set to the number 45 | c of functions. 46 | c 47 | c n is a positive integer input variable set to the number 48 | c of variables. n must not exceed m. 49 | c 50 | c x is an input array of length n. 51 | c 52 | c fvec is an input array of length m which must contain the 53 | c functions evaluated at x. 54 | c 55 | c fjac is an output m by n array which contains the 56 | c approximation to the jacobian matrix evaluated at x. 57 | c 58 | c ldfjac is a positive integer input variable not less than m 59 | c which specifies the leading dimension of the array fjac. 60 | c 61 | c iflag is an integer variable which can be used to terminate 62 | c the execution of fdjac2. see description of fcn. 63 | c 64 | c epsfcn is an input variable used in determining a suitable 65 | c step length for the forward-difference approximation. this 66 | c approximation assumes that the relative errors in the 67 | c functions are of the order of epsfcn. if epsfcn is less 68 | c than the machine precision, it is assumed that the relative 69 | c errors in the functions are of the order of the machine 70 | c precision. 71 | c 72 | c wa is a work array of length m. 73 | c 74 | c subprograms called 75 | c 76 | c user-supplied ...... fcn 77 | c 78 | c minpack-supplied ... dpmpar 79 | c 80 | c fortran-supplied ... dabs,dmax1,dsqrt 81 | c 82 | c argonne national laboratory. minpack project. march 1980. 83 | c burton s. garbow, kenneth e. hillstrom, jorge j. more 84 | c 85 | c ********** 86 | */ 87 | double *wa = new double[m]; 88 | double eps = sqrt(MAX(epsfcn,mcheps)); 89 | for(int j = 0; j < n; j++) { 90 | double temp = x[j]; 91 | double h = eps*fabs(temp); 92 | if (h == zero) h = eps; 93 | x[j]= temp + h; 94 | int iflag; 95 | if((iflag = fcn(m,n,x,wa)) == 0) { 96 | for(int i = 0; i < m; i++) { 97 | fjac[i+j*ldfjac] = (wa[i] - fvec[i])/h; 98 | } 99 | } 100 | x[j] = temp; 101 | if(iflag != 0) { 102 | delete[] wa; 103 | return iflag; 104 | } 105 | } 106 | delete[] wa; 107 | return 0; 108 | } -------------------------------------------------------------------------------- /fine_image/plot_image1.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import numpy as np 4 | import sys 5 | from PIL import Image 6 | import json 7 | import os 8 | 9 | try: 10 | from pillow_heif import register_heif_opener 11 | register_heif_opener() 12 | except ImportError: 13 | pass 14 | 15 | import matplotlib.pyplot as plt 16 | from matplotlib.font_manager import FontProperties 17 | 18 | if len(sys.argv) < 2: 19 | print(sys.argv[0],'target.png') 20 | exit(1) 21 | 22 | fprop = FontProperties(fname='data/jpfont/NotoSerifJP-Regular.otf') 23 | 24 | target_file = sys.argv[1] 25 | if len(sys.argv) > 2: 26 | for arg in sys.argv[2:]: 27 | if arg == 'kr': 28 | fprop = FontProperties(fname='data/krfont/NotoSerifKR-Regular.otf') 29 | print('kr font') 30 | 31 | im0 = Image.open(target_file).convert('RGB') 32 | 33 | linesfile = target_file + '.lines.png' 34 | if os.path.exists(linesfile): 35 | lines_all = Image.open(linesfile) 36 | lines_all = lines_all.resize((lines_all.width * 4, lines_all.height * 4), resample=Image.Resampling.BILINEAR) 37 | 38 | fig = plt.figure() 39 | plt.imshow(im0) 40 | plt.imshow(lines_all, cmap='gray', alpha=0.5) 41 | fig.subplots_adjust(left=0, right=1, bottom=0, top=1) 42 | 43 | sepsfile = target_file + '.seps.png' 44 | if os.path.exists(sepsfile): 45 | seps_all = Image.open(sepsfile) 46 | seps_all = seps_all.resize((seps_all.width * 4, seps_all.height * 4), resample=Image.Resampling.BILINEAR) 47 | 48 | fig = plt.figure() 49 | plt.imshow(im0) 50 | plt.imshow(seps_all, cmap='gray', alpha=0.5) 51 | fig.subplots_adjust(left=0, right=1, bottom=0, top=1) 52 | 53 | fig = plt.figure() 54 | plt.imshow(im0) 55 | fig.subplots_adjust(left=0, right=1, bottom=0, top=1) 56 | 57 | with open(target_file+'.json', 'r', encoding='utf-8') as file: 58 | outdict = json.load(file) 59 | 60 | for i, pos in enumerate(outdict['textbox']): 61 | cx = pos['cx'] 62 | cy = pos['cy'] 63 | w = pos['w'] 64 | h = pos['h'] 65 | 66 | points = [ 67 | [cx - w / 2, cy - h / 2], 68 | [cx + w / 2, cy - h / 2], 69 | [cx + w / 2, cy + h / 2], 70 | [cx - w / 2, cy + h / 2], 71 | [cx - w / 2, cy - h / 2], 72 | ] 73 | points = np.array(points) 74 | if pos['p_code8'] > 0.5: 75 | c = 'red' 76 | else: 77 | c = 'cyan' 78 | plt.plot(points[:,0], points[:,1],color=c) 79 | if pos['p_code2'] > 0.5: 80 | points = [ 81 | [cx - w / 2 - 1, cy - h / 2 - 1], 82 | [cx + w / 2 + 1, cy - h / 2 - 1], 83 | [cx + w / 2 + 1, cy + h / 2 + 1], 84 | [cx - w / 2 - 1, cy + h / 2 + 1], 85 | [cx - w / 2 - 1, cy - h / 2 - 1], 86 | ] 87 | points = np.array(points) 88 | plt.plot(points[:,0], points[:,1],color='yellow') 89 | if pos['p_code1'] > 0.5: 90 | points = [ 91 | [cx - w / 2 + 1, cy - h / 2 + 1], 92 | [cx + w / 2 - 1, cy - h / 2 + 1], 93 | [cx + w / 2 - 1, cy + h / 2 - 1], 94 | [cx - w / 2 + 1, cy + h / 2 - 1], 95 | [cx - w / 2 + 1, cy - h / 2 + 1], 96 | ] 97 | points = np.array(points) 98 | plt.plot(points[:,0], points[:,1],color='magenta') 99 | if pos['p_code4'] > 0.5: 100 | points = [ 101 | [cx - w / 2 + 2, cy - h / 2 + 2], 102 | [cx + w / 2 - 2, cy - h / 2 + 2], 103 | [cx + w / 2 - 2, cy + h / 2 - 2], 104 | [cx - w / 2 + 2, cy + h / 2 - 2], 105 | [cx - w / 2 + 2, cy - h / 2 + 2], 106 | ] 107 | points = np.array(points) 108 | plt.plot(points[:,0], points[:,1],color='blue') 109 | 110 | if pos['text']: 111 | if pos['p_code1'] > 0.5: 112 | c = 'green' 113 | else: 114 | c = 'blue' 115 | plt.gca().text(cx, cy, pos['text'], fontsize=28, color=c, fontproperties=fprop) 116 | 117 | plt.show() 118 | -------------------------------------------------------------------------------- /make_traindata/render_font/get_aozora.py: -------------------------------------------------------------------------------- 1 | import json 2 | import sys 3 | import urllib.parse 4 | import urllib.request 5 | import os 6 | import zipfile 7 | import io 8 | import csv 9 | import re 10 | from html.parser import HTMLParser 11 | 12 | code_list = {} 13 | with open('data/codepoints.csv') as f: 14 | reader = csv.reader(f) 15 | for row in reader: 16 | d1,d2,d3 = row[0].split('-') 17 | d1 = int(d1) 18 | d2 = int(d2) 19 | d3 = int(d3) 20 | c = row[1] 21 | c = int(c, 16) 22 | if c > 0x10FFFF: 23 | txt = chr((c & 0xFFFF0000) >> 16) + chr((c & 0xFFFF)) 24 | else: 25 | txt = chr(c) 26 | code_list['%d-%02d-%02d'%(d1,d2,d3)] = txt 27 | 28 | def get_aozora_urls(): 29 | aozora_csv_url = 'https://www.aozora.gr.jp/index_pages/list_person_all_extended_utf8.zip' 30 | 31 | xhtml_urls = [] 32 | html = urllib.request.urlopen(aozora_csv_url) 33 | with zipfile.ZipFile(io.BytesIO(html.read())) as myzip: 34 | with myzip.open('list_person_all_extended_utf8.csv') as myfile: 35 | reader = csv.reader(io.TextIOWrapper(myfile)) 36 | idx = -1 37 | for row in reader: 38 | if idx < 0: 39 | idx = [i for i, x in enumerate(row) if 'URL' in x] 40 | idx = [i for i in idx if 'HTML' in row[i]] 41 | if len(idx) == 0: 42 | exit() 43 | idx = idx[0] 44 | continue 45 | if row[idx].startswith('https://www.aozora.gr.jp/cards/'): 46 | xhtml_urls.append(row[idx]) 47 | return sorted(set(xhtml_urls)) 48 | 49 | class MyHTMLParser(HTMLParser): 50 | def __init__(self, *args, **kwargs): 51 | super().__init__(*args, **kwargs) 52 | self.main = False 53 | self.count = 0 54 | self.startpos = (-1,-1) 55 | self.endpos = (-1,-1) 56 | 57 | def handle_starttag(self, tag, attrs): 58 | if tag == 'div': 59 | if self.main: 60 | self.count += 1 61 | elif ('class', 'main_text') in attrs: 62 | self.main = True 63 | self.startpos = self.getpos() 64 | 65 | def handle_endtag(self, tag): 66 | if tag == 'div': 67 | if self.main: 68 | if self.count == 0: 69 | self.endpos = self.getpos() 70 | else: 71 | self.count -= 1 72 | 73 | def get_contents(url): 74 | html = urllib.request.urlopen(url, timeout=10) 75 | contents = html.read().decode('cp932') 76 | parser = MyHTMLParser() 77 | parser.feed(contents) 78 | maintext = [] 79 | for lineno, line in enumerate(contents.splitlines()): 80 | if parser.startpos[0] == lineno + 1: 81 | maintext.append(line[parser.startpos[1]:]) 82 | elif parser.startpos[0] < lineno + 1 <= parser.endpos[0]: 83 | if parser.endpos[0] == lineno + 1: 84 | if parser.endpos[1] == 0: 85 | pass 86 | else: 87 | maintext.append(line[:parser.endpos[1]]) 88 | else: 89 | maintext.append(line) 90 | maintext = '\n'.join(maintext) 91 | maintext = re.sub(r'/″\', '〴〵', maintext) 92 | maintext = re.sub(r'/\', '〳〵', maintext) 93 | maintext = re.sub(r'(.*?).*?(.*?).*?', '\uFFF9\\1\uFFFA\\2\uFFFB', maintext) 94 | m = True 95 | while m: 96 | m = re.search(r'', maintext) 97 | if m: 98 | maintext = maintext[:m.start()] + code_list[m.group(1)] + maintext[m.end():] 99 | maintext = re.sub(r'.*?', r'', maintext) 100 | maintext = re.sub(r'<[^>]*?>', r'', maintext) 101 | return maintext 102 | 103 | if __name__ == '__main__': 104 | from util_funcs import decode_ruby 105 | 106 | urls = get_aozora_urls() 107 | for u in urls: 108 | print(u) 109 | print(decode_ruby(get_contents(u))) -------------------------------------------------------------------------------- /dataset/data_fixdata.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import glob 3 | import os 4 | import json 5 | from PIL import Image 6 | import numpy as np 7 | 8 | try: 9 | from pillow_heif import register_heif_opener 10 | register_heif_opener() 11 | except ImportError: 12 | pass 13 | 14 | from .processer import process2 15 | Image.MAX_IMAGE_PIXELS = 1000000000 16 | 17 | rng = np.random.default_rng() 18 | 19 | class FixDataDataset(torch.utils.data.Dataset): 20 | def __init__(self, data_path, repeat_count): 21 | super().__init__() 22 | self.data_path = data_path 23 | self.repeat_count = repeat_count 24 | 25 | self.jsonfiles = sorted(glob.glob(os.path.join(data_path, '*.json'))) 26 | self.imagefiles = [os.path.splitext(f)[0] for f in self.jsonfiles] 27 | self.sepsfiles = [f+'.seps.png' for f in self.imagefiles] 28 | self.linesfiles = [f+'.lines.png' for f in self.imagefiles] 29 | 30 | self.jsons = [] 31 | for jsonfile in self.jsonfiles: 32 | with open(jsonfile, 'r', encoding='utf-8') as file: 33 | self.jsons.append(json.load(file)) 34 | 35 | self.positions = [] 36 | self.codelists = [] 37 | for data in self.jsons: 38 | position = np.zeros(shape=(len(data['textbox']), 4), dtype=np.float32) 39 | codelist = np.zeros(shape=(len(data['textbox']), 2), dtype=np.int32) 40 | for i, pos in enumerate(data['textbox']): 41 | cx = pos['cx'] 42 | cy = pos['cy'] 43 | w = pos['w'] 44 | h = pos['h'] 45 | position[i,0] = cx 46 | position[i,1] = cy 47 | position[i,2] = w 48 | position[i,3] = h 49 | text = pos['text'] 50 | if text is not None: 51 | c = int.from_bytes(text.encode("utf-32-le"), byteorder='little') 52 | else: 53 | c = 0 54 | code1 = 1 if pos['p_code1'] > 0.5 else 0 55 | code2 = 2 if pos['p_code2'] > 0.5 else 0 56 | code4 = 4 if pos['p_code4'] > 0.5 else 0 57 | code8 = 8 if pos['p_code8'] > 0.5 else 0 58 | code = code1 + code2 + code4 + code8 59 | assert text is None or len(text) == 1, f"{text}{c}" 60 | codelist[i,0] = c 61 | codelist[i,1] = code 62 | self.positions.append(position) 63 | self.codelists.append(codelist) 64 | 65 | def __len__(self): 66 | return len(self.jsonfiles) * self.repeat_count 67 | 68 | def __getitem__(self, idx): 69 | idx = idx % len(self.jsonfiles) 70 | 71 | im0 = np.asarray(Image.open(self.imagefiles[idx]).convert('RGB')) 72 | seps = np.asarray(Image.open(self.sepsfiles[idx])) 73 | lines = np.asarray(Image.open(self.linesfiles[idx])) 74 | posision = self.positions[idx] 75 | codelist = self.codelists[idx] 76 | 77 | image, mapimage, indexmap = process2(im0, lines, seps, posision, codelist) 78 | return image, mapimage, indexmap 79 | 80 | if __name__ == '__main__': 81 | import matplotlib.pyplot as plt 82 | from torchvision.transforms import ColorJitter 83 | 84 | transform = ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.5) 85 | 86 | dataset = FixDataDataset('train_data2',100) 87 | dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True) 88 | 89 | for sample in dataloader: 90 | image, labelmap, idmap = sample 91 | image = transform(image) 92 | 93 | plt.figure() 94 | if len(image[0].shape) > 2: 95 | plt.imshow(image[0].permute(1,2,0)) 96 | else: 97 | plt.imshow(image[0]) 98 | 99 | plt.figure() 100 | plt.subplot(2,4,1) 101 | if len(image[0].shape) > 2: 102 | plt.imshow(image[0].permute(1,2,0)) 103 | else: 104 | plt.imshow(image[0]) 105 | for i in range(5): 106 | plt.subplot(2,4,2+i) 107 | plt.imshow(labelmap[0,i]) 108 | plt.subplot(2,4,7) 109 | plt.imshow(idmap[0,0]) 110 | plt.subplot(2,4,8) 111 | plt.imshow(idmap[0,1]) 112 | plt.show() 113 | -------------------------------------------------------------------------------- /make_traindata/make_traindata3.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import numpy as np 4 | import glob 5 | from scipy.ndimage import gaussian_filter 6 | from PIL import Image 7 | import os 8 | 9 | from render_font.generate_random_txt import get_random_text, get_random_text2 10 | from processer3 import random_background, random_mono, random_single, random_double, process3 11 | 12 | if os.path.exists('TextDetector.mlpackage'): 13 | print('coreml') 14 | from process_coreml import call_model 15 | elif os.path.exists('TextDetector.onnx') or os.path.exists('TextDetector.quant.onnx') or os.path.exists('TextDetector.infer.onnx'): 16 | print('onnx') 17 | from process_onnx import call_model 18 | else: 19 | print('torch') 20 | from process_torch import call_model 21 | 22 | Image.MAX_IMAGE_PIXELS = 1000000000 23 | 24 | imagelist = glob.glob('data/background/*', recursive=True) 25 | rng = np.random.default_rng() 26 | 27 | def random_salt(x, s, prob=0.05): 28 | sizex = x.shape[1] 29 | sizey = x.shape[0] 30 | s = min(max(1, int(s / 4)), rng.integers(1, 16)) 31 | shape = ((sizey + s)//s, (sizex + s)//s) 32 | noise = rng.choice([0,1,np.nan], p=[prob / 2, 1 - prob, prob / 2], size=shape).astype(x.dtype) 33 | noise = np.repeat(noise, s, axis=0) 34 | noise = np.repeat(noise, s, axis=1) 35 | noise = noise[:sizey, :sizex] 36 | return np.nan_to_num(x * noise, nan=1) 37 | 38 | def random_distortion(im, s): 39 | if rng.random() < 0.3: 40 | alpha = min(0.4 * rng.random(), 20 / max(1,s)) 41 | im += alpha * rng.normal(size=im.shape) 42 | im = np.clip(im, 0, 1) 43 | if rng.random() < 0.3: 44 | sigma = min(s / 8, 1.5*rng.random()) 45 | im = gaussian_filter(im, sigma=sigma) 46 | im = np.clip(im, 0, 1) 47 | elif rng.random() < 0.3: 48 | blurred = gaussian_filter(im, sigma=5.) 49 | im = im + 10. * rng.random() * (im - blurred) 50 | im = np.clip(im, 0, 1) 51 | return im 52 | 53 | def transforms3(x1,minsize): 54 | if rng.random() < 0.2: 55 | im = random_salt(x1, minsize, prob=0.2 * rng.random()) 56 | 57 | if rng.random() < 0.3: 58 | bgimage = rng.choice(imagelist) 59 | bgimg = np.asarray(Image.open(bgimage).convert("RGBA"))[:,:,:3] 60 | im = random_background(x1, bgimg) 61 | elif rng.random() < 0.5: 62 | im = random_mono(x1) 63 | elif rng.random() < 0.5: 64 | im = random_single(x1) 65 | else: 66 | im = random_double(x1) 67 | return random_distortion(im, minsize) 68 | 69 | def save_value(code, value, vert): 70 | value = np.expand_dims(value.astype(np.float16), axis=0) 71 | base_path = 'code_features' 72 | os.makedirs(base_path, exist_ok=True) 73 | if vert == 0: 74 | filename = os.path.join(base_path,'h%08x.npy'%code) 75 | else: 76 | filename = os.path.join(base_path,'v%08x.npy'%code) 77 | if os.path.exists(filename): 78 | data = np.load(filename) 79 | value = np.concatenate([data,value], axis=0) 80 | np.save(filename, value) 81 | 82 | def proc(): 83 | while True: 84 | try: 85 | if rng.uniform() < 0.2: 86 | d = get_random_text(rng) 87 | else: 88 | d = get_random_text2(rng) 89 | except Exception as e: 90 | print(e,'error') 91 | continue 92 | if np.count_nonzero(d['image']) == 0: 93 | continue 94 | if d['image'].shape[0] >= (1 << 27) or d['image'].shape[1] >= (1 << 27) or d['image'].shape[0] * d['image'].shape[1] >= (1 << 29): 95 | continue 96 | 97 | image, position, minsize = process3(d['image'], d['position'].astype(np.float32)) 98 | image = transforms3(image, minsize) 99 | 100 | locations, glyphfeatures, vert = call_model(image) 101 | 102 | for i, loc in enumerate(locations): 103 | cx = loc[1] 104 | cy = loc[2] 105 | w = loc[3] 106 | h = loc[4] 107 | 108 | for j, (pcx, pcy, pw, ph) in enumerate(position): 109 | dist = np.sqrt((cx - pcx) ** 2 + (cy - pcy) ** 2) 110 | if dist < min(w/2, h/2): 111 | code = d['code_list'][j,0] 112 | v = vert[i] 113 | print(code, cx, cy, w, h, v) 114 | save_value(code, glyphfeatures[i], v) 115 | break 116 | 117 | if __name__=='__main__': 118 | proc() -------------------------------------------------------------------------------- /util_func.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import re 3 | import torch 4 | 5 | modulo_list = [1091,1093,1097] 6 | width = 768 7 | height = 768 8 | scale = 4 9 | feature_dim = 100 10 | 11 | def gaussian(x,a,x0,sigma): 12 | return a*np.exp(-(x-x0)**2/(2*sigma**2)) 13 | 14 | def sigmoid(x): 15 | return (np.tanh(x / 2) + 1) / 2 16 | 17 | def softmax(x): 18 | mx = np.max(x, axis=-1, keepdims=True) 19 | numerator = np.exp(x - mx) 20 | denominator = np.sum(numerator, axis=-1, keepdims=True) 21 | return numerator/denominator 22 | 23 | def calcHist(im): 24 | agg = 1 25 | rHist, bins = np.histogram(im[...,0], 256 // agg, (0.,255.)) 26 | gHist, bins = np.histogram(im[...,1], 256 // agg, (0.,255.)) 27 | bHist, bins = np.histogram(im[...,2], 256 // agg, (0.,255.)) 28 | 29 | maxPeakDiff = -1 30 | for hist in [rHist, gHist, bHist]: 31 | y = np.array(hist) 32 | x = np.linspace(0.,255.,len(y)) 33 | 34 | if np.sum(y) == 0: 35 | continue 36 | 37 | idx = np.argsort(-y) 38 | mu_y = x[idx[0]] 39 | mean_y = np.sum(x * y) / np.sum(y) 40 | 41 | if mu_y > mean_y: 42 | peak1 = y[idx[0]:] 43 | x1 = x[idx[0]:] 44 | peak1 = np.concatenate([peak1[::-1],peak1[1:]], axis=0) 45 | x1 = np.concatenate([(2 * x1[0] - x1[::-1]),x1[1:]], axis=0) 46 | else: 47 | peak1 = y[:idx[0]+1] 48 | x1 = x[:idx[0]+1] 49 | peak1 = np.concatenate([peak1[:-1],peak1[::-1]], axis=0) 50 | x1 = np.concatenate([x1[:-1],(x1 + x1[-1])], axis=0) 51 | 52 | mu = np.sum(x1 * peak1) / np.sum(peak1) 53 | sigma = np.sqrt(np.sum((x1 - mu)**2 * peak1) / np.sum(peak1)) 54 | fixmax = np.max(y[np.bitwise_and(mu + 10 > x, x > mu - 10)]) 55 | 56 | neg_peak = gaussian(x, fixmax, mu, sigma + 10) 57 | fixy = y - neg_peak 58 | fixy[fixy < 0] = 0 59 | 60 | if np.sum(fixy) == 0: 61 | continue 62 | 63 | fix_diff = np.sum(np.abs(x - mu) * fixy) / np.sum(fixy) 64 | idx = np.argsort(-fixy) 65 | fix_maxx = np.abs(x[idx[0]] - mu) 66 | 67 | maxPeakDiff = max(maxPeakDiff, fix_diff, fix_maxx) 68 | 69 | if False: 70 | import matplotlib.pyplot as plt 71 | plt.subplot(2,1,1) 72 | plt.plot(x,y) 73 | plt.plot(x,gaussian(x, fixmax, mu, sigma + 10)) 74 | plt.subplot(2,1,2) 75 | plt.plot(x,fixy) 76 | plt.vlines(mu, *plt.ylim(), 'r') 77 | plt.vlines(np.sum(x * fixy) / np.sum(fixy), *plt.ylim(), 'g') 78 | plt.show() 79 | 80 | return maxPeakDiff 81 | 82 | def pow_mod(a, b, n): 83 | x = 1 84 | y = a 85 | while b > 0: 86 | if b % 2 == 1: 87 | x = (x * y) % n 88 | y = (y * y) % n 89 | b //= 2 90 | return x % n 91 | 92 | def calc_predid(*args): 93 | m = modulo_list 94 | b = args 95 | assert(len(m) == len(b)) 96 | t = [] 97 | 98 | for k in range(len(m)): 99 | u = 0 100 | for j in range(k): 101 | if torch.is_tensor(t[j]): 102 | w = t[j].clone() 103 | elif isinstance(t[j],np.ndarray): 104 | w = np.array(t[j]) 105 | else: 106 | w = t[j] 107 | for i in range(j): 108 | w *= m[i] 109 | u += w 110 | tk = (b[k] - u) % m[k] 111 | for j in range(k): 112 | tk *= pow(m[j], m[k]-2, m[k]) 113 | #tk *= pow(m[j], -1, m[k]) 114 | tk = tk % m[k] 115 | t.append(tk) 116 | x = 0 117 | for k in range(len(t)): 118 | w = t[k] 119 | for i in range(k): 120 | w *= m[i] 121 | x += w 122 | mk = 1 123 | for k in range(len(m)): 124 | mk *= m[k] 125 | x = x % mk 126 | return x 127 | 128 | def decode_ruby(text, outtype='aozora'): 129 | if outtype == 'aozora': 130 | text = re.sub('\uFFF9(.*?)\uFFFA(.*?)\uFFFB',r'|\1《\2》', text) 131 | elif outtype == 'html': 132 | text = re.sub('\uFFF9(.*?)\uFFFA(.*?)\uFFFB',r'\1(\2)', text) 133 | elif outtype == 'noruby': 134 | text = re.sub('\uFFF9(.*?)\uFFFA(.*?)\uFFFB',r'\1', text) 135 | return text 136 | 137 | def encode_rubyhtml(text): 138 | text = re.sub('(.*?)\\((.*?)\\)', '\uFFF9\\1\uFFFA\\2\uFFFB', text) 139 | return text 140 | -------------------------------------------------------------------------------- /plot_json.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import matplotlib 4 | matplotlib.use("Agg") 5 | from matplotlib.font_manager import FontProperties 6 | import matplotlib.pyplot as plt 7 | import numpy as np 8 | from PIL import Image 9 | import json 10 | 11 | try: 12 | from pillow_heif import register_heif_opener 13 | register_heif_opener() 14 | except ImportError: 15 | pass 16 | 17 | def plot1(target_file): 18 | im0 = Image.open(target_file).convert('RGB') 19 | with open(target_file+'.json', 'r', encoding='utf-8') as file: 20 | data = json.load(file) 21 | 22 | fig = plt.figure(dpi=100, figsize=(im0.width / 100, im0.height / 100)) 23 | plt.imshow(im0) 24 | fig.subplots_adjust(left=0, right=1, bottom=0, top=1) 25 | 26 | fprop = FontProperties(fname='data/jpfont/NotoSerifJP-Regular.otf') 27 | for box in data['box']: 28 | cx = box['cx'] 29 | cy = box['cy'] 30 | w = box['w'] 31 | h = box['h'] 32 | text = box['text'] 33 | blockidx = box['blockidx'] 34 | lineidx = box['lineidx'] 35 | subidx = box['subidx'] 36 | ruby = box['ruby'] 37 | rubybase = box['rubybase'] 38 | emphasis = box['emphasis'] 39 | vertical = box['vertical'] 40 | 41 | points = [ 42 | [cx - w / 2, cy - h / 2], 43 | [cx + w / 2, cy - h / 2], 44 | [cx + w / 2, cy + h / 2], 45 | [cx - w / 2, cy + h / 2], 46 | [cx - w / 2, cy - h / 2], 47 | ] 48 | points = np.array(points) 49 | if vertical == 0: 50 | plt.plot(points[:,0], points[:,1],color='cyan') 51 | else: 52 | plt.plot(points[:,0], points[:,1],color='magenta') 53 | 54 | points = [ 55 | [cx - w / 2 - 1, cy - h / 2 - 1], 56 | [cx + w / 2 + 1, cy - h / 2 - 1], 57 | [cx + w / 2 + 1, cy + h / 2 + 1], 58 | [cx - w / 2 - 1, cy + h / 2 + 1], 59 | [cx - w / 2 - 1, cy - h / 2 - 1], 60 | ] 61 | points = np.array(points) 62 | if ruby == 1: 63 | plt.plot(points[:,0], points[:,1],color='green') 64 | elif rubybase == 1: 65 | plt.plot(points[:,0], points[:,1],color='yellow') 66 | 67 | points = [ 68 | [cx - w / 2 + 1, cy - h / 2 + 1], 69 | [cx + w / 2 - 1, cy - h / 2 + 1], 70 | [cx + w / 2 - 1, cy + h / 2 - 1], 71 | [cx - w / 2 + 1, cy + h / 2 - 1], 72 | [cx - w / 2 + 1, cy - h / 2 + 1], 73 | ] 74 | points = np.array(points) 75 | if emphasis == 1: 76 | plt.plot(points[:,0], points[:,1],color='blue') 77 | 78 | plt.gca().text(cx - w/2, cy - h/2, text, fontsize=max(w,h)*0.5, color='blue', fontproperties=fprop, ha='left', va='top') 79 | 80 | plt.savefig(target_file+'.boxplot.png') 81 | plt.close('all') 82 | 83 | def plot2(target_file): 84 | im0 = Image.open(target_file).convert('RGB') 85 | with open(target_file+'.json', 'r', encoding='utf-8') as file: 86 | data = json.load(file) 87 | 88 | fig = plt.figure(dpi=100, figsize=(im0.width / 100, im0.height / 100)) 89 | plt.imshow(im0) 90 | fig.subplots_adjust(left=0, right=1, bottom=0, top=1) 91 | 92 | fprop = FontProperties(fname='data/jpfont/NotoSerifJP-Regular.otf') 93 | for line in data['line']: 94 | x1 = line['x1'] 95 | y1 = line['y1'] 96 | x2 = line['x2'] 97 | y2 = line['y2'] 98 | text = line['text'] 99 | blockidx = line['blockidx'] 100 | lineidx = line['lineidx'] 101 | 102 | size = 0 103 | for box in data['box']: 104 | if blockidx == box['blockidx'] and lineidx == box['lineidx']: 105 | vertical = box['vertical'] 106 | size = max(size, max(box['w'], box['h'])*0.5) 107 | 108 | points = [ 109 | [x1, y1], 110 | [x2, y1], 111 | [x2, y2], 112 | [x1, y2], 113 | [x1, y1], 114 | ] 115 | points = np.array(points) 116 | if vertical == 0: 117 | plt.plot(points[:,0], points[:,1],color='cyan') 118 | rotation = 0 119 | plt.gca().text(x1, y2, text, fontsize=size, color='blue', fontproperties=fprop, rotation=rotation, ha='left', va='top') 120 | else: 121 | plt.plot(points[:,0], points[:,1],color='magenta') 122 | rotation = 270 123 | plt.gca().text(x1, y1, text, fontsize=size, color='blue', fontproperties=fprop, rotation=rotation, ha='right', va='top') 124 | 125 | 126 | plt.savefig(target_file+'.lineplot.png') 127 | plt.close('all') 128 | 129 | if __name__=='__main__': 130 | import sys 131 | 132 | plot1(sys.argv[1]) 133 | plot2(sys.argv[1]) -------------------------------------------------------------------------------- /quantize1_onnx.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | from onnxruntime.quantization import quantize, StaticQuantConfig, CalibrationDataReader, QuantType, QuantFormat 3 | from torch.utils.data import DataLoader 4 | from dataset.data_detector import get_dataset 5 | 6 | class QuntizationDataReader(CalibrationDataReader): 7 | def __init__(self): 8 | 9 | dataset = get_dataset(train=False) 10 | # dataloader = DataLoader(dataset, batch_size=8, num_workers=8) 11 | self.torch_dl = DataLoader(dataset, batch_size=1) 12 | 13 | self.enum_data = iter(self.torch_dl) 14 | self.count = 0 15 | 16 | def to_numpy(self, pt_tensor): 17 | return pt_tensor.detach().cpu().numpy() if pt_tensor.requires_grad else pt_tensor.cpu().numpy() 18 | 19 | def get_next(self): 20 | print(self.count) 21 | self.count += 1 22 | if self.count > 200: 23 | return None 24 | batch = next(self.enum_data, None) 25 | if batch is not None: 26 | return {'image': self.to_numpy(batch[0].float())} 27 | else: 28 | return None 29 | 30 | def rewind(self): 31 | self.enum_data = iter(self.torch_dl) 32 | 33 | 34 | def optimize1(nodes_to_exclude=None): 35 | qdr = QuntizationDataReader() 36 | 37 | config = StaticQuantConfig(qdr, 38 | quant_format=QuantFormat.QOperator, 39 | activation_type=QuantType.QUInt8, 40 | weight_type=QuantType.QInt8, 41 | nodes_to_exclude=nodes_to_exclude, 42 | extra_options={ 43 | 'CalibMovingAverage': True, 44 | }) 45 | quantize('TextDetector.pre.onnx', 46 | 'TextDetector.quant.onnx', 47 | config) 48 | 49 | def convert2(): 50 | import onnx 51 | model = onnx.load("TextDetector.quant.onnx") 52 | 53 | model.graph.input[0].type.tensor_type.elem_type = 10 54 | model.graph.output[0].type.tensor_type.elem_type = 10 55 | model.graph.output[1].type.tensor_type.elem_type = 10 56 | 57 | cast_node = onnx.helper.make_node(op_type='Cast', name='cast_'+model.graph.input[0].name, inputs=[model.graph.input[0].name], outputs=['cast_'+model.graph.input[0].name], to=1) 58 | 59 | node = [node for node in model.graph.node if model.graph.input[0].name in node.input][0] 60 | node.input[node.input.index(model.graph.input[0].name)] = 'cast_'+model.graph.input[0].name 61 | model.graph.node.insert(0, cast_node) 62 | 63 | cast_node = onnx.helper.make_node(op_type='Cast', name='cast_'+model.graph.output[0].name, inputs=['cast_'+model.graph.output[0].name], outputs=[model.graph.output[0].name], to=10) 64 | 65 | node = [node for node in model.graph.node if model.graph.output[0].name in node.output][0] 66 | node.output[0] = 'cast_'+model.graph.output[0].name 67 | model.graph.node.insert(model.graph.node.index(node)+1, cast_node) 68 | 69 | cast_node = onnx.helper.make_node(op_type='Cast', name='cast_'+model.graph.output[1].name, inputs=['cast_'+model.graph.output[1].name], outputs=[model.graph.output[1].name], to=10) 70 | 71 | node = [node for node in model.graph.node if model.graph.output[1].name in node.output][0] 72 | node.output[0] = 'cast_'+model.graph.output[1].name 73 | model.graph.node.insert(model.graph.node.index(node)+1, cast_node) 74 | 75 | graph = onnx.helper.make_graph(model.graph.node, model.graph.name, model.graph.input, model.graph.output, model.graph.initializer) 76 | info_model = onnx.helper.make_model(graph, opset_imports=model.opset_import) 77 | model_fixed = onnx.shape_inference.infer_shapes(info_model) 78 | 79 | onnx.checker.check_model(model_fixed) 80 | onnx.save(model_fixed, 'TextDetector.quant.fp16.onnx') 81 | 82 | if __name__ == "__main__": 83 | from onnxruntime.quantization.shape_inference import quant_pre_process 84 | from onnx import shape_inference 85 | import onnx 86 | import os 87 | 88 | if os.path.exists('TextDetector.quant.onnx'): 89 | os.remove('TextDetector.quant.onnx') 90 | if os.path.exists('TextDetector.pre.onnx'): 91 | os.remove('TextDetector.pre.onnx') 92 | 93 | quant_pre_process('TextDetector.onnx', 'TextDetector.pre.onnx', skip_symbolic_shape=True) 94 | 95 | model = onnx.load('TextDetector.pre.onnx') 96 | model = shape_inference.infer_shapes(model) 97 | outputs = [o.name for o in model.graph.output] 98 | nodes_to_exclude = [] 99 | for node in model.graph.node: 100 | if 'feature' in node.output: 101 | nodes_to_exclude.append(node.name) 102 | 103 | outputs = ['heatmap'] 104 | while outputs: 105 | next_intput = [] 106 | for output in outputs: 107 | for node in model.graph.node: 108 | if output in node.output: 109 | nodes_to_exclude.append(node.name) 110 | if node.op_type != 'Conv': 111 | next_intput += node.input 112 | outputs = list(set(next_intput)) 113 | 114 | nodes_to_exclude = list(set(nodes_to_exclude)) 115 | print(nodes_to_exclude) 116 | 117 | optimize1(nodes_to_exclude) 118 | 119 | convert2() 120 | 121 | -------------------------------------------------------------------------------- /dataset/data_detector.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import webdataset as wds 3 | from torch.utils.data import DataLoader 4 | import glob 5 | import os 6 | import numpy as np 7 | from PIL import Image 8 | from scipy.ndimage import gaussian_filter 9 | 10 | from .processer import random_background, random_mono, random_single, random_double, process 11 | 12 | Image.MAX_IMAGE_PIXELS = 1000000000 13 | 14 | rng = np.random.default_rng() 15 | imagelist = glob.glob('data/background/*', recursive=True) 16 | 17 | def random_salt(x, s, prob=0.05): 18 | sizex = x.shape[1] 19 | sizey = x.shape[0] 20 | s = min(max(1, int(s / 4)), rng.integers(1, 16)) 21 | shape = ((sizey + s)//s, (sizex + s)//s) 22 | noise = rng.choice([0,1,np.nan], p=[prob / 2, 1 - prob, prob / 2], size=shape).astype(x.dtype) 23 | noise = np.repeat(noise, s, axis=0) 24 | noise = np.repeat(noise, s, axis=1) 25 | noise = noise[:sizey, :sizex] 26 | return np.nan_to_num(x * noise, nan=1) 27 | 28 | def random_distortion(im, s): 29 | if rng.random() < 0.3: 30 | alpha = min(0.4 * rng.random(), 20 / max(1,s)) 31 | im += alpha * rng.normal(size=im.shape) 32 | im = np.clip(im, 0, 1) 33 | if rng.random() < 0.3: 34 | sigma = min(s / 8, 1.5*rng.random()) 35 | im = gaussian_filter(im, sigma=sigma) 36 | im = np.clip(im, 0, 1) 37 | elif rng.random() < 0.3: 38 | blurred = gaussian_filter(im, sigma=5.) 39 | im = im + 10. * rng.random() * (im - blurred) 40 | im = np.clip(im, 0, 1) 41 | return im 42 | 43 | def transforms3(x): 44 | x1,x2,x3,x4 = x 45 | if rng.random() < 0.2: 46 | x1 = random_salt(x1, x4, prob=0.2 * rng.random()) 47 | 48 | if rng.random() < 0.3: 49 | bgimage = rng.choice(imagelist) 50 | bgimg = np.asarray(Image.open(bgimage).convert("RGBA"))[:,:,:3] 51 | im = random_background(x1, bgimg) 52 | elif rng.random() < 0.5: 53 | im = random_mono(x1) 54 | elif rng.random() < 0.5: 55 | im = random_single(x1) 56 | else: 57 | im = random_double(x1) 58 | return random_distortion(im, x4), x2, x3 59 | 60 | 61 | def get_dataset(train=True, calib=False): 62 | local_disk = False 63 | downloader = os.path.join(os.path.dirname(__file__), 'downloader') 64 | if calib: 65 | if local_disk: 66 | shard_pattern = 'train_data1/test00000000.tar' 67 | else: 68 | shard_pattern = 'pipe:%s https://huggingface.co/datasets/lithium0003/findtextCenterNet_dataset/resolve/main/train_data1/test00000000.tar' 69 | shard_pattern = shard_pattern%(downloader) 70 | else: 71 | if train: 72 | if local_disk: 73 | shard_pattern = 'train_data1/train{00000000..00001023}.tar' 74 | else: 75 | shard_pattern = 'pipe:%s https://huggingface.co/datasets/lithium0003/findtextCenterNet_dataset/resolve/main/train_data1/train{00000000..00001023}.tar' 76 | shard_pattern = shard_pattern%(downloader) 77 | else: 78 | if local_disk: 79 | shard_pattern = 'train_data1/test{00000000..00000063}.tar' 80 | else: 81 | shard_pattern = 'pipe:%s https://huggingface.co/datasets/lithium0003/findtextCenterNet_dataset/resolve/main/train_data1/test{00000000..00000063}.tar' 82 | shard_pattern = shard_pattern%(downloader) 83 | dataset = ( 84 | wds.WebDataset(shard_pattern, shardshuffle=100) 85 | .shuffle(1000) 86 | .decode('l8') 87 | .rename( 88 | image='image.png', 89 | position='position.npy', 90 | textline='textline.png', 91 | sepline='sepline.png', 92 | codelist='code_list.npy', 93 | ) 94 | .to_tuple('image','textline','sepline','position','codelist') 95 | .map(process) 96 | .map(transforms3) 97 | ) 98 | return dataset 99 | 100 | if __name__=='__main__': 101 | import matplotlib.pylab as plt 102 | import time 103 | from dataset.multi import MultiLoader 104 | 105 | dataset = get_dataset(train=False) 106 | # dataloader = DataLoader(dataset, batch_size=1, num_workers=4) 107 | dataloader = MultiLoader(dataset.batched(1)) 108 | 109 | st = time.time() 110 | for sample in dataloader: 111 | print((time.time() - st) * 1000) 112 | image, labelmap, idmap = sample 113 | image = torch.tensor(image, dtype=torch.float) 114 | labelmap = torch.tensor(labelmap, dtype=torch.float) 115 | idmap = torch.tensor(idmap, dtype=torch.long) 116 | 117 | st = time.time() 118 | # continue 119 | 120 | plt.figure() 121 | if len(image[0].shape) > 2: 122 | plt.imshow(image[0].permute(1,2,0)) 123 | else: 124 | plt.imshow(image[0]) 125 | 126 | plt.figure() 127 | plt.subplot(2,4,1) 128 | if len(image[0].shape) > 2: 129 | plt.imshow(image[0].permute(1,2,0)) 130 | else: 131 | plt.imshow(image[0]) 132 | for i in range(5): 133 | plt.subplot(2,4,2+i) 134 | plt.imshow(labelmap[0,i]) 135 | plt.subplot(2,4,7) 136 | plt.imshow(idmap[0,0]) 137 | plt.subplot(2,4,8) 138 | plt.imshow(idmap[0,1]) 139 | plt.show() 140 | st = time.time() 141 | -------------------------------------------------------------------------------- /convert1_onnx.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | import onnx 3 | import onnxruntime 4 | import torch 5 | import numpy as np 6 | from PIL import Image 7 | import matplotlib 8 | matplotlib.use('Agg') 9 | import matplotlib.pyplot as plt 10 | import io 11 | import itertools 12 | 13 | from models.detector import TextDetectorModel, CenterNetDetector, CodeDecoder 14 | from util_func import calc_predid, width, height, feature_dim, sigmoid, modulo_list 15 | 16 | def convert1(): 17 | model = TextDetectorModel(pre_weights=False) 18 | data = torch.load('model.pt', map_location="cpu", weights_only=True) 19 | model.load_state_dict(data['model_state_dict']) 20 | detector = CenterNetDetector(model.detector) 21 | decoder = CodeDecoder(model.decoder) 22 | detector.eval() 23 | decoder.eval() 24 | 25 | ######################################################################### 26 | print('detector') 27 | 28 | example_input = torch.rand(1, 3, height, width) 29 | torch.onnx.export(detector, 30 | example_input, 31 | "TextDetector.onnx", 32 | input_names=['image'], 33 | output_names=['heatmap','feature'], 34 | dynamo=True, 35 | external_data=False, 36 | optimize=True, 37 | verify=True, 38 | opset_version=20) 39 | onnx.checker.check_model('TextDetector.onnx') 40 | 41 | ############################################################################ 42 | print('decoder') 43 | 44 | example_input = torch.rand(1, feature_dim) 45 | torch.onnx.export(decoder, 46 | example_input, 47 | "CodeDecoder.onnx", 48 | input_names=['feature_input'], 49 | output_names=['modulo_1091','modulo_1093','modulo_1097'], 50 | dynamo=True, 51 | external_data=False, 52 | optimize=True, 53 | verify=True, 54 | opset_version=20) 55 | onnx.checker.check_model('CodeDecoder.onnx') 56 | 57 | def cos_sim(v1, v2): 58 | return np.dot(v1, v2) / (np.linalg.norm(v1) * np.linalg.norm(v2)) 59 | 60 | def test_model(): 61 | print('test') 62 | plt.figure() 63 | plt.text(0.1,0.9,'test', fontsize=32) 64 | plt.axis('off') 65 | plt.tight_layout() 66 | 67 | buf = io.BytesIO() 68 | plt.savefig(buf, format='png') 69 | buf.seek(0) 70 | im = np.array(Image.open(buf).convert('RGB')) 71 | buf.close() 72 | 73 | im = im[:height,:width,:] 74 | im = np.pad(im, [[0,height-im.shape[0]], [0,width-im.shape[1]], [0,0]], 'constant', constant_values=((255,255),(255,255),(255,255))) 75 | 76 | image_input = im.astype(np.float32) 77 | image_input = np.expand_dims(image_input, 0).transpose(0,3,1,2) / 255 78 | print(image_input.shape) 79 | 80 | print('load') 81 | onnx_detector = onnxruntime.InferenceSession("TextDetector.onnx") 82 | onnx_decoder = onnxruntime.InferenceSession("CodeDecoder.onnx") 83 | 84 | print(' [ detector ] ') 85 | print('input:') 86 | for session_input in onnx_detector.get_inputs(): 87 | print(session_input.name, session_input.shape) 88 | print('output:') 89 | for session_output in onnx_detector.get_outputs(): 90 | print(session_output.name, session_output.shape) 91 | 92 | print(' [ decoder ] ') 93 | print('input:') 94 | for session_input in onnx_decoder.get_inputs(): 95 | print(session_input.name, session_input.shape) 96 | print('output:') 97 | for session_output in onnx_decoder.get_outputs(): 98 | print(session_output.name, session_output.shape) 99 | 100 | maps, feature = onnx_detector.run(['heatmap','feature'], {'image': image_input}) 101 | peakmap = maps[0,1,:,:] 102 | idxy, idxx = np.unravel_index(np.argsort(-peakmap.ravel()), peakmap.shape) 103 | results_dict = [] 104 | for y, x in zip(idxy, idxx): 105 | print(x,y,sigmoid(peakmap[y,x])) 106 | if sigmoid(peakmap[y,x]) < 0.5: 107 | break 108 | outnames = ['modulo_%d'%m for m in modulo_list] 109 | decode_outputs = onnx_decoder.run(outnames, {'feature_input': feature[:,:,y,x]}) 110 | 111 | p = [] 112 | id = [] 113 | for k,prob in enumerate(decode_outputs): 114 | prob = prob[0] 115 | idx = np.where(prob > 0.01)[0] 116 | if len(idx) == 0: 117 | idx = [np.argmax(prob)] 118 | if k == 0: 119 | for i in idx[:3]: 120 | id.append([i]) 121 | p.append([prob[i]]) 122 | else: 123 | id = [i1 + [i2] for i1, i2 in itertools.product(id, idx[:3])] 124 | p = [i1 + [prob[i2]] for i1, i2 in itertools.product(p, idx[:3])] 125 | p = [np.exp(np.mean([np.log(prob) for prob in probs])) for probs in p] 126 | i = [calc_predid(*ids) for ids in id] 127 | g = sorted([(prob, id) for prob,id in zip(p,i)], key=lambda x: x[0] if x[1] <= 0x10FFFF else 0, reverse=True) 128 | prob,idx = g[0] 129 | if idx <= 0x10FFFF: 130 | c = chr(idx) 131 | else: 132 | c = None 133 | print(prob, idx, c) 134 | print(feature[0,:,y,x].max(), feature[0,:,y,x].min()) 135 | results_dict.append((feature[0,:,y,x], idx, c)) 136 | print() 137 | 138 | 139 | for i in range(len(results_dict)): 140 | for j in range(i+1, len(results_dict)): 141 | s = cos_sim(results_dict[i][0], results_dict[j][0]) 142 | d = np.linalg.norm(results_dict[i][0] - results_dict[j][0]) 143 | print(s,d, i,j,results_dict[i][1:],results_dict[j][1:]) 144 | 145 | if __name__ == '__main__': 146 | convert1() 147 | test_model() 148 | -------------------------------------------------------------------------------- /textline_detect/src/prepare.cpp: -------------------------------------------------------------------------------- 1 | #include "prepare.h" 2 | #include 3 | #include 4 | #include 5 | 6 | int search_connection(const std::vector &immap, std::vector &idmap) 7 | { 8 | idmap.resize(width*height, -1); 9 | std::vector visitmap(width*height, false); 10 | int remain_count = width*height; 11 | 12 | int cluster_idx = 0; 13 | while(remain_count > 0) { 14 | int xi = -1; 15 | int yi = -1; 16 | for(int y = 0; y < height; y++) { 17 | for(int x = 0; x < width; x++) { 18 | if(visitmap[y*width + x]) continue; 19 | 20 | if(!immap[y*width + x]) { 21 | visitmap[y*width + x] = true; 22 | remain_count--; 23 | continue; 24 | } 25 | 26 | xi = x; 27 | yi = y; 28 | goto find_loop1; 29 | } 30 | } 31 | find_loop1: 32 | if(xi < 0 || yi < 0) break; 33 | 34 | std::vector> stack; 35 | stack.emplace_back(xi,yi); 36 | 37 | while(!stack.empty()) { 38 | xi = stack.back().first; 39 | yi = stack.back().second; 40 | stack.pop_back(); 41 | 42 | if(visitmap[yi*width + xi]) continue; 43 | 44 | visitmap[yi*width + xi] = true; 45 | remain_count--; 46 | 47 | if(immap[yi*width + xi]) { 48 | idmap[yi*width + xi] = cluster_idx; 49 | if(xi - 1 >= 0) { 50 | stack.emplace_back(xi-1,yi); 51 | } 52 | if(yi - 1 >= 0) { 53 | stack.emplace_back(xi,yi-1); 54 | } 55 | if(xi + 1 < width) { 56 | stack.emplace_back(xi+1,yi); 57 | } 58 | if(yi + 1 < height) { 59 | stack.emplace_back(xi,yi+1); 60 | } 61 | } 62 | } 63 | cluster_idx++; 64 | } 65 | return cluster_idx; 66 | } 67 | 68 | void prepare_id_image( 69 | std::vector &idimage, 70 | std::vector &idimage_main, 71 | std::vector &boxes) 72 | { 73 | fprintf(stderr, "prepare_id_image\n"); 74 | idimage.resize(width*height, -1); 75 | idimage_main.resize(width*height, -1); 76 | for(const auto &box: boxes) { 77 | //fprintf(stderr, "box %d cx %f cy %f w %f h %f c1 %f c2 %f c4 %f c8 %f\n", box.id, box.cx, box.cy, box.w, box.h, box.code1, box.code2, box.code4, box.code8); 78 | int left = (box.cx - box.w / 2) / 4; 79 | int right = (box.cx + box.w / 2) / 4 + 1; 80 | int top = (box.cy - box.h / 2) / 4; 81 | int bottom = (box.cy + box.h / 2) / 4 + 1; 82 | if(left < 0 || right >= width) continue; 83 | if(top < 0 || bottom >= height) continue; 84 | if((box.subtype & (2+4)) != 2+4) { 85 | for(int y = top; y < bottom; y++) { 86 | for(int x = left; x < right; x++) { 87 | idimage_main[y * width + x] = box.id; 88 | } 89 | } 90 | } 91 | for(int y = top; y < bottom; y++) { 92 | for(int x = left; x < right; x++) { 93 | idimage[y * width + x] = box.id; 94 | } 95 | } 96 | } 97 | } 98 | 99 | void make_lineblocker( 100 | std::vector &lineblocker, 101 | const std::vector &sepimage) 102 | { 103 | fprintf(stderr, "make_lineblocker\n"); 104 | lineblocker.resize(width*height, false); 105 | 106 | for(int y = 0; y < height; y++) { 107 | for(int x = 0; x < width; x++) { 108 | float value = sepimage[width * y + x]; 109 | if (value > sep_valueth) { 110 | lineblocker[width * y + x] = true; 111 | } 112 | } 113 | } 114 | std::vector blocker_cluster; 115 | int cluster_count = search_connection(lineblocker, blocker_cluster); 116 | std::vector cluster_weight(cluster_count, 0); 117 | for(int y = 0; y < height; y++) { 118 | for(int x = 0; x < width; x++) { 119 | int id = blocker_cluster[width * y + x]; 120 | if(id < 0) continue; 121 | float value = sepimage[width * y + x]; 122 | cluster_weight[id] += value; 123 | } 124 | } 125 | for(int y = 0; y < height; y++) { 126 | for(int x = 0; x < width; x++) { 127 | int id = blocker_cluster[width * y + x]; 128 | if(id < 0) continue; 129 | if(cluster_weight[id] < sep_clusterth) { 130 | lineblocker[width * y + x] = false; 131 | } 132 | } 133 | } 134 | 135 | std::vector search_idx; 136 | for(int y = 0; y < height; y++) { 137 | for(int x = 0; x < width; x++) { 138 | if(lineblocker[width * y + x]) { 139 | search_idx.push_back(width * y + x); 140 | } 141 | } 142 | } 143 | for(auto i: search_idx) { 144 | float value_th = sepimage[i] * 0.1; 145 | std::vector stack; 146 | stack.push_back(i); 147 | while(!stack.empty()) { 148 | int i2 = stack.back(); 149 | stack.pop_back(); 150 | 151 | if(sepimage[i2] < value_th) continue; 152 | lineblocker[i2] = true; 153 | 154 | int x0 = i2 % width; 155 | int y0 = i2 / width; 156 | 157 | std::vector tmp; 158 | for(int y = y0-1; y <= y0+1; y++) { 159 | for(int x = x0-1; x <= x0+1; x++) { 160 | if(x < 0 || x >= width || y < 0 || y >= height) continue; 161 | int i3 = y * width + x; 162 | if(lineblocker[i3]) continue; 163 | if(sepimage[i3] < value_th) continue; 164 | tmp.push_back(i3); 165 | } 166 | } 167 | std::copy(tmp.begin(), tmp.end(), std::back_inserter(stack)); 168 | } 169 | } 170 | } 171 | -------------------------------------------------------------------------------- /convert1_coreml.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | import coremltools as ct 3 | import torch 4 | import numpy as np 5 | from PIL import Image 6 | import matplotlib 7 | matplotlib.use('Agg') 8 | import matplotlib.pyplot as plt 9 | import io 10 | from datetime import datetime 11 | import itertools 12 | 13 | from models.detector import TextDetectorModel, CenterNetDetector, CodeDecoder 14 | from util_func import calc_predid, width, height, feature_dim, sigmoid, modulo_list 15 | 16 | def convert1(model_size='xl'): 17 | # import logging 18 | # logging.basicConfig(filename='debug.log', level=logging.DEBUG) 19 | 20 | model = TextDetectorModel(model_size=model_size) 21 | data = torch.load('model.pt', map_location="cpu", weights_only=True) 22 | model.load_state_dict(data['model_state_dict']) 23 | 24 | # with torch.no_grad(): 25 | # model.detector.code2.top_conv[-1].bias.copy_(model.detector.code2.top_conv[-1].bias+4) 26 | # model.detector.code8.top_conv[-1].bias.copy_(model.detector.code8.top_conv[-1].bias-2) 27 | 28 | detector = CenterNetDetector(model.detector) 29 | decoder = CodeDecoder(model.decoder) 30 | detector.eval() 31 | decoder.eval() 32 | 33 | ######################################################################### 34 | print('detector') 35 | 36 | example_input = torch.rand(1, 3, height, width) 37 | traced_model = torch.jit.trace(detector, example_input) 38 | 39 | mlmodel_detector = ct.convert(traced_model, 40 | inputs=[ 41 | ct.ImageType(name='image', shape=(1, 3, height, width), scale=1/255) 42 | ], 43 | outputs=[ 44 | ct.TensorType(name='heatmap'), 45 | ct.TensorType(name='feature'), 46 | ], 47 | convert_to="mlprogram", 48 | minimum_deployment_target=ct.target.iOS18) 49 | mlmodel_detector.version = datetime.now().strftime("%Y%m%d%H%M%S") 50 | mlmodel_detector.save("TextDetector.mlpackage") 51 | 52 | ############################################################################ 53 | print('decoder') 54 | 55 | example_input = torch.rand(1, feature_dim) 56 | traced_model = torch.jit.trace(decoder, example_input) 57 | 58 | mlmodel_decoder = ct.convert(traced_model, 59 | convert_to="mlprogram", 60 | inputs=[ 61 | ct.TensorType(name='feature_input', shape=(1, feature_dim)) 62 | ], 63 | outputs=[ 64 | ct.TensorType(name='modulo_1091'), 65 | ct.TensorType(name='modulo_1093'), 66 | ct.TensorType(name='modulo_1097'), 67 | ], 68 | minimum_deployment_target=ct.target.iOS18) 69 | mlmodel_decoder.version = datetime.now().strftime("%Y%m%d%H%M%S") 70 | mlmodel_decoder.save("CodeDecoder.mlpackage") 71 | 72 | 73 | def cos_sim(v1, v2): 74 | return np.dot(v1, v2) / (np.linalg.norm(v1) * np.linalg.norm(v2)) 75 | 76 | def test_model(): 77 | 78 | plt.figure() 79 | plt.text(0.1,0.9,'test', fontsize=32) 80 | plt.axis('off') 81 | plt.tight_layout() 82 | 83 | buf = io.BytesIO() 84 | plt.savefig(buf, format='png') 85 | plt.close() 86 | buf.seek(0) 87 | im = np.array(Image.open(buf).convert("RGB")) 88 | buf.close() 89 | 90 | im = im[:height,:width,:] 91 | im = np.pad(im, [[0,height-im.shape[0]], [0,width-im.shape[1]], [0,0]], 'constant', constant_values=((255,255),(255,255),(255,255))) 92 | 93 | print('test') 94 | input_image = Image.fromarray(im, mode="RGB") 95 | 96 | print('load') 97 | mlmodel_detector = ct.models.MLModel('TextDetector.mlpackage') 98 | mlmodel_decoder = ct.models.MLModel('CodeDecoder.mlpackage') 99 | 100 | output = mlmodel_detector.predict({'image': input_image}) 101 | peakmap = output['heatmap'][0,1,:,:] 102 | 103 | idxy, idxx = np.unravel_index(np.argsort(-peakmap.ravel()), peakmap.shape) 104 | results_dict = [] 105 | for y, x in zip(idxy, idxx): 106 | p1 = sigmoid(peakmap[y,x]) 107 | print(x,y,p1) 108 | if p1 < 0.5: 109 | break 110 | feature = output['feature'][:,:,y,x] 111 | decode_output = mlmodel_decoder.predict({'feature_input': feature}) 112 | p = [] 113 | id = [] 114 | for k,m in enumerate(modulo_list): 115 | prob = decode_output['modulo_%d'%m][0] 116 | idx = np.where(prob > 0.01)[0] 117 | if len(idx) == 0: 118 | idx = [np.argmax(prob)] 119 | if k == 0: 120 | for i in idx[:3]: 121 | id.append([i]) 122 | p.append([prob[i]]) 123 | else: 124 | id = [i1 + [i2] for i1, i2 in itertools.product(id, idx[:3])] 125 | p = [i1 + [prob[i2]] for i1, i2 in itertools.product(p, idx[:3])] 126 | p = [np.exp(np.mean([np.log(prob) for prob in probs])) for probs in p] 127 | i = [calc_predid(*ids) for ids in id] 128 | g = sorted([(prob, id) for prob,id in zip(p,i)], key=lambda x: x[0] if x[1] <= 0x10FFFF else 0, reverse=True) 129 | print(g) 130 | prob,idx = g[0] 131 | if idx <= 0x10FFFF: 132 | c = chr(idx) 133 | else: 134 | c = None 135 | print(prob, idx, c) 136 | print(feature.max(), feature.min()) 137 | results_dict.append((feature[0], idx, c)) 138 | print() 139 | 140 | for i in range(len(results_dict)): 141 | for j in range(i+1, len(results_dict)): 142 | s = cos_sim(results_dict[i][0], results_dict[j][0]) 143 | d = np.linalg.norm(results_dict[i][0] - results_dict[j][0]) 144 | print(s,d, i,j,results_dict[i][1:],results_dict[j][1:]) 145 | 146 | if __name__ == '__main__': 147 | import sys 148 | model_size = 'xl' 149 | if len(sys.argv) > 1: 150 | if sys.argv[1] == 's': 151 | model_size = 's' 152 | if sys.argv[1] == 'm': 153 | model_size = 'm' 154 | if sys.argv[1] == 'l': 155 | model_size = 'l' 156 | convert1(model_size) 157 | test_model() 158 | -------------------------------------------------------------------------------- /textline_detect/src/minpack/qrfac.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | #define mcheps 2.2204460492503131e-16 5 | #define one 1.0 6 | #define p05 5.0e-2 7 | #define zero 0.0 8 | 9 | #define MAX(a,b) ((a) > (b) ? (a) : (b)) 10 | #define MIN(a,b) ((a) < (b) ? (a) : (b)) 11 | 12 | double enorm(int n, double *x); 13 | 14 | void qrfac(int m,int n,double *a,int lda,bool pivot,int *ipvt,int lipvt,double *rdiag,double *acnorm) 15 | { 16 | /* 17 | c ********** 18 | c 19 | c subroutine qrfac 20 | c 21 | c this subroutine uses householder transformations with column 22 | c pivoting (optional) to compute a qr factorization of the 23 | c m by n matrix a. that is, qrfac determines an orthogonal 24 | c matrix q, a permutation matrix p, and an upper trapezoidal 25 | c matrix r with diagonal elements of nonincreasing magnitude, 26 | c such that a*p = q*r. the householder transformation for 27 | c column k, k = 1,2,...,min(m,n), is of the form 28 | c 29 | c t 30 | c i - (1/u(k))*u*u 31 | c 32 | c where u has zeros in the first k-1 positions. the form of 33 | c this transformation and the method of pivoting first 34 | c appeared in the corresponding linpack subroutine. 35 | c 36 | c the subroutine statement is 37 | c 38 | c subroutine qrfac(m,n,a,lda,pivot,ipvt,lipvt,rdiag,acnorm,wa) 39 | c 40 | c where 41 | c 42 | c m is a positive integer input variable set to the number 43 | c of rows of a. 44 | c 45 | c n is a positive integer input variable set to the number 46 | c of columns of a. 47 | c 48 | c a is an m by n array. on input a contains the matrix for 49 | c which the qr factorization is to be computed. on output 50 | c the strict upper trapezoidal part of a contains the strict 51 | c upper trapezoidal part of r, and the lower trapezoidal 52 | c part of a contains a factored form of q (the non-trivial 53 | c elements of the u vectors described above). 54 | c 55 | c lda is a positive integer input variable not less than m 56 | c which specifies the leading dimension of the array a. 57 | c 58 | c pivot is a logical input variable. if pivot is set true, 59 | c then column pivoting is enforced. if pivot is set false, 60 | c then no column pivoting is done. 61 | c 62 | c ipvt is an integer output array of length lipvt. ipvt 63 | c defines the permutation matrix p such that a*p = q*r. 64 | c column j of p is column ipvt(j) of the identity matrix. 65 | c if pivot is false, ipvt is not referenced. 66 | c 67 | c lipvt is a positive integer input variable. if pivot is false, 68 | c then lipvt may be as small as 1. if pivot is true, then 69 | c lipvt must be at least n. 70 | c 71 | c rdiag is an output array of length n which contains the 72 | c diagonal elements of r. 73 | c 74 | c acnorm is an output array of length n which contains the 75 | c norms of the corresponding columns of the input matrix a. 76 | c if this information is not needed, then acnorm can coincide 77 | c with rdiag. 78 | c 79 | c wa is a work array of length n. if pivot is false, then wa 80 | c can coincide with rdiag. 81 | c 82 | c subprograms called 83 | c 84 | c minpack-supplied ... dpmpar,enorm 85 | c 86 | c fortran-supplied ... dmax1,dsqrt,min0 87 | c 88 | c argonne national laboratory. minpack project. march 1980. 89 | c burton s. garbow, kenneth e. hillstrom, jorge j. more 90 | c 91 | c ********** 92 | */ 93 | double *wa = new double[n]; 94 | /* 95 | c 96 | c epsmch is the machine precision. 97 | c 98 | */ 99 | double epsmch = mcheps; 100 | /* 101 | c 102 | c compute the initial column norms and initialize several arrays. 103 | c 104 | */ 105 | for(int j = 0; j < n; j++) { 106 | acnorm[j] = enorm(m, &a[j*lda]); 107 | rdiag[j] = acnorm[j]; 108 | wa[j] = rdiag[j]; 109 | if (pivot) ipvt[j] = j; 110 | } 111 | /* 112 | c 113 | c reduce a to r with householder transformations. 114 | c 115 | */ 116 | int minmn = MIN(m,n); 117 | for(int j = 0; j < minmn; j++) { 118 | if(pivot) { 119 | /* 120 | c 121 | c bring the column of largest norm into the pivot position. 122 | c 123 | */ 124 | int kmax = j; 125 | for(int k = j; k < n; k++) { 126 | if (rdiag[k] > rdiag[kmax]) kmax = k; 127 | } 128 | if (kmax != j) { 129 | for(int i = 0; i < m; i++) { 130 | std::swap(a[i+j*lda], a[i+kmax*lda]); 131 | } 132 | rdiag[kmax] = rdiag[j]; 133 | wa[kmax] = wa[j]; 134 | std::swap(ipvt[j], ipvt[kmax]); 135 | } 136 | } 137 | /* 138 | c 139 | c compute the householder transformation to reduce the 140 | c j-th column of a to a multiple of the j-th unit vector. 141 | c 142 | */ 143 | double ajnorm = enorm(m-j, &a[j+j*lda]); 144 | if (ajnorm != zero) { 145 | if (a[j+j*lda] < zero) ajnorm = -ajnorm; 146 | for(int i = j; i < m; i++) { 147 | a[i+j*lda] /= ajnorm; 148 | } 149 | a[j+j*lda] += one; 150 | /* 151 | c 152 | c apply the transformation to the remaining columns 153 | c and update the norms. 154 | c 155 | */ 156 | for(int k = j+1; k < n; k++) { 157 | double sum = zero; 158 | for(int i = j; i < m; i++) { 159 | sum += a[i+j*lda]*a[i+k*lda]; 160 | } 161 | double temp = sum/a[j+j*lda]; 162 | for(int i = j; i < m; i++) { 163 | a[i+k*lda] -= temp*a[i+j*lda]; 164 | } 165 | if (pivot && rdiag[k] != zero) { 166 | temp = a[j+k*lda]/rdiag[k]; 167 | rdiag[k] *= sqrt(MAX(zero,one-temp*temp)); 168 | if(p05*(rdiag[k]/wa[k])*(rdiag[k]/wa[k]) <= epsmch) { 169 | rdiag[k] = enorm(m-j-1,&a[j+1+k*lda]); 170 | wa[k] = rdiag[k]; 171 | } 172 | } 173 | } 174 | } 175 | rdiag[j] = -ajnorm; 176 | } 177 | 178 | delete[] wa; 179 | } -------------------------------------------------------------------------------- /dataset/multi.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (c) 2017-2021 NVIDIA CORPORATION. All rights reserved. 3 | # This file is part of the WebDataset library. 4 | # See the LICENSE file for licensing terms (BSD-style). 5 | # 6 | 7 | """An alternative to DataLoader using ZMQ. 8 | 9 | This implements MultiLoader, an alternative to DataLoader when torch 10 | is not available. Subprocesses communicate with the loader through 11 | ZMQ, provided for high performance multithreaded queueing. 12 | """ 13 | 14 | import multiprocessing as mp 15 | import pickle 16 | import os 17 | import uuid 18 | import weakref 19 | import time 20 | import signal 21 | 22 | import zmq 23 | 24 | the_protocol = pickle.HIGHEST_PROTOCOL 25 | 26 | all_pids = weakref.WeakSet() 27 | 28 | 29 | class EOF: 30 | """A class that indicates that a data stream is finished.""" 31 | 32 | def __init__(self, **kw): 33 | """Initialize the class with the kw as instance variables.""" 34 | self.__dict__.update(kw) 35 | 36 | 37 | isalive = True 38 | 39 | def signal_handler(signal_number, frame): 40 | global isalive 41 | isalive = False 42 | 43 | def reader(dataset, sockname1, sockname2, index, num_workers): 44 | """Read samples from the dataset and send them over the socket. 45 | 46 | :param dataset: source dataset 47 | :param sockname: name for the socket to send data to 48 | :param index: index for this reader, using to indicate EOF 49 | """ 50 | global isalive, the_protocol 51 | os.environ["WORKER"] = str(index) 52 | os.environ["NUM_WORKERS"] = str(num_workers) 53 | signal.signal(signal.SIGTERM, signal_handler) 54 | ctx = zmq.Context.instance() 55 | sock1 = ctx.socket(zmq.PUSH) 56 | sock1.connect(sockname1) 57 | sock2 = ctx.socket(zmq.SUB) 58 | sock2.connect(sockname2) 59 | sock2.setsockopt(zmq.SUBSCRIBE, b'') 60 | poller = zmq.Poller() 61 | poller.register(sock2, zmq.POLLIN) 62 | rcount = 0 63 | for i, sample in enumerate(dataset): 64 | data = pickle.dumps(sample, protocol=the_protocol) 65 | sock1.send(data) 66 | while isalive: 67 | socks = dict(poller.poll(50)) 68 | if sock2 in socks and socks[sock2] == zmq.POLLIN: 69 | rcount = sock2.recv_pyobj() 70 | if i > rcount / num_workers + 2: 71 | time.sleep(0.05) 72 | else: 73 | break 74 | sock1.send(pickle.dumps(EOF(index=index))) 75 | while isalive: 76 | socks = dict(poller.poll(50)) 77 | if sock2 in socks and socks[sock2] == zmq.POLLIN: 78 | sample = sock2.recv_pyobj() 79 | if isinstance(sample, EOF) and sample.index == index: 80 | break 81 | else: 82 | time.sleep(0.05) 83 | 84 | sock1.close() 85 | sock2.close() 86 | ctx.destroy() 87 | 88 | class MultiLoader: 89 | """Alternative to PyTorch DataLoader based on ZMQ.""" 90 | 91 | def __init__( 92 | self, dataset, workers=4, verbose=False, nokill=False, prefix="/tmp/_multi-" 93 | ): 94 | """Create a MultiLoader for a dataset. 95 | 96 | This creates ZMQ sockets, spawns `workers` subprocesses, and has them send data 97 | to the socket. 98 | 99 | :param dataset: source dataset 100 | :param workers: number of workers 101 | :param verbose: report progress verbosely 102 | :param nokill: don't kill old processes when restarting (allows multiple loaders) 103 | :param prefix: directory prefix for the ZMQ socket 104 | """ 105 | self.dataset = dataset 106 | self.workers = workers 107 | self.verbose = verbose 108 | self.pids = [] 109 | self.socket1 = None 110 | self.socket2 = None 111 | self.ctx = zmq.Context.instance() 112 | self.nokill = nokill 113 | self.prefix = prefix 114 | 115 | def __del__(self): 116 | global isalive 117 | isalive = False 118 | self.kill() 119 | self.ctx.destroy() 120 | 121 | def kill(self): 122 | """kill.""" 123 | for pid in self.pids: 124 | if pid is None: 125 | continue 126 | if self.verbose: 127 | print("killing", pid) 128 | pid.kill() 129 | pid.join(1.0) 130 | self.pids = [] 131 | if self.socket1 is not None: 132 | if self.verbose: 133 | print("closing", self.socket1) 134 | self.socket1.close() 135 | self.socket1 = None 136 | if self.socket2 is not None: 137 | if self.verbose: 138 | print("closing", self.socket2) 139 | self.socket2.close() 140 | self.socket2 = None 141 | 142 | def __iter__(self): 143 | """Return an iterator over this dataloader.""" 144 | if not self.nokill: 145 | self.kill() 146 | self.sockname1 = "ipc://" + self.prefix + str(uuid.uuid4()) 147 | self.sockname2 = "ipc://" + self.prefix + str(uuid.uuid4()) 148 | self.socket1 = self.ctx.socket(zmq.PULL) 149 | self.socket1.bind(self.sockname1) 150 | if self.verbose: 151 | print("#", self.sockname1) 152 | self.socket2 = self.ctx.socket(zmq.PUB) 153 | self.socket2.bind(self.sockname2) 154 | if self.verbose: 155 | print("#", self.sockname2) 156 | self.pids = [None] * self.workers 157 | for index in range(self.workers): 158 | args = (self.dataset, self.sockname1, self.sockname2, index, self.workers) 159 | self.pids[index] = mp.Process(target=reader, args=args) 160 | all_pids.update(self.pids) 161 | for pid in self.pids: 162 | pid.start() 163 | count = 0 164 | self.socket2.send_pyobj(count) 165 | while self.pids.count(None) < len(self.pids): 166 | data = self.socket1.recv() 167 | sample = pickle.loads(data) 168 | if isinstance(sample, EOF): 169 | self.socket2.send_pyobj(sample) 170 | if self.verbose: 171 | print("# subprocess finished", sample.index) 172 | self.pids[sample.index].join(1.0) 173 | self.pids[sample.index] = None 174 | else: 175 | yield sample 176 | count += 1 177 | self.socket2.send_pyobj(count) 178 | -------------------------------------------------------------------------------- /textline_detect/src/minpack/qrsolv.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #define p5 5.0e-1 4 | #define p25 2.5e-1 5 | #define zero 0.0 6 | 7 | void qrsolv(int n,double *r,int ldr,int *ipvt,double *diag,double *qtb,double *x,double *sdiag) 8 | { 9 | /* 10 | c ********** 11 | c 12 | c subroutine qrsolv 13 | c 14 | c given an m by n matrix a, an n by n diagonal matrix d, 15 | c and an m-vector b, the problem is to determine an x which 16 | c solves the system 17 | c 18 | c a*x = b , d*x = 0 , 19 | c 20 | c in the least squares sense. 21 | c 22 | c this subroutine completes the solution of the problem 23 | c if it is provided with the necessary information from the 24 | c qr factorization, with column pivoting, of a. that is, if 25 | c a*p = q*r, where p is a permutation matrix, q has orthogonal 26 | c columns, and r is an upper triangular matrix with diagonal 27 | c elements of nonincreasing magnitude, then qrsolv expects 28 | c the full upper triangle of r, the permutation matrix p, 29 | c and the first n components of (q transpose)*b. the system 30 | c a*x = b, d*x = 0, is then equivalent to 31 | c 32 | c t t 33 | c r*z = q *b , p *d*p*z = 0 , 34 | c 35 | c where x = p*z. if this system does not have full rank, 36 | c then a least squares solution is obtained. on output qrsolv 37 | c also provides an upper triangular matrix s such that 38 | c 39 | c t t t 40 | c p *(a *a + d*d)*p = s *s . 41 | c 42 | c s is computed within qrsolv and may be of separate interest. 43 | c 44 | c the subroutine statement is 45 | c 46 | c subroutine qrsolv(n,r,ldr,ipvt,diag,qtb,x,sdiag,wa) 47 | c 48 | c where 49 | c 50 | c n is a positive integer input variable set to the order of r. 51 | c 52 | c r is an n by n array. on input the full upper triangle 53 | c must contain the full upper triangle of the matrix r. 54 | c on output the full upper triangle is unaltered, and the 55 | c strict lower triangle contains the strict upper triangle 56 | c (transposed) of the upper triangular matrix s. 57 | c 58 | c ldr is a positive integer input variable not less than n 59 | c which specifies the leading dimension of the array r. 60 | c 61 | c ipvt is an integer input array of length n which defines the 62 | c permutation matrix p such that a*p = q*r. column j of p 63 | c is column ipvt(j) of the identity matrix. 64 | c 65 | c diag is an input array of length n which must contain the 66 | c diagonal elements of the matrix d. 67 | c 68 | c qtb is an input array of length n which must contain the first 69 | c n elements of the vector (q transpose)*b. 70 | c 71 | c x is an output array of length n which contains the least 72 | c squares solution of the system a*x = b, d*x = 0. 73 | c 74 | c sdiag is an output array of length n which contains the 75 | c diagonal elements of the upper triangular matrix s. 76 | c 77 | c wa is a work array of length n. 78 | c 79 | c subprograms called 80 | c 81 | c fortran-supplied ... dabs,dsqrt 82 | c 83 | c argonne national laboratory. minpack project. march 1980. 84 | c burton s. garbow, kenneth e. hillstrom, jorge j. more 85 | c 86 | c ********** 87 | */ 88 | double *wa = new double[n]; 89 | /* 90 | c 91 | c copy r and (q transpose)*b to preserve input and initialize s. 92 | c in particular, save the diagonal elements of r in x. 93 | c 94 | */ 95 | for(int j = 0; j < n; j++) { 96 | for(int i = j; i < n; i++) { 97 | r[i+j*ldr] = r[j+i*ldr]; 98 | } 99 | x[j] = r[j+j*ldr]; 100 | wa[j] = qtb[j]; 101 | } 102 | /* 103 | c 104 | c eliminate the diagonal matrix d using a givens rotation. 105 | c 106 | */ 107 | for(int j = 0; j < n; j++) { 108 | /* 109 | c 110 | c prepare the row of d to be eliminated, locating the 111 | c diagonal element using p from the qr factorization. 112 | c 113 | */ 114 | if (diag[ipvt[j]] != zero) { 115 | for(int k = j; k < n; k++) { 116 | sdiag[k] = zero; 117 | } 118 | sdiag[j] = diag[ipvt[j]]; 119 | /* 120 | c 121 | c the transformations to eliminate the row of d 122 | c modify only a single element of (q transpose)*b 123 | c beyond the first n, which is initially zero. 124 | c 125 | */ 126 | double qtbpj = zero; 127 | for(int k = j; k < n; k++) { 128 | /* 129 | c 130 | c determine a givens rotation which eliminates the 131 | c appropriate element in the current row of d. 132 | c 133 | */ 134 | if (sdiag[k] == zero) continue; 135 | double cos,cotan,sin,tan; 136 | if (fabs(r[k+k*ldr]) < fabs(sdiag[k])) { 137 | cotan = r[k+k*ldr]/sdiag[k]; 138 | sin = p5/sqrt(p25+p25*cotan*cotan); 139 | cos = sin*cotan; 140 | } 141 | else { 142 | tan = sdiag[k]/r[k+k*ldr]; 143 | cos = p5/sqrt(p25+p25*tan*tan); 144 | sin = cos*tan; 145 | } 146 | /* 147 | c 148 | c compute the modified diagonal element of r and 149 | c the modified element of ((q transpose)*b,0). 150 | c 151 | */ 152 | r[k+k*ldr] = cos*r[k+k*ldr] + sin*sdiag[k]; 153 | double temp = cos*wa[k] + sin*qtbpj; 154 | qtbpj = -sin*wa[k] + cos*qtbpj; 155 | wa[k] = temp; 156 | /* 157 | c 158 | c accumulate the tranformation in the row of s. 159 | c 160 | */ 161 | for(int i = k+1; i < n; i++) { 162 | temp = cos*r[i+k*ldr] + sin*sdiag[i]; 163 | sdiag[i] = -sin*r[i+k*ldr] + cos*sdiag[i]; 164 | r[i+k*ldr] = temp; 165 | } 166 | } 167 | } 168 | /* 169 | c 170 | c store the diagonal element of s and restore 171 | c the corresponding diagonal element of r. 172 | c 173 | */ 174 | sdiag[j] = r[j+j*ldr]; 175 | r[j+j*ldr] = x[j]; 176 | } 177 | /* 178 | c 179 | c solve the triangular system for z. if the system is 180 | c singular, then obtain a least squares solution. 181 | c 182 | */ 183 | int nsing = n; 184 | for(int j = 0; j < n; j++) { 185 | if (sdiag[j] == zero && nsing == n) nsing = j - 1; 186 | if (nsing < n) wa[j] = zero; 187 | } 188 | for(int k = 1; k <= nsing; k++) { 189 | int j = nsing - k; 190 | double sum = zero; 191 | for(int i = j+1; i < nsing; i++) { 192 | sum += r[i+j*ldr]*wa[i]; 193 | } 194 | wa[j] = (wa[j] - sum)/sdiag[j]; 195 | } 196 | /* 197 | c 198 | c permute the components of z back to components of x. 199 | c 200 | */ 201 | for(int j = 0; j < n; j++) { 202 | x[ipvt[j]] = wa[j]; 203 | } 204 | 205 | delete[] wa; 206 | } -------------------------------------------------------------------------------- /textline_detect/src/main.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | 6 | #ifdef _WIN64 7 | #include 8 | #include 9 | #endif 10 | 11 | #include 12 | #include 13 | 14 | #include "line_detect.h" 15 | #include "process.h" 16 | 17 | double ruby_cutoff = 0.25; 18 | double rubybase_cutoff = 0.75; 19 | double space_cutoff = 0.5; 20 | double emphasis_cutoff = 0.5; 21 | float line_valueth = 0.4; 22 | float sep_valueth = 0.1; 23 | float sep_valueth2 = 0.15; 24 | const float sep_clusterth = 10.0; 25 | const int linearea_th = 20; 26 | double allowwidth_next_block = 1.5; 27 | double allow_sizediff = 0.5; 28 | double chain_line_ratio = 0.0; 29 | int page_divide = 0; 30 | int scale = 4; 31 | 32 | int run_mode = 0; 33 | int width = 0; 34 | int height = 0; 35 | 36 | int main(int argc, char **argv) 37 | { 38 | for(int i = 1; i < argc; i++) { 39 | std::string arg(argv[i]); 40 | if(arg.find("--ruby_cutoff=") != std::string::npos) { 41 | std::string vstr = arg.substr(arg.find('=')+1); 42 | std::stringstream(vstr) >> ruby_cutoff; 43 | std::cerr << "ruby_cutoff=" << ruby_cutoff << std::endl; 44 | } 45 | if(arg.find("--rubybase_cutoff=") != std::string::npos) { 46 | std::string vstr = arg.substr(arg.find('=')+1); 47 | std::stringstream(vstr) >> rubybase_cutoff; 48 | std::cerr << "rubybase_cutoff=" << rubybase_cutoff << std::endl; 49 | } 50 | if(arg.find("--space_cutoff=") != std::string::npos) { 51 | std::string vstr = arg.substr(arg.find('=')+1); 52 | std::stringstream(vstr) >> space_cutoff; 53 | std::cerr << "space_cutoff=" << space_cutoff << std::endl; 54 | } 55 | if(arg.find("--emphasis_cutoff=") != std::string::npos) { 56 | std::string vstr = arg.substr(arg.find('=')+1); 57 | std::stringstream(vstr) >> emphasis_cutoff; 58 | std::cerr << "emphasis_cutoff=" << emphasis_cutoff << std::endl; 59 | } 60 | if(arg.find("--line_valueth=") != std::string::npos) { 61 | std::string vstr = arg.substr(arg.find('=')+1); 62 | std::stringstream(vstr) >> line_valueth; 63 | std::cerr << "line_valueth=" << line_valueth << std::endl; 64 | } 65 | if(arg.find("--sep_valueth=") != std::string::npos) { 66 | std::string vstr = arg.substr(arg.find('=')+1); 67 | std::stringstream(vstr) >> sep_valueth; 68 | std::cerr << "sep_valueth=" << sep_valueth << std::endl; 69 | } 70 | if(arg.find("--sep_valueth2=") != std::string::npos) { 71 | std::string vstr = arg.substr(arg.find('=')+1); 72 | std::stringstream(vstr) >> sep_valueth2; 73 | std::cerr << "sep_valueth2=" << sep_valueth2 << std::endl; 74 | } 75 | if(arg.find("--allowwidth_next_block=") != std::string::npos) { 76 | std::string vstr = arg.substr(arg.find('=')+1); 77 | std::stringstream(vstr) >> allowwidth_next_block; 78 | std::cerr << "allowwidth_next_block=" << allowwidth_next_block << std::endl; 79 | } 80 | if(arg.find("--allow_sizediff=") != std::string::npos) { 81 | std::string vstr = arg.substr(arg.find('=')+1); 82 | std::stringstream(vstr) >> allow_sizediff; 83 | std::cerr << "allow_sizediff=" << allow_sizediff << std::endl; 84 | } 85 | if(arg.find("--page_divide=") != std::string::npos) { 86 | std::string vstr = arg.substr(arg.find('=')+1); 87 | std::stringstream(vstr) >> page_divide; 88 | std::cerr << "page_divide=" << page_divide << std::endl; 89 | } 90 | } 91 | 92 | #ifdef _WIN64 93 | _setmode(_fileno(stdin), _O_BINARY); 94 | _setmode(_fileno(stdout), _O_BINARY); 95 | #else 96 | freopen(NULL, "rb", stdin); 97 | freopen(NULL, "wb", stdout); 98 | #endif 99 | 100 | fread(&run_mode, sizeof(uint32_t), 1, stdin); 101 | 102 | fread(&width, sizeof(uint32_t), 1, stdin); 103 | fread(&height, sizeof(uint32_t), 1, stdin); 104 | 105 | std::vector lineimage(width*height); 106 | fread(lineimage.data(), sizeof(float), width*height, stdin); 107 | std::vector sepimage(width*height); 108 | fread(sepimage.data(), sizeof(float), width*height, stdin); 109 | 110 | int boxcount = 0; 111 | fread(&boxcount, sizeof(uint32_t), 1, stdin); 112 | 113 | std::cerr << boxcount << std::endl; 114 | 115 | std::vector boxes(boxcount); 116 | for(int i = 0; i < boxcount; i++) { 117 | boxes[i].id = i; 118 | boxes[i].block = -1; 119 | boxes[i].idx = -1; 120 | boxes[i].subidx = -1; 121 | boxes[i].subtype = 0; 122 | boxes[i].direction = 0; 123 | boxes[i].double_line = 0; 124 | fread(&boxes[i].cx, sizeof(float), 1, stdin); 125 | fread(&boxes[i].cy, sizeof(float), 1, stdin); 126 | fread(&boxes[i].w, sizeof(float), 1, stdin); 127 | fread(&boxes[i].h, sizeof(float), 1, stdin); 128 | fread(&boxes[i].code1, sizeof(float), 1, stdin); 129 | fread(&boxes[i].code2, sizeof(float), 1, stdin); 130 | fread(&boxes[i].code4, sizeof(float), 1, stdin); 131 | fread(&boxes[i].code8, sizeof(float), 1, stdin); 132 | // ルビ親文字 133 | if(boxes[i].code2 > rubybase_cutoff) { 134 | boxes[i].subtype |= 2; 135 | } 136 | // ルビの文字 137 | if(boxes[i].code1 > ruby_cutoff) { 138 | boxes[i].subtype |= 2+4; 139 | } 140 | // 空白 141 | if(boxes[i].code8 > space_cutoff) { 142 | boxes[i].subtype |= 8; 143 | } 144 | // 圏点 145 | if(boxes[i].code4 > emphasis_cutoff) { 146 | boxes[i].subtype |= 16; 147 | } 148 | // fprintf(stderr, "box %d cx %f cy %f w %f h %f c1 %f c2 %f c4 %f c8 %f t %d\n", 149 | // boxes[i].id, boxes[i].cx, boxes[i].cy, boxes[i].w, boxes[i].h, 150 | // boxes[i].code1, boxes[i].code2, boxes[i].code4, boxes[i].code8, 151 | // boxes[i].subtype); 152 | } 153 | 154 | process(lineimage, sepimage, boxes); 155 | 156 | std::sort(boxes.begin(), boxes.end(), [](auto a, auto b) { 157 | if(a.block != b.block) return a.block < b.block; 158 | if(a.idx != b.idx) return a.idx < b.idx; 159 | if(a.subidx != b.subidx) return a.subidx < b.subidx; 160 | return a.id < b.id; 161 | }); 162 | 163 | uint32_t count = boxes.size(); 164 | fwrite(&count, sizeof(int32_t), 1, stdout); 165 | 166 | for(int i = 0; i < boxes.size(); i++) { 167 | // fprintf(stderr, "box %d cx %f cy %f w %f h %f block %d idx %d sidx %d stype %d c1 %f c2 %f c4 %f c8 %f d %d\n", 168 | // boxes[i].id, boxes[i].cx, boxes[i].cy, boxes[i].w, boxes[i].h, 169 | // boxes[i].block, boxes[i].idx, boxes[i].subidx, boxes[i].subtype, 170 | // boxes[i].code1, boxes[i].code2, boxes[i].code4, boxes[i].code8, 171 | // boxes[i].double_line); 172 | 173 | fwrite(&boxes[i].id, sizeof(int32_t), 1, stdout); 174 | fwrite(&boxes[i].block, sizeof(int32_t), 1, stdout); 175 | fwrite(&boxes[i].idx, sizeof(int32_t), 1, stdout); 176 | fwrite(&boxes[i].subidx, sizeof(int32_t), 1, stdout); 177 | fwrite(&boxes[i].subtype, sizeof(int32_t), 1, stdout); 178 | fwrite(&boxes[i].page, sizeof(int32_t), 1, stdout); 179 | fwrite(&boxes[i].section, sizeof(int32_t), 1, stdout); 180 | } 181 | 182 | return 0; 183 | } -------------------------------------------------------------------------------- /convert3_onnx.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | import onnxruntime 3 | import torch 4 | import numpy as np 5 | import os 6 | import itertools 7 | 8 | from models.transformer import ModelDimensions, Transformer, TransformerEncoderPredictor, TransformerDecoderPredictor 9 | from util_func import feature_dim, modulo_list, calc_predid 10 | from const import encoder_add_dim, max_decoderlen, max_encoderlen, decoder_SOT, decoder_EOT, decoder_MSK 11 | 12 | def convert3(): 13 | # import logging 14 | # logging.basicConfig(filename='debug.log', level=logging.DEBUG) 15 | 16 | if os.path.exists('model3.pt'): 17 | data = torch.load('model3.pt', map_location="cpu", weights_only=True) 18 | config = ModelDimensions(**data['config']) 19 | model = Transformer(**config.__dict__) 20 | model.load_state_dict(data['model_state_dict']) 21 | print('loaded') 22 | else: 23 | config = ModelDimensions() 24 | model = Transformer(**config.__dict__) 25 | print('empty model') 26 | model.eval() 27 | encoder = TransformerEncoderPredictor(model.encoder) 28 | decoder = TransformerDecoderPredictor(model.decoder) 29 | encoder.eval() 30 | decoder.eval() 31 | 32 | ######################################################################### 33 | print('encoder') 34 | 35 | encoder_dim = feature_dim+encoder_add_dim 36 | encoder_input = torch.rand(1, max_encoderlen, encoder_dim) 37 | key_mask = torch.all(encoder_input == 0, dim=-1) 38 | key_mask = torch.where(key_mask[:,None,None,:], float("-inf"), 0) 39 | torch.onnx.export(encoder, 40 | (encoder_input, key_mask), 41 | "TransformerEncoder.onnx", 42 | verbose=True, 43 | input_names=['encoder_input', 'key_mask'], 44 | output_names=['encoder_output']) 45 | 46 | ############################################################################ 47 | print('decoder') 48 | 49 | encoder_output = torch.rand(1, max_encoderlen, config.embed_dim) 50 | decoder_input = torch.randint(0, 1000, size=(1, max_decoderlen), dtype=torch.long) 51 | torch.onnx.export(decoder, 52 | (encoder_output, decoder_input, key_mask), 53 | "TransformerDecoder.onnx", 54 | verbose=True, 55 | input_names=['encoder_output', 'decoder_input', 'key_mask'], 56 | output_names=['modulo_%d'%m for m in modulo_list]) 57 | 58 | def test3(): 59 | print('load') 60 | onnx_encoder = onnxruntime.InferenceSession("TransformerEncoder.onnx") 61 | onnx_decoder = onnxruntime.InferenceSession("TransformerDecoder.onnx") 62 | 63 | print(' [ encoder ] ') 64 | print('input:') 65 | for session_input in onnx_encoder.get_inputs(): 66 | print(session_input.name, session_input.shape) 67 | print('output:') 68 | for session_output in onnx_encoder.get_outputs(): 69 | print(session_output.name, session_output.shape) 70 | 71 | print(' [ decoder ] ') 72 | print('input:') 73 | for session_input in onnx_decoder.get_inputs(): 74 | print(session_input.name, session_input.shape) 75 | print('output:') 76 | for session_output in onnx_decoder.get_outputs(): 77 | print(session_output.name, session_output.shape) 78 | 79 | rng = np.random.default_rng() 80 | train_data3 = 'train_data3' 81 | 82 | encoder_dim = feature_dim+encoder_add_dim 83 | encoder_input = np.zeros(shape=(1, max_encoderlen, encoder_dim), dtype=np.float32) 84 | SP_token = np.zeros([encoder_dim], dtype=np.float32) 85 | SP_token[0:feature_dim:2] = 5 86 | SP_token[1:feature_dim:2] = -5 87 | encoder_input[0,0,:] = SP_token 88 | with np.load(os.path.join(train_data3, 'features.npz')) as data: 89 | for i,c in enumerate('test'): 90 | code = ord(c) 91 | value = data['hori_%d'%code] 92 | feat = rng.choice(value, replace=False) 93 | encoder_input[0,i+1,:feature_dim] = feat 94 | encoder_input[0,i+2,:] = -SP_token 95 | 96 | key_mask = np.where((encoder_input == 0).all(axis=-1)[:,None,None,:], float("-inf"), 0).astype(np.float32) 97 | print('encoder') 98 | encoder_output, = onnx_encoder.run(['encoder_output'], {'encoder_input': encoder_input, 'key_mask': key_mask}) 99 | 100 | print('decoder') 101 | decoder_input = np.zeros(shape=(1, max_decoderlen), dtype=np.int64) 102 | decoder_input[0,0] = decoder_SOT 103 | decoder_input[0,1:] = decoder_MSK 104 | rep_count = 8 105 | for k in range(rep_count): 106 | output = onnx_decoder.run(['modulo_%d'%m for m in modulo_list], { 107 | 'encoder_output': encoder_output, 108 | 'decoder_input': decoder_input, 109 | 'key_mask': key_mask, 110 | }) 111 | 112 | listp = [] 113 | listi = [] 114 | for pred_p1 in output: 115 | topi = np.argpartition(-pred_p1, 4, axis=-1)[...,:4] 116 | topp = np.take_along_axis(pred_p1, topi, axis=-1) 117 | listp.append(np.transpose(topp, (2,0,1))) 118 | listi.append(np.transpose(topi, (2,0,1))) 119 | 120 | pred_ids = np.stack([np.stack(x) for x in itertools.product(*listi)]) 121 | pred_p = np.stack([np.stack(x) for x in itertools.product(*listp)]) 122 | pred_ids = np.transpose(pred_ids, (1,0,2,3)) 123 | pred_p = np.transpose(pred_p, (1,0,2,3)) 124 | pred_p = np.exp(np.mean(np.log(np.maximum(pred_p, 1e-10)), axis=0)) 125 | decoder_output = calc_predid(*pred_ids) 126 | pred_p[decoder_output > 0x3FFFF] = 0 127 | maxi = np.argmax(pred_p, axis=0) 128 | decoder_output = np.take_along_axis(decoder_output, maxi[None,...], axis=0)[0] 129 | pred_p = np.take_along_axis(pred_p, maxi[None,...], axis=0)[0] 130 | if k > 0 and np.all(pred_p[decoder_output > 0] > 0.99): 131 | print(f'---[{k} early stop]---') 132 | break 133 | if k < rep_count-1: 134 | decoder_input[:,1:] = np.where(pred_p < 1/rep_count*k, decoder_MSK, decoder_output)[:,:-1] 135 | print(decoder_output[0]) 136 | predstr = '' 137 | for p in decoder_output[0]: 138 | if p == 0 or p == decoder_EOT: 139 | break 140 | if p < 0x3FFFF: 141 | predstr += chr(p) 142 | else: 143 | predstr += '\uFFFD' 144 | try: 145 | print(predstr) 146 | except UnicodeEncodeError: 147 | pass 148 | 149 | def test32(): 150 | from models.transformer import TransformerPredictor 151 | 152 | if torch.cuda.is_available(): 153 | device = 'cuda' 154 | elif torch.backends.mps.is_available(): 155 | device = 'mps' 156 | else: 157 | device = 'cpu' 158 | device = torch.device(device) 159 | rng = np.random.default_rng() 160 | 161 | if os.path.exists('model3.pt'): 162 | data = torch.load('model3.pt', map_location="cpu", weights_only=True) 163 | config = ModelDimensions(**data['config']) 164 | model = Transformer(**config.__dict__) 165 | model.load_state_dict(data['model_state_dict']) 166 | else: 167 | config = ModelDimensions() 168 | model = Transformer(**config.__dict__) 169 | model2 = TransformerPredictor(model.encoder, model.decoder) 170 | model2.to(device) 171 | model2.eval() 172 | 173 | rng = np.random.default_rng() 174 | train_data3 = 'train_data3' 175 | 176 | encoder_dim = feature_dim+encoder_add_dim 177 | encoder_input = np.zeros(shape=(1, max_encoderlen, encoder_dim), dtype=np.float32) 178 | SP_token = np.zeros([encoder_dim], dtype=np.float32) 179 | SP_token[0:feature_dim:2] = 5 180 | SP_token[1:feature_dim:2] = -5 181 | with np.load(os.path.join(train_data3, 'features.npz')) as data: 182 | encoder_input[0,0,:] = SP_token 183 | for i,c in enumerate('test'): 184 | code = ord(c) 185 | value = data['hori_%d'%code] 186 | feat = rng.choice(value, replace=False) 187 | encoder_input[0,i+1,:feature_dim] = feat 188 | encoder_input[0,i+2,:] = -SP_token 189 | 190 | encoder_input = torch.tensor(encoder_input).to(device) 191 | pred = model2(encoder_input).squeeze(0).cpu().numpy() 192 | predstr = '' 193 | for p in pred: 194 | if p == 0 or p == 2: 195 | break 196 | if p < 0x3FFFF: 197 | predstr += chr(p) 198 | else: 199 | predstr += '\uFFFD' 200 | print('------------------') 201 | try: 202 | print(predstr) 203 | except UnicodeEncodeError: 204 | pass 205 | print('==================') 206 | print(pred) 207 | 208 | if __name__ == '__main__': 209 | convert3() 210 | test3() 211 | # test32() 212 | -------------------------------------------------------------------------------- /fine_image/fix_line_image1.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import numpy as np 4 | import sys 5 | from PIL import Image, ImageDraw 6 | 7 | try: 8 | from pillow_heif import register_heif_opener 9 | register_heif_opener() 10 | except ImportError: 11 | pass 12 | 13 | import matplotlib.pyplot as plt 14 | from matplotlib.font_manager import FontProperties 15 | from matplotlib.backends.backend_tkagg import FigureCanvasTkAgg 16 | import tkinter as tk 17 | 18 | linetype = 'line' 19 | if len(sys.argv) < 2: 20 | print(sys.argv[0],'target.png') 21 | exit(1) 22 | 23 | dpi = 72 24 | target_file = sys.argv[1] 25 | if len(sys.argv) > 2: 26 | for arg in sys.argv[2:]: 27 | if arg.startswith('dpi'): 28 | dpi = int(arg[3:]) 29 | print('dpi:', dpi) 30 | else: 31 | linetype = arg 32 | 33 | class Application(tk.Frame): 34 | def __init__(self, root, target_file): 35 | super().__init__(root) 36 | self.target_file = target_file 37 | root.title(target_file) 38 | im0 = Image.open(target_file).convert('RGB') 39 | self.im0 = im0.resize((im0.width // 4, im0.height // 4), resample=Image.Resampling.BILINEAR) 40 | 41 | if linetype == 'line': 42 | linesfile = target_file + '.lines.png' 43 | elif linetype == 'seps': 44 | linesfile = target_file + '.seps.png' 45 | self.linesfile = linesfile 46 | self.lines_all = Image.open(linesfile) 47 | self.lines_draw = ImageDraw.Draw(self.lines_all) 48 | self.v_line = False 49 | self.h_line = False 50 | 51 | self.gen_mpl_graph(root) 52 | 53 | self.grid_rowconfigure(0, weight=1) 54 | self.grid_columnconfigure(0, weight=1) 55 | 56 | def save(self): 57 | self.lines_all.save(self.linesfile) 58 | 59 | def gen_mpl_graph(self, root): 60 | self.cid = None 61 | self.newpoints = [] 62 | frame1 = tk.Frame(root) 63 | frame2 = tk.Frame(root) 64 | 65 | def onclick1(event): 66 | x = event.xdata 67 | y = event.ydata 68 | if x is None or y is None: 69 | return 70 | self.newpoints.append((x,y)) 71 | self.ax_im.plot(x, y, 'r.') 72 | 73 | if len(self.newpoints) >= 2: 74 | self.fig.canvas.mpl_disconnect(self.cid) 75 | self.cid = None 76 | self.btn0.configure(text='new line') 77 | self.btn2.configure(text='new vline') 78 | self.btn3.configure(text='new hline') 79 | 80 | if self.h_line: 81 | self.lines_draw.line((self.newpoints[0][0],self.newpoints[0][1],self.newpoints[1][0],self.newpoints[0][1]), fill=255, width=3) 82 | elif self.v_line: 83 | self.lines_draw.line((self.newpoints[0][0],self.newpoints[0][1],self.newpoints[0][0],self.newpoints[1][1]), fill=255, width=3) 84 | else: 85 | self.lines_draw.line(self.newpoints, fill=255, width=3) 86 | self.plot_image() 87 | else: 88 | self.fig.canvas.draw_idle() 89 | 90 | def onclick2(event): 91 | x = event.xdata 92 | y = event.ydata 93 | if x is None or y is None: 94 | return 95 | self.newpoints.append((x,y)) 96 | self.ax_im.plot(x, y, 'y.') 97 | 98 | if len(self.newpoints) >= 2: 99 | self.fig.canvas.mpl_disconnect(self.cid) 100 | self.cid = None 101 | self.btn1.configure(text='remove area') 102 | 103 | self.lines_draw.rectangle(self.newpoints, fill=0) 104 | self.plot_image() 105 | else: 106 | self.fig.canvas.draw_idle() 107 | 108 | def btn_click0(): 109 | if self.cid is None: 110 | self.newpoints = [] 111 | self.v_line = False 112 | self.h_line = False 113 | self.cid = self.fig.canvas.mpl_connect('button_press_event', onclick1) 114 | self.btn0.configure(text='') 115 | else: 116 | self.fig.canvas.mpl_disconnect(self.cid) 117 | self.v_line = False 118 | self.h_line = False 119 | self.cid = None 120 | self.btn0.configure(text='new line') 121 | 122 | def btn_click1(): 123 | if self.cid is None: 124 | self.newpoints = [] 125 | self.cid = self.fig.canvas.mpl_connect('button_press_event', onclick2) 126 | self.btn1.configure(text='') 127 | else: 128 | self.fig.canvas.mpl_disconnect(self.cid) 129 | self.cid = None 130 | self.btn1.configure(text='remove area') 131 | 132 | def btn_click2(): 133 | if self.cid is None: 134 | self.newpoints = [] 135 | self.v_line = True 136 | self.h_line = False 137 | self.cid = self.fig.canvas.mpl_connect('button_press_event', onclick1) 138 | self.btn2.configure(text='') 139 | else: 140 | self.fig.canvas.mpl_disconnect(self.cid) 141 | self.v_line = False 142 | self.h_line = False 143 | self.cid = None 144 | self.btn2.configure(text='new vline') 145 | 146 | def btn_click3(): 147 | if self.cid is None: 148 | self.newpoints = [] 149 | self.v_line = False 150 | self.h_line = True 151 | self.cid = self.fig.canvas.mpl_connect('button_press_event', onclick1) 152 | self.btn3.configure(text='') 153 | else: 154 | self.fig.canvas.mpl_disconnect(self.cid) 155 | self.v_line = False 156 | self.h_line = False 157 | self.cid = None 158 | self.btn3.configure(text='new hline') 159 | 160 | self.btn0 = tk.Button(frame2, text='new line', command=btn_click0) 161 | self.btn0.pack(side=tk.LEFT) 162 | 163 | self.btn1 = tk.Button(frame2, text='remove area', command=btn_click1) 164 | self.btn1.pack(side=tk.LEFT) 165 | 166 | self.btn2 = tk.Button(frame2, text='new vline', command=btn_click2) 167 | self.btn2.pack(side=tk.LEFT) 168 | 169 | self.btn3 = tk.Button(frame2, text='new hline', command=btn_click3) 170 | self.btn3.pack(side=tk.LEFT) 171 | 172 | frame2.pack(side=tk.BOTTOM, fill=tk.X) 173 | frame1.pack(expand=True, fill=tk.BOTH) 174 | self.canvas = tk.Canvas(frame1) 175 | frame=tk.Frame(self.canvas) 176 | 177 | self.vbar = tk.Scrollbar(self.canvas, orient=tk.VERTICAL, command=self.canvas.yview) 178 | self.hbar = tk.Scrollbar(self.canvas, orient=tk.HORIZONTAL, command=self.canvas.xview) 179 | 180 | self.canvas.create_window((0, 0), window=frame, anchor="nw") 181 | self.canvas.configure(xscrollcommand=self.hbar.set, yscrollcommand=self.vbar.set) 182 | self.canvas.configure(xscrollincrement='1p',yscrollincrement='1p') 183 | frame.bind( 184 | "", 185 | lambda e: self.canvas.configure(scrollregion=self.canvas.bbox("all"))) 186 | root.bind('', lambda e: self.canvas.yview_scroll(-e.delta, 'units')) 187 | root.bind('', lambda e: self.canvas.xview_scroll(-e.delta, 'units')) 188 | # root.bind("", self.move_start) 189 | # root.bind("", self.move_move) 190 | 191 | self.fig = plt.figure(figsize=(self.im0.width/dpi, self.im0.height/dpi)) 192 | self.fig.subplots_adjust(left=0, right=1, bottom=0, top=1) 193 | self.ax_im = self.fig.add_subplot(111) 194 | self.im1 = FigureCanvasTkAgg(self.fig, frame) 195 | 196 | self.plot_image() 197 | 198 | self.vbar.pack(side=tk.RIGHT, fill=tk.Y) 199 | self.hbar.pack(side=tk.BOTTOM, fill=tk.X) 200 | self.canvas.pack(expand=True, fill=tk.BOTH) 201 | self.im1.get_tk_widget().pack(expand=1) 202 | 203 | # def move_start(self, event): 204 | # self.canvas.scan_mark(event.x, event.y) 205 | 206 | # def move_move(self, event): 207 | # self.canvas.scan_dragto(event.x, event.y, gain=1) 208 | 209 | def plot_image(self): 210 | self.ax_im.cla() 211 | self.ax_im.imshow(self.im0) 212 | self.ax_im.imshow(self.lines_all, cmap='gray', alpha=0.5) 213 | self.im1.draw_idle() 214 | 215 | root = tk.Tk() 216 | root.geometry('1400x800') 217 | app = Application(root, target_file) 218 | app.mainloop() 219 | 220 | app.save() 221 | -------------------------------------------------------------------------------- /fine_image/process_image4_coreml.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import coremltools as ct 4 | 5 | import numpy as np 6 | import sys 7 | from PIL import Image 8 | import json 9 | import glob 10 | import subprocess 11 | 12 | import matplotlib.pyplot as plt 13 | 14 | try: 15 | from pillow_heif import register_heif_opener 16 | register_heif_opener() 17 | except ImportError: 18 | pass 19 | 20 | from util_func import width, height, scale, feature_dim, modulo_list, sigmoid 21 | 22 | if len(sys.argv) < 2: 23 | print(sys.argv[0],'target.png') 24 | exit(1) 25 | 26 | target_files = [] 27 | model_size = 'xl' 28 | resize = 1.0 29 | cutoff = 0.4 30 | for arg in sys.argv[1:]: 31 | target_files += glob.glob(arg) 32 | 33 | target_files = sorted(target_files) 34 | 35 | print('load') 36 | mlmodel_detector = ct.models.MLModel('TextDetector.mlpackage') 37 | 38 | def eval(ds, org_img, centers): 39 | print(org_img.shape) 40 | print("test") 41 | 42 | glyphfeatures = np.zeros([centers.shape[0], feature_dim], dtype=np.float32) 43 | 44 | for n, inputs in enumerate(ds): 45 | print(n, '/', len(ds)) 46 | x_i = inputs['offsetx'] 47 | y_i = inputs['offsety'] 48 | x_s = width // scale 49 | y_s = height // scale 50 | 51 | input_image = Image.fromarray(inputs['input'], mode="RGB") 52 | 53 | output = mlmodel_detector.predict({'image': input_image}) 54 | features = output['feature'] 55 | 56 | x_min = int(x_s * 1 / 8) if x_i > 0 else 0 57 | x_max = int(x_s * 7 / 8) + 1 if x_i + width < org_img.shape[1] else x_s 58 | y_min = int(y_s * 1 / 8) if y_i > 0 else 0 59 | y_max = int(y_s * 7 / 8) + 1 if y_i + height < org_img.shape[0] else y_s 60 | 61 | target = np.where(np.logical_and(np.logical_and(x_i + x_min * scale < centers[:,0], centers[:,0] < x_i + x_max * scale), 62 | np.logical_and(y_i + y_min * scale < centers[:,1], centers[:,1] < y_i + y_max * scale)))[0] 63 | for i in target: 64 | xi = int((centers[i,0] - x_i) / scale) 65 | yi = int((centers[i,1] - y_i) / scale) 66 | glyphfeatures[i,:] = features[0,:,yi,xi] 67 | 68 | return glyphfeatures.astype(np.float16) 69 | 70 | stepx = width * 3 // 4 71 | stepy = height * 3 // 4 72 | 73 | for target_file in target_files: 74 | print(target_file) 75 | 76 | lines = np.asarray(Image.open(target_file+'.lines.png')).astype(np.float32) / 255 77 | seps = np.asarray(Image.open(target_file+'.seps.png')).astype(np.float32) / 255 78 | 79 | with open(target_file+'.json', 'r', encoding='utf-8') as file: 80 | data = json.load(file) 81 | textbox = data['textbox'] 82 | if len(textbox) == 0: 83 | print('empty') 84 | continue 85 | 86 | locations = [] 87 | for box in textbox: 88 | cx = box['cx'] 89 | cy = box['cy'] 90 | w = box['w'] 91 | h = box['h'] 92 | code1 = box['p_code1'] 93 | code2 = box['p_code2'] 94 | code4 = box['p_code4'] 95 | code8 = box['p_code8'] 96 | locations.append([cx,cy,w,h,code1,code2,code4,code8]) 97 | locations = np.array(locations, dtype=np.float32) 98 | 99 | print('construct data') 100 | h, w = lines.shape 101 | input_binary = int(0).to_bytes(4, 'little') 102 | input_binary += int(w).to_bytes(4, 'little') 103 | input_binary += int(h).to_bytes(4, 'little') 104 | input_binary += lines.tobytes() 105 | input_binary += seps.tobytes() 106 | input_binary += int(locations.shape[0]).to_bytes(4, 'little') 107 | input_binary += locations.tobytes() 108 | 109 | print('run') 110 | result = subprocess.run('textline_detect/linedetect', input=input_binary, stdout=subprocess.PIPE).stdout 111 | detected_boxes = [] 112 | p = 0 113 | max_block = 0 114 | count = int.from_bytes(result[p:p+4], byteorder='little') 115 | p += 4 116 | for i in range(count): 117 | id = int.from_bytes(result[p:p+4], byteorder='little', signed=True) 118 | p += 4 119 | block = int.from_bytes(result[p:p+4], byteorder='little', signed=True) 120 | max_block = max(max_block, block) 121 | p += 4 122 | idx = int.from_bytes(result[p:p+4], byteorder='little', signed=True) 123 | p += 4 124 | subidx = int.from_bytes(result[p:p+4], byteorder='little', signed=True) 125 | p += 4 126 | subtype = int.from_bytes(result[p:p+4], byteorder='little', signed=True) 127 | p += 4 128 | pageidx = int.from_bytes(result[p:p+4], byteorder='little', signed=True) 129 | p += 4 130 | sectionidx = int.from_bytes(result[p:p+4], byteorder='little', signed=True) 131 | p += 4 132 | detected_boxes.append((id,block,idx,subidx,subtype,pageidx,sectionidx)) 133 | 134 | # im = Image.open(target_file).convert('RGB') 135 | 136 | # fig = plt.figure() 137 | # plt.imshow(im) 138 | # fig.subplots_adjust(left=0, right=1, bottom=0, top=1) 139 | 140 | # cmap = plt.get_cmap('rainbow', max_block+1) 141 | # for id, block, idx, subidx, subtype in detected_boxes: 142 | # if id < 0: 143 | # continue 144 | # cx = locations[id, 0] 145 | # cy = locations[id, 1] 146 | # w = locations[id, 2] 147 | # h = locations[id, 3] 148 | 149 | # points = [ 150 | # [cx - w / 2, cy - h / 2], 151 | # [cx + w / 2, cy - h / 2], 152 | # [cx + w / 2, cy + h / 2], 153 | # [cx - w / 2, cy + h / 2], 154 | # [cx - w / 2, cy - h / 2], 155 | # ] 156 | # points = np.array(points) 157 | # plt.plot(points[:,0], points[:,1], color=cmap(block)) 158 | # if idx < 0: 159 | # t = '*' 160 | # else: 161 | # if subtype & 2+4 == 2+4: 162 | # points = [ 163 | # [cx - w / 2 + 1, cy - h / 2 + 1], 164 | # [cx + w / 2 - 1, cy - h / 2 + 1], 165 | # [cx + w / 2 - 1, cy + h / 2 - 1], 166 | # [cx - w / 2 + 1, cy + h / 2 - 1], 167 | # [cx - w / 2 + 1, cy - h / 2 + 1], 168 | # ] 169 | # points = np.array(points) 170 | # plt.plot(points[:,0], points[:,1], color='yellow') 171 | # t = '%d-r%d-%d'%(block, idx, subidx) 172 | # elif subtype & 2+4 == 2: 173 | # points = [ 174 | # [cx - w / 2 + 1, cy - h / 2 + 1], 175 | # [cx + w / 2 - 1, cy - h / 2 + 1], 176 | # [cx + w / 2 - 1, cy + h / 2 - 1], 177 | # [cx - w / 2 + 1, cy + h / 2 - 1], 178 | # [cx - w / 2 + 1, cy - h / 2 + 1], 179 | # ] 180 | # points = np.array(points) 181 | # plt.plot(points[:,0], points[:,1], color='blue') 182 | # t = '%d-b%d-%d'%(block, idx, subidx) 183 | # else: 184 | # t = '%d-%d-%d'%(block, idx, subidx) 185 | # if subtype & 8 == 8: 186 | # t += '+' 187 | # plt.text(cx - w/2, cy - h/2, t, color='black') 188 | # plt.show() 189 | # continue 190 | 191 | centers = [] 192 | boxlist = [] 193 | for id, block, idx, subidx, subtype, pageidx, sectionidx in detected_boxes: 194 | if id < 0: 195 | continue 196 | boxlist.append({ 197 | 'boxid': len(centers), 198 | 'blockid': block, 199 | 'lineid': idx, 200 | 'subidx': subidx, 201 | 'subtype': subtype, 202 | 'text': textbox[id].get('text', None), 203 | }) 204 | centers.append([locations[id,0], locations[id,1]]) 205 | centers = np.array(centers, dtype=np.float32) 206 | 207 | im0 = Image.open(target_file).convert('RGB') 208 | if resize != 1.0: 209 | im0 = im0.resize((int(im0.width * resize), int(im0.height * resize)), resample=Image.Resampling.BILINEAR) 210 | im0 = np.asarray(im0) 211 | 212 | padx = max(0, (width - im0.shape[1]) % stepx, width - im0.shape[1]) 213 | pady = max(0, (height - im0.shape[0]) % stepy, height - im0.shape[0]) 214 | im0 = np.pad(im0, [[0,pady],[0,padx],[0,0]], 'constant', constant_values=((255,255),(255,255),(255,255))) 215 | 216 | im = im0 217 | 218 | ds0 = [] 219 | for y in range(0, im0.shape[0] - height + 1, stepy): 220 | for x in range(0, im0.shape[1] - width + 1, stepx): 221 | ds0.append({ 222 | 'input': im[y:y+height,x:x+width,:], 223 | 'offsetx': x, 224 | 'offsety': y, 225 | }) 226 | 227 | glyph = eval(ds0, im, centers) 228 | np.save(target_file+'.npy', glyph) 229 | 230 | data['boxlist'] = boxlist 231 | with open(target_file+'.json', 'w', encoding='utf-8') as file: 232 | json.dump(data, file, indent=2, ensure_ascii=False) 233 | -------------------------------------------------------------------------------- /textline_detect/src/after_search.cpp: -------------------------------------------------------------------------------- 1 | #include "after_search.h" 2 | #include "ruby_search.h" 3 | #include "number_unbind.h" 4 | #include "make_block.h" 5 | #include "search_loop.h" 6 | 7 | #include 8 | #include 9 | #include 10 | 11 | #include 12 | #include 13 | #include 14 | #include 15 | 16 | // 短いチェーンは方向を修正しておく 17 | void fix_shortchain( 18 | std::vector &boxes, 19 | const std::vector> &line_box_chain) 20 | { 21 | for(int chainid = 0; chainid < line_box_chain.size(); chainid++) { 22 | if(line_box_chain[chainid].size() < 3) { 23 | int id1 = line_box_chain[chainid].front(); 24 | int id2 = line_box_chain[chainid].back(); 25 | float diffx = fabs(boxes[id1].cx - boxes[id2].cx); 26 | float diffy = fabs(boxes[id1].cy - boxes[id2].cy); 27 | if(diffx > diffy) { 28 | // 横書き 29 | for(auto boxid: line_box_chain[chainid]) { 30 | boxes[boxid].direction = 0; 31 | } 32 | } 33 | else { 34 | // 縦書き 35 | for(auto boxid: line_box_chain[chainid]) { 36 | boxes[boxid].direction = M_PI_2; 37 | } 38 | } 39 | } 40 | } 41 | } 42 | 43 | // chain id を登録する 44 | void register_chainid( 45 | std::vector &boxes, 46 | const std::vector> &line_box_chain) 47 | { 48 | for(int chainid = 0; chainid < line_box_chain.size(); chainid++) { 49 | for(auto boxid: line_box_chain[chainid]) { 50 | boxes[boxid].idx = chainid; 51 | if (fabs(boxes[boxid].direction) < M_PI_4) { 52 | boxes[boxid].subtype &= ~1; 53 | } 54 | else { 55 | boxes[boxid].subtype |= 1; 56 | } 57 | } 58 | } 59 | } 60 | 61 | // 飛んでる番号があるので振り直す 62 | int renumber_chain( 63 | std::vector &boxes) 64 | { 65 | std::vector chain_remap; 66 | for(const auto &box: boxes) { 67 | if(box.idx < 0) continue; 68 | if(std::find(chain_remap.begin(), chain_remap.end(), box.idx) == chain_remap.end()) { 69 | chain_remap.push_back(box.idx); 70 | } 71 | } 72 | std::sort(chain_remap.begin(), chain_remap.end()); 73 | for(auto &box: boxes) { 74 | if(box.idx < 0) continue; 75 | int id = (int)std::distance(chain_remap.begin(), std::find(chain_remap.begin(), chain_remap.end(), box.idx)); 76 | box.idx = id; 77 | } 78 | return int(chain_remap.size()); 79 | } 80 | 81 | int chain_line_force( 82 | int id_max, 83 | std::vector &boxes) 84 | { 85 | if(chain_line_ratio <= 0) return id_max; 86 | 87 | std::vector> line_box_chain(id_max); 88 | for(const auto &box: boxes) { 89 | if(box.idx < 0) continue; 90 | line_box_chain[box.idx].push_back(-1); 91 | } 92 | for(const auto &box: boxes) { 93 | if(box.idx < 0) continue; 94 | line_box_chain[box.idx][box.subidx] = box.id; 95 | } 96 | 97 | for(auto it = line_box_chain.begin(); it != line_box_chain.end();) { 98 | float direction = boxes[it->front()].direction; 99 | float ax1 = boxes[it->front()].cx; 100 | float ay1 = boxes[it->front()].cy; 101 | float ax2 = boxes[it->back()].cx; 102 | float ay2 = boxes[it->back()].cy; 103 | for(auto bit = it->rbegin(); bit != it->rend(); ++bit) { 104 | if((boxes[*bit].subtype & (2+4)) == 2+4) { 105 | continue; 106 | } 107 | ax2 = boxes[*bit].cx; 108 | ay2 = boxes[*bit].cy; 109 | break; 110 | } 111 | float s1 = 0; 112 | for(auto i: *it) { 113 | s1 = std::max(s1, std::max(boxes[i].w, boxes[i].h)); 114 | } 115 | std::vector>::iterator, float>> dist_map; 116 | for(auto it2 = line_box_chain.begin(); it2 != line_box_chain.end(); ++it2) { 117 | if(it == it2) continue; 118 | if(it2->size() > 2) { 119 | float direction2 = boxes[it2->front()].direction; 120 | if(fabs(direction) < M_PI_4 && fabs(direction2) > M_PI_4) continue; 121 | if(fabs(direction) > M_PI_4 && fabs(direction2) < M_PI_4) continue; 122 | } 123 | else if(it2->size() > 1) { 124 | // 横書きの2文字は、縦中横の可能性があるので通す 125 | float direction2 = boxes[it2->front()].direction; 126 | if(fabs(direction) < M_PI_4 && fabs(direction2) > M_PI_4) continue; 127 | } 128 | 129 | float bx1 = boxes[it2->front()].cx; 130 | float by1 = boxes[it2->front()].cy; 131 | float bx2 = boxes[it2->back()].cx; 132 | float by2 = boxes[it2->back()].cy; 133 | for(auto bit = it2->rbegin(); bit != it2->rend(); ++bit) { 134 | if((boxes[*bit].subtype & (2+4)) == 2+4) { 135 | continue; 136 | } 137 | bx2 = boxes[*bit].cx; 138 | by2 = boxes[*bit].cy; 139 | break; 140 | } 141 | 142 | if(fabs(direction) < M_PI_4) { 143 | // 横書き 144 | if(abs(ay1 - by2) < s1 && ax1 > bx2 && ax1 - bx2 < s1 * chain_line_ratio) { 145 | // b -> a 146 | dist_map.emplace_back(it2, ax1 - bx2); 147 | } 148 | if(abs(ay2 - by1) < s1 && ax2 > bx1 && ax2 - bx1 < s1 * chain_line_ratio) { 149 | // a -> b 150 | dist_map.emplace_back(it2, bx1 - ax2); 151 | } 152 | } 153 | else { 154 | // 縦書き 155 | if(abs(ax1 - bx2) < s1 && ay1 > by2 && ay1 - by2 < s1 * chain_line_ratio) { 156 | // b -> a 157 | dist_map.emplace_back(it2, ay1 - by2); 158 | } 159 | if(abs(ax2 - bx1) < s1 && ay2 > by1 && ay2 - by1 < s1 * chain_line_ratio) { 160 | // a -> b 161 | dist_map.emplace_back(it2, by1 - ay2); 162 | } 163 | } 164 | } 165 | std::sort(dist_map.begin(), dist_map.end(), [](const auto a, const auto b){ 166 | return fabs(a.second) < fabs(b.second); 167 | }); 168 | if(dist_map.empty()) { 169 | ++it; 170 | continue; 171 | } 172 | auto it2 = dist_map.front().first; 173 | auto d = dist_map.front().second; 174 | if(d < 0) { 175 | // a -> b 176 | std::copy(it2->begin(), it2->end(), std::back_inserter(*it)); 177 | boxes[it2->front()].subtype |= 8 + 512; 178 | if(fabs(direction) < M_PI_4) { 179 | for(auto i: *it) { 180 | boxes[i].subtype &= ~1; 181 | } 182 | } 183 | else { 184 | for(auto i: *it) { 185 | boxes[i].subtype |= 1; 186 | } 187 | } 188 | auto idx1 = std::distance(line_box_chain.begin(), it); 189 | auto idx2 = std::distance(line_box_chain.begin(), it2); 190 | line_box_chain.erase(it2); 191 | if(idx1 < idx2) { 192 | it = line_box_chain.begin() + idx1 + 1; 193 | } 194 | else { 195 | it = line_box_chain.begin() + idx1; 196 | } 197 | } 198 | else { 199 | // b -> a 200 | std::copy(it->begin(), it->end(), std::back_inserter(*it2)); 201 | boxes[it->front()].subtype |= 8 + 512; 202 | if(fabs(direction) < M_PI_4) { 203 | for(auto i: *it2) { 204 | boxes[i].subtype &= ~1; 205 | } 206 | } 207 | else { 208 | for(auto i: *it2) { 209 | boxes[i].subtype |= 1; 210 | } 211 | } 212 | it = line_box_chain.erase(it); 213 | } 214 | } 215 | 216 | id_max = (int)line_box_chain.size(); 217 | for(int lineid = 0; lineid < id_max; lineid++) { 218 | for(int subid = 0; subid < line_box_chain[lineid].size(); subid++) { 219 | int boxid = line_box_chain[lineid][subid]; 220 | boxes[boxid].idx = lineid; 221 | boxes[boxid].subidx = subid; 222 | } 223 | } 224 | return id_max; 225 | } 226 | 227 | void after_search( 228 | std::vector &boxes, 229 | std::vector> &line_box_chain, 230 | const std::vector &lineblocker, 231 | const std::vector &idimage) 232 | { 233 | fprintf(stderr, "after_search\n"); 234 | 235 | fix_shortchain(boxes, line_box_chain); 236 | register_chainid(boxes, line_box_chain); 237 | 238 | // ルビの検索 239 | search_ruby(boxes, line_box_chain, lineblocker, idimage); 240 | 241 | int id_max = renumber_chain(boxes); 242 | 243 | id_max = number_unbind(boxes, lineblocker, idimage, id_max); 244 | 245 | id_max = chain_line_force(id_max, boxes); 246 | std::cerr << "id max " << id_max << std::endl; 247 | 248 | make_block(boxes, lineblocker); 249 | 250 | fprintf(stderr, "after_search done\n"); 251 | } 252 | -------------------------------------------------------------------------------- /textline_detect/src/minpack/lmpar.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #define p1 1.0e-1 4 | #define p001 1.0e-3 5 | #define zero 0.0 6 | #define minmag 2.2250738585072014e-308 7 | 8 | #define MAX(a,b) ((a) > (b) ? (a) : (b)) 9 | #define MIN(a,b) ((a) < (b) ? (a) : (b)) 10 | 11 | double enorm(int n, double *x); 12 | void qrsolv(int n,double *r,int ldr,int *ipvt,double *diag,double *qtb,double *x,double *sdiag); 13 | 14 | void lmpar(int n,double *r,int ldr,int *ipvt,double *diag,double *qtb,double delta,double &par,double *x,double *sdiag) 15 | { 16 | /* 17 | c ********** 18 | c 19 | c subroutine lmpar 20 | c 21 | c given an m by n matrix a, an n by n nonsingular diagonal 22 | c matrix d, an m-vector b, and a positive number delta, 23 | c the problem is to determine a value for the parameter 24 | c par such that if x solves the system 25 | c 26 | c a*x = b , sqrt(par)*d*x = 0 , 27 | c 28 | c in the least squares sense, and dxnorm is the euclidean 29 | c norm of d*x, then either par is zero and 30 | c 31 | c (dxnorm-delta) .le. 0.1*delta , 32 | c 33 | c or par is positive and 34 | c 35 | c abs(dxnorm-delta) .le. 0.1*delta . 36 | c 37 | c this subroutine completes the solution of the problem 38 | c if it is provided with the necessary information from the 39 | c qr factorization, with column pivoting, of a. that is, if 40 | c a*p = q*r, where p is a permutation matrix, q has orthogonal 41 | c columns, and r is an upper triangular matrix with diagonal 42 | c elements of nonincreasing magnitude, then lmpar expects 43 | c the full upper triangle of r, the permutation matrix p, 44 | c and the first n components of (q transpose)*b. on output 45 | c lmpar also provides an upper triangular matrix s such that 46 | c 47 | c t t t 48 | c p *(a *a + par*d*d)*p = s *s . 49 | c 50 | c s is employed within lmpar and may be of separate interest. 51 | c 52 | c only a few iterations are generally needed for convergence 53 | c of the algorithm. if, however, the limit of 10 iterations 54 | c is reached, then the output par will contain the best 55 | c value obtained so far. 56 | c 57 | c the subroutine statement is 58 | c 59 | c subroutine lmpar(n,r,ldr,ipvt,diag,qtb,delta,par,x,sdiag, 60 | c wa1,wa2) 61 | c 62 | c where 63 | c 64 | c n is a positive integer input variable set to the order of r. 65 | c 66 | c r is an n by n array. on input the full upper triangle 67 | c must contain the full upper triangle of the matrix r. 68 | c on output the full upper triangle is unaltered, and the 69 | c strict lower triangle contains the strict upper triangle 70 | c (transposed) of the upper triangular matrix s. 71 | c 72 | c ldr is a positive integer input variable not less than n 73 | c which specifies the leading dimension of the array r. 74 | c 75 | c ipvt is an integer input array of length n which defines the 76 | c permutation matrix p such that a*p = q*r. column j of p 77 | c is column ipvt(j) of the identity matrix. 78 | c 79 | c diag is an input array of length n which must contain the 80 | c diagonal elements of the matrix d. 81 | c 82 | c qtb is an input array of length n which must contain the first 83 | c n elements of the vector (q transpose)*b. 84 | c 85 | c delta is a positive input variable which specifies an upper 86 | c bound on the euclidean norm of d*x. 87 | c 88 | c par is a nonnegative variable. on input par contains an 89 | c initial estimate of the levenberg-marquardt parameter. 90 | c on output par contains the final estimate. 91 | c 92 | c x is an output array of length n which contains the least 93 | c squares solution of the system a*x = b, sqrt(par)*d*x = 0, 94 | c for the output par. 95 | c 96 | c sdiag is an output array of length n which contains the 97 | c diagonal elements of the upper triangular matrix s. 98 | c 99 | c wa1 and wa2 are work arrays of length n. 100 | c 101 | c subprograms called 102 | c 103 | c minpack-supplied ... dpmpar,enorm,qrsolv 104 | c 105 | c fortran-supplied ... dabs,dmax1,dmin1,dsqrt 106 | c 107 | c argonne national laboratory. minpack project. march 1980. 108 | c burton s. garbow, kenneth e. hillstrom, jorge j. more 109 | c 110 | c ********** 111 | */ 112 | double *wa1 = new double[n]; 113 | double *wa2 = new double[n]; 114 | /* 115 | c 116 | c dwarf is the smallest positive magnitude. 117 | c 118 | */ 119 | double dwarf = minmag; 120 | /* 121 | c 122 | c compute and store in x the gauss-newton direction. if the 123 | c jacobian is rank-deficient, obtain a least squares solution. 124 | c 125 | */ 126 | int nsing = n; 127 | for(int j = 0; j < n; j++) { 128 | wa1[j] = qtb[j]; 129 | if (r[j+j*ldr] == zero && nsing == n) nsing = j - 1; 130 | if (nsing < n) wa1[j] = zero; 131 | } 132 | for(int k = 1; k <= nsing; k++) { 133 | int j = nsing - k; 134 | wa1[j] /= r[j+j*ldr]; 135 | 136 | double temp = wa1[j]; 137 | for(int i = 0; i < j; i++) { 138 | wa1[i] -= r[i+j*ldr]*temp; 139 | } 140 | } 141 | for(int j = 0; j < n; j++) { 142 | x[ipvt[j]] = wa1[j]; 143 | } 144 | /* 145 | c 146 | c initialize the iteration counter. 147 | c evaluate the function at the origin, and test 148 | c for acceptance of the gauss-newton direction. 149 | c 150 | */ 151 | int iter = 0; 152 | for(int j = 0; j < n; j++) { 153 | wa2[j] = diag[j]*x[j]; 154 | } 155 | double dxnorm = enorm(n,wa2); 156 | double fp = dxnorm - delta; 157 | if (fp <= p1*delta) { 158 | /* 159 | c 160 | c termination. 161 | c 162 | */ 163 | if (iter == 0) par = zero; 164 | delete[] wa1; 165 | delete[] wa2; 166 | return; 167 | } 168 | 169 | /* 170 | c 171 | c if the jacobian is not rank deficient, the newton 172 | c step provides a lower bound, parl, for the zero of 173 | c the function. otherwise set this bound to zero. 174 | c 175 | */ 176 | double parl = zero; 177 | if (nsing >= n) { 178 | for(int j = 0; j < n; j++) { 179 | wa1[j] = diag[ipvt[j]]*(wa2[ipvt[j]]/dxnorm); 180 | } 181 | for(int j = 0; j < n; j++) { 182 | double sum = zero; 183 | for(int i = 0; i < j; i++) { 184 | sum += r[i+j*ldr]*wa1[i]; 185 | } 186 | wa1[j] = (wa1[j] - sum)/r[j+j*ldr]; 187 | } 188 | double temp = enorm(n,wa1); 189 | parl = ((fp/delta)/temp)/temp; 190 | } 191 | /* 192 | c 193 | c calculate an upper bound, paru, for the zero of the function. 194 | c 195 | */ 196 | for(int j = 0; j < n; j++) { 197 | double sum = zero; 198 | for(int i = 0; i <= j; i++) { 199 | sum += r[i+j*ldr]*qtb[i]; 200 | } 201 | wa1[j] = sum/diag[ipvt[j]]; 202 | } 203 | double gnorm = enorm(n,wa1); 204 | double paru = gnorm/delta; 205 | if (paru == zero) paru = dwarf/MIN(delta,p1); 206 | /* 207 | c 208 | c if the input par lies outside of the interval (parl,paru), 209 | c set par to the closer endpoint. 210 | c 211 | */ 212 | par = MAX(par,parl); 213 | par = MIN(par,paru); 214 | if (par == zero) par = gnorm/dxnorm; 215 | /* 216 | c 217 | c beginning of an iteration. 218 | c 219 | */ 220 | while(true) { 221 | iter++; 222 | /* 223 | c 224 | c evaluate the function at the current value of par. 225 | c 226 | */ 227 | if (par == zero) par = MAX(dwarf,p001*paru); 228 | double temp = sqrt(par); 229 | for(int j = 0; j < n; j++) { 230 | wa1[j] = temp*diag[j]; 231 | } 232 | qrsolv(n,r,ldr,ipvt,wa1,qtb,x,sdiag); 233 | for(int j = 0; j < n; j++) { 234 | wa2[j] = diag[j]*x[j]; 235 | } 236 | dxnorm = enorm(n,wa2); 237 | temp = fp; 238 | fp = dxnorm - delta; 239 | /* 240 | c 241 | c if the function is small enough, accept the current value 242 | c of par. also test for the exceptional cases where parl 243 | c is zero or the number of iterations has reached 10. 244 | c 245 | */ 246 | if(fabs(fp) <= p1*delta || (parl == zero && fp <= temp && temp < zero) || iter == 10) { 247 | delete[] wa1; 248 | delete[] wa2; 249 | return; 250 | } 251 | /* 252 | c 253 | c compute the newton correction. 254 | c 255 | */ 256 | for(int j = 0; j < n; j++) { 257 | wa1[j] = diag[ipvt[j]]*(wa2[ipvt[j]]/dxnorm); 258 | } 259 | for(int j = 0; j < n; j++) { 260 | wa1[j] /= sdiag[j]; 261 | temp = wa1[j]; 262 | for(int i = j+1; i < n; i++) { 263 | wa1[i] -= r[i+j*ldr]*temp; 264 | } 265 | } 266 | temp = enorm(n,wa1); 267 | double parc = ((fp/delta)/temp)/temp; 268 | /* 269 | c 270 | c depending on the sign of the function, update parl or paru. 271 | c 272 | */ 273 | if (fp > zero) parl = MAX(parl,par); 274 | if (fp < zero) paru = MIN(paru,par); 275 | /* 276 | c 277 | c compute an improved estimate for par. 278 | c 279 | */ 280 | par = MAX(parl,par+parc); 281 | /* 282 | c 283 | c end of an iteration. 284 | c 285 | */ 286 | } 287 | } -------------------------------------------------------------------------------- /loss_func.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from util_func import modulo_list 4 | 5 | # Multi-Loss Weighting with Coefficient of Variations 6 | # https://arxiv.org/abs/2009.01717 7 | # https://github.com/rickgroen/cov-weighting 8 | class CoVWeightingLoss(torch.nn.modules.Module): 9 | def __init__(self, *args, **kwargs) -> None: 10 | self.device = kwargs.pop('device', 'cpu') 11 | self.losses = kwargs.pop('losses', []) 12 | self.num_losses = len(self.losses) 13 | super().__init__(*args, **kwargs) 14 | 15 | self.current_iter = -1 16 | self.alphas = torch.zeros((self.num_losses,), requires_grad=False, dtype=torch.float32, device=self.device) 17 | 18 | # Initialize all running statistics at 0. 19 | self.running_mean_L = torch.zeros((self.num_losses,), requires_grad=False, dtype=torch.float32, device=self.device) 20 | self.running_mean_l = torch.zeros((self.num_losses,), requires_grad=False, dtype=torch.float32, device=self.device) 21 | self.running_S_l = torch.zeros((self.num_losses,), requires_grad=False, dtype=torch.float32, device=self.device) 22 | self.running_std_l = None 23 | 24 | def forward(self, losses): 25 | L = torch.stack([losses[key].detach().clone().requires_grad_(False).to(torch.float32) for key in self.losses]) 26 | 27 | # If we are doing validation, we would like to return an unweighted loss be able 28 | # to see if we do not overfit on the training set. 29 | if not self.train: 30 | return torch.sum(L) 31 | 32 | # Increase the current iteration parameter. 33 | self.current_iter += 1 34 | # If we are at the zero-th iteration, set L0 to L. Else use the running mean. 35 | L0 = L.clone() if self.current_iter == 0 else self.running_mean_L 36 | # Compute the loss ratios for the current iteration given the current loss L. 37 | l = L / L0 38 | 39 | # If we are in the first iteration set alphas to all 1/32 40 | if self.current_iter <= 1: 41 | self.alphas = torch.ones((self.num_losses,), requires_grad=False, dtype=torch.float32, device=self.device) / self.num_losses 42 | # Else, apply the loss weighting method. 43 | else: 44 | ls = self.running_std_l / self.running_mean_l 45 | self.alphas = ls / torch.sum(ls) 46 | 47 | # Apply Welford's algorithm to keep running means, variances of L,l. But only do this throughout 48 | # training the model. 49 | # 1. Compute the decay parameter the computing the mean. 50 | if self.current_iter == 0: 51 | mean_param = 0.0 52 | else: 53 | mean_param = (1. - 1 / (self.current_iter + 1)) 54 | 55 | # 2. Update the statistics for l 56 | x_l = l.detach().clone() 57 | new_mean_l = mean_param * self.running_mean_l + (1 - mean_param) * x_l 58 | self.running_S_l += (x_l - self.running_mean_l) * (x_l - new_mean_l) 59 | self.running_mean_l = new_mean_l 60 | 61 | # The variance is S / (t - 1), but we have current_iter = t - 1 62 | running_variance_l = self.running_S_l / (self.current_iter + 1) 63 | self.running_std_l = torch.sqrt(running_variance_l.clamp_min(1e-16)) 64 | 65 | # 3. Update the statistics for L 66 | x_L = L.detach().clone() 67 | self.running_mean_L = mean_param * self.running_mean_L + (1 - mean_param) * x_L 68 | 69 | # Get the weighted losses and perform a standard back-pass. 70 | weighted_losses = [self.alphas[i] * losses[key].to(torch.float32) for i,key in enumerate(self.losses)] 71 | loss = sum(weighted_losses) 72 | return loss 73 | 74 | def heatmap_loss(true, logits): 75 | alpha = 2 76 | beta = 4 77 | pos_th = 1.0 78 | 79 | logits32 = logits.to(torch.float32) 80 | predict = torch.sigmoid(logits32) 81 | 82 | pos_mask = (true >= pos_th).to(torch.float32) 83 | neg_mask = (true < pos_th).to(torch.float32) 84 | 85 | neg_weights = torch.pow(1. - true, beta) 86 | 87 | pos_loss = - torch.nn.functional.logsigmoid(logits32) * torch.pow(1 - predict, alpha) * pos_mask 88 | neg_loss = (logits32 + torch.nn.functional.softplus(-logits32)) * torch.pow(predict, alpha) * neg_weights * neg_mask 89 | 90 | loss = (pos_loss + neg_loss).mean() 91 | 92 | return loss 93 | 94 | def loss_function(fmask, labelmap, idmap, heatmap, decoder_outputs): 95 | key_th1 = 0.85 96 | key_th2 = 0.85 97 | key_th3 = 0.99 98 | 99 | keylabel = labelmap[:,0,:,:] 100 | mask1 = keylabel > key_th1 101 | mask3 = torch.logical_and(keylabel.flatten()[fmask] > key_th3, idmap[:,0,:,:].flatten()[fmask] > 0) 102 | mask4 = torch.logical_and(keylabel.flatten()[fmask] == 1, idmap[:,0,:,:].flatten()[fmask] > 0) 103 | 104 | weight1 = torch.maximum(keylabel - key_th1, torch.tensor(0.)) / (1 - key_th1) 105 | weight1 = torch.masked_select(weight1, mask1) 106 | weight1_count = torch.maximum(torch.tensor(1.), weight1.sum()) 107 | weight2 = torch.maximum(keylabel - key_th2, torch.tensor(0.)) / (1 - key_th2) 108 | weight3 = torch.maximum(keylabel.flatten()[fmask] - key_th3, torch.tensor(0.)) / (1 - key_th3) 109 | weight3 = torch.masked_select(weight3, mask3) 110 | weight3_count = torch.maximum(torch.tensor(1.), weight3.sum()) 111 | 112 | keymap_loss = heatmap_loss(true=keylabel, logits=heatmap[:,0,:,:]) * 10. 113 | 114 | huber = torch.nn.HuberLoss(reduction='none') 115 | xsize_loss = huber(torch.masked_select(heatmap[:,1,:,:], mask1), torch.masked_select(labelmap[:,1,:,:], mask1)) 116 | ysize_loss = huber(torch.masked_select(heatmap[:,2,:,:], mask1), torch.masked_select(labelmap[:,2,:,:], mask1)) 117 | size_loss = (xsize_loss + ysize_loss) * weight1 118 | size_loss = size_loss.sum() / weight1_count 119 | 120 | textline_loss = torch.nn.functional.binary_cross_entropy_with_logits(heatmap[:,3,:,:], labelmap[:,3,:,:]) 121 | separator_loss = torch.nn.functional.binary_cross_entropy_with_logits(heatmap[:,4,:,:], labelmap[:,4,:,:]) 122 | 123 | code_losses = {} 124 | for i in range(4): 125 | label_map = ((idmap[:,1,:,:] & 2**(i)) > 0).to(torch.float32) 126 | predict_map = heatmap[:,5+i,:,:] 127 | weight = torch.ones_like(label_map) + label_map * weight2 + weight2 128 | code_loss = torch.nn.functional.binary_cross_entropy_with_logits(predict_map, label_map, weight=weight) 129 | code_losses['code%d_loss'%2**(i)] = code_loss 130 | 131 | target_id = idmap[:,0,:,:].flatten()[fmask] 132 | target_ids = [] 133 | for modulo in modulo_list: 134 | target_id1 = target_id % modulo 135 | target_ids.append(target_id1) 136 | 137 | id_loss = 0. 138 | for target_id1, decoder_id1 in zip(target_ids, decoder_outputs): 139 | target_id1 = torch.masked_select(target_id1, mask3) 140 | decoder_id1 = decoder_id1[mask3,:] 141 | id1_loss = torch.nn.functional.cross_entropy(decoder_id1, target_id1, reduction='none') 142 | id1_loss = (id1_loss * weight3).sum() / weight3_count 143 | id_loss += id1_loss 144 | 145 | pred_ids = [] 146 | for decoder_id1 in decoder_outputs: 147 | pred_id1 = torch.argmax(decoder_id1[mask4,:], dim=-1) 148 | pred_ids.append(pred_id1) 149 | 150 | target_id = torch.masked_select(target_id, mask4) 151 | target_ids = [] 152 | for modulo in modulo_list: 153 | target_id1 = target_id % modulo 154 | target_ids.append(target_id1) 155 | 156 | correct = torch.zeros_like(pred_ids[0]) 157 | for p,t in zip(pred_ids,target_ids): 158 | correct += p == t 159 | 160 | total = torch.ones_like(correct).sum() 161 | correct = (correct == 3).sum() 162 | 163 | loss = keymap_loss + size_loss + textline_loss + separator_loss + id_loss 164 | for c_loss in code_losses.values(): 165 | loss += c_loss 166 | 167 | return { 168 | 'loss': loss, 169 | 'keymap_loss': keymap_loss, 170 | 'size_loss': size_loss, 171 | 'textline_loss': textline_loss, 172 | 'separator_loss': separator_loss, 173 | 'id_loss': id_loss, 174 | **code_losses, 175 | 'correct': correct, 176 | 'total': total, 177 | } 178 | 179 | def loss_function3(outputs, labelcode, mask): 180 | target_ids = [] 181 | for modulo in modulo_list: 182 | target_id1 = labelcode % modulo 183 | target_ids.append(target_id1) 184 | 185 | loss = 0. 186 | for target_id1, decoder_id1 in zip(target_ids, outputs): 187 | id1_loss = torch.nn.functional.cross_entropy(decoder_id1.permute(0,2,1), target_id1, reduction='none') 188 | loss += torch.masked_select(id1_loss, mask).mean() 189 | 190 | pred_ids = [] 191 | for decoder_id1 in outputs: 192 | pred_id1 = torch.argmax(decoder_id1, dim=-1) 193 | pred_id1 = torch.masked_select(pred_id1, mask) 194 | pred_ids.append(pred_id1) 195 | 196 | target_ids = [] 197 | for modulo in modulo_list: 198 | target_id1 = labelcode % modulo 199 | target_id1 = torch.masked_select(target_id1, mask) 200 | target_ids.append(target_id1) 201 | 202 | correct = torch.zeros_like(pred_ids[0]) 203 | for p,t in zip(pred_ids,target_ids): 204 | correct += p == t 205 | 206 | total = torch.ones_like(correct).sum() 207 | correct = (correct == 3).sum() 208 | 209 | return { 210 | 'loss': loss, 211 | 'correct': correct, 212 | 'total': total, 213 | } 214 | -------------------------------------------------------------------------------- /models/adamw_schedulefree.py: -------------------------------------------------------------------------------- 1 | # https://github.com/facebookresearch/schedule_free 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | from typing import Tuple, Union, Optional, Iterable, Dict, Callable, Any 8 | from typing_extensions import TypeAlias 9 | import torch 10 | import torch.optim 11 | try: 12 | from torch.optim.optimizer import ParamsT 13 | except ImportError: 14 | ParamsT : TypeAlias = Union[Iterable[torch.Tensor], Iterable[Dict[str, Any]]] 15 | import math 16 | 17 | class AdamWScheduleFree(torch.optim.Optimizer): 18 | r""" 19 | Schedule-Free AdamW 20 | As the name suggests, no scheduler is needed with this optimizer. 21 | To add warmup, rather than using a learning rate schedule you can just 22 | set the warmup_steps parameter. 23 | 24 | This optimizer requires that .train() and .eval() be called before the 25 | beginning of training and evaluation respectively. The optimizer should 26 | also be placed in eval mode when saving checkpoints. 27 | 28 | Arguments: 29 | params (iterable): 30 | Iterable of parameters to optimize or dicts defining 31 | parameter groups. 32 | lr (float): 33 | Learning rate parameter (default 0.0025) 34 | betas (Tuple[float, float], optional): coefficients used for computing 35 | running averages of gradient and its square (default: (0.9, 0.999)). 36 | eps (float): 37 | Term added to the denominator outside of the root operation to 38 | improve numerical stability. (default: 1e-8). 39 | weight_decay (float): 40 | Weight decay, i.e. a L2 penalty (default: 0). 41 | warmup_steps (int): Enables a linear learning rate warmup (default 0). 42 | r (float): Use polynomial weighting in the average 43 | with power r (default 0). 44 | weight_lr_power (float): During warmup, the weights in the average will 45 | be equal to lr raised to this power. Set to 0 for no weighting 46 | (default 2.0). 47 | foreach (bool): Use a foreach-backed implementation of the optimizer. 48 | Should be significantly faster, but will have higher peak memory 49 | usage (default True if supported in your PyTorch version). 50 | """ 51 | def __init__(self, 52 | params: ParamsT, 53 | lr: Union[float, torch.Tensor] = 0.0025, 54 | betas: Tuple[float, float] = (0.9, 0.999), 55 | eps: float = 1e-8, 56 | weight_decay: float = 0, 57 | warmup_steps: int = 0, 58 | r: float = 0.0, 59 | weight_lr_power: float = 2.0, 60 | foreach: Optional[bool] = hasattr(torch, "_foreach_mul_") 61 | ): 62 | 63 | defaults = dict(lr=lr, 64 | betas=betas, 65 | eps=eps, 66 | r=r, 67 | k=0, 68 | warmup_steps=warmup_steps, 69 | train_mode=False, 70 | weight_sum=0.0, 71 | lr_max=-1.0, 72 | scheduled_lr=0.0, 73 | weight_lr_power=weight_lr_power, 74 | weight_decay=weight_decay, 75 | foreach=foreach) 76 | super().__init__(params, defaults) 77 | 78 | @torch.no_grad() 79 | def eval(self): 80 | for group in self.param_groups: 81 | train_mode = group['train_mode'] 82 | beta1, _ = group['betas'] 83 | if train_mode: 84 | for p in group['params']: 85 | state = self.state[p] 86 | if 'z' in state: 87 | # Set p to x 88 | p.lerp_(end=state['z'].to(p.device), weight=1-1/beta1) 89 | group['train_mode'] = False 90 | 91 | @torch.no_grad() 92 | def train(self): 93 | for group in self.param_groups: 94 | train_mode = group['train_mode'] 95 | beta1, _ = group['betas'] 96 | if not train_mode: 97 | for p in group['params']: 98 | state = self.state[p] 99 | if 'z' in state: 100 | # Set p to y 101 | p.lerp_(end=state['z'].to(p.device), weight=1-beta1) 102 | group['train_mode'] = True 103 | 104 | @torch.no_grad() 105 | def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float]: 106 | """Performs a single optimization step. 107 | 108 | Arguments: 109 | closure (callable, optional): A closure that reevaluates the model 110 | and returns the loss. 111 | """ 112 | if not self.param_groups[0]['train_mode']: 113 | raise Exception("Optimizer was not in train mode when step is called. " 114 | "Please insert .train() and .eval() calls on the " 115 | "optimizer. See documentation for details.") 116 | 117 | loss = None 118 | if closure is not None: 119 | with torch.enable_grad(): 120 | loss = closure() 121 | 122 | for group in self.param_groups: 123 | eps = group['eps'] 124 | beta1, beta2 = group['betas'] 125 | decay = group['weight_decay'] 126 | k = group['k'] 127 | r = group['r'] 128 | warmup_steps = group['warmup_steps'] 129 | weight_lr_power = group['weight_lr_power'] 130 | 131 | if k < warmup_steps: 132 | sched = (k+1) / warmup_steps 133 | else: 134 | sched = 1.0 135 | 136 | bias_correction2 = 1 - beta2 ** (k+1) 137 | lr = group['lr']*sched 138 | group['scheduled_lr'] = lr # For logging purposes 139 | 140 | lr_max = group['lr_max'] = max(lr, group['lr_max']) 141 | 142 | weight = ((k+1)**r) * (lr_max**weight_lr_power) 143 | weight_sum = group['weight_sum'] = group['weight_sum'] + weight 144 | 145 | try: 146 | ckp1 = weight/weight_sum 147 | except ZeroDivisionError: 148 | ckp1 = 0 149 | 150 | active_p = [p for p in group['params'] if p.grad is not None] 151 | 152 | for p in active_p: 153 | if 'z' not in self.state[p]: 154 | self.state[p]['z'] = torch.clone(p, memory_format=torch.preserve_format) 155 | self.state[p]['exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format) 156 | 157 | if group['foreach'] and len(active_p) > 0: 158 | y, grad, exp_avg_sq, z = zip(*[(p, 159 | p.grad, 160 | self.state[p]['exp_avg_sq'], 161 | self.state[p]['z']) 162 | for p in active_p]) 163 | 164 | # Decay the first and second moment running average coefficient 165 | torch._foreach_mul_(exp_avg_sq, beta2) 166 | torch._foreach_addcmul_(exp_avg_sq, grad, grad, value=1-beta2) 167 | denom = torch._foreach_div(exp_avg_sq, bias_correction2) 168 | torch._foreach_sqrt_(denom) 169 | torch._foreach_add_(denom, eps) 170 | 171 | # Normalize grad in-place for memory efficiency 172 | torch._foreach_div_(grad, denom) 173 | 174 | # Weight decay calculated at y 175 | if decay != 0: 176 | torch._foreach_add_(grad, y, alpha=decay) 177 | 178 | # These operations update y in-place, 179 | # without computing x explicitly. 180 | torch._foreach_lerp_(y, z, weight=ckp1) 181 | torch._foreach_add_(y, grad, alpha=lr*(beta1*(1-ckp1)-1)) 182 | 183 | # z step 184 | torch._foreach_sub_(z, grad, alpha=lr) 185 | else: 186 | for p in active_p: 187 | y = p # Notation to match theory 188 | grad = p.grad 189 | 190 | state = self.state[p] 191 | 192 | z = state['z'] 193 | exp_avg_sq = state['exp_avg_sq'] 194 | 195 | exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1-beta2) 196 | denom = exp_avg_sq.div(bias_correction2).sqrt_().add_(eps) 197 | 198 | # Reuse grad buffer for memory efficiency 199 | grad_normalized = grad.div_(denom) 200 | 201 | # Weight decay calculated at y 202 | if decay != 0: 203 | grad_normalized.add_(y, alpha=decay) 204 | 205 | # These operations update y in-place, 206 | # without computing x explicitly. 207 | y.lerp_(end=z, weight=ckp1) 208 | y.add_(grad_normalized, alpha=lr*(beta1*(1-ckp1)-1)) 209 | 210 | # z step 211 | z.sub_(grad_normalized, alpha=lr) 212 | 213 | group['k'] = k+1 214 | return loss -------------------------------------------------------------------------------- /fine_image/process_image4_torch.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import torch 4 | import numpy as np 5 | import sys 6 | from PIL import Image 7 | import json 8 | import glob 9 | import subprocess 10 | 11 | import matplotlib.pyplot as plt 12 | 13 | try: 14 | from pillow_heif import register_heif_opener 15 | register_heif_opener() 16 | except ImportError: 17 | pass 18 | 19 | from util_func import width, height, scale, feature_dim, modulo_list, sigmoid 20 | from models.detector import TextDetectorModel, CenterNetDetector 21 | 22 | if len(sys.argv) < 2: 23 | print(sys.argv[0],'target.png') 24 | exit(1) 25 | 26 | target_files = [] 27 | model_size = 'xl' 28 | resize = 1.0 29 | cutoff = 0.4 30 | for arg in sys.argv[1:]: 31 | if arg.startswith('--cutoff='): 32 | cutoff = float(arg.split('=')[1]) 33 | print('cutoff: ', cutoff) 34 | elif arg.startswith('--resize='): 35 | resize = float(arg.split('=')[1]) 36 | print('resize: ', resize) 37 | elif arg.startswith('--model='): 38 | model_size = arg.split('=')[1] 39 | print('model_size: ', model_size) 40 | if model_size == 's': 41 | print('model s') 42 | elif model_size == 'm': 43 | print('model m') 44 | elif model_size == 'l': 45 | print('model l') 46 | elif model_size == 'xl': 47 | print('model xl') 48 | else: 49 | exit(1) 50 | else: 51 | target_files += glob.glob(arg) 52 | 53 | target_files = sorted(target_files) 54 | 55 | print('load') 56 | model = TextDetectorModel(model_size=model_size) 57 | data = torch.load('model.pt', map_location="cpu", weights_only=True) 58 | model.load_state_dict(data['model_state_dict']) 59 | detector = CenterNetDetector(model.detector) 60 | if torch.cuda.is_available(): 61 | device = 'cuda' 62 | elif torch.backends.mps.is_available(): 63 | device = 'mps' 64 | else: 65 | device = 'cpu' 66 | device = torch.device(device) 67 | detector.to(device=device) 68 | detector.eval() 69 | 70 | def eval(ds, org_img, centers): 71 | print(org_img.shape) 72 | print("test") 73 | 74 | glyphfeatures = np.zeros([centers.shape[0], feature_dim], dtype=np.float32) 75 | 76 | for n, inputs in enumerate(ds): 77 | print(n, '/', len(ds)) 78 | x_i = inputs['offsetx'] 79 | y_i = inputs['offsety'] 80 | x_s = width // scale 81 | y_s = height // scale 82 | 83 | images = torch.from_numpy(inputs['input'] / 255.).permute(0,3,1,2).to(device=device) 84 | with torch.no_grad(): 85 | heatmap, features = detector(images) 86 | features = features.cpu().numpy() 87 | 88 | x_min = int(x_s * 1 / 8) if x_i > 0 else 0 89 | x_max = int(x_s * 7 / 8) + 1 if x_i + width < org_img.shape[1] else x_s 90 | y_min = int(y_s * 1 / 8) if y_i > 0 else 0 91 | y_max = int(y_s * 7 / 8) + 1 if y_i + height < org_img.shape[0] else y_s 92 | 93 | target = np.where(np.logical_and(np.logical_and(x_i + x_min * scale < centers[:,0], centers[:,0] < x_i + x_max * scale), 94 | np.logical_and(y_i + y_min * scale < centers[:,1], centers[:,1] < y_i + y_max * scale)))[0] 95 | for i in target: 96 | xi = int((centers[i,0] - x_i) / scale) 97 | yi = int((centers[i,1] - y_i) / scale) 98 | glyphfeatures[i,:] = features[0,:,yi,xi] 99 | 100 | return glyphfeatures.astype(np.float16) 101 | 102 | stepx = width * 3 // 4 103 | stepy = height * 3 // 4 104 | 105 | for target_file in target_files: 106 | print(target_file) 107 | 108 | lines = np.asarray(Image.open(target_file+'.lines.png')).astype(np.float32) / 255 109 | seps = np.asarray(Image.open(target_file+'.seps.png')).astype(np.float32) / 255 110 | 111 | with open(target_file+'.json', 'r', encoding='utf-8') as file: 112 | data = json.load(file) 113 | textbox = data['textbox'] 114 | if len(textbox) == 0: 115 | print('empty') 116 | continue 117 | 118 | locations = [] 119 | for box in textbox: 120 | cx = box['cx'] 121 | cy = box['cy'] 122 | w = box['w'] 123 | h = box['h'] 124 | code1 = box['p_code1'] 125 | code2 = box['p_code2'] 126 | code4 = box['p_code4'] 127 | code8 = box['p_code8'] 128 | locations.append([cx,cy,w,h,code1,code2,code4,code8]) 129 | locations = np.array(locations, dtype=np.float32) 130 | 131 | print('construct data') 132 | h, w = lines.shape 133 | input_binary = int(0).to_bytes(4, 'little') 134 | input_binary += int(w).to_bytes(4, 'little') 135 | input_binary += int(h).to_bytes(4, 'little') 136 | input_binary += lines.tobytes() 137 | input_binary += seps.tobytes() 138 | input_binary += int(locations.shape[0]).to_bytes(4, 'little') 139 | input_binary += locations.tobytes() 140 | 141 | print('run') 142 | result = subprocess.run('textline_detect/linedetect', input=input_binary, stdout=subprocess.PIPE).stdout 143 | detected_boxes = [] 144 | p = 0 145 | max_block = 0 146 | count = int.from_bytes(result[p:p+4], byteorder='little') 147 | p += 4 148 | for i in range(count): 149 | id = int.from_bytes(result[p:p+4], byteorder='little', signed=True) 150 | p += 4 151 | block = int.from_bytes(result[p:p+4], byteorder='little', signed=True) 152 | max_block = max(max_block, block) 153 | p += 4 154 | idx = int.from_bytes(result[p:p+4], byteorder='little', signed=True) 155 | p += 4 156 | subidx = int.from_bytes(result[p:p+4], byteorder='little', signed=True) 157 | p += 4 158 | subtype = int.from_bytes(result[p:p+4], byteorder='little', signed=True) 159 | p += 4 160 | pageidx = int.from_bytes(result[p:p+4], byteorder='little', signed=True) 161 | p += 4 162 | sectionidx = int.from_bytes(result[p:p+4], byteorder='little', signed=True) 163 | p += 4 164 | detected_boxes.append((id,block,idx,subidx,subtype,pageidx,sectionidx)) 165 | 166 | # im = Image.open(target_file).convert('RGB') 167 | 168 | # fig = plt.figure() 169 | # plt.imshow(im) 170 | # fig.subplots_adjust(left=0, right=1, bottom=0, top=1) 171 | 172 | # cmap = plt.get_cmap('rainbow', max_block+1) 173 | # for id, block, idx, subidx, subtype in detected_boxes: 174 | # if id < 0: 175 | # continue 176 | # cx = locations[id, 0] 177 | # cy = locations[id, 1] 178 | # w = locations[id, 2] 179 | # h = locations[id, 3] 180 | 181 | # points = [ 182 | # [cx - w / 2, cy - h / 2], 183 | # [cx + w / 2, cy - h / 2], 184 | # [cx + w / 2, cy + h / 2], 185 | # [cx - w / 2, cy + h / 2], 186 | # [cx - w / 2, cy - h / 2], 187 | # ] 188 | # points = np.array(points) 189 | # plt.plot(points[:,0], points[:,1], color=cmap(block)) 190 | # if idx < 0: 191 | # t = '*' 192 | # else: 193 | # if subtype & 2+4 == 2+4: 194 | # points = [ 195 | # [cx - w / 2 + 1, cy - h / 2 + 1], 196 | # [cx + w / 2 - 1, cy - h / 2 + 1], 197 | # [cx + w / 2 - 1, cy + h / 2 - 1], 198 | # [cx - w / 2 + 1, cy + h / 2 - 1], 199 | # [cx - w / 2 + 1, cy - h / 2 + 1], 200 | # ] 201 | # points = np.array(points) 202 | # plt.plot(points[:,0], points[:,1], color='yellow') 203 | # t = '%d-r%d-%d'%(block, idx, subidx) 204 | # elif subtype & 2+4 == 2: 205 | # points = [ 206 | # [cx - w / 2 + 1, cy - h / 2 + 1], 207 | # [cx + w / 2 - 1, cy - h / 2 + 1], 208 | # [cx + w / 2 - 1, cy + h / 2 - 1], 209 | # [cx - w / 2 + 1, cy + h / 2 - 1], 210 | # [cx - w / 2 + 1, cy - h / 2 + 1], 211 | # ] 212 | # points = np.array(points) 213 | # plt.plot(points[:,0], points[:,1], color='blue') 214 | # t = '%d-b%d-%d'%(block, idx, subidx) 215 | # else: 216 | # t = '%d-%d-%d'%(block, idx, subidx) 217 | # if subtype & 8 == 8: 218 | # t += '+' 219 | # plt.text(cx - w/2, cy - h/2, t, color='black') 220 | # plt.show() 221 | # continue 222 | 223 | centers = [] 224 | boxlist = [] 225 | for id, block, idx, subidx, subtype, pageidx, sectionidx in detected_boxes: 226 | if id < 0: 227 | continue 228 | boxlist.append({ 229 | 'boxid': len(centers), 230 | 'blockid': block, 231 | 'lineid': idx, 232 | 'subidx': subidx, 233 | 'subtype': subtype, 234 | 'text': textbox[id].get('text', None), 235 | }) 236 | centers.append([locations[id,0], locations[id,1]]) 237 | centers = np.array(centers, dtype=np.float32) 238 | 239 | im0 = Image.open(target_file).convert('RGB') 240 | if resize != 1.0: 241 | im0 = im0.resize((int(im0.width * resize), int(im0.height * resize)), resample=Image.Resampling.BILINEAR) 242 | im0 = np.asarray(im0) 243 | 244 | padx = max(0, (width - im0.shape[1]) % stepx, width - im0.shape[1]) 245 | pady = max(0, (height - im0.shape[0]) % stepy, height - im0.shape[0]) 246 | im0 = np.pad(im0, [[0,pady],[0,padx],[0,0]], 'constant', constant_values=((255,255),(255,255),(255,255))) 247 | 248 | im = im0 249 | 250 | ds0 = [] 251 | for y in range(0, im0.shape[0] - height + 1, stepy): 252 | for x in range(0, im0.shape[1] - width + 1, stepx): 253 | ds0.append({ 254 | 'input': im[y:y+height,x:x+width,:], 255 | 'offsetx': x, 256 | 'offsety': y, 257 | }) 258 | 259 | glyph = eval(ds0, im, centers) 260 | np.save(target_file+'.npy', glyph) 261 | 262 | data['boxlist'] = boxlist 263 | with open(target_file+'.json', 'w', encoding='utf-8') as file: 264 | json.dump(data, file, indent=2, ensure_ascii=False) 265 | -------------------------------------------------------------------------------- /models/radam_schedulefree.py: -------------------------------------------------------------------------------- 1 | # https://github.com/facebookresearch/schedule_free 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | from typing import Tuple, Union, Optional, Iterable, Dict, Callable, Any 8 | from typing_extensions import TypeAlias 9 | import torch 10 | import torch.optim 11 | try: 12 | from torch.optim.optimizer import ParamsT 13 | except ImportError: 14 | ParamsT : TypeAlias = Union[Iterable[torch.Tensor], Iterable[Dict[str, Any]]] 15 | import math 16 | 17 | class RAdamScheduleFree(torch.optim.Optimizer): 18 | r""" 19 | Schedule-Free RAdam 20 | Neither warmup hyperparameter nor scheduler is needed with this optimizer. 21 | 22 | This optimizer requires that .train() and .eval() be called before the 23 | beginning of training and evaluation respectively. The optimizer should 24 | also be placed in eval mode when saving checkpoints. 25 | 26 | Arguments: 27 | params (iterable): 28 | Iterable of parameters to optimize or dicts defining 29 | parameter groups. 30 | lr (float): 31 | Learning rate parameter (default 0.0025) 32 | betas (Tuple[float, float], optional): coefficients used for computing 33 | running averages of gradient and its square (default: (0.9, 0.999)). 34 | eps (float): 35 | Term added to the denominator outside of the root operation to 36 | improve numerical stability. (default: 1e-8). 37 | weight_decay (float): 38 | Weight decay, i.e. a L2 penalty (default: 0). 39 | r (float): Use polynomial weighting in the average 40 | with power r (default 0). 41 | weight_lr_power (float): During warmup, the weights in the average will 42 | be equal to lr raised to this power. Set to 0 for no weighting 43 | (default 2.0). 44 | foreach (bool): Use a foreach-backed implementation of the optimizer. 45 | Should be significantly faster, but will have higher peak memory 46 | usage (default True if supported in your PyTorch version). 47 | silent_sgd_phase (bool): If True, the optimizer will not use the first SGD phase of RAdam. 48 | This means that the optimizer will not update model parameters during the early training 49 | steps (e.g., < 5 when β_2 = 0.999), but just update the momentum values of the optimizer. 50 | This helps stabilize training by ensuring smoother warmup behavior and more reliable 51 | calculation of the moving average coefficient (`ckp1`). Recommended to set to True 52 | (default True). 53 | """ 54 | 55 | def __init__(self, 56 | params: ParamsT, 57 | lr: Union[float, torch.Tensor] = 0.0025, 58 | betas: Tuple[float, float] = (0.9, 0.999), 59 | eps: float = 1e-8, 60 | weight_decay: float = 0, 61 | r: float = 0.0, 62 | weight_lr_power: float = 2.0, 63 | foreach: Optional[bool] = hasattr(torch, "_foreach_mul_"), 64 | silent_sgd_phase: bool = True 65 | ): 66 | 67 | defaults = dict(lr=lr, 68 | betas=betas, 69 | eps=eps, 70 | r=r, 71 | k=0, 72 | train_mode=False, 73 | weight_sum=0.0, 74 | lr_max=-1.0, 75 | scheduled_lr=0.0, 76 | weight_lr_power=weight_lr_power, 77 | weight_decay=weight_decay, 78 | foreach=foreach, 79 | silent_sgd_phase=silent_sgd_phase) 80 | super().__init__(params, defaults) 81 | 82 | @torch.no_grad() 83 | def eval(self): 84 | for group in self.param_groups: 85 | train_mode = group["train_mode"] 86 | beta1, _ = group["betas"] 87 | if train_mode: 88 | for p in group["params"]: 89 | state = self.state[p] 90 | if "z" in state: 91 | # Set p to x 92 | p.lerp_(end=state["z"].to(p.device), weight=1 - 1 / beta1) 93 | group["train_mode"] = False 94 | 95 | @torch.no_grad() 96 | def train(self): 97 | for group in self.param_groups: 98 | train_mode = group["train_mode"] 99 | beta1, _ = group["betas"] 100 | if not train_mode: 101 | for p in group["params"]: 102 | state = self.state[p] 103 | if "z" in state: 104 | # Set p to y 105 | p.lerp_(end=state["z"].to(p.device), weight=1 - beta1) 106 | group["train_mode"] = True 107 | 108 | @torch.no_grad() 109 | def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float]: 110 | """Performs a single optimization step. 111 | 112 | Arguments: 113 | closure (callable, optional): A closure that reevaluates the model 114 | and returns the loss. 115 | """ 116 | if not self.param_groups[0]["train_mode"]: 117 | raise Exception( 118 | "Optimizer was not in train mode when step is called. " 119 | "Please insert .train() and .eval() calls on the " 120 | "optimizer. See documentation for details." 121 | ) 122 | 123 | loss = None 124 | if closure is not None: 125 | with torch.enable_grad(): 126 | loss = closure() 127 | 128 | for group in self.param_groups: 129 | eps = group["eps"] 130 | beta1, beta2 = group["betas"] 131 | decay = group["weight_decay"] 132 | silent_sgd_phase = group["silent_sgd_phase"] 133 | k = group["k"] # current steps 134 | step = k + 1 135 | r = group['r'] 136 | weight_lr_power = group['weight_lr_power'] 137 | 138 | beta2_t = beta2**step 139 | bias_correction2 = 1 - beta2_t 140 | 141 | # maximum length of the approximated SMA 142 | rho_inf = 2 / (1 - beta2) - 1 143 | # compute the length of the approximated SMA 144 | rho_t = rho_inf - 2 * step * beta2_t / bias_correction2 145 | rect = ( 146 | ((rho_t - 4) * (rho_t - 2) * rho_inf / ((rho_inf - 4) * (rho_inf - 2) * rho_t)) ** 0.5 147 | if rho_t > 4.0 148 | else float(not silent_sgd_phase) 149 | ) 150 | 151 | lr = group["lr"] * rect 152 | group["scheduled_lr"] = lr # For logging purposes 153 | 154 | lr_max = group["lr_max"] = max(lr, group["lr_max"]) 155 | 156 | weight = (step**r) * (lr_max**weight_lr_power) 157 | weight_sum = group["weight_sum"] = group["weight_sum"] + weight 158 | 159 | try: 160 | ckp1 = weight / weight_sum 161 | except ZeroDivisionError: 162 | ckp1 = 0 163 | 164 | adaptive_y_lr = lr * (beta1 * (1 - ckp1) - 1) 165 | active_p = [p for p in group["params"] if p.grad is not None] 166 | 167 | for p in active_p: 168 | if "z" not in self.state[p]: 169 | self.state[p]["z"] = torch.clone(p, memory_format=torch.preserve_format) 170 | self.state[p]["exp_avg_sq"] = torch.zeros_like(p, memory_format=torch.preserve_format) 171 | 172 | if group["foreach"] and len(active_p) > 0: 173 | y, grad, exp_avg_sq, z = zip( 174 | *[(p, p.grad, self.state[p]["exp_avg_sq"], self.state[p]["z"]) for p in active_p] 175 | ) 176 | 177 | # Decay the first and second moment running average coefficient 178 | torch._foreach_mul_(exp_avg_sq, beta2) 179 | torch._foreach_addcmul_(exp_avg_sq, grad, grad, value=1 - beta2) 180 | 181 | if rho_t > 4.0: 182 | # Adam step 183 | denom = torch._foreach_div(exp_avg_sq, bias_correction2) 184 | torch._foreach_sqrt_(denom) 185 | torch._foreach_add_(denom, eps) 186 | 187 | # Normalize grad in-place for memory efficiency 188 | torch._foreach_div_(grad, denom) 189 | 190 | # Weight decay calculated at y 191 | if decay != 0: 192 | torch._foreach_add_(grad, y, alpha=decay) 193 | 194 | # These operations update y in-place, 195 | # without computing x explicitly. 196 | torch._foreach_lerp_(y, z, weight=ckp1) 197 | torch._foreach_add_(y, grad, alpha=adaptive_y_lr) 198 | 199 | # z step 200 | torch._foreach_sub_(z, grad, alpha=lr) 201 | else: 202 | for p in active_p: 203 | y = p # Notation to match theory 204 | grad = p.grad 205 | 206 | state = self.state[p] 207 | 208 | z = state["z"] 209 | exp_avg_sq = state["exp_avg_sq"] 210 | 211 | exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) 212 | 213 | if rho_t > 4.0: 214 | # Adam step 215 | denom = exp_avg_sq.div(bias_correction2).sqrt_().add_(eps) 216 | 217 | # Reuse grad buffer for memory efficiency 218 | grad_normalized = grad.div_(denom) 219 | else: 220 | # Fall back to SGD (or nothing) 221 | grad_normalized = grad 222 | 223 | # Weight decay calculated at y 224 | if decay != 0: 225 | grad_normalized.add_(y, alpha=decay) 226 | 227 | # These operations update y in-place, 228 | # without computing x explicitly. 229 | y.lerp_(end=z, weight=ckp1) 230 | y.add_(grad_normalized, alpha=adaptive_y_lr) 231 | 232 | # z step 233 | z.sub_(grad_normalized, alpha=lr) 234 | 235 | group["k"] = k + 1 236 | return loss 237 | --------------------------------------------------------------------------------