├── rain ├── __init__.py ├── RFL_rain.py ├── calculate_dis.py ├── initializer.py ├── optimizer.py ├── data_aug.py ├── evaluate.py ├── metric.py ├── xconfig.py └── my_encode.py ├── RFL ├── Tool_Formula │ ├── test_latex_norm.py │ └── latex_norm │ │ ├── __init__.py │ │ ├── __pycache__ │ │ ├── __init__.cpython-39.pyc │ │ ├── transcription.cpython-39.pyc │ │ └── reverse_transcription.cpython-39.pyc │ │ ├── post_line_correct.map │ │ ├── pre_word_correct.map │ │ ├── rep_dict.map │ │ ├── structure_word_correct.map │ │ └── katex_valid_symbols.map ├── RFL_gen.sh ├── RFL_vocab.py ├── chemfig2ssml.py ├── RFL_main.py ├── chemfig_ssml_struct.py ├── reverse_render_main.py ├── text_render_main.py ├── cond_render_ssml_main.py ├── cond_render_main.py ├── viz_struct.py ├── utils.py ├── text_render.py └── graph_cmp.py ├── img ├── Result.png ├── Case_study.png ├── Framework.png ├── Introduction.png └── Generalization.png ├── train.sh ├── test_organic.sh ├── LICENSE ├── post_process_chemfig.py ├── loader_profiler.py ├── refine_name_for_log.py ├── README.md ├── dict └── vocab.txt └── test_list_multi.py /rain/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /RFL/Tool_Formula/test_latex_norm.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /RFL/Tool_Formula/latex_norm/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /img/Result.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JingMog/RFL-MSD/HEAD/img/Result.png -------------------------------------------------------------------------------- /img/Case_study.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JingMog/RFL-MSD/HEAD/img/Case_study.png -------------------------------------------------------------------------------- /img/Framework.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JingMog/RFL-MSD/HEAD/img/Framework.png -------------------------------------------------------------------------------- /img/Introduction.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JingMog/RFL-MSD/HEAD/img/Introduction.png -------------------------------------------------------------------------------- /img/Generalization.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JingMog/RFL-MSD/HEAD/img/Generalization.png -------------------------------------------------------------------------------- /RFL/Tool_Formula/latex_norm/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JingMog/RFL-MSD/HEAD/RFL/Tool_Formula/latex_norm/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /RFL/Tool_Formula/latex_norm/__pycache__/transcription.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JingMog/RFL-MSD/HEAD/RFL/Tool_Formula/latex_norm/__pycache__/transcription.cpython-39.pyc -------------------------------------------------------------------------------- /RFL/Tool_Formula/latex_norm/__pycache__/reverse_transcription.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JingMog/RFL-MSD/HEAD/RFL/Tool_Formula/latex_norm/__pycache__/reverse_transcription.cpython-39.pyc -------------------------------------------------------------------------------- /train.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES=0,1,2,3 2 | TRAIN_GPU_PER_NODE=2 3 | 4 | OMP_NUM_THREADS=2 torchrun --nproc_per_node $TRAIN_GPU_PER_NODE \ 5 | --master_port=12503 ce_trainer.py 6 | 7 | 8 | 9 | 10 | -------------------------------------------------------------------------------- /RFL/RFL_gen.sh: -------------------------------------------------------------------------------- 1 | conda activate torch 2 | 3 | WORKERS=40 4 | # EDU-CHEMC 5 | # Test 6 | python ./RFL_main.py \ 7 | -input ./test_ssml_sd.txt \ 8 | -output ./result/test_cs_string.txt \ 9 | -error_output ./result/test_error.txt\ 10 | -num_workers ${WORKERS} 11 | 12 | # Train 13 | # python ./complex2simple_main.py \ 14 | # -input ./train_ssml_sd.txt \ 15 | # -output ./result/train_cs_string.txt \ 16 | # -error_output ./result/train_error.txt\ 17 | # -num_workers ${WORKERS} 18 | 19 | 20 | 21 | 22 | 23 | -------------------------------------------------------------------------------- /rain/RFL_rain.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.insert(0, "./RFL_") 3 | from RFL import cs_main, chemstem2chemfig 4 | 5 | if __name__ == '__main__': 6 | str1 = 'H B r + \chemfig { ?[a] -[:330] -[:30] -[:90] ( =[:45] ?[b] ( -[:0] ?[c] ( -:[:0] \circle ) ( -[:60] ( -[:0] -[:300] -[:240] -[:180] ?[c,{-}] ) -[:135] \Chemabove { N } { H } -[:210] ?[b,{-}] =[:150] O ) ) ) -[:165] \Chemabove { N } { H } ?[a,{-}] }' 7 | success, cs_string, branch_info, ring_branch_info, cond_data = cs_main(str1, is_show=True) 8 | print(success) 9 | print(cs_string) 10 | print(ring_branch_info) 11 | 12 | 13 | 14 | 15 | 16 | -------------------------------------------------------------------------------- /test_organic.sh: -------------------------------------------------------------------------------- 1 | conda activate torch 2 | 3 | 4 | TEST_LRC=ssml_valid.lrc 5 | TEST_NAME=EDU-CHEMC-v3 6 | EPOCH=45 # 45, 32 7 | USED_GPU_ID=0_1_2 8 | PROCESS_PER_GPU=3 9 | 10 | echo TEST_LRC=${TEST_LRC} 11 | echo TEST_NAME=${TEST_NAME} 12 | echo EPOCH=${EPOCH} 13 | echo PROCESS_PER_GPU=${PROCESS_PER_GPU} 14 | echo USED_GPU_ID=${USED_GPU_ID} 15 | 16 | python3 test_lrc_top1top3_log.py \ 17 | --process_per_gpu ${PROCESS_PER_GPU} \ 18 | --used_gpu_id ${USED_GPU_ID} \ 19 | --test_lrc=${TEST_LRC} \ 20 | --name=${TEST_NAME} \ 21 | --load_epoch=${EPOCH} \ 22 | --is_show=False 23 | 24 | 25 | -------------------------------------------------------------------------------- /RFL/Tool_Formula/latex_norm/post_line_correct.map: -------------------------------------------------------------------------------- 1 | c o s C o s 2 | \because b e c a u s e \begcause 3 | # s \ace \th  4 | s i n S i n 5 | \smear < E R R > \err \seemar \semear \srmea \ssmear \semear \smeae \smevar 6 | \frac { 2 } { 3 } \fac 2 3 7 | \frac { 1 } { 2 } \farc 1 2 8 | \frac { \sqrt { 3 } } { 2 } \farc \sqrt { 3 } 2 9 | \frac { 3 } { 2 } \farc 3 2 10 | \frac { \sqrt { 1 7 } } { 2 } \farc \sqrt { 1 } 7 2 11 | \therefore \the ? r e f o r e t h e r e f o e r t h e r e f o r e 12 | \vartriangle \va \triangle 13 | \jump \jump \jump 14 | \jump 15 | \angle / a n g l e 16 | \triangle t r i a n g e l 17 | \textcircled { 1 } t e x t c i r c l e d 1 18 | c o s x \cox s -------------------------------------------------------------------------------- /RFL/Tool_Formula/latex_norm/pre_word_correct.map: -------------------------------------------------------------------------------- 1 | \bcancel 2 | \enter 3 | \frac 4 | \therefore 5 | \smear 6 | \textbf 7 | \sqrt 8 | \angle 9 | \space 10 | \because 11 | \times 12 | \circ 13 | \bot 14 | \cdot 15 | \triangle 16 | \overrightarrow 17 | \textcircled 18 | \prime 19 | \begincases 20 | \endcases 21 | \theta 22 | \alpha 23 | \Rightarrow 24 | \subset 25 | \infty 26 | \unk 27 | \cap 28 | \div 29 | \rho 30 | \not 31 | \int 32 | \overline 33 | \cdots 34 | \underline 35 | \rightarrow 36 | \partial 37 | \cong 38 | \beginmatrix 39 | \endmatrix 40 | \lambda 41 | \beta 42 | \varphi 43 | \Delta 44 | \widehat 45 | \omega 46 | \sim 47 | \cup 48 | \approx 49 | \subsetneqq 50 | \odot 51 | \sum 52 | \subseteq 53 | \varnothing 54 | \pxdy 55 | \uparrow 56 | \underset 57 | \neg 58 | \downarrow 59 | \pxsbx 60 | \overset 61 | \textcelsius 62 | \frown 63 | \phi 64 | \textit -------------------------------------------------------------------------------- /rain/calculate_dis.py: -------------------------------------------------------------------------------- 1 | import Levenshtein 2 | 3 | def cal_edit_ops(str1, str2): 4 | char_idx_dict = dict() 5 | for item in str1: 6 | if item not in char_idx_dict: 7 | char_idx_dict[item] = chr(len(char_idx_dict)) #转成这样是因为同一字符长度,方便计算 8 | for item in str2: 9 | if item not in char_idx_dict: 10 | char_idx_dict[item] = chr(len(char_idx_dict)) 11 | str1 = ''.join([char_idx_dict[item] for item in str1]) 12 | str2 = ''.join([char_idx_dict[item] for item in str2]) 13 | ops = Levenshtein.editops(str1, str2) #计算如果第一个字符串变成第二个字符串需要哪些操作 14 | return ops 15 | 16 | 17 | def count_ops(ops): 18 | insert_nums = sum([1 for op_name, *_ in ops if op_name=='delete']) 19 | substitute_nums = sum([1 for op_name, *_ in ops if op_name=='replace']) 20 | delete_nums = sum([1 for op_name, *_ in ops if op_name=='insert']) 21 | assert delete_nums + substitute_nums + insert_nums == len(ops) 22 | return delete_nums, substitute_nums, insert_nums -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 JsingMog 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /RFL/RFL_vocab.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | 4 | def main(args): 5 | vocab_path = args.output 6 | train_cs_string_path = args.train_input 7 | valid_cs_string_path = args.valid_input 8 | 9 | pre_define_words = ['\\unk', '', '', '\\enter', '\\jump', '\\space'] 10 | vocab_list = ['>@'] 11 | valid_cs_string = [] 12 | with open(valid_cs_string_path, 'r') as f: 13 | valid_cs_string = f.readlines() 14 | for string in valid_cs_string: 15 | label = string.strip().split('\t')[1] 16 | label = label.split('—')[0] 17 | unit = label.split(' ') 18 | for u in unit: 19 | if u not in vocab_list and u not in pre_define_words: 20 | vocab_list.append(u) 21 | 22 | 23 | with open(train_cs_string_path, 'r') as f: 24 | valid_cs_string = f.readlines() 25 | for string in valid_cs_string: 26 | label = string.strip().split('\t')[1] 27 | label = label.split('—')[0] 28 | unit = label.split(' ') 29 | for u in unit: 30 | if u not in vocab_list and u not in pre_define_words: 31 | vocab_list.append(u) 32 | 33 | vocab_list = pre_define_words + sorted(vocab_list) 34 | assert len(vocab_list) == len(set(vocab_list)) 35 | with open(vocab_path, 'w') as f: 36 | for i in range(len(vocab_list)): 37 | f.write(vocab_list[i] + '\t' + str(i) + '\n') 38 | 39 | 40 | 41 | if __name__ == "__main__": 42 | parser = argparse.ArgumentParser("") 43 | parser.add_argument("-train_input", type=str, default='./result/train_cs_string.txt') 44 | parser.add_argument("-valid_input", type=str, default='./result/valid_cs_string.txt') 45 | parser.add_argument("-output", type=str, default='./result/vocab.txt') 46 | args = parser.parse_args() 47 | main(args) 48 | 49 | 50 | 51 | -------------------------------------------------------------------------------- /post_process_chemfig.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | import argparse 3 | import pdb 4 | import tqdm 5 | from RFL import utils 6 | 7 | def main(args): 8 | with open(args.input, "r") as fin: 9 | lines = fin.readlines() 10 | 11 | output = args.output 12 | if output is None: 13 | prefix, ext = os.path.splitext(args.input) 14 | output = "{}_chemprocess{}".format(prefix, ext) 15 | 16 | fout = open(output, "w") 17 | 18 | for ind, line in enumerate(tqdm.tqdm(lines)): 19 | spts =line.strip().split("\t") 20 | if len(spts) == 2: 21 | spts.append("") 22 | if len(spts) !=3: 23 | print(line) 24 | continue 25 | img_key, lab, rec = spts 26 | _, rep_dict, remain_trans = utils.replace_chemfig(rec) 27 | for key, trans in rep_dict.items(): 28 | words = trans.split(" ") 29 | new_words = [] 30 | for word in words: 31 | if word in ["(", ")", "-"]: 32 | word = "{"+word+"}" 33 | word = word.replace("branch", "") 34 | new_words.append(word) 35 | rep_dict[key] = " ".join(new_words) 36 | #pdb.set_trace() 37 | remain_spts = remain_trans.split(" ") 38 | out_spts = [] 39 | for remain_spt in remain_spts: 40 | if remain_spt in rep_dict: 41 | out_spts.append(rep_dict[remain_spt]) 42 | else: 43 | out_spts.append(remain_spt) 44 | # for key, new_trans in rep_dict.items(): 45 | # remain_trans = remain_trans.replace(key, new_trans) 46 | out_rec = " ".join(out_spts) 47 | out_rec = out_rec.replace('branch', '') 48 | fout.write("{}\t{}\t{}\n".format(img_key, lab, out_rec)) 49 | fout.close() 50 | 51 | pass 52 | 53 | if __name__ == "__main__": 54 | parser = argparse.ArgumentParser("") 55 | parser.add_argument("-input", type=str) 56 | parser.add_argument("-output", type=str, default=None) 57 | parser.add_argument("-num_workers", type=int, default=32) 58 | args = parser.parse_args() 59 | main(args) 60 | -------------------------------------------------------------------------------- /rain/initializer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import logging 3 | import os, sys 4 | logger = logging.getLogger() 5 | import torch 6 | 7 | def initialize_model(model, model_prefix, model_type='pytorch', epoch=None, rank=0, allow_missing=False): 8 | model_dir = "/".join(model_prefix.split("/")[:-1]) 9 | if rank == 0 and not os.path.exists(model_dir): 10 | os.mkdir(model_dir) 11 | 12 | initialize_model_from_pytorch(model, model_prefix, epoch, allow_missing) 13 | 14 | def initialize_model_from_pytorch(model, model_prefix, epoch=None, allow_missing=False): 15 | if epoch is not None: 16 | model_path = '%s-%04d.pt'%(model_prefix, epoch) 17 | logger.info("Loading params from %s" % model_path) 18 | if not os.path.exists(model_path): 19 | model_path = model_path.replace("_for_test", "") 20 | param_state_dict = torch.load(model_path, map_location='cpu') 21 | model_state_dict = model.state_dict() 22 | if allow_missing: 23 | for k, v in param_state_dict.items(): 24 | if k in model_state_dict and model_state_dict[k].shape == v.shape: 25 | model_state_dict[k] = v 26 | else: 27 | logger.info("param %s can't be loaded, which shape is %s" % (k, v.shape)) 28 | model.load_state_dict(model_state_dict) 29 | else: 30 | model.load_state_dict(param_state_dict, strict = False) 31 | 32 | def initialize_model_from_pytorch_v2(model, model_path, allow_missing=False): 33 | if model_path is not None: 34 | logger.info("Loading params from %s" % model_path) 35 | param_state_dict = torch.load(model_path, map_location='cpu') 36 | if "net" in param_state_dict: 37 | param_state_dict = param_state_dict["net"] 38 | model_state_dict = model.state_dict() 39 | if allow_missing: 40 | for k, v in param_state_dict.items(): 41 | if k.find("decoder") != -1: 42 | continue 43 | if k in model_state_dict and model_state_dict[k].shape == v.shape: 44 | model_state_dict[k] = v 45 | else: 46 | logger.info("param %s can't be loaded, which shape is %s" % (k, v.shape)) 47 | model.load_state_dict(model_state_dict) 48 | else: 49 | model.load_state_dict(param_state_dict, strict = False) 50 | 51 | def save_pytorch_model(model, model_prefix, epoch): 52 | torch.save(model.state_dict(), '%s-%04d.pt'%(model_prefix, epoch)) -------------------------------------------------------------------------------- /loader_profiler.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | 4 | class LoaderTracer(): 5 | #code by jxwang of pydlp team 6 | def __init__(self,display_freq,drop_history,logger=None): 7 | self.time_loader_io = 0 8 | self.time_loader_others = 0 9 | self.io_count = 0 10 | self.others_count =0 11 | self.last_time_stmap_after_io = None 12 | self.display_freq = display_freq 13 | self.drop_history = drop_history 14 | self.logger = logger 15 | 16 | def update_io_time(self,io_time_start): 17 | self.last_time_stmap_after_io = time.time() 18 | self.time_loader_io +=self.last_time_stmap_after_io - io_time_start 19 | self.io_count+=1 20 | if (self.io_count) %self.display_freq == 0: 21 | if self.logger is None: 22 | print("io time per batch(%d): %.5f"%(os.getpid(),self.time_loader_io /self.io_count )) 23 | else: 24 | self.logger.info("io time per batch(%d): %.5f"%(os.getpid(),self.time_loader_io /self.io_count )) 25 | if self.drop_history: 26 | self.time_loader_io=0 27 | self.io_count =0 28 | 29 | def update_other_time(self,time_end_others): 30 | if self.last_time_stmap_after_io is not None: 31 | self.others_count+=1 32 | self.time_loader_others += time_end_others - self.last_time_stmap_after_io 33 | if (self.others_count) %self.display_freq == 0 : 34 | if self.logger is None: 35 | print("outer time per batch(%d): %.5f"%(os.getpid(),self.time_loader_others /self.others_count)) 36 | else: 37 | self.logger.info("outer time per batch(%d): %.5f"%(os.getpid(),self.time_loader_others /self.others_count)) 38 | if self.drop_history: 39 | self.time_loader_others =0 40 | self.others_count =0 41 | 42 | class LoaderProfiler(): 43 | #code by jxwang of pydlp team 44 | def __init__(self,loader,diplay_freq =100,drop_history=False, logger=None): 45 | self.profile = LoaderTracer(diplay_freq,drop_history, logger=logger) 46 | self.loader_iter = iter(loader) 47 | self.index = 0 48 | 49 | def __iter__(self): 50 | return self 51 | 52 | def __next__(self): 53 | time_start_io = time.time() 54 | self.profile.update_other_time(time_start_io) 55 | # item = self.loader_iter.next() 56 | item = self.loader_iter._next_data() # 这里使用next会报错,改成_next_data 57 | self.profile.update_io_time(time_start_io) 58 | index_past = self.index 59 | self.index += 1 60 | return index_past,item 61 | # only for debug and test,so weird named 62 | def _test_debug_info__(self): 63 | return self.profile 64 | -------------------------------------------------------------------------------- /refine_name_for_log.py: -------------------------------------------------------------------------------- 1 | from typing import OrderedDict 2 | import numpy as np 3 | import cv2 4 | import os 5 | from PIL import Image, ImageDraw, ImageFont 6 | import Levenshtein 7 | import sys 8 | import random 9 | import tqdm 10 | import pdb 11 | import argparse 12 | from multiprocessing import Process, synchronize, Lock, Manager, Pool 13 | from six.moves import queue 14 | import sys 15 | from data_encapsulation import ListRecordLoader 16 | 17 | 18 | def load_lab_rec(in_log): 19 | lab_dict = OrderedDict() 20 | rec_dict = OrderedDict() 21 | print('Loading label and rec ...') 22 | with open(in_log, "r") as lff: 23 | for line in tqdm.tqdm(lff): 24 | segs = line.strip().split('\t') 25 | if len(segs) < 1: 26 | print('Error Line: %s' % line) 27 | continue 28 | name = segs[0] 29 | img_path = name 30 | lab_dict[img_path] = segs[1] if len(segs) >= 2 else "" 31 | rec_dict[img_path] = segs[2] if len(segs) >= 3 else "" 32 | print('Get Valid label %d' % len(lab_dict)) 33 | print('Get Valid pred %d' % len(rec_dict)) 34 | return lab_dict, rec_dict 35 | 36 | def main(args): 37 | #load lrc 38 | print("load lrc...") 39 | sdr = ListRecordLoader(args.lrc_path) 40 | 41 | #load lab and rec 42 | print("load lines") 43 | #lab_dict, rec_dict = load_lab_rec(args.input) 44 | with open(args.input, "r") as fin: 45 | lines = fin.readlines() 46 | 47 | output = args.output 48 | # pdb.set_trace() 49 | if output is None: 50 | prefix, ext = os.path.splitext(args.input) 51 | # pdb.set_trace() 52 | output = "{}_wName{}".format(prefix, ext) 53 | 54 | new_lines = [] 55 | for line in tqdm.tqdm(lines): 56 | spts = line.split("\t") 57 | key = spts[0] 58 | record_idx, *idxes = [int(item) for item in key.split('-')] 59 | record = sdr.get_record(record_idx) 60 | image_path = record[args.key] 61 | if args.addPath > 0: 62 | image_name = image_path 63 | else: 64 | image_name = os.path.basename(image_path) 65 | new_key = image_name + "-" + "-".join(["%d"%idx for idx in idxes]) 66 | new_line = "\t".join([new_key]+ spts[1:]) 67 | new_lines.append(new_line) 68 | 69 | new_lines = sorted(new_lines, key = lambda x:x.split("\t")[0]) 70 | with open(output, "w") as fout: 71 | fout.writelines(new_lines) 72 | 73 | print('All Done!') 74 | 75 | 76 | 77 | if __name__ == "__main__": 78 | parser = argparse.ArgumentParser("") 79 | parser.add_argument("-input", type=str, default="") 80 | parser.add_argument("-lrc_path", type=str, default="") 81 | parser.add_argument("-addPath", type=int, default=0, help="") 82 | parser.add_argument("-key", type=str, default="image_path") 83 | parser.add_argument("-output", type=str, default=None) 84 | 85 | args = parser.parse_args() 86 | main(args) 87 | 88 | -------------------------------------------------------------------------------- /rain/optimizer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | 4 | class NewAdam(torch.optim.Adam): 5 | 6 | def step(self, closure=None): 7 | 8 | loss = None 9 | if closure is not None: 10 | loss = closure() 11 | 12 | for group in self.param_groups: 13 | for p in group['params']: 14 | if p.grad is None: 15 | continue 16 | 17 | # size = float(dist.get_world_size()) 18 | # dist.all_reduce(p.grad.data, op=dist.ReduceOp.SUM) 19 | # p.grad.data.div_(size) 20 | 21 | grad = p.grad.data 22 | if grad.is_sparse: 23 | raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead') 24 | amsgrad = group['amsgrad'] 25 | 26 | state = self.state[p] 27 | 28 | # State initialization 29 | if len(state) == 0: 30 | state['step'] = 0 31 | # Exponential moving average of gradient values 32 | state['exp_avg'] = torch.zeros_like(p.data) 33 | # Exponential moving average of squared gradient values 34 | state['exp_avg_sq'] = torch.zeros_like(p.data) 35 | if amsgrad: 36 | # Maintains max of all exp. moving avg. of sq. grad. values 37 | state['max_exp_avg_sq'] = torch.zeros_like(p.data) 38 | 39 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 40 | 41 | beta1, beta2 = group['betas'] 42 | 43 | state['step'] += 1 44 | bias_correction1 = 1 - beta1 ** state['step'] 45 | bias_correction2 = 1 - beta2 ** state['step'] 46 | 47 | if group['weight_decay'] != 0: 48 | grad.add_(group['weight_decay'], p.data) 49 | 50 | # Decay the first and second moment running average coefficient 51 | exp_avg.mul_(beta1).add_(1 - beta1, grad) 52 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) 53 | 54 | if amsgrad: 55 | max_exp_avg_sq = state['max_exp_avg_sq'] 56 | # Maintains the maximum of all 2nd moment running avg. till now 57 | torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq) 58 | # Use the max. for normalizing running avg. of gradient 59 | denom = (max_exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(group['eps']) 60 | else: 61 | # denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(group['eps']) # (sqrt(v) / sqrt(bc2)) + eps 62 | denom = exp_avg_sq.sqrt().add_(group['eps']) # sqrt(v) + eps 63 | 64 | # step_size = group['lr'] / bias_correction1 65 | step_size = group['lr'] * math.sqrt(bias_correction2) / bias_correction1 # lr * math.sqrt(coef2)/coef1 66 | 67 | p.data.addcdiv_(-step_size, exp_avg, denom) 68 | 69 | return loss 70 | -------------------------------------------------------------------------------- /rain/data_aug.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import numpy 4 | import random 5 | 6 | #img [h, w, c] in [0,255] 7 | def Generate_PadLine(img): 8 | # img = img.transpose([1, 2, 0]) 9 | h = img.shape[0] 10 | w = img.shape[1] 11 | # w_p = random.randint(0, min(w, 80)) 12 | # h_p = random.randint(0, 16) 13 | w_p = random.randint(0, min(w, 60)) 14 | h_p = random.randint(0, 12) 15 | x0 = int(random.random() * w) 16 | y0 = int(random.random() * h) 17 | if x0 + w_p < w: 18 | row_s = x0 19 | row_e = x0 + w_p 20 | else: 21 | row_s = x0 - w_p 22 | row_e = x0 23 | if y0 + h_p < h: 24 | col_s = y0 25 | col_e = y0 + h_p 26 | else: 27 | col_s = y0 - h_p 28 | col_e = y0 29 | img[0:h, row_s:row_e, :] = [0, 0, 0] 30 | img[col_s:col_e, 0:w, :] = [0, 0, 0] 31 | # img = img.transpose([2, 0, 1]) 32 | return img 33 | 34 | #img [h, w, c] 35 | def Generate_PadRow(img): 36 | # img = img.transpose([1, 2, 0]) 37 | h = img.shape[0] 38 | w = img.shape[1] 39 | # h_p = random.randint(0, 16) 40 | # h_p = random.randint(0, 30) 41 | h_p = random.randint(0, 24) 42 | y0 = int(random.random() * h) 43 | if y0 + h_p < h: 44 | col_s = y0 45 | col_e = y0 + h_p 46 | else: 47 | col_s = y0 - h_p 48 | col_e = y0 49 | img[col_s:col_e, 0:w, :] = [0, 0, 0] 50 | # img = img.transpose([2, 0, 1]) 51 | return img 52 | 53 | #img [h, w, c] 54 | def local_medblur(img): 55 | #img = img.transpose([1, 2, 0]) 56 | ksize = random.randint(0, 1) * 2 + 7 57 | h = img.shape[0] 58 | w = img.shape[1] 59 | wsize = int(random.randint(20, 50) * 1.0 / 100.0 * w) 60 | x0 = min(int(random.random() * w), w - wsize + 5) 61 | y0 = 0 62 | dst = cv2.medianBlur(img[:, x0:x0 + wsize, :], ksize) 63 | img[:, x0:x0 + wsize, :] = dst 64 | #img = img.transpose([2, 0, 1]) 65 | return img 66 | 67 | #img [h, w, c] 68 | def motion_blur(img): 69 | # img = img.transpose([1, 2, 0]) 70 | # h = img.shape[0] 71 | # w = img.shape[1] 72 | # wsize = int(random.randint(20, 50) * 1.0 / 100.0 * w) 73 | # x0 = min(int(random.random() * w), w - wsize + 5) 74 | # y0 = 0 75 | image = numpy.array(img) 76 | degree = random.randint(8,12) 77 | angle = random.randint(45,60) 78 | # degree = 16 79 | # angle = 60 80 | M = cv2.getRotationMatrix2D((degree / 2, degree / 2), angle, 1) 81 | motion_blur_kernel = numpy.diag(numpy.ones(degree)) 82 | motion_blur_kernel = cv2.warpAffine(motion_blur_kernel, M, (degree, degree)) 83 | motion_blur_kernel = motion_blur_kernel / degree 84 | blurred = cv2.filter2D(image, -1, motion_blur_kernel) 85 | cv2.normalize(blurred, blurred, 0.0, 255.0, cv2.NORM_MINMAX) 86 | # img = blurred.transpose([2, 0, 1]) 87 | img = blurred 88 | return img 89 | 90 | #img [h, w, c] 91 | def gaussian_blur(img): 92 | # img = img.transpose([1, 2, 0]) 93 | h = img.shape[0] 94 | w = img.shape[1] 95 | wsize = int(random.randint(20, 50) * 1.0 / 100.0 * w) 96 | x0 = min(int(random.random() * w), w - wsize + 5) 97 | y0 = 0 98 | knsize = random.randint(0, 3) * 2 + 9 99 | blurred = cv2.GaussianBlur(img[:,x0:x0 + wsize,:], ksize=(knsize, knsize), sigmaX=0, sigmaY=0) 100 | img[:,x0:x0 + wsize,:] = blurred 101 | # img = img.transpose([2, 0, 1]) 102 | return img 103 | 104 | def random_scale_downup(img, size_ratio_down=(0.4, 0.75)): 105 | random_ratio = random.uniform(size_ratio_down[0], size_ratio_down[1]) 106 | h, w, c = img.shape 107 | h_down = int(random_ratio * h + 0.5) 108 | w_down = int(random_ratio * w + 0.5) 109 | 110 | img = cv2.resize(img, (w_down, h_down)) 111 | img = cv2.resize(img, (w, h)) 112 | return img -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## :fire: RFL: Simplifying Chemical Structure Recognition with Ring-Free Language :fire: 2 | 3 |
4 |

5 | 6 | github follow 7 | 8 | 9 |

10 |
11 | 12 | This is the official implementation of our paper: "RFL: Simplifying Chemical Structure Recognition with Ring-Free Language". Accepted by AAAI 2025 oral. 13 | 14 | Paper arxiv: [Paper](https://arxiv.org/abs/2412.07594) 15 | 16 | 17 | ## :fire: News: 18 | 19 | - 2025.01.20. Our paper is selected as **AAAI 2025 oral**, congratulations :clap::clap::clap:. 20 | - The source code including training and inference has relase. 21 | 22 | TODO: 23 | - [x] Update paper link in arxiv. 24 | - [x] Update Source Code. 25 | 26 | ## :star: Overview 27 | 28 | The primary objective of Optical Chemical Structure Recognition is to identify chemical structure images into corresponding markup sequences. In this work, we propose a novel Ring-Free Language (RFL), which utilizes a divide-and-conquer strategy to describe chemical structures in a hierarchical form. RFL allows complex molecular structures to be decomposed into multiple parts. This approach significantly reduces the learning difficulty for recognition models. Leveraging RFL, we propose a universal Molecular Skeleton Decoder (MSD), which comprises a skeleton generation module that progressively predicts the molecular skeleton and individual rings, along with a branch classification module for predicting branch information. Experimental results demonstrate that the proposed RFL and MSD can be applied to various mainstream methods, achieving superior performance compared to state-of-the-art approaches in both printed and handwritten scenarios. 29 | 30 | Comparasion of RFL with previous modeling language: 31 |
32 | Introduction 33 |
34 | 35 | Our Model Architecture: 36 |
37 | model architecture 38 |
39 | 40 | 41 | ## :balloon: Datasets 42 | 43 | In Our paper, we use two dataset as follows. 44 | - [EDU-CHEMC](https://github.com/iFLYTEK-CV/EDU-CHEMC) : A dataset for handwritten chemical structure recognition. 45 | - [Mini-CASIA-CSDB](https://nlpr.ia.ac.cn/databases/CASIA-CSDB/index.html) : A dataset for printed chemical structure recognition. 46 | 47 | ## :memo: Ring-Free Language 48 | Our Ring-Free Language (RFL) utilizes a divide-and-conquer strategy to describe chemical structures in a hierarchical form. For a molecular structure $G$, it will be equivalently converted into a molecular skeleton $S$, individual ring structures $R$ and branch information $F$. 49 | 50 | You can use the following command to generate Ring-Free Language of single samples. We have provided some typical examples for testing in `./RFL/RFL.py`: 51 | ```bash 52 | cd RFL 53 | python RFL.py 54 | ``` 55 | 56 | Batch generation of multiple process using mutli-processings: 57 | ```bash 58 | cd RFL 59 | bash RFL_gen.sh 60 | ``` 61 | 62 | 63 | ## :bulb: Training 64 | You can start training using the following command: 65 | 66 | ```bash 67 | bash train.sh 68 | ``` 69 | 70 | Note: The dataset path and related paramaters need to be modified in `rain\config.py` 71 | 72 | 73 | ## :airplane: Evalutation 74 | ```bash 75 | bash test_organic.sh 76 | ``` 77 | 78 | 79 | ## :rocket: Experiment Results 80 | Comparison with state-of-the-art methods on handwritten dataset (EDU-CHEMC) and printed dataset (Mini-CASIA-CSDB). 81 | 82 |
83 | Result 84 |
85 | 86 | 87 | Ablation study on the EDU-CHEMC dataset, with all systems based on MSD-DenseWAP. 88 | | System | MSD | [conn] | EM | Struct-EM | 89 | |--------|------|--------|-------|-----------| 90 | | T1 | × | × | 38.70 | 49.45 | 91 | | T2 | × | √ | 44.02 | 55.77 | 92 | | T3 | √ | × | 52.76 | 58.58 | 93 | | T4 | √ | √ | 64.96 | 73.15 | 94 | 95 | 96 | To prove that RFL and MSD can simplify molecular structure recognition and enhance generalization ability, we design experiments on molecule complexity. 97 | 98 |
99 | Generalization 100 |
101 | 102 | Exact match rate (in \%) of DenseWAP and MSD-DenseWAP along test sets with different structural complexity. The left subplot is trained on complexity \{1,2\}, and the right subplot is trained on complexity \{1,2,3\}. 103 | 104 | 105 | Case Study: 106 |
107 | Case Study 108 |
109 | 110 | 111 | ## :newspaper: Citation 112 | If you find our work is useful in your research, please consider citing: 113 | 114 | ``` 115 | @inproceedings{chang2025rfl, 116 | title={RFL: Simplifying Chemical Structure Recognition with Ring-Free Language}, 117 | author={Chang, Qikai and Chen, Mingjun and Pi, Changpeng and Hu, Pengfei and Zhang, Zhenrong and Ma, Jiefeng and Du, Jun and Yin, Baocai and Hu, Jinshui}, 118 | booktitle={Proceedings of the AAAI Conference on Artificial Intelligence}, 119 | volume={39}, 120 | number={2}, 121 | pages={2007--2015}, 122 | year={2025} 123 | } 124 | ``` 125 | 126 | 127 | If you have any question, please feel free to contact me: qkchang@mail.ustc.edu.cn 128 | 129 | 130 | 131 | -------------------------------------------------------------------------------- /RFL/chemfig2ssml.py: -------------------------------------------------------------------------------- 1 | from multiprocessing import Process, synchronize, Lock, Manager, Pool 2 | from tqdm import tqdm 3 | from six.moves import queue 4 | import argparse 5 | 6 | from cond_render_main import process_cond_render 7 | from text_render import rend_text 8 | from chemfig_struct import * 9 | 10 | def chemfig_convert_to_ssml(input_chemfig): 11 | out_units = process_cond_render(input_chemfig) 12 | preprocess_chemfig = [] 13 | for id, unit in enumerate(out_units): 14 | if not isinstance(unit, Atom): 15 | # 非分子部分直接并入out_units 16 | preprocess_chemfig.append(unit) 17 | else: 18 | # 化学分子 19 | temp_string = rend_text(unit) 20 | temp_string = ['\chemfig', '{'] + temp_string + ['}'] 21 | preprocess_chemfig += temp_string 22 | preprocess_chemfig = " ".join(preprocess_chemfig) 23 | return preprocess_chemfig 24 | 25 | 26 | def do_single_task(params, args, records_queue, records_queue_lock, shared_params={}): 27 | line = params["line"] 28 | line_id = params["line_id"] 29 | 30 | line = line.strip().split('\t') 31 | file_path = line[0] 32 | input_chemfig = line[1] 33 | ssml_string = chemfig_convert_to_ssml(input_chemfig) 34 | 35 | out_cs_string = shared_params["out_cs_string"] 36 | out_cs_string_lock = shared_params["out_cs_string_lock"] 37 | with out_cs_string_lock: 38 | out_cs_string.put([file_path, ssml_string]) 39 | 40 | 41 | def try_do_single_task(params, args, records_queue, records_queue_lock, shared_params={}): 42 | try: 43 | do_single_task(params, args, records_queue, records_queue_lock, shared_params) 44 | except BaseException as e: 45 | error_lines = shared_params["error_lines"] 46 | error_lines_lock = shared_params["error_lines_lock"] 47 | with error_lines_lock: 48 | file_name = params["line"].strip().split('\t')[0] 49 | label_string = params["line"].strip().split('\t')[0] 50 | error_lines.put(file_name + '\t' + str(e) + '\t' + label_string + '\n') 51 | # print("try fail!") 52 | if records_queue_lock is not None: 53 | with records_queue_lock: 54 | records_queue.put(1) 55 | 56 | def main(args): 57 | if args.input_type == "text": 58 | with open(args.input, "r") as fin: 59 | lines = fin.readlines() 60 | else: 61 | raise NotImplementedError("unsupport input type = {}".format(args.input_type)) 62 | 63 | # init metrics 64 | manager = Manager() 65 | records_queue = manager.Queue() 66 | records_queue_lock = manager.Lock() 67 | 68 | shared_params = {} 69 | if args.input_type == "text": 70 | shared_params["error_lines"] = manager.Queue() 71 | shared_params["error_lines_lock"] = manager.Lock() 72 | shared_params["out_cs_string"] = manager.Queue() 73 | shared_params["out_cs_string_lock"] = manager.Lock() 74 | 75 | 76 | all_tasks = [] 77 | line_id = -1 78 | 79 | if args.num_workers <= 0: 80 | for line in tqdm(lines): 81 | line_id += 1 82 | params = {} 83 | params["line"] = line 84 | params["line_id"] = line_id 85 | cur_task = (params, args, records_queue, records_queue_lock, shared_params) 86 | do_single_task(*cur_task) 87 | if line_id > 20: 88 | break 89 | else: 90 | for line in tqdm(lines): 91 | line_id += 1 92 | params = {} 93 | params["line"] = line 94 | params["line_id"] = line_id 95 | cur_task = (params, args, records_queue, records_queue_lock, shared_params) 96 | all_tasks.append(cur_task) 97 | # if line_id > 100: 98 | # break 99 | 100 | def print_error(error): 101 | print("error:", error) 102 | 103 | poolSize = args.num_workers 104 | pool = Pool(poolSize) 105 | pool.starmap_async(try_do_single_task, all_tasks, error_callback=print_error) 106 | pool.close() 107 | tq = tqdm(total=len(all_tasks)) 108 | count = 0 109 | print("begin") 110 | #try: 111 | while count < len(all_tasks): 112 | try: 113 | c = records_queue.get_nowait() 114 | except queue.Empty: 115 | continue 116 | count += 1 117 | tq.update(1) 118 | 119 | pool.join() 120 | 121 | # 后处理保存转换错误的结果和生成的骨干字符串 122 | error_lines = shared_params["error_lines"] 123 | error_lines_lock = shared_params["error_lines_lock"] 124 | # print(error_lines) 125 | with open(args.error_output, "w") as fout: 126 | while not error_lines.empty(): 127 | line = error_lines.get() 128 | fout.write(line) 129 | 130 | out_cs_string = shared_params["out_cs_string"] 131 | out_cs_string_lock = shared_params["out_cs_string_lock"] 132 | with open(args.output, 'w') as fout: 133 | while not out_cs_string.empty(): 134 | line = out_cs_string.get() 135 | file_name = line[0] 136 | tmp_ssml_string = line[1] 137 | fout.write(file_name + '\t' + tmp_ssml_string + '\n') 138 | 139 | 140 | 141 | if __name__ == "__main__": 142 | parser = argparse.ArgumentParser("") 143 | parser.add_argument("-input", type=str, default='chemfig_test.txt') 144 | parser.add_argument("-output", type=str, default='ssml_test.txt') 145 | parser.add_argument("-error_output", type=str, default='ssml_test_error.txt') 146 | parser.add_argument("-input_type", type=str, default="text", help="current support text") 147 | parser.add_argument("-num_workers", type=int, default=40) 148 | args = parser.parse_args() 149 | main(args) 150 | 151 | -------------------------------------------------------------------------------- /RFL/Tool_Formula/latex_norm/rep_dict.map: -------------------------------------------------------------------------------- 1 | \UC_FF05 \% 2 | \UC_FF08 ( 3 | \UC_FF09 ) 4 | \UC_FF0B + 5 | \UC_FF0D - 6 | \UC_FF0E . 7 | \UC_FF0F / 8 | \UC_FF1A : 9 | \UC_FF0C , 10 | \UC_FF1B ; 11 | \UC_FF1F ? 12 | \UC_FF01 ! 13 | \UC_FF5B \{ 14 | \UC_FF5D \} 15 | \UC_FF5C | 16 | \UC_2460 \numone 17 | \UC_2461 \numtwo 18 | \UC_2462 \numthree 19 | \UC_2463 \numfour 20 | \UC_2464 \numfive 21 | \UC_2465 \numsix 22 | \UC_2466 \numseven 23 | \UC_2467 \numeight 24 | \UC_2468 \numnine 25 | \UC_FFE0 \not \subset 26 | \UC_FF5E \sim 27 | \UC_FF59 y 28 | \UC_FF58 x 29 | \UC_FF57 w 30 | \UC_FF55 u 31 | \UC_FF54 t 32 | \UC_FF53 s 33 | \UC_FF52 r 34 | \UC_FF51 q 35 | \UC_FF50 p 36 | \UC_FF4F o 37 | \UC_FF4E n 38 | \UC_FF4D m 39 | \UC_FF4C l 40 | \UC_FF4B k 41 | \UC_FF49 i 42 | \UC_FF48 h 43 | \UC_FF47 g 44 | \UC_FF46 f 45 | \UC_FF45 e 46 | \UC_FF44 d 47 | \UC_FF43 c 48 | \UC_FF42 b 49 | \UC_FF41 a 50 | \UC_FF3F _ 51 | \UC_FF3E ^ 52 | \UC_FF3D ] 53 | \UC_FF3C \ 54 | \UC_FF3B [ 55 | \UC_FF3A Z 56 | \UC_FF39 Y 57 | \UC_FF38 X 58 | \UC_FF34 T 59 | \UC_FF33 S 60 | \UC_FF32 R 61 | \UC_FF31 Q 62 | \UC_FF30 P 63 | \UC_FF2F O 64 | \UC_FF2E N 65 | \UC_FF2D M 66 | \UC_FF2C L 67 | \UC_FF2B K 68 | \UC_FF29 I 69 | \UC_FF28 H 70 | \UC_FF27 G 71 | \UC_FF26 F 72 | \UC_FF25 E 73 | \UC_FF24 D 74 | \UC_FF23 C 75 | \UC_FF22 B 76 | \UC_FF21 A 77 | \UC_FF1E > 78 | \UC_FF1D = 79 | \UC_FF1C < 80 | \UC_FF19 9 81 | \UC_FF18 8 82 | \UC_FF17 7 83 | \UC_FF16 6 84 | \UC_FF15 5 85 | \UC_FF14 4 86 | \UC_FF13 3 87 | \UC_FF12 2 88 | \UC_FF11 1 89 | \UC_FF10 0 90 | \UC_FF0A * 91 | \UC_FF07 ^ { \prime } 92 | \UC_FF06 & 93 | \UC_FF04 $ 94 | \UC_FF03 # 95 | \UC_FF02 " 96 | \UC_FE6A \% 97 | \UC_FE65 > 98 | \UC_FE64 < 99 | \UC_FE63 - 100 | \UC_FE62 + 101 | \UC_FE5A ) 102 | \UC_FE59 ( 103 | \UC_4E28 | 104 | \UC_33A1 m ^ { 2 } 105 | \UC_339E k m 106 | \UC_339D c m 107 | \UC_338F k g 108 | \UC_3221 ( \UC_4E8C ) 109 | \UC_30CB \UC_4E8C 110 | \UC_309C ^ { \circ } 111 | \UC_301E " 112 | \UC_301D " 113 | \UC_3015 ) 114 | \UC_3014 ( 115 | \UC_3009 > 116 | \UC_3008 < 117 | \UC_25CB \circ 118 | \UC_25B3 \triangle 119 | \UC_2571 / 120 | \UC_2502 | 121 | \UC_2500 - 122 | \UC_2492 1 1 . 123 | \UC_2489 2 . 124 | \UC_2488 1 . 125 | \UC_2487 ( 2 0 ) 126 | \UC_2486 ( 1 9 ) 127 | \UC_2484 ( 1 7 ) 128 | \UC_2482 ( 1 5 ) 129 | \UC_2481 ( 1 4 ) 130 | \UC_2480 ( 1 3 ) 131 | \UC_247F ( 1 2 ) 132 | \UC_247E ( 1 1 ) 133 | \UC_247D ( 1 0 ) 134 | \UC_247C ( 9 ) 135 | \UC_247A ( 7 ) 136 | \UC_2479 ( 6 ) 137 | \UC_2478 ( 5 ) 138 | \UC_2477 ( 4 ) 139 | \UC_2476 ( 3 ) 140 | \UC_2475 ( 2 ) 141 | \UC_2474 ( 1 ) 142 | \UC_2469 \textcircled { 1 0 } 143 | \UC_22A5 \bot 144 | \UC_2299 \odot 145 | \UC_2284 \not \subset 146 | \UC_2282 \subset 147 | \UC_2266 \le 148 | \UC_2265 \ge 149 | \UC_2264 \le 150 | \UC_2260 \not = 151 | \UC_224C \cong 152 | \UC_2248 \approx 153 | \UC_223D \sim 154 | \UC_2236 : 155 | \UC_2235 \because 156 | \UC_2234 \therefore 157 | \UC_222B \int 158 | \UC_222A \cup 159 | \UC_2229 \cap 160 | \UC_2228 \vee 161 | \UC_2227 \wedge 162 | \UC_2225 // 163 | \UC_2223 | 164 | \UC_2220 \angle 165 | \UC_221E \infty 166 | \UC_221A \checkmark 167 | \UC_2215 / 168 | \UC_2212 - 169 | \UC_220F \prod 170 | \UC_2209 \not \in 171 | \UC_2208 \in 172 | \UC_2205 \varnothing 173 | \UC_2202 \alpha 174 | \UC_2199 \swarrow 175 | \UC_2198 \searrow 176 | \UC_2197 \nearrow 177 | \UC_2193 \downarrow 178 | \UC_2192 \rightarrow 179 | \UC_2191 \uparrow 180 | \UC_2190 \leftarrow 181 | \UC_2174 v 182 | \UC_2103 \textcelsius 183 | \UC_2081 _ { 1 } 184 | \UC_2080 _ { 0 } 185 | \UC_203A > 186 | \UC_2039 < 187 | \UC_2033 " 188 | \UC_2032 ^ { \prime } 189 | \UC_2026 \cdots 190 | \UC_2025 \cdot \cdot 191 | \UC_201D " 192 | \UC_201C " 193 | \UC_2019 ^ { \prime } 194 | \UC_2015 - 195 | \UC_2014 - 196 | \UC_2013 - 197 | \UC_0430 a 198 | \UC_0421 C 199 | \UC_03C9 \omega 200 | \UC_03C7 \chi 201 | \UC_03C6 \varphi 202 | \UC_03C1 \rho 203 | \UC_03C0 \pi 204 | \UC_03BE \xi 205 | \UC_03BC \mu 206 | \UC_03BB \lambda 207 | \UC_03B8 \theta 208 | \UC_03B5 \varepsilon 209 | \UC_03B3 \gamma 210 | \UC_03B2 \beta 211 | \UC_03B1 \alpha 212 | \UC_03A9 \Omega 213 | \UC_03A6 \Phi 214 | \UC_03A3 \sum 215 | \UC_03A0 \prod 216 | \UC_0399 I 217 | \UC_02CD _ 218 | \UC_02CA ^ { \prime } 219 | \UC_0283 \int 220 | \UC_01B6 z 221 | \UC_00F7 \div 222 | \UC_00D8 \varnothing 223 | \UC_00D7 \times 224 | \UC_00BD 1 / 2 225 | \UC_00BA ^ { \circ } 226 | \UC_00B9 ^ { 1 } 227 | \UC_00B7 \cdot 228 | \UC_00B4 ^ { \prime } 229 | \UC_00B3 ^ { 3 } 230 | \UC_00B2 ^ { 2 } 231 | \UC_00B1 \pm 232 | \UC_00B0 ^ { \circ } 233 | \UC_00AD - 234 | \UC_00AA ^ { a } 235 | \UC_00AC \neg 236 | \UC_00AF ^ { - } 237 | \UC_00BB \gg 238 | \UC_0251 \alpha 239 | \UC_02C8 \prime 240 | \UC_02C9 ^ { - } 241 | \UC_0394 \Delta 242 | \UC_0395 E 243 | \UC_041D H 244 | \UC_041E O 245 | \UC_041F \prod 246 | \UC_0428 \UC_0428 247 | \UC_0435 e 248 | \UC_0440 p 249 | \UC_200B 250 | \UC_2018 \UC_2018 251 | \UC_2061 252 | \UC_2105 \UC_2105 253 | \UC_2126 \Omega 254 | \UC_2160 I 255 | \UC_2161 I I 256 | \UC_2162 I I I 257 | \UC_2163 I V 258 | \UC_2164 V 259 | \UC_2165 V I 260 | \UC_2166 V I I 261 | \UC_2167 V I I I 262 | \UC_2168 I X 263 | \UC_2169 X 264 | \UC_2170 i 265 | \UC_2171 i i 266 | \UC_2172 i i i 267 | \UC_2173 i v 268 | \UC_2179 x 269 | \UC_21C0 \rightharpoonup 270 | \UC_21C4 \rightleftarrows 271 | \UC_21CB \leftrightharpoons 272 | \UC_21CC \rightleftharpoons 273 | \UC_21D4 \Leftrightarrow 274 | \UC_2206 \triangle 275 | \UC_2218 \circ 276 | \UC_2219 \cdot 277 | \UC_221D \propto 278 | \UC_2237 : : 279 | \UC_223C \sim 280 | \UC_2259 \overset \wedge = 281 | \UC_225C \overset \triangle = 282 | \UC_2261 \equiv 283 | \UC_2267 \geqq 284 | \UC_226A \ll 285 | \UC_226B \gg 286 | \UC_22C5 \cdot 287 | \UC_22EE \vdots 288 | \UC_22EF . . . 289 | \UC_247B ( 8 ) 290 | \UC_24BC \textcircled { G } 291 | \UC_2550 = 292 | \UC_2573 \times 293 | \UC_25CF \cdot 294 | \UC_2782 \textcircled { 3 } 295 | \UC_2A7D \le 296 | \UC_3007 \circ 297 | \UC_300A \UC_300A 298 | \UC_300B \UC_300B 299 | \UC_30ED \UC_30ED 300 | \UC_E009 301 | \UC_E0FC 302 | \UC_E100 303 | \UC_E225 304 | \UC_E226 305 | \UC_E229 306 | \UC_E4D3 307 | \UC_E607 308 | \UC_E61D 309 | \UC_FE30 : 310 | \UC_FE51 , 311 | \UC_FE52 . 312 | \UC_FE55 : -------------------------------------------------------------------------------- /rain/evaluate.py: -------------------------------------------------------------------------------- 1 | import logging 2 | logger = logging.getLogger() 3 | from collections import namedtuple 4 | import time 5 | import numpy 6 | 7 | BatchEndParam = namedtuple('BatchEndParams', 8 | ['epoch', 9 | 'nbatch', 10 | 'model', 11 | 'locals']) 12 | 13 | class Speedometer(object): 14 | def __init__(self, batch_size, epoch_batch, frequent=50, opt=None, auto_reset=True): 15 | self.batch_size = batch_size 16 | self.epoch_batch = epoch_batch 17 | self.frequent = frequent 18 | self.init = False 19 | self.tic = 0 20 | self.last_count = 0 21 | self.auto_reset = auto_reset 22 | 23 | self.opt = opt 24 | 25 | def __call__(self, param): 26 | """Callback to Show speed.""" 27 | count = param.nbatch 28 | if self.last_count > count: 29 | self.init = False 30 | self.last_count = count 31 | 32 | if self.init: 33 | if count % self.frequent == 0: 34 | speed = self.frequent * self.batch_size / (time.time() - self.tic) 35 | name_values = [] 36 | msg = 'Epoch[%d]\tBatch[%d][%d]\tlr[%f]\tSpeed: %.2f samples/sec' 37 | if param.model._eval_metrics is not None: 38 | for i, eval_metric in enumerate(param.model._eval_metrics): 39 | name_value = eval_metric.get() 40 | name_values += name_value 41 | if self.auto_reset: 42 | eval_metric.reset() 43 | msg += '\t%s=%f' 44 | 45 | 46 | if self.opt is None: 47 | logging.info(msg, param.epoch, count, self.epoch_batch, 1, speed, *name_values) 48 | else: 49 | logging.info(msg, param.epoch, count, self.epoch_batch, self.opt.param_groups[0]['lr'], speed, *name_values) 50 | 51 | else: 52 | logging.info("Iter[%d]\tBatch[%d][%d]\tSpeed: %.2f samples/sec", 53 | param.epoch, count, self.epoch_batch, speed) 54 | self.tic = time.time() 55 | else: 56 | self.init = True 57 | self.tic = time.time() 58 | 59 | 60 | class WarmupScheduler(object): 61 | def __init__(self, optimizer, start_lr=1e-8, stop_lr=2e-4, step=10000, frequent=50): 62 | super(WarmupScheduler, self).__init__() 63 | self.optimizer = optimizer 64 | self.start_lr = start_lr 65 | self.stop_lr = stop_lr 66 | self.step = float(step) 67 | self.count = 0. 68 | self.frequent = frequent 69 | 70 | def __call__(self, param): 71 | if self.count < self.step and self.start_lr < self.stop_lr: 72 | self.count += 1 73 | next_lr = (self.count/self.step)*(self.stop_lr-self.start_lr) + self.start_lr 74 | self.optimizer.param_groups[0]['lr'] = next_lr 75 | if self.count % self.frequent == 0: 76 | logging.info('warmup[%d/%d]\tnext batch lr=%.2e' % (self.count,int(self.step),next_lr)) 77 | 78 | 79 | class KingScheduler(object): 80 | def __init__(self, optimizer, scheduler_dict): 81 | super(KingScheduler, self).__init__() 82 | self.optimizer = optimizer 83 | self.scheduler = scheduler_dict["scheduler"] 84 | self.lr_factor = scheduler_dict["lr_factor"] 85 | self.wd_factor = scheduler_dict["wd_factor"] 86 | self.eps_factor = scheduler_dict["eps_factor"] 87 | self.stop_lr = scheduler_dict["stop_lr"] 88 | self.stop_wd = scheduler_dict["stop_wd"] 89 | self.stop_eps = scheduler_dict["stop_eps"] 90 | self.decay_wd = scheduler_dict["decay_wd"] 91 | self.decay_eps = scheduler_dict["decay_eps"] 92 | self.thresh = scheduler_dict["thresh"] 93 | self.decay_step = scheduler_dict["decay_step"] 94 | self.cur_step = 0 95 | self.cur_step_ind = 0 96 | self.descent = -1.0 if scheduler_dict["valid_metric"] == 'ce' else 1.0 # ce or acc of validation 97 | self.max_val = -numpy.inf 98 | 99 | def __call__(self, value=None): 100 | if self.scheduler == 'FixStep': 101 | self.cur_step_ind = 0 102 | self.cur_step += 1 103 | elif self.scheduler == 'AutoStep' and value is not None: 104 | self.cur_step_ind = 0 105 | value *= self.descent 106 | if (value - self.max_val) > self.thresh: 107 | self.cur_step = 0 108 | self.max_val = value 109 | else: 110 | self.cur_step += 1 111 | elif self.scheduler == 'MultiStep': 112 | self.cur_step += 1 113 | 114 | # tune lr, eps, wd 115 | if self.cur_step == self.decay_step[self.cur_step_ind]: 116 | logging.info('{} {}'.format(self.scheduler, self.decay_step[self.cur_step_ind])) 117 | next_lr = max(self.optimizer.param_groups[0]['lr']*self.lr_factor, self.stop_lr) 118 | next_wd = max(self.optimizer.param_groups[0]['weight_decay']*self.wd_factor, self.stop_wd) 119 | self.optimizer.param_groups[0]['lr'] = next_lr 120 | logging.info('next epoch lr={}'.format(next_lr)) 121 | if self.decay_wd: 122 | self.optimizer.param_groups[0]['weight_decay'] = next_wd 123 | logging.info('next epoch wd={}'.format(next_wd)) 124 | if 'eps' in self.optimizer.param_groups[0] and self.decay_eps: 125 | next_eps = max(self.optimizer.param_groups[0]['eps']*self.wd_factor, self.stop_wd) 126 | self.optimizer.param_groups[0]['eps'] = next_eps 127 | logging.info('next epoch eps={}'.format(next_eps)) 128 | 129 | self.cur_step = 0 130 | self.cur_step_ind = min(self.cur_step_ind+1, len(self.decay_step)-1) 131 | -------------------------------------------------------------------------------- /RFL/RFL_main.py: -------------------------------------------------------------------------------- 1 | # @Time : 2024/3/11 2 | 3 | from RFL.RFL import * 4 | import argparse 5 | import pickle 6 | from multiprocessing import Process, synchronize, Lock, Manager, Pool 7 | from six.moves import queue 8 | 9 | 10 | def do_single_task(params, args, records_queue, records_queue_lock, shared_params={}): 11 | line = params["line"] 12 | line_id = params["line_id"] 13 | 14 | if args.input_type == "text": 15 | line = line.strip().split('\t') 16 | file_path = line[0] 17 | input_chemfig = line[1] 18 | success, cs_string, _, ring_branch_info, cond_data = cs_main(input_chemfig, is_show=False) 19 | 20 | if success: 21 | out_cs_string = shared_params["out_cs_string"] 22 | out_cs_string_lock = shared_params["out_cs_string_lock"] 23 | with out_cs_string_lock: 24 | cs_string = " ".join(cs_string) 25 | out_cs_string.put([file_path, cs_string, ring_branch_info, cond_data]) 26 | else: 27 | error_lines = shared_params["error_lines"] 28 | error_lines_lock = shared_params["error_lines_lock"] 29 | with error_lines_lock: 30 | error_lines.put(params["line"]) 31 | 32 | def try_do_single_task(params, args, records_queue, records_queue_lock, shared_params={}): 33 | try: 34 | do_single_task(params, args, records_queue, records_queue_lock, shared_params) 35 | except BaseException as e: 36 | error_lines = shared_params["error_lines"] 37 | error_lines_lock = shared_params["error_lines_lock"] 38 | with error_lines_lock: 39 | file_name = params["line"].strip().split('\t')[0] 40 | label_string = params["line"].strip().split('\t')[1] 41 | error_lines.put(file_name + '\t' + str(e) + '\t' + label_string + '\n') 42 | # print("try fail!") 43 | if records_queue_lock is not None: 44 | with records_queue_lock: 45 | records_queue.put(1) 46 | 47 | def main(args): 48 | if args.input_type == "text": 49 | with open(args.input, "r") as fin: 50 | lines = fin.readlines() 51 | else: 52 | raise NotImplementedError("unsupport input type = {}".format(args.input_type)) 53 | 54 | # init metrics 55 | manager = Manager() 56 | records_queue = manager.Queue() 57 | records_queue_lock = manager.Lock() 58 | 59 | shared_params = {} 60 | if args.input_type == "text": 61 | shared_params["error_lines"] = manager.Queue() 62 | shared_params["error_lines_lock"] = manager.Lock() 63 | shared_params["out_cs_string"] = manager.Queue() 64 | shared_params["out_cs_string_lock"] = manager.Lock() 65 | shared_params["out_branch_info"] = manager.Queue() 66 | shared_params["out_branch_info_lock"] = manager.Lock() 67 | 68 | all_tasks = [] 69 | line_id = -1 70 | 71 | if args.num_workers <= 0: 72 | for line in tqdm(lines): 73 | line_id += 1 74 | params = {} 75 | params["line"] = line 76 | params["line_id"] = line_id 77 | 78 | cur_task = (params, args, records_queue, records_queue_lock, shared_params) 79 | do_single_task(*cur_task) 80 | if line_id > 20: 81 | break 82 | else: 83 | for line in tqdm(lines): 84 | line_id += 1 85 | params = {} 86 | params["line"] = line 87 | params["line_id"] = line_id 88 | cur_task = (params, args, records_queue, records_queue_lock, shared_params) 89 | all_tasks.append(cur_task) 90 | # if line_id > 100: 91 | # break 92 | def print_error(error): 93 | print("error:", error) 94 | 95 | poolSize = args.num_workers 96 | pool = Pool(poolSize) 97 | pool.starmap_async(try_do_single_task, all_tasks, error_callback=print_error) 98 | pool.close() 99 | tq = tqdm(total=len(all_tasks)) 100 | count = 0 101 | print("begin") 102 | #try: 103 | while count < len(all_tasks): 104 | try: 105 | c = records_queue.get_nowait() 106 | except queue.Empty: 107 | continue 108 | count += 1 109 | tq.update(1) 110 | 111 | pool.join() 112 | 113 | # 后处理保存转换错误的结果和生成的骨干字符串 114 | error_lines = shared_params["error_lines"] 115 | error_lines_lock = shared_params["error_lines_lock"] 116 | # print(error_lines) 117 | with open(args.error_output, "w") as fout: 118 | while not error_lines.empty(): 119 | line = error_lines.get() 120 | fout.write(line) 121 | 122 | if args.input_type == "text": 123 | out_cs_string = shared_params["out_cs_string"] 124 | out_cs_string_lock = shared_params["out_cs_string_lock"] 125 | max_len = 1 126 | with open(args.output, 'w') as fout: 127 | while not out_cs_string.empty(): 128 | line = out_cs_string.get() 129 | file_name = line[0] 130 | tmp_cs_string = line[1] 131 | ring_branch_info = line[2] 132 | cond_data = line[3] 133 | cur_max_len = [len(item) for item in ring_branch_info if item is not None] 134 | if len(cur_max_len) > 0: 135 | max_len = max(max_len, max(cur_max_len)) 136 | 137 | fout.write(file_name + '\t' + tmp_cs_string + '\n') 138 | # print("ring_branch_info最大长度: ", max_len) 139 | 140 | 141 | if __name__ == "__main__": 142 | parser = argparse.ArgumentParser("") 143 | parser.add_argument("-input", type=str, default='valid_ssml_sd.txt') 144 | parser.add_argument("-output", type=str, default='./result/valid_RFL_string.txt') 145 | parser.add_argument("-error_output", type=str, default='./result/error_example.txt') 146 | parser.add_argument("-input_type", type=str, default="text", help="current support text") 147 | parser.add_argument("-num_workers", type=int, default=2) 148 | args = parser.parse_args() 149 | main(args) 150 | 151 | -------------------------------------------------------------------------------- /rain/metric.py: -------------------------------------------------------------------------------- 1 | import numpy 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from . import xconfig 6 | 7 | class EvalMetric(object): 8 | """Base class for all evaluation metrics. 9 | """ 10 | def __init__(self, name, output_names=None, 11 | label_names=None, **kwargs): 12 | self.name = str(name) 13 | self.output_names = output_names 14 | self.label_names = label_names 15 | self._kwargs = kwargs 16 | self.reset() 17 | 18 | def __str__(self): 19 | return "EvalMetric: {}".format(dict(self.get_name_value())) 20 | 21 | def get_config(self): 22 | """Save configurations of metric. Can be recreated 23 | from configs with metric.create(**config) 24 | """ 25 | config = self._kwargs.copy() 26 | config.update({ 27 | 'metric': self.__class__.__name__, 28 | 'name': self.name, 29 | 'output_names': self.output_names, 30 | 'label_names': self.label_names}) 31 | return config 32 | 33 | def update_dict(self, label, pred): 34 | """Update the internal evaluation with named label and pred 35 | """ 36 | 37 | if self.output_names is not None: 38 | pred = [pred[name] for name in self.output_names] 39 | else: 40 | pred = list(pred.values()) 41 | 42 | if self.label_names is not None: 43 | label = [label[name] for name in self.label_names] 44 | else: 45 | label = list(label.values()) 46 | 47 | self.update(label, pred) 48 | 49 | def update(self, labels, preds): 50 | """Updates the internal evaluation result. 51 | """ 52 | raise NotImplementedError() 53 | 54 | def reset(self): 55 | """Resets the internal evaluation result to initial state.""" 56 | self.num_inst = 0 57 | self.sum_metric = 0.0 58 | 59 | def get(self): 60 | """Gets the current evaluation result. 61 | """ 62 | if self.num_inst == 0: 63 | return (self.name, float('nan')) 64 | else: 65 | return (self.name, self.sum_metric / self.num_inst) 66 | 67 | def get_name_value(self): 68 | """Returns zipped name and value pairs. 69 | """ 70 | name, value = self.get() 71 | if not isinstance(name, list): 72 | name = [name] 73 | if not isinstance(value, list): 74 | value = [value] 75 | return list(zip(name, value)) 76 | 77 | 78 | 79 | class MyCrossEntropy(EvalMetric): 80 | def __init__(self, eps=1e-8, name='train'): 81 | super(MyCrossEntropy, self).__init__(name+"-ce") 82 | self.eps = eps 83 | 84 | @torch.no_grad() 85 | def update(self, labels, preds): # labels list preds list 86 | label, target_hook, mask = (l.detach() for l in labels) 87 | pred = preds[3].detach() # 2521 88 | 89 | #label = label.T.flatten().astype('int32') 90 | #mask = mask.T.flatten() 91 | label = label.flatten().long() 92 | mask = mask.flatten() 93 | ce = pred[torch.arange(len(label)), label] 94 | ce = -torch.log(ce+self.eps) *mask 95 | #ce= ce.sum()/mask.sum() 96 | 97 | self.sum_metric += float(ce.sum().item()) 98 | self.num_inst += mask.sum().item() 99 | 100 | class MyACC(EvalMetric): 101 | def __init__(self, name="train"): 102 | super(MyACC, self).__init__(name+"-acc") 103 | 104 | @torch.no_grad() 105 | def update(self, labels, preds): 106 | label, target_hook, mask = (l.detach() for l in labels) 107 | pred = preds[3].detach() 108 | label = label.long() 109 | label = label.flatten() 110 | mask = mask.flatten() 111 | rec = torch.argmax(pred, axis=1) 112 | rec = rec.long() 113 | acc = rec== label 114 | acc= acc*mask 115 | #acc = acc.sum()/mask.sum() 116 | 117 | self.sum_metric += acc.sum().item() 118 | self.num_inst += mask.sum().item() 119 | 120 | class MyLossCand(EvalMetric): 121 | def __init__(self, eps=1e-8, name='train'): 122 | super(MyLossCand, self).__init__(name+"-cand_loss") 123 | self.eps = eps 124 | 125 | def update(self, labels, preds): # labels list preds list 126 | # label, target_cand_angle, target_hook, mask = (l.detach().cpu().numpy() for l in labels) 127 | loss = preds[1].detach().cpu().numpy() # [B, ] 128 | self.sum_metric += loss 129 | self.num_inst += 1 130 | 131 | class MyLossMem(EvalMetric): 132 | def __init__(self, eps=1e-8, name='train'): 133 | super(MyLossMem, self).__init__(name+"-mem_loss") 134 | self.eps = eps 135 | 136 | def update(self, labels, preds): # labels list preds list 137 | # label, target_cand_angle, target_hook, mask = (l.detach().cpu().numpy() for l in labels) 138 | if torch.isnan(preds[1]): 139 | # print("nan, skip") 140 | return 141 | loss = preds[1].detach().cpu().numpy() # [B, ] 142 | self.sum_metric += loss 143 | self.num_inst += 1 144 | 145 | class MyAccMem(EvalMetric): 146 | def __init__(self, eps=1e-8, name='train'): 147 | super(MyAccMem, self).__init__(name+"-mem_acc") 148 | self.eps = eps 149 | self.ea_index = xconfig.vocab.getID("") 150 | 151 | @torch.no_grad() 152 | def update(self, labels, preds): # labels list preds list 153 | # label, target_cand_angle, target_hook, mask = (l.detach().cpu().numpy() for l in labels) 154 | if True in torch.isnan(preds[4].detach()) or True in torch.isnan(preds[5].detach()) or True in torch.isnan(preds[5]): 155 | return 156 | pred = preds[4].detach() # [B, L, M, v] 157 | rec = torch.argmax(pred, axis=-1) #[B, L, M] 158 | mem_tgt = preds[5].detach() 159 | mask = preds[6].detach() 160 | acc = (rec == mem_tgt) 161 | acc = acc * mask 162 | self.sum_metric += acc.sum() 163 | self.num_inst += mask.sum() 164 | 165 | def GenTranMetric(name="train"): 166 | return [MyCrossEntropy(name=name), MyACC(name=name), MyLossMem(name=name), MyAccMem(name=name)] -------------------------------------------------------------------------------- /rain/xconfig.py: -------------------------------------------------------------------------------- 1 | import os 2 | from rain.utils import Vocab 3 | import logging 4 | logger = logging.getLogger() 5 | 6 | # # data path 7 | source_dim = 3 8 | level = 'line' # line column topic 9 | dst_type_dict={'ContentType': ['text'], 'LogicalType':['answer', 'subject'], 'PhysicalType': ['hand', 'print'], 'ImageType': ['normal']} 10 | 11 | # train dataset 12 | train_lrc = "ssml_train.lrc" 13 | train_key = train_lrc + ".line.cache" 14 | train_key_inds = () 15 | train_ignore = None 16 | 17 | # devdataset 18 | dev_lrc = "ssml_valid.lrc" 19 | dev_key = dev_lrc + ".line.cache" 20 | 21 | # test dataset 22 | test_lrc = "ssml_valid.lrc" 23 | test_lrc_cache = test_lrc + ".line.cache" 24 | test_lrc_normh = 40 25 | 26 | # amp setting 27 | train_amp = False 28 | 29 | # # aug setting 30 | # rand resize 31 | rand_resize = True 32 | rand_resize_ratio = (0.6, 1.2) 33 | rand_crop = True 34 | rand_crop_pixel = 6 35 | # coutout 36 | do_cutout = False 37 | min_contour_area = 100. 38 | width_rate =1.3 39 | do_cut_rate =0.5 40 | cutout_sample_rate = 0.5 41 | ignore_pixel =50 42 | # blur 43 | do_blur = True 44 | # rand polygon 45 | rand_bbox_rate = 0.5 46 | # random_scale_downup 47 | do_random_scale_downup=False 48 | random_scale_downup_range=(0.4, 0.75) 49 | 50 | # # vocab params 51 | vocab_file = "./dict/vocab.txt" 52 | vocab = Vocab(vocab_file, unk_id=0) 53 | vocab_size = vocab.getVocSize() 54 | sos = vocab.get_sos() 55 | eos = vocab.get_eos() 56 | enter = vocab.getID("\\enter") 57 | 58 | # # model params 59 | base_model_dir = "" # model save path 60 | base_model_dir = base_model_dir 61 | model_prefix = base_model_dir + "/encdec" 62 | model_type = "pytorch" 63 | num_epochs = 100 64 | 65 | # # train data params 66 | max_height = 2000 67 | max_width = 2000 68 | max_length = 10000 69 | max_image_size = 6000000 #6000000 70 | max_batch_size = 8 71 | fix_batch_size = None 72 | 73 | # # test data params 74 | test_max_height = 1000 75 | test_max_width = 1000 76 | test_max_length = 10000 77 | test_fix_batch_size = 1 78 | test_image_list = None 79 | test_image_normh = 40 80 | test_load_epochs = '85' 81 | test_key = '' 82 | test_lrc = '' 83 | 84 | 85 | img_fix_char_height = None 86 | test_det_sections = None 87 | # ============================ Phase Params ========================== # 88 | # # Train phase params 89 | learning_rate = 2e-4 90 | weight_decay = 0 91 | seed = 369 92 | disp_batches = 100 93 | auto_load_epoch = False 94 | load_epoch = None # if None, model params for training will be initialized randomly 95 | load_param_path = None 96 | allow_missing = True 97 | epoch_batch = 1 98 | data_divide_num = 0.2 # 5 data as one epoch 99 | val_epoch_batch = 1000 100 | val_scheduler_dict = {'scheduler':'MultiStep', 'valid_metric': 'ce', # ce or acc 101 | 'lr_factor': 0.5, 'wd_factor':0.1, 'eps_factor': 0.1, 102 | 'stop_lr': 1e-8,'stop_wd': 1e-12, 'stop_eps': 1e-12, 103 | 'decay_wd': False, 'decay_eps': False, 'thresh': 1e-5, 104 | 'decay_step': [12, 6, 3, 2, 1, 1, 1, 1, 1], # [40, 20, 10, 5, 1, 1, 1, 1, 1] #[125, 20, 10, 5, 1, 1, 1, 1, 1] 105 | # warmup params 106 | 'use_warmup': True, 'warmup_start_lr': 1e-8, 107 | 'warmup_step': 1000, 'warmup_disp_freq': 50} # MultiStep decay_step, AutoStep decay_step[0], FixStep decay_step[0] 108 | use_bmuf = True 109 | bmuf_params = {"sync_step":50, "alpha":1, "blr":1.0, "bm": 0.875} 110 | if use_bmuf: 111 | num_gpus = 1 112 | if 'WORLD_SIZE' in os.environ: 113 | num_gpus = int(os.environ['WORLD_SIZE']) 114 | epoch_batch = int(1 / (num_gpus*data_divide_num)) # 16000 batch as one epoch 115 | #epoch_batch = 64000 // num_gpus # 16000 batch as one epoch 116 | if num_gpus == 4: 117 | bmuf_params["sync_step"] = 50 118 | bmuf_params["bm"] = 0.75 119 | elif num_gpus == 8: 120 | bmuf_params["sync_step"] = 50 121 | bmuf_params["bm"] = 0.875 122 | elif num_gpus == 12: 123 | bmuf_params["sync_step"] = 50 124 | bmuf_params["bm"] = 0.8875 125 | elif num_gpus == 16: 126 | bmuf_params["sync_step"] = 50 127 | bmuf_params["bm"] = 0.9 128 | elif num_gpus == 32: 129 | bmuf_params["sync_step"] = 25 130 | bmuf_params["bm"] = 0.9 131 | elif num_gpus == 1: 132 | logger.info("Gpu count = 1 means that single card debug mode or test mode was launched") 133 | else: 134 | bmuf_params["sync_step"] = 50 135 | bmuf_params["bm"] = 0.75 136 | #raise ValueError("Gpu count = %d error, which should be in [4,8,16,32] if use bmuf" % num_gpus) 137 | 138 | # # test phase params 139 | frame_per_char = 50 140 | beam = 5 141 | 142 | # ========================== Encoder Params ========================== # 143 | # # VGG16 144 | encoder_units = [3, 4, 6, 3] 145 | encoder_use_res = [1, 1, 1, 1] 146 | encoder_basic_group = [8, 16, 16, 32] 147 | encoder_filter_list = [24, 48, 96, 192, 384] 148 | encoder_stride_list = [(2, 2), (2, 2), (2, 2), (2, 2)] 149 | encode_dropout = 0.00 150 | encode_feat_dropout = 0.05 151 | 152 | # # SelfAtten structure 153 | encoder_position_dim = 384 154 | encoder_position_att = 192 155 | encoder_dim = encoder_position_dim 156 | 157 | 158 | # ========================== Decoder Params ========================== # 159 | decoder_state_dim = 256 160 | decoder_embed_dim = 128 161 | decoder_att_dim = 128 162 | decoder_merge_dim = 384 163 | decoder_chatt_dim = 384 # group channel attention 164 | decoder_max_seq_len = 1000000 165 | decoder_dropout = 0.2 166 | decoder_embed_drop = 0.15 167 | decoder_cover_kernel = (11,11) 168 | decoder_cover_padding = (5,5) 169 | 170 | decoder_angle_embed_dim = 128 171 | decoder_mem_match_dim = 256 172 | 173 | # =========================== Other Params ========================== # 174 | def get_config_str(): 175 | res = '' 176 | res += 'Config:\n' 177 | import collections 178 | hehe = collections.OrderedDict(sorted(globals().items(), key=lambda x: x[0])) 179 | for k, v in hehe.items(): 180 | if k.startswith('__'): continue 181 | if k.startswith('SEPARATOR'): continue 182 | if k.startswith('get'): continue 183 | if type(v) == (type(os)): continue 184 | if len(k) < 2: continue 185 | res += '{0}: {1}\n'.format(k, v) 186 | return res 187 | -------------------------------------------------------------------------------- /RFL/chemfig_ssml_struct.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | import re 3 | import pdb 4 | import numpy as np 5 | import re 6 | import networkx as nx 7 | import matplotlib.pyplot as plt 8 | import cv2 9 | 10 | from utils import replace_chemfig, get_atom_group 11 | 12 | bond_types = ["-", "=", "~", ">", "<", ">:", "<:", ">|", "<|", "-:", "=_", "=^", "~/"] 13 | bond_types = sorted(bond_types, key=lambda x: -len(x)) 14 | virtual_types = ['{', '}', '(', ')', 'branch(', 'branch)'] 15 | match_virtual = {'}':'{', ')':'(', 'branch)':'branch('} 16 | 17 | 18 | class Atom: 19 | index = 0 20 | 21 | def __init__(self, text=""): 22 | self.name = "Atom_{}".format(Atom.index) 23 | Atom.index += 1 24 | self.m_text = text 25 | self.pos_x = 0 26 | self.pos_y = 0 27 | self.ring_ids = {} 28 | self.conn_bonds = [] 29 | 30 | class Bond(object): 31 | index = 0 32 | __default__ = {"m_angle": 0, "m_length": 1, "m_start": 0, "m_end": 0} 33 | 34 | def __init__(self, b_type="-"): 35 | self.name = "Bond_{}".format(Bond.index) 36 | Bond.index += 1 37 | b_type = b_type.replace("_", "").replace("^", "") 38 | 39 | self.m_type = b_type 40 | self.m_angle = None 41 | self.m_length = None 42 | self.m_start = None #Bond.__default__["m_start"] 43 | self.m_end = None #Bond.__default__["m_end"] 44 | self.m_extra_info = None 45 | 46 | self.begin_atom = None 47 | self.end_atom = None 48 | 49 | self.ring_ids = {} 50 | 51 | self.is_assigned = set() 52 | 53 | 54 | def main(ssml: str): 55 | print(ssml) 56 | item_list = ssml.split() 57 | print(item_list) 58 | for item in item_list: 59 | print(item) 60 | 61 | 62 | def judge_str_item_type(item: str): 63 | bond_types = ["-", "=", "~", ">", "<", ">:", "<:", ">|", "<|", "-:", "=_", "=^", "~/"] 64 | virtual_types = ['{', '}', '(', ')', 'branch(', 'branch)'] 65 | 66 | if '?' in item: 67 | begin_conn_pattern = re.compile(r'\?\[[a-zA-Z]\]') 68 | begin_result = begin_conn_pattern.findall(item) 69 | if len(begin_result) > 0: 70 | return 'reconn_begin' 71 | else: 72 | return 'reconn_end' 73 | 74 | for bond in bond_types: 75 | if bond in item: 76 | return 'bond_atom' 77 | 78 | atom_pattern = re.compile(r'[a-zA-Z]+|\\circle') 79 | atom_result = atom_pattern.findall(item) 80 | if len(atom_result) > 0: 81 | if atom_result[0] == item: 82 | return 'atom' 83 | 84 | if item == 'branch(': 85 | return 'branch_begin' 86 | if item == 'branch)': 87 | return 'branch_end' 88 | 89 | if item in virtual_types: 90 | return 'virtual' 91 | # for virtual in virtual_types: 92 | # if virtual in item: 93 | # return 'virtual' 94 | return 'atom' 95 | 96 | # print("item " + item + " is not matched !!!!") 97 | 98 | def attr_obtain(str): 99 | item_type = judge_str_item_type(str) 100 | if item_type == 'bond_atom': 101 | # bond_type bond_angle 102 | bond_type_pattern = re.compile(r'.*\[') 103 | bond_angle_pattern = re.compile(r'\d+') 104 | bond_type = bond_type_pattern.findall(str)[0][:-1] 105 | bond_angle = int(bond_angle_pattern.findall(str)[0]) 106 | return bond_type, bond_angle 107 | elif item_type == 'atom': 108 | # atom name 109 | return str 110 | elif item_type == 'reconn_end': 111 | # reconn bond type 112 | bond_type_pattern = re.compile(r'\{(.+)\}') 113 | bond_type = bond_type_pattern.findall(str)[0] 114 | assert bond_type is not None, "bond_type is null." 115 | return bond_type 116 | else: 117 | print("Error in attr_obtain, type not defined.") 118 | sys.exit() 119 | 120 | def build_graph(input_str, is_debug = False): 121 | 122 | chemfig_text, rep_dict, rep_text = replace_chemfig(input_str) 123 | graph_list = [] 124 | for k, v in rep_dict.items(): 125 | Graph = nx.Graph() 126 | # 遍历每一个分子, v 127 | item_list = v.split()[1:] 128 | item_list = get_atom_group(item_list) 129 | # print(item_list) 130 | virtual_stack = [] # 模拟括号堆栈,用于括号匹配 131 | 132 | cur_atom = None 133 | cur_bond = None 134 | 135 | reconn_begin_atom_dict = {} # 记录回连开始的原子 136 | branch_stack = [] # 分支回溯堆栈 137 | is_reconn = False # 回连标志,因为回连原子在回连标识之后,所以需要额外一个标识 138 | is_branch_end = False 139 | cur_reconn_tag = '' 140 | 141 | node_tag = 0 142 | branch_begin_tag = 0 143 | 144 | for ssml_item in item_list: 145 | ssml_item_type = judge_str_item_type(ssml_item) 146 | if is_debug: 147 | print("cur: ", ssml_item, ssml_item_type) 148 | if ssml_item_type == 'atom': 149 | Graph.add_node(node_tag, name=ssml_item) 150 | elif ssml_item_type == 'bond_atom': 151 | Graph.add_node(node_tag) # 创建新节点 152 | Graph.add_node(node_tag + 1) 153 | bond_type, bond_angle = attr_obtain(ssml_item) 154 | if is_branch_end: 155 | Graph.add_edge(branch_begin_tag, node_tag + 1, bond_type = bond_type, angle = bond_angle) 156 | is_branch_end = False 157 | else: 158 | Graph.add_edge(node_tag, node_tag + 1, bond_type = bond_type, angle = bond_angle) 159 | node_tag += 1 160 | 161 | elif ssml_item_type == 'reconn_begin': 162 | is_reconn = True 163 | tag = ssml_item[2] 164 | cur_reconn_tag = tag 165 | reconn_begin_atom_dict[tag] = node_tag 166 | 167 | elif ssml_item_type == 'reconn_end': 168 | cur_reconn_tag = ssml_item[2] 169 | reconn_atom = reconn_begin_atom_dict[cur_reconn_tag] # 获取回连开始原子 170 | Graph.add_edge(reconn_atom, node_tag) 171 | 172 | del reconn_begin_atom_dict[cur_reconn_tag] # 从字典中删除处理完的回连记录 173 | elif ssml_item_type == 'branch_begin': 174 | branch_stack.append(node_tag) # begin branch, push stack 175 | elif ssml_item_type == 'branch_end': 176 | # branch_len = 1 + node_tag - branch_stack[-1] 177 | branch_begin_tag = branch_stack[-1] # get stack top 178 | branch_stack.pop() # end branch, pop stack 179 | is_branch_end = True 180 | 181 | elif ssml_item_type == 'virtual': 182 | if len(virtual_stack) == 0: 183 | virtual_stack.append(ssml_item) # 入栈 184 | else: 185 | cur_virtual = virtual_stack[-1] # 栈顶 186 | if ssml_item not in match_virtual.keys(): # 左括号,入栈 187 | virtual_stack.append(ssml_item) 188 | else: # 右括号,匹配 189 | virtual_stack.pop() # 括号匹配,出栈 190 | graph_list.append(Graph) 191 | return graph_list 192 | -------------------------------------------------------------------------------- /RFL/reverse_render_main.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | import pdb 3 | import argparse 4 | import tqdm 5 | from multiprocessing import Process, synchronize, Lock, Manager, Pool 6 | import multiprocessing 7 | from six.moves import queue 8 | import reverse_render 9 | import utils 10 | import shutil 11 | 12 | _support_format = set(["text", "lrc", "xml"]) 13 | 14 | def do_single_task(params, args, records_queue, records_queue_lock, shared_params={}): 15 | line = params["line"] 16 | line_id = params["line_id"] 17 | if args.input_type == "text": 18 | out_lines = shared_params["out_lines"] 19 | out_lines_lock = shared_params["out_lines_lock"] 20 | 21 | spts= line.strip().split("\t") 22 | if len(spts) != 2: 23 | return 24 | img_key, in_text = spts 25 | out_text = reverse_render.reverse_organic_trans(in_text, debug=False) 26 | out_text = " ".join(out_text) 27 | out_text = utils.process_trans_for_texlive(out_text) 28 | with out_lines_lock: 29 | out_lines.put("{}\t{}\n".format(img_key, out_text)) 30 | elif args.input_type == "xml": 31 | pass 32 | elif args.input_type == "lrc": 33 | if "lrc_parser" in params: 34 | lrc_parser = params["lrc_parser"] 35 | else: 36 | lrc_parser = _lrc_parser 37 | idx = line 38 | record = lrc_parser.get_record(idx) 39 | area = record["sub_areas"] 40 | top_id = 0 41 | areas_arr = [x for x in area] 42 | while top_id < len(areas_arr): 43 | cur_area = areas_arr[top_id] 44 | top_id += 1 45 | if "text" in cur_area: 46 | in_text = cur_area["text"] 47 | cur_area["text"] = text_render.text_render(in_text, debug=True) 48 | if "sub_areas" in cur_area: 49 | areas_arr += cur_area["sub_areas"] 50 | shutil.copy(record["image_path"], "./debug/origin.jpg") 51 | out_records = shared_params["out_records"] 52 | out_records_lock = shared_params["out_records_lock"] 53 | with out_records_lock: 54 | out_records.put(record) 55 | pass 56 | 57 | pass 58 | 59 | 60 | def try_do_single_task(params, args, records_queue, records_queue_lock, shared_params={}): 61 | line = params["line"] 62 | line_id = params["line_id"] 63 | try: 64 | do_single_task(params, args, records_queue, records_queue_lock, shared_params) 65 | except BaseException as e: 66 | line_content = line if type(line) is str else line 67 | print("try fail! line id = {} line = {} err= {}".format(line_id, line_content, e)) 68 | if records_queue_lock is not None: 69 | with records_queue_lock: 70 | records_queue.put(1) 71 | 72 | 73 | def main(args): 74 | if args.input_type == "text": 75 | with open(args.input, "r") as fin: 76 | lines = fin.readlines() 77 | elif args.input_type == "xml": 78 | lines = utils.scan_dir(args.input, "xml") 79 | else: 80 | raise NotImplementedError("unsupport input type = {}".format(args.input_type)) 81 | 82 | 83 | # init metrics 84 | manager = Manager() 85 | records_queue = manager.Queue() 86 | records_queue_lock = manager.Lock() 87 | 88 | # TODO add share params here 89 | shared_params = {} 90 | if args.input_type == "text": 91 | shared_params["out_lines"] = manager.Queue() 92 | shared_params["out_lines_lock"] = manager.Lock() 93 | elif args.input_type == "xml": 94 | common_prefix = os.path.commonpath(lines) 95 | # treat args.output as output_dir 96 | 97 | 98 | all_tasks = [] 99 | line_id = -1 100 | 101 | if args.num_workers <= 0: 102 | for line in tqdm.tqdm(lines): 103 | line_id += 1 104 | # if line.find("00186_0282") == -1: 105 | # continue 106 | params = {} 107 | params["line"] = line 108 | params["line_id"] = line_id 109 | if args.input_type == "xml": 110 | params["common_prefix"] = common_prefix 111 | if args.input_type == "lrc": 112 | params["lrc_parser"] = lrc_parser 113 | cur_task = (params, args, records_queue, records_queue_lock, shared_params) 114 | # do_single_task(*cur_task) 115 | try: 116 | do_single_task(*cur_task) 117 | except BaseException as e: 118 | print(e) 119 | # pdb.set_trace() 120 | 121 | #pdb.set_trace() 122 | else: 123 | for line in tqdm.tqdm(lines): 124 | line_id += 1 125 | params = {} 126 | params["line"] = line 127 | params["line_id"] = line_id 128 | if args.input_type == "xml": 129 | params["common_prefix"] = common_prefix 130 | cur_task = (params, args, records_queue, records_queue_lock, shared_params) 131 | all_tasks.append(cur_task) 132 | pass 133 | 134 | def print_error(error): 135 | print("error:", error) 136 | 137 | def init(a): 138 | global _lrc_parser 139 | _lrc_parser = a 140 | 141 | poolSize = args.num_workers 142 | if args.input_type == "lrc": 143 | pool = Pool(poolSize, initializer=init, initargs=(lrc_parser, )) 144 | else: 145 | pool = Pool(poolSize) 146 | pool.starmap_async(try_do_single_task, all_tasks, error_callback=print_error) 147 | pool.close() 148 | tq = tqdm.tqdm(total=len(all_tasks)) 149 | count = 0 150 | print("begin") 151 | #try: 152 | while count < len(all_tasks): 153 | try: 154 | c = records_queue.get_nowait() 155 | except queue.Empty: 156 | continue 157 | if args.input_type == "lrc": 158 | try: 159 | with shared_params["out_records_lock"]: 160 | cur_record = shared_params["out_records"].get() 161 | lrc_writer.add_record(cur_record) 162 | except queue.Empty: 163 | pass 164 | count += 1 165 | tq.update(1) 166 | # except: 167 | # pass 168 | pool.join() 169 | 170 | #TODO add post process 171 | if args.input_type == "text": 172 | out_lines = shared_params["out_lines"] 173 | out_lines_lock = shared_params["out_lines_lock"] 174 | with open(args.output, "w") as fout: 175 | while not out_lines.empty(): 176 | line = out_lines.get() 177 | fout.write(line) 178 | elif args.input_type == "xml": 179 | pass 180 | # treat args.output as output_dir 181 | 182 | 183 | 184 | if __name__ == "__main__": 185 | parser = argparse.ArgumentParser("") 186 | parser.add_argument("-input", type=str, default=None) 187 | parser.add_argument("-output", type=str, default=None) 188 | parser.add_argument("-input_type", type=str, default="text", help="current support {}".format(_support_format)) 189 | parser.add_argument("-num_workers", type=int, default=32) 190 | args = parser.parse_args() 191 | main(args) 192 | -------------------------------------------------------------------------------- /RFL/text_render_main.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | import pdb 3 | import argparse 4 | import tqdm 5 | from multiprocessing import Process, synchronize, Lock, Manager, Pool 6 | import multiprocessing 7 | from six.moves import queue 8 | import text_render 9 | import utils 10 | import shutil 11 | 12 | _support_format = set(["text", "lrc", "xml"]) 13 | 14 | def do_single_task(params, args, records_queue, records_queue_lock, shared_params={}): 15 | line = params["line"] 16 | line_id = params["line_id"] 17 | if args.input_type == "text": 18 | out_lines = shared_params["out_lines"] 19 | out_lines_lock = shared_params["out_lines_lock"] 20 | 21 | spts= line.strip().split("\t") 22 | if len(spts) != 2: 23 | return 24 | img_key, in_text = spts 25 | out_text = text_render.text_render(in_text, debug=False, branch_de_amb=1) 26 | with out_lines_lock: 27 | out_lines.put("{}\t{}\n".format(img_key, out_text)) 28 | elif args.input_type == "xml": 29 | pass 30 | elif args.input_type == "lrc": 31 | if "lrc_parser" in params: 32 | lrc_parser = params["lrc_parser"] 33 | else: 34 | lrc_parser = _lrc_parser 35 | idx = line 36 | record = lrc_parser.get_record(idx) 37 | area = record["sub_areas"] 38 | top_id = 0 39 | areas_arr = [x for x in area] 40 | while top_id < len(areas_arr): 41 | cur_area = areas_arr[top_id] 42 | top_id += 1 43 | if "text" in cur_area: 44 | if args.use_raw_info > 0: 45 | if "raw_info" in cur_area and "text" in cur_area["raw_info"]: 46 | in_text = cur_area["raw_info"]["text"] 47 | else: 48 | raise ValueError("can not find raw info") 49 | else: 50 | in_text = cur_area["text"] 51 | cur_area["text"] = text_render.text_render(in_text, debug=False, branch_de_amb=1) 52 | if "sub_areas" in cur_area: 53 | areas_arr += cur_area["sub_areas"] 54 | #shutil.copy(record["image_path"], "./debug/origin.jpg") 55 | out_records = shared_params["out_records"] 56 | out_records_lock = shared_params["out_records_lock"] 57 | with out_records_lock: 58 | out_records.put(record) 59 | pass 60 | 61 | pass 62 | 63 | 64 | def try_do_single_task(params, args, records_queue, records_queue_lock, shared_params={}): 65 | line = params["line"] 66 | line_id = params["line_id"] 67 | try: 68 | do_single_task(params, args, records_queue, records_queue_lock, shared_params) 69 | except BaseException as e: 70 | line_content = line if type(line) is str else line 71 | print("try fail! line id = {} line = {} err= {}".format(line_id, line_content, e)) 72 | if records_queue_lock is not None: 73 | with records_queue_lock: 74 | records_queue.put(1) 75 | 76 | 77 | def main(args): 78 | if args.input_type == "text": 79 | with open(args.input, "r") as fin: 80 | lines = fin.readlines() 81 | elif args.input_type == "xml": 82 | lines = utils.scan_dir(args.input, "xml") 83 | else: 84 | raise NotImplementedError("unsupport input type = {}".format(args.input_type)) 85 | 86 | 87 | # init metrics 88 | manager = Manager() 89 | records_queue = manager.Queue() 90 | records_queue_lock = manager.Lock() 91 | 92 | # TODO add share params here 93 | shared_params = {} 94 | if args.input_type == "text": 95 | shared_params["out_lines"] = manager.Queue() 96 | shared_params["out_lines_lock"] = manager.Lock() 97 | elif args.input_type == "xml": 98 | common_prefix = os.path.commonpath(lines) 99 | # treat args.output as output_dir 100 | 101 | all_tasks = [] 102 | line_id = -1 103 | 104 | if args.num_workers <= 0: 105 | for line in tqdm.tqdm(lines): 106 | line_id += 1 107 | params = {} 108 | params["line"] = line 109 | params["line_id"] = line_id 110 | if args.input_type == "xml": 111 | params["common_prefix"] = common_prefix 112 | if args.input_type == "lrc": 113 | params["lrc_parser"] = lrc_parser 114 | cur_task = (params, args, records_queue, records_queue_lock, shared_params) 115 | do_single_task(*cur_task) 116 | #pdb.set_trace() 117 | else: 118 | for line in tqdm.tqdm(lines): 119 | line_id += 1 120 | params = {} 121 | params["line"] = line 122 | params["line_id"] = line_id 123 | if args.input_type == "xml": 124 | params["common_prefix"] = common_prefix 125 | cur_task = (params, args, records_queue, records_queue_lock, shared_params) 126 | all_tasks.append(cur_task) 127 | pass 128 | 129 | def print_error(error): 130 | print("error:", error) 131 | 132 | def init(a): 133 | global _lrc_parser 134 | _lrc_parser = a 135 | 136 | poolSize = args.num_workers 137 | if args.input_type == "lrc": 138 | pool = Pool(poolSize, initializer=init, initargs=(lrc_parser, )) 139 | else: 140 | pool = Pool(poolSize) 141 | pool.starmap_async(try_do_single_task, all_tasks, error_callback=print_error) 142 | pool.close() 143 | tq = tqdm.tqdm(total=len(all_tasks)) 144 | count = 0 145 | print("begin") 146 | #try: 147 | while count < len(all_tasks): 148 | try: 149 | c = records_queue.get_nowait() 150 | except queue.Empty: 151 | continue 152 | if args.input_type == "lrc": 153 | try: 154 | with shared_params["out_records_lock"]: 155 | cur_record = shared_params["out_records"].get_nowait() 156 | lrc_writer.add_record(cur_record) 157 | except queue.Empty: 158 | pass 159 | count += 1 160 | tq.update(1) 161 | # except: 162 | # pass 163 | pool.join() 164 | 165 | #TODO add post process 166 | if args.input_type == "text": 167 | out_lines = shared_params["out_lines"] 168 | out_lines_lock = shared_params["out_lines_lock"] 169 | with open(args.output, "w") as fout: 170 | while not out_lines.empty(): 171 | line = out_lines.get() 172 | fout.write(line) 173 | elif args.input_type == "xml": 174 | pass 175 | # treat args.output as output_dir 176 | 177 | pass 178 | 179 | 180 | if __name__ == "__main__": 181 | parser = argparse.ArgumentParser("") 182 | parser.add_argument("-input", type=str, default=None) 183 | parser.add_argument("-output", type=str, default=None) 184 | parser.add_argument("-input_type", type=str, default="text", help="current support {}".format(_support_format)) 185 | parser.add_argument("-use_raw_info", type=int, default=0, help="for lrc mode") 186 | parser.add_argument("-num_workers", type=int, default=32) 187 | args = parser.parse_args() 188 | main(args) 189 | -------------------------------------------------------------------------------- /RFL/Tool_Formula/latex_norm/structure_word_correct.map: -------------------------------------------------------------------------------- 1 | \fra \frac 2 | \ove \over 3 | \underlin \underline 4 | \enbd \end 5 | \fa \frac 6 | \fac \frac 7 | \farac \frac 8 | \farc \frac 9 | \feac \frac 10 | \fr \frac 11 | \fraac \frac 12 | \frad \frac 13 | \frae \frac 14 | \frav \frac 15 | \frax \frac 16 | \frca \frac 17 | \frea \frac 18 | \frrac \frac 19 | \ftac \frac 20 | \nderline \underline 21 | \overset \overset 22 | \overlin \overline 23 | \sqer \sqrt 24 | \sqrrt \sqrt 25 | \sqrta \sqrt { a } 26 | \endmatrixegincases \endmatrix \begincases 27 | \smearINSERT \smear INSERT 28 | \spqrt \sqrt 29 | \sprt \sqrt 30 | \sprtb \sqrt 31 | \sptr \sqrt 32 | \sqart \sqrt 33 | \sqrc \sqrt 34 | \sqre \sqrt 35 | \sqret \sqrt 36 | \sqrtOF \sqrt O F 37 | \sqrtm \sqrt m 38 | \sqrtn \sqrt n 39 | \sqrtp \sqrt p 40 | \sqry \sqrt 41 | \sqt \sqrt 42 | \srtq \sqrt 43 | \tbetbf \textbf 44 | \tbextbf \textbf 45 | \tbxtbf \textbf 46 | \texttt \textbf 47 | \textsf \textbf 48 | \TEXTBF \textbf 49 | \teatbf \textbf 50 | \texrtbf \textbf 51 | \textbfbf \textbf 52 | \textfb \textbf 53 | \texytbf \textbf 54 | \texfbf \textbf 55 | \tcdot \cdot 56 | # \te \Re 57 | \tectbf \textbf 58 | \teextbf \textbf 59 | \testbf \textbf 60 | \tetbf \textbf 61 | \tettbf \textbf 62 | \tetxbf \textbf 63 | \tetxtbf \textbf 64 | \tex \text 65 | \texTbf \textbf 66 | \texbf \textbf 67 | \texbtf \textbf 68 | \texcircled \textcircled 69 | \texebf \textbf 70 | \texetbf \textbf 71 | \texgbf \textbf 72 | \texit 73 | \texrbf \textbf 74 | \textb \textbf 75 | \textbbf \textbf 76 | \textbd \textbf 77 | \textbe \textbf 78 | \textbf \textbf 79 | \textbfP \textbf P 80 | \textbfa \textbf a 81 | \textbfb \textbf b 82 | \textbfcircled \textcircled 83 | \textbfg \textbf g 84 | \textbfx \textbf x 85 | \textbg \textbf 86 | \textbgf \textbf 87 | \textbigcircle \textcircled 88 | \textbk \textbf 89 | \textbook \textbar 90 | \textbt \textbf 91 | \textbullet \bullet 92 | \textc c 93 | \textcelsiusC \textcelsius C 94 | \textcentoldstyle 95 | \textci \textcircled 96 | \textcicled \textcircled 97 | \textciecled \textcircled 98 | \textcir \textcircled 99 | \textcirc \textcircled 100 | \textcirccled \textcircled 101 | \textcirciled \textcircled 102 | \textcircl \textcircled 103 | \textcircld \textcircled 104 | \textcirclde \textcircled 105 | \textcirclel \textcircled 106 | \textcirclod \textcircled 107 | \textcirclted \textcircled 108 | \textcirdled \textcircled 109 | \textcireled \textcircled 110 | \textciriled \textcircled 111 | \textcirled \textcircled 112 | \textcirxled \textcircled 113 | \textciwcled \textcircled 114 | \textcled \textcircled 115 | \textclrcled \textcircled 116 | # \textdblhyphenchar \textdollar 117 | \textef \textbf 118 | # \textemdash \textemdash 119 | \textf \textbf 120 | \textfbf \textbf 121 | \texthf \textbf 122 | \texti \textit 123 | \textif \textit 124 | \textnf \textbf 125 | \textrbf \textbf 126 | \texttbf \textbf 127 | \texttcircled \textcircled 128 | \texttit \textit 129 | \textvf \textbf 130 | \texxtbf \textbf 131 | \texxtit \textit 132 | \texybf \textbf 133 | \texycircled \textcircled 134 | \teztbf \textbf 135 | \timesbcancel \times \bcancel 136 | \trxtbf \textbf 137 | \twxtbf \textbf 138 | \txetbf \textbf 139 | \txetcircled \textcircled 140 | \txtbf \textbf 141 | \txxtbf \textbf 142 | \thota \theta 143 | \uinderline \underline 144 | \undeline \underline 145 | \under \underline 146 | \undereline \underline 147 | \underlien \underline 148 | \underlineC \underline C 149 | \underlineE \underline E 150 | \underlineRR \underline R R 151 | \underlinea \underline a 152 | \underlinen \underline n 153 | \underlineq \underline q 154 | \underlinex \underline x 155 | \underliney \underline y 156 | \underlinez \underline z 157 | \underling \underline 158 | \underlinr \underline 159 | \undline \underline 160 | \undreline \underline 161 | \undrline \underline 162 | \undrtline \underline 163 | \unerline \underline 164 | \unferline \underline 165 | \unhderline \underline 166 | \wideha \widehat 167 | \widehatC \widehat C 168 | \widehata \widehat a 169 | \widehatb \widehat b 170 | \widehaty \widehat y 171 | \xLongequal \xlongequal 172 | \xlongleftrightarrow \xleftrightarrow 173 | \xrighttarrow \xrightarrow 174 | \yexybf \textbf 175 | \paralleltextit \parallel \textit 176 | \paralletextit \parallel \textit 177 | \scriptscriptstyle \scriptscriptstyle 178 | \scriptstyle \scriptstyle 179 | \scriptstyleA \scriptstyle A 180 | \scriptstyleABC \scriptstyle A B C 181 | \scriptstyleAD \scriptstyle A D 182 | \scriptstyleD \scriptstyle D 183 | \scriptstyleQABP \scriptstyle Q A B P 184 | \frqac \frac 185 | \FRAC \frac 186 | \faac \frac 187 | \fraC \frac 188 | \frsc \frac 189 | \fqrc \frac 190 | \frtac \frac 191 | \fraxc \frac 192 | \frc \frac 193 | \frce \frac 194 | \frqc \frac 195 | \grac \frac 196 | \ferac \frac 197 | \fras \frac 198 | \rac \frac 199 | \bcacenl \bcancel 200 | \beancel \bcancel 201 | \bacnel \bcancel 202 | \bacacel \bcancel 203 | \bancael \bcancel 204 | \ncancel \bcancel 205 | \bcanccel \bcancel 206 | \bcamcel \bcancel 207 | \bcanael \bcancel 208 | \bcanlce \bcancel 209 | \becancell \bcancel 210 | \blancel \bcancel 211 | \bcancl \bcancel 212 | \bcance \bcancel 213 | \bcanca \bcancel 214 | \bcanle \bcancel 215 | \bncancel \bcancel 216 | \bcancal \bcancel 217 | \bacancel \bcancel 218 | \bcabcel \bcancel 219 | \bcanclce \bcancel 220 | \bacncel \bcancel 221 | \bcancecl \bcancel 222 | \bcancek \bcancel 223 | \bccancel \bcancel 224 | \bcaccel \bcancel 225 | \bcandel \bcancel 226 | \bcanc \bcancel 227 | \bcencel \bcancel 228 | \bcanceL \bcancel 229 | \bcanerl \bcancel 230 | \bcaancel \bcancel 231 | \bcacle \bcancel 232 | \bca \bcancel 233 | \bcabce \bcancel 234 | \bcancer \bcancel 235 | \bcancrl \bcancel 236 | \bcanncel \bcancel 237 | \bcanxel \bcancel 238 | \bclancel \bcancel 239 | \bcsncel \bcancel 240 | \bczncle \bcancel 241 | \becancle \bcancel 242 | \ncancle \bcancel 243 | \ancel \bcancel 244 | \bcacncel \bcancel 245 | \bcancdl \bcancel 246 | \blance \bcancel 247 | \bcanse \bcancel 248 | \bcanecl \bcancel 249 | \bczncel \bcancel 250 | \becancel \bcancel 251 | \bcaacnel \bcancel 252 | \bcan \bcancel 253 | \bbcancel \bcancel 254 | \bcancrel \bcancel 255 | \bcanceql \bcancel 256 | \bcangel \bcancel 257 | \bcnacel \bcancel 258 | \bcancei \bcancel 259 | \baccnel \bcancel 260 | \bcances \bcancel 261 | \bcacel \bcancel 262 | \bcandcel \bcancel 263 | \bcansel \bcancel 264 | \bancle \bcancel 265 | \bcancfel \bcancel 266 | \bcancen \bcancel 267 | \bcancedl \bcancel 268 | \bcancel \bcancel 269 | \bcancle \bcancel 270 | \kbcancel \bcancel 271 | \BCANCEL \bcancel 272 | \BCANCLE \bcancel 273 | \bcacnel \bcancel 274 | \bcancec \bcancel 275 | \bcanceel \bcancel 276 | \bcanel \bcancel 277 | \bcncel \bcancel 278 | \bcanclee \bcancel 279 | \banlace \bcancel 280 | \SQRT \sqrt 281 | \fqrt \sqrt 282 | \beging \begin 283 | \bedib \begin 284 | \begim \begin 285 | \bengin \begin 286 | \brgin \begin 287 | \begian \begin 288 | \bergin \begin 289 | \bein \begin 290 | \bgein \begin 291 | \begion \begin 292 | \beqin \begin 293 | \begiin \begin 294 | \being \begin 295 | \beg \begin 296 | \begi \begin 297 | \bedin \begin 298 | \degin \begin 299 | \begun \begin 300 | \begib \begin 301 | \egin \begin 302 | \beign \begin 303 | \begoin \begin 304 | \behin \begin 305 | \besin \begin 306 | \beginP \begin 307 | \BEGIN \begin 308 | \begincass \begin{cases} 309 | \beginarray \begin{array} 310 | \doct \dot 311 | \enda \end 312 | \enbd \end 313 | \endases \end{cases} 314 | \frownA \frown { A } 315 | \frownB \frown { B } 316 | \frownE \frown { E } 317 | \fyown \frown 318 | \frowm \frown 319 | \frwon \frown 320 | \ledt \left 321 | \lefe \left 322 | \legt \left 323 | \lteft \left 324 | \letf \left 325 | \lrft \left 326 | \lunderline \underline 327 | \marhop \mathop 328 | \overlint \overline 329 | \ovreline \overline 330 | \ovweline \overline 331 | \ocerline \overline 332 | \overling \overline 333 | \onerline \overline 334 | \ovelrine \overline 335 | \overlinr \overline 336 | \overkine \overline 337 | \ocerset \overset 338 | \orarroe \overrightarrow 339 | \overright \overrightarrow 340 | \overrightarro \overrightarrow 341 | \overrightaggow \overrightarrow 342 | \overrgitarrow \overrightarrow 343 | \overrighta \overrightarrow 344 | \orarrozzw \overrightarrow 345 | \overrightarraw \overrightarrow 346 | \overrightraaow \overrightarrow 347 | \overrughtrarrow \overrightarrow 348 | \owerrightarrow \overrightarrow 349 | \orarro \overrightarrow 350 | \orarrow \overrightarrow 351 | \orarrw \overrightarrow 352 | \overrig \overrightarrow 353 | \overrightarroww \overrightarrow 354 | \ovarrightarrow \overrightarrow 355 | \overrightorarrow \overrightarrow 356 | \overrigtarrow \overrightarrow 357 | \overrightarrrow \overrightarrow 358 | \overrogjtarrow \overrightarrow 359 | \orattow \overrightarrow 360 | \overrightarrowow \overrightarrow 361 | \overrightar \overrightarrow 362 | \overrightrrow \overrightarrow 363 | \orarrrow \overrightarrow 364 | \orarrowghtarrow \overrightarrow 365 | \overrighttarrow \overrightarrow 366 | \vec \overrightarrow 367 | \sart \sqrt 368 | \sbcancel \bcancel 369 | \overrigharrow \overrightarrow 370 | \overrighrarrow \overrightarrow 371 | \overrightarrowc \overrightarrow c 372 | \overrightarrowm \overrightarrow m 373 | \overringhtarrow \overrightarrow 374 | \overroghtarrow \overrightarrow 375 | \overrrightarrow \overrightarrow 376 | \ovewrrightarrow \overrightarrow 377 | \angleBCDunderline \angle B C D \underline 378 | \Downarro \Downarrow 379 | \Ooverrightarrow \overrightarrow 380 | \qart \sqrt 381 | \aqrt \sqrt 382 | \sqrr \sqrt 383 | \ssqrt \sqrt 384 | \srqt \sqrt 385 | \sqtr \sqrt 386 | \sqrtt \sqrt 387 | \sqr \sqrt 388 | \qrt \sqrt 389 | \textsurd \sqrt 390 | \textcircle \textcircled 391 | \adot a \dot 392 | \box \square 393 | \underleftarrow \underleftarrow 394 | \underleftrightarrow \underleftrightarrow 395 | \underrightarrow \underrightarrow 396 | \underset \underset 397 | \undersetn \underset n 398 | \var \bar 399 | \eng \end 400 | \overest \overset 401 | \endc \endcases 402 | \ebd \end 403 | \beginc \begincases 404 | \ens \end 405 | \ned \end 406 | \smearfrac \smear \frac 407 | \subet \subset 408 | \ffrac \frac 409 | \fzrac \frac 410 | \ed \end 411 | \timesfrac \times \frac 412 | \enq \end 413 | \Right \right 414 | \emd \end 415 | \edn \end 416 | \sqqrt \sqrt 417 | \hatb \hat b 418 | \pxdyfrac \pxdy \frac 419 | \sqr \sqrt 420 | \fare \frac 421 | \extbf \textbf 422 | \stackre \overset 423 | \lambd \lambda 424 | \overlibe \overline 425 | \cdotfrac \cdot \frac 426 | \smearbcancel \smear \bcancel 427 | \smearleft \smear \left 428 | \varph \varphi 429 | \wedeg \wedge 430 | \Downarro \Downarrow 431 | \beginca \begincases 432 | \TIMES \times 433 | \overrightrarrow \overrightarrow 434 | \overrughtrarrow \overrightarrow 435 | \overste \overset 436 | \sqert \sqrt 437 | \sqrtr \sqrt 438 | \textcirlce \textcircled 439 | \textcirced \textcircled 440 | \textcirlced \textcircled 441 | \textcricled \textcircled 442 | \textircled \textcircled 443 | \textrircled \textcircled 444 | \twxtcircled \textcircled 445 | \txtcircled \textcircled 446 | -------------------------------------------------------------------------------- /RFL/cond_render_ssml_main.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | import pdb 3 | import argparse 4 | import tqdm 5 | from multiprocessing import Process, synchronize, Lock, Manager, Pool 6 | import multiprocessing 7 | from six.moves import queue 8 | import text_render 9 | import utils 10 | import shutil 11 | import cond_render as chemfig_cond_render 12 | import pickle 13 | from Tool_Formula.latex_norm.transcription import parse_transcription 14 | from chemfig_struct import * 15 | 16 | _support_format = set(["text", "lrc", "xml"]) 17 | 18 | def process_cond_render(text): 19 | chemfig_text, rep_dict, rep_text = utils.replace_chemfig(text) 20 | rep_text_units = parse_transcription(rep_text, simple_trans=True) 21 | new_rep_dict = {} 22 | for key, value in rep_dict.items(): 23 | # chemfig_str = "".join(list(filter(lambda x:x!=" ", value))) 24 | chemfig_str = value.strip()[8:].strip()[1:-1] 25 | try: 26 | rep_atom = chemfig_cond_render.chemfig_random_cond_parse(chemfig_str) 27 | new_rep_dict[key] = rep_atom 28 | except Exception as e: 29 | new_rep_dict[key] = value 30 | raise ValueError("Parse Error, err = {}".format(e)) 31 | out_units = [] 32 | for text_unit in rep_text_units: 33 | if text_unit in new_rep_dict: 34 | out_units.append(new_rep_dict[text_unit]) 35 | else: 36 | out_units.append(text_unit) 37 | return out_units 38 | 39 | def do_single_task(params, args, records_queue, records_queue_lock, shared_params={}): 40 | line = params["line"] 41 | line_id = params["line_id"] 42 | if args.input_type == "text": 43 | out_lines = shared_params["out_lines"] 44 | out_lines_lock = shared_params["out_lines_lock"] 45 | 46 | spts= line.strip().split("\t") 47 | if len(spts) != 2: 48 | return 49 | img_key, in_text = spts 50 | out_text = text_render.text_render(in_text, debug=False) 51 | with out_lines_lock: 52 | out_lines.put("{}\t{}\n".format(img_key, out_text)) 53 | elif args.input_type == "xml": 54 | pass 55 | elif args.input_type == "lrc": 56 | if "lrc_parser" in params: 57 | lrc_parser = params["lrc_parser"] 58 | else: 59 | lrc_parser = _lrc_parser 60 | idx = line 61 | record = lrc_parser.get_record(idx) 62 | 63 | # ============= for cond render =============== 64 | area_queue = [(x, []) for x in record["sub_areas"]] 65 | success = True 66 | while len(area_queue): 67 | cur_area, prefix_arr = area_queue.pop(0) 68 | cur_prefix_arr = prefix_arr + [cur_area["idx"]] 69 | cur_idx = "-".join(cur_prefix_arr) 70 | if "sub_areas" in cur_area: 71 | area_queue += [(x, cur_prefix_arr) for x in cur_area["sub_areas"]] 72 | if "text" in cur_area: 73 | if args.use_raw_info > 0: 74 | if "raw_info" in cur_area and "text" in cur_area["raw_info"]: 75 | input_text = cur_area["raw_info"]["text"] 76 | else: 77 | raise ValueError("can not find raw info") 78 | else: 79 | input_text = cur_area["text"] 80 | 81 | 82 | out_text_units = process_cond_render(input_text) ###### 83 | unit_bytes = pickle.dumps(out_text_units) 84 | if "raw_info" not in cur_area: 85 | warnings.warn("can not find raw_info") 86 | cur_area["raw_info"]["struct_text"] = unit_bytes 87 | #for test dump 88 | if args.rend_check > 0: 89 | # out_text = cur_area["text"] 90 | try_cnt = 0 91 | while True: 92 | try: 93 | out_parsed_units = chemfig_cond_render.process_text_rnd_cond_render(out_text_units) 94 | out_text = " ".join([x[0] for x in out_parsed_units]) 95 | break 96 | except BaseException as e: 97 | if try_cnt < 5: 98 | try_cnt += 1 99 | continue 100 | else: 101 | raise ValueError(e) 102 | 103 | 104 | 105 | #cur_area["text"] = out_text 106 | 107 | # except BaseException as e: 108 | # success = False 109 | # print("can not write, check record={} err={}".format(idx, e)) 110 | # ============= for cond render =============== 111 | #pdb.set_trace() 112 | # ==============for text render =============== 113 | if success: 114 | top_id = 0 115 | areas_arr = [x for x in record["sub_areas"]] 116 | while top_id < len(areas_arr): 117 | cur_area = areas_arr[top_id] 118 | top_id += 1 119 | if "text" in cur_area: 120 | if args.use_raw_info > 0: 121 | if "raw_info" in cur_area and "text" in cur_area["raw_info"]: 122 | in_text = cur_area["raw_info"]["text"] 123 | else: 124 | raise ValueError("can not find raw info") 125 | else: 126 | in_text = cur_area["text"] 127 | cur_area["text"] = text_render.text_render(in_text, debug=False, branch_de_amb=1) 128 | if "sub_areas" in cur_area: 129 | areas_arr += cur_area["sub_areas"] 130 | # ==============for text render =============== 131 | 132 | # shutil.copy(record["image_path"], "./debug/origin.jpg") 133 | if success: 134 | out_records = shared_params["out_records"] 135 | out_records_lock = shared_params["out_records_lock"] 136 | with out_records_lock: 137 | out_records.put(record) 138 | # pdb.set_trace() 139 | pass 140 | 141 | pass 142 | 143 | 144 | def try_do_single_task(params, args, records_queue, records_queue_lock, shared_params={}): 145 | line = params["line"] 146 | line_id = params["line_id"] 147 | try: 148 | do_single_task(params, args, records_queue, records_queue_lock, shared_params) 149 | except BaseException as e: 150 | line_content = line if type(line) is str else line 151 | print("try fail! line id = {} line = {} err= {}".format(line_id, line_content, e)) 152 | if records_queue_lock is not None: 153 | with records_queue_lock: 154 | records_queue.put(1) 155 | 156 | 157 | def main(args): 158 | if args.input_type == "text": 159 | with open(args.input, "r") as fin: 160 | lines = fin.readlines() 161 | elif args.input_type == "xml": 162 | lines = utils.scan_dir(args.input, "xml") 163 | else: 164 | raise NotImplementedError("unsupport input type = {}".format(args.input_type)) 165 | 166 | 167 | # init metrics 168 | manager = Manager() 169 | records_queue = manager.Queue() 170 | records_queue_lock = manager.Lock() 171 | 172 | # TODO add share params here 173 | shared_params = {} 174 | if args.input_type == "text": 175 | shared_params["out_lines"] = manager.Queue() 176 | shared_params["out_lines_lock"] = manager.Lock() 177 | elif args.input_type == "xml": 178 | common_prefix = os.path.commonpath(lines) 179 | # treat args.output as output_dir 180 | # shared_params["err_lines"] = manager.Queue() 181 | # shared_params["err_lines_lock"] = manager.Lock() 182 | 183 | all_tasks = [] 184 | line_id = -1 185 | 186 | if args.num_workers <= 0: 187 | for line in tqdm.tqdm(lines): 188 | line_id += 1 189 | params = {} 190 | params["line"] = line 191 | params["line_id"] = line_id 192 | if args.input_type == "xml": 193 | params["common_prefix"] = common_prefix 194 | if args.input_type == "lrc": 195 | params["lrc_parser"] = lrc_parser 196 | cur_task = (params, args, records_queue, records_queue_lock, shared_params) 197 | do_single_task(*cur_task) 198 | if line_id > 20: 199 | break 200 | #pdb.set_trace() 201 | else: 202 | for line in tqdm.tqdm(lines): 203 | line_id += 1 204 | params = {} 205 | params["line"] = line 206 | params["line_id"] = line_id 207 | if args.input_type == "xml": 208 | params["common_prefix"] = common_prefix 209 | cur_task = (params, args, records_queue, records_queue_lock, shared_params) 210 | all_tasks.append(cur_task) 211 | pass 212 | 213 | def print_error(error): 214 | print("error:", error) 215 | 216 | def init(a): 217 | global _lrc_parser 218 | _lrc_parser = a 219 | 220 | poolSize = args.num_workers 221 | if args.input_type == "lrc": 222 | pool = Pool(poolSize, initializer=init, initargs=(lrc_parser, )) 223 | else: 224 | pool = Pool(poolSize) 225 | pool.starmap_async(try_do_single_task, all_tasks, error_callback=print_error) 226 | pool.close() 227 | tq = tqdm.tqdm(total=len(all_tasks)) 228 | count = 0 229 | print("begin") 230 | #try: 231 | while count < len(all_tasks): 232 | try: 233 | c = records_queue.get_nowait() 234 | except queue.Empty: 235 | continue 236 | if args.input_type == "lrc": 237 | try: 238 | with shared_params["out_records_lock"]: 239 | cur_record = shared_params["out_records"].get_nowait() 240 | lrc_writer.add_record(cur_record) 241 | except queue.Empty: 242 | pass 243 | count += 1 244 | tq.update(1) 245 | # except: 246 | # pass 247 | pool.join() 248 | 249 | #TODO add post process 250 | if args.input_type == "text": 251 | out_lines = shared_params["out_lines"] 252 | out_lines_lock = shared_params["out_lines_lock"] 253 | with open(args.output, "w") as fout: 254 | while not out_lines.empty(): 255 | line = out_lines.get() 256 | fout.write(line) 257 | elif args.input_type == "xml": 258 | pass 259 | # treat args.output as output_dir 260 | # treat args.output as output_lrc_path 261 | 262 | pass 263 | 264 | -------------------------------------------------------------------------------- /dict/vocab.txt: -------------------------------------------------------------------------------- 1 | \unk 0 2 | 1 3 | 2 4 | \enter 3 5 | \jump 4 6 | \space 5 7 | ! 6 8 | " 7 9 | ( 8 10 | ) 9 11 | + 10 12 | , 11 13 | - 12 14 | -#[:0] 13 15 | -#[:105] 14 16 | -#[:120] 15 17 | -#[:135] 16 18 | -#[:150] 17 19 | -#[:15] 18 20 | -#[:165] 19 21 | -#[:180] 20 22 | -#[:195] 21 23 | -#[:210] 22 24 | -#[:225] 23 25 | -#[:240] 24 26 | -#[:255] 25 27 | -#[:270] 26 28 | -#[:285] 27 29 | -#[:300] 28 30 | -#[:30] 29 31 | -#[:315] 30 32 | -#[:330] 31 33 | -#[:345] 32 34 | -#[:45] 33 35 | -#[:60] 34 36 | -#[:75] 35 37 | -#[:90] 36 38 | -:[:0] 37 39 | -:[:15] 38 40 | -:[:30] 39 41 | -:[:330] 40 42 | -:[:345] 41 43 | -:[:45] 42 44 | -:[:60] 43 45 | -:[:90] 44 46 | -@[:0] 45 47 | -@[:150] 46 48 | -@[:15] 47 49 | -@[:165] 48 50 | -@[:180] 49 51 | -@[:195] 50 52 | -@[:210] 51 53 | -@[:225] 52 54 | -@[:240] 53 55 | -@[:255] 54 56 | -@[:270] 55 57 | -@[:285] 56 58 | -@[:300] 57 59 | -@[:30] 58 60 | -@[:315] 59 61 | -@[:330] 60 62 | -@[:345] 61 63 | -@[:45] 62 64 | -@[:60] 63 65 | -@[:75] 64 66 | -@[:90] 65 67 | -[:0] 66 68 | -[:105] 67 69 | -[:120] 68 70 | -[:135] 69 71 | -[:150] 70 72 | -[:15] 71 73 | -[:165] 72 74 | -[:180] 73 75 | -[:195] 74 76 | -[:210] 75 77 | -[:225] 76 78 | -[:240] 77 79 | -[:255] 78 80 | -[:270] 79 81 | -[:285] 80 82 | -[:300] 81 83 | -[:30] 82 84 | -[:315] 83 85 | -[:330] 84 86 | -[:345] 85 87 | -[:45] 86 88 | -[:60] 87 89 | -[:75] 88 90 | -[:90] 89 91 | . 90 92 | / 91 93 | // 92 94 | 0 93 95 | 1 94 96 | 2 95 97 | 3 96 98 | 4 97 99 | 5 98 100 | 6 99 101 | 7 100 102 | 8 101 103 | 9 102 104 | : 103 105 | ; 104 106 | < 105 107 | <#[:225] 106 108 | <#[:270] 107 109 | <#[:285] 108 110 | <#[:315] 109 111 | <:#[:255] 110 112 | <:@[:75] 111 113 | <:[:0] 112 114 | <:[:105] 113 115 | <:[:120] 114 116 | <:[:135] 115 117 | <:[:150] 116 118 | <:[:15] 117 119 | <:[:165] 118 120 | <:[:180] 119 121 | <:[:195] 120 122 | <:[:210] 121 123 | <:[:225] 122 124 | <:[:240] 123 125 | <:[:255] 124 126 | <:[:270] 125 127 | <:[:285] 126 128 | <:[:300] 127 129 | <:[:30] 128 130 | <:[:315] 129 131 | <:[:330] 130 132 | <:[:345] 131 133 | <:[:45] 132 134 | <:[:60] 133 135 | <:[:75] 134 136 | <:[:90] 135 137 | <@[:45] 136 138 | <@[:90] 137 139 | <[:0] 138 140 | <[:105] 139 141 | <[:120] 140 142 | <[:135] 141 143 | <[:150] 142 144 | <[:15] 143 145 | <[:165] 144 146 | <[:180] 145 147 | <[:195] 146 148 | <[:210] 147 149 | <[:225] 148 150 | <[:240] 149 151 | <[:255] 150 152 | <[:270] 151 153 | <[:285] 152 154 | <[:300] 153 155 | <[:30] 154 156 | <[:315] 155 157 | <[:330] 156 158 | <[:345] 157 159 | <[:45] 158 160 | <[:60] 159 161 | <[:75] 160 162 | <[:90] 161 163 | 162 164 | <|[:105] 163 165 | <|[:255] 164 166 | <|[:330] 165 167 | <|[:75] 166 168 | = 167 169 | =#[:0] 168 170 | =#[:15] 169 171 | =#[:180] 170 172 | =#[:195] 171 173 | =#[:210] 172 174 | =#[:225] 173 175 | =#[:240] 174 176 | =#[:255] 175 177 | =#[:270] 176 178 | =#[:285] 177 179 | =#[:300] 178 180 | =#[:30] 179 181 | =#[:315] 180 182 | =#[:330] 181 183 | =#[:345] 182 184 | =#[:45] 183 185 | =#[:60] 184 186 | =#[:75] 185 187 | =#[:90] 186 188 | =@[:0] 187 189 | =@[:15] 188 190 | =@[:180] 189 191 | =@[:195] 190 192 | =@[:210] 191 193 | =@[:30] 192 194 | =@[:315] 193 195 | =@[:330] 194 196 | =@[:345] 195 197 | =@[:45] 196 198 | =@[:60] 197 199 | =@[:75] 198 200 | =@[:90] 199 201 | =[:0] 200 202 | =[:105] 201 203 | =[:120] 202 204 | =[:135] 203 205 | =[:150] 204 206 | =[:15] 205 207 | =[:165] 206 208 | =[:180] 207 209 | =[:195] 208 210 | =[:210] 209 211 | =[:225] 210 212 | =[:240] 211 213 | =[:255] 212 214 | =[:270] 213 215 | =[:285] 214 216 | =[:300] 215 217 | =[:30] 216 218 | =[:315] 217 219 | =[:330] 218 220 | =[:345] 219 221 | =[:45] 220 222 | =[:60] 221 223 | =[:75] 222 224 | =[:90] 223 225 | > 224 226 | >#[:285] 225 227 | >#[:330] 226 228 | >:#[:240] 227 229 | >:@[:60] 228 230 | >:[:0] 229 231 | >:[:105] 230 232 | >:[:150] 231 233 | >:[:15] 232 234 | >:[:165] 233 235 | >:[:180] 234 236 | >:[:195] 235 237 | >:[:210] 236 238 | >:[:225] 237 239 | >:[:240] 238 240 | >:[:255] 239 241 | >:[:270] 240 242 | >:[:285] 241 243 | >:[:300] 242 244 | >:[:30] 243 245 | >:[:315] 244 246 | >:[:330] 245 247 | >:[:345] 246 248 | >:[:45] 247 249 | >:[:60] 248 250 | >:[:75] 249 251 | >:[:90] 250 252 | >@ 251 253 | >[:0] 252 254 | >[:105] 253 255 | >[:120] 254 256 | >[:150] 255 257 | >[:15] 256 258 | >[:180] 257 259 | >[:195] 258 260 | >[:210] 259 261 | >[:225] 260 262 | >[:240] 261 263 | >[:255] 262 264 | >[:270] 263 265 | >[:285] 264 266 | >[:300] 265 267 | >[:30] 266 268 | >[:315] 267 269 | >[:330] 268 270 | >[:345] 269 271 | >[:45] 270 272 | >[:60] 271 273 | >[:75] 272 274 | >[:90] 273 275 | >|[:0] 274 276 | >|[:60] 275 277 | ?[a,{-#}] 276 278 | ?[a,{-@}] 277 279 | ?[a,{-}] 278 280 | ?[a,{<:}] 279 281 | ?[a,{<@}] 280 282 | ?[a,{<}] 281 283 | ?[a,{=#}] 282 284 | ?[a,{=@}] 283 285 | ?[a,{=}] 284 286 | ?[a,{>:}] 285 287 | ?[a,{>@}] 286 288 | ?[a,{>}] 287 289 | ?[a,{~}] 288 290 | ?[a] 289 291 | ?[b,{-}] 290 292 | ?[b,{<}] 291 293 | ?[b,{=}] 292 294 | ?[b,{>}] 293 295 | ?[b] 294 296 | ?[c,{-}] 295 297 | ?[c] 296 298 | A 297 299 | B 298 300 | C 299 301 | D 300 302 | E 301 303 | F 302 304 | G 303 305 | H 304 306 | I 305 307 | J 306 308 | K 307 309 | L 308 310 | M 309 311 | N 310 312 | O 311 313 | P 312 314 | Q 313 315 | R 314 316 | S 315 317 | T 316 318 | U 317 319 | V 318 320 | W 319 321 | X 320 322 | Y 321 323 | Z 322 324 | [ 323 325 | \% 324 326 | \Chemabove 325 327 | \Chemfig 326 328 | \Leftrightarrow 327 329 | \Phi 328 330 | \Psi 329 331 | \Rightarrow 330 332 | \Superatom 331 333 | \UC_3001 332 334 | \UC_3010 333 335 | \UC_3011 334 336 | \UC_356D 335 337 | \UC_4E00 336 338 | \UC_4E09 337 339 | \UC_4E0B 338 340 | \UC_4E0D 339 341 | \UC_4E1A 340 342 | \UC_4E24 341 343 | \UC_4E59 342 344 | \UC_4E8C 343 345 | \UC_4EE3 344 346 | \UC_4EF6 345 347 | \UC_4F18 346 348 | \UC_4F53 347 349 | \UC_50AC 348 350 | \UC_5145 349 351 | \UC_5149 350 352 | \UC_5165 351 353 | \UC_5185 352 354 | \UC_51B0 353 355 | \UC_51B7 354 356 | \UC_5206 355 357 | \UC_5236 356 358 | \UC_5242 357 359 | \UC_52A0 358 360 | \UC_5316 359 361 | \UC_533A 360 362 | \UC_538B 361 363 | \UC_539F 362 364 | \UC_53BB 363 365 | \UC_53CD 364 366 | \UC_53D1 365 367 | \UC_53D6 366 368 | \UC_53EF 367 369 | \UC_5408 368 370 | \UC_540C 369 371 | \UC_5421 370 372 | \UC_542A 371 373 | \UC_544B 372 374 | \UC_5472 373 375 | \UC_548C 374 376 | \UC_54CC 375 377 | \UC_5576 376 378 | \UC_5583 377 379 | \UC_55C5 378 380 | \UC_5627 379 381 | \UC_56DB 380 382 | \UC_56DE 381 383 | \UC_56FA 382 384 | \UC_57DF 383 385 | \UC_57FA 384 386 | \UC_5883 385 387 | \UC_589E 386 388 | \UC_5931 387 389 | \UC_5B9A 388 390 | \UC_5BB9 389 391 | \UC_5BF9 390 392 | \UC_5C04 391 393 | \UC_5C0F 392 394 | \UC_5C3C 393 395 | \UC_5C51 394 396 | \UC_5C5E 395 397 | \UC_5E72 396 398 | \UC_5E94 397 399 | \UC_5EA6 398 400 | \UC_5F00 399 401 | \UC_5F3A 400 402 | \UC_5F53 401 403 | \UC_5FAE 402 404 | \UC_6027 403 405 | \UC_6052 404 406 | \UC_60AC 405 407 | \UC_6210 406 408 | \UC_6216 407 409 | \UC_624B 408 410 | \UC_6291 409 411 | \UC_62B9 410 412 | \UC_63A7 411 413 | \UC_6570 412 414 | \UC_6590 413 415 | \UC_65AF 414 416 | \UC_65B0 415 417 | \UC_65E0 416 418 | \UC_6613 417 419 | \UC_66FE 418 420 | \UC_6761 419 421 | \UC_6797 420 422 | \UC_6B21 421 423 | \UC_6C14 422 424 | \UC_6C22 423 425 | \UC_6C27 424 426 | \UC_6C28 425 427 | \UC_6C2E 426 428 | \UC_6C2F 427 429 | \UC_6C34 428 430 | \UC_6C7D 429 431 | \UC_6C99 430 432 | \UC_6CB9 431 433 | \UC_6CE2 432 434 | \UC_6CE8 433 435 | \UC_6D1B 434 436 | \UC_6D3B 435 437 | \UC_6D41 436 438 | \UC_6D4A 437 439 | \UC_6D53 438 440 | \UC_6D74 439 441 | \UC_6D82 440 442 | \UC_6D88 441 443 | \UC_6DB2 442 444 | \UC_6DE1 443 445 | \UC_6E05 444 446 | \UC_6E29 445 447 | \UC_6EB4 446 448 | \UC_6EB6 447 449 | \UC_6EE4 448 450 | \UC_6FB3 449 451 | \UC_6FC0 450 452 | \UC_70DF 451 453 | \UC_70ED 452 454 | \UC_70EF 453 455 | \UC_70F7 454 456 | \UC_7167 455 457 | \UC_7194 456 458 | \UC_71E5 457 459 | \UC_7269 458 460 | \UC_73AF 459 461 | \UC_7532 460 462 | \UC_7535 461 463 | \UC_758F 462 464 | \UC_7684 463 465 | \UC_76D0 464 466 | \UC_770B 465 467 | \UC_7845 466 468 | \UC_785D 467 469 | \UC_786B 468 470 | \UC_787C 469 471 | \UC_78B1 470 472 | \UC_78B3 471 473 | \UC_78F7 472 474 | \UC_78FA 473 475 | \UC_7A00 474 476 | \UC_7A7A 475 477 | \UC_7B28 476 478 | \UC_7C89 477 479 | \UC_7EA2 478 480 | \UC_7EDF 479 481 | \UC_7F29 480 482 | \UC_7F9F 481 483 | \UC_7FA7 482 484 | \UC_8010 483 485 | \UC_80A2 484 486 | \UC_80FA 485 487 | \UC_8131 486 488 | \UC_82EF 487 489 | \UC_84B8 488 490 | \UC_878D 489 491 | \UC_897F 490 492 | \UC_89E3 491 493 | \UC_8BD5 492 494 | \UC_8D28 493 495 | \UC_8DB3 494 496 | \UC_8DEF 495 497 | \UC_8F83 496 498 | \UC_8F90 497 499 | \UC_8FC7 498 500 | \UC_8FD8 499 501 | \UC_9002 500 502 | \UC_90BB 501 503 | \UC_9149 502 504 | \UC_914D 503 505 | \UC_9150 504 506 | \UC_915A 505 507 | \UC_916F 506 508 | \UC_9170 507 509 | \UC_9175 508 510 | \UC_9176 509 511 | \UC_9178 510 512 | \UC_9187 511 513 | \UC_918B 512 514 | \UC_919A 513 515 | \UC_919B 514 516 | \UC_91CF 515 517 | \UC_91D1 516 518 | \UC_94A0 517 519 | \UC_94A8 518 520 | \UC_94AF 519 521 | \UC_94BE 520 522 | \UC_94C1 521 523 | \UC_94DC 522 524 | \UC_94F6 523 525 | \UC_9530 524 526 | \UC_954D 525 527 | \UC_96F7 526 528 | \UC_9AD8 527 529 | \UC_9EC4 528 530 | \_ 529 531 | \bigtriangledown 530 532 | \bigtriangleup 531 533 | \boxed 532 534 | \cdot 533 535 | \cdots 534 536 | \cembelow 535 537 | \checkmark 536 538 | \chemfig 537 539 | \chenbelow 538 540 | \chhembelow 539 541 | \circ 540 542 | \circle 541 543 | \connbranch 542 544 | \ddot 543 545 | \diagdown 544 546 | \downarrow 545 547 | \equiv 546 548 | \eta 547 549 | \frac 548 550 | \frown 549 551 | \ghemabove 550 552 | \hembelow 551 553 | \leftarrow 552 554 | \leftrightarrow 553 555 | \leftrightarrows 554 556 | \leftrightharpoons 555 557 | \limits 556 558 | \nu 557 559 | \nwarrow 558 560 | \overline 559 561 | \overrightarrow 560 562 | \overset 561 563 | \phi 562 564 | \pm 563 565 | \prime 564 566 | \psi 565 567 | \rightarrow 566 568 | \rightharpoonup 567 569 | \rightleftarrows 568 570 | \rightleftharpoons 569 571 | \sim 570 572 | \sum 571 573 | \swarrow 572 574 | \textcelsius 573 575 | \textcircled 574 576 | \therefore 575 577 | \times 576 578 | \triangle 577 579 | \underset 578 580 | \uparrow 579 581 | \varepsilon 580 582 | \varnothing 581 583 | \vdots 582 584 | \wedge 583 585 | \xcacel 584 586 | \xcamcel 585 587 | \xmark 586 588 | ] 587 589 | ^ 588 590 | _ 589 591 | ` 590 592 | a 591 593 | b 592 594 | c 593 595 | d 594 596 | e 595 597 | f 596 598 | g 597 599 | h 598 600 | i 599 601 | j 600 602 | k 601 603 | l 602 604 | m 603 605 | n 604 606 | o 605 607 | p 606 608 | q 607 609 | r 608 610 | s 609 611 | t 610 612 | u 611 613 | v 612 614 | w 613 615 | x 614 616 | y 615 617 | z 616 618 | { 617 619 | | 618 620 | } 619 621 | ~ 620 622 | ~[:0] 621 623 | ~[:105] 622 624 | ~[:120] 623 625 | ~[:135] 624 626 | ~[:150] 625 627 | ~[:15] 626 628 | ~[:165] 627 629 | ~[:180] 628 630 | ~[:195] 629 631 | ~[:210] 630 632 | ~[:225] 631 633 | ~[:240] 632 634 | ~[:255] 633 635 | ~[:270] 634 636 | ~[:285] 635 637 | ~[:300] 636 638 | ~[:30] 637 639 | ~[:315] 638 640 | ~[:330] 639 641 | ~[:345] 640 642 | ~[:45] 641 643 | ~[:60] 642 644 | ~[:75] 643 645 | ~[:90] 644 646 | -------------------------------------------------------------------------------- /RFL/cond_render_main.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | import pdb 3 | import argparse 4 | import tqdm 5 | from multiprocessing import Process, synchronize, Lock, Manager, Pool 6 | import multiprocessing 7 | from six.moves import queue 8 | import text_render 9 | import utils 10 | import shutil 11 | import cond_render as chemfig_cond_render 12 | import pickle 13 | from Tool_Formula.latex_norm.transcription import parse_transcription 14 | from chemfig_struct import * 15 | 16 | _support_format = set(["text", "lrc", "xml"]) 17 | 18 | def process_cond_render(text, preprocess=True): 19 | chemfig_text, rep_dict, rep_text = utils.replace_chemfig(text) 20 | rep_text_units = parse_transcription(rep_text, simple_trans=True) # 汉字->UC 21 | new_rep_dict = {} 22 | for key, value in rep_dict.items(): 23 | # chemfig_str = "".join(list(filter(lambda x:x!=" ", value))) 24 | chemfig_str = value.strip()[8:].strip()[1:-1] 25 | try: 26 | rep_atom = chemfig_cond_render.chemfig_random_cond_parse(chemfig_str, preprocess=preprocess, bond_dict=[]) 27 | new_rep_dict[key] = rep_atom 28 | except Exception as e: 29 | new_rep_dict[key] = value 30 | raise ValueError("Parse Error, err = {}".format(e)) 31 | out_units = [] 32 | for text_unit in rep_text_units: 33 | if text_unit in new_rep_dict: 34 | out_units.append(new_rep_dict[text_unit]) 35 | else: 36 | out_units.append(text_unit) 37 | return out_units 38 | 39 | def do_single_task(params, args, records_queue, records_queue_lock, shared_params={}): 40 | line = params["line"] 41 | line_id = params["line_id"] 42 | if args.input_type == "text": 43 | out_lines = shared_params["out_lines"] 44 | out_lines_lock = shared_params["out_lines_lock"] 45 | 46 | spts= line.strip().split("\t") 47 | if len(spts) != 2: 48 | return 49 | img_key, in_text = spts 50 | out_text = text_render.text_render(in_text, debug=False) 51 | with out_lines_lock: 52 | out_lines.put("{}\t{}\n".format(img_key, out_text)) 53 | elif args.input_type == "xml": 54 | pass 55 | elif args.input_type == "lrc": 56 | if "lrc_parser" in params: 57 | lrc_parser = params["lrc_parser"] 58 | else: 59 | lrc_parser = _lrc_parser 60 | idx = line 61 | record = lrc_parser.get_record(idx) 62 | 63 | # ============= for cond render =============== 64 | area_queue = [(x, []) for x in record["sub_areas"]] 65 | success = True 66 | while len(area_queue): 67 | cur_area, prefix_arr = area_queue.pop(0) 68 | cur_prefix_arr = prefix_arr + [cur_area["idx"]] 69 | cur_idx = "-".join(cur_prefix_arr) 70 | if "sub_areas" in cur_area: 71 | area_queue += [(x, cur_prefix_arr) for x in cur_area["sub_areas"]] 72 | if "text" in cur_area: 73 | if args.use_raw_info > 0: 74 | if "raw_info" in cur_area and "text" in cur_area["raw_info"]: 75 | input_text = cur_area["raw_info"]["text"] 76 | else: 77 | raise ValueError("can not find raw info") 78 | else: 79 | input_text = cur_area["text"] 80 | out_text_units = process_cond_render(input_text) ###### 81 | unit_bytes = pickle.dumps(out_text_units) 82 | if "raw_info" not in cur_area: 83 | warnings.warn("can not find raw_info") 84 | cur_area["raw_info"]["struct_text"] = unit_bytes 85 | #for test dump 86 | if args.rend_check > 0: 87 | # out_text = cur_area["text"] 88 | try_cnt = 0 89 | while True: 90 | try: 91 | out_parsed_units = chemfig_cond_render.process_text_rnd_cond_render(out_text_units) 92 | out_text = " ".join([x[0] for x in out_parsed_units]) 93 | break 94 | except BaseException as e: 95 | if try_cnt < 5: 96 | try_cnt += 1 97 | continue 98 | else: 99 | raise ValueError(e) 100 | 101 | 102 | 103 | #cur_area["text"] = out_text 104 | 105 | # except BaseException as e: 106 | # success = False 107 | # print("can not write, check record={} err={}".format(idx, e)) 108 | # ============= for cond render =============== 109 | #pdb.set_trace() 110 | # ==============for text render =============== 111 | if success: 112 | top_id = 0 113 | areas_arr = [x for x in record["sub_areas"]] 114 | while top_id < len(areas_arr): 115 | cur_area = areas_arr[top_id] 116 | top_id += 1 117 | if "text" in cur_area: 118 | if args.use_raw_info > 0: 119 | if "raw_info" in cur_area and "text" in cur_area["raw_info"]: 120 | in_text = cur_area["raw_info"]["text"] 121 | else: 122 | raise ValueError("can not find raw info") 123 | else: 124 | in_text = cur_area["text"] 125 | cur_area["text"] = text_render.text_render(in_text, debug=False, branch_de_amb=1) 126 | if "sub_areas" in cur_area: 127 | areas_arr += cur_area["sub_areas"] 128 | # ==============for text render =============== 129 | 130 | # shutil.copy(record["image_path"], "./debug/origin.jpg") 131 | if success: 132 | out_records = shared_params["out_records"] 133 | out_records_lock = shared_params["out_records_lock"] 134 | with out_records_lock: 135 | out_records.put(record) 136 | # pdb.set_trace() 137 | pass 138 | 139 | pass 140 | 141 | 142 | def try_do_single_task(params, args, records_queue, records_queue_lock, shared_params={}): 143 | line = params["line"] 144 | line_id = params["line_id"] 145 | try: 146 | do_single_task(params, args, records_queue, records_queue_lock, shared_params) 147 | except BaseException as e: 148 | line_content = line if type(line) is str else line 149 | print("try fail! line id = {} line = {} err= {}".format(line_id, line_content, e)) 150 | if records_queue_lock is not None: 151 | with records_queue_lock: 152 | records_queue.put(1) 153 | 154 | 155 | def main(args): 156 | if args.input_type == "text": 157 | with open(args.input, "r") as fin: 158 | lines = fin.readlines() 159 | elif args.input_type == "xml": 160 | lines = utils.scan_dir(args.input, "xml") 161 | else: 162 | raise NotImplementedError("unsupport input type = {}".format(args.input_type)) 163 | 164 | 165 | # init metrics 166 | manager = Manager() 167 | records_queue = manager.Queue() 168 | records_queue_lock = manager.Lock() 169 | 170 | # TODO add share params here 171 | shared_params = {} 172 | if args.input_type == "text": 173 | shared_params["out_lines"] = manager.Queue() 174 | shared_params["out_lines_lock"] = manager.Lock() 175 | elif args.input_type == "xml": 176 | common_prefix = os.path.commonpath(lines) 177 | # shared_params["err_lines"] = manager.Queue() 178 | # shared_params["err_lines_lock"] = manager.Lock() 179 | 180 | all_tasks = [] 181 | line_id = -1 182 | 183 | if args.num_workers <= 0: 184 | for line in tqdm.tqdm(lines): 185 | line_id += 1 186 | params = {} 187 | params["line"] = line 188 | params["line_id"] = line_id 189 | if args.input_type == "xml": 190 | params["common_prefix"] = common_prefix 191 | if args.input_type == "lrc": 192 | params["lrc_parser"] = lrc_parser 193 | cur_task = (params, args, records_queue, records_queue_lock, shared_params) 194 | do_single_task(*cur_task) 195 | if line_id > 20: 196 | break 197 | #pdb.set_trace() 198 | else: 199 | for line in tqdm.tqdm(lines): 200 | line_id += 1 201 | params = {} 202 | params["line"] = line 203 | params["line_id"] = line_id 204 | if args.input_type == "xml": 205 | params["common_prefix"] = common_prefix 206 | cur_task = (params, args, records_queue, records_queue_lock, shared_params) 207 | all_tasks.append(cur_task) 208 | pass 209 | 210 | def print_error(error): 211 | print("error:", error) 212 | 213 | def init(a): 214 | global _lrc_parser 215 | _lrc_parser = a 216 | 217 | poolSize = args.num_workers 218 | if args.input_type == "lrc": 219 | pool = Pool(poolSize, initializer=init, initargs=(lrc_parser, )) 220 | else: 221 | pool = Pool(poolSize) 222 | pool.starmap_async(try_do_single_task, all_tasks, error_callback=print_error) 223 | pool.close() 224 | tq = tqdm.tqdm(total=len(all_tasks)) 225 | count = 0 226 | print("begin") 227 | #try: 228 | while count < len(all_tasks): 229 | try: 230 | c = records_queue.get_nowait() 231 | except queue.Empty: 232 | continue 233 | if args.input_type == "lrc": 234 | try: 235 | with shared_params["out_records_lock"]: 236 | cur_record = shared_params["out_records"].get_nowait() 237 | lrc_writer.add_record(cur_record) 238 | except queue.Empty: 239 | pass 240 | count += 1 241 | tq.update(1) 242 | # except: 243 | # pass 244 | pool.join() 245 | 246 | #TODO add post process 247 | if args.input_type == "text": 248 | out_lines = shared_params["out_lines"] 249 | out_lines_lock = shared_params["out_lines_lock"] 250 | with open(args.output, "w") as fout: 251 | while not out_lines.empty(): 252 | line = out_lines.get() 253 | fout.write(line) 254 | elif args.input_type == "xml": 255 | pass 256 | # treat args.output as output_dir 257 | elif args.input_type == "lrc": 258 | while not shared_params["out_records"].empty(): 259 | record = shared_params["out_records"].get() 260 | lrc_writer.add_record(record) 261 | #pdb.set_trace() 262 | lrc_writer.close() 263 | 264 | 265 | -------------------------------------------------------------------------------- /RFL/viz_struct.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | import argparse 3 | import pdb 4 | from PIL import Image, ImageDraw, ImageFont 5 | import cv2 6 | import text_render 7 | import image_render 8 | import numpy as np 9 | 10 | font_path = '' # your font path 11 | def adjuct_font_and_width(in_text, fixed_height, ref_width, min_font_size=15, line_space=2): 12 | min_font_size = 15 13 | line_space = 2 14 | cur_text_width = ref_width 15 | cur_lab_font_size = 25 16 | font = ImageFont.truetype(font_path, cur_lab_font_size, encoding="utf-8") 17 | while True: 18 | success = True 19 | cur_x = 0 20 | cur_y = 0 21 | cur_h = 0 22 | for ch in in_text: 23 | w, h = font.getsize(ch) 24 | if cur_x + w >= cur_text_width: 25 | cur_x = 0 26 | cur_y += (cur_h + line_space) 27 | cur_h = 0 28 | cur_x += w 29 | cur_h = max(cur_h, h) 30 | if cur_y + cur_h >= fixed_height: 31 | success = False 32 | break 33 | if success is True: 34 | break 35 | else: 36 | if cur_lab_font_size > min_font_size: 37 | cur_lab_font_size -= 1 38 | font = ImageFont.truetype(font_path, cur_lab_font_size, encoding="utf-8") 39 | else: 40 | cur_text_width += 100 41 | return cur_lab_font_size, cur_text_width 42 | 43 | def drawText(text, height, width, font, line_space=2): 44 | max_w = 0 45 | max_h = 0 46 | cur_x = 0 47 | cur_y = 0 48 | cur_h = 0 49 | img_np = np.ones((height, width, 3), dtype=np.uint8)*255 50 | #img = Image.new('RGB', (width, height), "#FFFFFF") 51 | img = Image.fromarray(img_np) 52 | imgDraw = ImageDraw.Draw(img) 53 | for ch in text: 54 | w, h = font.getsize(ch) 55 | if cur_x + w >= width: 56 | cur_x = 0 57 | cur_y += (cur_h + line_space) 58 | cur_h = 0 59 | imgDraw.text((cur_x, cur_y), ch, fill=(0,0,0), font=font) 60 | cur_x += w 61 | cur_h = max(cur_h, h) 62 | max_w = max([max_w, cur_x]) 63 | max_h = max([max_h, cur_y + cur_h]) 64 | # pdb.set_trace() 65 | img_np = np.array(img) 66 | img_out = np.ones((height, width, 3), dtype=np.uint8)*255 67 | #img_np[:max_h, :max_w, :] 68 | start_x = (width - max_w) // 2 69 | start_y = (height - max_h) // 2 70 | img_out[start_y:start_y+max_h, start_x:start_x+max_w, :] = img_np[:max_h, :max_w, :] 71 | return img_out 72 | 73 | def adjuct_font_and_height(in_text, fixed_width, ref_height, min_font_size=15, line_space=2): 74 | min_font_size = 15 75 | line_space = 2 76 | cur_text_height = ref_height 77 | cur_lab_font_size = 25 78 | font = ImageFont.truetype(font_path, cur_lab_font_size, encoding="utf-8") 79 | while True: 80 | success = True 81 | cur_x = 0 82 | cur_y = 0 83 | cur_h = 0 84 | max_height = 0 85 | for ch in in_text: 86 | w, h = font.getsize(ch) 87 | if cur_x + w >= fixed_width: 88 | cur_x = 0 89 | cur_y += (cur_h + line_space) 90 | cur_h = 0 91 | cur_x += w 92 | cur_h = max(cur_h, h) 93 | if cur_y + cur_h >= cur_text_height: 94 | success = False 95 | break 96 | else: 97 | max_height = max([max_height, cur_y + cur_h]) 98 | if success is True: 99 | break 100 | else: 101 | if cur_lab_font_size > min_font_size: 102 | cur_lab_font_size -= 1 103 | font = ImageFont.truetype(font_path, cur_lab_font_size, encoding="utf-8") 104 | else: 105 | cur_text_height += (cur_lab_font_size + line_space) 106 | # pdb.set_trace() 107 | return cur_lab_font_size, max_height 108 | 109 | 110 | 111 | def viz_struct_res(img_path, rec_rep_dict, trans_text=None): 112 | # print(img_path) 113 | src_img = cv2.imread(img_path) 114 | src_h, src_w, src_c = src_img.shape 115 | 116 | img_pair_dict = {} 117 | labs_height = 0 118 | recs_height = 0 119 | tmp_width_sum = 0 120 | tmp_width_cnt = 0 121 | temp_dir = "/ps3/cv9/haowu16/Texlive/tex_temp/" 122 | image_name = os.path.splitext(os.path.basename(img_path))[0] 123 | for key in rec_rep_dict: 124 | if "res_graph" in rec_rep_dict[key]: 125 | img_pair_dict[key] = {} 126 | lab_atom = rec_rep_dict[key]["lab_atom"] 127 | rec_atom = rec_rep_dict[key]["rec_atom"] 128 | 129 | lab_atom_img = image_render.rend(lab_atom, scale=100) 130 | rec_atom_img = image_render.rend(rec_atom, scale=100) 131 | # lab_reverse_trans = test_reverse_trans.process_trans("".join(chemfig_render.reverse_rend_text(lab_atom, scale=1, remove_dup=0, norm_circle=0, connect_distant=0))) 132 | # lab_atom_img = test_reverse_trans.texlive_rend(lab_reverse_trans, os.path.join(temp_dir, "{}_lab.pdf".format(image_name))) 133 | # rec_reverse_trans = test_reverse_trans.process_trans("".join(chemfig_render.reverse_rend_text(rec_atom, scale=1, remove_dup=0, norm_circle=0, connect_distant=0))) 134 | # rec_atom_img = test_reverse_trans.texlive_rend(rec_reverse_trans, os.path.join(temp_dir, "{}_rec.pdf".format(image_name))) 135 | # pdb.set_trace() 136 | img_pair_dict[key]["lab_img"] = lab_atom_img 137 | img_pair_dict[key]["rec_img"] = rec_atom_img 138 | cur_width = max([lab_atom_img.shape[1], rec_atom_img.shape[1]]) 139 | img_pair_dict[key]["cur_width"] = cur_width 140 | labs_height = max([labs_height, lab_atom_img.shape[0]]) 141 | recs_height = max([recs_height, rec_atom_img.shape[0]]) 142 | tmp_width_sum += cur_width 143 | tmp_width_cnt += 1 144 | if tmp_width_cnt > 0: 145 | ref_width = int(float(tmp_width_sum)/tmp_width_cnt) 146 | else: 147 | ref_width = int(float(src_w)/len(rec_rep_dict)) 148 | if labs_height == 0: 149 | labs_height = 50 150 | if recs_height == 0: 151 | recs_height = 50 152 | 153 | for key in rec_rep_dict: 154 | if "res_text" in rec_rep_dict[key]: 155 | img_pair_dict[key] = {} 156 | lab_text = rec_rep_dict[key]["lab_text"] 157 | rec_text = rec_rep_dict[key]["rec_text"] 158 | 159 | min_font_size = 15 160 | line_space = 2 161 | 162 | lab_font_size, lab_text_width = adjuct_font_and_width(lab_text, labs_height, ref_width) 163 | rec_font_size, rec_text_width = adjuct_font_and_width(rec_text, recs_height, ref_width, min_font_size=lab_font_size) 164 | common_font_size = min([lab_font_size, rec_font_size]) 165 | common_text_width = max([lab_text_width, rec_text_width]) 166 | 167 | font = ImageFont.truetype(font_path, common_font_size, encoding="utf-8") 168 | lab_text_img = drawText(lab_text, labs_height, common_text_width, font) 169 | rec_text_img = drawText(rec_text, recs_height, common_text_width, font) 170 | img_pair_dict[key]["lab_img"] = lab_text_img 171 | img_pair_dict[key]["rec_img"] = rec_text_img 172 | img_pair_dict[key]["cur_width"] = common_text_width 173 | 174 | 175 | 176 | img_line_space = 10 177 | img_col_space = 10 178 | result_bbox_width = 3 179 | out_width = 0 180 | for key in img_pair_dict: 181 | out_width += img_pair_dict[key]["cur_width"] 182 | out_width += 2 * result_bbox_width 183 | out_width += img_col_space 184 | out_width = max([out_width, src_w]) 185 | # put src lab 186 | if trans_text is not None and trans_text != "": 187 | font_size, area_height = adjuct_font_and_height(trans_text, out_width, 30) 188 | font = ImageFont.truetype(font_path, font_size, encoding="utf-8") 189 | 190 | trans_img = drawText(trans_text, area_height, out_width, font) 191 | # pdb.set_trace() 192 | new_src_img = 255*np.ones((src_h + area_height, out_width, 3), dtype=np.uint8) 193 | new_src_img[:src_h, :src_w, :] = src_img 194 | new_src_img[src_h:, :, :] = trans_img 195 | #src_img = np.vstack([src_img, trans_img]) 196 | src_img = new_src_img 197 | src_h, src_w, src_c = src_img.shape 198 | 199 | out_height = labs_height + recs_height + src_h + img_line_space*2 + result_bbox_width * 4 200 | outImg = 255*np.ones((out_height, out_width, 3), dtype=np.uint8) 201 | # pdb.set_trace() 202 | #put src img 203 | src_pos_y = labs_height +img_line_space+ result_bbox_width * 2 204 | outImg[src_pos_y:src_pos_y+src_h, :src_w, :] = src_img 205 | #put recs 206 | cur_pos_x = 0 207 | lab_pos_y = 0 208 | rec_pos_y = src_pos_y + src_h + img_line_space 209 | 210 | for key in img_pair_dict: 211 | lab_img = img_pair_dict[key]["lab_img"] 212 | lab_img_h, lab_img_w, _ = lab_img.shape 213 | container_w = img_pair_dict[key]["cur_width"] 214 | container_h = labs_height 215 | if rec_rep_dict[key]["res_cmp"] == 0: #correct 216 | outImg[lab_pos_y:lab_pos_y+container_h+2*result_bbox_width, cur_pos_x:cur_pos_x+container_w+2*result_bbox_width, :] = 0 217 | else: 218 | outImg[lab_pos_y:lab_pos_y+container_h+2*result_bbox_width, cur_pos_x:cur_pos_x+container_w+2*result_bbox_width, :2] = 0 219 | lab_img_expand = 255*np.ones((container_h, container_w, 3), dtype=np.uint8) 220 | t_x = (container_w - lab_img_w) // 2 221 | t_y = (container_h - lab_img_h) // 2 222 | lab_img_expand[t_y:t_y+lab_img_h, t_x:t_x+lab_img_w, :] = lab_img 223 | outImg[lab_pos_y+result_bbox_width:lab_pos_y+result_bbox_width+container_h, cur_pos_x+result_bbox_width:cur_pos_x+result_bbox_width+container_w, :] = lab_img_expand 224 | 225 | rec_img = img_pair_dict[key]["rec_img"] 226 | rec_img_h, rec_img_w, _ = rec_img.shape 227 | container_w = img_pair_dict[key]["cur_width"] 228 | container_h = recs_height 229 | if rec_rep_dict[key]["res_cmp"] == 0: #correct 230 | outImg[rec_pos_y:rec_pos_y+container_h+2*result_bbox_width, cur_pos_x:cur_pos_x+container_w+2*result_bbox_width, :] = 0 231 | else: 232 | outImg[rec_pos_y:rec_pos_y+container_h+2*result_bbox_width, cur_pos_x:cur_pos_x+container_w+2*result_bbox_width, :2] = 0 233 | rec_img_expand = 255*np.ones((container_h, container_w, 3), dtype=np.uint8) 234 | t_x = (container_w - rec_img_w) // 2 235 | t_y = (container_h - rec_img_h) // 2 236 | rec_img_expand[t_y:t_y+rec_img_h, t_x:t_x+rec_img_w, :] = rec_img 237 | outImg[rec_pos_y+result_bbox_width:rec_pos_y+result_bbox_width+container_h, cur_pos_x+result_bbox_width:cur_pos_x+result_bbox_width+container_w, :] = rec_img_expand 238 | 239 | cur_pos_x += (container_w + img_col_space + result_bbox_width * 2) 240 | return outImg 241 | 242 | def main(args): 243 | pass 244 | 245 | if __name__ == "__main__": 246 | parser = argparse.ArgumentParser("") 247 | parser.add_argument("-input", type=str, default="") 248 | args = parser.parse_args() 249 | main(args) 250 | -------------------------------------------------------------------------------- /rain/my_encode.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Model Architecture Definition for Encoder 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | import torch.nn.init as init 8 | from torch.nn import Parameter 9 | import os 10 | import sys 11 | 12 | from . import xconfig 13 | import logging 14 | import math 15 | logger = logging.getLogger() 16 | 17 | class ConvBnRelu(nn.Module): 18 | """(convolution => [BN] => ReLU)""" 19 | 20 | def __init__(self, in_channels, out_channels, kernel_size = (3, 3), stride = (1, 1), dilation = (1, 1)): 21 | super().__init__() 22 | 23 | my_padding = ( int((kernel_size[0] - 1)/2) + dilation[0] - 1, int((kernel_size[1] - 1)/2) + dilation[1] - 1 ) 24 | 25 | self.conv = nn.Sequential( 26 | nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, padding=my_padding, dilation = dilation, bias=False), 27 | nn.BatchNorm2d(out_channels), 28 | nn.ReLU(inplace=True), 29 | ) 30 | 31 | def forward(self, x): 32 | return self.conv(x) 33 | 34 | 35 | 36 | class SELayerMask(nn.Module): 37 | def __init__(self, channel, reduction=16): 38 | super(SELayerMask, self).__init__() 39 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 40 | self.fc = nn.Sequential( 41 | nn.Linear(channel, channel // reduction, bias=False), 42 | nn.ReLU(inplace=True), 43 | nn.Linear(channel // reduction, channel, bias=False), 44 | nn.Sigmoid() 45 | ) 46 | 47 | def forward(self, x, mask = None): 48 | y = (x * mask).sum([2, 3]) / (mask.sum([2, 3]) + 1e-6) 49 | y = self.fc(y).unsqueeze(-1).unsqueeze(-1) 50 | return x * y.expand_as(x) 51 | 52 | class ResBasicBlockSE(nn.Module): 53 | 54 | def __init__(self, in_channels, out_channels, basic_groups = 4, kernel_size = (3, 3), stride = (1, 1), dilation = (1, 1)): 55 | super().__init__() 56 | 57 | my_padding = ( int((kernel_size[0] - 1)/2) + dilation[0] - 1, int((kernel_size[1] - 1)/2) + dilation[1] - 1 ) 58 | self.out_channels = out_channels 59 | 60 | self.conv1 = nn.Sequential( 61 | nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, padding=my_padding, dilation = dilation, stride = stride, bias=False), 62 | nn.BatchNorm2d(out_channels), 63 | nn.ReLU(inplace=True), 64 | ) 65 | 66 | self.conv2 = nn.Sequential( 67 | nn.Conv2d(out_channels, out_channels, kernel_size=kernel_size, padding=my_padding, dilation = dilation, stride = 1, bias=False), 68 | nn.BatchNorm2d(out_channels), 69 | ) 70 | 71 | self.relu = nn.ReLU(inplace=True) 72 | 73 | if stride[0] > 1 or stride[1] > 1: 74 | self.res_conv = True 75 | else: 76 | self.res_conv = False 77 | 78 | if self.res_conv: 79 | self.res_conv = nn.Sequential( 80 | nn.Conv2d(in_channels, out_channels, kernel_size=(1, 1), padding=(0, 0), dilation = (1, 1), stride = stride, bias=False), 81 | nn.BatchNorm2d(out_channels), 82 | ) 83 | 84 | self.se_layer = SELayerMask(out_channels) 85 | 86 | def forward(self, x, mask = None): 87 | 88 | x_feat = self.conv1(x) * mask 89 | x_feat = self.conv2(x_feat) * mask 90 | 91 | x_feat = self.se_layer(x_feat, mask) 92 | 93 | if self.res_conv: 94 | return self.relu(x_feat + self.res_conv(x)) 95 | else: 96 | return self.relu(x_feat + x) 97 | 98 | 99 | 100 | class PosSelfAtten(nn.Module): 101 | def __init__(self, psa_dim_in, psa_dim_att, psa_dim_out, name = 'encoder_pos_selfAtten'): 102 | super(PosSelfAtten, self).__init__() 103 | self.name = name 104 | self._psa_dim_out = psa_dim_out 105 | self._psa_dim_att = psa_dim_att 106 | self._psa_dim_in = psa_dim_in 107 | 108 | self.reduc_conv = ConvBnRelu(psa_dim_in, psa_dim_out, kernel_size = (1, 1)) 109 | self.atten_conv = ConvBnRelu(psa_dim_in, psa_dim_att, kernel_size = (1, 1)) 110 | 111 | def forward(self, x, mask = None): 112 | 113 | x_feat = self.reduc_conv(x) 114 | 115 | 116 | x_atten = self.atten_conv(x) # N C H W 117 | n, c, h, w = x_atten.shape 118 | x_atten_ = x_atten.view(n, c, -1).contiguous() 119 | x_atten_trans = x_atten_.transpose(1, 2).contiguous() 120 | 121 | energy = torch.matmul(x_atten_trans, x_atten_) / (1e-8 + math.sqrt(c)) 122 | 123 | if mask is not None: 124 | energy = energy + (mask.view(n, 1, -1).contiguous() - 1) * 1e8 125 | 126 | weight = energy.softmax(2) 127 | x_feat_trans = x_feat.view(n, self._psa_dim_out, -1).contiguous().transpose(1, 2).contiguous() 128 | out_feat = torch.matmul(weight, x_feat_trans) 129 | out_feat = out_feat.transpose(1, 2).contiguous().view(n, self._psa_dim_out, h, w).contiguous() 130 | 131 | if mask is not None: 132 | out_feat = out_feat * mask 133 | return out_feat 134 | 135 | class ChanSelfAtten(nn.Module): 136 | def __init__(self, csa_dim_in, csa_dim_att, csa_dim_out, name = 'encoder_chan_selfAtten'): 137 | super(ChanSelfAtten, self).__init__() 138 | self.name = name 139 | self._csa_dim_out = csa_dim_out 140 | self._csa_dim_att = csa_dim_att 141 | self._csa_dim_in = csa_dim_in 142 | 143 | self.reduc_conv = ConvBnRelu(csa_dim_in, csa_dim_out, kernel_size = (1, 1)) 144 | self.atten_conv = ConvBnRelu(csa_dim_in, csa_dim_out, kernel_size = (1, 1)) 145 | 146 | 147 | def forward(self, x, mask = None): 148 | 149 | x_feat = self.reduc_conv(x) 150 | 151 | 152 | x_atten = self.atten_conv(x) # N C H W 153 | 154 | n, c, h, w = x_atten.shape 155 | x_atten_ = x_atten.view(n, c, -1).contiguous() 156 | x_atten_trans = x_atten_.transpose(1, 2).contiguous() 157 | 158 | energy = torch.matmul(x_atten_, x_atten_trans) / (1e-8 + math.sqrt(h * w)) # N C C 159 | weight = energy.softmax(2) 160 | x_feat_trans = x_feat.view(n, self._csa_dim_out, -1).contiguous() 161 | 162 | out_feat = torch.matmul(weight, x_feat_trans) 163 | out_feat = out_feat.view(n, self._csa_dim_out, h, w).contiguous() 164 | 165 | 166 | return out_feat 167 | 168 | class EncodeSelfAtten(nn.Module): 169 | def __init__(self, dim_in, dim_att, dim_out, name = 'encoder_selfAtten'): 170 | super(EncodeSelfAtten, self).__init__() 171 | self.name = name 172 | self._dim_out = dim_in 173 | self._dim_att = dim_att 174 | self._dim_in = dim_out 175 | 176 | self.pos_selfAtten = PosSelfAtten(dim_in, dim_att, dim_out) 177 | self.chan_selfAtten = ChanSelfAtten(dim_in, dim_att, dim_out) 178 | 179 | self.residual_conv = ConvBnRelu(dim_in, dim_out, kernel_size = (3, 3)) 180 | def forward(self, x, mask = None): 181 | 182 | out_feat = self.pos_selfAtten(x, mask) + self.chan_selfAtten(x, mask) + self.residual_conv(x) 183 | 184 | if mask is not None: 185 | out_feat = out_feat * mask 186 | 187 | return out_feat 188 | 189 | class Backbone(nn.Module): 190 | def __init__(self, in_channels, num_level=4, num_block=[4, 4, 4, 4], num_filters_arr=[32, 64, 64, 128], num_stride_arr = [(2, 2), (2, 2), (2, 2), (2, 2)], num_groups = [4, 4, 4, 4], encoder_use_res = [1, 1, 1, 1], 191 | residual=True, dropout=0, name='cnet'): 192 | super(Backbone, self).__init__() 193 | self.num_level = num_level 194 | self.num_block = num_block 195 | self.name = name 196 | num_channels = 1 197 | stem_channels = num_filters_arr[0] 198 | self.stem_conv = ConvBnRelu(in_channels, stem_channels) 199 | num_filters_arr = num_filters_arr[1:] 200 | for level in range(num_level): 201 | dr = 0 202 | if level == 4: 203 | dr = dropout 204 | residual_conv = False 205 | if num_filters_arr[level] != num_channels: 206 | residual_conv = True 207 | 208 | if level < 2 or residual is False: 209 | if level == 0: 210 | self.make_conv_block(level, num_block[level], stem_channels, num_filters_arr[level], num_stride_arr[level], dr, encoder_use_res[level], residual_conv, num_groups[level]) 211 | else: 212 | self.make_conv_block(level, num_block[level], num_filters_arr[level-1], num_filters_arr[level], num_stride_arr[level], dr, encoder_use_res[level], residual_conv, num_groups[level]) 213 | num_channels = num_filters_arr[level] 214 | 215 | 216 | def make_conv_block(self, level, num_block, in_channels, out_channels, stride, dropout, residual, residual_conv, groups = 2): 217 | for block in range(num_block): 218 | 219 | if block == 0: 220 | self.add_module('{}_maskpool{}'.format(self.name, level), nn.MaxPool2d(kernel_size=stride, stride=stride, padding=0, dilation=1, ceil_mode=True)) 221 | if residual > 0: 222 | if block == 0: 223 | self.add_module('{}_conv_l{}_b{}'.format(self.name, level, block), ResBasicBlockSE(in_channels, out_channels, basic_groups = groups, stride=stride)) 224 | else: 225 | self.add_module('{}_conv_l{}_b{}'.format(self.name, level, block), ResBasicBlockSE(out_channels, out_channels, basic_groups = groups)) 226 | 227 | 228 | 229 | def forward(self, x, source_mask): 230 | x = self.stem_conv(x) * source_mask 231 | for level in range(self.num_level): 232 | for block in range(self.num_block[level]): 233 | if block == 0: 234 | source_mask = self._modules['{}_maskpool{}'.format(self.name, level)](source_mask) 235 | 236 | x = self._modules['{}_conv_l{}_b{}'.format(self.name, level, block)](x, source_mask) 237 | return x, source_mask 238 | 239 | 240 | 241 | class Encoder(nn.Module): 242 | def __init__(self): 243 | super(Encoder, self).__init__() 244 | self.backbone = Backbone(in_channels = xconfig.source_dim, 245 | num_level = len(xconfig.encoder_units), 246 | num_block = xconfig.encoder_units, 247 | num_filters_arr = xconfig.encoder_filter_list, 248 | num_stride_arr = xconfig.encoder_stride_list, 249 | num_groups = xconfig.encoder_basic_group, 250 | encoder_use_res = xconfig.encoder_use_res, 251 | dropout = xconfig.encode_dropout, 252 | residual = False, 253 | name = 'encoder') 254 | 255 | self.feat_drop = nn.Dropout(xconfig.encode_feat_dropout) 256 | self.selfAtten = EncodeSelfAtten(xconfig.encoder_filter_list[-1], xconfig.encoder_position_att, xconfig.encoder_position_dim) 257 | 258 | def forward(self, source, source_mask): 259 | encode_out, encode_mask = self.backbone(source, source_mask) 260 | encode_out = self.feat_drop(encode_out) 261 | encode_out = self.selfAtten(encode_out, encode_mask) 262 | 263 | return encode_out, encode_mask -------------------------------------------------------------------------------- /RFL/utils.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | import cv2 3 | import math 4 | import numpy 5 | import pdb 6 | from collections import OrderedDict 7 | import re 8 | import Levenshtein 9 | 10 | 11 | pair_dict = {"}":"{", "]":"["} 12 | def replace_chemfig(text): 13 | replace_dict = OrderedDict() 14 | ind = 0 15 | new_text = "" 16 | while True: 17 | pos = text.find("\\chemfig") 18 | if pos == -1: 19 | break 20 | cur_pos = pos + 8 21 | cur_left_pair = None 22 | cur_left_pos = None 23 | curLevel = 0 24 | range_cnt = {"[":0, "{":0} 25 | while cur_pos < len(text): 26 | ch = text[cur_pos] 27 | if ch == "[" or ch == "{": 28 | if cur_left_pair is None: 29 | cur_left_pair = ch 30 | cur_left_pos = cur_pos 31 | curLevel = 1 32 | elif cur_left_pair == ch: 33 | curLevel += 1 34 | elif ch == "}" or ch == "]": 35 | if cur_left_pair == pair_dict[ch]: 36 | curLevel -= 1 37 | if curLevel == 0: 38 | range_cnt[cur_left_pair] += 1 39 | if range_cnt["["] > 1: 40 | raise ValueError("multiple attr range") 41 | if range_cnt["{"] >= 1: 42 | # pdb.set_trace() 43 | break 44 | cur_left_pair = None 45 | cur_left_pos = None 46 | 47 | elif cur_left_pair is None: 48 | if ch != " ": 49 | raise ValueError("format err, input = {}".format(text)) 50 | else: 51 | pass 52 | cur_pos += 1 53 | beginPos = cur_left_pos 54 | endPos = cur_pos + 1 55 | rep_key = "\\chem{}".format(chr(ord('a') + ind)) 56 | ind += 1 57 | replace_dict[rep_key] = "\\chemfig "+text[beginPos:endPos] 58 | text = text[0:pos] + " " + rep_key + " " + text[endPos:] 59 | new_text += replace_dict[rep_key] + " " 60 | 61 | pos = cur_pos + 1 62 | return new_text, replace_dict, text 63 | 64 | 65 | def get_atom_group(item_list): 66 | '''检测原子团,并将item_list中分开的原子团替换为合并之后的原子团''' 67 | lengths = [len(item) for item in item_list if isinstance(item, str)] 68 | # print(lengths) 69 | # print(item_list) 70 | consecutive_ones = [] 71 | start_index = None 72 | end_index = None 73 | for i, num in enumerate(lengths): 74 | if i == 0 or i == len(item_list)-1: 75 | continue # 跳过第一个和最后一个 { } 76 | if num == 1: 77 | if start_index is None: 78 | start_index = i 79 | else: 80 | if start_index is not None: 81 | end_index = i - 1 82 | if end_index - start_index + 1 >= 2: 83 | consecutive_ones.append((start_index, end_index)) 84 | start_index = None 85 | 86 | if start_index is not None: 87 | end_index = len(lengths) - 2 88 | if end_index - start_index + 1 >= 2: 89 | consecutive_ones.append((start_index, end_index)) 90 | 91 | for item in reversed(consecutive_ones): 92 | # replace 93 | # print(item[0], item[1]) 94 | atom_group = "".join(item_list[item[0]: item[1]+1]) 95 | print(atom_group) 96 | del item_list[item[0]: item[1] + 1] 97 | # item_list[item[0]: item[1]] = atom_group 98 | item_list.insert(item[0], atom_group) 99 | 100 | # print(item_list) 101 | return item_list 102 | # print(lengths[consecutive_ones]) 103 | 104 | def scan_dir(in_dir, ext, rescan=False): 105 | cache_path = os.path.join(in_dir, "{}.cache.txt".format(ext)) 106 | if not os.path.exists(cache_path) or rescan is True: 107 | cmd = "find {} -name \*.{}".format(in_dir, ext) 108 | in_lines = os.popen(cmd).readlines() 109 | with open(cache_path, "w") as fout: 110 | fout.writelines(in_lines) 111 | else: 112 | with open(cache_path, "r") as fin: 113 | in_lines = fin.readlines() 114 | return in_lines 115 | 116 | def cal_edit_ops(str1, str2): 117 | char_idx_dict = dict() 118 | for item in str1: 119 | if item not in char_idx_dict: 120 | char_idx_dict[item] = chr(len(char_idx_dict)) 121 | for item in str2: 122 | if item not in char_idx_dict: 123 | char_idx_dict[item] = chr(len(char_idx_dict)) 124 | str1 = ''.join([char_idx_dict[item] for item in str1]) 125 | str2 = ''.join([char_idx_dict[item] for item in str2]) 126 | ops = Levenshtein.editops(str1, str2) 127 | return ops 128 | 129 | def norm_text(text_arr, do_norm_sub=True): 130 | tmp_text_arr = [] 131 | for text in text_arr: 132 | if text.startswith("[:") and text.endswith("]"): 133 | continue 134 | if text.find("[:") != -1 and text.endswith("]"): 135 | pos = text.find("[:") 136 | text = text[:pos] 137 | if text == "\\:" or text == ":": 138 | text = ":" 139 | tmp_text_arr.append(text) 140 | 141 | out_text_arr = tmp_text_arr 142 | return out_text_arr 143 | 144 | def removeAngle(inArr): 145 | outArr = [] 146 | for unit in inArr: 147 | # pos = unit.find("[") 148 | # if pos != -1 and unit.endswith("]"): 149 | # pdb.set_trace() 150 | # outArr.append(unit[:pos]) 151 | # else: 152 | # outArr.append(unit) 153 | tgt_unit = re.sub(r"\[:[0-9]*\]","", unit) #fixed at 2022.03.04 for bug in ?[a,{=}] 154 | outArr.append(unit) 155 | return outArr 156 | 157 | #====================== text process =======================# 158 | 159 | def norm_sub(in_text_arr): 160 | out_text_arr = [] 161 | # for i, text in enumerate(tmp_text_arr): 162 | ind = 0 163 | left_start = False 164 | while ind < len(in_text_arr): 165 | if in_text_arr[ind] == "_" and ind + 1 < len(in_text_arr) and in_text_arr[ind + 1] == "{": 166 | left_start = True 167 | ind += 2 168 | continue 169 | if left_start and in_text_arr[ind] == "}": 170 | left_start = False 171 | ind += 1 172 | continue 173 | out_text_arr.append(in_text_arr[ind]) 174 | ind += 1 175 | return out_text_arr 176 | 177 | def norm_text(text_arr, do_norm_sub=True): 178 | tmp_text_arr = [] 179 | for text in text_arr: 180 | if text.startswith("[:") and text.endswith("]"): 181 | continue 182 | if text.find("[:") != -1 and text.endswith("]"): 183 | pos = text.find("[:") 184 | text = text[:pos] 185 | if text == "\\:" or text == ":": 186 | text = ":" 187 | tmp_text_arr.append(text) 188 | 189 | if do_norm_sub > 0: 190 | out_text_arr = norm_sub(tmp_text_arr) 191 | else: 192 | out_text_arr = tmp_text_arr 193 | return out_text_arr 194 | 195 | def rm_bracket(label_list, rm_list = ['\\underline']): 196 | if(len(label_list) == 0): 197 | return 0, [] 198 | if(label_list[0] != '{'): 199 | return 0, [] 200 | left_bracket_num = 0 201 | right_bracket_num = 0 202 | for i, temp in enumerate(label_list): 203 | if(temp == '{'): 204 | left_bracket_num += 1 205 | elif (temp == '}'): 206 | right_bracket_num += 1 207 | if(left_bracket_num == right_bracket_num): 208 | temp_list = rm_underline_textbf(label_list[1:i], rm_list) 209 | return i + 1, temp_list 210 | return 0, [] 211 | 212 | def rm_underline_textbf(label_list, rm_list = ['\\underline']): 213 | # print(label_list) 214 | new_label_list = [] 215 | idx = 0 216 | while idx < len(label_list): 217 | if label_list[idx] != rm_list[0]: 218 | new_label_list.append(label_list[idx]) 219 | idx += 1 220 | else: 221 | new_idx, temp_list = rm_bracket(label_list[idx+1:]) 222 | new_label_list = new_label_list + temp_list 223 | idx += 1 224 | idx += new_idx 225 | return new_label_list 226 | 227 | 228 | def post_process(input_arr): 229 | output_arr = input_arr 230 | output_arr = [temp for temp in output_arr if temp != "\\smear" and temp != "\\space"] 231 | output_arr = rm_underline_textbf(output_arr) 232 | return output_arr 233 | 234 | #====================process for reverser texlive render==============# 235 | 236 | 237 | null_str=set(["\t", " ", "\u00A0", "\u3000"]) 238 | def rm_bracket_v2(inStr, prefixs = ["\\textit"]): 239 | curStr = inStr 240 | # curStr = list(filter(None, curStr)) 241 | for cur_prefix in prefixs: 242 | pos = -1 243 | while True: 244 | pos = curStr.find(cur_prefix, pos+1) 245 | if pos == -1: 246 | break 247 | prefix_end = pos + len(cur_prefix) 248 | pos_left = curStr.find("{", prefix_end) 249 | if pos_left == -1: 250 | continue 251 | valid = True 252 | for ind in range(prefix_end, pos_left): 253 | if curStr[ind] not in null_str: 254 | valid = False 255 | break 256 | if not valid: 257 | continue 258 | ind = pos_left + 1 259 | pos_right = -1 260 | curLevel =1 261 | while ind < len(curStr): 262 | if curStr[ind] == "{": 263 | curLevel += 1 264 | elif curStr[ind] == "}": 265 | curLevel -= 1 266 | if curLevel == 0: 267 | pos_right = ind 268 | break 269 | ind += 1 270 | curStr = curStr[0:pos] + curStr[pos_left+1:pos_right] + curStr[pos_right+1:] 271 | return curStr 272 | 273 | def IsChinese(uni_num): 274 | if uni_num >= 0x4E00 and uni_num <= 0x9FBF: 275 | return True 276 | elif uni_num>=0xF900 and uni_num <= 0xFAFF: 277 | return True 278 | else: 279 | return False 280 | 281 | def process_trans_for_texlive(trans): 282 | # rm textit 283 | trans = rm_bracket_v2(trans) 284 | # process chinese 285 | new_trans = "" 286 | is_chn_range = False 287 | for ch in trans: 288 | if len(ch) != 1: 289 | continue 290 | uni_num = ord(ch) 291 | if IsChinese(uni_num) and not is_chn_range: 292 | new_trans += "\\text{" 293 | is_chn_range = True 294 | if not IsChinese(uni_num) and is_chn_range: 295 | new_trans += "}" 296 | is_chn_range = False 297 | new_trans += ch 298 | if is_chn_range: 299 | new_trans += "}" 300 | is_chn_range = False 301 | # new_trans = "$" + new_trans + "$" 302 | 303 | new_trans = new_trans.replace("\r\n", "\\\\") 304 | new_trans = new_trans.replace("\r", "\\\\") 305 | new_trans = new_trans.replace("\n", "\\\\") 306 | new_trans = new_trans.replace("\\smear", "") 307 | new_trans = new_trans.replace("\\enter", "\\\\") 308 | new_trans = new_trans.replace("\\space", "\\quad") 309 | new_trans = new_trans.replace("\\unk", "") 310 | 311 | # process "\\" 312 | spts = new_trans.split("\\\\") 313 | new_trans = "\\\\".join(["$ \\rm "+spt+"$" for spt in spts if len(spt)>0]) 314 | 315 | return new_trans 316 | -------------------------------------------------------------------------------- /RFL/text_render.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | import cv2 3 | from chemfig_struct import * 4 | from chemfig_ops import * 5 | from chemfig_parser import * 6 | import image_render 7 | import utils 8 | import math 9 | import numpy 10 | import pdb 11 | from Tool_Formula.latex_norm.transcription import parse_transcription 12 | import Tool_Formula.latex_norm.transcription as trans 13 | import argparse 14 | 15 | def dfs_visited(cur_atom, last_bond=None, angle=0, visited=[], out_text=[], info_arr=[], withAngle=1, angle_step=15, branch_de_amb=0, branch_arr=[], bond_dict=[]): 16 | if last_bond in visited: 17 | return 18 | 19 | if last_bond is not None: 20 | if not cur_atom in visited: 21 | if last_bond.m_type is not None and last_bond.m_type != "" and (last_bond.m_type != "-:" or isinstance(cur_atom, CircleAtom) or cur_atom.m_text == "\\circle"): 22 | out_text.append(last_bond.m_type) 23 | if withAngle > 0: 24 | if angle_step is not None: 25 | normed_angle = int((angle+angle_step/2.0) / angle_step) * angle_step 26 | normed_angle = normed_angle % 360 27 | else: 28 | normed_angle = float("{:.2f}".format(angle)) 29 | normed_angle = normed_angle % 360 30 | out_text[-1] += "[:{}]".format(normed_angle) 31 | bond_dict.append([last_bond, len(out_text)-1]) 32 | # bond_dict[last_bond] = len(out_text)-1 # 维护bond_dictc 33 | branch_arr.append(last_bond.branch_info) # 获取分支键与环相连的相对位置 34 | else: # create hook 35 | start_atom = cur_atom 36 | if cur_atom == last_bond.end_atom: 37 | end_atom = last_bond.begin_atom 38 | else: 39 | end_atom = last_bond.end_atom 40 | if len(start_atom.start_hooks) == 0: 41 | hook_uni = 97 #97~122 42 | while hook_uni in visited: 43 | hook_uni += 1 44 | if hook_uni > 122: 45 | raise ValueError("File {}, line {} :not enough hook name".format(__file__, sys._getframe().f_lineno)) 46 | visited.append(hook_uni) 47 | hook_name = chr(hook_uni) 48 | new_start_hook = DistantHook("?[{}]".format(hook_name)) 49 | start_atom.start_hooks.append(new_start_hook) 50 | if len(start_atom.start_hooks) > 0: #already exists 51 | new_end_hook = DistantHook(start_atom.start_hooks[0].attr_str) 52 | new_end_hook.m_bondtype = last_bond.m_type 53 | end_atom.end_hooks.append(new_end_hook) 54 | 55 | bond_dict.append([last_bond, ]) 56 | 57 | 58 | visited.append(last_bond) 59 | 60 | if cur_atom in visited: 61 | return 62 | 63 | if cur_atom is not None: 64 | out_text.append(cur_atom) 65 | visited.append(cur_atom) 66 | branch_arr.append(None) 67 | # ring_branch_arr.append(None) 68 | 69 | all_bonds = [] 70 | for bond, angle in cur_atom.all_bonds: 71 | if cur_atom == bond.end_atom: 72 | all_bonds.append((bond, angle, bond.begin_atom)) 73 | elif cur_atom == bond.begin_atom: 74 | all_bonds.append((bond, angle, bond.end_atom)) 75 | else: 76 | raise ValueError("atom bond not connect") 77 | 78 | pairs = [] 79 | last_ind = -1 80 | 81 | for child_id, (bond, angle, next_atom) in enumerate(all_bonds): 82 | if branch_de_amb > 0: 83 | out_text.append("branch(") 84 | branch_arr.append(None) 85 | # ring_branch_arr.append(None) 86 | else: 87 | out_text.append("(") 88 | branch_arr.append(None) 89 | # ring_branch_arr.append(None) 90 | begin_ind = len(out_text) - 1 91 | dfs_visited(next_atom, bond, angle, visited, out_text, info_arr, withAngle, angle_step, branch_de_amb=branch_de_amb, branch_arr=branch_arr, bond_dict=bond_dict) 92 | if branch_de_amb > 0: 93 | out_text.append("branch)") 94 | branch_arr.append(None) 95 | # ring_branch_arr.append(None) 96 | else: 97 | out_text.append(")") 98 | branch_arr.append(None) 99 | # ring_branch_arr.append(None) 100 | end_ind = len(out_text) - 1 101 | pairs.append((begin_ind, end_ind)) 102 | if end_ind == begin_ind + 1: 103 | info_arr += [begin_ind, end_ind] 104 | last_ind = len(pairs) - 2 105 | else: 106 | last_ind = len(pairs) - 1 107 | 108 | if len(pairs) > 0 and last_ind >= 0: 109 | info_arr += list(pairs[last_ind]) 110 | 111 | 112 | 113 | def pre_process(rootAtom, scale = 1, preprocess=True, bond_dict=[]): 114 | all_atoms = SimulateCoord(rootAtom, scale=scale) 115 | if preprocess: 116 | all_atoms = RemoveDupAtoms(all_atoms[0], th=scale*0.01) 117 | RemoveDupBonds(all_atoms[0]) 118 | ConnectDistantAtoms(all_atoms[0], bond_dict=bond_dict) # 这里会处理DistanceHook,创建回连的Bond 119 | all_atoms = NormAllCircleAtom(all_atoms[0]) 120 | return all_atoms 121 | 122 | def rend_text(rootAtom, withAngle=1, branch_de_amb=0): 123 | # all_atoms = SimulateCoord(rootAtom, scale=scale) 124 | # all_atoms = RemoveDupAtoms(all_atoms[0], th=scale*0.01) 125 | # RemoveDupBonds(all_atoms[0]) 126 | # all_atoms = NormAllCircleAtom(all_atoms[0]) 127 | all_atoms = GetAllAtoms(rootAtom) 128 | min_x, min_y, max_x, max_y = GetCoordRange(all_atoms) 129 | 130 | #select start 131 | anchor_x = min_x 132 | anchor_y = (min_y+max_y) / 2.0 133 | min_dis = 1e10 134 | min_atom = None 135 | for atom in all_atoms: 136 | cur_dis = math.sqrt(math.pow(anchor_x - atom.pos_x, 2) + math.pow(anchor_y - atom.pos_y, 2)) 137 | cur_dis = atom.pos_x * 10000 + (-atom.pos_y) 138 | if cur_dis < min_dis: 139 | min_dis = cur_dis 140 | min_atom = atom 141 | atom.all_bonds = [] 142 | for bond in atom.out_bonds: 143 | atom.all_bonds.append((bond, bond.m_angle)) 144 | for bond in atom.in_bonds: 145 | atom.all_bonds.append((bond, (bond.m_angle - 180) % 360)) 146 | atom.all_bonds = sorted(atom.all_bonds, key=lambda x: x[1]) 147 | start_atom = min_atom 148 | 149 | 150 | visited = [] 151 | out_seq = [] 152 | info_arr = [] 153 | dfs_visited(start_atom, None, 0, visited, out_seq, info_arr, withAngle, branch_de_amb=branch_de_amb) 154 | 155 | info_arr = set(info_arr) 156 | #new_out_text = [] 157 | new_out_seq = [] 158 | for ind, element in enumerate(out_seq): 159 | if ind in info_arr: 160 | continue #remove rebundant bracket 161 | # new_out_text.append(text) 162 | new_out_seq.append(element) 163 | out_seq = new_out_seq 164 | 165 | #adjust hooks 166 | cur_uni = 97 167 | rep_dict = {} 168 | for element in new_out_seq: 169 | if not isinstance(element, Atom): 170 | continue 171 | if len(element.start_hooks) <= 0: 172 | continue 173 | assert len(element.start_hooks) == 1 174 | hook_name = element.start_hooks[0].m_hookname 175 | rep_dict[hook_name] = chr(cur_uni) 176 | cur_uni += 1 177 | for element in new_out_seq: 178 | if not isinstance(element, Atom): 179 | continue 180 | for hook in element.start_hooks + element.end_hooks: 181 | hook.m_hookname = rep_dict[hook.m_hookname] 182 | 183 | out_text = [] 184 | for element in new_out_seq: 185 | if isinstance(element, str): 186 | out_text.append(element) 187 | elif isinstance(element, Atom): 188 | # if element.name=="Atom_0": 189 | # pdb.set_trace() 190 | if isinstance(element, CircleAtom): 191 | out_text += ["\\circle"] 192 | else: 193 | out_text += element.normed_text() 194 | element.start_hooks = [] # must clear hooks here, other wise it will affect following operation!!! 195 | element.end_hooks = [] # 196 | elif isinstance(element, Bond): 197 | last_bond = element 198 | cur_atom = last_bond.end_atom 199 | angle = math.atan2(last_bond.end_atom.pos_y - last_bond.begin_atom.pos_y, last_bond.end_atom.pos_x - last_bond.begin_atom.pos_x) 200 | if last_bond.m_type is not None and last_bond.m_type != "" and (last_bond.m_type != "-:" or isinstance(cur_atom, CircleAtom) or cur_atom.m_text == "\\circle"): 201 | out_text.append(last_bond.m_type) 202 | if withAngle > 0: 203 | normed_angle = int((angle+7.5) / 15.0) * 15 204 | normed_angle = normed_angle % 360 205 | out_text.append("[:{}]".format(normed_angle)) 206 | 207 | 208 | 209 | atom_count = 0 210 | bond_count = 0 211 | other_count = 0 212 | for ele in visited: 213 | if isinstance(ele, Atom): 214 | atom_count += 1 215 | elif isinstance(ele, Bond): 216 | bond_count += 1 217 | else: 218 | other_count += 1 219 | # pdb.set_trace() 220 | # assert atom_count == len(all_atoms), "{} vs {}".format(atom_count, len(all_atoms)) 221 | # if atom_count != len(all_atoms): 222 | # print("------------------{} vs {} ---------------------".format(len(all_atoms), atom_count)) 223 | # print(sorted(all_atoms, key=lambda x:x.name)) 224 | # visited_atoms = [] 225 | # for element in visited: 226 | # if isinstance(element, Atom): 227 | # visited_atoms.append(element) 228 | # print(sorted(visited_atoms, key = lambda x:x.name)) 229 | # print("-------------------------------------------------------") 230 | return out_text 231 | 232 | # interface for text render 233 | def text_render(in_text, debug=False, branch_de_amb=0): 234 | _, rep_dict, text = utils.replace_chemfig(in_text) 235 | new_text = parse_transcription(text, simple_trans=True) 236 | 237 | new_rep_dict = {} 238 | for key, value in rep_dict.items(): 239 | chemfig_str = value.strip()[8:].strip()[1:-1] 240 | rootAtom, _ = chemfig_parse(chemfig_str, echo=debug) 241 | all_atoms = pre_process(rootAtom, scale=1) 242 | chemfig_units = rend_text(all_atoms[0], branch_de_amb=branch_de_amb) 243 | chemfig_units = ["\chemfig", "{"] + chemfig_units + ["}"] 244 | new_rep_dict[key] = chemfig_units 245 | if debug: 246 | 247 | all_atoms = SimulateCoord(all_atoms[0], scale=100) 248 | img_key = key.replace("\\chem", "") 249 | cv2.imwrite("./debug/rend_{}.jpg".format(img_key), image_render.rend_atoms(all_atoms, scale=100, rend_name=0)) 250 | cv2.imwrite("./debug/rend_name_{}.jpg".format(img_key), image_render.rend_atoms(all_atoms, scale=100, rend_name=1)) 251 | 252 | out_text = [] 253 | for unit in new_text: 254 | if unit in new_rep_dict: 255 | out_text += new_rep_dict[unit] 256 | else: 257 | out_text.append(unit) 258 | out_text = " ".join(out_text) 259 | return out_text 260 | 261 | 262 | def main(args): 263 | inputLines = [] 264 | if os.path.exists(args.input): 265 | with open(args.input, "r") as fin: 266 | lines = fin.readlines() 267 | for line in lines: 268 | inputLines.append(line.strip()) 269 | elif len(args.input)>1: 270 | inputLines.append(args.input) 271 | else: 272 | s1 = "\chemfig{**6(-=---(-OH)-)}" 273 | inputLines.append(s1) 274 | 275 | import texlive_rend 276 | for _id, inputStr in enumerate(inputLines): 277 | if _id < args.start: 278 | continue 279 | if inputStr.startswith("#") or len(inputStr)<2: 280 | continue 281 | out_text = text_render(inputStr, debug=True) 282 | rend_img = texlive_rend.texlive_rend("$"+inputStr+"$") 283 | cv2.imwrite("./debug/demo.jpg", rend_img) 284 | print(inputStr) 285 | print(out_text) 286 | pdb.set_trace() 287 | 288 | 289 | 290 | if __name__ == "__main__": 291 | parser = argparse.ArgumentParser("") 292 | parser.add_argument("-input", type=str, default="\chemfig{**6(-=---(-OH)-)}") 293 | parser.add_argument("-start", type=int, default=0) 294 | args = parser.parse_args() 295 | main(args) 296 | 297 | -------------------------------------------------------------------------------- /RFL/graph_cmp.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | import pdb 3 | import argparse 4 | import Levenshtein 5 | import chemfig_struct 6 | import chemfig_ops 7 | import chemfig_parser 8 | import text_render 9 | from chemfig_struct import * 10 | import utils 11 | import tqdm 12 | import cv2 13 | import math 14 | import numpy 15 | 16 | # interface 17 | def compare_graph(rootA: Atom, rootB: Atom): 18 | # pdb.set_trace() 19 | res_text, textA, textB = match_text(rootA, rootB) 20 | res_graph = match_graph(rootA, rootB) 21 | if res_text == 1 and res_graph == 1 or True: 22 | # print("res_text={}".format(res_text)) 23 | # print("res_graph={}".format(res_graph)) 24 | # import image_render 25 | # cv2.imwrite("a.jpg", image_render.rend(rootA, scale=100)) 26 | # cv2.imwrite("a2.jpg", image_render.rend(rootA, scale=100, rend_name=1)) 27 | 28 | # cv2.imwrite("b.jpg", image_render.rend(rootB, scale=100)) 29 | # cv2.imwrite("b2.jpg", image_render.rend(rootB, scale=100, rend_name=1)) 30 | # print("textA: {}".format(textA)) 31 | # print("textB: {}".format(textB)) 32 | # pdb.set_trace() 33 | pass 34 | if res_text == 0 or res_graph == 0: 35 | return 0 36 | else: 37 | return 1 38 | 39 | def match_text(rootA: Atom, rootB: Atom): 40 | arr_A = text_render.rend_text(rootA, 1) 41 | arr_B = text_render.rend_text(rootB, 1) 42 | 43 | textA = " ".join(arr_A) 44 | textB = " ".join(arr_B) 45 | textA_noAngle = " ".join(utils.removeAngle(arr_A)) 46 | textB_noAngle = " ".join(utils.removeAngle(arr_B)) 47 | if textA_noAngle == textB_noAngle: 48 | # cv2.imwrite("a.jpg", chemfig_render.rend(rootA, scale=100)) 49 | # cv2.imwrite("a2.jpg", chemfig_render.rend(rootA, scale=100, rend_name=1)) 50 | # cv2.imwrite("b.jpg", chemfig_render.rend(rootB, scale=100)) 51 | # print("txtA: {}".format(txtA)) 52 | # print("txtB: {}".format(txtB)) 53 | # print("textA: {}".format(textA)) 54 | # print("textB: {}".format(textB)) 55 | # # print(os.popen("cat show.txt").readlines()[0]) 56 | # pdb.set_trace() 57 | return 0, textA, textB 58 | else: 59 | # with open("debug.txt", "a") as fout: 60 | # fout.write("txtA: {}\n".format(txtA)) 61 | # fout.write("txtA: {}\n".format(txtB)) 62 | # fout.write("textA: {}\n".format(textA)) 63 | # fout.write("textB: {}\n".format(textB)) 64 | # fout.write("\n") 65 | # cv2.imwrite("a.jpg", chemfig_render.rend(rootA, scale=100)) 66 | # cv2.imwrite("a2.jpg", chemfig_render.rend(rootA, scale=100, rend_name=1)) 67 | # cv2.imwrite("b.jpg", chemfig_render.rend(rootB, scale=100)) 68 | # cv2.imwrite("b2.jpg", chemfig_render.rend(rootB, scale=100, rend_name=1)) 69 | # print("txtA: {}".format(txtA)) 70 | # print("txtB: {}".format(txtB)) 71 | # print("textA: {}".format(textA)) 72 | # print("textB: {}".format(textB)) 73 | # print(os.popen("cat show.txt").readlines()[0]) 74 | # pdb.set_trace() 75 | 76 | # print("textA: {}".format(textA)) 77 | # print("textB: {}".format(textB)) 78 | # print("-----------------------") 79 | return 1, textA, textB 80 | 81 | def cmp_bond_type(bondA:Bond, bondB:Bond): 82 | bond_same = False 83 | if bondA.m_type in chemfig_struct.directed_bond_types and bondB.m_type in chemfig_struct.directed_bond_types: 84 | bond_delta_angle = math.fabs(bondA.m_angle - bondB.m_angle) 85 | if bond_delta_angle > 180: 86 | bond_delta_angle = 360 - bond_delta_angle 87 | # if bond_delta_angle > 90: 88 | # bond_same = (bondA.m_type == chemfig_struct.directed_bond_types[bondB.m_type]) 89 | # else: 90 | # bond_same = (bondA.m_type == bondB.m_type) 91 | bond_same = ((bondA.m_type == chemfig_struct.directed_bond_types[bondB.m_type]) or (bondA.m_type == bondB.m_type)) 92 | else: 93 | bond_same = (bondA.m_type == bondB.m_type) 94 | return bond_same 95 | 96 | def compare_atom_dist(atomA: Atom, atomB: Atom): 97 | #compare text 98 | ed_ops = utils.cal_edit_ops(atomA.normed_text(), atomB.normed_text()) 99 | ed_dist = len(ed_ops) 100 | #compare degree 101 | degree_dist = math.fabs(atomA.degree - atomB.degree) 102 | #compare content 103 | if degree_dist > 0: 104 | content_dist = 1e10 105 | contentA = atomA.content_arr 106 | contentB = atomB.content_arr 107 | if atomA.degree < atomB.degree: 108 | shift_content = contentB 109 | cmp_content = contentA 110 | else: 111 | shift_content = contentA 112 | cmp_content = contentB 113 | min_shift_content = None 114 | min_content_dist = math.inf 115 | for shift in range(0, len(shift_content)): 116 | ref_content = shift_content[shift:] + shift_content[:shift] 117 | cur_content_dist = 0 118 | for ind in range(len(cmp_content)): 119 | ref_item = ref_content[ind] #angle, bond_type, tgt_text, tgt_degree 120 | cmp_item = cmp_content[ind] 121 | # cur_content_dist += math.fabs(ref_item[0] - cmp_item[0]) #angle 122 | delta_angle = (ref_item[0] - cmp_item[0])%360 123 | if delta_angle > 180: 124 | delta_angle = 360 - delta_angle 125 | cur_content_dist += delta_angle #angle 126 | cur_content_dist += (not cmp_bond_type(ref_item[1], cmp_item[1])) * 10.0 # update @2022.07.20 haowu16 127 | cur_content_dist += math.fabs(ref_item[2].degree - cmp_item[2].degree) * 30 # degree 128 | if cur_content_dist < min_content_dist: 129 | min_content_dist = cur_content_dist 130 | min_shift_content = ref_content 131 | content_dist = min_content_dist 132 | 133 | total_distance = ed_dist*10 + degree_dist*30 + content_dist 134 | if atomA.degree < atomB.degree: 135 | match_result = (contentA, min_shift_content) 136 | else: 137 | match_result = (min_shift_content, contentB) 138 | return total_distance, match_result 139 | 140 | def GuidedWalk(start_atomA, start_atomB, debug=False): 141 | atom_stackA = [(None, start_atomA)] 142 | atom_stackB = [(None, start_atomB)] 143 | visitedA = set() 144 | visitedB = set() 145 | while len(atom_stackA) > 0: 146 | cur_BondA, cur_atomA = atom_stackA.pop() 147 | cur_BondB, cur_atomB = atom_stackB.pop() 148 | if cur_atomA in visitedA: 149 | if cur_atomB not in visitedB: 150 | if debug: 151 | pdb.set_trace() 152 | return False 153 | else: 154 | continue 155 | if cur_atomB in visitedB: 156 | return False 157 | if debug: 158 | print("A={} B={}".format(cur_atomA.name, cur_atomB.name)) 159 | atomA_text = cur_atomA.normed_text() if not isinstance(cur_atomA, CircleAtom) else ["\\circle"] 160 | atomB_text = cur_atomB.normed_text() if not isinstance(cur_atomB, CircleAtom) else ["\\circle"] 161 | atomA_text = utils.post_process(atomA_text) 162 | atomB_text = utils.post_process(atomB_text) 163 | 164 | if atomA_text != atomB_text: 165 | if debug: 166 | pdb.set_trace() 167 | return False 168 | visitedA.add(cur_atomA) 169 | visitedB.add(cur_atomB) 170 | contentA = cur_atomA.content_arr 171 | contentB = cur_atomB.content_arr 172 | if len(contentA) != len(contentB): 173 | if debug: 174 | pdb.set_trace() 175 | return False 176 | 177 | min_shift_contentB = [] 178 | min_content_dist = math.inf 179 | for shift in range(0, len(contentB)): 180 | ref_content = contentB[shift:] + contentB[:shift] 181 | cur_content_dist = 0 182 | for ind in range(len(contentA)): 183 | ref_item = ref_content[ind] #angle, bond_type, tgt_text, tgt_degree 184 | cmp_item = contentA[ind] 185 | # cur_content_dist += math.fabs(ref_item[0] - cmp_item[0]) #angle 186 | delta_angle = (ref_item[0] - cmp_item[0])%360 187 | if delta_angle > 180: 188 | delta_angle = 360 - delta_angle 189 | cur_content_dist += delta_angle #angle 190 | cur_content_dist += (not cmp_bond_type(ref_item[1], cmp_item[1])) * 10.0 #bond_type 191 | cur_content_dist += math.fabs(ref_item[2].degree - cmp_item[2].degree) * 30 # degree 192 | if cur_content_dist < min_content_dist: 193 | min_content_dist = cur_content_dist 194 | min_shift_contentB = ref_content 195 | for itemA, itemB in zip(contentA, min_shift_contentB): 196 | #if itemA[1].m_type != itemB[1].m_type: 197 | if not cmp_bond_type(itemA[1], itemB[1]): 198 | if debug: 199 | pdb.set_trace() 200 | return False 201 | atom_stackA.append((itemA[1], itemA[2])) 202 | atom_stackB.append((itemB[1], itemB[2])) 203 | return True 204 | 205 | def pre_process_for_cmp(rootAtom:Atom): 206 | all_atoms = chemfig_ops.SimulateCoord(rootAtom, scale=1) 207 | all_atoms = chemfig_ops.RemoveDupAtoms(all_atoms[0]) 208 | chemfig_ops.RemoveDupBonds(all_atoms[0]) 209 | chemfig_ops.ConnectDistantAtoms(all_atoms[0]) 210 | all_atoms = chemfig_ops.NormAllCircleAtom(all_atoms[0]) 211 | return all_atoms 212 | 213 | def match_graph(rootA: Atom, rootB: Atom): 214 | a_atoms = chemfig_ops.NormAllCircleAtom(rootA, all_connect=1) 215 | b_atoms = chemfig_ops.NormAllCircleAtom(rootB, all_connect=1) 216 | 217 | a_name2idx = dict([(atom.name, ind) for ind, atom in enumerate(a_atoms)]) 218 | b_name2idx = dict([(atom.name, ind) for ind, atom in enumerate(b_atoms)]) 219 | 220 | #calc atom attribute [degree, content] content_arr 221 | for child_atom in a_atoms + b_atoms: 222 | child_atom.degree = len(child_atom.in_bonds + child_atom.out_bonds) 223 | child_atom.content_arr = [] #angle, bond_type, tgt_text, tgt_degree 224 | for in_bond in child_atom.in_bonds + child_atom.out_bonds: 225 | tgt_atom = in_bond.begin_atom if in_bond.begin_atom != child_atom else in_bond.end_atom 226 | angle = math.atan2(-tgt_atom.pos_y + child_atom.pos_y, tgt_atom.pos_x - child_atom.pos_x) * 180.0 / math.pi 227 | angle = angle % 360 228 | child_atom.content_arr.append((angle, in_bond, tgt_atom)) 229 | child_atom.content_arr = sorted(child_atom.content_arr, key=lambda x: x[0]) 230 | 231 | #compare atom 232 | dist_mat = numpy.zeros((len(a_atoms), len(b_atoms)), dtype=numpy.float) #lenA, lenB 233 | atom_match_results = [] 234 | for child_idA, child_atomA in enumerate(a_atoms): 235 | cur_dist_arr = [] 236 | perfect_num = 0 237 | for child_idB, child_atomB in enumerate(b_atoms): 238 | total_distance, match_result = compare_atom_dist(child_atomA, child_atomB) 239 | cur_dist_arr.append((child_idB, total_distance, match_result)) 240 | dist_mat[child_idA, child_idB] = total_distance 241 | if total_distance < 1e-6: 242 | perfect_num += 1 243 | cur_dist_arr = sorted(cur_dist_arr, key=lambda x: x[1]) 244 | atom_match_results.append((child_idA, cur_dist_arr, perfect_num)) 245 | 246 | # select pairs 247 | a_perfect_match_num = (dist_mat==0).sum(1, keepdims=True) 248 | b_perfect_match_num = (dist_mat==0).sum(0, keepdims=True) 249 | # a_match_dist = (a_perfect_match_num == 1)*0 + (a_perfect_match_num > 1)*a_perfect_match_num + (a_perfect_match_num == 0)*1000 250 | # b_match_dist = (b_perfect_match_num == 1)*0 + (b_perfect_match_num > 1)*b_perfect_match_num + (b_perfect_match_num == 0)*1000 251 | match_dist = ~((a_perfect_match_num == 1)&(b_perfect_match_num == 1)) 252 | match_dist = match_dist.astype("float") 253 | 254 | cond_match_dist = match_dist * (dist_mat == 0) * 1000 + (dist_mat > 0) * 2000 255 | fused_dist = dist_mat + cond_match_dist 256 | 257 | pairs = [(i,j) for j in range(len(b_atoms)) for i in range(len(a_atoms))] 258 | pairs = sorted(pairs, key = lambda x:fused_dist[x[0],x[1]]) 259 | # for a_ind, b_ind in pairs[0:10]: 260 | # print("{}-{}: {} + {}".format(a_atoms[a_ind].name, b_atoms[b_ind].name, dist_mat[a_ind, b_ind], cond_match_dist[a_ind, b_ind])) 261 | #pdb.set_trace() 262 | 263 | max_try_count = 3 264 | cur_try_count = 0 265 | for indA, indB in pairs: 266 | start_atomA = a_atoms[indA] #reference 267 | start_atomB = b_atoms[indB] 268 | cmp_result = GuidedWalk(start_atomA, start_atomB, False) 269 | cur_try_count += 1 270 | if cmp_result is False: 271 | if cur_try_count > max_try_count: 272 | return 1 273 | else: 274 | continue 275 | else: 276 | return 0 277 | return 1 278 | 279 | -------------------------------------------------------------------------------- /test_list_multi.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | import os 3 | import sys 4 | import cv2 5 | import logging 6 | logging.basicConfig(level=logging.INFO, 7 | format='%(asctime)s[%(levelname)s] %(name)s -%(message)s', 8 | ) 9 | logger = logging.getLogger(__name__) 10 | import torch 11 | import torch.nn as nn 12 | from multiprocessing import Process, synchronize, Lock 13 | 14 | from rain import xconfig 15 | from rain.bucket_io import build_data, DataPartioner, build_data_test 16 | from rain.beam_search import BeamSearcher 17 | from rain.initializer import initialize_model 18 | from rain.utils import AccPerformance, collate_torch 19 | from rain.utils import read_image, load_det_sections, parse_title 20 | import argparse 21 | from tqdm import tqdm 22 | import numpy as np 23 | 24 | def filter_list(in_str_list, min_dul_num = 4): 25 | 26 | len_str = len(in_str_list) 27 | max_dul = len_str // min_dul_num 28 | is_del = False 29 | 30 | 31 | for n_dul in range(1, max_dul, 1): 32 | cur_dul_list = in_str_list[-n_dul:] 33 | cur_n_dul = n_dul + n_dul 34 | cur_count = 0 35 | while cur_n_dul <= len_str: 36 | if in_str_list[-cur_n_dul:(-cur_n_dul + n_dul)] == cur_dul_list: 37 | cur_count += 1 38 | cur_n_dul += n_dul 39 | else: 40 | break 41 | 42 | if cur_count >= min_dul_num: 43 | is_del = True 44 | 45 | in_str_list = in_str_list[:(-cur_n_dul + n_dul)] 46 | break 47 | 48 | 49 | if is_del: 50 | return filter_list(in_str_list) 51 | else: 52 | return in_str_list 53 | 54 | def post_filter(in_file, out_file): 55 | 56 | with open(in_file, 'r') as F: 57 | in_str = F.readlines() 58 | 59 | out_str = '' 60 | postfix = 200 61 | times = 5 62 | for in_line in in_str: 63 | cur_line = in_line.strip().split('\t') 64 | rst_str_list = cur_line[1].split(' ') 65 | 66 | for idx, sstr in enumerate(rst_str_list): 67 | if sstr == '' or sstr == '': 68 | rst_str_list = rst_str_list[:idx] 69 | break 70 | 71 | for _ in range(times): 72 | rst_str_list = filter_list(rst_str_list) 73 | 74 | c_postfix = min(postfix, len(rst_str_list)) 75 | 76 | 77 | for xt in range(1, c_postfix): 78 | rst_str_list = filter_list(rst_str_list[:-xt]) + rst_str_list[-xt:] 79 | 80 | cur_line[1] = ' '.join(rst_str_list) 81 | out_str += '%s\t%s\t%s\n' % (cur_line[0], cur_line[1], cur_line[2]) 82 | 83 | with open(out_file, 'w') as F: 84 | F.write(out_str) 85 | 86 | class DataPartionerTest(object): 87 | def __init__(self, data, size=1, rank=0, seed = 0): 88 | self.data = data 89 | self.partitions = [] 90 | self.part_len = len(data) // size 91 | self.rank = rank 92 | 93 | import random 94 | random.seed(seed) 95 | random.shuffle(self.data._batchs) 96 | begin = 0 97 | for i in range(size - 1): 98 | 99 | self.partitions.append(self.data._batchs[begin:begin + self.part_len]) 100 | begin += self.part_len 101 | 102 | self.partitions.append(self.data._batchs[begin:]) 103 | 104 | 105 | assert len(self.partitions) == size 106 | 107 | def __len__(self, ): 108 | return len(self.partitions[self.rank]) 109 | 110 | def __getitem__(self, index): 111 | #index = (index+1) % len(self) 112 | return self.data.get_item(self.partitions[self.rank], index) 113 | 114 | 115 | def test(load_epoch, gpu, rank_id, logname, lock=None): 116 | 117 | logger.info("Start testing epoch %d on gpu %d" % (load_epoch, gpu)) 118 | data_set = build_data_test( 119 | max_h = xconfig.test_max_height, 120 | max_w = xconfig.test_max_width, 121 | max_l = xconfig.test_max_length, 122 | fix_batch_size = xconfig.test_fix_batch_size, 123 | max_batch_size = xconfig.max_batch_size, 124 | max_image_size = xconfig.max_image_size, 125 | seed = xconfig.seed, 126 | do_shuffle = False, 127 | use_all = True, 128 | last_method = 'fill', 129 | one_key = False, 130 | return_name = True, 131 | image_list_file = xconfig.test_image_list_path, 132 | det_sections_file= xconfig.test_det_sections, 133 | normh = xconfig.test_image_normh, 134 | do_test = True 135 | ) 136 | data_partition = DataPartionerTest(data_set, size = test_world_size, rank=rank_id) 137 | data_loader = torch.utils.data.DataLoader(dataset=data_partition, batch_size=1, num_workers=2, 138 | collate_fn=collate_torch, shuffle = False) 139 | 140 | # prepare 141 | torch.cuda.set_device(gpu) 142 | tester_model = BeamSearcher(vocab_size=xconfig.vocab_size, sos=xconfig.sos, eos=xconfig.eos, 143 | beam=xconfig.beam, frame_per_char=xconfig.frame_per_char) 144 | initialize_model(tester_model, xconfig.model_prefix, xconfig.model_type, load_epoch) 145 | tester_model.cuda() 146 | tester_model.eval() 147 | 148 | 149 | logfile = open(logname + '_nofilter', 'w') 150 | acc_metric = AccPerformance(ignores={xconfig.sos,xconfig.eos,xconfig.enter}) 151 | names_set = set() 152 | avg_cost = 0.0 153 | with torch.no_grad(): 154 | if rank_id == 0: 155 | data_loader = tqdm(data_loader) 156 | for data, data_mask, target, target_mask, names_list in data_loader: 157 | data = data.cuda() 158 | data_mask = data_mask.cuda() 159 | batch_size = len(names_list) 160 | # print(data.shape, data_mask.shape) 161 | preds_batch, costs_batch = tester_model.search_gpu(data, data_mask) 162 | 163 | #import pdb; pdb.set_trace() 164 | 165 | for i in range(batch_size): 166 | if names_list[i] in names_set: 167 | continue 168 | else: 169 | names_set.add(names_list[i]) 170 | #label = target[i, :int(target_mask[i].sum())].tolist() 171 | pred = preds_batch[i][0] 172 | cost = costs_batch[i][0] 173 | #acc_metric.evaluate(label, pred) 174 | #lab_str = ' '.join([xconfig.vocab.getWord(wid) for wid in label]) 175 | pred = pred[1:-1] 176 | pre_str = ' '.join([xconfig.vocab.getWord(wid) for wid in pred]) 177 | 178 | logfile.write('{}\t{}\t{}\n'.format(names_list[i], pre_str, cost)) 179 | avg_cost += cost 180 | logfile.flush() 181 | #avg_cost /= len(names_set) 182 | #logger.info('decode on epoch {}, acc={}, cost={}'.format(load_epoch, acc_metric.get_performance(), avg_cost)) 183 | logfile.close() 184 | logger.info("End testing epoch %d on gpu %d" % (load_epoch, gpu)) 185 | 186 | # post_filter(logname, logname + '_filter') 187 | post_filter(logname + '_nofilter', logname) 188 | 189 | 190 | 191 | def norm_check(keyfile, det_sections_file=None, key_info_file = None): 192 | if det_sections_file is not None: 193 | det_sections = load_det_sections(det_sections_file) 194 | else: 195 | det_sections = None 196 | 197 | if key_info_file is None: 198 | key_info_file = keyfile + '_info.txt' 199 | # if os.path.exists(key_info_file): 200 | # return 201 | fout = open(key_info_file, 'w') 202 | with open(keyfile, 'r') as fin: 203 | all_line = fin.readlines() 204 | for line in tqdm(all_line): 205 | line_all = line.strip() 206 | imgfile = line_all 207 | img_title = parse_title(imgfile) 208 | if det_sections is not None: 209 | img_segmentation = det_sections[img_title] 210 | else: 211 | img_segmentation = None 212 | label = ' ' 213 | img = read_image(imgfile, xconfig.test_image_normh, img_segmentation, xconfig.img_fix_char_height) 214 | if img is None: 215 | logger.info('{} is None'.format(imgfile)) 216 | continue 217 | vocab = xconfig.vocab 218 | trans_label = [vocab.getID(l) for l in label.split(' ')] 219 | item = (imgfile, int(img.shape[1]), int(img.shape[2]), int(len(trans_label)), 0, 0) 220 | fout.write('{} {} {} {} {} {}\n'.format(imgfile, int(img.shape[1]), int(img.shape[2]), 221 | int(len(trans_label)), 0, 0)) 222 | fout.close() 223 | 224 | 225 | 226 | 227 | if __name__ == '__main__': 228 | 229 | parser = argparse.ArgumentParser("OCR MultiProcess MultiCPU Single Epoch Testing") 230 | parser.add_argument('--process_per_gpu', type=int) 231 | parser.add_argument('--test_epoch', type=int) 232 | parser.add_argument('--used_gpu_id', type=str) 233 | parser.add_argument('--test_image_list_path', type=str) 234 | parser.add_argument('--img_fix_char_height', type=int, default=None) 235 | parser.add_argument('--test_det_sections', type=str, default=None) 236 | 237 | parser.add_argument('--norm_check_cpus', type=int, default = 16) 238 | parser.add_argument('--do_norm_check', type=int, default = 1) 239 | args = parser.parse_args() 240 | #import pdb; pdb.set_trace() 241 | 242 | xconfig.test_image_list_path = args.test_image_list_path 243 | xconfig.test_log_out_path = os.path.join( os.path.dirname(args.test_image_list_path), 'pred.trans') 244 | 245 | xconfig.img_fix_char_height = args.img_fix_char_height 246 | xconfig.test_det_sections = args.test_det_sections 247 | 248 | # norm_check(xconfig.test_image_list_path, args.test_det_sections) 249 | 250 | if args.do_norm_check > 0: 251 | norm_check_name_list = list() 252 | norm_check_pre_list = list() 253 | total_check_file = xconfig.test_image_list_path + '_info.txt' 254 | 255 | with open(xconfig.test_image_list_path, 'r') as F: 256 | all_lines = F.readlines() 257 | len_images = len(all_lines) 258 | sep_list_id = list(np.linspace(0, len_images, args.norm_check_cpus + 1, dtype = int)) 259 | 260 | for idx in range(args.norm_check_cpus): 261 | norm_check_name_list.append(total_check_file + '_part%d.txt'%(idx)) 262 | cur_name = xconfig.test_image_list_path + '_part%d.txt'%(idx) 263 | with open(cur_name, 'w') as F: 264 | for jdx in range(sep_list_id[idx], sep_list_id[idx + 1], 1): 265 | F.write(all_lines[jdx]) 266 | norm_check_pre_list.append(cur_name) 267 | 268 | records = [] 269 | for f_in, f_out in zip(norm_check_pre_list, norm_check_name_list): 270 | 271 | p = Process(target=norm_check, args=(f_in, args.test_det_sections, f_out)) 272 | p.start() 273 | records.append(p) 274 | 275 | for p in records: 276 | p.join() 277 | cat_str = 'cat' 278 | for xt in norm_check_name_list: 279 | cat_str += ' %s'%(xt) 280 | cat_str += ' > %s'%(total_check_file) 281 | os.system(cat_str) 282 | for xt in norm_check_pre_list: 283 | rm_str = 'rm %s' %(xt) 284 | os.system(rm_str) 285 | for xt in norm_check_name_list: 286 | rm_str = 'rm %s' %(xt) 287 | os.system(rm_str) 288 | 289 | # import pdb; pdb.set_trace() 290 | test_epoch = args.test_epoch 291 | 292 | uesd_gpus = [int(g) for g in args.used_gpu_id.split(",")] 293 | process_per_gpu = args.process_per_gpu 294 | 295 | 296 | resource_list = [] 297 | log_name_list = [] 298 | part_id = 0 299 | for igpu in uesd_gpus: 300 | for icpu in range(process_per_gpu): 301 | resource_list.append(igpu) 302 | log_name_list.append('{}_part_{}.log'.format(xconfig.test_log_out_path, part_id)) 303 | part_id += 1 304 | 305 | test_world_size = len(resource_list) 306 | 307 | 308 | # start test 309 | lock = Lock() 310 | records = [] 311 | 312 | #test(test_epoch, resource_list[0], 0, log_name_list[0], lock) 313 | 314 | for i, use_gpu in enumerate(resource_list): 315 | 316 | p = Process(target=test, args=(test_epoch, use_gpu, i, log_name_list[i], lock)) 317 | p.start() 318 | records.append(p) 319 | 320 | for p in records: 321 | p.join() 322 | 323 | cat_str = 'cat' 324 | logname_all = xconfig.test_log_out_path 325 | for xt in log_name_list: 326 | cat_str += ' %s'%(xt) 327 | cat_str += ' > %s'%(logname_all) 328 | os.system(cat_str) 329 | 330 | for xt in log_name_list: 331 | rm_str = 'rm %s' %(xt) 332 | os.system(rm_str) 333 | 334 | -------------------------------------------------------------------------------- /RFL/Tool_Formula/latex_norm/katex_valid_symbols.map: -------------------------------------------------------------------------------- 1 | \equiv 2 | \prec 3 | \succ 4 | \sim 5 | \perp 6 | \preceq 7 | \succeq 8 | \simeq 9 | \mid 10 | \asymp 11 | \parallel 12 | \bowtie 13 | \smile 14 | \sqsubseteq 15 | \sqsupseteq 16 | \doteq 17 | \frown 18 | \propto 19 | \vdash 20 | \dashv 21 | \owns 22 | \ldotp 23 | \cdotp 24 | \aleph 25 | \forall 26 | \hbar 27 | \exists 28 | \nabla 29 | \flat 30 | \ell 31 | \natural 32 | \clubsuit 33 | \sharp 34 | \diamondsuit 35 | \heartsuit 36 | \spadesuit 37 | \dag 38 | \ddag 39 | \rmoustache 40 | \lmoustache 41 | \rgroup 42 | \lgroup 43 | \ominus 44 | \uplus 45 | \sqcap 46 | \ast 47 | \sqcup 48 | \bigcirc 49 | \bullet 50 | \ddagger 51 | \amalg 52 | \And 53 | \longleftarrow 54 | \Leftarrow 55 | \Longleftarrow 56 | \longrightarrow 57 | \Rightarrow 58 | \Longrightarrow 59 | \leftrightarrow 60 | \longleftrightarrow 61 | \Leftrightarrow 62 | \Longleftrightarrow 63 | \mapsto 64 | \longmapsto 65 | \nearrow 66 | \hookleftarrow 67 | \hookrightarrow 68 | \searrow 69 | \leftharpoonup 70 | \rightharpoonup 71 | \swarrow 72 | \leftharpoondown 73 | \rightharpoondown 74 | \nwarrow 75 | \rightleftharpoons 76 | \nless 77 | \lneq 78 | \lneqq 79 | \lnsim 80 | \lnapprox 81 | \nprec 82 | \npreceq 83 | \precnsim 84 | \precnapprox 85 | \nsim 86 | \nmid 87 | \nvdash 88 | \nvDash 89 | \ntriangleleft 90 | \ntrianglelefteq 91 | \subsetneq 92 | \subsetneqq 93 | \ngtr 94 | \gneq 95 | \gneqq 96 | \gnsim 97 | \gnapprox 98 | \nsucc 99 | \nsucceq 100 | \succnsim 101 | \succnapprox 102 | \ncong 103 | \nparallel 104 | \nVDash 105 | \ntriangleright 106 | \ntrianglerighteq 107 | \supsetneq 108 | \supsetneqq 109 | \nVdash 110 | \precneqq 111 | \succneqq 112 | \unlhd 113 | \unrhd 114 | \nleftarrow 115 | \nrightarrow 116 | \nLeftarrow 117 | \nRightarrow 118 | \nleftrightarrow 119 | \nLeftrightarrow 120 | \vartriangle 121 | \hslash 122 | \triangledown 123 | \lozenge 124 | \circledS 125 | \circledR 126 | \measuredangle 127 | \nexists 128 | \mho 129 | \Finv 130 | \Game 131 | \backprime 132 | \blacktriangle 133 | \blacktriangledown 134 | \blacksquare 135 | \blacklozenge 136 | \bigstar 137 | \sphericalangle 138 | \complement 139 | \eth 140 | \diagup 141 | \diagdown 142 | \square 143 | \Box 144 | \Diamond 145 | \yen 146 | \checkmark 147 | \beth 148 | \daleth 149 | \gimel 150 | \digamma 151 | \varkappa 152 | \ulcorner 153 | \urcorner 154 | \llcorner 155 | \lrcorner 156 | \leqq 157 | \leqslant 158 | \eqslantless 159 | \lesssim 160 | \lessapprox 161 | \approxeq 162 | \lessdot 163 | \lll 164 | \lessgtr 165 | \lesseqgtr 166 | \lesseqqgtr 167 | \doteqdot 168 | \risingdotseq 169 | \fallingdotseq 170 | \backsim 171 | \backsimeq 172 | \subseteqq 173 | \Subset 174 | \sqsubset 175 | \preccurlyeq 176 | \curlyeqprec 177 | \precsim 178 | \precapprox 179 | \vartriangleleft 180 | \trianglelefteq 181 | \vDash 182 | \Vvdash 183 | \smallsmile 184 | \smallfrown 185 | \bumpeq 186 | \Bumpeq 187 | \geqq 188 | \geqslant 189 | \eqslantgtr 190 | \gtrsim 191 | \gtrapprox 192 | \gtrdot 193 | \ggg 194 | \gtrless 195 | \gtreqless 196 | \gtreqqless 197 | \eqcirc 198 | \circeq 199 | \triangleq 200 | \thicksim 201 | \thickapprox 202 | \supseteqq 203 | \Supset 204 | \sqsupset 205 | \succcurlyeq 206 | \curlyeqsucc 207 | \succsim 208 | \succapprox 209 | \vartriangleright 210 | \trianglerighteq 211 | \Vdash 212 | \shortmid 213 | \shortparallel 214 | \between 215 | \pitchfork 216 | \varpropto 217 | \blacktriangleleft 218 | \therefore 219 | \backepsilon 220 | \blacktriangleright 221 | \because 222 | \llless 223 | \gggtr 224 | \lhd 225 | \rhd 226 | \eqsim 227 | \Join 228 | \Doteq 229 | \dotplus 230 | \smallsetminus 231 | \Cap 232 | \Cup 233 | \doublebarwedge 234 | \boxminus 235 | \boxplus 236 | \divideontimes 237 | \ltimes 238 | \rtimes 239 | \leftthreetimes 240 | \rightthreetimes 241 | \curlywedge 242 | \curlyvee 243 | \circleddash 244 | \circledast 245 | \centerdot 246 | \intercal 247 | \doublecap 248 | \doublecup 249 | \boxtimes 250 | \dashrightarrow 251 | \dashleftarrow 252 | \leftleftarrows 253 | \leftrightarrows 254 | \Lleftarrow 255 | \twoheadleftarrow 256 | \leftarrowtail 257 | \looparrowleft 258 | \leftrightharpoons 259 | \curvearrowleft 260 | \circlearrowleft 261 | \Lsh 262 | \upuparrows 263 | \upharpoonleft 264 | \downharpoonleft 265 | \multimap 266 | \leftrightsquigarrow 267 | \rightrightarrows 268 | \rightleftarrows 269 | \twoheadrightarrow 270 | \rightarrowtail 271 | \looparrowright 272 | \curvearrowright 273 | \circlearrowright 274 | \Rsh 275 | \downdownarrows 276 | \upharpoonright 277 | \downharpoonright 278 | \rightsquigarrow 279 | \leadsto 280 | \Rrightarrow 281 | \restriction 282 | \angle 283 | \infty 284 | \prime 285 | \triangle 286 | \Gamma 287 | \Delta 288 | \Theta 289 | \Lambda 290 | \Sigma 291 | \Upsilon 292 | \Phi 293 | \Psi 294 | \Omega 295 | \neg 296 | \lnot 297 | \top 298 | \bot 299 | \emptyset 300 | \varnothing 301 | \alpha 302 | \beta 303 | \gamma 304 | \delta 305 | \epsilon 306 | \zeta 307 | \eta 308 | \theta 309 | \iota 310 | \kappa 311 | \lambda 312 | \omicron 313 | \rho 314 | \sigma 315 | \tau 316 | \upsilon 317 | \phi 318 | \chi 319 | \psi 320 | \omega 321 | \varepsilon 322 | \vartheta 323 | \varpi 324 | \varrho 325 | \varsigma 326 | \varphi 327 | \cdot 328 | \circ 329 | \div 330 | \times 331 | \cap 332 | \cup 333 | \setminus 334 | \land 335 | \lor 336 | \wedge 337 | \vee 338 | \surd 339 | \langle 340 | \lvert 341 | \lVert 342 | \rangle 343 | \rvert 344 | \rVert 345 | \approx 346 | \cong 347 | \geq 348 | \gets 349 | \subset 350 | \supset 351 | \subseteq 352 | \supseteq 353 | \nsubseteq 354 | \nsupseteq 355 | \models 356 | \leftarrow 357 | \leq 358 | \rightarrow 359 | \ngeq 360 | \nleq 361 | \space 362 | \nobreakspace 363 | \nobreak 364 | \allowbreak 365 | \barwedge 366 | \veebar 367 | \odot 368 | \oplus 369 | \otimes 370 | \partial 371 | \oslash 372 | \circledcirc 373 | \boxdot 374 | \bigtriangleup 375 | \bigtriangledown 376 | \dagger 377 | \diamond 378 | \star 379 | \triangleleft 380 | \triangleright 381 | \lbrace 382 | \rbrace 383 | \lbrack 384 | \rbrack 385 | \lparen 386 | \rparen 387 | \lfloor 388 | \rfloor 389 | \lceil 390 | \rceil 391 | \backslash 392 | \vert 393 | \Vert 394 | \uparrow 395 | \Uparrow 396 | \downarrow 397 | \Downarrow 398 | \updownarrow 399 | \Updownarrow 400 | \coprod 401 | \bigvee 402 | \bigwedge 403 | \biguplus 404 | \bigcap 405 | \bigcup 406 | \int 407 | \intop 408 | \iint 409 | \iiint 410 | \prod 411 | \sum 412 | \bigotimes 413 | \bigoplus 414 | \bigodot 415 | \oint 416 | \oiint 417 | \oiiint 418 | \bigsqcup 419 | \smallint 420 | \mathellipsis 421 | \ldots 422 | \ddots 423 | \varvdots 424 | \acute 425 | \grave 426 | \ddot 427 | \tilde 428 | \bar 429 | \breve 430 | \check 431 | \hat 432 | \vec 433 | \dot 434 | \mathring 435 | \imath 436 | \jmath 437 | \degree 438 | \pounds 439 | \mathsterling 440 | \maltese 441 | \textdagger 442 | \textdaggerdbl 443 | \textdollar 444 | \textunderscore 445 | \textbraceleft 446 | \textbraceright 447 | \textless 448 | \textgreater 449 | \textbar 450 | \textbardbl 451 | \textasciitilde 452 | \textbackslash 453 | \textasciicircum 454 | \textellipsis 455 | \textcircled 456 | \textendash 457 | \textemdash 458 | \textquoteleft 459 | \textquoteright 460 | \textquotedblleft 461 | \textquotedblright 462 | \textdegree 463 | \textsterling 464 | \widecheck 465 | \widehat 466 | \widetilde 467 | \overrightarrow 468 | \overleftarrow 469 | \Overrightarrow 470 | \overleftrightarrow 471 | \overgroup 472 | \overlinesegment 473 | \overleftharpoon 474 | \overrightharpoon 475 | \underleftarrow 476 | \underrightarrow 477 | \underleftrightarrow 478 | \undergroup 479 | \underlinesegment 480 | \utilde 481 | \xleftarrow 482 | \xrightarrow 483 | \xLeftarrow 484 | \xRightarrow 485 | \xleftrightarrow 486 | \xLeftrightarrow 487 | \xhookleftarrow 488 | \xhookrightarrow 489 | \xmapsto 490 | \xrightharpoondown 491 | \xrightharpoonup 492 | \xleftharpoondown 493 | \xleftharpoonup 494 | \xrightleftharpoons 495 | \xleftrightharpoons 496 | \xlongequal 497 | \xtwoheadrightarrow 498 | \xtwoheadleftarrow 499 | \xtofrom 500 | \xrightleftarrows 501 | \xrightequilibrium 502 | \xleftequilibrium 503 | \textcolor 504 | \color 505 | \newline 506 | \bigl 507 | \Bigl 508 | \biggl 509 | \Biggl 510 | \bigr 511 | \Bigr 512 | \biggr 513 | \Biggr 514 | \bigm 515 | \Bigm 516 | \biggm 517 | \Biggm 518 | \big 519 | \Big 520 | \bigg 521 | \Bigg 522 | \right 523 | \left 524 | \middle 525 | \colorbox 526 | \fcolorbox 527 | \fbox 528 | \cancel 529 | \bcancel 530 | \xcancel 531 | \sout 532 | \hline 533 | \hdashline 534 | \begin 535 | \end 536 | \mathord 537 | \mathbin 538 | \mathrel 539 | \mathopen 540 | \mathclose 541 | \mathpunct 542 | \mathinner 543 | \stackrel 544 | \overset 545 | \underset 546 | \mathrm 547 | \mathit 548 | \mathbf 549 | \mathnormal 550 | \mathbb 551 | \mathcal 552 | \mathfrak 553 | \mathscr 554 | \mathsf 555 | \mathtt 556 | \Bbb 557 | \bold 558 | \frak 559 | \boldsymbol 560 | \cfrac 561 | \dfrac 562 | \frac 563 | \tfrac 564 | \dbinom 565 | \binom 566 | \tbinom 567 | \over 568 | \choose 569 | \atop 570 | \brace 571 | \brack 572 | \genfrac 573 | \above 574 | \overbrace 575 | \underbrace 576 | \href 577 | \url 578 | \includegraphics 579 | \kern 580 | \mkern 581 | \hskip 582 | \mskip 583 | \mathllap 584 | \mathrlap 585 | \mathclap 586 | \mathchoice 587 | \mathop 588 | \arcsin 589 | \arccos 590 | \arctan 591 | \arctg 592 | \arcctg 593 | \arg 594 | \cos 595 | \cosec 596 | \cosh 597 | \cot 598 | \cotg 599 | \coth 600 | \csc 601 | \ctg 602 | \cth 603 | \deg 604 | \dim 605 | \exp 606 | \hom 607 | \ker 608 | \log 609 | \sec 610 | \sin 611 | \sinh 612 | \tan 613 | \tanh 614 | \det 615 | \gcd 616 | \inf 617 | \lim 618 | \max 619 | \min 620 | \sup 621 | \operatorname 622 | \overline 623 | \phantom 624 | \hphantom 625 | \vphantom 626 | \raisebox 627 | \rule 628 | \tiny 629 | \sixptsize 630 | \scriptsize 631 | \footnotesize 632 | \small 633 | \normalsize 634 | \large 635 | \Large 636 | \LARGE 637 | \huge 638 | \Huge 639 | \smash 640 | \sqrt 641 | \displaystyle 642 | \textstyle 643 | \scriptstyle 644 | \scriptscriptstyle 645 | \text 646 | \textrm 647 | \textsf 648 | \texttt 649 | \textnormal 650 | \textbf 651 | \textmd 652 | \textit 653 | \textup 654 | \underline 655 | \verb 656 | \TextOrMath 657 | \char 658 | \gdef 659 | \def 660 | \global 661 | \newcommand 662 | \renewcommand 663 | \providecommand 664 | \bgroup 665 | \egroup 666 | \textcopyright 667 | \copyright 668 | \textregistered 669 | \Bbbk 670 | \llap 671 | \rlap 672 | \clap 673 | \not 674 | \neq 675 | \notin 676 | \vdots 677 | \varGamma 678 | \varDelta 679 | \varTheta 680 | \varLambda 681 | \varXi 682 | \varPi 683 | \varSigma 684 | \varUpsilon 685 | \varPhi 686 | \varPsi 687 | \varOmega 688 | \substack 689 | \colon 690 | \boxed 691 | \iff 692 | \implies 693 | \impliedby 694 | \dots 695 | \dotso 696 | \dotsc 697 | \cdots 698 | \dotsb 699 | \dotsm 700 | \dotsi 701 | \dotsx 702 | \DOTSI 703 | \DOTSB 704 | \DOTSX 705 | \tmspace 706 | \thinspace 707 | \medspace 708 | \thickspace 709 | \negthinspace 710 | \negmedspace 711 | \negthickspace 712 | \enspace 713 | \enskip 714 | \quad 715 | \qquad 716 | \tag 717 | \bmod 718 | \pod 719 | \pmod 720 | \mod 721 | \pmb 722 | \TeX 723 | \LaTeX 724 | \KaTeX 725 | \hspace 726 | \ordinarycolon 727 | \vcentcolon 728 | \dblcolon 729 | \coloneqq 730 | \Coloneqq 731 | \coloneq 732 | \Coloneq 733 | \eqqcolon 734 | \Eqqcolon 735 | \eqcolon 736 | \Eqcolon 737 | \colonapprox 738 | \Colonapprox 739 | \colonsim 740 | \Colonsim 741 | \ratio 742 | \coloncolon 743 | \colonequals 744 | \coloncolonequals 745 | \equalscolon 746 | \equalscoloncolon 747 | \colonminus 748 | \coloncolonminus 749 | \minuscolon 750 | \minuscoloncolon 751 | \coloncolonapprox 752 | \coloncolonsim 753 | \simcolon 754 | \simcoloncolon 755 | \approxcolon 756 | \approxcoloncolon 757 | \notni 758 | \limsup 759 | \liminf 760 | \gvertneqq 761 | \lvertneqq 762 | \ngeqq 763 | \ngeqslant 764 | \nleqq 765 | \nleqslant 766 | \nshortmid 767 | \nshortparallel 768 | \nsubseteqq 769 | \nsupseteqq 770 | \varsubsetneq 771 | \varsubsetneqq 772 | \varsupsetneq 773 | \varsupsetneqq 774 | \llbracket 775 | \rrbracket 776 | \lBrace 777 | \rBrace 778 | \darr 779 | \dArr 780 | \Darr 781 | \lang 782 | \rang 783 | \uarr 784 | \uArr 785 | \Uarr 786 | \alef 787 | \alefsym 788 | \Alpha 789 | \Beta 790 | \bull 791 | \Chi 792 | \clubs 793 | \cnums 794 | \Complex 795 | \Dagger 796 | \diamonds 797 | \empty 798 | \Epsilon 799 | \Eta 800 | \exist 801 | \harr 802 | \hArr 803 | \Harr 804 | \hearts 805 | \image 806 | \infin 807 | \Iota 808 | \isin 809 | \Kappa 810 | \larr 811 | \lArr 812 | \Larr 813 | \lrarr 814 | \lrArr 815 | \Lrarr 816 | \natnums 817 | \Omicron 818 | \plusmn 819 | \rarr 820 | \rArr 821 | \Rarr 822 | \real 823 | \reals 824 | \Reals 825 | \Rho 826 | \sdot 827 | \sect 828 | \spades 829 | \sub 830 | \sube 831 | \supe 832 | \Tau 833 | \thetasym 834 | \weierp 835 | \Zeta 836 | \argmin 837 | \argmax 838 | \plim 839 | \blue 840 | \orange 841 | \pink 842 | \red 843 | \green 844 | \gray 845 | \purple 846 | \blueA 847 | \blueB 848 | \blueC 849 | \blueD 850 | \blueE 851 | \tealA 852 | \tealB 853 | \tealC 854 | \tealD 855 | \tealE 856 | \greenA 857 | \greenB 858 | \greenC 859 | \greenD 860 | \greenE 861 | \goldA 862 | \goldB 863 | \goldC 864 | \goldD 865 | \goldE 866 | \redA 867 | \redB 868 | \redC 869 | \redD 870 | \redE 871 | \maroonA 872 | \maroonB 873 | \maroonC 874 | \maroonD 875 | \maroonE 876 | \purpleA 877 | \purpleB 878 | \purpleC 879 | \purpleD 880 | \purpleE 881 | \mintA 882 | \mintB 883 | \mintC 884 | \grayA 885 | \grayB 886 | \grayC 887 | \grayD 888 | \grayE 889 | \grayF 890 | \grayG 891 | \grayH 892 | \grayI 893 | \kaBlue 894 | \kaGreen 895 | \smear 896 | \pxdy 897 | \jump 898 | \enter 899 | \endmatrix 900 | \beginmatrix 901 | \endcases 902 | \begincases 903 | \unk 904 | \pxsbx 905 | \textcelsius --------------------------------------------------------------------------------