├── .gitattributes ├── README.md ├── data ├── .DS_Store └── model │ ├── .DS_Store │ ├── encoder.pkl │ └── model.pkl ├── scrips ├── .DS_Store ├── evaluation │ ├── .DS_Store │ ├── Cal_B4.py │ ├── Cal_R4.py │ ├── CalculateMatch-WS.py │ └── CalculateMatch.py └── preprocessing │ ├── .DS_Store │ ├── preprocess_test.py │ ├── preprocess_train.py │ └── preprocess_val.py └── src ├── .DS_Store ├── __init__.py ├── model ├── __init__.py └── transformers.py ├── test.py ├── train.py └── utils ├── .DS_Store ├── ImgCandidate.py └── __init__.py /.gitattributes: -------------------------------------------------------------------------------- 1 | # Auto detect text files and perform LF normalization 2 | * text=auto 3 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # EDSL 2 | EDSL code 3 | 4 | ## Prerequsites 5 | 1. Pytorch 6 | 2. Numpy, Pillow 7 | 3. Pdflatex 8 | 4. ImageMagick 9 | 10 | ## Dataset Download 11 | ME-98K dataset: https://pan.baidu.com/s/1itQEjPUMve3A3dXezJDedw password: axqh 12 | 13 | ME-20K dataset: https://pan.baidu.com/s/1ti1xfCdj6c36yvy_sHld5Q password: nwks 14 | 15 | 16 | ## Quick Start 17 | ### 1. prepocess 18 | 19 | Preprocessing of training set. 20 |
21 | `python scrips/preprocessing/preprocess_train.py --formulas data/sample/formulas.txt --train data/sample/train.txt --vocab data/sample/latex_vocab.txt --img data/sample/image_processed/` 22 | 23 | Preprocessing of validation set 24 |
25 | `python scrips/preprocessing/preprocess_val.py --formulas data/sample/formulas.txt --val data/sample/train.txt --vocab data/sample/latex_vocab.txt --img data/sample/image_processed/` 26 | 27 | Preprocessing of test set 28 |
29 | `python scrips/preprocessing/preprocess_test.py --formulas data/sample/formulas.txt --test data/sample/test.txt --vocab data/sample/latex_vocab.txt --img data/sample/image_processed/` 30 | 31 | ### 2. training model 32 | `python src/train.py --formulas data/sample/formulas.txt --train data/sample/train.txt --val data/sample/val.txt --vocab data/sample/latex_vocab.txt` 33 | 34 | 35 | ### 3. testing 36 | `python src/train.py --formulas data/sample/formulas.txt --test data/sample/test.txt --vocab data/sample/latex_vocab.txt` 37 | 38 | 39 | ### 4. evaluation 40 | BLEU-4 Calculation: 41 |
42 | `python scrips/evaluation/Cal_B4.py --formulas data/sample/formulas.txt` 43 | 44 | Rouge-4 Calculation: 45 |
46 | `python scrips/evaluation/Cal_R4.py --formulas data/sample/formulas.txt` 47 | 48 | Match Calculation: 49 |
50 | `python scrips/evaluation/CalculateMath.py --formulas data/sample/formulas.txt` 51 | 52 | Match-ws Calculation: 53 |
54 | `python scrips/evaluation/CalculateMath-WS.py --formulas data/sample/formulas.txt` 55 | -------------------------------------------------------------------------------- /data/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/abcAnonymous/EDSL/e8b4d0e597aa7ff335c070e80853f33af52837f4/data/.DS_Store -------------------------------------------------------------------------------- /data/model/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/abcAnonymous/EDSL/e8b4d0e597aa7ff335c070e80853f33af52837f4/data/model/.DS_Store -------------------------------------------------------------------------------- /data/model/encoder.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/abcAnonymous/EDSL/e8b4d0e597aa7ff335c070e80853f33af52837f4/data/model/encoder.pkl -------------------------------------------------------------------------------- /data/model/model.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/abcAnonymous/EDSL/e8b4d0e597aa7ff335c070e80853f33af52837f4/data/model/model.pkl -------------------------------------------------------------------------------- /scrips/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/abcAnonymous/EDSL/e8b4d0e597aa7ff335c070e80853f33af52837f4/scrips/.DS_Store -------------------------------------------------------------------------------- /scrips/evaluation/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/abcAnonymous/EDSL/e8b4d0e597aa7ff335c070e80853f33af52837f4/scrips/evaluation/.DS_Store -------------------------------------------------------------------------------- /scrips/evaluation/Cal_B4.py: -------------------------------------------------------------------------------- 1 | from nltk.translate.bleu_score import sentence_bleu 2 | from tqdm import tqdm 3 | import argparse 4 | 5 | import sys 6 | import os 7 | 8 | root_path = os.path.abspath(__file__) 9 | root_path = '/'.join(root_path.split('/')[:-3]) 10 | sys.path.append(root_path) 11 | 12 | 13 | def formula_to_index(formula, dic): 14 | tmp = [] 15 | for word in formula.strip().split(' '): 16 | if len(word.strip()) > 0: 17 | if word.strip() in dic: 18 | tmp.append(dic[word.strip()]) 19 | else: 20 | tmp.append(dic['UNK']) 21 | return tmp 22 | 23 | 24 | def process_args(): 25 | parser = argparse.ArgumentParser(description='Get parameters') 26 | 27 | parser.add_argument('--formulas', dest='formulas_file_path', 28 | type=str, required=True, 29 | help='Input formulas.txt path') 30 | 31 | parameters = parser.parse_args() 32 | return parameters 33 | 34 | 35 | if __name__ == '__main__': 36 | 37 | parameters = process_args() 38 | 39 | f = open(parameters.formulas_file_path, 40 | encoding='utf-8').readlines() 41 | 42 | labelIndexDic = {} 43 | for item_f in f: 44 | labelIndexDic[item_f.strip().split('\t')[0].strip()] = item_f.strip().split('\t')[1] \ 45 | .strip() 46 | 47 | predictDic = {} 48 | f2 = open(root_path + '/data/result/predict.txt', encoding='utf-8').readlines() 49 | 50 | for item_f2 in f2: 51 | index = item_f2.strip().split('\t')[0] 52 | formula = item_f2.strip().split('\t')[1] 53 | predictDic[index] = formula 54 | 55 | bleuList = [] 56 | 57 | for item_p in tqdm(predictDic): 58 | predict = predictDic[item_p].strip().split(' ') 59 | label = labelIndexDic[item_p].strip().split(' ') 60 | 61 | if len(label) >= 4: 62 | if len(predict) < 4: 63 | bleuList.append(0) 64 | else: 65 | tmpBleu1 = sentence_bleu([label], predict, weights=(0, 0, 0, 1)) 66 | bleuList.append(tmpBleu1) 67 | 68 | print("BLEU-4:") 69 | print(sum(bleuList) / len(bleuList)) 70 | -------------------------------------------------------------------------------- /scrips/evaluation/Cal_R4.py: -------------------------------------------------------------------------------- 1 | from tqdm import tqdm 2 | 3 | import argparse 4 | 5 | import sys 6 | import os 7 | 8 | root_path = os.path.abspath(__file__) 9 | root_path = '/'.join(root_path.split('/')[:-3]) 10 | sys.path.append(root_path) 11 | 12 | 13 | def formula_to_index(formula, dic): 14 | tmp = [] 15 | for word in formula.strip().split(' '): 16 | if len(word.strip()) > 0: 17 | if word.strip() in dic: 18 | tmp.append(dic[word.strip()]) 19 | else: 20 | tmp.append(dic['UNK']) 21 | return tmp 22 | 23 | 24 | def getRouge_N(predict, label, n): 25 | predictCanList = [] 26 | labelCanList = [] 27 | 28 | for i in range(len(predict)): 29 | tmp = predict[i:i + n] 30 | if len(tmp) == n: 31 | predictCanList.append(tmp) 32 | 33 | for i in range(len(label)): 34 | tmp = label[i:i + n] 35 | if len(tmp) == n: 36 | labelCanList.append(tmp) 37 | 38 | len_labenCanList = len(labelCanList) 39 | 40 | if len_labenCanList == 0: 41 | return None 42 | else: 43 | countList = [] 44 | 45 | while len(predictCanList) > 0: 46 | try: 47 | index = labelCanList.index(predictCanList[0]) 48 | countList.append(predictCanList[0]) 49 | del labelCanList[index] 50 | except: 51 | pass 52 | 53 | del predictCanList[0] 54 | 55 | rouge_n = len(countList) / len_labenCanList 56 | return rouge_n 57 | 58 | 59 | def process_args(): 60 | parser = argparse.ArgumentParser(description='Get parameters') 61 | 62 | parser.add_argument('--formulas', dest='formulas_file_path', 63 | type=str, required=True, 64 | help='Input formulas.txt path') 65 | 66 | parameters = parser.parse_args() 67 | return parameters 68 | 69 | 70 | if __name__ == '__main__': 71 | 72 | parameters = process_args() 73 | 74 | f = open(parameters.formulas_file_path, 75 | encoding='utf-8').readlines() 76 | 77 | labelIndexDic = {} 78 | for item_f in f: 79 | labelIndexDic[item_f.strip().split('\t')[0].strip()] = item_f.strip().split('\t')[1] \ 80 | .strip() 81 | 82 | predictDic = {} 83 | f2 = open(root_path + '/data/result/predict.txt', encoding='utf-8').readlines() 84 | 85 | for item_f2 in f2: 86 | index = item_f2.strip().split('\t')[0] 87 | formula = item_f2.strip().split('\t')[1] 88 | predictDic[index] = formula 89 | 90 | roughList = [] 91 | 92 | for item_p in tqdm(predictDic): 93 | predict = predictDic[item_p].strip().split(' ') 94 | label = labelIndexDic[item_p].strip().split(' ') 95 | 96 | rougeN = getRouge_N(predict, label, 4) 97 | 98 | if rougeN == None: 99 | pass 100 | else: 101 | roughList.append(rougeN) 102 | 103 | print("ROUGE-4:") 104 | print(sum(roughList) / len(roughList)) 105 | -------------------------------------------------------------------------------- /scrips/evaluation/CalculateMatch-WS.py: -------------------------------------------------------------------------------- 1 | import subprocess 2 | from PIL import Image 3 | from tqdm import tqdm 4 | 5 | import argparse 6 | from src.utils import ImgCandidate 7 | 8 | import sys 9 | import os 10 | 11 | root_path = os.path.abspath(__file__) 12 | root_path = '/'.join(root_path.split('/')[:-3]) 13 | sys.path.append(root_path) 14 | 15 | 16 | def process_args(): 17 | parser = argparse.ArgumentParser(description='Get parameters') 18 | 19 | parser.add_argument('--formulas', dest='formulas_file_path', 20 | type=str, required=True, 21 | help='Input formulas.txt path') 22 | 23 | parameters = parser.parse_args() 24 | return parameters 25 | 26 | 27 | if __name__ == '__main__': 28 | 29 | parameters = process_args() 30 | 31 | f = open(root_path + '/data/result/predict.txt').readlines() 32 | f2 = open(parameters.formulas_file_path, encoding='utf-8').readlines() 33 | 34 | formulaDic = {} 35 | for item_f2 in f2: 36 | formulaDic[item_f2.strip().split('\t')[0]] = item_f2.strip().split('\t')[1] 37 | 38 | accList = [] 39 | 40 | for item_f in tqdm(f): 41 | index = item_f.strip().split('\t')[0] 42 | formula = item_f.strip().split('\t')[1] 43 | labelFormula = formulaDic[index].strip() 44 | 45 | if formula == labelFormula: 46 | accList.append(1) 47 | else: 48 | 49 | pdfText = r'\documentclass{article}' + '\n' + r'\usepackage{amsmath,amssymb}' + '\n' + '\pagestyle{empty}' + '\n' + \ 50 | r'\thispagestyle{empty}' + '\n' + r'\begin{document}' + '\n' + r'\begin{equation*}' + '\n' + formula + \ 51 | r'\end{equation*}' + '\n' + '\end{document}' 52 | f3 = open('predict.tex', mode='w') 53 | f3.write(pdfText) 54 | f3.close() 55 | sub = subprocess.Popen("pdflatex -halt-on-error " + "predict.tex", shell=True, stdout=subprocess.PIPE) 56 | sub.wait() 57 | 58 | pdfFiles = [] 59 | for _, _, pf in os.walk(os.getcwd()): 60 | pdfFiles = pf 61 | break 62 | 63 | if 'predict.pdf' in pdfFiles: 64 | try: 65 | pdfText = r'\documentclass{article}' + '\n' + r'\usepackage{amsmath,amssymb}' + '\n' + '\pagestyle{empty}' + '\n' + \ 66 | r'\thispagestyle{empty}' + '\n' + r'\begin{document}' + '\n' + r'\begin{equation*}' + '\n' + labelFormula + \ 67 | r'\end{equation*}' + '\n' + '\end{document}' 68 | f3 = open('label.tex', mode='w') 69 | f3.write(pdfText) 70 | f3.close() 71 | sub = subprocess.Popen("pdflatex -halt-on-error " + "label.tex", shell=True, stdout=subprocess.PIPE) 72 | sub.wait() 73 | 74 | os.system( 75 | 'convert -background white -density 200 -quality 100 -strip ' + 'label.pdf ' + 'label.png') 76 | os.system( 77 | 'convert -background white -density 200 -quality 100 -strip ' + 'predict.pdf ' + 'predict.png') 78 | label = ImgCandidate.deleSpace( 79 | ImgCandidate.deletePadding(np.array(Image.open('label.png').convert('L')))).tolist() 80 | predict = ImgCandidate.deleSpace( 81 | ImgCandidate.deletePadding(np.array(Image.open('predict.png').convert('L')))).tolist() 82 | 83 | if label == predict: 84 | accList.append(1) 85 | else: 86 | accList.append(0) 87 | except: 88 | accList.append(0) 89 | 90 | else: 91 | accList.append(0) 92 | print(sum(accList) / len(accList)) 93 | 94 | os.system('rm -rf *.aux') 95 | os.system('rm -rf *.log') 96 | os.system('rm -rf *.tex') 97 | os.system('rm -rf *.pdf') 98 | os.system('rm -rf *.png') 99 | 100 | print(sum(accList) / len(accList)) 101 | -------------------------------------------------------------------------------- /scrips/evaluation/CalculateMatch.py: -------------------------------------------------------------------------------- 1 | import subprocess 2 | from PIL import Image 3 | from tqdm import tqdm 4 | 5 | import argparse 6 | 7 | import sys 8 | import os 9 | 10 | root_path = os.path.abspath(__file__) 11 | root_path = '/'.join(root_path.split('/')[:-3]) 12 | sys.path.append(root_path) 13 | 14 | 15 | def process_args(): 16 | parser = argparse.ArgumentParser(description='Get parameters') 17 | 18 | parser.add_argument('--formulas', dest='formulas_file_path', 19 | type=str, required=True, 20 | help='Input formulas.txt path') 21 | 22 | parameters = parser.parse_args() 23 | return parameters 24 | 25 | 26 | if __name__ == '__main__': 27 | parameters = process_args() 28 | 29 | f = open(root_path + '/data/result/predict.txt').readlines() 30 | f2 = open(parameters.formulas_file_path, encoding='utf-8').readlines() 31 | 32 | formulaDic = {} 33 | for item_f2 in f2: 34 | formulaDic[item_f2.strip().split('\t')[0]] = item_f2.strip().split('\t')[1] 35 | 36 | accList = [] 37 | 38 | for item_f in tqdm(f): 39 | index = item_f.strip().split('\t')[0] 40 | formula = item_f.strip().split('\t')[1] 41 | labelFormula = formulaDic[index].strip() 42 | 43 | if formula == labelFormula: 44 | accList.append(1) 45 | else: 46 | pdfText = r'\documentclass{article}' + '\n' + r'\usepackage{amsmath,amssymb}' + '\n' + '\pagestyle{empty}' + '\n' + \ 47 | r'\thispagestyle{empty}' + '\n' + r'\begin{document}' + '\n' + r'\begin{equation*}' + '\n' + formula + \ 48 | r'\end{equation*}' + '\n' + '\end{document}' 49 | f3 = open('predict.tex', mode='w') 50 | f3.write(pdfText) 51 | f3.close() 52 | sub = subprocess.Popen("pdflatex -halt-on-error " + "predict.tex", shell=True, stdout=subprocess.PIPE) 53 | sub.wait() 54 | 55 | pdfFiles = [] 56 | for _, _, pf in os.walk(os.getcwd()): 57 | pdfFiles = pf 58 | break 59 | 60 | if 'predict.pdf' in pdfFiles: 61 | try: 62 | pdfText = r'\documentclass{article}' + '\n' + r'\usepackage{amsmath,amssymb}' + '\n' + '\pagestyle{empty}' + '\n' + \ 63 | r'\thispagestyle{empty}' + '\n' + r'\begin{document}' + '\n' + r'\begin{equation*}' + '\n' + labelFormula + \ 64 | r'\end{equation*}' + '\n' + '\end{document}' 65 | f3 = open('label.tex', mode='w') 66 | f3.write(pdfText) 67 | f3.close() 68 | sub = subprocess.Popen("pdflatex -halt-on-error " + "label.tex", shell=True, stdout=subprocess.PIPE) 69 | sub.wait() 70 | 71 | os.system('convert -strip ' + 'label.pdf ' + 'label.png') 72 | os.system('convert -strip ' + 'predict.pdf ' + 'predict.png') 73 | label = Image.open('label.png').convert('L') 74 | 75 | predict = Image.open('predict.png').convert('L') 76 | 77 | if label == predict: 78 | accList.append(1) 79 | else: 80 | accList.append(0) 81 | except: 82 | accList.append(0) 83 | 84 | else: 85 | accList.append(0) 86 | print(sum(accList) / len(accList)) 87 | 88 | os.system('rm -rf *.aux') 89 | os.system('rm -rf *.log') 90 | os.system('rm -rf *.tex') 91 | os.system('rm -rf *.pdf') 92 | os.system('rm -rf *.png') 93 | 94 | print('Match:') 95 | print(sum(accList) / len(accList)) 96 | -------------------------------------------------------------------------------- /scrips/preprocessing/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/abcAnonymous/EDSL/e8b4d0e597aa7ff335c070e80853f33af52837f4/scrips/preprocessing/.DS_Store -------------------------------------------------------------------------------- /scrips/preprocessing/preprocess_test.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from tqdm import tqdm 3 | import warnings 4 | from PIL import Image 5 | 6 | import argparse 7 | import sys 8 | import os 9 | 10 | root_path = os.path.abspath(__file__) 11 | root_path = '/'.join(root_path.split('/')[:-3]) 12 | sys.path.append(root_path) 13 | 14 | from src.utils import ImgCandidate 15 | 16 | 17 | def process_args(): 18 | parser = argparse.ArgumentParser(description='Get parameters') 19 | 20 | parser.add_argument('--formulas', dest='formulas_file_path', 21 | type=str, required=True, 22 | help='Input formulas.txt path') 23 | 24 | parser.add_argument('--test', dest='test_file_path', 25 | type=str, required=True, 26 | help='Input test.txt path') 27 | 28 | parser.add_argument('--vocab', dest='vocab_file_path', 29 | type=str, required=True, 30 | help='Input latex_vocab.txt path') 31 | 32 | parser.add_argument('--img', dest='img_path', 33 | type=str, required=True, 34 | help='Input image path') 35 | 36 | parameters = parser.parse_args() 37 | return parameters 38 | 39 | 40 | def get_position_vec(positionList): 41 | xMax = max([item[1] for item in positionList]) 42 | yMax = max([item[3] for item in positionList]) 43 | 44 | finalPosition = [] 45 | for item_g in positionList: 46 | x_1 = item_g[0] 47 | x_2 = item_g[1] 48 | y_1 = item_g[2] 49 | y_2 = item_g[3] 50 | 51 | tmp = [x_1 / xMax, x_2 / xMax, y_1 / yMax, y_2 / yMax, xMax / yMax] 52 | finalPosition.append(tmp) 53 | return finalPosition 54 | 55 | 56 | if __name__ == '__main__': 57 | 58 | parameters = process_args() 59 | 60 | warnings.filterwarnings('ignore') 61 | 62 | f = open(parameters.formulas_file_path, encoding='utf-8').readlines() 63 | labelIndexDic = {} 64 | for item_f in f: 65 | labelIndexDic[item_f.strip().split('\t')[0]] = item_f.strip().split('\t')[1] 66 | 67 | f3 = open(parameters.test_file_path, 68 | encoding='utf-8').readlines() 69 | 70 | testLabelList = [] 71 | for item_f3 in f3: 72 | if len(item_f3) > 0: 73 | testLabelList.append(labelIndexDic[item_f3.strip()]) 74 | 75 | MAXLENGTH = 150 76 | 77 | f5 = open(parameters.vocab_file_path, encoding='utf-8').readlines() 78 | 79 | PAD = 0 80 | START = 1 81 | END = 2 82 | 83 | index_label_dic = {} 84 | label_index_dic = {} 85 | 86 | i = 3 87 | for item_f5 in f5: 88 | word = item_f5.strip() 89 | if len(word) > 0: 90 | label_index_dic[word] = i 91 | index_label_dic[i] = word 92 | i += 1 93 | label_index_dic['unk'] = i 94 | index_label_dic[i] = 'unk' 95 | i += 1 96 | 97 | labelEmbed_teaching_test = [] 98 | labelEmbed_predict_test = [] 99 | 100 | for item_l in testLabelList: 101 | tmp = [1] 102 | words = item_l.strip().split(' ') 103 | for item_w in words: 104 | if len(item_w) > 0: 105 | if item_w in label_index_dic: 106 | tmp.append(label_index_dic[item_w]) 107 | else: 108 | tmp.append(label_index_dic['unk']) 109 | 110 | labelEmbed_teaching_test.append(tmp) 111 | 112 | tmp = [] 113 | words = item_l.strip().split(' ') 114 | for item_w in words: 115 | if len(item_w) > 0: 116 | if item_w in label_index_dic: 117 | tmp.append(label_index_dic[item_w]) 118 | else: 119 | tmp.append(label_index_dic['unk']) 120 | 121 | tmp.append(2) 122 | labelEmbed_predict_test.append(tmp) 123 | 124 | labelEmbed_teachingArray_test = np.array(labelEmbed_teaching_test) 125 | labelEmbed_predictArray_test = np.array(labelEmbed_predict_test) 126 | 127 | # 128 | testData = [] 129 | testPosition = [] 130 | 131 | for item_f3 in tqdm(f3): 132 | img = Image.open(parameters.img_path + 133 | item_f3.strip() + ".png").convert('L') 134 | img = Image.fromarray(ImgCandidate.deletePadding(np.array(img))) 135 | imgInfo = [] 136 | positionInfo = [] 137 | for t in [160]: 138 | tmp = ImgCandidate.getAllCandidate(img, t) 139 | for item_t in tmp: 140 | if item_t[1] not in positionInfo: 141 | imgInfo.append(item_t[0]) 142 | positionInfo.append(item_t[1]) 143 | positionVec = get_position_vec(positionInfo) 144 | testData.append(imgInfo) 145 | testPosition.append(positionVec) 146 | 147 | np.save(root_path + '/data/preprocess_data/testData_160', np.array(testData)) 148 | np.save(root_path + '/data/preprocess_data/testPosition_160', np.array(testPosition)) 149 | 150 | # 151 | -------------------------------------------------------------------------------- /scrips/preprocessing/preprocess_train.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from tqdm import tqdm 3 | import warnings 4 | from PIL import Image 5 | 6 | import argparse 7 | import sys 8 | import os 9 | 10 | root_path = os.path.abspath(__file__) 11 | root_path = '/'.join(root_path.split('/')[:-3]) 12 | sys.path.append(root_path) 13 | 14 | from src.utils import ImgCandidate 15 | 16 | warnings.filterwarnings('ignore') 17 | 18 | 19 | def process_args(): 20 | parser = argparse.ArgumentParser(description='Get parameters') 21 | 22 | parser.add_argument('--formulas', dest='formulas_file_path', 23 | type=str, required=True, 24 | help= 'Input formulas.txt path') 25 | 26 | parser.add_argument('--train', dest='train_file_path', 27 | type=str, required=True, 28 | help='Input train.txt path') 29 | 30 | parser.add_argument('--vocab', dest='vocab_file_path', 31 | type=str, required=True, 32 | help='Input latex_vocab.txt path') 33 | 34 | parser.add_argument('--img', dest='img_path', 35 | type=str, required=True, 36 | help='Input image path') 37 | 38 | parameters = parser.parse_args() 39 | return parameters 40 | 41 | 42 | 43 | 44 | def get_position_vec(positionList): 45 | 46 | xMax = max([item[1] for item in positionList]) 47 | yMax = max([item[3] for item in positionList]) 48 | 49 | finalPosition = [] 50 | for item_g in positionList: 51 | 52 | x_1 = item_g[0] 53 | x_2 = item_g[1] 54 | y_1 = item_g[2] 55 | y_2 = item_g[3] 56 | 57 | 58 | tmp = [x_1 /xMax , x_2 /xMax , y_1 /yMax, y_2 /yMax, xMax/yMax] 59 | finalPosition.append(tmp) 60 | return finalPosition 61 | 62 | 63 | 64 | if __name__ == '__main__': 65 | 66 | 67 | parameters = process_args() 68 | 69 | f = open(parameters.formulas_file_path, encoding='utf-8').readlines() 70 | labelIndexDic = {} 71 | for item_f in f: 72 | labelIndexDic[item_f.strip().split('\t')[0]] = item_f.strip().split('\t')[1] 73 | 74 | f2 = open(parameters.train_file_path, 75 | encoding='utf-8').readlines() 76 | 77 | 78 | trainLabelList = [] 79 | for item_f2 in f2: 80 | if len(item_f2) > 0: 81 | trainLabelList.append(labelIndexDic[item_f2.strip()]) 82 | 83 | 84 | 85 | MAXLENGTH = 150 86 | 87 | f5 = open(parameters.vocab_file_path, encoding='utf-8').readlines() 88 | 89 | PAD = 0 90 | START = 1 91 | END = 2 92 | 93 | index_label_dic = {} 94 | label_index_dic = {} 95 | 96 | i = 3 97 | for item_f5 in f5: 98 | word = item_f5.strip() 99 | if len(word) > 0: 100 | label_index_dic[word] = i 101 | index_label_dic[i] = word 102 | i += 1 103 | label_index_dic['unk'] = i 104 | index_label_dic[i] = 'unk' 105 | i += 1 106 | 107 | labelEmbed_teaching_train = [] 108 | 109 | labelEmbed_predict_train = [] 110 | for item_l in trainLabelList: 111 | tmp = [1] 112 | words = item_l.strip().split(' ') 113 | for item_w in words: 114 | if len(item_w) > 0: 115 | if item_w in label_index_dic: 116 | tmp.append(label_index_dic[item_w]) 117 | else: 118 | 119 | tmp.append(label_index_dic['unk']) 120 | 121 | labelEmbed_teaching_train.append(tmp) 122 | 123 | tmp = [] 124 | words = item_l.strip().split(' ') 125 | for item_w in words: 126 | if len(item_w) > 0: 127 | if item_w in label_index_dic: 128 | tmp.append(label_index_dic[item_w]) 129 | else: 130 | tmp.append(label_index_dic['unk']) 131 | 132 | tmp.append(2) 133 | labelEmbed_predict_train.append(tmp) 134 | 135 | 136 | 137 | imgDic = {} 138 | imgList = [] 139 | 140 | trainData = [] 141 | trainPosition = [] 142 | trainIndexList = [] 143 | 144 | 145 | for i in tqdm(range(len(f2))): 146 | item_f2 = f2[i] ## 147 | img = Image.open(parameters.img_path + 148 | item_f2.strip() + ".png").convert('L') 149 | img = Image.fromarray(ImgCandidate.deletePadding(np.array(img))) 150 | for t in [160, 180, 200]: 151 | 152 | tmp = ImgCandidate.getAllCandidate(img, t) 153 | positionInfo = [item[1] for item in tmp] 154 | 155 | # 156 | positionVec = get_position_vec(positionInfo) 157 | # 158 | if t == 160: 159 | trainIndexList.append(0 + i) 160 | 161 | trainPosition.append(positionVec) 162 | imgInfo = [] 163 | 164 | for item_t in tmp: 165 | if str(item_t[0].tolist()) not in imgDic: 166 | imgDic[str(item_t[0].tolist())] = len(imgDic) 167 | imgList.append(item_t[0]) 168 | imgInfo.append(imgDic[str(item_t[0].tolist())]) 169 | trainData.append(imgInfo) 170 | 171 | else: 172 | if positionVec not in trainPosition: 173 | trainIndexList.append(0 + i) 174 | 175 | trainPosition.append(positionVec) 176 | imgInfo = [] 177 | for item_t in tmp: 178 | if str(item_t[0].tolist()) not in imgDic: 179 | imgDic[str(item_t[0].tolist())] = len(imgDic) 180 | imgList.append(item_t[0]) 181 | imgInfo.append(imgDic[str(item_t[0].tolist())]) 182 | trainData.append(imgInfo) 183 | 184 | 185 | 186 | trainData = np.array(trainData) 187 | trainPosition = np.array(trainPosition) 188 | trainIndexList = np.array(trainIndexList) 189 | imgList = np.array(imgList) 190 | 191 | np.save(root_path + '/data/preprocess_data/trainData', np.array(trainData)) 192 | np.save(root_path + '/data/preprocess_data/trainPosition', np.array(trainPosition)) 193 | np.save(root_path + '/data/preprocess_data/trainIndexList', np.array(trainIndexList)) 194 | np.save(root_path + '/data/preprocess_data/trainImgList',np.array(imgList)) 195 | 196 | -------------------------------------------------------------------------------- /scrips/preprocessing/preprocess_val.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from tqdm import tqdm 3 | import warnings 4 | from PIL import Image 5 | 6 | import argparse 7 | import sys 8 | import os 9 | 10 | root_path = os.path.abspath(__file__) 11 | root_path = '/'.join(root_path.split('/')[:-3]) 12 | sys.path.append(root_path) 13 | 14 | from src.utils import ImgCandidate 15 | 16 | 17 | def process_args(): 18 | parser = argparse.ArgumentParser(description='Get parameters') 19 | 20 | parser.add_argument('--formulas', dest='formulas_file_path', 21 | type=str, required=True, 22 | help='Input formulas.txt path') 23 | 24 | parser.add_argument('--val', dest='val_file_path', 25 | type=str, required=True, 26 | help='Input val.txt path') 27 | 28 | parser.add_argument('--vocab', dest='vocab_file_path', 29 | type=str, required=True, 30 | help='Input latex_vocab.txt path') 31 | 32 | parser.add_argument('--img', dest='img_path', 33 | type=str, required=True, 34 | help='Input image path') 35 | 36 | parameters = parser.parse_args() 37 | return parameters 38 | 39 | 40 | def get_position_vec(positionList): 41 | xMax = max([item[1] for item in positionList]) 42 | yMax = max([item[3] for item in positionList]) 43 | 44 | finalPosition = [] 45 | for item_g in positionList: 46 | x_1 = item_g[0] 47 | x_2 = item_g[1] 48 | y_1 = item_g[2] 49 | y_2 = item_g[3] 50 | 51 | tmp = [x_1 / xMax, x_2 / xMax, y_1 / yMax, y_2 / yMax, xMax / yMax] 52 | finalPosition.append(tmp) 53 | return finalPosition 54 | 55 | 56 | if __name__ == '__main__': 57 | 58 | parameters = process_args() 59 | 60 | warnings.filterwarnings('ignore') 61 | 62 | f = open(parameters.formulas_file_path, encoding='utf-8').readlines() 63 | labelIndexDic = {} 64 | for item_f in f: 65 | labelIndexDic[item_f.strip().split('\t')[0]] = item_f.strip().split('\t')[1] 66 | 67 | f3 = open(parameters.val_file_path, 68 | encoding='utf-8').readlines() 69 | 70 | valLabelList = [] 71 | for item_f3 in f3: 72 | if len(item_f3) > 0: 73 | valLabelList.append(labelIndexDic[item_f3.strip()]) 74 | 75 | MAXLENGTH = 150 76 | 77 | f5 = open(parameters.vocab_file_path, encoding='utf-8').readlines() 78 | 79 | PAD = 0 80 | START = 1 81 | END = 2 82 | 83 | index_label_dic = {} 84 | label_index_dic = {} 85 | 86 | i = 3 87 | for item_f5 in f5: 88 | word = item_f5.strip() 89 | if len(word) > 0: 90 | label_index_dic[word] = i 91 | index_label_dic[i] = word 92 | i += 1 93 | label_index_dic['unk'] = i 94 | index_label_dic[i] = 'unk' 95 | i += 1 96 | 97 | labelEmbed_teaching_val = [] 98 | labelEmbed_predict_val = [] 99 | 100 | for item_l in valLabelList: 101 | tmp = [1] 102 | words = item_l.strip().split(' ') 103 | for item_w in words: 104 | if len(item_w) > 0: 105 | if item_w in label_index_dic: 106 | tmp.append(label_index_dic[item_w]) 107 | else: 108 | tmp.append(label_index_dic['unk']) 109 | 110 | labelEmbed_teaching_val.append(tmp) 111 | 112 | tmp = [] 113 | words = item_l.strip().split(' ') 114 | for item_w in words: 115 | if len(item_w) > 0: 116 | if item_w in label_index_dic: 117 | tmp.append(label_index_dic[item_w]) 118 | else: 119 | tmp.append(label_index_dic['unk']) 120 | 121 | tmp.append(2) 122 | labelEmbed_predict_val.append(tmp) 123 | 124 | labelEmbed_teachingArray_val = np.array(labelEmbed_teaching_val) 125 | labelEmbed_predictArray_val = np.array(labelEmbed_predict_val) 126 | 127 | # 128 | valData = [] 129 | valPosition = [] 130 | 131 | for item_f3 in tqdm(f3): 132 | img = Image.open(parameters.img_path + 133 | item_f3.strip() + ".png").convert('L') 134 | img = Image.fromarray(ImgCandidate.deletePadding(np.array(img))) 135 | imgInfo = [] 136 | positionInfo = [] 137 | for t in [160]: 138 | tmp = ImgCandidate.getAllCandidate(img, t) 139 | for item_t in tmp: 140 | if item_t[1] not in positionInfo: 141 | imgInfo.append(item_t[0]) 142 | positionInfo.append(item_t[1]) 143 | positionVec = get_position_vec(positionInfo) 144 | valData.append(imgInfo) 145 | valPosition.append(positionVec) 146 | 147 | np.save(root_path + '/data/preprocess_data/valData_160', np.array(valData)) 148 | np.save(root_path + '/data/preprocess_data/valPosition_160', np.array(valPosition)) 149 | 150 | # 151 | -------------------------------------------------------------------------------- /src/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/abcAnonymous/EDSL/e8b4d0e597aa7ff335c070e80853f33af52837f4/src/.DS_Store -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/abcAnonymous/EDSL/e8b4d0e597aa7ff335c070e80853f33af52837f4/src/__init__.py -------------------------------------------------------------------------------- /src/model/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/abcAnonymous/EDSL/e8b4d0e597aa7ff335c070e80853f33af52837f4/src/model/__init__.py -------------------------------------------------------------------------------- /src/model/transformers.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import math, copy 6 | import seaborn 7 | 8 | seaborn.set_context(context="talk") 9 | 10 | 11 | class EncoderDecoder(nn.Module): 12 | """ 13 | A standard Encoder-Decoder architecture. Base for this and many 14 | other models. 15 | """ 16 | 17 | def __init__(self, encoder2, decoder, src_position_embed, tgt_embed, generator): 18 | super(EncoderDecoder, self).__init__() 19 | 20 | self.encoder2 = encoder2 21 | self.decoder = decoder 22 | self.src_position_embed = src_position_embed 23 | self.tgt_embed = tgt_embed 24 | self.generator = generator 25 | self.src_proj = nn.Linear(576, 256) 26 | 27 | def forward(self, src, src_position, tgt, src_mask, tgt_mask): 28 | src = F.relu(self.src_proj(src)) 29 | src_position_embed = self.src_position_embed(src_position) 30 | src_embed = src_position_embed + src 31 | 32 | memory = self.encode2(src_embed, src_mask, src_position_embed) 33 | decode_embeded = self.decode(memory, src_mask, tgt, tgt_mask) 34 | return decode_embeded 35 | 36 | def encode1(self, src, src_mask): 37 | return self.encoder1(src, src_mask) 38 | 39 | def encode2(self, src, src_mask, pos): 40 | return self.encoder2(src, src_mask, pos) 41 | 42 | def decode(self, memory, src_mask, tgt, tgt_mask): 43 | tgt_embed = self.tgt_embed(tgt) 44 | 45 | return self.decoder(tgt_embed, memory, src_mask, tgt_mask) 46 | 47 | 48 | class Generator(nn.Module): 49 | "Define standard linear + softmax generation step." 50 | 51 | def __init__(self, d_model, vocab): 52 | super(Generator, self).__init__() 53 | self.proj = nn.Linear(d_model, vocab) 54 | 55 | def forward(self, x): 56 | return F.log_softmax(self.proj(x), dim=-1) 57 | 58 | 59 | def clones(module, N): 60 | "Produce N identical layers." 61 | return nn.ModuleList([copy.deepcopy(module) for _ in range(N)]) 62 | 63 | 64 | class Encoder(nn.Module): 65 | "Core encoder is a stack of N layers" 66 | 67 | def __init__(self, layer, N): 68 | super(Encoder, self).__init__() 69 | self.layers = clones(layer, N) 70 | self.norm = nn.LayerNorm(layer.size) 71 | 72 | self.pj1 = nn.Linear(512, 256) 73 | self.pj2 = nn.Linear(256, 1) 74 | 75 | def forward(self, x, mask, pos): 76 | "Pass the input (and mask) through each layer in turn." 77 | pos1 = pos.unsqueeze(1).repeat(1, pos.size(1), 1, 1) 78 | pos2 = pos.unsqueeze(2).repeat(1, 1, pos.size(1), 1) 79 | 80 | posr = self.pj2(F.relu(self.pj1(torch.cat((pos1, pos2), dim=-1)))).squeeze(-1) 81 | 82 | for layer in self.layers: 83 | x, attn = layer(x, mask, posr) 84 | 85 | return self.norm(x) 86 | # return self.norm(x), attn 87 | 88 | 89 | class SublayerConnection(nn.Module): 90 | """ 91 | A residual connection followed by a layer norm. 92 | Note for code simplicity the norm is first as opposed to last. 93 | """ 94 | 95 | def __init__(self, size, dropout): 96 | super(SublayerConnection, self).__init__() 97 | self.norm = nn.LayerNorm(size) 98 | self.dropout = nn.Dropout(dropout) 99 | 100 | def forward(self, x, sublayer): 101 | "Apply residual connection to any sublayer with the same size." 102 | return x + self.dropout(sublayer(self.norm(x))) 103 | 104 | 105 | class EncoderLayer(nn.Module): 106 | "Encoder is made up of self-attn and feed forward (defined below)" 107 | 108 | def __init__(self, size, self_attn, feed_forward, dropout): 109 | super(EncoderLayer, self).__init__() 110 | self.self_attn = self_attn 111 | self.feed_forward = feed_forward 112 | self.sublayer = clones(SublayerConnection(size, dropout), 2) 113 | self.size = size 114 | 115 | def forward(self, x, mask, posr): 116 | "Follow Figure 1 (left) for connections." 117 | 118 | x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, mask, posr)) 119 | attn = self.self_attn.attn 120 | return self.sublayer[1](x, self.feed_forward), attn 121 | 122 | 123 | class Decoder(nn.Module): 124 | "Generic N layer decoder with masking." 125 | 126 | def __init__(self, layer, N): 127 | super(Decoder, self).__init__() 128 | self.layers = clones(layer, N) 129 | self.norm = nn.LayerNorm(layer.size) 130 | 131 | self.count = nn.Embedding(200, 256) 132 | 133 | # self.pos_pj = nn.Linear(512,256) 134 | 135 | def forward(self, x, memory, src_mask, tgt_mask): 136 | # src_count = torch.sum(src_mask,dim=-1).squeeze(-1).long().cuda() - 1 137 | # src_count = self.count(src_count).unsqueeze(1).repeat(1,x.size(1),1) 138 | 139 | x_pos = torch.Tensor(list(range(x.size(1)))).unsqueeze(0).repeat(x.size(0), 1).long().cuda() 140 | x_pos = self.count(x_pos) 141 | 142 | # x_pos = self.pos_pj(torch.cat((src_count, x_pos),dim=-1)) 143 | x = x + x_pos 144 | 145 | # x_cat = torch.empty(x.size(0), 0, x.size(1), x.size(2)).cuda() 146 | # x_cat = torch.cat((x_cat, x.unsqueeze(1)), dim=1) 147 | 148 | for layer in self.layers: 149 | x = layer(x, memory, src_mask, tgt_mask) 150 | # x_cat = torch.cat((x_cat, x.unsqueeze(1)), dim=1) 151 | 152 | # x_cat = torch.mean(x_cat, dim=1) 153 | 154 | return self.norm(x) 155 | 156 | 157 | class DecoderLayer(nn.Module): 158 | "Decoder is made of self-attn, src-attn, and feed forward (defined below)" 159 | 160 | def __init__(self, size, self_attn, src_attn, feed_forward, dropout): 161 | super(DecoderLayer, self).__init__() 162 | self.size = size 163 | self.self_attn = self_attn 164 | self.src_attn = src_attn 165 | self.feed_forward = feed_forward 166 | self.sublayer = clones(SublayerConnection(size, dropout), 3) 167 | 168 | def forward(self, x, memory, src_mask, tgt_mask): 169 | "Follow Figure 1 (right) for connections." 170 | m = memory 171 | # x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, tgt_mask)) 172 | x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, tgt_mask)) 173 | x = self.sublayer[1](x, lambda x: self.src_attn(x, m, m, src_mask)) 174 | return self.sublayer[2](x, self.feed_forward) 175 | 176 | 177 | def subsequent_mask(size): 178 | "Mask out subsequent positions." 179 | attn_shape = (1, size, size) 180 | subsequent_mask = np.triu(np.ones(attn_shape), k=1).astype('uint8') 181 | return torch.from_numpy(subsequent_mask) == 0 182 | 183 | 184 | def attention(query, key, value, mask=None, dropout=None, posr=None): 185 | "Compute 'Scaled Dot Product Attention'" 186 | d_k = query.size(-1) 187 | 188 | scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k) 189 | 190 | if posr is not None: 191 | posr = posr.unsqueeze(1) 192 | 193 | scores = scores + posr 194 | 195 | if mask is not None: 196 | scores = scores.masked_fill(mask == 0, -1e9) 197 | 198 | p_attn = F.softmax(scores, dim=-1) 199 | 200 | if dropout is not None: 201 | p_attn = dropout(p_attn) 202 | return torch.matmul(p_attn, value), p_attn 203 | 204 | 205 | class MultiHeadedAttention(nn.Module): 206 | def __init__(self, h, d_model, dropout): 207 | "Take in model size and number of heads." 208 | super(MultiHeadedAttention, self).__init__() 209 | assert d_model % h == 0 210 | # We assume d_v always equals d_k 211 | self.d_k = d_model // h 212 | self.h = h 213 | self.linears = clones(nn.Linear(d_model, d_model), 4) 214 | self.attn = None 215 | self.dropout = nn.Dropout(p=dropout) 216 | 217 | def forward(self, query, key, value, mask=None, posr=None): 218 | "Implements Figure 2" 219 | if mask is not None: 220 | # Same mask applied to all h heads. 221 | mask = mask.unsqueeze(1) 222 | nbatches = query.size(0) 223 | 224 | # 1) Do all the linear projections in batch from d_model => h x d_k 225 | 226 | query, key, value = \ 227 | [l(x).view(nbatches, -1, self.h, self.d_k).transpose(1, 2) 228 | for l, x in zip(self.linears, (query, key, value))] 229 | 230 | # 2) Apply attention on all the projected vectors in batch. 231 | x, self.attn = attention(query, key, value, mask=mask, 232 | dropout=self.dropout, posr=posr) 233 | 234 | # 3) "Concat" using a view and apply a final linear. 235 | x = x.transpose(1, 2).contiguous() \ 236 | .view(nbatches, -1, self.h * self.d_k) 237 | 238 | return self.linears[-1](x) 239 | 240 | 241 | class PositionwiseFeedForward(nn.Module): 242 | "Implements FFN equation." 243 | 244 | def __init__(self, d_model, d_ff, dropout): 245 | super(PositionwiseFeedForward, self).__init__() 246 | self.w_1 = nn.Linear(d_model, d_ff) 247 | self.w_2 = nn.Linear(d_ff, d_model) 248 | self.dropout = nn.Dropout(dropout) 249 | 250 | def forward(self, x): 251 | return self.w_2(F.relu(self.dropout(self.w_1(x)))) 252 | 253 | 254 | class Embeddings(nn.Module): 255 | def __init__(self, d_model, vocab): 256 | super(Embeddings, self).__init__() 257 | self.lut = nn.Embedding(vocab, d_model) 258 | self.d_model = d_model 259 | 260 | def forward(self, x): 261 | embed = self.lut(x.long()) * math.sqrt(self.d_model) 262 | return embed 263 | 264 | 265 | 266 | class EncoderPositionalEmbedding(nn.Module): 267 | def __init__(self, dmodel): 268 | super(EncoderPositionalEmbedding, self).__init__() 269 | 270 | self.fc1 = nn.Linear(5, 64) 271 | self.fc2 = nn.Linear(64, 128) 272 | self.fc3 = nn.Linear(128, 256) 273 | 274 | def forward(self, encoder_position): 275 | e = F.relu((self.fc1(encoder_position))) 276 | e = F.relu((self.fc2(e))) 277 | e = F.relu((self.fc3(e))) 278 | 279 | return e 280 | -------------------------------------------------------------------------------- /src/test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | import copy 6 | import argparse 7 | import os 8 | import random 9 | from torch.autograd import Variable 10 | from tqdm import tqdm 11 | import warnings 12 | import sys 13 | 14 | root_path = os.path.abspath(__file__) 15 | root_path = '/'.join(root_path.split('/')[:-2]) 16 | 17 | sys.path.append(root_path) 18 | 19 | from src.model import transformers 20 | 21 | warnings.filterwarnings('ignore') 22 | 23 | 24 | 25 | def getPositionVec(positionList): 26 | xMax = max([item[1] for item in positionList]) 27 | yMax = max([item[3] for item in positionList]) 28 | 29 | finalPosition = [] 30 | for item_g in positionList: 31 | x_1 = item_g[0] 32 | x_2 = item_g[1] 33 | y_1 = item_g[2] 34 | y_2 = item_g[3] 35 | 36 | tmp = [x_1 / xMax, x_2 / xMax, y_1 / yMax, y_2 / yMax, xMax / yMax] 37 | finalPosition.append(tmp) 38 | return finalPosition 39 | 40 | 41 | def getBatchIndex(length, batchSize, shuffle=True): 42 | indexList = list(range(length)) 43 | if shuffle == True: 44 | random.shuffle(indexList) 45 | batchList = [] 46 | tmp = [] 47 | for inidex in indexList: 48 | tmp.append(inidex) 49 | if len(tmp) == batchSize: 50 | batchList.append(tmp) 51 | tmp = [] 52 | if len(tmp) > 0: 53 | batchList.append(tmp) 54 | return batchList 55 | 56 | 57 | def make_tgt_mask(tgt, pad): 58 | "Create a mask to hide padding and future words." 59 | tgt_mask = (tgt != pad).unsqueeze(-2) 60 | tgt_mask = tgt_mask & Variable( 61 | subsequent_mask(tgt.size(-1)).type_as(tgt_mask.data)) 62 | return tgt_mask 63 | 64 | 65 | def subsequent_mask(size): 66 | "Mask out subsequent positions." 67 | attn_shape = (1, size, size) 68 | subsequent_mask = np.triu(np.ones(attn_shape), k=1).astype('uint8') 69 | return torch.from_numpy(subsequent_mask) == 0 70 | 71 | 72 | ####Transformer#### 73 | 74 | 75 | def make_model(tgt_vocab, encoderN=6, decoderN=6, 76 | d_model=256, d_ff=1024, h=8, dropout=0.0): 77 | "Helper: Construct a model from hyperparameters." 78 | c = copy.deepcopy 79 | attn = transformers.MultiHeadedAttention(h, d_model, dropout) 80 | ff = transformers.PositionwiseFeedForward(d_model, d_ff, dropout) 81 | position = transformers.PositionalEncoding(d_model) 82 | model = transformers.EncoderDecoder( 83 | transformers.Encoder(transformers.EncoderLayer(d_model, c(attn), c(ff), dropout), 84 | encoderN), 85 | transformers.Decoder(transformers.DecoderLayer(d_model, c(attn), c(attn), 86 | c(ff), dropout), decoderN), 87 | transformers.EncoderPositionalEmbedding(d_model), 88 | transformers.Embeddings(d_model, tgt_vocab), 89 | transformers.Generator(d_model, tgt_vocab), 90 | ) 91 | 92 | for p in model.parameters(): 93 | if p.dim() > 1: 94 | nn.init.xavier_uniform_(p) 95 | return model 96 | 97 | 98 | def greedy_decode(model, src, src_position, src_mask, max_len): 99 | src = F.relu(model.src_proj(src)) 100 | src_position_embed = model.src_position_embed(src_position) 101 | # src_position_embed = model.encode1(src_position_embed, src_mask) 102 | 103 | src_embed = src_position_embed + src 104 | 105 | memory = model.encode2(src_embed, src_mask, src_position_embed) 106 | 107 | lastWord = torch.ones(len(src), 1).cuda().long() 108 | for i in range(max_len): 109 | tgt_mask = Variable(subsequent_mask(lastWord.size(1)).type_as(src.data)) 110 | tgt_mask = tgt_mask.repeat(src.size(0), 1, 1) 111 | out = model.decode(memory, src_mask, Variable(lastWord), tgt_mask) 112 | prob = model.generator(out[:, -1, :].squeeze(0)).unsqueeze(1) 113 | _, predictTmp = prob.max(dim=-1) 114 | lastWord = torch.cat((lastWord, predictTmp), dim=-1) 115 | prob = model.generator.proj(out) 116 | 117 | return prob 118 | 119 | 120 | class SimpleLossCompute: 121 | "A simple loss compute and train function." 122 | 123 | def __init__(self, generator, criterion): 124 | self.generator = generator 125 | self.criterion = criterion 126 | 127 | def __call__(self, x, y, norm): 128 | x = self.generator(x) 129 | loss = self.criterion(x.contiguous().view(-1, x.size(-1)), 130 | y.contiguous().view(-1)) / norm 131 | 132 | return loss 133 | 134 | 135 | class LabelSmoothing(nn.Module): 136 | "Implement label smoothing." 137 | 138 | def __init__(self, size, padding_idx, smoothing=0.0): 139 | super(LabelSmoothing, self).__init__() 140 | self.criterion = nn.KLDivLoss(size_average=False) 141 | self.padding_idx = padding_idx 142 | self.confidence = 1.0 - smoothing 143 | self.smoothing = smoothing 144 | self.size = size 145 | self.true_dist = None 146 | 147 | def forward(self, x, target): 148 | assert x.size(1) == self.size 149 | true_dist = x.data.clone() 150 | true_dist.fill_(self.smoothing / (self.size - 2)) 151 | true_dist.scatter_(1, target.data.unsqueeze(1), self.confidence) 152 | true_dist[:, self.padding_idx] = 0 153 | mask = torch.nonzero(target.data == self.padding_idx) 154 | if mask.dim() > 0: 155 | true_dist.index_fill_(0, mask.squeeze(), 0.0) 156 | self.true_dist = true_dist 157 | return self.criterion(x, Variable(true_dist, requires_grad=False)) 158 | 159 | 160 | class VGGModel(nn.Module): 161 | def __init__(self): 162 | super(VGGModel, self).__init__() 163 | self.conv1_1 = nn.Conv2d(in_channels=1, out_channels=32, kernel_size=3, padding=1) 164 | self.bn1_1 = nn.BatchNorm2d(32) 165 | self.conv1_2 = nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, padding=1) 166 | self.bn1_2 = nn.BatchNorm2d(32) 167 | self.mp1 = nn.MaxPool2d(2, 2) 168 | 169 | self.conv2_1 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, padding=1) 170 | self.bn2_1 = nn.BatchNorm2d(64) 171 | self.conv2_2 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, padding=1) 172 | self.bn2_2 = nn.BatchNorm2d(64) 173 | self.mp2 = nn.MaxPool2d(2, 2) 174 | 175 | self.conv3_1 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, padding=1) 176 | self.bn3_1 = nn.BatchNorm2d(64) 177 | self.conv3_2 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, padding=1) 178 | self.bn3_2 = nn.BatchNorm2d(64) 179 | self.mp3 = nn.MaxPool2d(2, 2) 180 | # self.linear = nn.Linear(576,256) 181 | # self.cly = nn.Linear(576, label_count) 182 | 183 | def forward(self, x): 184 | x = F.relu(self.bn1_1(self.conv1_1(x))) 185 | x = F.relu(self.bn1_2(self.conv1_2(x))) 186 | x = self.mp1(x) 187 | 188 | x = F.relu(self.bn2_1(self.conv2_1(x))) 189 | x = F.relu(self.bn2_2(self.conv2_2(x))) 190 | x = self.mp2(x) 191 | 192 | x = F.relu(self.bn3_1(self.conv3_1(x))) 193 | x = F.relu(self.bn3_2(self.conv3_2(x))) 194 | x = self.mp3(x) 195 | 196 | x_embed = x.view(x.size()[0], -1) 197 | 198 | return x_embed 199 | 200 | 201 | def process_args(): 202 | parser = argparse.ArgumentParser(description='Get parameters') 203 | 204 | parser.add_argument('--formulas', dest='formulas_file_path', 205 | type=str, required=True, 206 | help='Input formulas.txt path') 207 | 208 | parser.add_argument('--test', dest='test_file_path', 209 | type=str, required=True, 210 | help='Input test.txt path') 211 | 212 | parser.add_argument('--vocab', dest='vocab_file_path', 213 | type=str, required=True, 214 | help='Input latex_vocab.txt path') 215 | 216 | parameters = parser.parse_args() 217 | return parameters 218 | 219 | 220 | 221 | if __name__ == '__main__': 222 | 223 | parameters = process_args() 224 | 225 | f = open(parameters.formulas_file_path, encoding='utf-8').readlines() 226 | labelIndexDic = {} 227 | for item_f in f: 228 | labelIndexDic[item_f.strip().split('\t')[0]] = item_f.strip().split('\t')[1] 229 | 230 | f3 = open(parameters.test_file_path, encoding='utf-8').readlines() # [:100] 231 | 232 | valLabelList = [] 233 | for item_f3 in f3: 234 | if len(item_f3) > 0: 235 | valLabelList.append(labelIndexDic[item_f3.strip()]) 236 | 237 | MAXLENGTH = 150 238 | 239 | f5 = open(parameters.vocab_file_path, encoding='utf-8').readlines() 240 | 241 | PAD = 0 242 | START = 1 243 | END = 2 244 | 245 | index_label_dic = {} 246 | label_index_dic = {} 247 | 248 | i = 3 249 | for item_f5 in f5: 250 | word = item_f5.strip() 251 | if len(word) > 0: 252 | label_index_dic[word] = i 253 | index_label_dic[i] = word 254 | i += 1 255 | label_index_dic['unk'] = i 256 | index_label_dic[i] = 'unk' 257 | i += 1 258 | 259 | labelEmbed_teaching_val = [] 260 | labelEmbed_predict_val = [] 261 | 262 | for item_l in valLabelList: 263 | tmp = [1] 264 | words = item_l.strip().split(' ') 265 | for item_w in words: 266 | if len(item_w) > 0: 267 | if item_w in label_index_dic: 268 | tmp.append(label_index_dic[item_w]) 269 | else: 270 | tmp.append(label_index_dic['unk']) 271 | 272 | labelEmbed_teaching_val.append(tmp) 273 | 274 | tmp = [] 275 | words = item_l.strip().split(' ') 276 | for item_w in words: 277 | if len(item_w) > 0: 278 | if item_w in label_index_dic: 279 | tmp.append(label_index_dic[item_w]) 280 | else: 281 | tmp.append(label_index_dic['unk']) 282 | 283 | tmp.append(2) 284 | labelEmbed_predict_val.append(tmp) 285 | 286 | labelEmbed_teachingArray_val = np.array(labelEmbed_teaching_val) 287 | labelEmbed_predictArray_val = np.array(labelEmbed_predict_val) 288 | 289 | valDataArray = np.load(root_path + '/data/preprocess_data/testData_160.npy') # .tolist() 290 | valPositionArray = np.load(root_path + '/data/preprocess_data/testPosition_160.npy') # .tolist() 291 | 292 | labelLenListVal = [] 293 | for item_pv in labelEmbed_predictArray_val: 294 | count = 0 295 | for item in item_pv: 296 | if item != 0: 297 | count += 1 298 | labelLenListVal.append(count) 299 | 300 | valLabelIndexOrderByLen = np.argsort(np.array(labelLenListVal)).tolist() 301 | 302 | valDataArray = valDataArray[valLabelIndexOrderByLen].tolist() 303 | valPositionArray = valPositionArray[valLabelIndexOrderByLen].tolist() 304 | 305 | labelEmbed_teachingArray_val = labelEmbed_teachingArray_val[valLabelIndexOrderByLen] 306 | 307 | labelEmbed_predictArray_val = labelEmbed_predictArray_val[valLabelIndexOrderByLen] 308 | 309 | BATCH_SIZE = 10 310 | 311 | #####Regularization Parameters#### 312 | dropout = 0.2 313 | l2 = 1e-4 314 | ################# 315 | 316 | 317 | model = make_model(len(index_label_dic) + 3, encoderN=8, decoderN=8, 318 | d_model=256, d_ff=1024, dropout=dropout).cuda() 319 | vgg = VGGModel().cuda() 320 | 321 | model.load_state_dict(torch.load(root_path + '/data/model/model.pkl')) 322 | vgg.load_state_dict(torch.load(root_path + '/data/model/encoder.pkl')) 323 | 324 | param = list(model.parameters()) + list(vgg.parameters()) 325 | 326 | criterion = LabelSmoothing(size=len(label_index_dic) + 3, padding_idx=0, smoothing=0.1) 327 | lossComput = SimpleLossCompute(model.generator, criterion) 328 | 329 | learningRate = 0.001 330 | 331 | totalCount = 0 332 | 333 | exit_count = 0 334 | 335 | bestVal = 0 336 | bestTrainList = 0 337 | 338 | f6 = open(root_path + '/data/result/predict.txt', mode='w') 339 | 340 | while True: 341 | 342 | model.eval() 343 | vgg.eval() 344 | 345 | latex_batch_index = getBatchIndex(len(valLabelList), BATCH_SIZE, shuffle=False) 346 | 347 | latexAccListVal = [] 348 | 349 | with torch.no_grad(): 350 | for batch_i in tqdm(range(len(latex_batch_index))): 351 | 352 | latex_batch = latex_batch_index[batch_i] 353 | 354 | sourceDataTmp = [copy.copy(valDataArray[item]) for item in latex_batch] 355 | sourcePositionTmp = [copy.copy(valPositionArray[item]) for item in latex_batch] 356 | sourceLengthList = [len(item) for item in sourceDataTmp] 357 | sourceMaskTmp = [item * [1] for item in sourceLengthList] 358 | 359 | sourceLengthMax = max(sourceLengthList) 360 | 361 | for i in range(len(sourceDataTmp)): 362 | if len(sourceDataTmp[i]) < sourceLengthMax: 363 | while len(sourceDataTmp[i]) < sourceLengthMax: 364 | sourceDataTmp[i].append(np.zeros((30, 30))) 365 | sourcePositionTmp[i].append([0, 0, 0, 0, 0]) 366 | sourceMaskTmp[i].append(0) 367 | 368 | sourceDataTmp = np.array(sourceDataTmp) 369 | sourcePositionTmp = np.array(sourcePositionTmp) 370 | sourceMaskTmp = np.array(sourceMaskTmp) 371 | 372 | tgt_teaching = labelEmbed_teachingArray_val[latex_batch].tolist() 373 | tgt_predict = labelEmbed_predictArray_val[latex_batch].tolist() 374 | 375 | tgt_teaching_copy = copy.deepcopy(tgt_teaching) 376 | tgt_predict_copy = copy.deepcopy(tgt_predict) 377 | 378 | tgtMaxBatch = 0 379 | for item_tgt in tgt_teaching_copy: 380 | if len(item_tgt) >= tgtMaxBatch: 381 | tgtMaxBatch = len(item_tgt) 382 | 383 | for i in range(len(tgt_teaching_copy)): 384 | while len(tgt_teaching_copy[i]) < tgtMaxBatch: 385 | tgt_teaching_copy[i].append(0) 386 | 387 | for i in range(len(tgt_predict_copy)): 388 | while len(tgt_predict_copy[i]) < tgtMaxBatch: 389 | tgt_predict_copy[i].append(0) 390 | 391 | sourceDataTmpArray = torch.from_numpy(sourceDataTmp).cuda().float() 392 | sourceMaskTmpArray = torch.from_numpy(sourceMaskTmp).cuda().float().unsqueeze(1) 393 | sourcePositionTmpArray = torch.from_numpy(sourcePositionTmp).cuda().float() 394 | 395 | tgt_teachingArray = torch.from_numpy(np.array(tgt_teaching_copy)).cuda().float() 396 | 397 | tgt_teachingMask = make_tgt_mask(tgt_teachingArray, 0) 398 | 399 | tgt_predictArray = torch.from_numpy(np.array(tgt_predict_copy)).cuda().long() 400 | 401 | sourceDataTmpArray_input = sourceDataTmpArray.view(-1, 1, 30, 30) / 255 402 | sourceDataTmpArray = vgg(sourceDataTmpArray_input).view(sourceDataTmpArray.size(0), 403 | sourceDataTmpArray.size(1), -1) 404 | 405 | out = greedy_decode(model, sourceDataTmpArray, sourcePositionTmpArray, sourceMaskTmpArray, 200) 406 | 407 | _, latexPredict = out.max(dim=-1) 408 | 409 | 410 | for i in range(len(latexPredict)): 411 | 412 | fileIndex = f3[valLabelIndexOrderByLen[latex_batch[i]]].strip() 413 | 414 | if 2 in latexPredict[i].tolist(): 415 | endIndex = latexPredict[i].tolist().index(2) 416 | else: 417 | endIndex = MAXLENGTH 418 | 419 | predictTmp = latexPredict[i].tolist()[:endIndex] 420 | labelTmp = tgt_predictArray[i].tolist()[:endIndex] 421 | 422 | predictStr = ' '.join([index_label_dic[item] for item in predictTmp]).strip() 423 | f6.write(fileIndex + '\t' + predictStr + '\n') 424 | 425 | if predictTmp == labelTmp: 426 | latexAccListVal.append(1) 427 | else: 428 | latexAccListVal.append(0) 429 | 430 | valAcc = sum(latexAccListVal) / len(latexAccListVal) 431 | print(valAcc) 432 | exit() 433 | -------------------------------------------------------------------------------- /src/train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | import copy 6 | import argparse 7 | import os 8 | import random 9 | from torch.autograd import Variable 10 | from tqdm import tqdm 11 | from torch import optim 12 | from nltk.translate.bleu_score import sentence_bleu 13 | import warnings 14 | import sys 15 | 16 | root_path = os.path.abspath(__file__) 17 | root_path = '/'.join(root_path.split('/')[:-2]) 18 | 19 | sys.path.append(root_path) 20 | 21 | from src.model import transformers 22 | 23 | warnings.filterwarnings('ignore') 24 | 25 | 26 | def setup_seed(seed): 27 | torch.manual_seed(seed) 28 | torch.cuda.manual_seed_all(seed) 29 | np.random.seed(seed) 30 | random.seed(seed) 31 | torch.backends.cudnn.deterministic = True 32 | 33 | 34 | def getBatchIndex(length, batchSize, shuffle=True): 35 | indexList = list(range(length)) 36 | if shuffle == True: 37 | random.shuffle(indexList) 38 | batchList = [] 39 | tmp = [] 40 | for inidex in indexList: 41 | tmp.append(inidex) 42 | if len(tmp) == batchSize: 43 | batchList.append(tmp) 44 | tmp = [] 45 | if len(tmp) > 0: 46 | batchList.append(tmp) 47 | return batchList 48 | 49 | 50 | def make_tgt_mask(tgt, pad): 51 | "Create a mask to hide padding and future words." 52 | tgt_mask = (tgt != pad).unsqueeze(-2) 53 | tgt_mask = tgt_mask & Variable( 54 | subsequent_mask(tgt.size(-1)).type_as(tgt_mask.data)) 55 | return tgt_mask 56 | 57 | 58 | def subsequent_mask(size): 59 | "Mask out subsequent positions." 60 | attn_shape = (1, size, size) 61 | subsequent_mask = np.triu(np.ones(attn_shape), k=1).astype('uint8') 62 | return torch.from_numpy(subsequent_mask) == 0 63 | 64 | 65 | ####Transformer#### 66 | def make_model(tgt_vocab, encoderN=6, decoderN=6, 67 | d_model=256, d_ff=1024, h=8, dropout=0.0): 68 | "Helper: Construct a model from hyperparameters." 69 | c = copy.deepcopy 70 | attn = transformers.MultiHeadedAttention(h, d_model, dropout) 71 | ff = transformers.PositionwiseFeedForward(d_model, d_ff, dropout) 72 | model = transformers.EncoderDecoder( 73 | transformers.Encoder(transformers.EncoderLayer(d_model, c(attn), c(ff), dropout), 74 | encoderN), 75 | transformers.Decoder(transformers.DecoderLayer(d_model, c(attn), c(attn), 76 | c(ff), dropout), decoderN), 77 | transformers.EncoderPositionalEmbedding(d_model), 78 | transformers.Embeddings(d_model, tgt_vocab), 79 | transformers.Generator(d_model, tgt_vocab), 80 | ) 81 | 82 | for p in model.parameters(): 83 | if p.dim() > 1: 84 | nn.init.xavier_uniform_(p) 85 | return model 86 | 87 | 88 | def greedy_decode(model, src, src_position, src_mask, max_len): 89 | src = F.relu(model.src_proj(src)) 90 | src_position_embed = model.src_position_embed(src_position) 91 | 92 | src_embed = src_position_embed + src 93 | 94 | memory = model.encode2(src_embed, src_mask, src_position_embed) 95 | 96 | lastWord = torch.ones(len(src), 1).cuda().long() 97 | for i in range(max_len): 98 | tgt_mask = Variable(subsequent_mask(lastWord.size(1)).type_as(src.data)) 99 | tgt_mask = tgt_mask.repeat(src.size(0), 1, 1) 100 | out = model.decode(memory, src_mask, Variable(lastWord), tgt_mask) 101 | prob = model.generator(out[:, -1, :].squeeze(0)).unsqueeze(1) 102 | _, predictTmp = prob.max(dim=-1) 103 | lastWord = torch.cat((lastWord, predictTmp), dim=-1) 104 | prob = model.generator.proj(out) 105 | 106 | return prob 107 | 108 | 109 | class SimpleLossCompute: 110 | "A simple loss compute and train function." 111 | 112 | def __init__(self, generator, criterion): 113 | self.generator = generator 114 | self.criterion = criterion 115 | 116 | def __call__(self, x, y, norm): 117 | x = self.generator(x) 118 | loss = self.criterion(x.contiguous().view(-1, x.size(-1)), 119 | y.contiguous().view(-1)) / norm 120 | return loss 121 | 122 | 123 | class LabelSmoothing(nn.Module): 124 | "Implement label smoothing." 125 | 126 | def __init__(self, size, padding_idx, smoothing=0.0): 127 | super(LabelSmoothing, self).__init__() 128 | self.criterion = nn.KLDivLoss(size_average=False) 129 | self.padding_idx = padding_idx 130 | self.confidence = 1.0 - smoothing 131 | self.smoothing = smoothing 132 | self.size = size 133 | self.true_dist = None 134 | 135 | def forward(self, x, target): 136 | assert x.size(1) == self.size 137 | true_dist = x.data.clone() 138 | true_dist.fill_(self.smoothing / (self.size - 2)) 139 | true_dist.scatter_(1, target.data.unsqueeze(1), self.confidence) 140 | true_dist[:, self.padding_idx] = 0 141 | mask = torch.nonzero(target.data == self.padding_idx) 142 | if mask.dim() > 0: 143 | true_dist.index_fill_(0, mask.squeeze(), 0.0) 144 | self.true_dist = true_dist 145 | return self.criterion(x, Variable(true_dist, requires_grad=False)) 146 | 147 | 148 | class Encoder(nn.Module): 149 | def __init__(self): 150 | super(Encoder, self).__init__() 151 | self.conv1_1 = nn.Conv2d(in_channels=1, out_channels=32, kernel_size=3, padding=1) 152 | self.bn1_1 = nn.BatchNorm2d(32) 153 | self.conv1_2 = nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, padding=1) 154 | self.bn1_2 = nn.BatchNorm2d(32) 155 | self.mp1 = nn.MaxPool2d(2, 2) 156 | 157 | self.conv2_1 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, padding=1) 158 | self.bn2_1 = nn.BatchNorm2d(64) 159 | self.conv2_2 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, padding=1) 160 | self.bn2_2 = nn.BatchNorm2d(64) 161 | self.mp2 = nn.MaxPool2d(2, 2) 162 | 163 | self.conv3_1 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, padding=1) 164 | self.bn3_1 = nn.BatchNorm2d(64) 165 | self.conv3_2 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, padding=1) 166 | self.bn3_2 = nn.BatchNorm2d(64) 167 | self.mp3 = nn.MaxPool2d(2, 2) 168 | # self.linear = nn.Linear(576,256) 169 | # self.cly = nn.Linear(576, label_count) 170 | 171 | def forward(self, x): 172 | x = F.relu(self.bn1_1(self.conv1_1(x))) 173 | x = F.relu(self.bn1_2(self.conv1_2(x))) 174 | x = self.mp1(x) 175 | 176 | x = F.relu(self.bn2_1(self.conv2_1(x))) 177 | x = F.relu(self.bn2_2(self.conv2_2(x))) 178 | x = self.mp2(x) 179 | 180 | x = F.relu(self.bn3_1(self.conv3_1(x))) 181 | x = F.relu(self.bn3_2(self.conv3_2(x))) 182 | x = self.mp3(x) 183 | 184 | x_embed = x.view(x.size()[0], -1) 185 | 186 | return x_embed 187 | 188 | 189 | def process_args(): 190 | parser = argparse.ArgumentParser(description='Get parameters') 191 | 192 | parser.add_argument('--formulas', dest='formulas_file_path', 193 | type=str, required=True, 194 | help='Input formulas.txt path') 195 | 196 | parser.add_argument('--train', dest='train_file_path', 197 | type=str, required=True, 198 | help='Input train.txt path') 199 | 200 | parser.add_argument('--val', dest='val_file_path', 201 | type=str, required=True, 202 | help='Input val.txt path') 203 | 204 | parser.add_argument('--vocab', dest='vocab_file_path', 205 | type=str, required=True, 206 | help='Input latex_vocab.txt path') 207 | 208 | parameters = parser.parse_args() 209 | return parameters 210 | 211 | 212 | if __name__ == '__main__': 213 | 214 | parameters = process_args() 215 | 216 | 217 | f = open(parameters.formulas_file_path, encoding='utf-8').readlines() 218 | labelIndexDic = {} 219 | for item_f in f: 220 | labelIndexDic[item_f.strip().split('\t')[0]] = item_f.strip().split('\t')[1] 221 | 222 | f2 = open(parameters.train_file_path, encoding='utf-8').readlines() 223 | f3 = open(parameters.val_file_path, encoding='utf-8').readlines() 224 | 225 | trainLabelList = [] 226 | for item_f2 in f2: 227 | if len(item_f2) > 0: 228 | trainLabelList.append(labelIndexDic[item_f2.strip()]) 229 | 230 | valLabelList = [] 231 | for item_f3 in f3: 232 | if len(item_f3) > 0: 233 | valLabelList.append(labelIndexDic[item_f3.strip()]) 234 | 235 | MAXLENGTH = 150 236 | 237 | f5 = open(parameters.vocab_file_path, encoding='utf-8').readlines() 238 | 239 | PAD = 0 240 | START = 1 241 | END = 2 242 | 243 | index_label_dic = {} 244 | label_index_dic = {} 245 | 246 | i = 3 247 | for item_f5 in f5: 248 | word = item_f5.strip() 249 | if len(word) > 0: 250 | label_index_dic[word] = i 251 | index_label_dic[i] = word 252 | i += 1 253 | label_index_dic['unk'] = i 254 | index_label_dic[i] = 'unk' 255 | i += 1 256 | 257 | labelEmbed_teaching_train = [] 258 | labelEmbed_predict_train = [] 259 | for item_l in trainLabelList: 260 | tmp = [1] 261 | words = item_l.strip().split(' ') 262 | for item_w in words: 263 | if len(item_w) > 0: 264 | if item_w in label_index_dic: 265 | tmp.append(label_index_dic[item_w]) 266 | else: 267 | tmp.append(label_index_dic['unk']) 268 | 269 | labelEmbed_teaching_train.append(tmp) 270 | 271 | tmp = [] 272 | words = item_l.strip().split(' ') 273 | for item_w in words: 274 | if len(item_w) > 0: 275 | if item_w in label_index_dic: 276 | tmp.append(label_index_dic[item_w]) 277 | else: 278 | tmp.append(label_index_dic['unk']) 279 | 280 | tmp.append(2) 281 | labelEmbed_predict_train.append(tmp) 282 | 283 | labelEmbed_teaching_val = [] 284 | labelEmbed_predict_val = [] 285 | 286 | for item_l in valLabelList: 287 | tmp = [1] 288 | words = item_l.strip().split(' ') 289 | for item_w in words: 290 | if len(item_w) > 0: 291 | if item_w in label_index_dic: 292 | tmp.append(label_index_dic[item_w]) 293 | else: 294 | tmp.append(label_index_dic['unk']) 295 | 296 | labelEmbed_teaching_val.append(tmp) 297 | 298 | tmp = [] 299 | words = item_l.strip().split(' ') 300 | for item_w in words: 301 | if len(item_w) > 0: 302 | if item_w in label_index_dic: 303 | tmp.append(label_index_dic[item_w]) 304 | else: 305 | tmp.append(label_index_dic['unk']) 306 | 307 | tmp.append(2) 308 | labelEmbed_predict_val.append(tmp) 309 | 310 | labelEmbed_teachingArray_val = np.array(labelEmbed_teaching_val) 311 | labelEmbed_predictArray_val = np.array(labelEmbed_predict_val) 312 | 313 | trainDataArray = np.load(root_path + '/data/preprocess_data/trainData') 314 | trainPositionArray = np.load(root_path + '/data/preprocess_data/trainPosition') 315 | 316 | trainIndexList = np.load(root_path + '/data/preprocess_data/trainIndexList').tolist() 317 | trainImgList = np.load(root_path + '/data/preprocess_data/trainImgList') 318 | 319 | labelEmbed_teaching_train_copy = [] 320 | labelEmbed_predict_train_copy = [] 321 | for item_ti in trainIndexList: 322 | labelEmbed_teaching_train_copy.append(labelEmbed_teaching_train[item_ti]) 323 | labelEmbed_predict_train_copy.append(labelEmbed_predict_train[item_ti]) 324 | 325 | labelLenListTrain = [] 326 | for item_pv in labelEmbed_teaching_train_copy: 327 | count = 0 328 | for item in item_pv: 329 | if item != 0: 330 | count += 1 331 | labelLenListTrain.append(count) 332 | 333 | trainLabelIndexOrderByLen = np.argsort(np.array(labelLenListTrain)).tolist() 334 | 335 | trainDataArray = trainDataArray[trainLabelIndexOrderByLen].tolist() 336 | trainPositionArray = trainPositionArray[trainLabelIndexOrderByLen].tolist() 337 | 338 | labelEmbed_teachingArray_train = np.array(labelEmbed_teaching_train_copy)[trainLabelIndexOrderByLen] 339 | labelEmbed_predictArray_train = np.array(labelEmbed_predict_train_copy)[trainLabelIndexOrderByLen] 340 | 341 | valDataArray = np.load(root_path + '/data/preprocess_data/valData_160') # .tolist() 342 | valPositionArray = np.load(root_path + '/data/preprocess_data/valPosition_160') # .tolist() 343 | 344 | labelLenListVal = [] 345 | for item_pv in labelEmbed_predictArray_val: 346 | count = 0 347 | for item in item_pv: 348 | if item != 0: 349 | count += 1 350 | labelLenListVal.append(count) 351 | 352 | valLabelIndexOrderByLen = np.argsort(np.array(labelLenListVal)).tolist() 353 | 354 | valDataArray = valDataArray[valLabelIndexOrderByLen].tolist() 355 | valPositionArray = valPositionArray[valLabelIndexOrderByLen].tolist() 356 | 357 | labelEmbed_teachingArray_val = labelEmbed_teachingArray_val[valLabelIndexOrderByLen] 358 | 359 | labelEmbed_predictArray_val = labelEmbed_predictArray_val[valLabelIndexOrderByLen] 360 | 361 | BATCH_SIZE = 16 362 | 363 | #####Regularization Parameters#### 364 | dropout = 0.2 365 | l2 = 1e-4 366 | ################# 367 | 368 | 369 | model = make_model(len(index_label_dic) + 3, encoderN=8, decoderN=8, 370 | d_model=256, d_ff=1024, dropout=dropout).cuda() 371 | encoder = Encoder().cuda() 372 | 373 | param = list(model.parameters()) + list(encoder.parameters()) 374 | 375 | criterion = LabelSmoothing(size=len(label_index_dic) + 3, padding_idx=0, smoothing=0.1) 376 | lossComput = SimpleLossCompute(model.generator, criterion) 377 | 378 | learningRate = 3e-4 379 | 380 | totalCount = 0 381 | 382 | exit_count = 0 383 | 384 | bestVal = 0 385 | bestTrainList = 0 386 | 387 | criterionVal = nn.CrossEntropyLoss(ignore_index=0, size_average=True).cuda() 388 | 389 | while True: 390 | 391 | model.train() 392 | encoder.train() 393 | 394 | optimizer = optim.Adam(param, lr=learningRate, weight_decay=l2) 395 | 396 | latex_batch_index = getBatchIndex(len(trainIndexList), BATCH_SIZE, shuffle=False) 397 | random.shuffle(latex_batch_index) 398 | 399 | lossListTrain = [] 400 | latexAccListTrain = [] 401 | bleuListTrain = [] 402 | 403 | for batch_i in tqdm(range(len(latex_batch_index))): 404 | 405 | latex_batch = latex_batch_index[batch_i] 406 | 407 | sourceDataTmp = [copy.copy(trainDataArray[item]) for item in latex_batch] 408 | sourcePositionTmp = [copy.copy(trainPositionArray[item]) for item in latex_batch] 409 | sourceLengthList = [len(item) for item in sourceDataTmp] 410 | sourceMaskTmp = [item * [1] for item in sourceLengthList] 411 | 412 | sourceDataTmp = [trainImgList[item] for item in sourceDataTmp] 413 | 414 | sourceLengthMax = max(sourceLengthList) 415 | 416 | for i in range(len(sourceDataTmp)): 417 | if len(sourceDataTmp[i]) < sourceLengthMax: 418 | while len(sourceDataTmp[i]) < sourceLengthMax: 419 | sourceDataTmp[i] = np.concatenate((sourceDataTmp[i], np.zeros((1, 30, 30))), axis=0) 420 | sourcePositionTmp[i].append([0, 0, 0, 0, 0]) 421 | sourceMaskTmp[i].append(0) 422 | 423 | sourceDataTmp = np.array(sourceDataTmp) 424 | sourcePositionTmp = np.array(sourcePositionTmp) 425 | sourceMaskTmp = np.array(sourceMaskTmp) 426 | 427 | tgt_teaching = labelEmbed_teachingArray_train[latex_batch].tolist() 428 | tgt_predict = labelEmbed_predictArray_train[latex_batch].tolist() 429 | 430 | tgt_teaching_copy = copy.deepcopy(tgt_teaching) 431 | tgt_predict_copy = copy.deepcopy(tgt_predict) 432 | 433 | tgtMaxBatch = 0 434 | for item_tgt in tgt_teaching_copy: 435 | if len(item_tgt) >= tgtMaxBatch: 436 | tgtMaxBatch = len(item_tgt) 437 | 438 | for i in range(len(tgt_teaching_copy)): 439 | while len(tgt_teaching_copy[i]) < tgtMaxBatch: 440 | tgt_teaching_copy[i].append(0) 441 | 442 | for i in range(len(tgt_predict_copy)): 443 | while len(tgt_predict_copy[i]) < tgtMaxBatch: 444 | tgt_predict_copy[i].append(0) 445 | 446 | sourceDataTmpArray = torch.from_numpy(sourceDataTmp).cuda().float() 447 | sourceMaskTmpArray = torch.from_numpy(sourceMaskTmp).cuda().float().unsqueeze(1) 448 | sourcePositionTmpArray = torch.from_numpy(sourcePositionTmp).cuda().float() 449 | 450 | tgt_teachingArray = torch.from_numpy(np.array(tgt_teaching_copy)).cuda().float() 451 | 452 | tgt_teachingMask = make_tgt_mask(tgt_teachingArray, 0) 453 | 454 | tgt_predictArray = torch.from_numpy(np.array(tgt_predict_copy)).cuda().long() 455 | 456 | sourceDataTmpArray_input = sourceDataTmpArray.view(-1, 1, 30, 30) / 255 457 | 458 | sourceDataTmpArray = encoder(sourceDataTmpArray_input).view(sourceDataTmpArray.size(0), 459 | sourceDataTmpArray.size(1), -1) 460 | 461 | out = model.forward(sourceDataTmpArray, sourcePositionTmpArray, tgt_teachingArray, 462 | sourceMaskTmpArray, tgt_teachingMask) 463 | 464 | _, latexPredict = model.generator(out).max(dim=-1) 465 | 466 | for i in range(len(latexPredict)): 467 | 468 | if 2 in tgt_predictArray[i].tolist(): 469 | endIndex = tgt_predictArray[i].tolist().index(2) 470 | else: 471 | endIndex = MAXLENGTH - 1 472 | 473 | predictTmp = latexPredict[i].tolist()[:endIndex + 1] 474 | labelTmp = tgt_predictArray[i].tolist()[:endIndex + 1] 475 | 476 | if predictTmp == labelTmp: 477 | latexAccListTrain.append(1) 478 | else: 479 | latexAccListTrain.append(0) 480 | 481 | bleuScore = sentence_bleu([labelTmp], predictTmp) 482 | bleuListTrain.append(bleuScore) 483 | 484 | loss = lossComput(out, tgt_predictArray, len(latex_batch)) 485 | 486 | lossListTrain.append(loss.item()) 487 | 488 | optimizer.zero_grad() 489 | loss.backward() 490 | optimizer.step() 491 | 492 | trainLoss = sum(lossListTrain) / len(lossListTrain) 493 | trainAcc = sum(latexAccListTrain) / len(latexAccListTrain) 494 | trainBleu = sum(bleuListTrain) / len(bleuListTrain) 495 | 496 | model.eval() 497 | encoder.eval() 498 | 499 | latex_batch_index = getBatchIndex(len(valLabelList), BATCH_SIZE, shuffle=False) 500 | 501 | latexAccListVal = [] 502 | latexLossListVal = [] 503 | 504 | with torch.no_grad(): 505 | for batch_i in tqdm(range(len(latex_batch_index))): 506 | 507 | latex_batch = latex_batch_index[batch_i] 508 | 509 | sourceDataTmp = [copy.copy(valDataArray[item]) for item in latex_batch] 510 | sourcePositionTmp = [copy.copy(valPositionArray[item]) for item in latex_batch] 511 | sourceLengthList = [len(item) for item in sourceDataTmp] 512 | sourceMaskTmp = [item * [1] for item in sourceLengthList] 513 | 514 | sourceLengthMax = max(sourceLengthList) 515 | 516 | for i in range(len(sourceDataTmp)): 517 | if len(sourceDataTmp[i]) < sourceLengthMax: 518 | while len(sourceDataTmp[i]) < sourceLengthMax: 519 | sourceDataTmp[i].append(np.zeros((30, 30))) 520 | sourcePositionTmp[i].append([0, 0, 0, 0, 0]) 521 | sourceMaskTmp[i].append(0) 522 | 523 | sourceDataTmp = np.array(sourceDataTmp) 524 | sourcePositionTmp = np.array(sourcePositionTmp) 525 | sourceMaskTmp = np.array(sourceMaskTmp) 526 | 527 | tgt_teaching = labelEmbed_teachingArray_val[latex_batch].tolist() 528 | tgt_predict = labelEmbed_predictArray_val[latex_batch].tolist() 529 | 530 | tgt_teaching_copy = copy.deepcopy(tgt_teaching) 531 | tgt_predict_copy = copy.deepcopy(tgt_predict) 532 | 533 | tgtMaxBatch = 0 534 | for item_tgt in tgt_teaching_copy: 535 | if len(item_tgt) >= tgtMaxBatch: 536 | tgtMaxBatch = len(item_tgt) 537 | 538 | for i in range(len(tgt_teaching_copy)): 539 | while len(tgt_teaching_copy[i]) < tgtMaxBatch: 540 | tgt_teaching_copy[i].append(0) 541 | 542 | for i in range(len(tgt_predict_copy)): 543 | while len(tgt_predict_copy[i]) < tgtMaxBatch: 544 | tgt_predict_copy[i].append(0) 545 | 546 | sourceDataTmpArray = torch.from_numpy(sourceDataTmp).cuda().float() 547 | sourceMaskTmpArray = torch.from_numpy(sourceMaskTmp).cuda().float().unsqueeze(1) 548 | sourcePositionTmpArray = torch.from_numpy(sourcePositionTmp).cuda().float() 549 | 550 | tgt_teachingArray = torch.from_numpy(np.array(tgt_teaching_copy)).cuda().float() 551 | 552 | tgt_teachingMask = make_tgt_mask(tgt_teachingArray, 0) 553 | 554 | tgt_predictArray = torch.from_numpy(np.array(tgt_predict_copy)).cuda().long() 555 | 556 | sourceDataTmpArray_input = sourceDataTmpArray.view(-1, 1, 30, 30) / 255 557 | sourceDataTmpArray = encoder(sourceDataTmpArray_input).view(sourceDataTmpArray.size(0), 558 | sourceDataTmpArray.size(1), -1) 559 | 560 | out = greedy_decode(model, sourceDataTmpArray, sourcePositionTmpArray, sourceMaskTmpArray, tgtMaxBatch) 561 | 562 | _, latexPredict = out.max(dim=-1) 563 | 564 | 565 | for i in range(len(latexPredict)): 566 | if 2 in latexPredict[i].tolist(): 567 | endIndex = latexPredict[i].tolist().index(2) 568 | else: 569 | endIndex = MAXLENGTH 570 | 571 | predictTmp = latexPredict[i].tolist()[:endIndex] 572 | labelTmp = tgt_predictArray[i].tolist()[:endIndex] 573 | 574 | if predictTmp == labelTmp: 575 | latexAccListVal.append(1) 576 | else: 577 | latexAccListVal.append(0) 578 | 579 | out = out.contiguous().view(-1, out.size(-1)) 580 | 581 | targets = tgt_predictArray.view(-1) 582 | loss = criterionVal(out, targets) 583 | 584 | latexLossListVal.append(loss.item()) 585 | 586 | valAcc = sum(latexAccListVal) / len(latexAccListVal) 587 | valLoss = sum(latexLossListVal) / len(latexLossListVal) 588 | 589 | if valAcc > bestVal: 590 | torch.save(model.state_dict(), root_path + '/data/model/model.pkl') 591 | torch.save(encoder.state_dict(), root_path + '/data/model/encoder.pkl') 592 | bestVal = valAcc 593 | exit_count = 0 594 | else: 595 | exit_count += 1 596 | 597 | if exit_count > 0 and exit_count % 3 == 0: 598 | learningRate *= 0.5 599 | 600 | if exit_count == 10: 601 | exit() 602 | 603 | print("Epoch:" + str(totalCount) + "\t TrainingSet:" + '\t loss:' + str(trainLoss) + "\t ACC:" + str( 604 | trainAcc) + "\t BLEU:" 605 | + str(trainBleu) + 606 | "\t ValSet:" + "\t ACC:" + str(valAcc) + "\t bestAcc:" + str( 607 | bestVal) + "\t learningRate:" 608 | + str(learningRate) + "\t exit_count:" 609 | + str(exit_count)) 610 | 611 | totalCount += 1 612 | -------------------------------------------------------------------------------- /src/utils/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/abcAnonymous/EDSL/e8b4d0e597aa7ff335c070e80853f33af52837f4/src/utils/.DS_Store -------------------------------------------------------------------------------- /src/utils/ImgCandidate.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from PIL import Image 3 | import copy 4 | import math 5 | from skimage import measure 6 | 7 | 8 | 9 | def deletePadding(imgNp): 10 | 11 | 12 | imgNp[imgNp > 220] = 255 13 | 14 | binary = imgNp != 255 15 | sum_0 = np.sum(binary,axis=0) 16 | sum_1 = np.sum(binary,axis=1) 17 | sum_0_left = min(np.argwhere(sum_0 != 0).tolist())[0] 18 | sum_0_right = max(np.argwhere(sum_0 != 0).tolist())[0] 19 | sum_1_left = min(np.argwhere(sum_1 != 0).tolist())[0] 20 | sum_1_right = max(np.argwhere(sum_1 != 0).tolist())[0] 21 | delImg = imgNp[sum_1_left:sum_1_right + 1,sum_0_left :sum_0_right+ 1] 22 | 23 | return delImg 24 | 25 | def deleSpace(imgNp): 26 | binary = imgNp != 255 27 | sum_0 = np.sum(binary, axis=0).tolist() 28 | index_0 = [] 29 | 30 | for i in range(len(sum_0)): 31 | if sum_0[i] != 0: 32 | index_0.append(i) 33 | 34 | imgNp = imgNp[:,index_0] 35 | 36 | return imgNp 37 | 38 | 39 | def imgResize(arr,square=30): 40 | 41 | a_height = arr.shape[0] 42 | a_weight = arr.shape[1] 43 | 44 | 45 | ratio = square / max(a_height,a_weight) 46 | arr = np.array(Image.fromarray(arr).resize((math.ceil(a_weight*ratio),math.ceil(a_height*ratio)),Image.ANTIALIAS)) 47 | 48 | a_height = arr.shape[0] 49 | a_weight = arr.shape[1] 50 | 51 | if a_height < square: 52 | h_1 = int((square - a_height)/2) 53 | h_2 = square - a_height- h_1 54 | if h_2!=0: 55 | arr = np.vstack((np.ones((h_1,a_weight)) * 255,arr,np.ones((h_2,a_weight)) * 255)) 56 | else: 57 | arr = np.vstack((np.ones((h_1, a_weight)) * 255, arr)) 58 | if a_weight THREASHOLD) & (img < 255) 76 | 77 | 78 | binary = binary * (splitPoint == False) 79 | label = (measure.label(binary, connectivity=2)) 80 | 81 | 82 | label_add_split = copy.copy(label) 83 | times = 0 84 | while True in splitPoint and times < 10 : 85 | for item in np.argwhere(splitPoint == True): 86 | top = item[0] - 1 if item[0] - 1 > 0 else 0 87 | down = item[0] + 2 if item[0] + 2 < label.shape[0] else label.shape[0] 88 | left = item[1] - 1 if item[1] - 1 > 0 else 0 89 | right = item[1] + 2 if item[1] + 2 < label.shape[1] else label.shape[1] 90 | 91 | area = label[top:down,left:right].reshape(-1) 92 | count = np.bincount(area) 93 | if len(count) > 1: 94 | count[0] = 0 95 | label_add_split[item[0],item[1]] = np.argmax(count) 96 | splitPoint[item[0],item[1]] = False 97 | label = label_add_split 98 | times += 1 99 | 100 | if True in splitPoint: 101 | for item in np.argwhere(splitPoint == True): 102 | label_add_split[item[0], item[1]] = 0 103 | 104 | 105 | labelCount = np.max(label_add_split) 106 | 107 | 108 | labelPosition = [] 109 | 110 | for i in range(labelCount): 111 | 112 | tmpLeft = np.min(np.where(label_add_split == i + 1)[1]) 113 | tmpRight = np.max(np.where(label_add_split == i + 1)[1]) 114 | tmpTop = np.min(np.where(label_add_split == i + 1)[0]) 115 | tmpDown = np.max(np.where(label_add_split == i + 1)[0]) 116 | 117 | imgTmp = copy.copy(img) 118 | imgTmp[label_add_split != i+1] = 255 119 | imgTmp = deletePadding(imgTmp) 120 | imgTmp = imgResize(imgTmp,24) 121 | 122 | pad1 = np.ones((3, imgTmp.shape[1])) * 255 123 | imgTmp = np.concatenate((pad1,imgTmp),axis=0) 124 | imgTmp = np.concatenate((imgTmp,pad1),axis=0) 125 | 126 | pad2 = np.ones((imgTmp.shape[0],3)) * 255 127 | imgTmp = np.concatenate((pad2, imgTmp), axis=1) 128 | imgTmp = np.concatenate((imgTmp, pad2), axis=1) 129 | 130 | 131 | labelPosition.append((imgTmp,(tmpTop, tmpDown+1, tmpLeft, tmpRight+1))) 132 | 133 | 134 | return labelPosition 135 | -------------------------------------------------------------------------------- /src/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/abcAnonymous/EDSL/e8b4d0e597aa7ff335c070e80853f33af52837f4/src/utils/__init__.py --------------------------------------------------------------------------------