├── .gitattributes
├── README.md
├── data
├── .DS_Store
└── model
│ ├── .DS_Store
│ ├── encoder.pkl
│ └── model.pkl
├── scrips
├── .DS_Store
├── evaluation
│ ├── .DS_Store
│ ├── Cal_B4.py
│ ├── Cal_R4.py
│ ├── CalculateMatch-WS.py
│ └── CalculateMatch.py
└── preprocessing
│ ├── .DS_Store
│ ├── preprocess_test.py
│ ├── preprocess_train.py
│ └── preprocess_val.py
└── src
├── .DS_Store
├── __init__.py
├── model
├── __init__.py
└── transformers.py
├── test.py
├── train.py
└── utils
├── .DS_Store
├── ImgCandidate.py
└── __init__.py
/.gitattributes:
--------------------------------------------------------------------------------
1 | # Auto detect text files and perform LF normalization
2 | * text=auto
3 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # EDSL
2 | EDSL code
3 |
4 | ## Prerequsites
5 | 1. Pytorch
6 | 2. Numpy, Pillow
7 | 3. Pdflatex
8 | 4. ImageMagick
9 |
10 | ## Dataset Download
11 | ME-98K dataset: https://pan.baidu.com/s/1itQEjPUMve3A3dXezJDedw password: axqh
12 |
13 | ME-20K dataset: https://pan.baidu.com/s/1ti1xfCdj6c36yvy_sHld5Q password: nwks
14 |
15 |
16 | ## Quick Start
17 | ### 1. prepocess
18 |
19 | Preprocessing of training set.
20 |
21 | `python scrips/preprocessing/preprocess_train.py --formulas data/sample/formulas.txt --train data/sample/train.txt --vocab data/sample/latex_vocab.txt --img data/sample/image_processed/`
22 |
23 | Preprocessing of validation set
24 |
25 | `python scrips/preprocessing/preprocess_val.py --formulas data/sample/formulas.txt --val data/sample/train.txt --vocab data/sample/latex_vocab.txt --img data/sample/image_processed/`
26 |
27 | Preprocessing of test set
28 |
29 | `python scrips/preprocessing/preprocess_test.py --formulas data/sample/formulas.txt --test data/sample/test.txt --vocab data/sample/latex_vocab.txt --img data/sample/image_processed/`
30 |
31 | ### 2. training model
32 | `python src/train.py --formulas data/sample/formulas.txt --train data/sample/train.txt --val data/sample/val.txt --vocab data/sample/latex_vocab.txt`
33 |
34 |
35 | ### 3. testing
36 | `python src/train.py --formulas data/sample/formulas.txt --test data/sample/test.txt --vocab data/sample/latex_vocab.txt`
37 |
38 |
39 | ### 4. evaluation
40 | BLEU-4 Calculation:
41 |
42 | `python scrips/evaluation/Cal_B4.py --formulas data/sample/formulas.txt`
43 |
44 | Rouge-4 Calculation:
45 |
46 | `python scrips/evaluation/Cal_R4.py --formulas data/sample/formulas.txt`
47 |
48 | Match Calculation:
49 |
50 | `python scrips/evaluation/CalculateMath.py --formulas data/sample/formulas.txt`
51 |
52 | Match-ws Calculation:
53 |
54 | `python scrips/evaluation/CalculateMath-WS.py --formulas data/sample/formulas.txt`
55 |
--------------------------------------------------------------------------------
/data/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/abcAnonymous/EDSL/e8b4d0e597aa7ff335c070e80853f33af52837f4/data/.DS_Store
--------------------------------------------------------------------------------
/data/model/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/abcAnonymous/EDSL/e8b4d0e597aa7ff335c070e80853f33af52837f4/data/model/.DS_Store
--------------------------------------------------------------------------------
/data/model/encoder.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/abcAnonymous/EDSL/e8b4d0e597aa7ff335c070e80853f33af52837f4/data/model/encoder.pkl
--------------------------------------------------------------------------------
/data/model/model.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/abcAnonymous/EDSL/e8b4d0e597aa7ff335c070e80853f33af52837f4/data/model/model.pkl
--------------------------------------------------------------------------------
/scrips/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/abcAnonymous/EDSL/e8b4d0e597aa7ff335c070e80853f33af52837f4/scrips/.DS_Store
--------------------------------------------------------------------------------
/scrips/evaluation/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/abcAnonymous/EDSL/e8b4d0e597aa7ff335c070e80853f33af52837f4/scrips/evaluation/.DS_Store
--------------------------------------------------------------------------------
/scrips/evaluation/Cal_B4.py:
--------------------------------------------------------------------------------
1 | from nltk.translate.bleu_score import sentence_bleu
2 | from tqdm import tqdm
3 | import argparse
4 |
5 | import sys
6 | import os
7 |
8 | root_path = os.path.abspath(__file__)
9 | root_path = '/'.join(root_path.split('/')[:-3])
10 | sys.path.append(root_path)
11 |
12 |
13 | def formula_to_index(formula, dic):
14 | tmp = []
15 | for word in formula.strip().split(' '):
16 | if len(word.strip()) > 0:
17 | if word.strip() in dic:
18 | tmp.append(dic[word.strip()])
19 | else:
20 | tmp.append(dic['UNK'])
21 | return tmp
22 |
23 |
24 | def process_args():
25 | parser = argparse.ArgumentParser(description='Get parameters')
26 |
27 | parser.add_argument('--formulas', dest='formulas_file_path',
28 | type=str, required=True,
29 | help='Input formulas.txt path')
30 |
31 | parameters = parser.parse_args()
32 | return parameters
33 |
34 |
35 | if __name__ == '__main__':
36 |
37 | parameters = process_args()
38 |
39 | f = open(parameters.formulas_file_path,
40 | encoding='utf-8').readlines()
41 |
42 | labelIndexDic = {}
43 | for item_f in f:
44 | labelIndexDic[item_f.strip().split('\t')[0].strip()] = item_f.strip().split('\t')[1] \
45 | .strip()
46 |
47 | predictDic = {}
48 | f2 = open(root_path + '/data/result/predict.txt', encoding='utf-8').readlines()
49 |
50 | for item_f2 in f2:
51 | index = item_f2.strip().split('\t')[0]
52 | formula = item_f2.strip().split('\t')[1]
53 | predictDic[index] = formula
54 |
55 | bleuList = []
56 |
57 | for item_p in tqdm(predictDic):
58 | predict = predictDic[item_p].strip().split(' ')
59 | label = labelIndexDic[item_p].strip().split(' ')
60 |
61 | if len(label) >= 4:
62 | if len(predict) < 4:
63 | bleuList.append(0)
64 | else:
65 | tmpBleu1 = sentence_bleu([label], predict, weights=(0, 0, 0, 1))
66 | bleuList.append(tmpBleu1)
67 |
68 | print("BLEU-4:")
69 | print(sum(bleuList) / len(bleuList))
70 |
--------------------------------------------------------------------------------
/scrips/evaluation/Cal_R4.py:
--------------------------------------------------------------------------------
1 | from tqdm import tqdm
2 |
3 | import argparse
4 |
5 | import sys
6 | import os
7 |
8 | root_path = os.path.abspath(__file__)
9 | root_path = '/'.join(root_path.split('/')[:-3])
10 | sys.path.append(root_path)
11 |
12 |
13 | def formula_to_index(formula, dic):
14 | tmp = []
15 | for word in formula.strip().split(' '):
16 | if len(word.strip()) > 0:
17 | if word.strip() in dic:
18 | tmp.append(dic[word.strip()])
19 | else:
20 | tmp.append(dic['UNK'])
21 | return tmp
22 |
23 |
24 | def getRouge_N(predict, label, n):
25 | predictCanList = []
26 | labelCanList = []
27 |
28 | for i in range(len(predict)):
29 | tmp = predict[i:i + n]
30 | if len(tmp) == n:
31 | predictCanList.append(tmp)
32 |
33 | for i in range(len(label)):
34 | tmp = label[i:i + n]
35 | if len(tmp) == n:
36 | labelCanList.append(tmp)
37 |
38 | len_labenCanList = len(labelCanList)
39 |
40 | if len_labenCanList == 0:
41 | return None
42 | else:
43 | countList = []
44 |
45 | while len(predictCanList) > 0:
46 | try:
47 | index = labelCanList.index(predictCanList[0])
48 | countList.append(predictCanList[0])
49 | del labelCanList[index]
50 | except:
51 | pass
52 |
53 | del predictCanList[0]
54 |
55 | rouge_n = len(countList) / len_labenCanList
56 | return rouge_n
57 |
58 |
59 | def process_args():
60 | parser = argparse.ArgumentParser(description='Get parameters')
61 |
62 | parser.add_argument('--formulas', dest='formulas_file_path',
63 | type=str, required=True,
64 | help='Input formulas.txt path')
65 |
66 | parameters = parser.parse_args()
67 | return parameters
68 |
69 |
70 | if __name__ == '__main__':
71 |
72 | parameters = process_args()
73 |
74 | f = open(parameters.formulas_file_path,
75 | encoding='utf-8').readlines()
76 |
77 | labelIndexDic = {}
78 | for item_f in f:
79 | labelIndexDic[item_f.strip().split('\t')[0].strip()] = item_f.strip().split('\t')[1] \
80 | .strip()
81 |
82 | predictDic = {}
83 | f2 = open(root_path + '/data/result/predict.txt', encoding='utf-8').readlines()
84 |
85 | for item_f2 in f2:
86 | index = item_f2.strip().split('\t')[0]
87 | formula = item_f2.strip().split('\t')[1]
88 | predictDic[index] = formula
89 |
90 | roughList = []
91 |
92 | for item_p in tqdm(predictDic):
93 | predict = predictDic[item_p].strip().split(' ')
94 | label = labelIndexDic[item_p].strip().split(' ')
95 |
96 | rougeN = getRouge_N(predict, label, 4)
97 |
98 | if rougeN == None:
99 | pass
100 | else:
101 | roughList.append(rougeN)
102 |
103 | print("ROUGE-4:")
104 | print(sum(roughList) / len(roughList))
105 |
--------------------------------------------------------------------------------
/scrips/evaluation/CalculateMatch-WS.py:
--------------------------------------------------------------------------------
1 | import subprocess
2 | from PIL import Image
3 | from tqdm import tqdm
4 |
5 | import argparse
6 | from src.utils import ImgCandidate
7 |
8 | import sys
9 | import os
10 |
11 | root_path = os.path.abspath(__file__)
12 | root_path = '/'.join(root_path.split('/')[:-3])
13 | sys.path.append(root_path)
14 |
15 |
16 | def process_args():
17 | parser = argparse.ArgumentParser(description='Get parameters')
18 |
19 | parser.add_argument('--formulas', dest='formulas_file_path',
20 | type=str, required=True,
21 | help='Input formulas.txt path')
22 |
23 | parameters = parser.parse_args()
24 | return parameters
25 |
26 |
27 | if __name__ == '__main__':
28 |
29 | parameters = process_args()
30 |
31 | f = open(root_path + '/data/result/predict.txt').readlines()
32 | f2 = open(parameters.formulas_file_path, encoding='utf-8').readlines()
33 |
34 | formulaDic = {}
35 | for item_f2 in f2:
36 | formulaDic[item_f2.strip().split('\t')[0]] = item_f2.strip().split('\t')[1]
37 |
38 | accList = []
39 |
40 | for item_f in tqdm(f):
41 | index = item_f.strip().split('\t')[0]
42 | formula = item_f.strip().split('\t')[1]
43 | labelFormula = formulaDic[index].strip()
44 |
45 | if formula == labelFormula:
46 | accList.append(1)
47 | else:
48 |
49 | pdfText = r'\documentclass{article}' + '\n' + r'\usepackage{amsmath,amssymb}' + '\n' + '\pagestyle{empty}' + '\n' + \
50 | r'\thispagestyle{empty}' + '\n' + r'\begin{document}' + '\n' + r'\begin{equation*}' + '\n' + formula + \
51 | r'\end{equation*}' + '\n' + '\end{document}'
52 | f3 = open('predict.tex', mode='w')
53 | f3.write(pdfText)
54 | f3.close()
55 | sub = subprocess.Popen("pdflatex -halt-on-error " + "predict.tex", shell=True, stdout=subprocess.PIPE)
56 | sub.wait()
57 |
58 | pdfFiles = []
59 | for _, _, pf in os.walk(os.getcwd()):
60 | pdfFiles = pf
61 | break
62 |
63 | if 'predict.pdf' in pdfFiles:
64 | try:
65 | pdfText = r'\documentclass{article}' + '\n' + r'\usepackage{amsmath,amssymb}' + '\n' + '\pagestyle{empty}' + '\n' + \
66 | r'\thispagestyle{empty}' + '\n' + r'\begin{document}' + '\n' + r'\begin{equation*}' + '\n' + labelFormula + \
67 | r'\end{equation*}' + '\n' + '\end{document}'
68 | f3 = open('label.tex', mode='w')
69 | f3.write(pdfText)
70 | f3.close()
71 | sub = subprocess.Popen("pdflatex -halt-on-error " + "label.tex", shell=True, stdout=subprocess.PIPE)
72 | sub.wait()
73 |
74 | os.system(
75 | 'convert -background white -density 200 -quality 100 -strip ' + 'label.pdf ' + 'label.png')
76 | os.system(
77 | 'convert -background white -density 200 -quality 100 -strip ' + 'predict.pdf ' + 'predict.png')
78 | label = ImgCandidate.deleSpace(
79 | ImgCandidate.deletePadding(np.array(Image.open('label.png').convert('L')))).tolist()
80 | predict = ImgCandidate.deleSpace(
81 | ImgCandidate.deletePadding(np.array(Image.open('predict.png').convert('L')))).tolist()
82 |
83 | if label == predict:
84 | accList.append(1)
85 | else:
86 | accList.append(0)
87 | except:
88 | accList.append(0)
89 |
90 | else:
91 | accList.append(0)
92 | print(sum(accList) / len(accList))
93 |
94 | os.system('rm -rf *.aux')
95 | os.system('rm -rf *.log')
96 | os.system('rm -rf *.tex')
97 | os.system('rm -rf *.pdf')
98 | os.system('rm -rf *.png')
99 |
100 | print(sum(accList) / len(accList))
101 |
--------------------------------------------------------------------------------
/scrips/evaluation/CalculateMatch.py:
--------------------------------------------------------------------------------
1 | import subprocess
2 | from PIL import Image
3 | from tqdm import tqdm
4 |
5 | import argparse
6 |
7 | import sys
8 | import os
9 |
10 | root_path = os.path.abspath(__file__)
11 | root_path = '/'.join(root_path.split('/')[:-3])
12 | sys.path.append(root_path)
13 |
14 |
15 | def process_args():
16 | parser = argparse.ArgumentParser(description='Get parameters')
17 |
18 | parser.add_argument('--formulas', dest='formulas_file_path',
19 | type=str, required=True,
20 | help='Input formulas.txt path')
21 |
22 | parameters = parser.parse_args()
23 | return parameters
24 |
25 |
26 | if __name__ == '__main__':
27 | parameters = process_args()
28 |
29 | f = open(root_path + '/data/result/predict.txt').readlines()
30 | f2 = open(parameters.formulas_file_path, encoding='utf-8').readlines()
31 |
32 | formulaDic = {}
33 | for item_f2 in f2:
34 | formulaDic[item_f2.strip().split('\t')[0]] = item_f2.strip().split('\t')[1]
35 |
36 | accList = []
37 |
38 | for item_f in tqdm(f):
39 | index = item_f.strip().split('\t')[0]
40 | formula = item_f.strip().split('\t')[1]
41 | labelFormula = formulaDic[index].strip()
42 |
43 | if formula == labelFormula:
44 | accList.append(1)
45 | else:
46 | pdfText = r'\documentclass{article}' + '\n' + r'\usepackage{amsmath,amssymb}' + '\n' + '\pagestyle{empty}' + '\n' + \
47 | r'\thispagestyle{empty}' + '\n' + r'\begin{document}' + '\n' + r'\begin{equation*}' + '\n' + formula + \
48 | r'\end{equation*}' + '\n' + '\end{document}'
49 | f3 = open('predict.tex', mode='w')
50 | f3.write(pdfText)
51 | f3.close()
52 | sub = subprocess.Popen("pdflatex -halt-on-error " + "predict.tex", shell=True, stdout=subprocess.PIPE)
53 | sub.wait()
54 |
55 | pdfFiles = []
56 | for _, _, pf in os.walk(os.getcwd()):
57 | pdfFiles = pf
58 | break
59 |
60 | if 'predict.pdf' in pdfFiles:
61 | try:
62 | pdfText = r'\documentclass{article}' + '\n' + r'\usepackage{amsmath,amssymb}' + '\n' + '\pagestyle{empty}' + '\n' + \
63 | r'\thispagestyle{empty}' + '\n' + r'\begin{document}' + '\n' + r'\begin{equation*}' + '\n' + labelFormula + \
64 | r'\end{equation*}' + '\n' + '\end{document}'
65 | f3 = open('label.tex', mode='w')
66 | f3.write(pdfText)
67 | f3.close()
68 | sub = subprocess.Popen("pdflatex -halt-on-error " + "label.tex", shell=True, stdout=subprocess.PIPE)
69 | sub.wait()
70 |
71 | os.system('convert -strip ' + 'label.pdf ' + 'label.png')
72 | os.system('convert -strip ' + 'predict.pdf ' + 'predict.png')
73 | label = Image.open('label.png').convert('L')
74 |
75 | predict = Image.open('predict.png').convert('L')
76 |
77 | if label == predict:
78 | accList.append(1)
79 | else:
80 | accList.append(0)
81 | except:
82 | accList.append(0)
83 |
84 | else:
85 | accList.append(0)
86 | print(sum(accList) / len(accList))
87 |
88 | os.system('rm -rf *.aux')
89 | os.system('rm -rf *.log')
90 | os.system('rm -rf *.tex')
91 | os.system('rm -rf *.pdf')
92 | os.system('rm -rf *.png')
93 |
94 | print('Match:')
95 | print(sum(accList) / len(accList))
96 |
--------------------------------------------------------------------------------
/scrips/preprocessing/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/abcAnonymous/EDSL/e8b4d0e597aa7ff335c070e80853f33af52837f4/scrips/preprocessing/.DS_Store
--------------------------------------------------------------------------------
/scrips/preprocessing/preprocess_test.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from tqdm import tqdm
3 | import warnings
4 | from PIL import Image
5 |
6 | import argparse
7 | import sys
8 | import os
9 |
10 | root_path = os.path.abspath(__file__)
11 | root_path = '/'.join(root_path.split('/')[:-3])
12 | sys.path.append(root_path)
13 |
14 | from src.utils import ImgCandidate
15 |
16 |
17 | def process_args():
18 | parser = argparse.ArgumentParser(description='Get parameters')
19 |
20 | parser.add_argument('--formulas', dest='formulas_file_path',
21 | type=str, required=True,
22 | help='Input formulas.txt path')
23 |
24 | parser.add_argument('--test', dest='test_file_path',
25 | type=str, required=True,
26 | help='Input test.txt path')
27 |
28 | parser.add_argument('--vocab', dest='vocab_file_path',
29 | type=str, required=True,
30 | help='Input latex_vocab.txt path')
31 |
32 | parser.add_argument('--img', dest='img_path',
33 | type=str, required=True,
34 | help='Input image path')
35 |
36 | parameters = parser.parse_args()
37 | return parameters
38 |
39 |
40 | def get_position_vec(positionList):
41 | xMax = max([item[1] for item in positionList])
42 | yMax = max([item[3] for item in positionList])
43 |
44 | finalPosition = []
45 | for item_g in positionList:
46 | x_1 = item_g[0]
47 | x_2 = item_g[1]
48 | y_1 = item_g[2]
49 | y_2 = item_g[3]
50 |
51 | tmp = [x_1 / xMax, x_2 / xMax, y_1 / yMax, y_2 / yMax, xMax / yMax]
52 | finalPosition.append(tmp)
53 | return finalPosition
54 |
55 |
56 | if __name__ == '__main__':
57 |
58 | parameters = process_args()
59 |
60 | warnings.filterwarnings('ignore')
61 |
62 | f = open(parameters.formulas_file_path, encoding='utf-8').readlines()
63 | labelIndexDic = {}
64 | for item_f in f:
65 | labelIndexDic[item_f.strip().split('\t')[0]] = item_f.strip().split('\t')[1]
66 |
67 | f3 = open(parameters.test_file_path,
68 | encoding='utf-8').readlines()
69 |
70 | testLabelList = []
71 | for item_f3 in f3:
72 | if len(item_f3) > 0:
73 | testLabelList.append(labelIndexDic[item_f3.strip()])
74 |
75 | MAXLENGTH = 150
76 |
77 | f5 = open(parameters.vocab_file_path, encoding='utf-8').readlines()
78 |
79 | PAD = 0
80 | START = 1
81 | END = 2
82 |
83 | index_label_dic = {}
84 | label_index_dic = {}
85 |
86 | i = 3
87 | for item_f5 in f5:
88 | word = item_f5.strip()
89 | if len(word) > 0:
90 | label_index_dic[word] = i
91 | index_label_dic[i] = word
92 | i += 1
93 | label_index_dic['unk'] = i
94 | index_label_dic[i] = 'unk'
95 | i += 1
96 |
97 | labelEmbed_teaching_test = []
98 | labelEmbed_predict_test = []
99 |
100 | for item_l in testLabelList:
101 | tmp = [1]
102 | words = item_l.strip().split(' ')
103 | for item_w in words:
104 | if len(item_w) > 0:
105 | if item_w in label_index_dic:
106 | tmp.append(label_index_dic[item_w])
107 | else:
108 | tmp.append(label_index_dic['unk'])
109 |
110 | labelEmbed_teaching_test.append(tmp)
111 |
112 | tmp = []
113 | words = item_l.strip().split(' ')
114 | for item_w in words:
115 | if len(item_w) > 0:
116 | if item_w in label_index_dic:
117 | tmp.append(label_index_dic[item_w])
118 | else:
119 | tmp.append(label_index_dic['unk'])
120 |
121 | tmp.append(2)
122 | labelEmbed_predict_test.append(tmp)
123 |
124 | labelEmbed_teachingArray_test = np.array(labelEmbed_teaching_test)
125 | labelEmbed_predictArray_test = np.array(labelEmbed_predict_test)
126 |
127 | #
128 | testData = []
129 | testPosition = []
130 |
131 | for item_f3 in tqdm(f3):
132 | img = Image.open(parameters.img_path +
133 | item_f3.strip() + ".png").convert('L')
134 | img = Image.fromarray(ImgCandidate.deletePadding(np.array(img)))
135 | imgInfo = []
136 | positionInfo = []
137 | for t in [160]:
138 | tmp = ImgCandidate.getAllCandidate(img, t)
139 | for item_t in tmp:
140 | if item_t[1] not in positionInfo:
141 | imgInfo.append(item_t[0])
142 | positionInfo.append(item_t[1])
143 | positionVec = get_position_vec(positionInfo)
144 | testData.append(imgInfo)
145 | testPosition.append(positionVec)
146 |
147 | np.save(root_path + '/data/preprocess_data/testData_160', np.array(testData))
148 | np.save(root_path + '/data/preprocess_data/testPosition_160', np.array(testPosition))
149 |
150 | #
151 |
--------------------------------------------------------------------------------
/scrips/preprocessing/preprocess_train.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from tqdm import tqdm
3 | import warnings
4 | from PIL import Image
5 |
6 | import argparse
7 | import sys
8 | import os
9 |
10 | root_path = os.path.abspath(__file__)
11 | root_path = '/'.join(root_path.split('/')[:-3])
12 | sys.path.append(root_path)
13 |
14 | from src.utils import ImgCandidate
15 |
16 | warnings.filterwarnings('ignore')
17 |
18 |
19 | def process_args():
20 | parser = argparse.ArgumentParser(description='Get parameters')
21 |
22 | parser.add_argument('--formulas', dest='formulas_file_path',
23 | type=str, required=True,
24 | help= 'Input formulas.txt path')
25 |
26 | parser.add_argument('--train', dest='train_file_path',
27 | type=str, required=True,
28 | help='Input train.txt path')
29 |
30 | parser.add_argument('--vocab', dest='vocab_file_path',
31 | type=str, required=True,
32 | help='Input latex_vocab.txt path')
33 |
34 | parser.add_argument('--img', dest='img_path',
35 | type=str, required=True,
36 | help='Input image path')
37 |
38 | parameters = parser.parse_args()
39 | return parameters
40 |
41 |
42 |
43 |
44 | def get_position_vec(positionList):
45 |
46 | xMax = max([item[1] for item in positionList])
47 | yMax = max([item[3] for item in positionList])
48 |
49 | finalPosition = []
50 | for item_g in positionList:
51 |
52 | x_1 = item_g[0]
53 | x_2 = item_g[1]
54 | y_1 = item_g[2]
55 | y_2 = item_g[3]
56 |
57 |
58 | tmp = [x_1 /xMax , x_2 /xMax , y_1 /yMax, y_2 /yMax, xMax/yMax]
59 | finalPosition.append(tmp)
60 | return finalPosition
61 |
62 |
63 |
64 | if __name__ == '__main__':
65 |
66 |
67 | parameters = process_args()
68 |
69 | f = open(parameters.formulas_file_path, encoding='utf-8').readlines()
70 | labelIndexDic = {}
71 | for item_f in f:
72 | labelIndexDic[item_f.strip().split('\t')[0]] = item_f.strip().split('\t')[1]
73 |
74 | f2 = open(parameters.train_file_path,
75 | encoding='utf-8').readlines()
76 |
77 |
78 | trainLabelList = []
79 | for item_f2 in f2:
80 | if len(item_f2) > 0:
81 | trainLabelList.append(labelIndexDic[item_f2.strip()])
82 |
83 |
84 |
85 | MAXLENGTH = 150
86 |
87 | f5 = open(parameters.vocab_file_path, encoding='utf-8').readlines()
88 |
89 | PAD = 0
90 | START = 1
91 | END = 2
92 |
93 | index_label_dic = {}
94 | label_index_dic = {}
95 |
96 | i = 3
97 | for item_f5 in f5:
98 | word = item_f5.strip()
99 | if len(word) > 0:
100 | label_index_dic[word] = i
101 | index_label_dic[i] = word
102 | i += 1
103 | label_index_dic['unk'] = i
104 | index_label_dic[i] = 'unk'
105 | i += 1
106 |
107 | labelEmbed_teaching_train = []
108 |
109 | labelEmbed_predict_train = []
110 | for item_l in trainLabelList:
111 | tmp = [1]
112 | words = item_l.strip().split(' ')
113 | for item_w in words:
114 | if len(item_w) > 0:
115 | if item_w in label_index_dic:
116 | tmp.append(label_index_dic[item_w])
117 | else:
118 |
119 | tmp.append(label_index_dic['unk'])
120 |
121 | labelEmbed_teaching_train.append(tmp)
122 |
123 | tmp = []
124 | words = item_l.strip().split(' ')
125 | for item_w in words:
126 | if len(item_w) > 0:
127 | if item_w in label_index_dic:
128 | tmp.append(label_index_dic[item_w])
129 | else:
130 | tmp.append(label_index_dic['unk'])
131 |
132 | tmp.append(2)
133 | labelEmbed_predict_train.append(tmp)
134 |
135 |
136 |
137 | imgDic = {}
138 | imgList = []
139 |
140 | trainData = []
141 | trainPosition = []
142 | trainIndexList = []
143 |
144 |
145 | for i in tqdm(range(len(f2))):
146 | item_f2 = f2[i] ##
147 | img = Image.open(parameters.img_path +
148 | item_f2.strip() + ".png").convert('L')
149 | img = Image.fromarray(ImgCandidate.deletePadding(np.array(img)))
150 | for t in [160, 180, 200]:
151 |
152 | tmp = ImgCandidate.getAllCandidate(img, t)
153 | positionInfo = [item[1] for item in tmp]
154 |
155 | #
156 | positionVec = get_position_vec(positionInfo)
157 | #
158 | if t == 160:
159 | trainIndexList.append(0 + i)
160 |
161 | trainPosition.append(positionVec)
162 | imgInfo = []
163 |
164 | for item_t in tmp:
165 | if str(item_t[0].tolist()) not in imgDic:
166 | imgDic[str(item_t[0].tolist())] = len(imgDic)
167 | imgList.append(item_t[0])
168 | imgInfo.append(imgDic[str(item_t[0].tolist())])
169 | trainData.append(imgInfo)
170 |
171 | else:
172 | if positionVec not in trainPosition:
173 | trainIndexList.append(0 + i)
174 |
175 | trainPosition.append(positionVec)
176 | imgInfo = []
177 | for item_t in tmp:
178 | if str(item_t[0].tolist()) not in imgDic:
179 | imgDic[str(item_t[0].tolist())] = len(imgDic)
180 | imgList.append(item_t[0])
181 | imgInfo.append(imgDic[str(item_t[0].tolist())])
182 | trainData.append(imgInfo)
183 |
184 |
185 |
186 | trainData = np.array(trainData)
187 | trainPosition = np.array(trainPosition)
188 | trainIndexList = np.array(trainIndexList)
189 | imgList = np.array(imgList)
190 |
191 | np.save(root_path + '/data/preprocess_data/trainData', np.array(trainData))
192 | np.save(root_path + '/data/preprocess_data/trainPosition', np.array(trainPosition))
193 | np.save(root_path + '/data/preprocess_data/trainIndexList', np.array(trainIndexList))
194 | np.save(root_path + '/data/preprocess_data/trainImgList',np.array(imgList))
195 |
196 |
--------------------------------------------------------------------------------
/scrips/preprocessing/preprocess_val.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from tqdm import tqdm
3 | import warnings
4 | from PIL import Image
5 |
6 | import argparse
7 | import sys
8 | import os
9 |
10 | root_path = os.path.abspath(__file__)
11 | root_path = '/'.join(root_path.split('/')[:-3])
12 | sys.path.append(root_path)
13 |
14 | from src.utils import ImgCandidate
15 |
16 |
17 | def process_args():
18 | parser = argparse.ArgumentParser(description='Get parameters')
19 |
20 | parser.add_argument('--formulas', dest='formulas_file_path',
21 | type=str, required=True,
22 | help='Input formulas.txt path')
23 |
24 | parser.add_argument('--val', dest='val_file_path',
25 | type=str, required=True,
26 | help='Input val.txt path')
27 |
28 | parser.add_argument('--vocab', dest='vocab_file_path',
29 | type=str, required=True,
30 | help='Input latex_vocab.txt path')
31 |
32 | parser.add_argument('--img', dest='img_path',
33 | type=str, required=True,
34 | help='Input image path')
35 |
36 | parameters = parser.parse_args()
37 | return parameters
38 |
39 |
40 | def get_position_vec(positionList):
41 | xMax = max([item[1] for item in positionList])
42 | yMax = max([item[3] for item in positionList])
43 |
44 | finalPosition = []
45 | for item_g in positionList:
46 | x_1 = item_g[0]
47 | x_2 = item_g[1]
48 | y_1 = item_g[2]
49 | y_2 = item_g[3]
50 |
51 | tmp = [x_1 / xMax, x_2 / xMax, y_1 / yMax, y_2 / yMax, xMax / yMax]
52 | finalPosition.append(tmp)
53 | return finalPosition
54 |
55 |
56 | if __name__ == '__main__':
57 |
58 | parameters = process_args()
59 |
60 | warnings.filterwarnings('ignore')
61 |
62 | f = open(parameters.formulas_file_path, encoding='utf-8').readlines()
63 | labelIndexDic = {}
64 | for item_f in f:
65 | labelIndexDic[item_f.strip().split('\t')[0]] = item_f.strip().split('\t')[1]
66 |
67 | f3 = open(parameters.val_file_path,
68 | encoding='utf-8').readlines()
69 |
70 | valLabelList = []
71 | for item_f3 in f3:
72 | if len(item_f3) > 0:
73 | valLabelList.append(labelIndexDic[item_f3.strip()])
74 |
75 | MAXLENGTH = 150
76 |
77 | f5 = open(parameters.vocab_file_path, encoding='utf-8').readlines()
78 |
79 | PAD = 0
80 | START = 1
81 | END = 2
82 |
83 | index_label_dic = {}
84 | label_index_dic = {}
85 |
86 | i = 3
87 | for item_f5 in f5:
88 | word = item_f5.strip()
89 | if len(word) > 0:
90 | label_index_dic[word] = i
91 | index_label_dic[i] = word
92 | i += 1
93 | label_index_dic['unk'] = i
94 | index_label_dic[i] = 'unk'
95 | i += 1
96 |
97 | labelEmbed_teaching_val = []
98 | labelEmbed_predict_val = []
99 |
100 | for item_l in valLabelList:
101 | tmp = [1]
102 | words = item_l.strip().split(' ')
103 | for item_w in words:
104 | if len(item_w) > 0:
105 | if item_w in label_index_dic:
106 | tmp.append(label_index_dic[item_w])
107 | else:
108 | tmp.append(label_index_dic['unk'])
109 |
110 | labelEmbed_teaching_val.append(tmp)
111 |
112 | tmp = []
113 | words = item_l.strip().split(' ')
114 | for item_w in words:
115 | if len(item_w) > 0:
116 | if item_w in label_index_dic:
117 | tmp.append(label_index_dic[item_w])
118 | else:
119 | tmp.append(label_index_dic['unk'])
120 |
121 | tmp.append(2)
122 | labelEmbed_predict_val.append(tmp)
123 |
124 | labelEmbed_teachingArray_val = np.array(labelEmbed_teaching_val)
125 | labelEmbed_predictArray_val = np.array(labelEmbed_predict_val)
126 |
127 | #
128 | valData = []
129 | valPosition = []
130 |
131 | for item_f3 in tqdm(f3):
132 | img = Image.open(parameters.img_path +
133 | item_f3.strip() + ".png").convert('L')
134 | img = Image.fromarray(ImgCandidate.deletePadding(np.array(img)))
135 | imgInfo = []
136 | positionInfo = []
137 | for t in [160]:
138 | tmp = ImgCandidate.getAllCandidate(img, t)
139 | for item_t in tmp:
140 | if item_t[1] not in positionInfo:
141 | imgInfo.append(item_t[0])
142 | positionInfo.append(item_t[1])
143 | positionVec = get_position_vec(positionInfo)
144 | valData.append(imgInfo)
145 | valPosition.append(positionVec)
146 |
147 | np.save(root_path + '/data/preprocess_data/valData_160', np.array(valData))
148 | np.save(root_path + '/data/preprocess_data/valPosition_160', np.array(valPosition))
149 |
150 | #
151 |
--------------------------------------------------------------------------------
/src/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/abcAnonymous/EDSL/e8b4d0e597aa7ff335c070e80853f33af52837f4/src/.DS_Store
--------------------------------------------------------------------------------
/src/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/abcAnonymous/EDSL/e8b4d0e597aa7ff335c070e80853f33af52837f4/src/__init__.py
--------------------------------------------------------------------------------
/src/model/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/abcAnonymous/EDSL/e8b4d0e597aa7ff335c070e80853f33af52837f4/src/model/__init__.py
--------------------------------------------------------------------------------
/src/model/transformers.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | import math, copy
6 | import seaborn
7 |
8 | seaborn.set_context(context="talk")
9 |
10 |
11 | class EncoderDecoder(nn.Module):
12 | """
13 | A standard Encoder-Decoder architecture. Base for this and many
14 | other models.
15 | """
16 |
17 | def __init__(self, encoder2, decoder, src_position_embed, tgt_embed, generator):
18 | super(EncoderDecoder, self).__init__()
19 |
20 | self.encoder2 = encoder2
21 | self.decoder = decoder
22 | self.src_position_embed = src_position_embed
23 | self.tgt_embed = tgt_embed
24 | self.generator = generator
25 | self.src_proj = nn.Linear(576, 256)
26 |
27 | def forward(self, src, src_position, tgt, src_mask, tgt_mask):
28 | src = F.relu(self.src_proj(src))
29 | src_position_embed = self.src_position_embed(src_position)
30 | src_embed = src_position_embed + src
31 |
32 | memory = self.encode2(src_embed, src_mask, src_position_embed)
33 | decode_embeded = self.decode(memory, src_mask, tgt, tgt_mask)
34 | return decode_embeded
35 |
36 | def encode1(self, src, src_mask):
37 | return self.encoder1(src, src_mask)
38 |
39 | def encode2(self, src, src_mask, pos):
40 | return self.encoder2(src, src_mask, pos)
41 |
42 | def decode(self, memory, src_mask, tgt, tgt_mask):
43 | tgt_embed = self.tgt_embed(tgt)
44 |
45 | return self.decoder(tgt_embed, memory, src_mask, tgt_mask)
46 |
47 |
48 | class Generator(nn.Module):
49 | "Define standard linear + softmax generation step."
50 |
51 | def __init__(self, d_model, vocab):
52 | super(Generator, self).__init__()
53 | self.proj = nn.Linear(d_model, vocab)
54 |
55 | def forward(self, x):
56 | return F.log_softmax(self.proj(x), dim=-1)
57 |
58 |
59 | def clones(module, N):
60 | "Produce N identical layers."
61 | return nn.ModuleList([copy.deepcopy(module) for _ in range(N)])
62 |
63 |
64 | class Encoder(nn.Module):
65 | "Core encoder is a stack of N layers"
66 |
67 | def __init__(self, layer, N):
68 | super(Encoder, self).__init__()
69 | self.layers = clones(layer, N)
70 | self.norm = nn.LayerNorm(layer.size)
71 |
72 | self.pj1 = nn.Linear(512, 256)
73 | self.pj2 = nn.Linear(256, 1)
74 |
75 | def forward(self, x, mask, pos):
76 | "Pass the input (and mask) through each layer in turn."
77 | pos1 = pos.unsqueeze(1).repeat(1, pos.size(1), 1, 1)
78 | pos2 = pos.unsqueeze(2).repeat(1, 1, pos.size(1), 1)
79 |
80 | posr = self.pj2(F.relu(self.pj1(torch.cat((pos1, pos2), dim=-1)))).squeeze(-1)
81 |
82 | for layer in self.layers:
83 | x, attn = layer(x, mask, posr)
84 |
85 | return self.norm(x)
86 | # return self.norm(x), attn
87 |
88 |
89 | class SublayerConnection(nn.Module):
90 | """
91 | A residual connection followed by a layer norm.
92 | Note for code simplicity the norm is first as opposed to last.
93 | """
94 |
95 | def __init__(self, size, dropout):
96 | super(SublayerConnection, self).__init__()
97 | self.norm = nn.LayerNorm(size)
98 | self.dropout = nn.Dropout(dropout)
99 |
100 | def forward(self, x, sublayer):
101 | "Apply residual connection to any sublayer with the same size."
102 | return x + self.dropout(sublayer(self.norm(x)))
103 |
104 |
105 | class EncoderLayer(nn.Module):
106 | "Encoder is made up of self-attn and feed forward (defined below)"
107 |
108 | def __init__(self, size, self_attn, feed_forward, dropout):
109 | super(EncoderLayer, self).__init__()
110 | self.self_attn = self_attn
111 | self.feed_forward = feed_forward
112 | self.sublayer = clones(SublayerConnection(size, dropout), 2)
113 | self.size = size
114 |
115 | def forward(self, x, mask, posr):
116 | "Follow Figure 1 (left) for connections."
117 |
118 | x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, mask, posr))
119 | attn = self.self_attn.attn
120 | return self.sublayer[1](x, self.feed_forward), attn
121 |
122 |
123 | class Decoder(nn.Module):
124 | "Generic N layer decoder with masking."
125 |
126 | def __init__(self, layer, N):
127 | super(Decoder, self).__init__()
128 | self.layers = clones(layer, N)
129 | self.norm = nn.LayerNorm(layer.size)
130 |
131 | self.count = nn.Embedding(200, 256)
132 |
133 | # self.pos_pj = nn.Linear(512,256)
134 |
135 | def forward(self, x, memory, src_mask, tgt_mask):
136 | # src_count = torch.sum(src_mask,dim=-1).squeeze(-1).long().cuda() - 1
137 | # src_count = self.count(src_count).unsqueeze(1).repeat(1,x.size(1),1)
138 |
139 | x_pos = torch.Tensor(list(range(x.size(1)))).unsqueeze(0).repeat(x.size(0), 1).long().cuda()
140 | x_pos = self.count(x_pos)
141 |
142 | # x_pos = self.pos_pj(torch.cat((src_count, x_pos),dim=-1))
143 | x = x + x_pos
144 |
145 | # x_cat = torch.empty(x.size(0), 0, x.size(1), x.size(2)).cuda()
146 | # x_cat = torch.cat((x_cat, x.unsqueeze(1)), dim=1)
147 |
148 | for layer in self.layers:
149 | x = layer(x, memory, src_mask, tgt_mask)
150 | # x_cat = torch.cat((x_cat, x.unsqueeze(1)), dim=1)
151 |
152 | # x_cat = torch.mean(x_cat, dim=1)
153 |
154 | return self.norm(x)
155 |
156 |
157 | class DecoderLayer(nn.Module):
158 | "Decoder is made of self-attn, src-attn, and feed forward (defined below)"
159 |
160 | def __init__(self, size, self_attn, src_attn, feed_forward, dropout):
161 | super(DecoderLayer, self).__init__()
162 | self.size = size
163 | self.self_attn = self_attn
164 | self.src_attn = src_attn
165 | self.feed_forward = feed_forward
166 | self.sublayer = clones(SublayerConnection(size, dropout), 3)
167 |
168 | def forward(self, x, memory, src_mask, tgt_mask):
169 | "Follow Figure 1 (right) for connections."
170 | m = memory
171 | # x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, tgt_mask))
172 | x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, tgt_mask))
173 | x = self.sublayer[1](x, lambda x: self.src_attn(x, m, m, src_mask))
174 | return self.sublayer[2](x, self.feed_forward)
175 |
176 |
177 | def subsequent_mask(size):
178 | "Mask out subsequent positions."
179 | attn_shape = (1, size, size)
180 | subsequent_mask = np.triu(np.ones(attn_shape), k=1).astype('uint8')
181 | return torch.from_numpy(subsequent_mask) == 0
182 |
183 |
184 | def attention(query, key, value, mask=None, dropout=None, posr=None):
185 | "Compute 'Scaled Dot Product Attention'"
186 | d_k = query.size(-1)
187 |
188 | scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k)
189 |
190 | if posr is not None:
191 | posr = posr.unsqueeze(1)
192 |
193 | scores = scores + posr
194 |
195 | if mask is not None:
196 | scores = scores.masked_fill(mask == 0, -1e9)
197 |
198 | p_attn = F.softmax(scores, dim=-1)
199 |
200 | if dropout is not None:
201 | p_attn = dropout(p_attn)
202 | return torch.matmul(p_attn, value), p_attn
203 |
204 |
205 | class MultiHeadedAttention(nn.Module):
206 | def __init__(self, h, d_model, dropout):
207 | "Take in model size and number of heads."
208 | super(MultiHeadedAttention, self).__init__()
209 | assert d_model % h == 0
210 | # We assume d_v always equals d_k
211 | self.d_k = d_model // h
212 | self.h = h
213 | self.linears = clones(nn.Linear(d_model, d_model), 4)
214 | self.attn = None
215 | self.dropout = nn.Dropout(p=dropout)
216 |
217 | def forward(self, query, key, value, mask=None, posr=None):
218 | "Implements Figure 2"
219 | if mask is not None:
220 | # Same mask applied to all h heads.
221 | mask = mask.unsqueeze(1)
222 | nbatches = query.size(0)
223 |
224 | # 1) Do all the linear projections in batch from d_model => h x d_k
225 |
226 | query, key, value = \
227 | [l(x).view(nbatches, -1, self.h, self.d_k).transpose(1, 2)
228 | for l, x in zip(self.linears, (query, key, value))]
229 |
230 | # 2) Apply attention on all the projected vectors in batch.
231 | x, self.attn = attention(query, key, value, mask=mask,
232 | dropout=self.dropout, posr=posr)
233 |
234 | # 3) "Concat" using a view and apply a final linear.
235 | x = x.transpose(1, 2).contiguous() \
236 | .view(nbatches, -1, self.h * self.d_k)
237 |
238 | return self.linears[-1](x)
239 |
240 |
241 | class PositionwiseFeedForward(nn.Module):
242 | "Implements FFN equation."
243 |
244 | def __init__(self, d_model, d_ff, dropout):
245 | super(PositionwiseFeedForward, self).__init__()
246 | self.w_1 = nn.Linear(d_model, d_ff)
247 | self.w_2 = nn.Linear(d_ff, d_model)
248 | self.dropout = nn.Dropout(dropout)
249 |
250 | def forward(self, x):
251 | return self.w_2(F.relu(self.dropout(self.w_1(x))))
252 |
253 |
254 | class Embeddings(nn.Module):
255 | def __init__(self, d_model, vocab):
256 | super(Embeddings, self).__init__()
257 | self.lut = nn.Embedding(vocab, d_model)
258 | self.d_model = d_model
259 |
260 | def forward(self, x):
261 | embed = self.lut(x.long()) * math.sqrt(self.d_model)
262 | return embed
263 |
264 |
265 |
266 | class EncoderPositionalEmbedding(nn.Module):
267 | def __init__(self, dmodel):
268 | super(EncoderPositionalEmbedding, self).__init__()
269 |
270 | self.fc1 = nn.Linear(5, 64)
271 | self.fc2 = nn.Linear(64, 128)
272 | self.fc3 = nn.Linear(128, 256)
273 |
274 | def forward(self, encoder_position):
275 | e = F.relu((self.fc1(encoder_position)))
276 | e = F.relu((self.fc2(e)))
277 | e = F.relu((self.fc3(e)))
278 |
279 | return e
280 |
--------------------------------------------------------------------------------
/src/test.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import numpy as np
5 | import copy
6 | import argparse
7 | import os
8 | import random
9 | from torch.autograd import Variable
10 | from tqdm import tqdm
11 | import warnings
12 | import sys
13 |
14 | root_path = os.path.abspath(__file__)
15 | root_path = '/'.join(root_path.split('/')[:-2])
16 |
17 | sys.path.append(root_path)
18 |
19 | from src.model import transformers
20 |
21 | warnings.filterwarnings('ignore')
22 |
23 |
24 |
25 | def getPositionVec(positionList):
26 | xMax = max([item[1] for item in positionList])
27 | yMax = max([item[3] for item in positionList])
28 |
29 | finalPosition = []
30 | for item_g in positionList:
31 | x_1 = item_g[0]
32 | x_2 = item_g[1]
33 | y_1 = item_g[2]
34 | y_2 = item_g[3]
35 |
36 | tmp = [x_1 / xMax, x_2 / xMax, y_1 / yMax, y_2 / yMax, xMax / yMax]
37 | finalPosition.append(tmp)
38 | return finalPosition
39 |
40 |
41 | def getBatchIndex(length, batchSize, shuffle=True):
42 | indexList = list(range(length))
43 | if shuffle == True:
44 | random.shuffle(indexList)
45 | batchList = []
46 | tmp = []
47 | for inidex in indexList:
48 | tmp.append(inidex)
49 | if len(tmp) == batchSize:
50 | batchList.append(tmp)
51 | tmp = []
52 | if len(tmp) > 0:
53 | batchList.append(tmp)
54 | return batchList
55 |
56 |
57 | def make_tgt_mask(tgt, pad):
58 | "Create a mask to hide padding and future words."
59 | tgt_mask = (tgt != pad).unsqueeze(-2)
60 | tgt_mask = tgt_mask & Variable(
61 | subsequent_mask(tgt.size(-1)).type_as(tgt_mask.data))
62 | return tgt_mask
63 |
64 |
65 | def subsequent_mask(size):
66 | "Mask out subsequent positions."
67 | attn_shape = (1, size, size)
68 | subsequent_mask = np.triu(np.ones(attn_shape), k=1).astype('uint8')
69 | return torch.from_numpy(subsequent_mask) == 0
70 |
71 |
72 | ####Transformer####
73 |
74 |
75 | def make_model(tgt_vocab, encoderN=6, decoderN=6,
76 | d_model=256, d_ff=1024, h=8, dropout=0.0):
77 | "Helper: Construct a model from hyperparameters."
78 | c = copy.deepcopy
79 | attn = transformers.MultiHeadedAttention(h, d_model, dropout)
80 | ff = transformers.PositionwiseFeedForward(d_model, d_ff, dropout)
81 | position = transformers.PositionalEncoding(d_model)
82 | model = transformers.EncoderDecoder(
83 | transformers.Encoder(transformers.EncoderLayer(d_model, c(attn), c(ff), dropout),
84 | encoderN),
85 | transformers.Decoder(transformers.DecoderLayer(d_model, c(attn), c(attn),
86 | c(ff), dropout), decoderN),
87 | transformers.EncoderPositionalEmbedding(d_model),
88 | transformers.Embeddings(d_model, tgt_vocab),
89 | transformers.Generator(d_model, tgt_vocab),
90 | )
91 |
92 | for p in model.parameters():
93 | if p.dim() > 1:
94 | nn.init.xavier_uniform_(p)
95 | return model
96 |
97 |
98 | def greedy_decode(model, src, src_position, src_mask, max_len):
99 | src = F.relu(model.src_proj(src))
100 | src_position_embed = model.src_position_embed(src_position)
101 | # src_position_embed = model.encode1(src_position_embed, src_mask)
102 |
103 | src_embed = src_position_embed + src
104 |
105 | memory = model.encode2(src_embed, src_mask, src_position_embed)
106 |
107 | lastWord = torch.ones(len(src), 1).cuda().long()
108 | for i in range(max_len):
109 | tgt_mask = Variable(subsequent_mask(lastWord.size(1)).type_as(src.data))
110 | tgt_mask = tgt_mask.repeat(src.size(0), 1, 1)
111 | out = model.decode(memory, src_mask, Variable(lastWord), tgt_mask)
112 | prob = model.generator(out[:, -1, :].squeeze(0)).unsqueeze(1)
113 | _, predictTmp = prob.max(dim=-1)
114 | lastWord = torch.cat((lastWord, predictTmp), dim=-1)
115 | prob = model.generator.proj(out)
116 |
117 | return prob
118 |
119 |
120 | class SimpleLossCompute:
121 | "A simple loss compute and train function."
122 |
123 | def __init__(self, generator, criterion):
124 | self.generator = generator
125 | self.criterion = criterion
126 |
127 | def __call__(self, x, y, norm):
128 | x = self.generator(x)
129 | loss = self.criterion(x.contiguous().view(-1, x.size(-1)),
130 | y.contiguous().view(-1)) / norm
131 |
132 | return loss
133 |
134 |
135 | class LabelSmoothing(nn.Module):
136 | "Implement label smoothing."
137 |
138 | def __init__(self, size, padding_idx, smoothing=0.0):
139 | super(LabelSmoothing, self).__init__()
140 | self.criterion = nn.KLDivLoss(size_average=False)
141 | self.padding_idx = padding_idx
142 | self.confidence = 1.0 - smoothing
143 | self.smoothing = smoothing
144 | self.size = size
145 | self.true_dist = None
146 |
147 | def forward(self, x, target):
148 | assert x.size(1) == self.size
149 | true_dist = x.data.clone()
150 | true_dist.fill_(self.smoothing / (self.size - 2))
151 | true_dist.scatter_(1, target.data.unsqueeze(1), self.confidence)
152 | true_dist[:, self.padding_idx] = 0
153 | mask = torch.nonzero(target.data == self.padding_idx)
154 | if mask.dim() > 0:
155 | true_dist.index_fill_(0, mask.squeeze(), 0.0)
156 | self.true_dist = true_dist
157 | return self.criterion(x, Variable(true_dist, requires_grad=False))
158 |
159 |
160 | class VGGModel(nn.Module):
161 | def __init__(self):
162 | super(VGGModel, self).__init__()
163 | self.conv1_1 = nn.Conv2d(in_channels=1, out_channels=32, kernel_size=3, padding=1)
164 | self.bn1_1 = nn.BatchNorm2d(32)
165 | self.conv1_2 = nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, padding=1)
166 | self.bn1_2 = nn.BatchNorm2d(32)
167 | self.mp1 = nn.MaxPool2d(2, 2)
168 |
169 | self.conv2_1 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, padding=1)
170 | self.bn2_1 = nn.BatchNorm2d(64)
171 | self.conv2_2 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, padding=1)
172 | self.bn2_2 = nn.BatchNorm2d(64)
173 | self.mp2 = nn.MaxPool2d(2, 2)
174 |
175 | self.conv3_1 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, padding=1)
176 | self.bn3_1 = nn.BatchNorm2d(64)
177 | self.conv3_2 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, padding=1)
178 | self.bn3_2 = nn.BatchNorm2d(64)
179 | self.mp3 = nn.MaxPool2d(2, 2)
180 | # self.linear = nn.Linear(576,256)
181 | # self.cly = nn.Linear(576, label_count)
182 |
183 | def forward(self, x):
184 | x = F.relu(self.bn1_1(self.conv1_1(x)))
185 | x = F.relu(self.bn1_2(self.conv1_2(x)))
186 | x = self.mp1(x)
187 |
188 | x = F.relu(self.bn2_1(self.conv2_1(x)))
189 | x = F.relu(self.bn2_2(self.conv2_2(x)))
190 | x = self.mp2(x)
191 |
192 | x = F.relu(self.bn3_1(self.conv3_1(x)))
193 | x = F.relu(self.bn3_2(self.conv3_2(x)))
194 | x = self.mp3(x)
195 |
196 | x_embed = x.view(x.size()[0], -1)
197 |
198 | return x_embed
199 |
200 |
201 | def process_args():
202 | parser = argparse.ArgumentParser(description='Get parameters')
203 |
204 | parser.add_argument('--formulas', dest='formulas_file_path',
205 | type=str, required=True,
206 | help='Input formulas.txt path')
207 |
208 | parser.add_argument('--test', dest='test_file_path',
209 | type=str, required=True,
210 | help='Input test.txt path')
211 |
212 | parser.add_argument('--vocab', dest='vocab_file_path',
213 | type=str, required=True,
214 | help='Input latex_vocab.txt path')
215 |
216 | parameters = parser.parse_args()
217 | return parameters
218 |
219 |
220 |
221 | if __name__ == '__main__':
222 |
223 | parameters = process_args()
224 |
225 | f = open(parameters.formulas_file_path, encoding='utf-8').readlines()
226 | labelIndexDic = {}
227 | for item_f in f:
228 | labelIndexDic[item_f.strip().split('\t')[0]] = item_f.strip().split('\t')[1]
229 |
230 | f3 = open(parameters.test_file_path, encoding='utf-8').readlines() # [:100]
231 |
232 | valLabelList = []
233 | for item_f3 in f3:
234 | if len(item_f3) > 0:
235 | valLabelList.append(labelIndexDic[item_f3.strip()])
236 |
237 | MAXLENGTH = 150
238 |
239 | f5 = open(parameters.vocab_file_path, encoding='utf-8').readlines()
240 |
241 | PAD = 0
242 | START = 1
243 | END = 2
244 |
245 | index_label_dic = {}
246 | label_index_dic = {}
247 |
248 | i = 3
249 | for item_f5 in f5:
250 | word = item_f5.strip()
251 | if len(word) > 0:
252 | label_index_dic[word] = i
253 | index_label_dic[i] = word
254 | i += 1
255 | label_index_dic['unk'] = i
256 | index_label_dic[i] = 'unk'
257 | i += 1
258 |
259 | labelEmbed_teaching_val = []
260 | labelEmbed_predict_val = []
261 |
262 | for item_l in valLabelList:
263 | tmp = [1]
264 | words = item_l.strip().split(' ')
265 | for item_w in words:
266 | if len(item_w) > 0:
267 | if item_w in label_index_dic:
268 | tmp.append(label_index_dic[item_w])
269 | else:
270 | tmp.append(label_index_dic['unk'])
271 |
272 | labelEmbed_teaching_val.append(tmp)
273 |
274 | tmp = []
275 | words = item_l.strip().split(' ')
276 | for item_w in words:
277 | if len(item_w) > 0:
278 | if item_w in label_index_dic:
279 | tmp.append(label_index_dic[item_w])
280 | else:
281 | tmp.append(label_index_dic['unk'])
282 |
283 | tmp.append(2)
284 | labelEmbed_predict_val.append(tmp)
285 |
286 | labelEmbed_teachingArray_val = np.array(labelEmbed_teaching_val)
287 | labelEmbed_predictArray_val = np.array(labelEmbed_predict_val)
288 |
289 | valDataArray = np.load(root_path + '/data/preprocess_data/testData_160.npy') # .tolist()
290 | valPositionArray = np.load(root_path + '/data/preprocess_data/testPosition_160.npy') # .tolist()
291 |
292 | labelLenListVal = []
293 | for item_pv in labelEmbed_predictArray_val:
294 | count = 0
295 | for item in item_pv:
296 | if item != 0:
297 | count += 1
298 | labelLenListVal.append(count)
299 |
300 | valLabelIndexOrderByLen = np.argsort(np.array(labelLenListVal)).tolist()
301 |
302 | valDataArray = valDataArray[valLabelIndexOrderByLen].tolist()
303 | valPositionArray = valPositionArray[valLabelIndexOrderByLen].tolist()
304 |
305 | labelEmbed_teachingArray_val = labelEmbed_teachingArray_val[valLabelIndexOrderByLen]
306 |
307 | labelEmbed_predictArray_val = labelEmbed_predictArray_val[valLabelIndexOrderByLen]
308 |
309 | BATCH_SIZE = 10
310 |
311 | #####Regularization Parameters####
312 | dropout = 0.2
313 | l2 = 1e-4
314 | #################
315 |
316 |
317 | model = make_model(len(index_label_dic) + 3, encoderN=8, decoderN=8,
318 | d_model=256, d_ff=1024, dropout=dropout).cuda()
319 | vgg = VGGModel().cuda()
320 |
321 | model.load_state_dict(torch.load(root_path + '/data/model/model.pkl'))
322 | vgg.load_state_dict(torch.load(root_path + '/data/model/encoder.pkl'))
323 |
324 | param = list(model.parameters()) + list(vgg.parameters())
325 |
326 | criterion = LabelSmoothing(size=len(label_index_dic) + 3, padding_idx=0, smoothing=0.1)
327 | lossComput = SimpleLossCompute(model.generator, criterion)
328 |
329 | learningRate = 0.001
330 |
331 | totalCount = 0
332 |
333 | exit_count = 0
334 |
335 | bestVal = 0
336 | bestTrainList = 0
337 |
338 | f6 = open(root_path + '/data/result/predict.txt', mode='w')
339 |
340 | while True:
341 |
342 | model.eval()
343 | vgg.eval()
344 |
345 | latex_batch_index = getBatchIndex(len(valLabelList), BATCH_SIZE, shuffle=False)
346 |
347 | latexAccListVal = []
348 |
349 | with torch.no_grad():
350 | for batch_i in tqdm(range(len(latex_batch_index))):
351 |
352 | latex_batch = latex_batch_index[batch_i]
353 |
354 | sourceDataTmp = [copy.copy(valDataArray[item]) for item in latex_batch]
355 | sourcePositionTmp = [copy.copy(valPositionArray[item]) for item in latex_batch]
356 | sourceLengthList = [len(item) for item in sourceDataTmp]
357 | sourceMaskTmp = [item * [1] for item in sourceLengthList]
358 |
359 | sourceLengthMax = max(sourceLengthList)
360 |
361 | for i in range(len(sourceDataTmp)):
362 | if len(sourceDataTmp[i]) < sourceLengthMax:
363 | while len(sourceDataTmp[i]) < sourceLengthMax:
364 | sourceDataTmp[i].append(np.zeros((30, 30)))
365 | sourcePositionTmp[i].append([0, 0, 0, 0, 0])
366 | sourceMaskTmp[i].append(0)
367 |
368 | sourceDataTmp = np.array(sourceDataTmp)
369 | sourcePositionTmp = np.array(sourcePositionTmp)
370 | sourceMaskTmp = np.array(sourceMaskTmp)
371 |
372 | tgt_teaching = labelEmbed_teachingArray_val[latex_batch].tolist()
373 | tgt_predict = labelEmbed_predictArray_val[latex_batch].tolist()
374 |
375 | tgt_teaching_copy = copy.deepcopy(tgt_teaching)
376 | tgt_predict_copy = copy.deepcopy(tgt_predict)
377 |
378 | tgtMaxBatch = 0
379 | for item_tgt in tgt_teaching_copy:
380 | if len(item_tgt) >= tgtMaxBatch:
381 | tgtMaxBatch = len(item_tgt)
382 |
383 | for i in range(len(tgt_teaching_copy)):
384 | while len(tgt_teaching_copy[i]) < tgtMaxBatch:
385 | tgt_teaching_copy[i].append(0)
386 |
387 | for i in range(len(tgt_predict_copy)):
388 | while len(tgt_predict_copy[i]) < tgtMaxBatch:
389 | tgt_predict_copy[i].append(0)
390 |
391 | sourceDataTmpArray = torch.from_numpy(sourceDataTmp).cuda().float()
392 | sourceMaskTmpArray = torch.from_numpy(sourceMaskTmp).cuda().float().unsqueeze(1)
393 | sourcePositionTmpArray = torch.from_numpy(sourcePositionTmp).cuda().float()
394 |
395 | tgt_teachingArray = torch.from_numpy(np.array(tgt_teaching_copy)).cuda().float()
396 |
397 | tgt_teachingMask = make_tgt_mask(tgt_teachingArray, 0)
398 |
399 | tgt_predictArray = torch.from_numpy(np.array(tgt_predict_copy)).cuda().long()
400 |
401 | sourceDataTmpArray_input = sourceDataTmpArray.view(-1, 1, 30, 30) / 255
402 | sourceDataTmpArray = vgg(sourceDataTmpArray_input).view(sourceDataTmpArray.size(0),
403 | sourceDataTmpArray.size(1), -1)
404 |
405 | out = greedy_decode(model, sourceDataTmpArray, sourcePositionTmpArray, sourceMaskTmpArray, 200)
406 |
407 | _, latexPredict = out.max(dim=-1)
408 |
409 |
410 | for i in range(len(latexPredict)):
411 |
412 | fileIndex = f3[valLabelIndexOrderByLen[latex_batch[i]]].strip()
413 |
414 | if 2 in latexPredict[i].tolist():
415 | endIndex = latexPredict[i].tolist().index(2)
416 | else:
417 | endIndex = MAXLENGTH
418 |
419 | predictTmp = latexPredict[i].tolist()[:endIndex]
420 | labelTmp = tgt_predictArray[i].tolist()[:endIndex]
421 |
422 | predictStr = ' '.join([index_label_dic[item] for item in predictTmp]).strip()
423 | f6.write(fileIndex + '\t' + predictStr + '\n')
424 |
425 | if predictTmp == labelTmp:
426 | latexAccListVal.append(1)
427 | else:
428 | latexAccListVal.append(0)
429 |
430 | valAcc = sum(latexAccListVal) / len(latexAccListVal)
431 | print(valAcc)
432 | exit()
433 |
--------------------------------------------------------------------------------
/src/train.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import numpy as np
5 | import copy
6 | import argparse
7 | import os
8 | import random
9 | from torch.autograd import Variable
10 | from tqdm import tqdm
11 | from torch import optim
12 | from nltk.translate.bleu_score import sentence_bleu
13 | import warnings
14 | import sys
15 |
16 | root_path = os.path.abspath(__file__)
17 | root_path = '/'.join(root_path.split('/')[:-2])
18 |
19 | sys.path.append(root_path)
20 |
21 | from src.model import transformers
22 |
23 | warnings.filterwarnings('ignore')
24 |
25 |
26 | def setup_seed(seed):
27 | torch.manual_seed(seed)
28 | torch.cuda.manual_seed_all(seed)
29 | np.random.seed(seed)
30 | random.seed(seed)
31 | torch.backends.cudnn.deterministic = True
32 |
33 |
34 | def getBatchIndex(length, batchSize, shuffle=True):
35 | indexList = list(range(length))
36 | if shuffle == True:
37 | random.shuffle(indexList)
38 | batchList = []
39 | tmp = []
40 | for inidex in indexList:
41 | tmp.append(inidex)
42 | if len(tmp) == batchSize:
43 | batchList.append(tmp)
44 | tmp = []
45 | if len(tmp) > 0:
46 | batchList.append(tmp)
47 | return batchList
48 |
49 |
50 | def make_tgt_mask(tgt, pad):
51 | "Create a mask to hide padding and future words."
52 | tgt_mask = (tgt != pad).unsqueeze(-2)
53 | tgt_mask = tgt_mask & Variable(
54 | subsequent_mask(tgt.size(-1)).type_as(tgt_mask.data))
55 | return tgt_mask
56 |
57 |
58 | def subsequent_mask(size):
59 | "Mask out subsequent positions."
60 | attn_shape = (1, size, size)
61 | subsequent_mask = np.triu(np.ones(attn_shape), k=1).astype('uint8')
62 | return torch.from_numpy(subsequent_mask) == 0
63 |
64 |
65 | ####Transformer####
66 | def make_model(tgt_vocab, encoderN=6, decoderN=6,
67 | d_model=256, d_ff=1024, h=8, dropout=0.0):
68 | "Helper: Construct a model from hyperparameters."
69 | c = copy.deepcopy
70 | attn = transformers.MultiHeadedAttention(h, d_model, dropout)
71 | ff = transformers.PositionwiseFeedForward(d_model, d_ff, dropout)
72 | model = transformers.EncoderDecoder(
73 | transformers.Encoder(transformers.EncoderLayer(d_model, c(attn), c(ff), dropout),
74 | encoderN),
75 | transformers.Decoder(transformers.DecoderLayer(d_model, c(attn), c(attn),
76 | c(ff), dropout), decoderN),
77 | transformers.EncoderPositionalEmbedding(d_model),
78 | transformers.Embeddings(d_model, tgt_vocab),
79 | transformers.Generator(d_model, tgt_vocab),
80 | )
81 |
82 | for p in model.parameters():
83 | if p.dim() > 1:
84 | nn.init.xavier_uniform_(p)
85 | return model
86 |
87 |
88 | def greedy_decode(model, src, src_position, src_mask, max_len):
89 | src = F.relu(model.src_proj(src))
90 | src_position_embed = model.src_position_embed(src_position)
91 |
92 | src_embed = src_position_embed + src
93 |
94 | memory = model.encode2(src_embed, src_mask, src_position_embed)
95 |
96 | lastWord = torch.ones(len(src), 1).cuda().long()
97 | for i in range(max_len):
98 | tgt_mask = Variable(subsequent_mask(lastWord.size(1)).type_as(src.data))
99 | tgt_mask = tgt_mask.repeat(src.size(0), 1, 1)
100 | out = model.decode(memory, src_mask, Variable(lastWord), tgt_mask)
101 | prob = model.generator(out[:, -1, :].squeeze(0)).unsqueeze(1)
102 | _, predictTmp = prob.max(dim=-1)
103 | lastWord = torch.cat((lastWord, predictTmp), dim=-1)
104 | prob = model.generator.proj(out)
105 |
106 | return prob
107 |
108 |
109 | class SimpleLossCompute:
110 | "A simple loss compute and train function."
111 |
112 | def __init__(self, generator, criterion):
113 | self.generator = generator
114 | self.criterion = criterion
115 |
116 | def __call__(self, x, y, norm):
117 | x = self.generator(x)
118 | loss = self.criterion(x.contiguous().view(-1, x.size(-1)),
119 | y.contiguous().view(-1)) / norm
120 | return loss
121 |
122 |
123 | class LabelSmoothing(nn.Module):
124 | "Implement label smoothing."
125 |
126 | def __init__(self, size, padding_idx, smoothing=0.0):
127 | super(LabelSmoothing, self).__init__()
128 | self.criterion = nn.KLDivLoss(size_average=False)
129 | self.padding_idx = padding_idx
130 | self.confidence = 1.0 - smoothing
131 | self.smoothing = smoothing
132 | self.size = size
133 | self.true_dist = None
134 |
135 | def forward(self, x, target):
136 | assert x.size(1) == self.size
137 | true_dist = x.data.clone()
138 | true_dist.fill_(self.smoothing / (self.size - 2))
139 | true_dist.scatter_(1, target.data.unsqueeze(1), self.confidence)
140 | true_dist[:, self.padding_idx] = 0
141 | mask = torch.nonzero(target.data == self.padding_idx)
142 | if mask.dim() > 0:
143 | true_dist.index_fill_(0, mask.squeeze(), 0.0)
144 | self.true_dist = true_dist
145 | return self.criterion(x, Variable(true_dist, requires_grad=False))
146 |
147 |
148 | class Encoder(nn.Module):
149 | def __init__(self):
150 | super(Encoder, self).__init__()
151 | self.conv1_1 = nn.Conv2d(in_channels=1, out_channels=32, kernel_size=3, padding=1)
152 | self.bn1_1 = nn.BatchNorm2d(32)
153 | self.conv1_2 = nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, padding=1)
154 | self.bn1_2 = nn.BatchNorm2d(32)
155 | self.mp1 = nn.MaxPool2d(2, 2)
156 |
157 | self.conv2_1 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, padding=1)
158 | self.bn2_1 = nn.BatchNorm2d(64)
159 | self.conv2_2 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, padding=1)
160 | self.bn2_2 = nn.BatchNorm2d(64)
161 | self.mp2 = nn.MaxPool2d(2, 2)
162 |
163 | self.conv3_1 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, padding=1)
164 | self.bn3_1 = nn.BatchNorm2d(64)
165 | self.conv3_2 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, padding=1)
166 | self.bn3_2 = nn.BatchNorm2d(64)
167 | self.mp3 = nn.MaxPool2d(2, 2)
168 | # self.linear = nn.Linear(576,256)
169 | # self.cly = nn.Linear(576, label_count)
170 |
171 | def forward(self, x):
172 | x = F.relu(self.bn1_1(self.conv1_1(x)))
173 | x = F.relu(self.bn1_2(self.conv1_2(x)))
174 | x = self.mp1(x)
175 |
176 | x = F.relu(self.bn2_1(self.conv2_1(x)))
177 | x = F.relu(self.bn2_2(self.conv2_2(x)))
178 | x = self.mp2(x)
179 |
180 | x = F.relu(self.bn3_1(self.conv3_1(x)))
181 | x = F.relu(self.bn3_2(self.conv3_2(x)))
182 | x = self.mp3(x)
183 |
184 | x_embed = x.view(x.size()[0], -1)
185 |
186 | return x_embed
187 |
188 |
189 | def process_args():
190 | parser = argparse.ArgumentParser(description='Get parameters')
191 |
192 | parser.add_argument('--formulas', dest='formulas_file_path',
193 | type=str, required=True,
194 | help='Input formulas.txt path')
195 |
196 | parser.add_argument('--train', dest='train_file_path',
197 | type=str, required=True,
198 | help='Input train.txt path')
199 |
200 | parser.add_argument('--val', dest='val_file_path',
201 | type=str, required=True,
202 | help='Input val.txt path')
203 |
204 | parser.add_argument('--vocab', dest='vocab_file_path',
205 | type=str, required=True,
206 | help='Input latex_vocab.txt path')
207 |
208 | parameters = parser.parse_args()
209 | return parameters
210 |
211 |
212 | if __name__ == '__main__':
213 |
214 | parameters = process_args()
215 |
216 |
217 | f = open(parameters.formulas_file_path, encoding='utf-8').readlines()
218 | labelIndexDic = {}
219 | for item_f in f:
220 | labelIndexDic[item_f.strip().split('\t')[0]] = item_f.strip().split('\t')[1]
221 |
222 | f2 = open(parameters.train_file_path, encoding='utf-8').readlines()
223 | f3 = open(parameters.val_file_path, encoding='utf-8').readlines()
224 |
225 | trainLabelList = []
226 | for item_f2 in f2:
227 | if len(item_f2) > 0:
228 | trainLabelList.append(labelIndexDic[item_f2.strip()])
229 |
230 | valLabelList = []
231 | for item_f3 in f3:
232 | if len(item_f3) > 0:
233 | valLabelList.append(labelIndexDic[item_f3.strip()])
234 |
235 | MAXLENGTH = 150
236 |
237 | f5 = open(parameters.vocab_file_path, encoding='utf-8').readlines()
238 |
239 | PAD = 0
240 | START = 1
241 | END = 2
242 |
243 | index_label_dic = {}
244 | label_index_dic = {}
245 |
246 | i = 3
247 | for item_f5 in f5:
248 | word = item_f5.strip()
249 | if len(word) > 0:
250 | label_index_dic[word] = i
251 | index_label_dic[i] = word
252 | i += 1
253 | label_index_dic['unk'] = i
254 | index_label_dic[i] = 'unk'
255 | i += 1
256 |
257 | labelEmbed_teaching_train = []
258 | labelEmbed_predict_train = []
259 | for item_l in trainLabelList:
260 | tmp = [1]
261 | words = item_l.strip().split(' ')
262 | for item_w in words:
263 | if len(item_w) > 0:
264 | if item_w in label_index_dic:
265 | tmp.append(label_index_dic[item_w])
266 | else:
267 | tmp.append(label_index_dic['unk'])
268 |
269 | labelEmbed_teaching_train.append(tmp)
270 |
271 | tmp = []
272 | words = item_l.strip().split(' ')
273 | for item_w in words:
274 | if len(item_w) > 0:
275 | if item_w in label_index_dic:
276 | tmp.append(label_index_dic[item_w])
277 | else:
278 | tmp.append(label_index_dic['unk'])
279 |
280 | tmp.append(2)
281 | labelEmbed_predict_train.append(tmp)
282 |
283 | labelEmbed_teaching_val = []
284 | labelEmbed_predict_val = []
285 |
286 | for item_l in valLabelList:
287 | tmp = [1]
288 | words = item_l.strip().split(' ')
289 | for item_w in words:
290 | if len(item_w) > 0:
291 | if item_w in label_index_dic:
292 | tmp.append(label_index_dic[item_w])
293 | else:
294 | tmp.append(label_index_dic['unk'])
295 |
296 | labelEmbed_teaching_val.append(tmp)
297 |
298 | tmp = []
299 | words = item_l.strip().split(' ')
300 | for item_w in words:
301 | if len(item_w) > 0:
302 | if item_w in label_index_dic:
303 | tmp.append(label_index_dic[item_w])
304 | else:
305 | tmp.append(label_index_dic['unk'])
306 |
307 | tmp.append(2)
308 | labelEmbed_predict_val.append(tmp)
309 |
310 | labelEmbed_teachingArray_val = np.array(labelEmbed_teaching_val)
311 | labelEmbed_predictArray_val = np.array(labelEmbed_predict_val)
312 |
313 | trainDataArray = np.load(root_path + '/data/preprocess_data/trainData')
314 | trainPositionArray = np.load(root_path + '/data/preprocess_data/trainPosition')
315 |
316 | trainIndexList = np.load(root_path + '/data/preprocess_data/trainIndexList').tolist()
317 | trainImgList = np.load(root_path + '/data/preprocess_data/trainImgList')
318 |
319 | labelEmbed_teaching_train_copy = []
320 | labelEmbed_predict_train_copy = []
321 | for item_ti in trainIndexList:
322 | labelEmbed_teaching_train_copy.append(labelEmbed_teaching_train[item_ti])
323 | labelEmbed_predict_train_copy.append(labelEmbed_predict_train[item_ti])
324 |
325 | labelLenListTrain = []
326 | for item_pv in labelEmbed_teaching_train_copy:
327 | count = 0
328 | for item in item_pv:
329 | if item != 0:
330 | count += 1
331 | labelLenListTrain.append(count)
332 |
333 | trainLabelIndexOrderByLen = np.argsort(np.array(labelLenListTrain)).tolist()
334 |
335 | trainDataArray = trainDataArray[trainLabelIndexOrderByLen].tolist()
336 | trainPositionArray = trainPositionArray[trainLabelIndexOrderByLen].tolist()
337 |
338 | labelEmbed_teachingArray_train = np.array(labelEmbed_teaching_train_copy)[trainLabelIndexOrderByLen]
339 | labelEmbed_predictArray_train = np.array(labelEmbed_predict_train_copy)[trainLabelIndexOrderByLen]
340 |
341 | valDataArray = np.load(root_path + '/data/preprocess_data/valData_160') # .tolist()
342 | valPositionArray = np.load(root_path + '/data/preprocess_data/valPosition_160') # .tolist()
343 |
344 | labelLenListVal = []
345 | for item_pv in labelEmbed_predictArray_val:
346 | count = 0
347 | for item in item_pv:
348 | if item != 0:
349 | count += 1
350 | labelLenListVal.append(count)
351 |
352 | valLabelIndexOrderByLen = np.argsort(np.array(labelLenListVal)).tolist()
353 |
354 | valDataArray = valDataArray[valLabelIndexOrderByLen].tolist()
355 | valPositionArray = valPositionArray[valLabelIndexOrderByLen].tolist()
356 |
357 | labelEmbed_teachingArray_val = labelEmbed_teachingArray_val[valLabelIndexOrderByLen]
358 |
359 | labelEmbed_predictArray_val = labelEmbed_predictArray_val[valLabelIndexOrderByLen]
360 |
361 | BATCH_SIZE = 16
362 |
363 | #####Regularization Parameters####
364 | dropout = 0.2
365 | l2 = 1e-4
366 | #################
367 |
368 |
369 | model = make_model(len(index_label_dic) + 3, encoderN=8, decoderN=8,
370 | d_model=256, d_ff=1024, dropout=dropout).cuda()
371 | encoder = Encoder().cuda()
372 |
373 | param = list(model.parameters()) + list(encoder.parameters())
374 |
375 | criterion = LabelSmoothing(size=len(label_index_dic) + 3, padding_idx=0, smoothing=0.1)
376 | lossComput = SimpleLossCompute(model.generator, criterion)
377 |
378 | learningRate = 3e-4
379 |
380 | totalCount = 0
381 |
382 | exit_count = 0
383 |
384 | bestVal = 0
385 | bestTrainList = 0
386 |
387 | criterionVal = nn.CrossEntropyLoss(ignore_index=0, size_average=True).cuda()
388 |
389 | while True:
390 |
391 | model.train()
392 | encoder.train()
393 |
394 | optimizer = optim.Adam(param, lr=learningRate, weight_decay=l2)
395 |
396 | latex_batch_index = getBatchIndex(len(trainIndexList), BATCH_SIZE, shuffle=False)
397 | random.shuffle(latex_batch_index)
398 |
399 | lossListTrain = []
400 | latexAccListTrain = []
401 | bleuListTrain = []
402 |
403 | for batch_i in tqdm(range(len(latex_batch_index))):
404 |
405 | latex_batch = latex_batch_index[batch_i]
406 |
407 | sourceDataTmp = [copy.copy(trainDataArray[item]) for item in latex_batch]
408 | sourcePositionTmp = [copy.copy(trainPositionArray[item]) for item in latex_batch]
409 | sourceLengthList = [len(item) for item in sourceDataTmp]
410 | sourceMaskTmp = [item * [1] for item in sourceLengthList]
411 |
412 | sourceDataTmp = [trainImgList[item] for item in sourceDataTmp]
413 |
414 | sourceLengthMax = max(sourceLengthList)
415 |
416 | for i in range(len(sourceDataTmp)):
417 | if len(sourceDataTmp[i]) < sourceLengthMax:
418 | while len(sourceDataTmp[i]) < sourceLengthMax:
419 | sourceDataTmp[i] = np.concatenate((sourceDataTmp[i], np.zeros((1, 30, 30))), axis=0)
420 | sourcePositionTmp[i].append([0, 0, 0, 0, 0])
421 | sourceMaskTmp[i].append(0)
422 |
423 | sourceDataTmp = np.array(sourceDataTmp)
424 | sourcePositionTmp = np.array(sourcePositionTmp)
425 | sourceMaskTmp = np.array(sourceMaskTmp)
426 |
427 | tgt_teaching = labelEmbed_teachingArray_train[latex_batch].tolist()
428 | tgt_predict = labelEmbed_predictArray_train[latex_batch].tolist()
429 |
430 | tgt_teaching_copy = copy.deepcopy(tgt_teaching)
431 | tgt_predict_copy = copy.deepcopy(tgt_predict)
432 |
433 | tgtMaxBatch = 0
434 | for item_tgt in tgt_teaching_copy:
435 | if len(item_tgt) >= tgtMaxBatch:
436 | tgtMaxBatch = len(item_tgt)
437 |
438 | for i in range(len(tgt_teaching_copy)):
439 | while len(tgt_teaching_copy[i]) < tgtMaxBatch:
440 | tgt_teaching_copy[i].append(0)
441 |
442 | for i in range(len(tgt_predict_copy)):
443 | while len(tgt_predict_copy[i]) < tgtMaxBatch:
444 | tgt_predict_copy[i].append(0)
445 |
446 | sourceDataTmpArray = torch.from_numpy(sourceDataTmp).cuda().float()
447 | sourceMaskTmpArray = torch.from_numpy(sourceMaskTmp).cuda().float().unsqueeze(1)
448 | sourcePositionTmpArray = torch.from_numpy(sourcePositionTmp).cuda().float()
449 |
450 | tgt_teachingArray = torch.from_numpy(np.array(tgt_teaching_copy)).cuda().float()
451 |
452 | tgt_teachingMask = make_tgt_mask(tgt_teachingArray, 0)
453 |
454 | tgt_predictArray = torch.from_numpy(np.array(tgt_predict_copy)).cuda().long()
455 |
456 | sourceDataTmpArray_input = sourceDataTmpArray.view(-1, 1, 30, 30) / 255
457 |
458 | sourceDataTmpArray = encoder(sourceDataTmpArray_input).view(sourceDataTmpArray.size(0),
459 | sourceDataTmpArray.size(1), -1)
460 |
461 | out = model.forward(sourceDataTmpArray, sourcePositionTmpArray, tgt_teachingArray,
462 | sourceMaskTmpArray, tgt_teachingMask)
463 |
464 | _, latexPredict = model.generator(out).max(dim=-1)
465 |
466 | for i in range(len(latexPredict)):
467 |
468 | if 2 in tgt_predictArray[i].tolist():
469 | endIndex = tgt_predictArray[i].tolist().index(2)
470 | else:
471 | endIndex = MAXLENGTH - 1
472 |
473 | predictTmp = latexPredict[i].tolist()[:endIndex + 1]
474 | labelTmp = tgt_predictArray[i].tolist()[:endIndex + 1]
475 |
476 | if predictTmp == labelTmp:
477 | latexAccListTrain.append(1)
478 | else:
479 | latexAccListTrain.append(0)
480 |
481 | bleuScore = sentence_bleu([labelTmp], predictTmp)
482 | bleuListTrain.append(bleuScore)
483 |
484 | loss = lossComput(out, tgt_predictArray, len(latex_batch))
485 |
486 | lossListTrain.append(loss.item())
487 |
488 | optimizer.zero_grad()
489 | loss.backward()
490 | optimizer.step()
491 |
492 | trainLoss = sum(lossListTrain) / len(lossListTrain)
493 | trainAcc = sum(latexAccListTrain) / len(latexAccListTrain)
494 | trainBleu = sum(bleuListTrain) / len(bleuListTrain)
495 |
496 | model.eval()
497 | encoder.eval()
498 |
499 | latex_batch_index = getBatchIndex(len(valLabelList), BATCH_SIZE, shuffle=False)
500 |
501 | latexAccListVal = []
502 | latexLossListVal = []
503 |
504 | with torch.no_grad():
505 | for batch_i in tqdm(range(len(latex_batch_index))):
506 |
507 | latex_batch = latex_batch_index[batch_i]
508 |
509 | sourceDataTmp = [copy.copy(valDataArray[item]) for item in latex_batch]
510 | sourcePositionTmp = [copy.copy(valPositionArray[item]) for item in latex_batch]
511 | sourceLengthList = [len(item) for item in sourceDataTmp]
512 | sourceMaskTmp = [item * [1] for item in sourceLengthList]
513 |
514 | sourceLengthMax = max(sourceLengthList)
515 |
516 | for i in range(len(sourceDataTmp)):
517 | if len(sourceDataTmp[i]) < sourceLengthMax:
518 | while len(sourceDataTmp[i]) < sourceLengthMax:
519 | sourceDataTmp[i].append(np.zeros((30, 30)))
520 | sourcePositionTmp[i].append([0, 0, 0, 0, 0])
521 | sourceMaskTmp[i].append(0)
522 |
523 | sourceDataTmp = np.array(sourceDataTmp)
524 | sourcePositionTmp = np.array(sourcePositionTmp)
525 | sourceMaskTmp = np.array(sourceMaskTmp)
526 |
527 | tgt_teaching = labelEmbed_teachingArray_val[latex_batch].tolist()
528 | tgt_predict = labelEmbed_predictArray_val[latex_batch].tolist()
529 |
530 | tgt_teaching_copy = copy.deepcopy(tgt_teaching)
531 | tgt_predict_copy = copy.deepcopy(tgt_predict)
532 |
533 | tgtMaxBatch = 0
534 | for item_tgt in tgt_teaching_copy:
535 | if len(item_tgt) >= tgtMaxBatch:
536 | tgtMaxBatch = len(item_tgt)
537 |
538 | for i in range(len(tgt_teaching_copy)):
539 | while len(tgt_teaching_copy[i]) < tgtMaxBatch:
540 | tgt_teaching_copy[i].append(0)
541 |
542 | for i in range(len(tgt_predict_copy)):
543 | while len(tgt_predict_copy[i]) < tgtMaxBatch:
544 | tgt_predict_copy[i].append(0)
545 |
546 | sourceDataTmpArray = torch.from_numpy(sourceDataTmp).cuda().float()
547 | sourceMaskTmpArray = torch.from_numpy(sourceMaskTmp).cuda().float().unsqueeze(1)
548 | sourcePositionTmpArray = torch.from_numpy(sourcePositionTmp).cuda().float()
549 |
550 | tgt_teachingArray = torch.from_numpy(np.array(tgt_teaching_copy)).cuda().float()
551 |
552 | tgt_teachingMask = make_tgt_mask(tgt_teachingArray, 0)
553 |
554 | tgt_predictArray = torch.from_numpy(np.array(tgt_predict_copy)).cuda().long()
555 |
556 | sourceDataTmpArray_input = sourceDataTmpArray.view(-1, 1, 30, 30) / 255
557 | sourceDataTmpArray = encoder(sourceDataTmpArray_input).view(sourceDataTmpArray.size(0),
558 | sourceDataTmpArray.size(1), -1)
559 |
560 | out = greedy_decode(model, sourceDataTmpArray, sourcePositionTmpArray, sourceMaskTmpArray, tgtMaxBatch)
561 |
562 | _, latexPredict = out.max(dim=-1)
563 |
564 |
565 | for i in range(len(latexPredict)):
566 | if 2 in latexPredict[i].tolist():
567 | endIndex = latexPredict[i].tolist().index(2)
568 | else:
569 | endIndex = MAXLENGTH
570 |
571 | predictTmp = latexPredict[i].tolist()[:endIndex]
572 | labelTmp = tgt_predictArray[i].tolist()[:endIndex]
573 |
574 | if predictTmp == labelTmp:
575 | latexAccListVal.append(1)
576 | else:
577 | latexAccListVal.append(0)
578 |
579 | out = out.contiguous().view(-1, out.size(-1))
580 |
581 | targets = tgt_predictArray.view(-1)
582 | loss = criterionVal(out, targets)
583 |
584 | latexLossListVal.append(loss.item())
585 |
586 | valAcc = sum(latexAccListVal) / len(latexAccListVal)
587 | valLoss = sum(latexLossListVal) / len(latexLossListVal)
588 |
589 | if valAcc > bestVal:
590 | torch.save(model.state_dict(), root_path + '/data/model/model.pkl')
591 | torch.save(encoder.state_dict(), root_path + '/data/model/encoder.pkl')
592 | bestVal = valAcc
593 | exit_count = 0
594 | else:
595 | exit_count += 1
596 |
597 | if exit_count > 0 and exit_count % 3 == 0:
598 | learningRate *= 0.5
599 |
600 | if exit_count == 10:
601 | exit()
602 |
603 | print("Epoch:" + str(totalCount) + "\t TrainingSet:" + '\t loss:' + str(trainLoss) + "\t ACC:" + str(
604 | trainAcc) + "\t BLEU:"
605 | + str(trainBleu) +
606 | "\t ValSet:" + "\t ACC:" + str(valAcc) + "\t bestAcc:" + str(
607 | bestVal) + "\t learningRate:"
608 | + str(learningRate) + "\t exit_count:"
609 | + str(exit_count))
610 |
611 | totalCount += 1
612 |
--------------------------------------------------------------------------------
/src/utils/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/abcAnonymous/EDSL/e8b4d0e597aa7ff335c070e80853f33af52837f4/src/utils/.DS_Store
--------------------------------------------------------------------------------
/src/utils/ImgCandidate.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from PIL import Image
3 | import copy
4 | import math
5 | from skimage import measure
6 |
7 |
8 |
9 | def deletePadding(imgNp):
10 |
11 |
12 | imgNp[imgNp > 220] = 255
13 |
14 | binary = imgNp != 255
15 | sum_0 = np.sum(binary,axis=0)
16 | sum_1 = np.sum(binary,axis=1)
17 | sum_0_left = min(np.argwhere(sum_0 != 0).tolist())[0]
18 | sum_0_right = max(np.argwhere(sum_0 != 0).tolist())[0]
19 | sum_1_left = min(np.argwhere(sum_1 != 0).tolist())[0]
20 | sum_1_right = max(np.argwhere(sum_1 != 0).tolist())[0]
21 | delImg = imgNp[sum_1_left:sum_1_right + 1,sum_0_left :sum_0_right+ 1]
22 |
23 | return delImg
24 |
25 | def deleSpace(imgNp):
26 | binary = imgNp != 255
27 | sum_0 = np.sum(binary, axis=0).tolist()
28 | index_0 = []
29 |
30 | for i in range(len(sum_0)):
31 | if sum_0[i] != 0:
32 | index_0.append(i)
33 |
34 | imgNp = imgNp[:,index_0]
35 |
36 | return imgNp
37 |
38 |
39 | def imgResize(arr,square=30):
40 |
41 | a_height = arr.shape[0]
42 | a_weight = arr.shape[1]
43 |
44 |
45 | ratio = square / max(a_height,a_weight)
46 | arr = np.array(Image.fromarray(arr).resize((math.ceil(a_weight*ratio),math.ceil(a_height*ratio)),Image.ANTIALIAS))
47 |
48 | a_height = arr.shape[0]
49 | a_weight = arr.shape[1]
50 |
51 | if a_height < square:
52 | h_1 = int((square - a_height)/2)
53 | h_2 = square - a_height- h_1
54 | if h_2!=0:
55 | arr = np.vstack((np.ones((h_1,a_weight)) * 255,arr,np.ones((h_2,a_weight)) * 255))
56 | else:
57 | arr = np.vstack((np.ones((h_1, a_weight)) * 255, arr))
58 | if a_weight THREASHOLD) & (img < 255)
76 |
77 |
78 | binary = binary * (splitPoint == False)
79 | label = (measure.label(binary, connectivity=2))
80 |
81 |
82 | label_add_split = copy.copy(label)
83 | times = 0
84 | while True in splitPoint and times < 10 :
85 | for item in np.argwhere(splitPoint == True):
86 | top = item[0] - 1 if item[0] - 1 > 0 else 0
87 | down = item[0] + 2 if item[0] + 2 < label.shape[0] else label.shape[0]
88 | left = item[1] - 1 if item[1] - 1 > 0 else 0
89 | right = item[1] + 2 if item[1] + 2 < label.shape[1] else label.shape[1]
90 |
91 | area = label[top:down,left:right].reshape(-1)
92 | count = np.bincount(area)
93 | if len(count) > 1:
94 | count[0] = 0
95 | label_add_split[item[0],item[1]] = np.argmax(count)
96 | splitPoint[item[0],item[1]] = False
97 | label = label_add_split
98 | times += 1
99 |
100 | if True in splitPoint:
101 | for item in np.argwhere(splitPoint == True):
102 | label_add_split[item[0], item[1]] = 0
103 |
104 |
105 | labelCount = np.max(label_add_split)
106 |
107 |
108 | labelPosition = []
109 |
110 | for i in range(labelCount):
111 |
112 | tmpLeft = np.min(np.where(label_add_split == i + 1)[1])
113 | tmpRight = np.max(np.where(label_add_split == i + 1)[1])
114 | tmpTop = np.min(np.where(label_add_split == i + 1)[0])
115 | tmpDown = np.max(np.where(label_add_split == i + 1)[0])
116 |
117 | imgTmp = copy.copy(img)
118 | imgTmp[label_add_split != i+1] = 255
119 | imgTmp = deletePadding(imgTmp)
120 | imgTmp = imgResize(imgTmp,24)
121 |
122 | pad1 = np.ones((3, imgTmp.shape[1])) * 255
123 | imgTmp = np.concatenate((pad1,imgTmp),axis=0)
124 | imgTmp = np.concatenate((imgTmp,pad1),axis=0)
125 |
126 | pad2 = np.ones((imgTmp.shape[0],3)) * 255
127 | imgTmp = np.concatenate((pad2, imgTmp), axis=1)
128 | imgTmp = np.concatenate((imgTmp, pad2), axis=1)
129 |
130 |
131 | labelPosition.append((imgTmp,(tmpTop, tmpDown+1, tmpLeft, tmpRight+1)))
132 |
133 |
134 | return labelPosition
135 |
--------------------------------------------------------------------------------
/src/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/abcAnonymous/EDSL/e8b4d0e597aa7ff335c070e80853f33af52837f4/src/utils/__init__.py
--------------------------------------------------------------------------------