├── 1_WebScraping_Dolphin
├── CleanVersionExtractor.py
└── OriginalDataExtractor.py
├── 2_Data_cleaning
├── Initial_Data_Cleaning.ipynb
├── MWP_DataCleaning.py
└── cleaned_data_examples
│ ├── filtered_cleaned_dolphin_data.json
│ └── uncleaned_dolphin_data.csv
├── 3_T-RNN_&_baselines
├── output
│ ├── pred_dolphin_rep_tfidf.txt
│ ├── pred_dolphin_tfidf.txt
│ ├── pred_math23k_tfidf.txt
│ ├── run_dolphin.txt
│ ├── run_dolphin_rep.txt
│ └── run_math23k.txt
└── src
│ ├── data_loader.py
│ ├── main.py
│ ├── model.py
│ ├── replicate.py
│ ├── retrieval_model.py
│ ├── trainer.py
│ └── utils.py
├── LICENSE.md
└── readme.md
/1_WebScraping_Dolphin/CleanVersionExtractor.py:
--------------------------------------------------------------------------------
1 | import urllib
2 | import difflib
3 | import pickle
4 | import json
5 | import threading
6 | import sys
7 | import getopt
8 |
9 | reload(sys)
10 | sys.setdefaultencoding('utf8')
11 |
12 |
13 | def generate_clean_question_file(diff_file, raw_file, out_clean_question_file):
14 | pkl_diff = open(diff_file, 'rb')
15 | records = pickle.load(pkl_diff)
16 | word_prob_groups = []
17 | fp = open(raw_file)
18 | word_probs_datas = json.load(fp, encoding='utf8')
19 | for word_prob_data in word_probs_datas:
20 | qid = word_prob_data["id"]
21 | # if qid.find("20130905131554AArnMlQ") != -1:
22 | # print "debug"
23 | original_text = word_prob_data["original_text"]
24 | text = original_text
25 | diff_oprs = records[qid]
26 | for tag, i1, i2, j1, j2, ins_text in reversed(diff_oprs):
27 | if tag == 'delete':
28 | text = text[0:i1] + text[i2:]
29 | elif tag == 'insert':
30 | text = text[0:i1] + ins_text + text[i2:]
31 | elif tag == 'replace':
32 | text = text[0:i1] + ins_text + text[i2:]
33 | word_prob_data["text"] = text
34 | word_prob_groups.append(word_prob_data)
35 | fp.close()
36 | fp = open(out_clean_question_file, 'w')
37 | json.dump(word_prob_groups, fp, indent=2, ensure_ascii=False)
38 | fp.close()
39 |
40 |
41 | if __name__ == "__main__":
42 | original_file, diff_file, out_file = "", "", ""
43 | try:
44 | opts, args = getopt.getopt(sys.argv[1:], 'i:d:o:')
45 | except getopt.GetoptError:
46 | print 'invalid input format'
47 | sys.exit(2)
48 | for opt, arg in opts:
49 | if opt == "-i":
50 | original_file = arg
51 | elif opt == "-d":
52 | diff_file = arg
53 | elif opt == "-o":
54 | out_file = arg
55 | original_file = "eval_original.json"
56 | diff_file = "eval_diff.pkl"
57 | out_file = "eval_cleaned.json"
58 | generate_clean_question_file(diff_file, original_file, out_file)
59 |
--------------------------------------------------------------------------------
/1_WebScraping_Dolphin/OriginalDataExtractor.py:
--------------------------------------------------------------------------------
1 | import urllib
2 | import difflib
3 | import pickle
4 | import json
5 | import threading
6 | import sys
7 | import getopt
8 |
9 | reload(sys)
10 | sys.setdefaultencoding('utf8')
11 |
12 |
13 | def getHtml(url):
14 | page = urllib.urlopen(url)
15 | html = page.read()
16 | return html
17 |
18 |
19 | def extract_raw_problom(str_html):
20 | meta_problem_title = ""
23 | str_content = ""
24 | i_problem_title_start = str_html.find(meta_problem_title)
25 | if i_problem_title_start != -1:
26 | i_problem_title_end = str_html.find(problem_text_end, i_problem_title_start)
27 | str_problem_title = str_html[i_problem_title_start + len(meta_problem_title):i_problem_title_end]
28 | if str_problem_title != "":
29 | str_content += str_problem_title + "\n"
30 | i_problem_text_start = str_html.find(meta_problem_text)
31 | if i_problem_text_start != -1:
32 | i_problem_text_end = str_html.find(problem_text_end, i_problem_text_start)
33 | str_problem_text = str_html[i_problem_text_start + len(meta_problem_text):i_problem_text_end]
34 | if str_problem_text != "":
35 | str_content += str_problem_text
36 | str_content = str_content.replace("'", "\'").replace("\", "\\").replace("<", "<").replace(
37 | ">", ">").replace(""", "\"")
38 | str_content = str_content.replace("&", "&")
39 | return str_content
40 |
41 |
42 | def is_page_not_found(str_content):
43 | if str_content.find("This question does not exist or is under review.") != -1:
44 | return True
45 | return False
46 |
47 |
48 | '''step 1: TODO multi-thread'''
49 |
50 |
51 | def generate_raw_question_file(input_file, out_raw_question_file):
52 | word_prob_groups = []
53 | fp = open(input_file, 'r')
54 | word_probs_datas = json.load(fp)
55 | index = 0
56 | for word_prob_data in word_probs_datas:
57 | url = word_prob_data["id"].replace("yahoo.answers.", "")
58 | print url
59 | url = "https://answers.yahoo.com/question/index?qid=" + url
60 | html = getHtml(url)
61 | # fpage = open("dev_pages\\" + str(index) + ".txt")
62 | # html = fpage.read()
63 | str_content = extract_raw_problom(html)
64 | word_prob_data["original_text"] = str_content
65 | word_prob_groups.append(word_prob_data)
66 | index += 1
67 | fp.close()
68 | fp = open(out_raw_question_file, 'w')
69 | json.dump(word_prob_groups, fp, indent=2, ensure_ascii=False)
70 | fp.close()
71 |
72 |
73 | def assign_original_text(qids, qid2index, word_probs_datas):
74 | for qid in qids:
75 | print qid
76 | url = qid.replace("yahoo.answers.", "")
77 | url = "https://answers.yahoo.com/question/index?qid=" + url
78 | html = getHtml(url)
79 | # fpage = open("dev_pages\\" + str(qid2index[qid]) + ".txt")
80 | # html = fpage.read()
81 | str_content = extract_raw_problom(html)
82 | word_probs_datas[qid2index[qid]]["original_text"] = str_content
83 |
84 |
85 | def generate_raw_question_file_multi_thread(input_file, out_raw_question_file, ithread=10):
86 | qid2index = {}
87 | fp = open(input_file, 'r')
88 | word_probs_datas = json.load(fp)
89 | index = 0
90 | split_qids = []
91 | for i in range(0, ithread):
92 | split_qids.append([])
93 | for word_prob_data in word_probs_datas:
94 | qid = word_prob_data["id"]
95 | qid2index[qid] = index
96 | split_qids[index % ithread].append(qid)
97 | index += 1
98 | fp.close()
99 | threads = []
100 | for i in range(0, ithread):
101 | # tid = threading.Thread(target=assign_original_text, args=(split_qids[i], qid2index, word_probs_datas, ) )
102 | # tid.start()
103 | try:
104 | tid = threading.Thread(target=assign_original_text, args=(split_qids[i], qid2index, word_probs_datas,))
105 | tid.start()
106 | threads.append(tid)
107 | except:
108 | print "Error: unable to start thread"
109 | while True:
110 | is_alive = False
111 | for tid in threads:
112 | if tid.isAlive():
113 | is_alive = True
114 | break
115 | if is_alive == False:
116 | break
117 | fp = open(out_raw_question_file, 'w')
118 | json.dump(word_probs_datas, fp, indent=2, ensure_ascii=False)
119 | fp.close()
120 |
121 |
122 | '''temp, to be deleted'''
123 |
124 |
125 | def generate_diff_file(json_file, out_diff_file):
126 | id2text = {}
127 | fp = open("dev_dataset_full.json")
128 | word_probs_datas = json.load(fp, encoding='utf8')
129 | for word_prob_data in word_probs_datas:
130 | qid = word_prob_data["id"]
131 | text = word_prob_data["text"]
132 | id2text[qid] = text
133 | fp.close()
134 | fp = open(json_file)
135 | word_probs_datas = json.load(fp, encoding='utf8')
136 | records = {}
137 | for word_prob_data in word_probs_datas:
138 | qid = word_prob_data["id"]
139 | # text = word_prob_data["text"]
140 | text = id2text[qid]
141 | original_text = word_prob_data["original_text"]
142 | matcher = difflib.SequenceMatcher(None, original_text, text)
143 | new_record = []
144 | for rec in matcher.get_opcodes():
145 | if rec[0] == 'equal':
146 | continue
147 | if rec[0] == 'insert' or rec[0] == 'replace':
148 | rec = rec + (text[rec[3]:rec[4]],)
149 | else:
150 | rec = rec + ("",)
151 | new_record.append(rec)
152 | records[qid] = new_record
153 | fp.close()
154 | output = open(out_diff_file, 'wb')
155 | pickle.dump(records, output)
156 | output.close()
157 |
158 |
159 | def validate(gold_file, merge_file):
160 | fp = open(gold_file)
161 | word_probs_datas = json.load(fp)
162 | id2text, id2ori = {}, {}
163 | for word_prob_data in word_probs_datas:
164 | qid = word_prob_data["id"]
165 | text = word_prob_data["text"]
166 | if word_prob_data.has_key("original_text"):
167 | original_text = word_prob_data["original_text"]
168 | id2ori[qid] = original_text
169 | id2text[qid] = text
170 | fp.close()
171 | fp1 = open(merge_file)
172 | word_probs_data_merge = json.load(fp1)
173 | for word_prob_data in word_probs_data_merge:
174 | qid = word_prob_data["id"]
175 | text = word_prob_data["text"]
176 | if word_prob_data.has_key("original_text"):
177 | original_text = word_prob_data["original_text"]
178 | if id2ori.has_key(qid) and original_text != id2ori[qid]:
179 | print qid
180 | if text != id2text[qid]:
181 | print qid
182 | fp1.close()
183 |
184 |
185 | if __name__ == "__main__":
186 | '''
187 | fp = open("eval_dataset_full.json")
188 | word_probs_datas = json.load(fp)
189 | for word_prob_data in word_probs_datas:
190 | if word_prob_data.has_key("type"):
191 | del word_prob_data["type"]
192 | fp = open("eval_urls1.json", 'w')
193 | json.dump(word_probs_datas, fp, indent = 2, ensure_ascii=False)
194 | fp.close()
195 | '''
196 |
197 | url_file, out_file = "", ""
198 | try:
199 | opts, args = getopt.getopt(sys.argv[1:], 'i:t:o:')
200 | except getopt.GetoptError:
201 | print 'invalid input format'
202 | sys.exit(2)
203 | for opt, arg in opts:
204 | if opt == "-i":
205 | url_file = arg
206 | elif opt == "-t":
207 | i_thread = int(arg)
208 | elif opt == "-o":
209 | out_file = arg
210 | url_file = "eval_urls.json"
211 | out_file = "eval_original.json"
212 | i_thread = 10
213 | generate_raw_question_file_multi_thread(url_file, out_file, i_thread)
214 | # config = "dev"
215 | # generate_diff_file(config + "_original.json", config + "_diff.pkl")
216 |
--------------------------------------------------------------------------------
/2_Data_cleaning/Initial_Data_Cleaning.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "nbformat": 4,
3 | "nbformat_minor": 0,
4 | "metadata": {
5 | "colab": {
6 | "name": "Initial Data Cleaning.ipynb",
7 | "provenance": [],
8 | "collapsed_sections": [],
9 | "toc_visible": true,
10 | "machine_shape": "hm"
11 | },
12 | "kernelspec": {
13 | "name": "python3",
14 | "display_name": "Python 3"
15 | }
16 | },
17 | "cells": [
18 | {
19 | "cell_type": "markdown",
20 | "metadata": {
21 | "id": "ihfXrF0zE-5H",
22 | "colab_type": "text"
23 | },
24 | "source": [
25 | "Import Libraries\n",
26 | "--"
27 | ]
28 | },
29 | {
30 | "cell_type": "code",
31 | "metadata": {
32 | "id": "ivCIhm4zyRyI",
33 | "colab_type": "code",
34 | "outputId": "69d1f301-789c-4c09-8059-bb9a72a53c1a",
35 | "colab": {
36 | "base_uri": "https://localhost:8080/",
37 | "height": 35
38 | }
39 | },
40 | "source": [
41 | "import json, csv\n",
42 | "import numpy as np\n",
43 | "import pandas as pd\n",
44 | "import re\n",
45 | "import warnings\n",
46 | "\n",
47 | "warnings.filterwarnings('ignore')\n",
48 | "\n",
49 | "print(\"Libraries Imported\")"
50 | ],
51 | "execution_count": 1,
52 | "outputs": [
53 | {
54 | "output_type": "stream",
55 | "text": [
56 | "Libraries Imported\n"
57 | ],
58 | "name": "stdout"
59 | }
60 | ]
61 | },
62 | {
63 | "cell_type": "code",
64 | "metadata": {
65 | "id": "mvxifhcnAx6h",
66 | "colab_type": "code",
67 | "colab": {}
68 | },
69 | "source": [
70 | "def getText(doc):\n",
71 | " doc = str(doc)\n",
72 | " doc = doc.lower().strip()\n",
73 | " doc = re.sub('\\n', ' ', doc)\n",
74 | " doc = re.sub(r'\\s+', ' ', doc)\n",
75 | " m = re.search(r'',doc)\n",
76 | " m1 = re.search(r'',doc)\n",
77 | "\n",
78 | " if m != None and m1!= None:\n",
79 | " text = str(m.group(1)) + ' ' + str(m1.group(1))\n",
80 | " else:\n",
81 | " text = \"No match\"\n",
82 | "\n",
83 | " return text"
84 | ],
85 | "execution_count": 0,
86 | "outputs": []
87 | },
88 | {
89 | "cell_type": "markdown",
90 | "metadata": {
91 | "id": "8EXob7wzFNcz",
92 | "colab_type": "text"
93 | },
94 | "source": [
95 | "Preparing Datasets\n",
96 | "--"
97 | ]
98 | },
99 | {
100 | "cell_type": "code",
101 | "metadata": {
102 | "id": "YgcH2x0NY6Rz",
103 | "colab_type": "code",
104 | "colab": {}
105 | },
106 | "source": [
107 | "data = pd.read_json('eval_cleaned.json')\n",
108 | "\n",
109 | "for i, row in data.iterrows():\n",
110 | " if re.match(r\"^\\\"/>\n",
163 | "Int64Index: 6625 entries, 19 to 14722\n",
164 | "Data columns (total 4 columns):\n",
165 | "text 6625 non-null object\n",
166 | "ans 6625 non-null object\n",
167 | "equations 6625 non-null object\n",
168 | "unknowns 6625 non-null object\n",
169 | "dtypes: object(4)\n",
170 | "memory usage: 258.8+ KB\n"
171 | ],
172 | "name": "stdout"
173 | }
174 | ]
175 | },
176 | {
177 | "cell_type": "markdown",
178 | "metadata": {
179 | "id": "l_j4sf8UFRAO",
180 | "colab_type": "text"
181 | },
182 | "source": [
183 | "Data Cleaning\n",
184 | "--"
185 | ]
186 | },
187 | {
188 | "cell_type": "code",
189 | "metadata": {
190 | "id": "PdhRZ_R1FTCI",
191 | "colab_type": "code",
192 | "outputId": "136792b6-a0ff-4cc7-b1d7-5ae7cf3f4d62",
193 | "colab": {
194 | "base_uri": "https://localhost:8080/",
195 | "height": 69
196 | }
197 | },
198 | "source": [
199 | "import nltk\n",
200 | "nltk.download('stopwords')\n",
201 | "\n",
202 | "from nltk.corpus import stopwords\n",
203 | "\n",
204 | "import spacy\n",
205 | "\n",
206 | "# import string\n",
207 | "\n",
208 | "import re\n",
209 | "\n",
210 | "nlp = spacy.load(\"en\")\n",
211 | "\n",
212 | "nltk_stopwords = set(stopwords.words('english'))\n",
213 | "\n",
214 | "spacy_stopwords = nlp.Defaults.stop_words\n",
215 | "\n",
216 | "stopset = nltk_stopwords.union(spacy_stopwords)\n",
217 | "\n",
218 | "stopset.difference_update([\"a\",\"more\",\"less\",\"than\",\"one\",\"two\",\"three\",\"four\",\"five\",\"six\",\"seven\",\"eight\",\"nine\",\"ten\",\"eleven\",\"twelve\",\"fifteen\",\"twenty\",\"forty\",\"sixty\",\"fifty\",\"hundred\",\"once\",\"first\",\"second\",\"third\"])\n",
219 | "\n",
220 | "punctuation = \"!\\\"#$&',;?@\\_`{|}~\"\n",
221 | "\n",
222 | "def getText(doc):\n",
223 | " doc = str(doc)\n",
224 | " doc = doc.lower().strip()\n",
225 | " doc = re.sub('\\n', ' ', doc)\n",
226 | " doc = re.sub(r'\\s+', ' ', doc)\n",
227 | " m = re.search(r'',doc)\n",
228 | " m1 = re.search(r'',doc)\n",
229 | "\n",
230 | " if m != None and m1!= None:\n",
231 | " text = str(m.group(1)) + ' ' + str(m1.group(1))\n",
232 | " else:\n",
233 | " text = \"No match\"\n",
234 | "\n",
235 | " return text\n",
236 | "\n",
237 | "\n",
238 | "def cleanData(doc):\n",
239 | " doc = str(doc)\n",
240 | " doc = doc.lower().strip()\n",
241 | " doc = re.sub('\\n', ' ', doc)\n",
242 | " doc = re.sub(r'\\s+', ' ', doc)\n",
243 | " pattern = '\"/>\n",
330 | "Int64Index: 6487 entries, 19 to 14722\n",
331 | "Data columns (total 5 columns):\n",
332 | "text 6487 non-null object\n",
333 | "ans 6487 non-null object\n",
334 | "equations 6487 non-null object\n",
335 | "unknowns 6487 non-null object\n",
336 | "cleaned_text 6487 non-null object\n",
337 | "dtypes: object(5)\n",
338 | "memory usage: 304.1+ KB\n"
339 | ],
340 | "name": "stdout"
341 | }
342 | ]
343 | },
344 | {
345 | "cell_type": "markdown",
346 | "metadata": {
347 | "id": "OpkqG1aVO2Z5",
348 | "colab_type": "text"
349 | },
350 | "source": [
351 | "Data Modelling (Archieve)\n",
352 | "--"
353 | ]
354 | },
355 | {
356 | "cell_type": "code",
357 | "metadata": {
358 | "id": "WvbmU8_8XLBX",
359 | "colab_type": "code",
360 | "outputId": "c7fd0123-7402-4bdd-e6ff-7d2fcdb47475",
361 | "colab": {
362 | "base_uri": "https://localhost:8080/",
363 | "height": 399
364 | }
365 | },
366 | "source": [
367 | "import pandas as pd\n",
368 | "\n",
369 | "data = pd.read_csv('new_cleaned_data.csv')\n",
370 | "\n",
371 | "from sklearn.model_selection import train_test_split\n",
372 | "\n",
373 | "trainData, testData = train_test_split(data, test_size = 0.2)\n",
374 | "\n",
375 | "trainData.rename(columns={'cleaned text': 'cleaned_text'}, inplace=True)\n",
376 | "\n",
377 | "testData.rename(columns={'cleaned text': 'cleaned_text'}, inplace=True)\n",
378 | "\n",
379 | "trainData = trainData.reset_index(drop=True)\n",
380 | "\n",
381 | "testData = testData.reset_index(drop=True)\n",
382 | "\n",
383 | "print(trainData.info())\n",
384 | "\n",
385 | "print(testData.info())"
386 | ],
387 | "execution_count": 0,
388 | "outputs": [
389 | {
390 | "output_type": "stream",
391 | "text": [
392 | "\n",
393 | "Int64Index: 1896 entries, 1302 to 771\n",
394 | "Data columns (total 5 columns):\n",
395 | "ans 1896 non-null object\n",
396 | "cleaned_text 1896 non-null object\n",
397 | "equations 1896 non-null object\n",
398 | "text 1896 non-null object\n",
399 | "unknowns 1895 non-null object\n",
400 | "dtypes: object(5)\n",
401 | "memory usage: 88.9+ KB\n",
402 | "None\n",
403 | "\n",
404 | "Int64Index: 474 entries, 1928 to 1444\n",
405 | "Data columns (total 5 columns):\n",
406 | "ans 474 non-null object\n",
407 | "cleaned_text 474 non-null object\n",
408 | "equations 474 non-null object\n",
409 | "text 474 non-null object\n",
410 | "unknowns 471 non-null object\n",
411 | "dtypes: object(5)\n",
412 | "memory usage: 22.2+ KB\n",
413 | "None\n"
414 | ],
415 | "name": "stdout"
416 | }
417 | ]
418 | },
419 | {
420 | "cell_type": "code",
421 | "metadata": {
422 | "id": "52SBxzo3HKMV",
423 | "colab_type": "code",
424 | "colab": {}
425 | },
426 | "source": [
427 | "from sklearn.feature_extraction.text import TfidfVectorizer\n",
428 | "\n",
429 | "# print(data.info())\n",
430 | "\n",
431 | "tfidf = TfidfVectorizer(sublinear_tf=True, min_df=1, norm='l2', encoding='latin-1', ngram_range=(1, 2), stop_words= None)\n",
432 | "\n",
433 | "tfidf.fit(trainData['cleaned_text'])\n",
434 | "\n",
435 | "features = tfidf.transform(trainData['cleaned_text']).toarray()\n",
436 | "\n",
437 | "# test = \"Three times the first of three consecutive odd integers is 3 more than twice the third . What is the third integer ?\"\n",
438 | "\n",
439 | "# testClean = cleanData(test)\n",
440 | "\n",
441 | "# print(testClean)\n",
442 | "\n",
443 | "# test_feature = tfidf.transform([testClean]).toarray()\n",
444 | "\n",
445 | "test_features = tfidf.transform(testData['cleaned_text']).toarray()\n",
446 | "\n",
447 | "# print(test_features)\n",
448 | "\n",
449 | "# print(features)"
450 | ],
451 | "execution_count": 0,
452 | "outputs": []
453 | },
454 | {
455 | "cell_type": "code",
456 | "metadata": {
457 | "id": "LzubaLyqhBJl",
458 | "colab_type": "code",
459 | "outputId": "625bbd87-2908-4c73-e70d-5d999793212c",
460 | "colab": {
461 | "base_uri": "https://localhost:8080/",
462 | "height": 433
463 | }
464 | },
465 | "source": [
466 | "testData['matchedQuestion'] = ''\n",
467 | "testData['matchedEq'] = ''\n",
468 | "\n",
469 | "print(trainData.info())\n",
470 | "\n",
471 | "print(testData.info())"
472 | ],
473 | "execution_count": 0,
474 | "outputs": [
475 | {
476 | "output_type": "stream",
477 | "text": [
478 | "\n",
479 | "RangeIndex: 1896 entries, 0 to 1895\n",
480 | "Data columns (total 5 columns):\n",
481 | "ans 1896 non-null object\n",
482 | "cleaned_text 1896 non-null object\n",
483 | "equations 1896 non-null object\n",
484 | "text 1896 non-null object\n",
485 | "unknowns 1895 non-null object\n",
486 | "dtypes: object(5)\n",
487 | "memory usage: 74.2+ KB\n",
488 | "None\n",
489 | "\n",
490 | "RangeIndex: 474 entries, 0 to 473\n",
491 | "Data columns (total 7 columns):\n",
492 | "ans 474 non-null object\n",
493 | "cleaned_text 474 non-null object\n",
494 | "equations 474 non-null object\n",
495 | "text 474 non-null object\n",
496 | "unknowns 471 non-null object\n",
497 | "matchedQuestion 474 non-null object\n",
498 | "matchedEq 474 non-null object\n",
499 | "dtypes: object(7)\n",
500 | "memory usage: 26.0+ KB\n",
501 | "None\n"
502 | ],
503 | "name": "stdout"
504 | }
505 | ]
506 | },
507 | {
508 | "cell_type": "code",
509 | "metadata": {
510 | "id": "T8xvBdI6gmKo",
511 | "colab_type": "code",
512 | "colab": {}
513 | },
514 | "source": [
515 | "from scipy import spatial\n",
516 | "\n",
517 | "score = 0\n",
518 | "index = 0\n",
519 | "\n",
520 | "\n",
521 | "for i, row1 in testData.iterrows():\n",
522 | " score = 0\n",
523 | " for j, row2 in trainData.iterrows():\n",
524 | " similarity = 1 - spatial.distance.cosine(test_features[i], features[j])\n",
525 | " if similarity > score:\n",
526 | " score = similarity\n",
527 | " testData.at[i,'matchedQuestion'] = row2['cleaned_text']\n",
528 | " testData.at[i, 'matchedEq'] = row2['equations']"
529 | ],
530 | "execution_count": 0,
531 | "outputs": []
532 | },
533 | {
534 | "cell_type": "code",
535 | "metadata": {
536 | "id": "BS9iUcSVNYpD",
537 | "colab_type": "code",
538 | "colab": {}
539 | },
540 | "source": [
541 | "testData.info()\n",
542 | "\n",
543 | "testData.to_csv(\"cosineSimilarity.csv\", index = False)"
544 | ],
545 | "execution_count": 0,
546 | "outputs": []
547 | },
548 | {
549 | "cell_type": "code",
550 | "metadata": {
551 | "id": "VIrqPv3czZEw",
552 | "colab_type": "code",
553 | "colab": {}
554 | },
555 | "source": [
556 | "from scipy import spatial\n",
557 | "\n",
558 | "score = 0\n",
559 | "index = 0\n",
560 | "\n",
561 | "def similarity(sen1, sen2):\n",
562 | " score = np.dot(sen1, sen2)/(np.linalg.norm(sen1)*np.linalg.norm(sen2))\n",
563 | " return score\n",
564 | "\n",
565 | "for i, row1 in testData.iterrows():\n",
566 | " score = 0\n",
567 | " for j, row2 in trainData.iterrows():\n",
568 | " similarity = 1 - similarity(test_features[i], features[j])\n",
569 | " if similarity > score:\n",
570 | " score = similarity\n",
571 | " testData.at[i,'matchedQuestion'] = row2['cleaned_text']\n",
572 | " testData.at[i, 'matchedEq'] = row2['equations']"
573 | ],
574 | "execution_count": 0,
575 | "outputs": []
576 | },
577 | {
578 | "cell_type": "code",
579 | "metadata": {
580 | "id": "18jlqxG3DJa3",
581 | "colab_type": "code",
582 | "outputId": "106fb3d7-d495-4b08-eeea-90864e0bd389",
583 | "colab": {
584 | "base_uri": "https://localhost:8080/",
585 | "height": 225
586 | }
587 | },
588 | "source": [
589 | "testData.info()\n",
590 | "\n",
591 | "testData.to_csv(\"generalSimilarity.csv\", index = False)"
592 | ],
593 | "execution_count": 0,
594 | "outputs": [
595 | {
596 | "output_type": "stream",
597 | "text": [
598 | "\n",
599 | "RangeIndex: 474 entries, 0 to 473\n",
600 | "Data columns (total 7 columns):\n",
601 | "ans 474 non-null object\n",
602 | "cleaned_text 474 non-null object\n",
603 | "equations 474 non-null object\n",
604 | "text 474 non-null object\n",
605 | "unknowns 471 non-null object\n",
606 | "matchedQuestion 474 non-null object\n",
607 | "matchedEq 474 non-null object\n",
608 | "dtypes: object(7)\n",
609 | "memory usage: 26.0+ KB\n"
610 | ],
611 | "name": "stdout"
612 | }
613 | ]
614 | },
615 | {
616 | "cell_type": "code",
617 | "metadata": {
618 | "id": "smFQYt8gHHxR",
619 | "colab_type": "code",
620 | "colab": {}
621 | },
622 | "source": [
623 | "from math import *\n",
624 | "\n",
625 | "score = 0\n",
626 | "index = 0\n",
627 | "\n",
628 | "def jaccard_similarity(sen1, sen2):\n",
629 | " intersection = len(set.intersection(*[set(sen1), set(sen2)]))\n",
630 | "\n",
631 | " union = len(set.union(*[set(sen1), set(set2)]))\n",
632 | "\n",
633 | " score = intersection/float(union)\n",
634 | "\n",
635 | " return score\n",
636 | "\n",
637 | "for i, row1 in testData.iterrows():\n",
638 | " score = 0\n",
639 | " for j, row2 in trainData.iterrows():\n",
640 | " similarity = jaccard_similarity(test_features[i], features[j])\n",
641 | " if similarity > score:\n",
642 | " score = similarity\n",
643 | " testData.at[i,'matchedQuestion'] = row2['cleaned_text']\n",
644 | " testData.at[i, 'matchedEq'] = row2['equations']"
645 | ],
646 | "execution_count": 0,
647 | "outputs": []
648 | },
649 | {
650 | "cell_type": "code",
651 | "metadata": {
652 | "id": "hwLHfc5gNbYF",
653 | "colab_type": "code",
654 | "colab": {}
655 | },
656 | "source": [
657 | "testData.info()\n",
658 | "\n",
659 | "testData.to_csv(\"jaccardSimilarity.csv\", index = False)"
660 | ],
661 | "execution_count": 0,
662 | "outputs": []
663 | }
664 | ]
665 | }
--------------------------------------------------------------------------------
/2_Data_cleaning/MWP_DataCleaning.py:
--------------------------------------------------------------------------------
1 | import pandas as pd
2 | import numpy as np
3 | import re
4 | import json
5 | import nltk
6 | from word2number import w2n
7 | from nltk.stem.snowball import SnowballStemmer
8 |
9 | df = pd.read_csv('./Cleaned Data/trainData_univariable.csv')
10 |
11 | df = df[np.invert(np.array(df['text'].isna()))]
12 | numMap = {"twice": 2, "double": 2, "thrice": 3, "half": "1/2", "tenth": "1/10", "quarter": "1/4", "fifth": "1/5"}
13 | fraction = {"third": "/3", "half": "/2", "fourth": "/4", "sixth": "/6", "fifth": "/5", "seventh": "/7", "eighth": "/8",
14 | "ninth": "/9", "tenth": "/10", "eleventh": "/11", "twelfth": "/12", "thirteenth": "/13",
15 | "fourteenth": "/14", "fifteenth": "/15", "sixteenth": "/16", "seventeenth": "/17", "eighteenth": "/18",
16 | "nineteenth": "/19", "twentieth": "/20", "%": "/100"}
17 |
18 | stemmer = SnowballStemmer(language='english')
19 |
20 |
21 | # Convert fractions to floating point numbers
22 | def convert_to_float(frac_str):
23 | try:
24 | return float(frac_str)
25 | except ValueError:
26 | num, denom = frac_str.split('/')
27 | try:
28 | leading, num = num.split(' ')
29 | whole = float(leading)
30 | except ValueError:
31 | whole = 0
32 | frac = float(num) / float(denom)
33 | return whole - frac if whole < 0 else whole + frac
34 |
35 |
36 | # Convert infix expression to suffix expression
37 | def postfix_equation(equ_list):
38 | stack = []
39 | post_equ = []
40 | op_list = ['+', '-', '*', '/', '^']
41 | priori = {'^': 3, '*': 2, '/': 2, '+': 1, '-': 1}
42 | for elem in equ_list:
43 | if elem == '(':
44 | stack.append('(')
45 | elif elem == ')':
46 | while 1:
47 | op = stack.pop()
48 | if op == '(':
49 | break
50 | else:
51 | post_equ.append(op)
52 | elif elem in op_list:
53 | while 1:
54 | if not stack:
55 | break
56 | elif stack[-1] == '(':
57 | break
58 | elif priori[elem] > priori[stack[-1]]:
59 | break
60 | else:
61 | op = stack.pop()
62 | post_equ.append(op)
63 | stack.append(elem)
64 | else:
65 | post_equ.append(elem)
66 | while stack:
67 | post_equ.append(stack.pop())
68 | return post_equ
69 |
70 |
71 | # Identification and filtering of univariable equations
72 | char_set = {'0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '(', ')', '+', '-', '/', '*'}
73 | df_cleaned = pd.DataFrame()
74 | for id_, row in df.iterrows():
75 | l, r = row['equations'].split("=", 1)
76 | lSet, rSet = set(l.replace(" ", "")), set(r.replace(" ", ""))
77 | flagl = (len(l.strip()) == 1 and not l.strip().isdigit() and len(rSet - char_set) == 0)
78 | flagr = (len(r.strip()) == 1 and not r.strip().isdigit() and len(lSet - char_set) == 0)
79 | if flagl or flagr:
80 | if flagr:
81 | row['equations'] = r + '=' + l
82 |
83 | df_cleaned = df_cleaned.append(row)
84 |
85 | k = 0
86 |
87 | numLists = {}
88 | numLists_idx = {}
89 | eqLists_idx = {}
90 | eqLists = {}
91 | texts = {}
92 | equations_List = {}
93 | final_ans = {}
94 | numListMAP = {}
95 | final_replaced_text = {}
96 | final_replaced_eq = {}
97 | final_replaced_eq_post = {}
98 | final_number_list = {}
99 | final_num_postn_list = {}
100 | numtemp_order = {}
101 |
102 | for id_, row in df_cleaned.iterrows():
103 | # Converting fractions and ordinals to values through appropriate sub-routines and string builder ops
104 | if not bool(re.search(r'[\d]+', row['text'])):
105 | continue
106 | sb = ""
107 | numSb = ""
108 | val = 0
109 | prevToken = ""
110 | for tokens in nltk.word_tokenize((row['text'])):
111 | try:
112 | val += w2n.word_to_num(tokens)
113 | except ValueError:
114 | if val > 0:
115 | sb = sb + " " + str(val)
116 | if tokens in fraction:
117 | sb = sb + fraction[tokens]
118 | elif stemmer.stem(tokens) in fraction:
119 | sb = sb + fraction[stemmer.stem(tokens)]
120 | else:
121 | sb = sb + " " + tokens
122 | val = 0
123 | else:
124 | if tokens in numMap.keys():
125 | sb = sb + " " + str(numMap[tokens])
126 | else:
127 | sb = sb + " " + tokens
128 | prevToken = tokens
129 |
130 | re.sub(' +', ' ', sb)
131 | sb = sb.strip()
132 |
133 | # Re-structure equation anomalies for normalization
134 | row['equations'] = row['equations'].replace(' ', '')
135 | eqs = row['equations'][0]
136 |
137 | for i in range(1, len(row['equations']) - 1):
138 | if row['equations'][i] == '(':
139 | if row['equations'][i - 1] == ')' or row['equations'][i - 1].isdigit():
140 | eqs += '*' + row['equations'][i]
141 | else:
142 | eqs += row['equations'][i]
143 | elif row['equations'][i + 1].isdigit() and row['equations'][i] == ')':
144 | eqs += row['equations'][i] + "*"
145 |
146 | elif row['equations'][i] == '-' and (row['equations'][i - 1].isdigit() or row['equations'][i - 1] == ')') and \
147 | row['equations'][i + 1].isdigit():
148 | eqs += "+" + row['equations'][i]
149 | else:
150 | eqs += row['equations'][i]
151 | eqs += row['equations'][-1]
152 |
153 | # Extract different number types from the text in order
154 | numList = re.findall(r'-?[\d]* *[\d]+ *\/ *?[\d]*|-?[\d]+\.?[\d]*', sb)
155 | if len(numList) == 0:
156 | continue
157 | # Extract different number types from the equation in order
158 | eqNumList = re.findall(r'-?[\d]* *[\d]+ *\/ *?[\d]*|-?[\d]+\.?[\d]*', eqs)
159 | # Get the positions of the numbers in the text
160 | id_pattern = [m.span() for m in re.finditer(r'-?[\d]* *[\d]+ *\/ *?[\d]*|-?[\d]+\.?[\d]*', sb)]
161 | # Get the positions of the numbers in the equation
162 | eqid_pattern = [m.span() for m in re.finditer(r'-?[\d]* *[\d]+ *\/ *?[\d]*|-?[\d]+\.?[\d]*', eqs)]
163 | if len(set(numList)) < len(set(eqNumList)) or len(set(eqNumList) - set(numList)) != 0:
164 | continue
165 |
166 | numLists[id_] = numList
167 | numLists_idx[id_] = id_pattern
168 | eqLists_idx[id_] = eqid_pattern
169 | texts[id_] = sb
170 | eqLists[id_] = eqNumList
171 | equations_List[id_] = eqs
172 | final_ans[id_] = row['ans']
173 |
174 | # Adding space between numbers for ease of identification
175 | for id, text in texts.items():
176 | modified_text = ''
177 | prev = 0
178 | for (start, end) in numLists_idx[id]:
179 | modified_text += text[prev:start] + ' ' + text[start:end] + ' '
180 | prev = end
181 | modified_text += text[end:]
182 |
183 | # Re-evaluate positions of variables in the modified text
184 | numList = re.findall(r'-?[\d]* *[\d]+ *\/ *?[\d]*|-?[\d]+\.?[\d]*', modified_text)
185 | if len(numList) == 0:
186 | continue
187 | id_pattern = [m.span() for m in re.finditer(r'-?[\d]* *[\d]+ *\/ *?[\d]*|-?[\d]+\.?[\d]*', modified_text)]
188 |
189 | # assert (len(numList) == len(numLists[id]))
190 | # for i in range(len(numList)):
191 | # assert (numList[i].strip() == numLists[id][i].strip())
192 |
193 | numLists[id] = numList
194 | numLists_idx[id] = id_pattern
195 | texts[id] = modified_text
196 |
197 | # Replace questions text with template variables
198 | for id, vals in numLists.items():
199 | numListMap = {}
200 | num_list = []
201 | num_postion = []
202 | k = 0
203 | for i in range(len(vals)):
204 | tmp = convert_to_float(vals[i]) if vals[i].find("/") != -1 else vals[i]
205 | if tmp not in numListMap.keys():
206 | num_list.append(str(tmp))
207 | num_postion.append(numLists_idx[id][i][0])
208 | numListMap[tmp] = 'temp_' + str(k)
209 | k += 1
210 | final_number_list[id] = num_list
211 | final_num_postn_list[id] = num_postion
212 |
213 | numListMAP[id] = numListMap
214 | replaced_text = ''
215 | prev_index = 0
216 | final_num_list = []
217 | num_temps = []
218 | for i in range(len(vals)):
219 | tmp = convert_to_float(vals[i]) if vals[i].find("/") != -1 else vals[i]
220 | final_num_list.append(tmp)
221 | start = numLists_idx[id][i][0]
222 | end = numLists_idx[id][i][1]
223 | unchanged = texts[id][prev_index:start]
224 | changed = texts[id][start:end]
225 | replaced_text = replaced_text + unchanged + numListMap[tmp]
226 | num_temps.append(numListMap[tmp])
227 | prev_index = end
228 | replaced_text += texts[id][numLists_idx[id][i][-1]:]
229 | final_replaced_text[id] = replaced_text
230 | numtemp_order[id] = num_temps
231 |
232 | # Replace equations with template variables
233 | for id, vals in eqLists.items():
234 | equation = equations_List[id]
235 | replaced_eq = ''
236 | prev_index = 0
237 | for i in range(len(vals)):
238 | tmp = convert_to_float(vals[i]) if vals[i].find("/") != -1 else vals[i]
239 | start = eqLists_idx[id][i][0]
240 | end = eqLists_idx[id][i][1]
241 | unchanged = equation[prev_index:start]
242 | changed = equation[start:end]
243 | replaced_eq = replaced_eq + " " + unchanged + " " + numListMAP[id][tmp]
244 | prev_index = end
245 |
246 | replaced_eq += " " + equation[eqLists_idx[id][i][-1]:]
247 |
248 | list_eqs = replaced_eq.split(" ")
249 | equation_normalized = []
250 |
251 | for i in range(len(list_eqs)):
252 | if list_eqs[i] == '':
253 | continue
254 | elif list_eqs[i].find('temp') == -1 and len(list_eqs[i]) > 1:
255 | equation_normalized.extend(list(list_eqs[i]))
256 | else:
257 | equation_normalized.append(list_eqs[i])
258 |
259 | final_replaced_eq[id] = equation_normalized
260 |
261 | eq_norm_postfix = postfix_equation(equation_normalized)
262 | final_replaced_eq_post[id] = eq_norm_postfix
263 |
264 | # Calculate word position for numbers in text
265 | for id, text in texts.items():
266 | words = [i.replace(' ', '') for i in text.split()]
267 | numpos = []
268 |
269 | min_ind = {i.replace(' ', ''): float('inf') for i in final_number_list[id]}
270 |
271 | for i in range(len(words)):
272 | if words[i] in [j.replace(' ', '') for j in numLists[id]]:
273 | tmp = convert_to_float(words[i]) if words[i].find("/") != -1 else words[i]
274 | min_ind[str(tmp)] = min(min_ind[str(tmp)], i)
275 |
276 | for val in final_number_list[id]:
277 | numpos.append(min_ind[str(val).replace(' ', '')])
278 |
279 | final_num_postn_list[id] = numpos
280 |
281 | # Create json dump from the updated dictionaries
282 | data = {}
283 | for id in final_replaced_text.keys():
284 | data_template = {}
285 | data_template["template_text"] = final_replaced_text[id]
286 | data_template["expression"] = equations_List[id]
287 | data_template["mid_template"] = final_replaced_eq[id]
288 | data_template["num_list"] = final_number_list[id]
289 | data_template["index"] = str(id)
290 | data_template["numtemp_order"] = numtemp_order[id]
291 | data_template["post_template"] = final_replaced_eq_post[id][2:]
292 | if final_ans[id].find("|") != -1:
293 | data_template["ans"] = final_ans[id].split("|")[1].strip()
294 | elif final_ans[id].find("/") != -1:
295 | data_template["ans"] = str(convert_to_float(final_ans[id]))
296 | else:
297 | data_template["ans"] = final_ans[id]
298 | data_template["text"] = texts[id]
299 | data_template["num_position"] = final_num_postn_list[id]
300 | data[id] = data_template
301 |
302 | with open('final_dolphin_data.json', 'w') as f:
303 | json.dump(data, f)
304 |
--------------------------------------------------------------------------------
/2_Data_cleaning/cleaned_data_examples/filtered_cleaned_dolphin_data.json:
--------------------------------------------------------------------------------
1 | {
2 | "1": {
3 | "template_text": "what is the sum of temp_0 and temp_0 ?",
4 | "expression": "x=1+1",
5 | "mid_template": [
6 | "x",
7 | "=",
8 | "temp_0",
9 | "+",
10 | "temp_0"
11 | ],
12 | "num_list": [
13 | "1"
14 | ],
15 | "index": "1",
16 | "numtemp_order": [
17 | "temp_0",
18 | "temp_0"
19 | ],
20 | "post_template": [
21 | "temp_0",
22 | "temp_0",
23 | "+"
24 | ],
25 | "ans": "2",
26 | "text": "what is the sum of 1 and 1 ?",
27 | "num_position": [
28 | 5
29 | ]
30 | },
31 | "65": {
32 | "template_text": "how many temp_0 digit numbers can you make using the digits temp_1 , temp_2 , temp_0 , and temp_3 if the hundreds digit is prime and repetition of a digit is not permitted ?",
33 | "expression": "n=2*3*2",
34 | "mid_template": [
35 | "n",
36 | "=",
37 | "temp_2",
38 | "*",
39 | "temp_0",
40 | "*",
41 | "temp_2"
42 | ],
43 | "num_list": [
44 | "3",
45 | "1",
46 | "2",
47 | "4"
48 | ],
49 | "index": "65",
50 | "numtemp_order": [
51 | "temp_0",
52 | "temp_1",
53 | "temp_2",
54 | "temp_0",
55 | "temp_3"
56 | ],
57 | "post_template": [
58 | "temp_2",
59 | "temp_0",
60 | "*",
61 | "temp_2",
62 | "*"
63 | ],
64 | "ans": "12",
65 | "text": "how many 3 digit numbers can you make using the digits 1 , 2 , 3 , and 4 if the hundreds digit is prime and repetition of a digit is not permitted ?",
66 | "num_position": [
67 | 2,
68 | 11,
69 | 13,
70 | 18
71 | ]
72 | }
73 | }
--------------------------------------------------------------------------------
/2_Data_cleaning/cleaned_data_examples/uncleaned_dolphin_data.csv:
--------------------------------------------------------------------------------
1 | text,ans,equations,unknowns
2 | a number subtracted from 17 gives the quotient of 48 and -8. find the number.,23,17 - x =48 / -8,x
3 | what is the sum of 1 and 1?,2,x = 1 + 1,x
4 | which value is a solution to 60+2x=12x+10?,5,60+2*x=12*x+10,x
5 | "phil found that the sum of twice a number and -21 is 129 greater than the opposite of the number, what is the number?",50,(2*n + (-21)) = 129 + (-1 * n),n
6 | forty-two more than a number divided by four is equal to fifty-two more than the same number divided by fourteen. what is the number?,56,42 + n / 4 = 52 + n / 14,n
7 | "a number when multiplied by 7/18 instead of 7/8 and got the result 770 less than the actual result , find the original number?",1584,7/18*n = 7/8 * n - 770,n
8 | "if one- third of a number is subtracted from three-fourths of that number, the difference is 15. what is the number?",36,3/4*n - 1/3*n = 15,n
9 | one third of a number is 5 less than half of the same number. what is the number?,30,1/3 * n = 1/2 * n - 5,n
10 | 3 times a number is 7 more than twice the number? what is the number?,7,3*n = 7 + 2*n,n
11 | four times a number is 36 greater than the product of the number and -2. what is the number?,6,4*n = 36 + (-2)*n,n
12 | five times a number subtracted from seven times the number is a result of 18?,9,7*n - 5*n = 18,n
13 | "six times a number equals 3 times the number , increased by 24. find the number",8,6*m = 3*m + 24,m
14 | what fraction when added to its reciprocal is equal to 13/6?,3/2 or 2/3,f + 1/f = 13/6,f
15 | if a number is added to its reciprocal the sum is 25/12. find the number.,4/3 or 3/4 | 1.333 or 0.75,n + 1/n = 25/12,n
16 | "when 60% of a number is added to the number, the result is 192. what is the number?",120,60/100*n + n = 192,n
17 | the sum of an integer and its square is 30. find the number?,5 or -6,x + x^2 = 30,x
18 | "if f(x) = x + 2, find the value of x if f(x) = 12.",10,x + 2 = 12,x
19 | "if .4% of a quantity is 2.4, find the total quantity.",600,0.4/100*q = 2.4,q
20 | "what percentage of 306,000 equals 8,200?",0.0268,x * 306000 = 8200,x
21 | 4% of what number is 34?,850,4/100*n = 34,n
22 | 30 is 75% of what number?,40,30 = 75/100*n,n
23 | what percent of 54 is 135?,2.5,x * 54 = 135,x
24 | 56 is 75% of what number?,224/3 | 74.667,56 = 75/100*x,x
25 | if 18% of a number z is 54 find the value of z.,300,18/100*z = 54,z
26 | 35% of what number is 70?,200,35/100*n = 70,n
27 | i have the number 640 and the number 80. i need to figure out what percentage 80 is of 640.,0.125,p*640=80,p
28 | what is the value of x?here's the equation: 3/5(x+2)=12.,18,3/5(x+2)=12,x
29 | "if 5 is decreased by 3 times a number, the result is -4",3,5 - 3*n = -4,n
30 | "if a number is doubled and the result decreased by 35, the result is 25. what is the number?",30,2*n - 35 = 25,n
31 | five times the sum of a number and eight is twenty-five. find the number.,-3,5(n + 8) = 25,n
32 | "if 6 times a number is added to 8, the result is 56. what is the number",8,8 + 6*n = 56,n
33 | the difference of 10 and the product of 6 and n is 21.,-11/6 | -1.833,10 - 6*n = 21,n
34 | three more than the product of a number and 4 is 15. find the number.,3,4*x + 3 = 15,x
35 | three less than twice what number is -7?,-2,2*x-3 = -7,x
36 | the sum of 3 times a number and 17 is 5. what is the number?,-4,3*x + 17 = 5,x
37 | "if you double a number and then add 32,the result is 80. what is the number?",24,2*n + 32 = 80,n
38 | two less than the product of a number and -5 is -42? what is the number?,8,-5 * n - 2 = -42,n
39 | two more than twice a number is equal to the number itself. find the number.,-2,2 + 2*n = n,n
40 | "when the square of an integer is added to ten times the integer, the sum is zero. what is the integer?",-10 or 0,10*n + n^2 = 0,n
41 | two consecutive odd integers have a sum of 48. what is the largest of the two integers?,25,(2*n+1)+(2*n+3) = 48,2*n+3
42 | four times an integer plus 5 is equal to the square of the integer. what is the integer?,-1 or 5,4*n + 5 = n^2,n
43 | "robert thinks of a number. when he multiples his number by 5 and subtracts 16 from the result, he gets the same answer as to when he adds 10 to his number and multiples the result by 3. find the number robert is thinking of.",23,5*x - 16 = (x + 10) * 3,x
44 | the sum of two times a number and eight is equal to three times the difference between the number and four. what's the number?,20,2*n + 8 = 3(n - 4),n
45 | twice the sum of a number and sixteen is five less than three times the number?find the number?,37,2(n + 16) = 3*n - 5,n
46 | find a number such that subtracting its reciprocal from the number gives 16/15.,5/3 or -3/5 | 1.667 or -0.6,n - 1/n = 16/15,n
47 | "if 240 : 3 = x : 5, what's the value of x?",400,240/3 = x/5,x
48 | "given that 7 : 2 = x : 8,find the value of x",28,7/2 = x/8,x
49 | how many eight-digit telephone numbers are possible if the first digit must be nonzero?,90000000 | 9E7,n = 9*10*10*10*10*10*10*10,n
50 | 20% of the number is 3. find the number?,15,0.2*x = 3,x
51 | 5 times a number equals 20what is the number?,4,5*n = 20,n
52 | 15 is 25% of what number?,60,15 = 0.25*x,x
53 | 36 is 60% of what number?,60,36 = 0.6*x,x
54 | four times the difference of a number and one is equal to six times the sum of the number and three. find the number.,-11,4(x - 1) = 6(x + 3),x
55 | 16 more than a number is 20 more than twice the number.,-4,16 + n = 20 + 2*n,n
56 | "when 4 is added to 5 times a number,the number increases by 50. find the number.",23/2 | 11.5,5*n + 4 = n + 50,n
57 | the sum of a number and its reciprocal is 6.41. what is the number?,6.25 or 0.16 | 25/4 or 4/25,n + 1/n = 6.41,n
58 | the sum of a number and its reciprocal is 13. find this number.,12.923 or 0.0774,n + 1/n = 13,n
59 | the sum of a number and twice its reciprocal is 3. what is the number?,2 or 1,x + 2/x = 3,x
60 | "when the number x is increased by x%, it becomes 24. find x.",20 or -120,x*(1+0.01*x) = 24,x
61 | "if a number is multiplied by 5 more than twice that number, the result is 12. find the number?",3/2 or -4 | 1.5 or -4,x * (2*x + 5) = 12,x
62 | a 104 degree angle with 2 congruent sides is called? a 104 degree angle with 2 congruent sides is called?,4 or 2,x^2 + 8 = 6*x,x
63 | the sum of a number and seven times the number is 112. what is the number?,14,x + 7*x = 112,x
64 | "how many 3-digit positive integers are odd and do not contain the digit ""5""?",288,n = 8 * 9 * 4,n
65 | "how many 3 digit even numbers can be made using the digits 1, 2, 3, 4, 6, 7, if no digit is repeated?",60,n=3*5*4,n
66 | "how many 3 digit numbers can you make using the digits 1, 2, 3, and 4 if the hundreds digit is prime and repetition of a digit is not permitted?",12,n = 2 * 3 * 2,n
67 | a number is 110 less than its square. find all such numbers.,-10 or 11,n = n^2 - 110,n
68 | the sum of one third of a number and its reciprocal is the same as 49 divided by the number. find the number,12 or -12,n/3 + 1/n = 49/n,n
69 | "what number, when added to the number three or multiplied to the three, gives the same result?",3/2 | 1.5,n + 3 = 3*n,n
70 | if the average of 20 different positive numbers is 20 then what is the greatest possible number among these 20 numbers?,210,(1+(20-1))*(20-1)/2 + x = 20*20,x
71 | "if 9n^2 - 30n + c is a perfect square for all integers n, what is the value of c?",25,30^2 = 4 * 9 * c,c
72 | what is the least value of k if 3x^2 + 6x + k is never negative?,3,6^2 = 4*3*k,k
73 | find the value of c that makes x^2+18x+c a perfect square trinomial.,81,18^2 = 4 * 1 * c,c
74 | find this quantity: one half of the quantity multiplied by one third of the quantity equals 24.,12 or -12,(1/2)q * (1/3)q = 24,q
75 | "if the expression x^2 - kx - 12 is equal to zero when x=4, what is the value of k?",1,4^2 - 4*k - 12 = 0,k
76 | a 3- digit number in which one digit is the average of the other two digits is called an average number. 456 is an average number because 5 is the average of 4 and 6. how many three digit average numbers are there?,45,n = 9 * 5,n
77 | "what is 8% of 2,000.",160,x = 2000 * 0.08,x
78 | what's 8% of 5.00?,0.4,x = 0.08 * 5,x
79 | one fourth of one third is the same as one half of what fraction?,1/6,(1/4) * (1/3) = (1/2) * x,x
80 | "when the number n is multiplied by 4, the result is the same as when 4 is added to n. what is the value of 3n?",4,4*n = 4 + n,3*n
81 | "when the number w is multiplied by 4, the result is the same as when 4 is added to w. what is the value of 3w?",4,4*w = w + 4,3*w
82 | how many whole numbers less than 500 have seven as the sum of their digits?,30,n = 1 + 7 + 7 + 6 + 5 + 4,n
83 | "if 4 less than 3 times a certain number is 2 more than the number, what is the number",3,3*x - 4 = 2 + x,x
84 | "when twice a number is decreased by 8, the result is the number increased by 7. find the number?",15,2*x - 8 = x + 7,x
85 | 3 less than 5 times a number is 11 more than the number? find the number.,7/2 | 3.5,5*n - 3 = 11 + n,n
86 | "the first three terms of an arithmetic sequence are (12-p), 2p, (4p-5) respectively where p is a constant.find the value of p",7,2*p - (12-p) = (4*p-5) - 2*p,p
87 | "if the product of -3 and the opposite of a number is decreased by 7, the result is 1 greater than the number. what is the number?",4,-3*(-1 * n) - 7 = 1 + n,n
88 | what is the difference between the sum of the first 2004 positive integers and the sum of the next 2004 positive?,4016016,d = 2004 * (1 + 2004 - 1),d
89 | the sum of 2/3 and four times a number is equal to 5/6 subtracted from five times the number. find the number.,3/2 | 1.5,2/3 + 4*n = 5*n - 5/6,n
90 | "find the difference between the product of 26/22 and 3.09 and the sum of 3,507.00, 2.08, 11.50 and 16,712.00",-22251821/1100 | -20228.928,"x = 26/22 * 3.09 - (3507.00 + 2.08 + 11.50 + 16,712.00)",x
91 | what is 5/9 times 36?,20,x = 5/9 * 36,x
92 | what number is 5 sixths of 100?,250/3 | 83.333,x = 5/6 * 100,x
93 | 5/6 of a number is 3/4. what is the number?,9/10 | 0.9,5/6 * n = 3/4,n
94 | what is the value of x in the following equation?3x - 4(x + 1) + 10 = 0,6,3*x - 4(x + 1) + 10 = 0,x
95 | "a whole number between 100 and 1000 is to be formed so that one of the digits is 6, and all the digits are different. how many numbers are possible.",200,x = 9*8 + 8*8 + 8*8,x
96 | how many permutations of digits in the number 12345 will make an even number?,48,n = 2 * (4 * 3 * 2 * 1),n
97 | how many 1/4 ounces are in 56 pounds?,3584,x*(1/4) = 56 * 16,x
98 | what number minus one-half is equal to negative one-half ?,0,x - 1/2 = -1/2,x
99 | 16 over 24 minus 18 over 24?,-1/12 | -0.0833,x = 16/24 - 18/24,x
100 | "the quotient of -54 and 9, subtracted from 8 times a number, is -18. what's the number?",-3,8*n - (-54)/9 = -18,n
101 | "if 30% of 40% of a positive # is equal to 20% of w% of the same #, what is the value of w?",60,30/100*40/100*n = 20/100*w/100*n,w
102 | square root six over square root 7 multiplied by square root 14 over square root 3.,2,x = (sqrt(6)/sqrt(7)) * (sqrt(14)/sqrt(3)),x
103 | "if 4 is added to the square of a composite integer, the result is 14 less than 9 times that integer. find the integer.",6,4 + n^2 = 9*n - 14,n
104 | "how many four-digit numbers can be formed using the digits 0, 1, 2, 3, 4, 5, 6, 7, 8, and 9 if the first digit cannot be 0? repeated digits are allowed.",9000,n = 9 * 10 * 10 * 10,n
105 | the equation is 10x - 19 = 9 - 4x. what's the value of x?,2,10*x - 19 = 9 - 4*x,x
106 | twelve less than 2 times a number is equal to 15 minus 7 times the number.,3,2*n - 12 = 15 - 7*n,n
107 |
--------------------------------------------------------------------------------
/3_T-RNN_&_baselines/output/pred_dolphin_rep_tfidf.txt:
--------------------------------------------------------------------------------
1 | 954 216216.0 216216.0
2 | 956 299484.0 299484.0
3 | 2015 1924.0 1924.0
4 | 976 -0.71875 -0.71875
5 | 982 5893157296026.0 5893157296026.0
6 | 985 78706.0 78706.0
7 | 999 223404480.0 223404480.0
8 | 1040 5184000.0 5184000.0
9 | 1004 274.0 274.0
10 | 600 3000.0 3000
11 | 1011 117740.0 117740.0
12 | 1012 10080.0 10080.0
13 | 686 6750.0 6750
14 | 1046 26.543513101659002 26.543513101659002
15 | 1066 625.0 625.0
16 | 1067 560.0 560
17 | 1069 66030.0 66030.0
18 | 41 0.30000000000000004 0.3
19 | 1088 87280691400.0 87280691400.0
20 | 1092 2976.0 2976.0
21 | 1095 376.0 376.0
22 | 1099 56.0 56.0
23 | 1104 3873860100.0 3873860100.0
24 | 1105 453.0 453
25 | 1118 52000.0 52000
26 | 1122 1026.0 1026.0
27 | 1133 11020.0 11020
28 | 1147 5382013074.0 5382013074.0
29 | 1157 385.0 385.0
30 | 1164 1026.0 1026.0
31 | 1173 1682.0 1682.0
32 | 1179 91931.0 91931.0
33 | 1186 418200.0 418200.0
34 | 1194 48440106.0 48440106.0
35 | 1198 6774144.0 6774144.0
36 | 1200 27440.0 27440.0
37 | 1203 112390880.0 112390880.0
38 | 1217 3562120.0 3562120.0
39 | 1230 788112.0 788112.0
40 | 1232 3162.0 3162.0
41 | 1247 225.0 225
42 | 1251 -0.5428571428571428 -0.5428571428571428
43 | 1254 254562.0 254562.0
44 | 617 360.0 360
45 | 110 2205.0 2205
46 | 1295 7851720.0 7851720.0
47 | 1296 218400.0 218400.0
48 | 1315 3146.0 3146.0
49 | 1320 103247.0 103247.0
50 | 1325 4968.0 4968.0
51 | 1331 2675.0 2675.0
52 | 1338 53.0 53.0
53 | 1340 1.4308943089430894 1.4308943089430894
54 | 1352 4567612665.0 4567612665.0
55 | 1359 5285280.0 5285280.0
56 | 1367 475716.0 475716.0
57 | 1371 702.0 702.0
58 | 1372 5985.0 5985.0
59 | 1681 8576.0 8576.0
60 | 1392 6873.0 6873.0
61 | 1396 240.0 240
62 | 1397 48668.0 48668.0
63 | 1402 104580.0 104580.0
64 | 1422 8528085.0 8528085.0
65 | 1438 106.5 106.5
66 | 1444 4758.0 4758.0
67 | 1448 204.0 204.0
68 | 2217 13720224.0 13720224.0
69 | 1449 0.1363636363636 0.136363636364
70 | 1455 0.26954863318499683 0.26954863318499683
71 | 1458 0.8073981542854229 0.8073981542854229
72 | 1473 180.0 180
73 | 1475 10528848.0 10528848.0
74 | 1476 10142808.0 10142808.0
75 | 1488 144.0 144
76 | 1503 39900.0 39900.0
77 | 1514 280.0 280.0
78 | 1518 0.4444444444444444 0.444444444444
79 | 1526 80410.0 80410.0
80 | 1541 13910400.0 13910400.0
81 | 1548 67704.0 67704.0
82 | 1568 23.525339373123398 23.525339373123398
83 | 1576 687.0 687.0
84 | 1577 864.0 864
85 | 1592 4592.0 4592.0
86 | 1602 700.0 700.0
87 | 292 31536000.0 31536000
88 | 1632 -18.0 -18.0
89 | 1648 1936.0 1936.0
90 | 1653 2744.0 2744.0
91 | 1762 155.88888888888889 155.88888888888889
92 | 1661 228.0 228.0
93 | 1674 134.0 134.0
94 | 1687 486.0 486
95 | 1692 1232.0 1232.0
96 | 1711 22950000.0 22950000.0
97 | 1730 167.0 167.0
98 | 1731 94.0 94.0
99 | 1752 0.84 0.84
100 | 1779 -0.75 -0.75
101 | 1784 0.9037037037037037 0.9037037037037037
102 | 2316 1369.0 1369.0
103 | 1900 82.0 82.0
104 | 1919 31000486944.0 31000486944.0
105 | 1920 1890108.0 1890108.0
106 | 1922 6.0 6
107 | 1929 46746.0 46746.0
108 | 1942 205.0 205.0
109 | 1943 133.0 133.0
110 | 1950 1031510.0 1031510.0
111 | 1952 86700.0 86700.0
112 | 1955 1440.0 1440
113 | 1957 291456.0 291456.0
114 | 1965 132.0 132.0
115 | 1981 5750352.0 5750352.0
116 | 1995 131424.0 131424.0
117 | 1997 1082004.0 1082004.0
118 | 2002 100.0 100.0
119 | 2026 462384.0 462384.0
120 | 2033 37004.0 37004.0
121 | 2053 22.053571428571427 22.053571428571427
122 | 2067 1560896.0 1560896.0
123 | 385 80.0 80
124 | 2093 1649061.0 1649061.0
125 | 2099 1331.0 1331.0
126 | 2100 64.0 64.0
127 | 2103 59319.0 59319.0
128 | 2109 8.35483870967742 8.35483870967742
129 | 2118 67811.0 67811.0
130 | 2137 7605.0 7605.0
131 | 2146 3996.0 3996.0
132 | 2167 1800.0 1800.0
133 | 2168 4004.0 4004.0
134 | 2183 4224.0 4224.0
135 | 2209 2.6382978723404253 2.6382978723404253
136 | 2242 867510.0 867510.0
137 | 2252 75548.0 75548.0
138 | 2259 1600.0 1600.0
139 | 2272 39.0 39.0
140 | 2282 17280000.0 17280000.0
141 | 455 11.0 11
142 | 2356 4.8 4.8
143 | 473 210.0 210
144 | 2360 551448.0 551448.0
145 | 2364 429975.0 429975.0
146 | 2367 213759.0 213759.0
147 | 2393 6862.0 6862.0
148 | 2394 44100.0 44100.0
149 | 2413 5720.0 5720.0
150 | 2419 0.008927575047427742 0.008927575047427742
151 | 497 3.0 3
152 | 2424 7826.0 7826.0
153 | 2427 6014.0 6014.0
154 | 2428 1479.0 1479.0
155 | 2435 33284498950.0 33284498950.0
156 | 2439 289304.0 289304.0
157 | 2465 86.11981566820276 86.11981566820276
158 |
--------------------------------------------------------------------------------
/3_T-RNN_&_baselines/output/pred_dolphin_tfidf.txt:
--------------------------------------------------------------------------------
1 | 1536 3150.0 60
2 | 1561 20.0 4752
3 | 37 5.0 6
4 | 2123 -6.0 2
5 | 1110 0.0 105
6 | 1551 78.0 78
7 | 2152 25088.0 0.875
8 | 1652 105.0 315
9 | 1301 24.0 120
10 | 717 1944.0 216
11 | 2278 168.0 18
12 | 247 16.0 0.033
13 | 253 12.0 80
14 | 267 75.0 0.615384615385
15 | 1826 1660.0 3160
16 | 1670 72.0 52
17 | 299 5.0 7
18 | 2140 80.0 20
19 | 1690 88.0 408
20 | 1863 -8.0 8
21 | 1370 144.0 144
22 | 455 11.0 11
23 | 1754 1.0 0.833333333333
24 | 1492 180.0 180
25 | 1621 8.0 3.42857142857
26 |
--------------------------------------------------------------------------------
/3_T-RNN_&_baselines/output/pred_math23k_tfidf.txt:
--------------------------------------------------------------------------------
1 | 7534 1100.0 60
2 | 7652 -0.8333333333299999 0.266666666667
3 | 23054 -124.0 43
4 | 14980 56.35673624288425 7
5 | 12837 -18.0 21
6 | 978 -178.833333333333 30
7 | 3513 4.0 4
8 | 15534 120.75 68.5714285714
9 | 12831 180.0 36
10 | 21881 799.4 240
11 | 18638 1500.0 1500
12 | 14552 0.3 0.3
13 | 21887 840.2 945
14 | 1818 60.0 60
15 | 5902 400.0 400
16 | 17497 15.0 315
17 | 20458 16.666666666666668 25
18 | 12053 -28.125 15
19 | 11062 796000.0 190
20 | 7828 4.8 3.072
21 | 13693 1203.0 750
22 | 18174 9.0 9
23 | 6679 21.0 21
24 | 6181 250.0 250
25 | 923 4000.0 4000
26 | 21017 0.8 81920
27 | 7686 0.05714285714285714 72
28 | 10874 2369.0 2369
29 | 12148 0.7291666666643749 1
30 | 5726 23005.0 47
31 | 13009 1152.0 1152
32 | 8509 500.0 320
33 | 2339 8.2 9
34 | 12042 0.3333333333333333 6
35 | 15159 0.34594594594594597 12.75
36 | 18333 425.0 42
37 | 7395 49.999999999999986 200
38 | 6061 88014.70999999999 1449.3
39 | 7729 27.041666666666668 2174
40 | 19245 3.761904761904762 5
41 | 316 20.65853658536585 2.9
42 | 6682 3120.0 3120
43 | 15563 96.0 150
44 | 11959 20.0 20
45 | 1120 16.0 16
46 | 233 54.666666666720666 12
47 | 15688 40.0 40
48 | 17147 481.0 481
49 | 21323 245.0 320
50 | 3456 57.292775000000006 0.47
51 | 20787 0.12 0.12
52 | 14915 12.0 12
53 | 1970 15000.0 15000
54 | 12906 0.5384615384615384 0.538461538462
55 | 13206 2160.0 30
56 | 15150 -1.44 936
57 | 581 8.0 8
58 | 3416 8.53981128 8.53981128
59 | 19230 6248.0 6336
60 | 14974 4.5 11
61 | 16206 5100.0 5100
62 | 9130 -5.75 0.125
63 | 3843 49.999999999950006 50
64 | 10464 160.000000000064 160
65 | 16208 7.0 6000
66 | 5364 78500.0 10
67 | 22153 -0.0030303030303030303 320
68 | 22057 30.000000000000007 13.5
69 | 9820 100.0 100
70 | 23017 0.47619047619064286 112
71 | 14404 7.9866888519134775 180
72 | 5413 0.65 0.6
73 | 17285 485.833333333245 318
74 | 21946 755.3325 126.75
75 | 6951 135.375 150
76 | 6051 25.0 20
77 | 15757 2.9899999999999998 3.05
78 | 20445 64.0 440
79 | 12348 80.0 80
80 | 18952 13.333333333333334 12
81 | 4050 -46.199999999999996 46.2
82 | 5899 0.9090909090909091 111
83 | 4412 200.0 200
84 | 15464 0.6351791530944625 61
85 | 12813 0.16666666666675 4.29166666667
86 | 5701 19.7 15.5
87 | 16257 160.0 14
88 | 1830 61350.0 61350
89 | 17773 5.0 40
90 | 10045 36.0 9600
91 | 4640 0.42857142857099995 0.428571428571
92 | 1688 0.14285714285714285 62
93 | 9962 90.60166666666667 120
94 | 20647 0.08256880733944955 0.09
95 | 15678 0.5 0.5
96 | 19222 -1.9 13.6
97 | 1888 1.804 6.76
98 | 18362 672.0 5600
99 | 21301 -277.21774193548384 1.2
100 | 11417 0.0 0.333333333333
101 | 13907 -34.0 65
102 | 22852 384.0 7
103 | 11276 2242.0 1229
104 | 16422 -17567.0 49
105 | 7257 77.999999999976 4
106 | 20524 3320.0 470
107 | 21203 0.6 100
108 | 6155 1.0125 81
109 | 9322 1.5 450
110 | 4129 52.881844380403464 6
111 | 12922 27.0 27
112 | 4234 36.1 36.1
113 | 7459 4200.0 170
114 | 18453 31.999999999992 18
115 | 12472 0.42857142857200003 0.428571428571
116 | 1329 160.0 160
117 | 14341 32.000000000016 32
118 | 4084 9230.0 19520
119 | 22574 1230.0 15
120 | 11350 7.0 7
121 | 22726 1008.0 28
122 | 14706 1.4 1.4
123 | 20302 28.0 28
124 | 10437 0.35714285714285715 715
125 | 18415 649.999999998 600
126 | 22015 89.28571428574999 50
127 | 22195 15.135135135135135 11
128 | 16997 13499775.0 75
129 | 17108 -0.007035175879396985 280
130 | 4289 600.0 600
131 | 21848 75.0 75
132 | 18360 576.0 400
133 | 2350 -6.442953020134228 58
134 | 3744 0.0016638935108153079 54
135 | 10838 0.09090909090909091 0.1
136 | 12421 320.0 130
137 | 22255 0.20930232558139536 80
138 | 13213 -0.15384615384615385 0.2
139 | 4716 1049.76 1.2
140 | 23146 70.0 70
141 | 18585 10.05 0.5
142 | 4504 -0.041666666666666664 28
143 | 17301 1850.0 1850
144 | 22294 23.999999999999996 24
145 | 15835 15.0 15
146 | 4118 -34.399999999999196 46
147 | 17109 280.0 96
148 | 9080 0.92 4600
149 | 23085 10.714285714286 10.7142857143
150 | 16001 2.6666666666666665 48
151 | 21186 -450.0 1200
152 | 2267 104.0 564
153 | 13141 810.00000000081 81
154 | 21632 12.5 6.5
155 | 5135 22.0 22
156 | 3589 530.0 530
157 | 4468 1920.0 4
158 | 6744 2.0996052742063494e-05 63
159 | 10394 0.4 0.4
160 | 7123 100.0 100
161 | 5205 -2.9 0.3
162 | 5089 1.1 11
163 | 10529 76.0 76
164 | 7098 672.0 672
165 | 11573 4.3 4.3
166 | 15765 45.0 225
167 | 7198 250.00000000000003 250
168 | 16277 8.5 1.5
169 | 13466 100000.0 10
170 | 6910 27.0 120
171 | 21695 8.0 8
172 | 15334 336.000000000336 168
173 | 16628 -15.172413793103448 40
174 | 18685 0.0033333333333333335 50
175 | 3757 -1780.0 10
176 | 5144 1.1904761904761905 930
177 | 3379 1.6666666666666667 21
178 | 1536 1.111111111111111 200
179 | 6725 7.0 7
180 | 2479 9.999999999989999 80
181 | 9143 0.14 0.14
182 | 3724 -0.7629999999999999 3.88
183 | 742 -3.1578947368421053 0.25
184 | 4079 0.49 6825
185 | 22672 0.7484375 1120
186 | 19291 10.0 10
187 | 17403 1160.0 12
188 | 11231 0.4444444444444444 0.444444444444
189 | 15583 108.4 123
190 | 15263 2000.0 200
191 | 20131 -27.999999999984 20
192 | 8945 450.0 60
193 | 6827 0.14285714285714285 0.142857142857
194 | 8040 0.26785714285725 0.267857142857
195 | 22913 -100.0 35
196 | 7496 5400.0 5400
197 | 6776 409.0 169
198 | 12186 27.47222222222222 15
199 | 1675 -56.25 1350
200 | 12474 0.4444444444444444 0.222222222222
201 | 1779 60.4 360
202 | 14732 75.0 55
203 | 2332 1000.0 90
204 | 3514 33.84615384605621 300
205 | 8444 126.0 126
206 | 19806 360.0 240
207 | 3906 2810.0 6
208 | 7555 3.75 3.75
209 | 15320 -34.980000000000004 2000
210 | 16903 14112.0 392
211 | 12994 5306.0 5306
212 | 21128 30.0 30
213 | 3946 50.000000000025004 120
214 | 13680 400.0 400
215 | 15377 11.523809523809524 55
216 | 14874 -36.0 40
217 | 19586 64.0 64
218 | 22011 -0.0234375 258
219 | 13229 288.0 6
220 | 4179 -0.25 0.75
221 | 19377 0.8235294117647058 58
222 | 8453 711.0 1935
223 | 6470 5.25 200
224 | 118 3.0 3
225 | 5853 20.0 20
226 | 21398 8001.6 1250
227 | 5036 -7.6 120
228 | 1355 25.0 25
229 | 11936 10.799999999999999 30
230 | 2570 0.2666666666666 360
231 | 13439 7.0 7
232 | 7654 6.0 6
233 | 3784 -57.142857142857146 200
234 | 16486 -2205.000000000014 37
235 | 10750 45.0 4.5
236 | 8438 0.49166666666666664 900
237 | 9681 650140.0 65
238 | 12868 8.888888888888857 80
239 | 8067 -12.0 10
240 | 21421 1.7777777777777777 1.77777777778
241 | 14754 100.8 115.2
242 | 9078 3.2129396788997536e-05 0
243 | 17140 1517472.0 2246
244 | 7722 5625.0 20
245 | 2791 41.0 98
246 | 6648 226.66666666666663 40
247 | 3775 144.0 144
248 | 15793 0.16666666666666666 0.166666666667
249 | 1466 0.2375 84.375
250 | 3501 -4.008016032064128 0.2
251 | 20781 1280.0 2000
252 | 10980 25967.5 6110
253 | 1393 200.0 200
254 | 771 21.0 21
255 | 22667 1.150537634408602 3.6
256 | 15748 1.6 1.6
257 | 15803 0.017361111111111112 36000
258 | 19868 2360.0 344
259 | 19390 -0.0011111111111095238 77.7777777778
260 | 15140 44.444444444488894 60
261 | 5859 20.0 20
262 | 1705 159.0 208
263 | 1411 224.0 800
264 | 22239 -0.039473684210526314 15
265 | 15208 140.0 8120
266 | 12465 128.40000000000597 0.206896551724
267 | 1789 0.10641025641025642 1.56
268 | 19129 6.24 0.1
269 | 14906 0.00019994001799460164 125
270 | 17781 -0.030454318522216674 34
271 | 20076 49.000000000006125 64
272 | 18326 8.0 8
273 | 14762 810.0000000000002 9
274 | 20828 2400.0 2400
275 | 18063 14.0 14
276 | 14100 0.4 2.4
277 | 1538 18.181818181818183 8
278 | 17983 215.99999999987295 216
279 | 15937 488.5714285714286 488.571428571
280 | 9111 -0.36 2.2
281 | 1302 432.0 432
282 | 14281 2.8 2.8
283 | 9252 0.0 27.75
284 | 253 5.0 5
285 | 9208 2.2857142857139183 180
286 | 17422 3605400.0 10
287 | 10328 3.0 15
288 | 2747 0.85 6
289 | 10869 1080.6779999999999 1998
290 | 2195 300.0 7
291 | 1595 1896.0 2121
292 | 21239 67.5 67.5
293 | 4664 270.0 270
294 | 8033 8.0 8
295 | 6953 4.694444444444445 8
296 | 11657 -0.759493670886076 0.25
297 | 17902 4800.0 1800
298 | 7911 -288.0 3
299 | 19018 0.8042328042326032 75
300 | 1592 7.1000000000000005 8.7
301 | 17266 18.0 18
302 | 22981 36.0 36
303 | 21855 450.0 43
304 | 21575 2250.0 9000
305 | 12587 48.0 87
306 | 1911 839.99999999937 840
307 | 10565 0.2 0.2
308 | 4999 0.37142857142799995 0.571428571429
309 | 6152 0.5 0.5
310 | 18609 40.0 40
311 | 14332 -1.4933333333324 9
312 | 3439 -0.5833333333335 0.833333333333
313 | 13381 42.0 42
314 | 17611 1680.0 400
315 | 3478 8000.0 264
316 | 3852 1.7 12
317 | 17504 8.0 18
318 | 12505 1.4666666666666668 0.7
319 | 8452 4000.0 40
320 | 17850 8.0 56
321 | 6997 432.222222222222 264
322 | 14179 200.0 162
323 | 2894 -9.0 28
324 | 11644 11.0 11
325 | 643 1119.0 867
326 | 18153 459.0 459
327 | 12969 45946.56 252
328 | 2337 0.05 0.95
329 | 10407 0.30000000000000027 0.5
330 | 6439 0.525 700
331 | 14142 0.96 0.04
332 | 19900 -6.0 2
333 | 22332 86.4 86.4
334 | 7086 34300.0 28
335 | 12537 5.0 0.1
336 | 13501 2.9699999999999998 2.4
337 | 2141 41.15384615384615 840
338 | 7207 3.0 50
339 | 20436 0.01 4.5
340 | 21726 12.100000000000001 12.1
341 | 9828 -0.0011129660545358931 1500
342 | 3115 -3392.0 73
343 | 7066 0.002802144249513158 216
344 | 18232 0.8333333333325 0.266666666667
345 | 17668 6.0 22
346 | 514 3.0 3
347 | 14444 0.5333333333336 150
348 | 19741 8.7 8.7
349 | 2732 240.0 135
350 | 18517 5.0 5
351 | 12797 0.19999999999760004 18
352 | 9856 504.0 48
353 | 2466 15.333333333333 10
354 | 17388 3600.0 3600
355 | 9650 31.0 31
356 | 1988 -5634.67 130.13
357 | 20281 0.09999999999999998 0.1
358 | 842 280721.0 5
359 | 8137 213.54545454545456 26100
360 | 10287 -0.0026041666666666665 392
361 | 9251 35.555555555519994 360
362 | 2866 5.0 90
363 | 16212 2.0 2
364 | 4197 1185230.0 490
365 | 5069 4037.04 6.1
366 | 9566 0.13793103448275862 170
367 | 15995 10.0 10
368 | 18319 122.4 489.6
369 | 853 40.0 40
370 | 11135 301.45 780
371 | 18772 -0.18388429752066118 39.2
372 | 18618 8.0 5
373 | 12567 0.6 240
374 | 12256 66.66666666666667 96
375 | 11894 -3745.2 1.5
376 | 11538 50.0 50
377 | 7295 50.0 50
378 | 8089 0.166666666667 10
379 | 6780 0.625 0.35
380 | 19497 5.429246293061726e+168 1.5
381 | 15177 0.16 0.16
382 | 8565 36.0 36
383 | 21876 0.06666666666666667 143
384 | 7931 -43.0 87.6
385 | 11060 0.39999999999999997 0.4
386 | 15573 14.0 14
387 | 1553 -0.031055900621086957 34.5
388 | 11683 114.0 114
389 | 8269 0.0009302325581395349 1075
390 | 17025 4018.0 58
391 | 7364 7.0 7
392 | 14314 0.036000000000000004 0.036
393 | 20327 -28320.0 0.125
394 | 10799 -5.333333333333 0.833333333333
395 | 11165 198.0 54
396 | 15036 4.5 7.5
397 | 11767 78.000000000078 78
398 | 17369 -113.89830508474576 39
399 | 8319 -50.00000000022998 528
400 | 406 1926.0 7
401 | 13945 33.666666666666664 50
402 | 10050 0.011908931698774083 30
403 | 11429 160.0 25
404 | 17479 8.9 8.9
405 | 11659 0.0015282730514503335 120
406 | 16156 93500.00000000001 93500
407 | 10333 107.2 48
408 | 12203 0.06584362139920823 0.0833333333333
409 | 20631 -52680.0 7920
410 | 5009 64.00000000008001 64
411 | 3258 0.005333333333333333 30
412 | 17281 10.5 4
413 | 6922 250.0 730
414 | 9165 840.0 840
415 | 1249 3.0 600
416 | 9342 14400.0 1
417 | 11078 0.08662696264212236 0.25
418 | 2102 4.8076923076923075 16
419 | 5137 0.19999999999999998 0.25
420 | 18088 11.0 11
421 | 19267 0.5625 52
422 | 7293 180.0 180
423 | 11509 1.666666666666 0.972222222222
424 | 1591 0.9944751381215425 225
425 | 5846 53.333333333333336 300
426 | 4346 -1349.42 783
427 | 17849 0.033333333333300005 5.45454545455
428 | 10505 2.6315789473684212 40
429 | 3202 105.0 105
430 | 9511 20.0 16
431 | 11061 60.0 60
432 | 2316 12.0 3
433 | 14575 0.435 0.125
434 | 9289 0.004166666666666667 32
435 | 11873 18.0 198
436 | 17694 -50.0 50
437 | 10001 2.4000000000000004 0.4
438 | 1520 8.421052631578947 5
439 | 8463 -59.416666666667 28
440 | 10077 18.96 18.96
441 | 21794 -15.0 0.25
442 | 19545 -0.0046875 200
443 | 17535 1980.0 1980
444 | 22821 1.3333333333333333 4
445 | 21557 1542.6 52
446 | 1140 132.0 9
447 | 15890 0.2 0.8
448 | 17134 1698.6666666666667 5
449 | 3475 0.0004982559798166703 0.999002991027
450 | 9849 400.0 400
451 | 13309 2064.15 310
452 | 20185 4560.0 10
453 | 18748 146.285714285696 112
454 | 2957 -6.75 80
455 | 1929 2000500.0 528.35
456 | 7607 36.0 36
457 | 327 22500.0 0
458 | 11541 1100.0 1100
459 | 7657 75.0 5
460 | 9592 23.0 15
461 | 13295 7.0 8
462 | 12387 56.0 5
463 | 10930 3.0 3
464 | 8289 3300.0 108
465 | 14335 0.006 238
466 | 3500 8000.0 8000
467 | 22994 -28.0 60
468 | 21882 324.0 36
469 | 18002 -20.0 3
470 | 11879 -0.053106744556558685 6
471 | 7417 39.8 40.2
472 | 6088 11.0 27
473 | 19437 168.8 15.6
474 | 14847 26.0 26
475 | 15901 28.8 228.8
476 | 15948 144.0 144
477 | 5063 45.2 80
478 | 11790 840.0000000001528 840
479 | 4829 0.32000000000000006 1.33333333333
480 | 4159 42.0 42
481 | 20513 25.0 200
482 | 13691 72.0 72
483 | 6892 -14160.0 0.2
484 | 8793 0.13333333333333333 30.6
485 | 5231 0.05555555555555555 18
486 | 21370 -0.08333333333299997 2
487 | 2506 750.0 750
488 | 12936 540.0 2040
489 | 7536 9.142857142857144 7.8
490 | 16035 4.800000000000001 3.57
491 | 2916 6.0 6
492 | 20803 12.0 12
493 | 11214 0.30000000000004357 0.3
494 | 12802 56980.0 76
495 | 15384 nan 95.76
496 | 1919 46.444444444398 171
497 | 1795 -29.0 3
498 | 23066 -75.01854140914709 53.9466666667
499 | 21639 75.00000000000001 75
500 | 4493 -3.0000000000094285 45
501 | 21176 3.9599999999999995 6
502 | 22059 210.0 90
503 | 10632 -0.20746058091286307 720
504 | 12514 63.9999999999744 36
505 | 1291 2304.0 26
506 | 7942 2.0000000000035003 2
507 | 829 30.0 30
508 | 17470 110.00000000002501 110
509 | 432 1008.0 27
510 | 610 -224400.0 17
511 | 13490 30.25000000003025 0.590909090909
512 | 18573 3.8461538461538463 0.3125
513 | 2651 0.25 4
514 | 2625 0.013888888888888888 72
515 | 1035 -1.877167755990811 33.5
516 | 4299 9.0 15
517 | 22620 102.0 102
518 | 9682 7.0 2500
519 | 20963 0.0011111111111111111 330
520 | 17516 500.0 500
521 | 12323 1.3333333333333333 40
522 | 14903 80.99999999998988 81
523 | 17078 35.999999999951996 36
524 | 10910 640.0 11
525 | 20800 1000.0 1000
526 | 8590 760.0 2
527 | 11446 16.25 16.25
528 | 7487 87120.0 0.2
529 | 5006 2800.0 1520
530 | 8101 1.0 2.5
531 | 11476 12.407407407416665 11
532 | 17789 -434.99999999999983 200
533 | 14524 80.0 620
534 | 2349 31.249999999999996 31.25
535 | 4383 2268.0 2268
536 | 12008 350.0 100
537 | 13992 19.0 20
538 | 9448 2.0 2500
539 | 14773 0.02 8.33333333333
540 | 3003 4.2749999999999995 90
541 | 3301 9.32 4.66
542 | 5918 81.08108108108108 100
543 | 20952 -36.5625 6
544 | 9883 500.0 500
545 | 22714 85.33333333333333 170
546 | 21612 270.0 270
547 | 10474 1.1800000000000002 14.4
548 | 14965 15.0 15
549 | 3487 24.975 5
550 | 19875 20.0 20
551 | 16757 11.0 38
552 | 5754 9.0 9
553 | 13635 0.06153846153846154 1040
554 | 22110 12.0 12
555 | 8535 3016.0 308
556 | 11727 -3.04 56
557 | 15141 354.0 120
558 | 18419 16.0 1
559 | 10867 105.0 45
560 | 11204 0.33333333333333337 15
561 | 9874 0.525252525253 0.525252525253
562 | 12075 2.0 12.5
563 | 17268 0.02 33
564 | 11792 4320.0 4320
565 | 7240 399.0 399
566 | 12602 300.0 300
567 | 3932 11.818181818181818 713
568 | 10298 51.0 17
569 | 7150 3574.2857142857147 700
570 | 13982 0.5833333333329717 0.583333333333
571 | 5273 0.050724637681159424 26
572 | 22259 90.0 60
573 | 12349 4.0 20
574 | 5906 195.0 195
575 | 16642 480.0 480
576 | 3255 404.3076923076923 84
577 | 8248 59.0 61
578 | 20022 1.8571428571428572 0.533333333333
579 | 19973 15.0 0.0666666666667
580 | 13573 316.0 31
581 | 22928 0.00021132713440405747 7
582 | 6427 1.5 21
583 | 9658 75.00000000033751 75
584 | 163 0.125 0.125
585 | 12855 -0.6612244897959183 95
586 | 15740 162.0 27
587 | 14855 19659.0 390
588 | 13741 -0.05259259259259259 1.4332
589 | 22202 840.0 840
590 | 20771 2268.0 1764
591 | 18851 24.0 25
592 | 6931 -29.666666666666668 120
593 | 15346 -10.25316455696202 110
594 | 10002 76.5 75
595 | 17583 17600.0 1000
596 | 11602 48.333333333333336 80
597 | 14609 3.4 2.4
598 | 15897 5985.0 5985
599 | 21211 60.00000000000001 60
600 | 18405 7.2 3
601 | 19680 28.800000000000004 5.5
602 | 11450 59.999999999879996 50
603 | 12691 2.5 2.5
604 | 7012 13.5 13.5
605 | 6096 84.0 38
606 | 18450 0.9319727891157142 0.3
607 | 21348 800.0 800
608 | 13807 0.0196078431372549 0.0196078431373
609 | 3561 0.6666666666666666 0.5
610 | 14069 6036.0 20700
611 | 10589 21875.428571428572 1500
612 | 2954 35.555555555573335 20
613 | 16273 1.0345071300973196 1
614 | 8287 501.0 3
615 | 19878 154.13333333333333 6
616 | 1611 27327.89527242847 282.6
617 | 1883 349.9999999999416 350
618 | 20297 244.72299999999998 2
619 | 14752 0.8571428571428571 3.71428571429
620 | 8819 0.625 1.3
621 | 3064 -0.6020066889632106 500
622 | 15886 0.3333333333333333 0.5
623 | 4358 240.0 375
624 | 14955 4.0 4
625 | 4622 9.0 11.75
626 | 10707 32.0 32
627 | 2599 60.0 12
628 | 4752 6.5 66
629 | 7154 18.0 2
630 | 22380 -0.039603960396039604 1368
631 | 21487 -0.375 36
632 | 7844 1123.2 1123.2
633 | 11081 12.0 3
634 | 18385 0.00881542699724518 200
635 | 7185 145.0 145
636 | 2216 57.5 30
637 | 8199 0.7142857142857143 16
638 | 16493 180.0 65
639 | 11169 1350.0 60
640 | 2972 315.0 315
641 | 20043 2.5 2
642 | 22444 85.0 15
643 | 2498 5.999999999994 6
644 | 5404 66.0 66
645 | 1125 31999.999999996002 4000
646 | 14461 367.0 367
647 | 16284 1.2954545454545454 104
648 | 12335 -4.341875 15.7
649 | 18883 1250.0 1100
650 | 18318 78.0 4
651 | 12119 2000.000000001 2000
652 | 7490 3.0 3
653 | 16758 49.0 200000
654 | 20234 7.363636363636363 10
655 | 35 1.5 4
656 | 20657 18720.0 3120
657 | 440 6.0 6
658 | 4219 17.43576388888889 35.7
659 | 20848 22.1 9.25
660 | 22775 36.57142857142857 256
661 | 5250 5040.0 5160
662 | 15092 24.0 24
663 | 11334 10430.0 10430
664 | 2717 33.0 720
665 | 22148 4.9999999999995834 5
666 | 969 3600.0 3600
667 | 15090 19.45 19.45
668 | 5483 -4.3500000000000005 40
669 | 3094 33.23076923077026 33
670 | 7598 726.0 180
671 | 20519 448.00000000000006 64
672 | 14076 75.0 75
673 | 18959 10.0 250
674 | 8614 600.0 600
675 | 5126 280.0 40
676 | 10725 0.0044444444444444444 37.5
677 | 19368 11.666666666667 16
678 | 2519 -48.0 1296
679 | 20186 0.1875 192
680 | 17901 1587.5 1587.5
681 | 22051 36.0 36
682 | 15481 -49.0 63
683 | 2598 6.0714285714309995 9.33333333333
684 | 16489 3.95959595959596 2
685 | 4142 12.0 14
686 | 7258 33.5999999999328 224
687 | 7212 2.520080321285141 275
688 | 15434 300.0 300
689 | 14117 28.0 37
690 | 13947 23800.0 23800
691 | 15323 26.0 26
692 | 2587 599.0 15
693 | 18617 1000.0 1000
694 | 1235 6.0 0.5
695 | 3193 0.30845681789999996 1.111
696 | 18156 4.986666666666666 6
697 | 11665 2400.0 2400
698 | 4963 0.0666666666668 560
699 | 22633 52.5 8
700 | 18553 23.666666666667 9
701 | 19730 40.0 40
702 | 10398 480.0 480
703 | 2801 0.041666666666666664 27
704 | 9245 61.0 61
705 | 14014 -3.8285714285714287 4
706 | 10772 0.4444444444444444 13
707 | 11132 7.0 12
708 | 2075 14592.0 16
709 | 19122 15675.0 15675
710 | 4922 -0.999375 1920
711 | 19595 0.020408163265309524 49
712 | 14531 0.25 0.375
713 | 7781 1.5902777777777777 5.5
714 | 3161 1.666666666665 562.5
715 | 5687 3600.000000000001 40
716 | 10590 23.5 23.5
717 | 6057 19.700000000000003 12.1
718 | 7553 2.9999999999994 3
719 | 588 60.0 45
720 | 21505 420.0 90
721 | 12215 0.05656565656565657 523
722 | 11985 52.0 420
723 | 8241 130.0 130
724 | 13746 1.26 1.26
725 | 12878 0.25 57
726 | 243 1000.0 240
727 | 7221 5.6 6
728 | 20984 0.0026666666666666666 20
729 | 16019 497.77777777777777 490
730 | 19176 512.0 800
731 | 754 34.29 38.75
732 | 17290 -0.8275862068965517 2
733 | 19158 0.7000000000000001 0.7
734 | 13814 0.20833333333333331 4.8
735 | 22682 -153.0 32
736 | 306 50.000000000025004 50
737 | 15599 0.625 0.15
738 | 18902 -331.1639871382636 718
739 | 20360 60.25 15
740 | 19866 62.800000000000004 31.4
741 | 4415 0.5833333333333334 10
742 | 15083 0.8571428571428571 845
743 | 2772 0.08 0.08
744 | 20388 2400.0 2300
745 | 12577 4200.0 4200
746 | 18707 -66.66666666666666 0.047619047619
747 | 1514 48.0 48
748 | 6187 55.800000000000004 55.8
749 | 16895 36.0 144
750 | 14372 1.0 3
751 | 11528 112.0 112
752 | 16601 -77.0 42
753 | 16638 11.999999999994001 12
754 | 8047 0.21 0.21
755 | 20300 67.5 210
756 | 362 0.16666666666666666 0.166666666667
757 | 13792 0.33999999999999997 0.34
758 | 8321 6.0 36
759 | 22396 0.24999999999999994 0.444444444444
760 | 12072 135459.0 12
761 | 11294 0.9999999999835012 1
762 | 22395 913.6363636363635 60
763 | 21325 1.35 0.825
764 | 6137 72.0 72
765 | 7028 1.2195121951219512 8217
766 | 11258 0.7972222222222223 0.73
767 | 6581 30.0 330
768 | 21662 7410.0 204
769 | 10977 5.0 5
770 | 13450 0.1250521050437682 4800
771 | 6254 2405.0 1080
772 | 15248 90.428571428571 255
773 | 14217 2.5 2.5
774 | 17256 0.993816425120773 180
775 | 15553 5.0 5
776 | 11029 11.052631578947368 10
777 | 4436 3.3600000000000003 4.2
778 | 6900 18.0 18
779 | 3180 0.00199700449326011 0.25
780 | 12115 5.0 5
781 | 8839 11666.666666666666 17
782 | 17180 810.0 70
783 | 12069 0.8 0.657142857143
784 | 4872 119.0 177
785 | 14058 -1.8606847697756788 2424
786 | 7211 6.48 1.3
787 | 706 -9.485714285714286 620
788 | 8561 -11.666666666642 8
789 | 16946 -224.25 400
790 | 3111 -0.05128205128205128 90
791 | 15917 28.8 28.8
792 | 14731 100.0 100
793 | 3421 150.0 30
794 | 3645 8.333333333333002 8
795 | 244 6000.0 3840
796 | 5868 31.999999999920004 32
797 | 18252 3.0 3
798 | 14126 3.39 3.39
799 | 22739 12.495 12.495
800 | 1131 1.8285714285714285 0.8
801 | 6130 8.0 8
802 | 12775 0.08333333333333333 0.0833333333333
803 | 10684 499.0 3
804 | 20151 17.0 22
805 | 2929 0.4285714285715 0.428571428571
806 | 20140 3.0 3
807 | 20611 4.9 4
808 | 3950 27.428571428568 21
809 | 4250 24.0 44
810 | 22183 145.0 145
811 | 1739 15300.0 15300
812 | 6832 12.0 12
813 | 13861 112.0 96
814 | 6992 80.5 40
815 | 14143 435344.0 1157
816 | 3268 175.0 175
817 | 20583 -50.90909090909091 4370
818 | 2255 1500.0 600
819 | 3583 0.16666666666666666 0.0833333333333
820 | 19867 480.0 720
821 | 21555 -0.009259259259259259 112
822 | 6022 26.052631578947366 15
823 | 10299 -4.5 2
824 | 19587 300.0 800
825 | 17421 -4.1000000000000005 0.833333333333
826 | 8030 0.16666666666666666 54
827 | 4012 39.111111111111114 80
828 | 19921 23.0 23
829 | 1462 13000.0 20
830 | 11718 5.0 5
831 | 16702 76.0 76
832 | 17880 96.0 96
833 | 11982 1.32 1.32
834 | 3271 10850.0 10850
835 | 5985 2.0 251.2
836 | 1790 -47.800000000000004 900
837 | 2305 4.32 44.5
838 | 1392 6388.8 6388.8
839 | 12149 0.037037037037037035 324
840 | 15961 71.35714285714286 1850
841 | 13217 0.25 4
842 | 16746 -480.0 960
843 | 5449 438636005.19200003 3
844 | 20321 -143.8666666666668 240
845 | 10640 0.15789473684210525 4000
846 | 16227 220.0 60
847 | 8558 -0.222222222222 0.333333333333
848 | 18301 1.25 0.2
849 | 14235 10.0 10
850 | 2455 24.03 24.03
851 | 20259 19.800000000000008 19.8
852 | 19655 75.0 70
853 | 13275 323.9999999997408 100
854 |
--------------------------------------------------------------------------------
/3_T-RNN_&_baselines/output/run_dolphin.txt:
--------------------------------------------------------------------------------
1 | 10
2 | 297
3 | 250 25 30
4 | loading finished
5 | RECUR EPOCH: 0, loss: 1.1682798597547743, train_acc_e: 0.6151079136690647, train_acc_i: 0.5679012345679012 time: 0.08729670445124309
6 | valid_acc_e: 0.75, valid_acc_i: 0.76, test_acc_e: 0.49333333333333335, test_acc_i: 0.41379310344827586
7 | originial 0
8 | saving... 0.76
9 | saveing ok!
10 | final test temp_acc:0.43333333333333335, ans_acc:0.43333333333333335
11 |
12 | RECUR EPOCH: 1, loss: 0.9512225806467818, train_acc_e: 0.6276978417266187, train_acc_i: 0.6008230452674898 time: 0.08907195727030436
13 | valid_acc_e: 0.75, valid_acc_i: 0.76, test_acc_e: 0.49333333333333335, test_acc_i: 0.41379310344827586
14 | originial 0.76
15 | saving... 0.76
16 | saveing ok!
17 | final test temp_acc:0.43333333333333335, ans_acc:0.43333333333333335
18 |
19 | RECUR EPOCH: 2, loss: 0.88405234725387, train_acc_e: 0.6276978417266187, train_acc_i: 0.6008230452674898 time: 0.09000606139500936
20 | valid_acc_e: 0.75, valid_acc_i: 0.76, test_acc_e: 0.49333333333333335, test_acc_i: 0.41379310344827586
21 | originial 0.76
22 | saving... 0.76
23 | saveing ok!
24 | final test temp_acc:0.43333333333333335, ans_acc:0.43333333333333335
25 |
26 | RECUR EPOCH: 3, loss: 0.8569288528505176, train_acc_e: 0.6276978417266187, train_acc_i: 0.6008230452674898 time: 0.09406771659851074
27 | valid_acc_e: 0.75, valid_acc_i: 0.76, test_acc_e: 0.49333333333333335, test_acc_i: 0.41379310344827586
28 | originial 0.76
29 | saving... 0.76
30 | saveing ok!
31 | final test temp_acc:0.43333333333333335, ans_acc:0.43333333333333335
32 |
33 | RECUR EPOCH: 4, loss: 0.8268176341743626, train_acc_e: 0.6420863309352518, train_acc_i: 0.6131687242798354 time: 0.09377397696177164
34 | valid_acc_e: 0.7857142857142857, valid_acc_i: 0.8, test_acc_e: 0.5733333333333334, test_acc_i: 0.4827586206896552
35 | originial 0.76
36 | saving... 0.8
37 | saveing ok!
38 | final test temp_acc:0.5, ans_acc:0.5
39 |
40 | RECUR EPOCH: 5, loss: 0.8387280844857172, train_acc_e: 0.6636690647482014, train_acc_i: 0.6460905349794238 time: 0.09163166284561157
41 | valid_acc_e: 0.7857142857142857, valid_acc_i: 0.8, test_acc_e: 0.5733333333333334, test_acc_i: 0.4827586206896552
42 | originial 0.8
43 | saving... 0.8
44 | saveing ok!
45 | final test temp_acc:0.5, ans_acc:0.5
46 |
47 | RECUR EPOCH: 6, loss: 0.7824347774678297, train_acc_e: 0.6690647482014388, train_acc_i: 0.6460905349794238 time: 0.09390345414479574
48 | valid_acc_e: 0.7321428571428571, valid_acc_i: 0.72, test_acc_e: 0.5466666666666666, test_acc_i: 0.4482758620689655
49 | jumping...
50 |
51 | RECUR EPOCH: 7, loss: 0.7914263799847889, train_acc_e: 0.658273381294964, train_acc_i: 0.6255144032921811 time: 0.09514609972635905
52 | valid_acc_e: 0.7857142857142857, valid_acc_i: 0.76, test_acc_e: 0.5866666666666667, test_acc_i: 0.4827586206896552
53 | jumping...
54 |
55 | RECUR EPOCH: 8, loss: 0.7353058999457968, train_acc_e: 0.6960431654676259, train_acc_i: 0.6337448559670782 time: 0.09473386208216349
56 | valid_acc_e: 0.7857142857142857, valid_acc_i: 0.8, test_acc_e: 0.5866666666666667, test_acc_i: 0.4827586206896552
57 | originial 0.8
58 | saving... 0.8
59 | saveing ok!
60 | final test temp_acc:0.5, ans_acc:0.5
61 |
62 | RECUR EPOCH: 9, loss: 0.7557127485549989, train_acc_e: 0.6672661870503597, train_acc_i: 0.6255144032921811 time: 0.09168526728947958
63 | valid_acc_e: 0.7321428571428571, valid_acc_i: 0.68, test_acc_e: 0.56, test_acc_i: 0.3793103448275862
64 | jumping...
65 |
66 | RECUR EPOCH: 10, loss: 0.6924405353059494, train_acc_e: 0.6816546762589928, train_acc_i: 0.5967078189300411 time: 0.09236328999201457
67 | valid_acc_e: 0.7678571428571429, valid_acc_i: 0.8, test_acc_e: 0.5866666666666667, test_acc_i: 0.4827586206896552
68 | originial 0.8
69 | saving... 0.8
70 | saveing ok!
71 | final test temp_acc:0.5, ans_acc:0.4666666666666667
72 |
73 | RECUR EPOCH: 11, loss: 0.6715376661638174, train_acc_e: 0.6942446043165468, train_acc_i: 0.654320987654321 time: 0.09408228794733683
74 | valid_acc_e: 0.6964285714285714, valid_acc_i: 0.68, test_acc_e: 0.48, test_acc_i: 0.2413793103448276
75 | jumping...
76 |
77 | RECUR EPOCH: 12, loss: 0.6259040008356542, train_acc_e: 0.710431654676259, train_acc_i: 0.6378600823045267 time: 0.09618140856424967
78 | valid_acc_e: 0.6785714285714286, valid_acc_i: 0.64, test_acc_e: 0.4666666666666667, test_acc_i: 0.3103448275862069
79 | jumping...
80 |
81 | RECUR EPOCH: 13, loss: 0.6323614552187822, train_acc_e: 0.7140287769784173, train_acc_i: 0.654320987654321 time: 0.09348649183909098
82 | valid_acc_e: 0.6964285714285714, valid_acc_i: 0.68, test_acc_e: 0.5333333333333333, test_acc_i: 0.27586206896551724
83 | jumping...
84 |
85 | RECUR EPOCH: 14, loss: 0.6169184696527175, train_acc_e: 0.7302158273381295, train_acc_i: 0.6378600823045267 time: 0.09691879351933798
86 | valid_acc_e: 0.6964285714285714, valid_acc_i: 0.68, test_acc_e: 0.6266666666666667, test_acc_i: 0.4482758620689655
87 | jumping...
88 |
89 | RECUR EPOCH: 15, loss: 0.5792982862810049, train_acc_e: 0.7050359712230215, train_acc_i: 0.6460905349794238 time: 0.09498010873794556
90 | valid_acc_e: 0.6964285714285714, valid_acc_i: 0.6, test_acc_e: 0.5333333333333333, test_acc_i: 0.27586206896551724
91 | jumping...
92 |
93 | RECUR EPOCH: 16, loss: 0.6057494206683626, train_acc_e: 0.7140287769784173, train_acc_i: 0.6584362139917695 time: 0.093644118309021
94 | valid_acc_e: 0.5535714285714286, valid_acc_i: 0.4, test_acc_e: 0.5066666666666667, test_acc_i: 0.27586206896551724
95 | jumping...
96 |
97 | RECUR EPOCH: 17, loss: 0.5289612836798523, train_acc_e: 0.7589928057553957, train_acc_i: 0.6625514403292181 time: 0.09545961221059164
98 | valid_acc_e: 0.6785714285714286, valid_acc_i: 0.64, test_acc_e: 0.5733333333333334, test_acc_i: 0.4482758620689655
99 | jumping...
100 |
101 | RECUR EPOCH: 18, loss: 0.5154155585991502, train_acc_e: 0.7589928057553957, train_acc_i: 0.6790123456790124 time: 0.09287167390187581
102 | valid_acc_e: 0.625, valid_acc_i: 0.52, test_acc_e: 0.5333333333333333, test_acc_i: 0.27586206896551724
103 | jumping...
104 |
105 | RECUR EPOCH: 19, loss: 0.5485985622484497, train_acc_e: 0.7446043165467626, train_acc_i: 0.6502057613168725 time: 0.09662172396977743
106 | valid_acc_e: 0.6071428571428571, valid_acc_i: 0.52, test_acc_e: 0.5866666666666667, test_acc_i: 0.3448275862068966
107 | jumping...
108 |
109 | RECUR EPOCH: 20, loss: 0.46297485249523274, train_acc_e: 0.7931654676258992, train_acc_i: 0.7242798353909465 time: 0.0916707158088684
110 | valid_acc_e: 0.6428571428571429, valid_acc_i: 0.56, test_acc_e: 0.6666666666666666, test_acc_i: 0.4482758620689655
111 | jumping...
112 |
113 | RECUR EPOCH: 21, loss: 0.4903093247747225, train_acc_e: 0.7805755395683454, train_acc_i: 0.7078189300411523 time: 0.09656307299931845
114 | valid_acc_e: 0.6964285714285714, valid_acc_i: 0.6, test_acc_e: 0.6266666666666667, test_acc_i: 0.4482758620689655
115 | jumping...
116 |
117 | RECUR EPOCH: 22, loss: 0.39395640612629707, train_acc_e: 0.8201438848920863, train_acc_i: 0.7366255144032922 time: 0.09862770636876424
118 | valid_acc_e: 0.7142857142857143, valid_acc_i: 0.68, test_acc_e: 0.6, test_acc_i: 0.4827586206896552
119 | jumping...
120 |
121 | RECUR EPOCH: 23, loss: 0.36069519137158806, train_acc_e: 0.841726618705036, train_acc_i: 0.7695473251028807 time: 0.09261656999588012
122 | valid_acc_e: 0.6964285714285714, valid_acc_i: 0.64, test_acc_e: 0.6266666666666667, test_acc_i: 0.41379310344827586
123 | jumping...
124 |
125 | RECUR EPOCH: 24, loss: 0.3598016642739253, train_acc_e: 0.829136690647482, train_acc_i: 0.7489711934156379 time: 0.09933442274729411
126 | valid_acc_e: 0.7321428571428571, valid_acc_i: 0.68, test_acc_e: 0.64, test_acc_i: 0.5517241379310345
127 | jumping...
128 |
129 | RECUR EPOCH: 25, loss: 0.4195781421268918, train_acc_e: 0.8075539568345323, train_acc_i: 0.7448559670781894 time: 0.09181210199991861
130 | valid_acc_e: 0.48214285714285715, valid_acc_i: 0.36, test_acc_e: 0.5733333333333334, test_acc_i: 0.3448275862068966
131 | jumping...
132 |
133 | RECUR EPOCH: 26, loss: 0.42452051021434645, train_acc_e: 0.8093525179856115, train_acc_i: 0.7489711934156379 time: 0.09510379234949748
134 | valid_acc_e: 0.7857142857142857, valid_acc_i: 0.76, test_acc_e: 0.6533333333333333, test_acc_i: 0.4482758620689655
135 | jumping...
136 |
137 | RECUR EPOCH: 27, loss: 0.332387818230523, train_acc_e: 0.8489208633093526, train_acc_i: 0.7736625514403292 time: 0.09798026084899902
138 | valid_acc_e: 0.7857142857142857, valid_acc_i: 0.8, test_acc_e: 0.64, test_acc_i: 0.5172413793103449
139 | originial 0.8
140 | saving... 0.8
141 | saveing ok!
142 | final test temp_acc:0.5, ans_acc:0.5
143 |
144 | RECUR EPOCH: 28, loss: 0.32374130060643325, train_acc_e: 0.8489208633093526, train_acc_i: 0.7654320987654321 time: 0.09599583148956299
145 | valid_acc_e: 0.625, valid_acc_i: 0.48, test_acc_e: 0.6133333333333333, test_acc_i: 0.3103448275862069
146 | jumping...
147 |
148 | RECUR EPOCH: 29, loss: 0.3817319320553124, train_acc_e: 0.8219424460431655, train_acc_i: 0.7695473251028807 time: 0.09603599707285564
149 | valid_acc_e: 0.5357142857142857, valid_acc_i: 0.32, test_acc_e: 0.6, test_acc_i: 0.3793103448275862
150 | jumping...
151 |
152 | RECUR EPOCH: 30, loss: 0.3141926777215652, train_acc_e: 0.8525179856115108, train_acc_i: 0.7901234567901234 time: 0.0925633192062378
153 | valid_acc_e: 0.6964285714285714, valid_acc_i: 0.68, test_acc_e: 0.6, test_acc_i: 0.4827586206896552
154 | jumping...
155 |
156 | RECUR EPOCH: 31, loss: 0.22721145869282539, train_acc_e: 0.8812949640287769, train_acc_i: 0.8148148148148148 time: 0.09542576471964519
157 | valid_acc_e: 0.7321428571428571, valid_acc_i: 0.64, test_acc_e: 0.72, test_acc_i: 0.4827586206896552
158 | jumping...
159 |
160 | RECUR EPOCH: 32, loss: 0.21175358324875065, train_acc_e: 0.8902877697841727, train_acc_i: 0.8106995884773662 time: 0.09543842077255249
161 | valid_acc_e: 0.6607142857142857, valid_acc_i: 0.6, test_acc_e: 0.7066666666666667, test_acc_i: 0.5172413793103449
162 | jumping...
163 |
164 | RECUR EPOCH: 33, loss: 0.16183604134453666, train_acc_e: 0.9190647482014388, train_acc_i: 0.8765432098765432 time: 0.09457332690556844
165 | valid_acc_e: 0.6071428571428571, valid_acc_i: 0.56, test_acc_e: 0.68, test_acc_i: 0.5517241379310345
166 | jumping...
167 |
168 | RECUR EPOCH: 34, loss: 0.17849770151538613, train_acc_e: 0.9118705035971223, train_acc_i: 0.8600823045267489 time: 0.0989931066830953
169 | valid_acc_e: 0.6428571428571429, valid_acc_i: 0.56, test_acc_e: 0.6666666666666666, test_acc_i: 0.4482758620689655
170 | jumping...
171 |
172 | RECUR EPOCH: 35, loss: 0.17104502617086403, train_acc_e: 0.9100719424460432, train_acc_i: 0.8683127572016461 time: 0.09347049792607626
173 | valid_acc_e: 0.6785714285714286, valid_acc_i: 0.56, test_acc_e: 0.64, test_acc_i: 0.4827586206896552
174 | jumping...
175 |
176 | RECUR EPOCH: 36, loss: 0.16812368626457183, train_acc_e: 0.9100719424460432, train_acc_i: 0.8600823045267489 time: 0.09435887336730957
177 | valid_acc_e: 0.6071428571428571, valid_acc_i: 0.56, test_acc_e: 0.6666666666666666, test_acc_i: 0.41379310344827586
178 | jumping...
179 |
180 | RECUR EPOCH: 37, loss: 0.15582417485154706, train_acc_e: 0.920863309352518, train_acc_i: 0.8641975308641975 time: 0.09741787115732829
181 | valid_acc_e: 0.6071428571428571, valid_acc_i: 0.48, test_acc_e: 0.6933333333333334, test_acc_i: 0.4482758620689655
182 | jumping...
183 |
184 | RECUR EPOCH: 38, loss: 0.17178444214809088, train_acc_e: 0.8992805755395683, train_acc_i: 0.8353909465020576 time: 0.09656809568405152
185 | valid_acc_e: 0.75, valid_acc_i: 0.72, test_acc_e: 0.6533333333333333, test_acc_i: 0.5172413793103449
186 | jumping...
187 |
188 | RECUR EPOCH: 39, loss: 0.23990754335505482, train_acc_e: 0.9100719424460432, train_acc_i: 0.8395061728395061 time: 0.09630475044250489
189 | valid_acc_e: 0.75, valid_acc_i: 0.76, test_acc_e: 0.6666666666666666, test_acc_i: 0.4482758620689655
190 | jumping...
191 |
192 | RECUR EPOCH: 40, loss: 0.13023324586727, train_acc_e: 0.9370503597122302, train_acc_i: 0.8930041152263375 time: 0.09295337200164795
193 | valid_acc_e: 0.6071428571428571, valid_acc_i: 0.56, test_acc_e: 0.6666666666666666, test_acc_i: 0.4482758620689655
194 | jumping...
195 |
196 | RECUR EPOCH: 41, loss: 0.14373765438181874, train_acc_e: 0.9334532374100719, train_acc_i: 0.8930041152263375 time: 0.09432597160339355
197 | valid_acc_e: 0.6071428571428571, valid_acc_i: 0.6, test_acc_e: 0.5733333333333334, test_acc_i: 0.3793103448275862
198 | jumping...
199 |
200 | RECUR EPOCH: 42, loss: 0.10655531520215572, train_acc_e: 0.9496402877697842, train_acc_i: 0.9218106995884774 time: 0.09521870215733846
201 | valid_acc_e: 0.75, valid_acc_i: 0.72, test_acc_e: 0.68, test_acc_i: 0.4482758620689655
202 | jumping...
203 |
204 | RECUR EPOCH: 43, loss: 0.10571475958627928, train_acc_e: 0.9460431654676259, train_acc_i: 0.9012345679012346 time: 0.09794450203577677
205 | valid_acc_e: 0.6785714285714286, valid_acc_i: 0.6, test_acc_e: 0.7066666666666667, test_acc_i: 0.5517241379310345
206 | jumping...
207 |
208 | RECUR EPOCH: 44, loss: 0.10741040772861904, train_acc_e: 0.9406474820143885, train_acc_i: 0.9053497942386831 time: 0.09870184659957885
209 | valid_acc_e: 0.7678571428571429, valid_acc_i: 0.76, test_acc_e: 0.68, test_acc_i: 0.5862068965517241
210 | jumping...
211 |
212 | RECUR EPOCH: 45, loss: 0.13996641748726613, train_acc_e: 0.9334532374100719, train_acc_i: 0.8765432098765432 time: 0.09284635384877522
213 | valid_acc_e: 0.625, valid_acc_i: 0.56, test_acc_e: 0.7066666666666667, test_acc_i: 0.4827586206896552
214 | jumping...
215 |
216 | RECUR EPOCH: 46, loss: 0.12933335284637326, train_acc_e: 0.9370503597122302, train_acc_i: 0.897119341563786 time: 0.09647434155146281
217 | valid_acc_e: 0.6071428571428571, valid_acc_i: 0.56, test_acc_e: 0.6666666666666666, test_acc_i: 0.3793103448275862
218 | jumping...
219 |
220 | RECUR EPOCH: 47, loss: 0.06780000196562873, train_acc_e: 0.9676258992805755, train_acc_i: 0.9382716049382716 time: 0.0968562642733256
221 | valid_acc_e: 0.7321428571428571, valid_acc_i: 0.72, test_acc_e: 0.76, test_acc_i: 0.5862068965517241
222 | jumping...
223 |
224 | RECUR EPOCH: 48, loss: 0.0724956984382598, train_acc_e: 0.9676258992805755, train_acc_i: 0.934156378600823 time: 0.09332544406255086
225 | valid_acc_e: 0.75, valid_acc_i: 0.72, test_acc_e: 0.6666666666666666, test_acc_i: 0.5517241379310345
226 | jumping...
227 |
228 | RECUR EPOCH: 49, loss: 0.08274855584274103, train_acc_e: 0.947841726618705, train_acc_i: 0.9053497942386831 time: 0.09943470557530722
229 | valid_acc_e: 0.5714285714285714, valid_acc_i: 0.44, test_acc_e: 0.64, test_acc_i: 0.41379310344827586
230 | jumping...
231 |
232 | RECUR EPOCH: 50, loss: 0.06540786852071315, train_acc_e: 0.9568345323741008, train_acc_i: 0.9176954732510288 time: 0.09516379833221436
233 | valid_acc_e: 0.6607142857142857, valid_acc_i: 0.52, test_acc_e: 0.6933333333333334, test_acc_i: 0.41379310344827586
234 | jumping...
235 |
236 | RECUR EPOCH: 51, loss: 0.05518364053940086, train_acc_e: 0.9658273381294964, train_acc_i: 0.9465020576131687 time: 0.09435760180155436
237 | valid_acc_e: 0.7321428571428571, valid_acc_i: 0.72, test_acc_e: 0.6533333333333333, test_acc_i: 0.5517241379310345
238 | jumping...
239 |
240 | RECUR EPOCH: 52, loss: 0.09833989197334635, train_acc_e: 0.9586330935251799, train_acc_i: 0.9094650205761317 time: 0.0969707727432251
241 | valid_acc_e: 0.5892857142857143, valid_acc_i: 0.52, test_acc_e: 0.6666666666666666, test_acc_i: 0.4827586206896552
242 | jumping...
243 |
244 | RECUR EPOCH: 53, loss: 0.09804084958362971, train_acc_e: 0.9532374100719424, train_acc_i: 0.9300411522633745 time: 0.09771036307017009
245 | valid_acc_e: 0.7321428571428571, valid_acc_i: 0.72, test_acc_e: 0.6533333333333333, test_acc_i: 0.4827586206896552
246 | jumping...
247 |
248 | RECUR EPOCH: 54, loss: 0.09636338147116297, train_acc_e: 0.9622302158273381, train_acc_i: 0.9300411522633745 time: 0.09834134578704834
249 | valid_acc_e: 0.5535714285714286, valid_acc_i: 0.52, test_acc_e: 0.68, test_acc_i: 0.4482758620689655
250 | jumping...
251 |
252 | RECUR EPOCH: 55, loss: 0.07666621182435825, train_acc_e: 0.9532374100719424, train_acc_i: 0.9135802469135802 time: 0.0945337176322937
253 | valid_acc_e: 0.6607142857142857, valid_acc_i: 0.68, test_acc_e: 0.7466666666666667, test_acc_i: 0.5517241379310345
254 | jumping...
255 |
256 | RECUR EPOCH: 56, loss: 0.04281204975681541, train_acc_e: 0.9820143884892086, train_acc_i: 0.9588477366255144 time: 0.09585176308949789
257 | valid_acc_e: 0.7857142857142857, valid_acc_i: 0.76, test_acc_e: 0.6666666666666666, test_acc_i: 0.5862068965517241
258 | jumping...
259 |
260 | RECUR EPOCH: 57, loss: 0.07147320019610134, train_acc_e: 0.9586330935251799, train_acc_i: 0.9259259259259259 time: 0.09756191174189249
261 | valid_acc_e: 0.6607142857142857, valid_acc_i: 0.6, test_acc_e: 0.6666666666666666, test_acc_i: 0.41379310344827586
262 | jumping...
263 |
264 | RECUR EPOCH: 58, loss: 0.049979879440348825, train_acc_e: 0.9712230215827338, train_acc_i: 0.9465020576131687 time: 0.09521978696187337
265 | valid_acc_e: 0.6964285714285714, valid_acc_i: 0.64, test_acc_e: 0.6666666666666666, test_acc_i: 0.5172413793103449
266 | jumping...
267 |
268 | RECUR EPOCH: 59, loss: 0.03579042790603245, train_acc_e: 0.9820143884892086, train_acc_i: 0.9629629629629629 time: 0.09564072688420613
269 | valid_acc_e: 0.6964285714285714, valid_acc_i: 0.68, test_acc_e: 0.6533333333333333, test_acc_i: 0.41379310344827586
270 | jumping...
271 |
272 | RECUR EPOCH: 60, loss: 0.07674725838158847, train_acc_e: 0.960431654676259, train_acc_i: 0.934156378600823 time: 0.09347194035847982
273 | valid_acc_e: 0.7321428571428571, valid_acc_i: 0.68, test_acc_e: 0.6266666666666667, test_acc_i: 0.4482758620689655
274 | jumping...
275 |
276 | RECUR EPOCH: 61, loss: 0.03581807494868712, train_acc_e: 0.9748201438848921, train_acc_i: 0.9506172839506173 time: 0.09602559407552083
277 | valid_acc_e: 0.6785714285714286, valid_acc_i: 0.6, test_acc_e: 0.7066666666666667, test_acc_i: 0.5517241379310345
278 | jumping...
279 |
280 | RECUR EPOCH: 62, loss: 0.02495744525466436, train_acc_e: 0.9802158273381295, train_acc_i: 0.9629629629629629 time: 0.09728382031122844
281 | valid_acc_e: 0.625, valid_acc_i: 0.6, test_acc_e: 0.6533333333333333, test_acc_i: 0.4482758620689655
282 | jumping...
283 |
284 | RECUR EPOCH: 63, loss: 0.039477180061026365, train_acc_e: 0.9784172661870504, train_acc_i: 0.9629629629629629 time: 0.09547470808029175
285 | valid_acc_e: 0.7857142857142857, valid_acc_i: 0.76, test_acc_e: 0.68, test_acc_i: 0.5172413793103449
286 | jumping...
287 |
288 | RECUR EPOCH: 64, loss: 0.041618613725520455, train_acc_e: 0.9820143884892086, train_acc_i: 0.9629629629629629 time: 0.09983280102411905
289 | valid_acc_e: 0.6607142857142857, valid_acc_i: 0.64, test_acc_e: 0.6133333333333333, test_acc_i: 0.5172413793103449
290 | jumping...
291 |
292 | RECUR EPOCH: 65, loss: 0.027657700135322757, train_acc_e: 0.9856115107913669, train_acc_i: 0.9753086419753086 time: 0.0956899086634318
293 | valid_acc_e: 0.7321428571428571, valid_acc_i: 0.72, test_acc_e: 0.6666666666666666, test_acc_i: 0.5172413793103449
294 | jumping...
295 |
296 | RECUR EPOCH: 66, loss: 0.037495824657840494, train_acc_e: 0.9838129496402878, train_acc_i: 0.9629629629629629 time: 0.09479515552520752
297 | valid_acc_e: 0.7321428571428571, valid_acc_i: 0.68, test_acc_e: 0.7333333333333333, test_acc_i: 0.4827586206896552
298 | jumping...
299 |
300 | RECUR EPOCH: 67, loss: 0.0549890123644974, train_acc_e: 0.9730215827338129, train_acc_i: 0.9506172839506173 time: 0.0966543157895406
301 | valid_acc_e: 0.5892857142857143, valid_acc_i: 0.48, test_acc_e: 0.6666666666666666, test_acc_i: 0.3793103448275862
302 | jumping...
303 |
304 | RECUR EPOCH: 68, loss: 0.023487998686209627, train_acc_e: 0.9910071942446043, train_acc_i: 0.9794238683127572 time: 0.09579426447550456
305 | valid_acc_e: 0.7678571428571429, valid_acc_i: 0.72, test_acc_e: 0.6533333333333333, test_acc_i: 0.4482758620689655
306 | jumping...
307 |
308 | RECUR EPOCH: 69, loss: 0.029502890176243253, train_acc_e: 0.9838129496402878, train_acc_i: 0.9711934156378601 time: 0.09850571552912395
309 | valid_acc_e: 0.6964285714285714, valid_acc_i: 0.68, test_acc_e: 0.6533333333333333, test_acc_i: 0.4827586206896552
310 | jumping...
311 |
312 | RECUR EPOCH: 70, loss: 0.050007173891175434, train_acc_e: 0.9784172661870504, train_acc_i: 0.9588477366255144 time: 0.09385395050048828
313 | valid_acc_e: 0.7321428571428571, valid_acc_i: 0.68, test_acc_e: 0.6933333333333334, test_acc_i: 0.4482758620689655
314 | jumping...
315 |
316 | RECUR EPOCH: 71, loss: 0.0374770586382705, train_acc_e: 0.9910071942446043, train_acc_i: 0.9794238683127572 time: 0.09573783079783121
317 | valid_acc_e: 0.7857142857142857, valid_acc_i: 0.8, test_acc_e: 0.64, test_acc_i: 0.4827586206896552
318 | originial 0.8
319 | saving... 0.8
320 | saveing ok!
321 | final test temp_acc:0.5333333333333333, ans_acc:0.5
322 |
323 | RECUR EPOCH: 72, loss: 0.0281009430472007, train_acc_e: 0.9856115107913669, train_acc_i: 0.9670781893004116 time: 0.09709215958913167
324 | valid_acc_e: 0.75, valid_acc_i: 0.72, test_acc_e: 0.6533333333333333, test_acc_i: 0.4482758620689655
325 | jumping...
326 |
327 | RECUR EPOCH: 73, loss: 0.01708699586521451, train_acc_e: 0.987410071942446, train_acc_i: 0.9753086419753086 time: 0.0944095253944397
328 | valid_acc_e: 0.8214285714285714, valid_acc_i: 0.8, test_acc_e: 0.64, test_acc_i: 0.4827586206896552
329 | originial 0.8
330 | saving... 0.8
331 | saveing ok!
332 | final test temp_acc:0.5, ans_acc:0.5
333 |
334 | RECUR EPOCH: 74, loss: 0.05629130771743908, train_acc_e: 0.9838129496402878, train_acc_i: 0.9670781893004116 time: 0.09714263677597046
335 | valid_acc_e: 0.6785714285714286, valid_acc_i: 0.72, test_acc_e: 0.6, test_acc_i: 0.41379310344827586
336 | jumping...
337 |
338 | RECUR EPOCH: 75, loss: 0.04158908542659548, train_acc_e: 0.9748201438848921, train_acc_i: 0.9506172839506173 time: 0.09388278722763062
339 | valid_acc_e: 0.6964285714285714, valid_acc_i: 0.68, test_acc_e: 0.6666666666666666, test_acc_i: 0.4827586206896552
340 | jumping...
341 |
342 | RECUR EPOCH: 76, loss: 0.04258053513711372, train_acc_e: 0.9838129496402878, train_acc_i: 0.9670781893004116 time: 0.09425394932428996
343 | valid_acc_e: 0.8035714285714286, valid_acc_i: 0.84, test_acc_e: 0.6, test_acc_i: 0.4482758620689655
344 | originial 0.8
345 | saving... 0.84
346 | saveing ok!
347 | final test temp_acc:0.4666666666666667, ans_acc:0.43333333333333335
348 |
349 | RECUR EPOCH: 77, loss: 0.06083750326201749, train_acc_e: 0.9676258992805755, train_acc_i: 0.9465020576131687 time: 0.09651113351186116
350 | valid_acc_e: 0.6785714285714286, valid_acc_i: 0.68, test_acc_e: 0.6533333333333333, test_acc_i: 0.4482758620689655
351 | jumping...
352 |
353 | RECUR EPOCH: 78, loss: 0.04671263096877086, train_acc_e: 0.9802158273381295, train_acc_i: 0.9588477366255144 time: 0.09435170888900757
354 | valid_acc_e: 0.7678571428571429, valid_acc_i: 0.72, test_acc_e: 0.6933333333333334, test_acc_i: 0.4827586206896552
355 | jumping...
356 |
357 | RECUR EPOCH: 79, loss: 0.05908534688867414, train_acc_e: 0.9820143884892086, train_acc_i: 0.9588477366255144 time: 0.09631927410761515
358 | valid_acc_e: 0.6607142857142857, valid_acc_i: 0.6, test_acc_e: 0.7333333333333333, test_acc_i: 0.5172413793103449
359 | jumping...
360 |
361 | RECUR EPOCH: 80, loss: 0.028193465131239145, train_acc_e: 0.9910071942446043, train_acc_i: 0.9794238683127572 time: 0.09260084629058837
362 | valid_acc_e: 0.6785714285714286, valid_acc_i: 0.6, test_acc_e: 0.6666666666666666, test_acc_i: 0.4482758620689655
363 | jumping...
364 |
365 | RECUR EPOCH: 81, loss: 0.009333069808014627, train_acc_e: 0.9928057553956835, train_acc_i: 0.9876543209876543 time: 0.0965872049331665
366 | valid_acc_e: 0.6428571428571429, valid_acc_i: 0.52, test_acc_e: 0.6266666666666667, test_acc_i: 0.5172413793103449
367 | jumping...
368 |
369 | RECUR EPOCH: 82, loss: 0.011288232396168964, train_acc_e: 0.9946043165467626, train_acc_i: 0.9876543209876543 time: 0.09728428125381469
370 | valid_acc_e: 0.7321428571428571, valid_acc_i: 0.68, test_acc_e: 0.6933333333333334, test_acc_i: 0.4482758620689655
371 | jumping...
372 |
373 | RECUR EPOCH: 83, loss: 0.023994814577692574, train_acc_e: 0.9910071942446043, train_acc_i: 0.9794238683127572 time: 0.09791800578435263
374 | valid_acc_e: 0.7321428571428571, valid_acc_i: 0.68, test_acc_e: 0.68, test_acc_i: 0.4482758620689655
375 | jumping...
376 |
377 | RECUR EPOCH: 84, loss: 0.011793449894743568, train_acc_e: 0.9964028776978417, train_acc_i: 0.9917695473251029 time: 0.098715607325236
378 | valid_acc_e: 0.75, valid_acc_i: 0.68, test_acc_e: 0.68, test_acc_i: 0.41379310344827586
379 | jumping...
380 |
381 | RECUR EPOCH: 85, loss: 0.01421276430902167, train_acc_e: 0.9892086330935251, train_acc_i: 0.9753086419753086 time: 0.09430506229400634
382 | valid_acc_e: 0.6964285714285714, valid_acc_i: 0.64, test_acc_e: 0.6266666666666667, test_acc_i: 0.5517241379310345
383 | jumping...
384 |
385 | RECUR EPOCH: 86, loss: 0.007714040408777111, train_acc_e: 0.9946043165467626, train_acc_i: 0.9876543209876543 time: 0.09514733552932739
386 | valid_acc_e: 0.75, valid_acc_i: 0.72, test_acc_e: 0.6133333333333333, test_acc_i: 0.41379310344827586
387 | jumping...
388 |
389 | RECUR EPOCH: 87, loss: 0.02914390094011041, train_acc_e: 0.9928057553956835, train_acc_i: 0.9917695473251029 time: 0.09735808769861858
390 | valid_acc_e: 0.6964285714285714, valid_acc_i: 0.68, test_acc_e: 0.68, test_acc_i: 0.5172413793103449
391 | jumping...
392 |
393 | RECUR EPOCH: 88, loss: 0.029182899762481573, train_acc_e: 0.9928057553956835, train_acc_i: 0.9835390946502057 time: 0.09592658281326294
394 | valid_acc_e: 0.6964285714285714, valid_acc_i: 0.64, test_acc_e: 0.6533333333333333, test_acc_i: 0.4827586206896552
395 | jumping...
396 |
397 | RECUR EPOCH: 89, loss: 0.009332968670407256, train_acc_e: 0.9946043165467626, train_acc_i: 0.9876543209876543 time: 0.09735538562138875
398 | valid_acc_e: 0.7321428571428571, valid_acc_i: 0.68, test_acc_e: 0.64, test_acc_i: 0.4827586206896552
399 | jumping...
400 |
401 | RECUR EPOCH: 90, loss: 0.024998605864528765, train_acc_e: 0.9820143884892086, train_acc_i: 0.9670781893004116 time: 0.09479767481486003
402 | valid_acc_e: 0.7142857142857143, valid_acc_i: 0.64, test_acc_e: 0.5866666666666667, test_acc_i: 0.4482758620689655
403 | jumping...
404 |
405 | RECUR EPOCH: 91, loss: 0.022708293832378622, train_acc_e: 0.987410071942446, train_acc_i: 0.9711934156378601 time: 0.09715214967727662
406 | valid_acc_e: 0.7678571428571429, valid_acc_i: 0.72, test_acc_e: 0.56, test_acc_i: 0.4482758620689655
407 | jumping...
408 |
409 | RECUR EPOCH: 92, loss: 0.0042395228826644, train_acc_e: 0.9982014388489209, train_acc_i: 0.9958847736625515 time: 0.09556546608606974
410 | valid_acc_e: 0.7142857142857143, valid_acc_i: 0.72, test_acc_e: 0.68, test_acc_i: 0.5172413793103449
411 | jumping...
412 |
413 | RECUR EPOCH: 93, loss: 0.037599779972078376, train_acc_e: 0.9946043165467626, train_acc_i: 0.9917695473251029 time: 0.09529571135838827
414 | valid_acc_e: 0.7142857142857143, valid_acc_i: 0.68, test_acc_e: 0.6533333333333333, test_acc_i: 0.4482758620689655
415 | jumping...
416 |
417 | RECUR EPOCH: 94, loss: 0.019692288941623254, train_acc_e: 0.9910071942446043, train_acc_i: 0.9794238683127572 time: 0.09814429680506388
418 | valid_acc_e: 0.7321428571428571, valid_acc_i: 0.64, test_acc_e: 0.6533333333333333, test_acc_i: 0.4827586206896552
419 | jumping...
420 |
421 | RECUR EPOCH: 95, loss: 0.012651146805381456, train_acc_e: 0.9946043165467626, train_acc_i: 0.9876543209876543 time: 0.09417590697606405
422 | valid_acc_e: 0.6785714285714286, valid_acc_i: 0.68, test_acc_e: 0.6933333333333334, test_acc_i: 0.4482758620689655
423 | jumping...
424 |
425 | RECUR EPOCH: 96, loss: 0.003612281639634827, train_acc_e: 1.0, train_acc_i: 1.0 time: 0.0961928923924764
426 | valid_acc_e: 0.75, valid_acc_i: 0.72, test_acc_e: 0.64, test_acc_i: 0.4827586206896552
427 | jumping...
428 |
429 | RECUR EPOCH: 97, loss: 0.04159402045126192, train_acc_e: 0.9910071942446043, train_acc_i: 0.9876543209876543 time: 0.0963845173517863
430 | valid_acc_e: 0.6785714285714286, valid_acc_i: 0.6, test_acc_e: 0.68, test_acc_i: 0.4482758620689655
431 | jumping...
432 |
433 | RECUR EPOCH: 98, loss: 0.028848348652997624, train_acc_e: 0.9946043165467626, train_acc_i: 0.9876543209876543 time: 0.0943213144938151
434 | valid_acc_e: 0.7678571428571429, valid_acc_i: 0.76, test_acc_e: 0.6933333333333334, test_acc_i: 0.6206896551724138
435 | jumping...
436 |
437 | RECUR EPOCH: 99, loss: 0.017882701150659538, train_acc_e: 0.9928057553956835, train_acc_i: 0.9835390946502057 time: 0.09760322570800781
438 | valid_acc_e: 0.7321428571428571, valid_acc_i: 0.72, test_acc_e: 0.6666666666666666, test_acc_i: 0.4482758620689655
439 | jumping...
440 |
441 | RECUR EPOCH: 100, loss: 0.004860475782013724, train_acc_e: 0.9982014388489209, train_acc_i: 0.9958847736625515 time: 0.09372621774673462
442 | valid_acc_e: 0.6785714285714286, valid_acc_i: 0.64, test_acc_e: 0.7466666666666667, test_acc_i: 0.4827586206896552
443 | jumping...
444 |
445 |
--------------------------------------------------------------------------------
/3_T-RNN_&_baselines/output/run_dolphin_rep.txt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Phybiolo57/MathWordProblemSolver/a7753a8762370f20fce86ea3001554c4f927b563/3_T-RNN_&_baselines/output/run_dolphin_rep.txt
--------------------------------------------------------------------------------
/3_T-RNN_&_baselines/src/data_loader.py:
--------------------------------------------------------------------------------
1 | import json
2 | import numpy as np
3 | import time
4 |
5 | import random
6 | import pickle
7 | import sys
8 |
9 | from utils import *
10 |
11 | class DataLoader:
12 | def __init__(self, dataset):#, dataset, math23k_vocab_list, math23k_decode_list):
13 |
14 | self.dataset = dataset
15 |
16 | math23k_vocab_list = ['PAD','', 'EOS', '1', 'PI']
17 | math23k_decode_list = ['temp_0','temp_1','temp_2','temp_3','temp_4','temp_5','temp_6','temp_7',\
18 | 'temp_8', 'temp_9', 'temp_10', 'temp_11', 'temp_12','temp_13',#'temp_14',
19 | '1', 'PI', 'PAD', '','EOS']
20 |
21 | #math23k_decode_list = ['temp_0','temp_1','temp_2','temp_3','1', 'PI', 'PAD', '','EOS']
22 | for k, v in dataset.items():
23 | for elem in v['template_text'].split(' '):
24 | if elem not in math23k_vocab_list:
25 | math23k_vocab_list.append(elem)
26 |
27 | self.train_list, self.valid_list, self.test_list = split_by_feilong_23k(dataset)
28 | self.data_set = dataset
29 | self.vocab_list = math23k_vocab_list
30 | self.vocab_dict = dict([(elem, idx) for idx, elem in enumerate(self.vocab_list)])
31 | self.vocab_len = len(self.vocab_list)
32 |
33 | self.decode_classes_list = math23k_decode_list
34 | self.decode_classes_dict = dict([(elem, idx) for idx, elem in enumerate(self.decode_classes_list)])
35 | self.decode_classed_len = len(self.decode_classes_list)
36 |
37 | self.decode_emb_idx = [self.vocab_dict[elem] for elem in self.decode_classes_list]
38 |
39 |
40 |
41 |
42 | def __len__(self):
43 | return self.data_size
44 |
45 | def _data_batch_preprocess(self, data_batch):
46 | batch_encode_idx = []
47 | batch_encode_len = []
48 | batch_encode_num_pos = []
49 |
50 | batch_decode_idx = []
51 | batch_decode_emb_idx = []
52 | batch_decode_len = []
53 |
54 | batch_idxs = []
55 | batch_text = []
56 | batch_num_list = []
57 | batch_solution = []
58 |
59 | batch_post_equation = []
60 |
61 | batch_gd_tree = []
62 | #batch_mask = []
63 |
64 | #dict_keys(['numtemp_order', 'index', 'post_template', 'ans', 'num_list', 'template_text', 'mid_template', 'expression', 'text'])
65 | for elem in data_batch:
66 | idx = elem[0]
67 | encode_sen = elem[1]['template_text'][:]
68 | encode_sen_idx = string_2_idx_sen(encode_sen.strip().split(' '), self.vocab_dict)
69 |
70 | batch_encode_idx.append(encode_sen_idx)
71 | batch_encode_len.append(len(encode_sen_idx))
72 | batch_encode_num_pos.append(elem[1]['num_position'][:])
73 |
74 | decode_sen = elem[1]['numtemp_order'][:]
75 | #print (decode_sen)
76 | decode_sen.append('EOS')
77 | decode_sen_idx = string_2_idx_sen(decode_sen, self.decode_classes_dict)
78 | decode_sen_emb_idx = [self.decode_emb_idx[elem] for elem in decode_sen_idx]
79 |
80 | batch_decode_idx.append(decode_sen_idx)
81 | batch_decode_emb_idx.append(decode_sen_emb_idx)
82 | batch_decode_len.append(len(decode_sen_idx))
83 |
84 | batch_idxs.append(idx)
85 | batch_text.append(encode_sen)
86 | batch_num_list.append(elem[1]['num_list'][:])
87 | batch_solution.append(elem[1]['ans'][:])
88 | batch_post_equation.append(elem[1]['post_template'][:])
89 | #print (elem[1]['post_template'][2:])
90 |
91 | batch_gd_tree.append(elem[1]['gd_tree_list'])
92 |
93 |
94 | max_encode_len = max(batch_encode_len)
95 | batch_encode_pad_idx = []
96 | #max_decode_len = max(batch_decode_len)
97 | max_decode_len = 22
98 | batch_decode_pad_idx = []
99 | batch_decode_pad_emb_idx = []
100 |
101 |
102 | for i in range(len(data_batch)):
103 | encode_sen_idx = batch_encode_idx[i]
104 | encode_sen_pad_idx = pad_sen(\
105 | encode_sen_idx, max_encode_len, self.vocab_dict['PAD'])
106 | batch_encode_pad_idx.append(encode_sen_pad_idx)
107 |
108 | decode_sen_idx = batch_decode_idx[i]
109 | decode_sen_pad_idx = pad_sen(\
110 | decode_sen_idx, max_decode_len, self.decode_classes_dict['PAD'])
111 | #decode_sen_idx, max_decode_len, self.decode_classes_dict['PAD_token'])
112 | batch_decode_pad_idx.append(decode_sen_pad_idx)
113 | decode_sen_pad_emb_idx = [self.decode_emb_idx[elem] for elem in decode_sen_pad_idx]
114 | batch_decode_pad_emb_idx.append(decode_sen_pad_emb_idx)
115 |
116 | batch_data_dict = dict()
117 |
118 | batch_data_dict['batch_encode_idx'] = batch_encode_idx
119 | batch_data_dict['batch_encode_len'] = batch_encode_len
120 | batch_data_dict['batch_encode_num_pos'] = batch_encode_num_pos
121 | batch_data_dict['batch_encode_pad_idx'] = batch_encode_pad_idx
122 |
123 | batch_data_dict['batch_decode_idx'] = batch_decode_idx
124 | batch_data_dict['batch_decode_len'] = batch_decode_len
125 | batch_data_dict['batch_decode_emb_idx'] = batch_decode_emb_idx
126 | batch_data_dict['batch_decode_pad_idx'] = batch_decode_pad_idx
127 | batch_data_dict['batch_decode_pad_emb_idx'] = batch_decode_pad_emb_idx
128 |
129 | batch_data_dict['batch_index'] = batch_idxs
130 | batch_data_dict['batch_text'] = batch_text
131 | batch_data_dict['batch_num_list'] = batch_num_list
132 | batch_data_dict['batch_solution'] = batch_solution
133 | batch_data_dict['batch_post_equation'] = batch_post_equation
134 |
135 | batch_data_dict['batch_gd_tree'] = batch_gd_tree
136 |
137 | if len(data_batch) != 1:
138 | new_batch_data_dict = self._sorted_batch(batch_data_dict)
139 | else:
140 | new_batch_data_dict = batch_data_dict
141 | return new_batch_data_dict
142 |
143 | def _sorted_batch(self, batch_data_dict):
144 | new_batch_data_dict = dict()
145 | batch_encode_len = np.array(batch_data_dict['batch_encode_len'])
146 | sort_idx = np.argsort(-batch_encode_len)
147 | for key, value in batch_data_dict.items():
148 | new_batch_data_dict[key] = np.array(value)[sort_idx]
149 | return new_batch_data_dict
150 |
151 | def get_batch(self, data_list, batch_size, verbose=0):
152 | #print data_list
153 | batch_num = int(len(data_list)/batch_size)+1
154 | for idx in range(batch_num):
155 | batch_start = idx*batch_size
156 | batch_end = min((idx+1)*batch_size, len(data_list))
157 | #print batch_start, batch_end, len(data_list)
158 | batch_data_dict = self._data_batch_preprocess(data_list[batch_start: batch_end])
159 | yield batch_data_dict
160 |
161 |
--------------------------------------------------------------------------------
/3_T-RNN_&_baselines/src/main.py:
--------------------------------------------------------------------------------
1 | import json
2 | import numpy as np
3 | import time
4 | from tqdm import tqdm_notebook as tqdm
5 |
6 | import torch
7 | import torch.nn as nn
8 | from torch.nn import Parameter
9 | import torch.nn.functional as F
10 | from torch.utils.data import Dataset
11 | from torch.autograd import Variable
12 | import torch.optim as optim
13 | from torch.nn.utils import clip_grad_norm_
14 | import logging
15 | import random
16 | import pickle
17 | import sys
18 |
19 | import os
20 |
21 | from trainer import *
22 | from utils import *
23 | from model import *
24 | from data_loader import *
25 |
26 |
27 | import warnings
28 | warnings.filterwarnings("ignore")
29 |
30 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
31 | os.environ["CUDA_VISIBLE_DEVICES"] = "0"
32 |
33 | #np.random.seed(123)
34 | #random.seed(123)
35 |
36 | def main():
37 |
38 | dataset = read_data_json("./data/final_dolphin_data_replicate.json")
39 | #emb_vectors = np.load('./data/emb_100.npy')
40 | #dict_keys(['text', 'ans', 'mid_template', 'num_position', 'post_template', 'num_list', \
41 | #'template_text', 'expression', 'numtemp_order', 'index', 'gd_tree_list'])
42 | count = 0
43 | max_l = 0
44 | #norm_templates = read_data_json("./data/post_dup_templates_num.json")
45 | for key, elem in dataset.items():
46 | #print (elem['post_template'])
47 | #print (norm_templates[key])
48 | #print ()
49 | #elem['post_template'] = norm_templates[key]
50 | elem['gd_tree_list'] = form_gdtree(elem)
51 | if len(elem['gd_tree_list']):
52 | #print (elem.keys())
53 | #print (elem['text'])
54 | #print (elem['mid_template'])
55 | #print (elem['post_template'])
56 | #print (elem['post_template'][2:])
57 | l = max([int(i.split('_')[1]) for i in set(elem['post_template']) if 'temp' in i])
58 | if max_l < l:
59 | max_l = l
60 | count += 1
61 | print (max_l)
62 | #print (elem['gd_tree_list'])
63 | print (count)
64 |
65 | data_loader = DataLoader(dataset)
66 | print ('loading finished')
67 |
68 | if os.path.isfile("./ablation_recursive-Copy1.log"):
69 | os.remove("./ablation_recursive-Copy1.log")
70 |
71 | logger = logging.getLogger()
72 | fhandler = logging.FileHandler(filename='mylog.log', mode='w')
73 | formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
74 | fhandler.setFormatter(formatter)
75 | logger.addHandler(fhandler)
76 | logger.setLevel(logging.DEBUG)
77 |
78 |
79 | params = {"batch_size": 32,
80 | "start_epoch" : 1,
81 | "n_epoch": 50, # 100, 50
82 | "rnn_classes":4, # 5, 4
83 | "save_file": "model_xxx_att.pkl"
84 | }
85 | encode_params = {
86 | "emb_size": 100,
87 | "hidden_size": 160, # 160, 128
88 | "input_dropout_p": 0.2,
89 | "dropout_p": 0.5,
90 | "n_layers": 2,
91 | "bidirectional": True,
92 | "rnn_cell": None,
93 | "rnn_cell_name": 'lstm',
94 | "variable_lengths_flag": True
95 | }
96 |
97 | #dataset = read_data_json("/home/wanglei/aaai_2019/pointer_math_dqn/dataset/source_2/math23k_final.json")
98 | #emb_vectors = np.load('/home/wanglei/aaai_2019/parsing_for_mwp/data/source_2/emb_100.npy')
99 | #data_loader = DataLoader(dataset)
100 | #print ('loading finished')
101 |
102 |
103 | recu_nn = RecursiveNN(data_loader.vocab_len, encode_params['emb_size'], params["rnn_classes"])
104 | #recu_nn = recu_nn.cuda()
105 | recu_nn = recu_nn.to(device)
106 | self_att_recu_tree = Self_ATT_RTree(data_loader, encode_params, recu_nn)
107 | #self_att_recu_tree = self_att_recu_tree.cuda()
108 | self_att_recu_tree = self_att_recu_tree.to(device)
109 | #for name, params in self_att_recu_tree.named_children():
110 | # print (name, params)
111 | optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad, self_att_recu_tree.parameters()), \
112 | lr=0.01, momentum=0.9, dampening=0.0)
113 |
114 | trainer = Trainer(data_loader, params)
115 | trainer.train(self_att_recu_tree, optimizer)
116 |
117 |
118 | main()
119 |
--------------------------------------------------------------------------------
/3_T-RNN_&_baselines/src/model.py:
--------------------------------------------------------------------------------
1 | import json
2 | import numpy as np
3 | import time
4 |
5 | import torch
6 | import torch.nn as nn
7 | from torch.nn import Parameter
8 | import torch.nn.functional as F
9 | from torch.utils.data import Dataset
10 | from torch.autograd import Variable
11 | import torch.optim as optim
12 | from torch.nn.utils import clip_grad_norm_
13 | import logging
14 | import random
15 | import pickle
16 | import sys
17 |
18 | from utils import *
19 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
20 |
21 | class BaseRNN(nn.Module):
22 | def __init__(self, vocab_size, emb_size, hidden_size, input_dropout_p, dropout_p, \
23 | n_layers, rnn_cell_name):
24 | super(BaseRNN, self).__init__()
25 | self.vocab_size = vocab_size
26 | self.emb_size = emb_size
27 | self.hidden_size = hidden_size
28 | self.n_layers = n_layers
29 | self.input_dropout_p = input_dropout_p
30 | self.input_dropout = nn.Dropout(p=input_dropout_p)
31 | self.rnn_cell_name = rnn_cell_name
32 | if rnn_cell_name.lower() == 'lstm':
33 | self.rnn_cell = nn.LSTM
34 | elif rnn_cell_name.lower() == 'gru':
35 | self.rnn_cell = nn.GRU
36 | else:
37 | raise ValueError("Unsupported RNN Cell: {0}".format(rnn_cell_name))
38 | self.dropout_p = dropout_p
39 |
40 | def forward(self, *args, **kwargs):
41 | raise NotImplementedError()
42 |
43 | class EncoderRNN(BaseRNN):
44 | def __init__(self, vocab_size, embed_model, emb_size=100, hidden_size=128, \
45 | input_dropout_p=0, dropout_p=0, n_layers=1, bidirectional=False, \
46 | rnn_cell=None, rnn_cell_name='gru', variable_lengths_flag=True):
47 | super(EncoderRNN, self).__init__(vocab_size, emb_size, hidden_size,
48 | input_dropout_p, dropout_p, n_layers, rnn_cell_name)
49 | self.variable_lengths_flag = variable_lengths_flag
50 | self.bidirectional = bidirectional
51 | self.embedding = embed_model
52 | if rnn_cell == None:
53 | self.rnn = self.rnn_cell(emb_size, hidden_size, n_layers,
54 | batch_first=True, bidirectional=bidirectional, dropout=dropout_p)
55 | else:
56 | self.rnn = rnn_cell
57 |
58 | def forward(self, input_var, input_lengths=None):
59 | embedded = self.embedding(input_var)
60 | embedded = self.input_dropout(embedded)
61 | #pdb.set_trace()
62 | if self.variable_lengths_flag:
63 | embedded = nn.utils.rnn.pack_padded_sequence(embedded, input_lengths, batch_first=True)
64 | output, hidden = self.rnn(embedded)
65 | if self.variable_lengths_flag:
66 | output, _ = nn.utils.rnn.pad_packed_sequence(output, batch_first=True)
67 | return output, hidden
68 |
69 | class Attention_1(nn.Module):
70 | def __init__(self, input_size, output_size):
71 | super(Attention_1, self).__init__()
72 | self.linear_out = nn.Linear(input_size, output_size)
73 | #self.mask = Parameter(torch.ones(1), requires_grad=False)
74 |
75 | #def set_mask(self, batch_size, input_length, num_pos):
76 | # self.mask = self.mask.repeat(input_length).unsqueeze(0).repeat(batch_size, 1)
77 | # for mask_i in range(batch_size):
78 | # self.mask[mask_i][num_pos[mask_i]] = 1
79 | def init_mask(self, size_0, size_1, input_length):
80 | mask = Parameter(torch.ones(1), requires_grad=False)
81 | mask = mask.repeat(size_1).unsqueeze(0).repeat(size_0, 1)
82 | #for i in range(input_length)
83 | input_index = list(range(input_length))
84 | for i in range(size_0):
85 | mask[i][input_index] = 0
86 | #print (mask)
87 | mask = mask.byte()
88 | #mask = mask.cuda()
89 | mask = mask.to(device)
90 | return mask
91 |
92 |
93 |
94 | def _forward(self, output, context, input_lengths, mask):
95 | '''
96 | output: len x hidden_size
97 | context: num_len x hidden_size
98 | input_lengths: torch scalar
99 | '''
100 | #print (output.size()) torch.Size([5, 256])
101 | #print (.size()) torch.Size([80, 256])
102 | #print (input_lengths)
103 | attn = torch.matmul(output, context.transpose(1,0))
104 | #print (attn.size()) 0 x 1
105 | attn.data.masked_fill_(mask, -float('inf'))
106 | attn = F.softmax(attn, dim=1)
107 | #print (attn)
108 | mix = torch.matmul(attn, context)
109 | #print ("mix:", mix)
110 | #print ("output:", output)
111 | combined = torch.cat((mix, output), dim=1)
112 | #print ("combined:",combined)
113 | output = F.tanh(self.linear_out(combined))
114 |
115 | #print ("output:",output)
116 | #print ("------------")
117 | #print ()
118 | return output, attn
119 |
120 |
121 |
122 |
123 | def forward(self, output, context, num_pos, input_lengths):
124 | '''
125 | output: decoder, (batch, 1, hiddem_dim2)
126 | context: from encoder, (batch, n, hidden_dim1)
127 | actually, dim2 == dim1, otherwise cannot do matrix multiplication
128 |
129 | '''
130 | batch_size = output.size(0)
131 | hidden_size = output.size(2)
132 | input_size = context.size(1)
133 | #print ('att:', hidden_size, input_size)
134 | #print ("context", context.size())
135 |
136 | attn_list = []
137 | mask_list = []
138 | output_list = []
139 | for b_i in range(batch_size):
140 | per_output = output[b_i]
141 | per_num_pos = num_pos[b_i]
142 | #print(context, num_pos)
143 | current_output = per_output[per_num_pos]
144 | per_mask = self.init_mask(len(per_num_pos), input_size, input_lengths[b_i])
145 | mask_list.append(per_mask)
146 | #print ("current_context:", current_context.size())
147 | per_output, per_attn = self._forward(current_output, context[b_i], input_lengths[b_i], per_mask)
148 | #for p_j in range(len(per_num_pos)):
149 | # current_context = per_context[per_num_pos[p_j]]
150 | # print ("c_context:", current_context.size())
151 | output_list.append(per_output)
152 | attn_list.append(per_attn)
153 |
154 |
155 |
156 | #self.set_mask(batch_size, input_size, num_pos)
157 | # (b, o, dim) * (b, dim, i) -> (b, o, i)
158 | '''
159 | attn = torch.bmm(output, context.transpose(1,2))
160 | if self.mask is not None:
161 | attn.data.masked_fill_(self.mask, -float('inf'))
162 | attn = F.softmax(attn.view(-1, input_size), dim=1).view(batch_size, -1, input_size)
163 |
164 | # (b, o, i) * (b, i, dim) -> (b, o, dim)
165 | mix = torch.bmm(attn, context)
166 |
167 | combined = torch.cat((mix, output), dim=2)
168 |
169 | #output = F.tanh(self.linear_out(combined.view(-1, 2*hidden_size)))\
170 | .view(batch_size, -1, hidden_size)
171 |
172 | # output: (b, o, dim)
173 | # attn : (b, o, i)
174 | #return output, attn
175 | '''
176 | return output_list, attn_list
177 |
178 | class RecursiveNN(nn.Module):
179 | def __init__(self, vocabSize, embedSize=100, numClasses=5):
180 | super(RecursiveNN, self).__init__()
181 | #self.embedding = nn.Embedding(int(vocabSize), embedSize)
182 | #self.self-att -> embedding
183 | #self.att
184 | #self.self_att_model = self_att_model
185 | self.W = nn.Linear(2*embedSize, embedSize, bias=True)
186 | self.projection = nn.Linear(embedSize, numClasses, bias=True)
187 | self.activation = F.relu
188 | self.nodeProbList = []
189 | self.labelList = []
190 | self.classes = ['+','-','*','/','^']
191 |
192 | def leaf_emb(self, node, num_embed, look_up):
193 | if node.is_leaf:
194 | #try:
195 | node.node_emb = num_embed[look_up.index(node.root_value)]
196 | # except:
197 | # print(node)
198 | else:
199 | self.leaf_emb(node.left_tree, num_embed, look_up)
200 | self.leaf_emb(node.right_tree, num_embed, look_up)
201 |
202 | def traverse(self, node):
203 | if node.is_leaf:
204 | currentNode = node.node_emb.unsqueeze(0)
205 | else:
206 | #currentNode = self.activation(self.W(torch.cat((self.traverse(node.left_tree),self.traverse(node.right_tree)),1)))
207 | left_vector = self.traverse(node.left_tree)#.unsqueeze(0)
208 | right_vector = self.traverse(node.right_tree)#.unsqueeze(0)
209 | #print (left_vector)
210 | combined_v = torch.cat((left_vector, right_vector),1)
211 | currentNode = self.activation(self.W(combined_v))
212 | node.node_emb = currentNode.squeeze(0)
213 | assert node.is_leaf==False, "error is leaf"
214 | #self.nodeProbList.append(self.projection(currentNode))
215 | proj_probs = self.projection(currentNode)
216 | self.nodeProbList.append(proj_probs)
217 | #node.numclass_probs = proj_probs
218 | self.labelList.append(self.classes.index(node.root_value))
219 | return currentNode
220 |
221 | def forward(self, tree_node, num_embed, look_up):
222 |
223 | self.nodeProbList = []
224 | self.labelList = []
225 | self.leaf_emb(tree_node, num_embed, look_up)
226 | self.traverse(tree_node)
227 | self.labelList = torch.LongTensor(self.labelList)
228 |
229 | #self.labelList = self.labelList.cuda()
230 | #batch_loss_l = torch.FloatTensor([0])[0].cuda()
231 | batch_loss_l = torch.FloatTensor([0])[0].to(device)
232 | self.labelList = self.labelList.to(device)
233 | #print (torch.cat(self.nodeProbList).size())
234 | #print (self.labelList)
235 | return torch.cat(self.nodeProbList)#, tree_node
236 |
237 | def getLoss_train(self, tree_node, num_embed, look_up):
238 | nodes = self.forward(tree_node, num_embed, look_up)
239 | predictions = nodes.max(dim=1)[1]
240 | loss = F.cross_entropy(input=nodes, target=self.labelList)
241 |
242 | #exit(0)
243 | #print (predictions.size())
244 | #print ()
245 | acc_elemwise, acc_t = self.compute_acc_elemwise(predictions, self.labelList)
246 | acc_integrate = self.compute_acc_integrate(predictions, self.labelList)
247 | return predictions, loss, acc_elemwise, acc_t, acc_integrate#, tree_node
248 |
249 | def test_forward(self, tree_node, num_embed, look_up):
250 | nodes = self.forward(tree_node, num_embed, look_up)
251 | predictions = nodes.max(dim=1)[1]
252 | acc_elemwise, acc_t = self.compute_acc_elemwise(predictions, self.labelList)
253 | acc_integrate = self.compute_acc_integrate(predictions, self.labelList)
254 | return predictions, acc_elemwise, acc_t, acc_integrate#, tree_node
255 |
256 | def predict_traverse(self, node):
257 | if node.is_leaf:
258 | currentNode = node.node_emb.unsqueeze(0)
259 | else:
260 | #currentNode = self.activation(self.W(torch.cat((self.traverse(node.left_tree),self.traverse(node.right_tree)),1)))
261 | left_vector = self.predict_traverse(node.left_tree)#.unsqueeze(0)
262 | right_vector = self.predict_traverse(node.right_tree)#.unsqueeze(0)
263 | #print (left_vector)
264 | combined_v = torch.cat((left_vector, right_vector),1)
265 | currentNode = self.activation(self.W(combined_v))
266 | node.node_emb = currentNode.squeeze(0)
267 | assert node.is_leaf==False, "error is leaf"
268 | #self.nodeProbList.append(self.projection(currentNode))
269 | proj_probs = self.projection(currentNode)
270 | node_id = proj_probs.max(dim=1)[1]
271 | node_marker = self.classes[node_id]
272 | node.root_value = node_marker
273 | #print ('_++_', node_marker)
274 | #print ("_++_", proj_probs.size())
275 | #self.nodeProbList.append(proj_probs)
276 | #node.numclass_probs = proj_probs
277 | #self.labelList.append(self.classes.index(node.root_value))
278 | return currentNode
279 |
280 | def predict(self, tree_node, num_embed, look_up):#, num_list, gold_ans):
281 | self.leaf_emb(tree_node, num_embed, look_up)
282 | self.predict_traverse(tree_node)
283 | post_equ = post_order(tree_node)
284 | #print ('tst:', post_equ)
285 | #pred_ans = post_solver(post_equ)
286 |
287 | return tree_node, post_equ#, pred_ans
288 |
289 | def compute_acc_elemwise(self, pred_tensor, label_tensor):
290 | return torch.sum((pred_tensor == label_tensor).int()).item() , len(pred_tensor)
291 |
292 | def compute_acc_integrate(self, pred_tensor, label_tensor):
293 | return 1 if torch.equal(pred_tensor, label_tensor) else 0
294 |
295 | def evaluate(self, trees):
296 | n = nAll = correctRoot = correctAll = 0.0
297 | return correctRoot / n, correctAll/nAll
298 |
299 | def forward_one_layer(self, left_node, right_node):
300 | left_vector = left_node.node_emb.unsqueeze(0)
301 | right_vector = right_node.node_emb.unsqueeze(0)
302 | combined_v = torch.cat((left_vector, right_vector),1)
303 | currentNode = self.activation(self.W(combined_v))
304 | root_node = BinaryTree()
305 | root_node.is_leaf = False
306 | root_node.node_emb = currentNode.squeeze(0)
307 | proj_probs = self.projection(currentNode)
308 | #print ('recur:', proj_probs)
309 | #print ('r_m',proj_probs.max(1)[1][0].item())
310 | pred_idx = proj_probs.max(1)[1][0].item()
311 | root_node.root_value = self.classes[pred_idx]
312 | root_node.left_tree = left_node
313 | root_node.right_tree = right_node
314 | return root_node, proj_probs
315 |
316 | class Self_ATT_RTree(nn.Module):
317 | def __init__(self, data_loader, encode_params, RecursiveNN):
318 | super(Self_ATT_RTree, self).__init__()
319 | self.data_loader = data_loader
320 | self.encode_params = encode_params
321 | self.embed_model = nn.Embedding(data_loader.vocab_len, encode_params['emb_size'])
322 | #self.embed_model = self.embed_model.cuda()
323 | self.embed_model = self.embed_model.to(device)
324 | self.encoder = EncoderRNN(vocab_size = data_loader.vocab_len,
325 | embed_model = self.embed_model,
326 | emb_size = encode_params['emb_size'],
327 | hidden_size = encode_params['hidden_size'],
328 | input_dropout_p = encode_params['input_dropout_p'],
329 | dropout_p = encode_params['dropout_p'],
330 | n_layers = encode_params['n_layers'],
331 | bidirectional = encode_params['bidirectional'],
332 | rnn_cell = encode_params['rnn_cell'],
333 | rnn_cell_name = encode_params['rnn_cell_name'],
334 | variable_lengths_flag = encode_params['variable_lengths_flag'])
335 | if encode_params['bidirectional'] == True:
336 | self.self_attention = Attention_1(encode_params['hidden_size']*4, encode_params['emb_size'])
337 | else:
338 | self.self_attention = Attention_1(encode_params['hidden_size']*2, encode_params['emb_size'])
339 |
340 | if encode_params['bidirectional'] == True:
341 | decoder_hidden_size = encode_params['hidden_size']*2
342 |
343 | self.recur_nn = RecursiveNN
344 |
345 | self._prepare_for_recur()
346 | self._prepare_for_pointer()
347 |
348 | def _prepare_for_recur(self):
349 | self.fixed_num_symbol = ['1', 'PI']
350 | self.fixed_num_idx = [self.data_loader.vocab_dict[elem] for elem in self.fixed_num_symbol]
351 | self.fixed_num = torch.LongTensor(self.fixed_num_idx)
352 |
353 | #self.fixed_num = self.fixed_num.cuda()
354 | self.fixed_num = self.fixed_num.to(device)
355 |
356 | self.fixed_num_emb = self.embed_model(self.fixed_num)
357 |
358 | def _prepare_for_pointer(self):
359 | self.fixed_p_num_symbol = ['EOS','1', 'PI']
360 | self.fixed_p_num_idx = [self.data_loader.vocab_dict[elem] for elem in self.fixed_p_num_symbol]
361 | self.fixed_p_num = torch.LongTensor(self.fixed_p_num_idx)
362 |
363 | #self.fixed_p_num = self.fixed_p_num.cuda()
364 | self.fixed_p_num = self.fixed_p_num.to(device)
365 |
366 | self.fixed_p_num_emb = self.embed_model(self.fixed_p_num)
367 |
368 | def forward(self, input_tensor, input_lengths, num_pos, b_gd_tree):
369 | encoder_outputs, encoder_hidden = self.encoder(input_tensor, input_lengths)
370 | en_output_list, en_attn_list = self.self_attention(encoder_outputs, encoder_outputs, num_pos, input_lengths)
371 |
372 | batch_size = len(en_output_list)
373 |
374 | batch_predictions = []
375 | batch_acc_e_list = []
376 | batch_acc_e_t_list = []
377 | batch_acc_i_list = []
378 | #batch_loss_l = torch.FloatTensor([0])[0].cuda()
379 | batch_loss_l = torch.FloatTensor([0])[0].to(device)
380 | batch_count = 0
381 | for b_i in range(batch_size):
382 | #en_output = en_output_list[b_i]
383 | en_output = torch.cat([self.fixed_num_emb, en_output_list[b_i]], dim=0)
384 | #print (num_pos)
385 | look_up = self.fixed_num_symbol + ['temp_'+str(temp_i) for temp_i in range(len(num_pos[b_i]))]
386 | #print (mid_order(b_gd_tree))
387 | if len(b_gd_tree[b_i]) == 0:
388 | continue
389 | gd_tree_node = b_gd_tree[b_i][-1]
390 |
391 | #print (en_output)
392 | #print (look_up.index('temp_0'))
393 | #print (en_output[look_up.index("temp_0")])
394 | #print ()
395 |
396 | #print (mid_order(gd_tree_node))
397 | #self.recur_nn(gd_tree_node, en_output, look_up)
398 | p, l, acc_e, acc_e_t, acc_i = self.recur_nn.getLoss_train(gd_tree_node, en_output, look_up)
399 | #print (p)tensor([ 0, 2])
400 | #print (l)tensor(1.5615)
401 | #print (l)
402 | #print (post_order(gd_tree_node))
403 | #get_info_teacher_pointer(gd_tree_node)
404 |
405 | batch_predictions.append(p)
406 | batch_acc_e_list.append(acc_e)
407 | batch_acc_e_t_list.append(acc_e_t)
408 | batch_acc_i_list.append(acc_i)
409 | #batch_loss_l.append(l)
410 | #batch_loss_l += l
411 | batch_loss_l = torch.sum(torch.cat([ batch_loss_l.unsqueeze(0), l.unsqueeze(0)], 0))
412 | batch_count += 1
413 |
414 | #print (post_order(gd_tree_final))
415 | #print ("hhhhhh:", en_output)
416 | #print (batch_loss_l)
417 | #torch.cat(batch_loss_l)
418 | #print (torch.sum(torch.cat(batch_loss_l)))
419 | #print ()
420 | return batch_predictions, batch_loss_l, batch_count, batch_acc_e_list, batch_acc_e_t_list, batch_acc_i_list
421 |
422 | def test_forward_recur(self, input_tensor, input_lengths, num_pos, b_gd_tree):
423 | encoder_outputs, encoder_hidden = self.encoder(input_tensor, input_lengths)
424 | en_output_list, en_attn_list = self.self_attention(encoder_outputs, encoder_outputs, num_pos, input_lengths)
425 |
426 | batch_size = len(en_output_list)
427 |
428 | batch_predictions = []
429 | batch_acc_e_list = []
430 | batch_acc_e_t_list = []
431 | batch_acc_i_list = []
432 | batch_count = 0
433 | for b_i in range(batch_size):
434 | #en_output = en_output_list[b_i]
435 | en_output = torch.cat([self.fixed_num_emb, en_output_list[b_i]], dim=0)
436 | #print (num_pos)
437 | look_up = self.fixed_num_symbol + ['temp_'+str(temp_i) for temp_i in range(len(num_pos[b_i]))]
438 | #print (b_gd_tree[b_i])
439 | if len(b_gd_tree[b_i]) == 0:
440 | continue
441 | gd_tree_node = b_gd_tree[b_i][-1]
442 |
443 | #print (en_output)
444 | #print (look_up.index('temp_0'))
445 | #print (en_output[look_up.index("temp_0")])
446 | #print ()
447 |
448 | #self.recur_nn(gd_tree_node, en_output, look_up)
449 | p, acc_e, acc_e_t, acc_i = self.recur_nn.test_forward(gd_tree_node, en_output, look_up)
450 | #print (p)tensor([ 0, 2])
451 | #print (l)tensor(1.5615)
452 | #print (l)
453 | #print (post_order(gd_tree_node))
454 | #get_info_teacher_pointer(gd_tree_node)
455 |
456 | batch_predictions.append(p)
457 | batch_acc_e_list.append(acc_e)
458 | batch_acc_e_t_list.append(acc_e_t)
459 | batch_acc_i_list.append(acc_i)
460 | #batch_loss_l.append(l)
461 | #batch_loss_l += l
462 | batch_count += 1
463 |
464 | #print (post_order(gd_tree_final))
465 | #print ("hhhhhh:", en_output)
466 | #print (batch_loss_l)
467 | #torch.cat(batch_loss_l)
468 | #print (torch.sum(torch.cat(batch_loss_l)))
469 | #print ()
470 | return batch_predictions, batch_count, batch_acc_e_list, batch_acc_e_t_list, batch_acc_i_list
471 |
472 | def predict_forward_recur(self, input_tensor, input_lengths, num_pos, batch_seq_tree, batch_flags):
473 | encoder_outputs, encoder_hidden = self.encoder(input_tensor, input_lengths)
474 | en_output_list, en_attn_list = self.self_attention(encoder_outputs, encoder_outputs, num_pos, input_lengths)
475 | #print ('-x-x-x-',en_attn_list[0])
476 | batch_size = len(en_output_list)
477 | batch_pred_tree_node = []
478 | batch_pred_post_equ = []
479 | #batch_pred_ans = []
480 |
481 |
482 | for b_i in range(batch_size):
483 |
484 | flag = batch_flags[b_i]
485 | #num_list = batch_num_list[b_i]
486 | if flag == 1:
487 |
488 |
489 | en_output = torch.cat([self.fixed_num_emb, en_output_list[b_i]], dim=0)
490 |
491 | look_up = self.fixed_num_symbol + ['temp_'+str(temp_i) for temp_i in range(len(num_pos[b_i]))]
492 |
493 | seq_node = batch_seq_tree[b_i]
494 | #num_list = batch_num_list[i]
495 | #gold_ans = batch_solution[i]
496 | tree_node, post_equ = self.recur_nn.predict(seq_node, en_output, look_up)#, num_list, gold_ans)
497 | #p, acc_e, acc_e_t, acc_i = self.recur_nn.test_forward(gd_tree_node, en_output, look_up)
498 |
499 | batch_pred_tree_node.append(tree_node)
500 | batch_pred_post_equ.append(post_equ)
501 |
502 | else:
503 | batch_pred_tree_node.append(None)
504 | batch_pred_post_equ.append([])
505 |
506 | return batch_pred_tree_node, batch_pred_post_equ#,en_attn_list
507 |
--------------------------------------------------------------------------------
/3_T-RNN_&_baselines/src/replicate.py:
--------------------------------------------------------------------------------
1 | ######################################################################
2 | # File: replicate.py
3 | # Author: Vishal Dey
4 | # Created on: 11 Dec 2019
5 | #######################################################################
6 | '''
7 | Synopsis: Duplicates Dolphin300 to Dolphin1500 from Dolphin18k
8 | duplicates templates, substitutes quantities and recompute ans
9 | '''
10 |
11 | import os
12 | import sys
13 | import json
14 | import numpy as np
15 | from copy import deepcopy
16 |
17 | from utils import *
18 |
19 | np.random.seed(123)
20 |
21 | def read_json(fname):
22 | with open(os.path.join('./data/', fname), 'r') as fp:
23 | return json.load(fp)
24 |
25 | def write_json(fname, data):
26 | with open(os.path.join('./data/', fname), 'w') as fp:
27 | json.dump(data, fp)
28 |
29 |
30 | def duplicate(entry):
31 | text = entry['text']
32 | #equation = entry['expression']
33 | num_list = entry['num_list']
34 | post_temp = entry['post_template']
35 |
36 | dup_num_list = []
37 | dup_text = []
38 |
39 | for i in num_list:
40 | if '.' in i:
41 | temp = str(np.random.rand())[:len(i)]
42 | while (temp in num_list or temp in dup_num_list):
43 | temp = str(np.random.rand())[:len(i)]
44 | dup_num_list.append(temp)
45 | else:
46 | temp = str(np.random.randint(123))
47 | while (temp in num_list or temp in dup_num_list):
48 | temp = str(np.random.randint(10000))
49 | dup_num_list.append(temp)
50 |
51 | for each in text.split():
52 | if each in num_list:
53 | dup_text.append(dup_num_list[num_list.index(each)])
54 | else:
55 | dup_text.append(each)
56 |
57 | post_equ = []
58 | for i in post_temp:
59 | if 'temp' in i:
60 | num = dup_num_list[int(i.split('_')[1])]
61 | post_equ.append(num)
62 | else:
63 | post_equ.append(i)
64 | try:
65 | ans = post_solver(post_equ)
66 | except:
67 | return None
68 | # print(dup_num_list, ' '.join(dup_text), ans)
69 | return dup_num_list, ' '.join(dup_text), ans
70 |
71 |
72 | def main():
73 | dolphin = read_json('final_dolphin_data.json')
74 |
75 | n_repeats = 5
76 | max_id = int(max(list(dolphin.keys())))
77 |
78 | new_dolphin = {}
79 |
80 | for k, v in dolphin.items():
81 | new_dolphin[k] = deepcopy(v)
82 |
83 | for i in range(n_repeats):
84 | max_id += 1
85 | tmp = duplicate(v)
86 | if tmp:
87 | new_dolphin[str(max_id)] = deepcopy(v)
88 | new_dolphin[str(max_id)]['index'] = str(max_id)
89 | new_dolphin[str(max_id)]['text'] = tmp[1]
90 | new_dolphin[str(max_id)]['num_list'] = tmp[0]
91 | new_dolphin[str(max_id)]['ans'] = tmp[-1]
92 | # print(new_dolphin)
93 |
94 | print(len(new_dolphin.keys()))
95 |
96 | write_json('final_dolphin_data_replicate.json', new_dolphin)
97 | index = list(new_dolphin.keys())
98 |
99 | valid_size = int(0.2*len(index))
100 |
101 | valid_ids = np.random.choice(index, valid_size)
102 | test_size = int(valid_size/2)
103 | valid_size -= test_size
104 |
105 | print(len(index)-valid_size-test_size, valid_size, test_size)
106 |
107 | # validation ids
108 | write_json('valid_ids_dolphin.json', valid_ids[:valid_size].tolist())
109 | # test ids
110 | fp1 = open('./data/id_ans_test_dolphin', 'w')
111 |
112 | testtemp = []
113 | for id in valid_ids[valid_size:]:
114 | print(id + '\t' + new_dolphin[id]['ans'], file=fp1)
115 | tmp = ' '.join(new_dolphin[id]['post_template'])
116 | for ch in ['+', '-', '*', '/']:
117 | tmp = tmp.replace(ch, '^')
118 | testtemp.append([id, tmp.split()])
119 |
120 | fp1.close()
121 |
122 | # test templates masked
123 | write_json('pg_norm_test_dolphin.json', testtemp)
124 |
125 |
126 | main()
127 |
--------------------------------------------------------------------------------
/3_T-RNN_&_baselines/src/retrieval_model.py:
--------------------------------------------------------------------------------
1 | ######################################################################
2 | # File: retrieval_model.py
3 | # Author: Vishal Dey
4 | # Created on: 11 Dec 2019
5 | #######################################################################
6 | '''
7 | Synopsis: Create w2v for corresponding string description of each problem
8 | Reads in pretrianed Word2vec vectors and add each word vector to obtain
9 | phrase vectors.
10 | Either compute TF-IDF / w2v based cosine similarity
11 | '''
12 |
13 | import os
14 | import json
15 | import sys
16 |
17 | import numpy as np
18 | from sklearn.feature_extraction.text import TfidfVectorizer
19 | from sklearn.metrics.pairwise import linear_kernel, cosine_similarity
20 | from copy import deepcopy
21 |
22 | # for reproducibility
23 | np.random.choice(123)
24 |
25 |
26 | # read json file
27 | def read_json(fname):
28 | with open(os.path.join('./data/', fname), 'r') as fp:
29 | return json.load(fp)
30 |
31 |
32 | # compute TF-iDF based most similar problem
33 | def find_similar_tfidf(tfidf_matrix, tfidf_vector):
34 | cosine_similarities = linear_kernel(tfidf_vector, tfidf_matrix).flatten()
35 | related_docs_indices = [i for i in cosine_similarities.argsort()[::-1]]
36 | #print(related_docs_indices[0], cosine_similarities[related_docs_indices[0]])
37 | return related_docs_indices[0]
38 |
39 | # compute w2v cosine based most similar problem
40 | def find_similar_w2v(w2vmatrix, w2vemb, query, EMB_DIM, minv=0, maxv=1):
41 | query_emb = []
42 | for word in query.split():
43 | if word in w2vemb:
44 | query_emb.append(np.array(list(map(float, w2vemb[word]))))
45 | else:
46 | query_emb.append(np.random.uniform(minv, maxv, EMB_DIM))
47 |
48 | cosine_similarities = linear_kernel(query_emb, w2vmatrix).flatten()
49 | related_docs_indices = [i for i in cosine_similarities.argsort()[::-1]]
50 | print(related_docs_indices)
51 | return related_docs_indices[0]
52 |
53 |
54 | # load w2v problem
55 | def load_w2v(w2v_file):
56 | EMB_DIM = 0
57 | w2v_emb = {}
58 | minv = sys.float_info.max
59 | maxv = sys.float_info.min
60 |
61 | with open(os.path.join('./w2v', w2v_file), 'r') as fp:
62 | EMB_DIM = int(fp.readline().split()[1])
63 | for line in fp.readlines():
64 | tmp = line.split()
65 | tmp = list(map(float, tmp[1:]))
66 | minv = min(minv, min(tmp))
67 | maxv = max(maxv, max(tmp))
68 |
69 | print(EMB_DIM, minv, maxv)
70 | return w2v_emb, EMB_DIM, minv, maxv
71 |
72 |
73 | # compute w2v matrix (problem x EMB_DIM)
74 | def compute_w2vmatrix(docs, w2v_emb, EMB_DIM, minv=0, maxv=1):
75 | w2vmatrix = []
76 | for doc in docs:
77 | emb = [0]*EMB_DIM
78 | for word in doc.split():
79 | if word in w2v_emb:
80 | emb += np.array(list(map(float, w2v_emb[word])))
81 | else:
82 | emb += np.random.uniform(minv, maxv, EMB_DIM)
83 | emb /= len(doc.split())
84 |
85 | w2vmatrix.append(emb)
86 | return w2vmatrix
87 |
88 |
89 | # post fix equation solver
90 | def post_solver(post_equ):
91 | stack = []
92 | op_list = ['+', '-', '/', '*', '^']
93 |
94 | # if len(post_equ)-2*len([i for i in post_equ if (any(c.isdigit() for c in i))]) >= 0:
95 | # print(post_equ)
96 | # return 0
97 | for elem in post_equ:
98 | if elem not in op_list:
99 | op_v = elem
100 | if '%' in op_v:
101 | op_v = float(op_v[:-1])/100.0
102 | if op_v == 'PI':
103 | op_v = np.pi
104 | stack.append(str(op_v))
105 | elif elem in op_list:
106 | op_v_1 = stack.pop()
107 | op_v_1 = float(op_v_1)
108 | op_v_2 = stack.pop()
109 | op_v_2 = float(op_v_2)
110 | if elem == '+':
111 | stack.append(str(op_v_2+op_v_1))
112 | elif elem == '-':
113 | stack.append(str(op_v_2-op_v_1))
114 | elif elem == '*':
115 | stack.append(str(op_v_2*op_v_1))
116 | elif elem == '/':
117 | if op_v_1 == 0:
118 | return np.nan
119 | stack.append(str(op_v_2/op_v_1))
120 | else:
121 | stack.append(str(op_v_2**op_v_1))
122 | return stack.pop()
123 |
124 | # get post fix equation from the template
125 | def get_equation(template, num_list):
126 | equation = []
127 | for x in template:
128 | if 'temp' in x:
129 | equation.append(str(num_list[int(x.split('_')[1])]))
130 | else:
131 | equation.append(x)
132 | return equation
133 |
134 |
135 |
136 | def main():
137 | math23k = read_json('math23k_final.json')
138 | #math23k = read_json('final_dolphin_data.json')
139 |
140 | # read in test IDs and answer - gold truth
141 | test_ids = []
142 | with open('./data/id_ans_test', 'r') as fp:
143 | for line in fp.readlines():
144 | test_ids.append(line.split('\t')[0])
145 | # del math23k['1682']
146 |
147 | test_length = len(test_ids)
148 |
149 | train = {}
150 | test = {}
151 | corpus = []
152 |
153 | # make training corpus text
154 | for k, v in math23k.items():
155 | if k not in test_ids:
156 | train[k] = v
157 | corpus.append(v['template_text'])
158 | else:
159 | test[k] = v
160 |
161 | corpus = list(corpus)
162 | indices = list(train.keys())
163 |
164 | print(len(corpus), test_length)
165 |
166 | # tfidf vectorizer
167 | vectorizer = TfidfVectorizer()
168 | corpus_tfidf = vectorizer.fit_transform(corpus)
169 |
170 | # load w2v
171 | #w2v_emb, EMB_DIM, minv, maxv = load_w2v('crawl-300d-2M.vec')
172 | #w2vmatrix = compute_w2vmatrix(corpus, w2v_emb, EMB_DIM, minv, maxv)
173 |
174 | #print(len(w2vmatrix))
175 | similar_indices = {}
176 |
177 | for k, v in test.items():
178 | query_tfidf = vectorizer.transform([v['template_text']])
179 | i = find_similar_tfidf(corpus_tfidf, query_tfidf)
180 | #i = find_similar_w2v(w2vmatrix, w2v_emb, v['template_text'], EMB_DIM, minv, maxv)
181 | #i = jaccard_similarity(corpus, v['template_text'])
182 | #print(i)
183 | similar_indices[k] = indices[i]
184 |
185 | num_correct = 0
186 | # outputs the prediction
187 | fout = open('pred.txt', 'w')
188 |
189 | for k, v in similar_indices.items():
190 | template = math23k[v]['post_template']
191 | num_list = math23k[k]['num_list']
192 | ans = math23k[k]['ans']
193 |
194 | if (max([int(s.split('temp_')[1]) for s in set(template) if 'temp' in s]) < len(num_list)):
195 | pred_ans = post_solver(get_equation(template, num_list))
196 | print(math23k[k]['index'], pred_ans, ans, file=fout)
197 | if (np.isclose(float(pred_ans), float(ans))):
198 | num_correct += 1
199 | else:
200 | print('wrong template prediction: mismatch with len of num list')
201 |
202 | print('Accuracy = ', str(float(num_correct)/test_length))
203 | fout.close()
204 |
205 | main()
206 |
--------------------------------------------------------------------------------
/3_T-RNN_&_baselines/src/trainer.py:
--------------------------------------------------------------------------------
1 | from utils import *
2 | from model import *
3 | import os
4 |
5 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
6 |
7 | class Trainer():
8 | def __init__(self, data_loader, params):
9 | self.data_loader = data_loader
10 | self.params = params
11 | self.train_len = len(data_loader.train_list)
12 | self.valid_len = len(data_loader.valid_list)
13 | self.test_len = len(data_loader.test_list)
14 | #self.optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9, dampening=0.0)
15 | #self.pg_seq = dict(read_data_json("./data/pg_seq_norm_0828.json")) #dict
16 | self.pg_seq = dict(read_data_json("./data/pg_norm_test_dolphin.json")) #dict
17 |
18 | def _train_batch_recur(self, model, batch_encode_pad_idx, batch_encode_num_pos, batch_encode_len, batch_gd_tree):
19 | batch_encode_pad_idx_tensor = torch.LongTensor(batch_encode_pad_idx)
20 | batch_encode_tensor_len = torch.LongTensor(batch_encode_len)
21 |
22 | batch_encode_pad_idx_tensor = batch_encode_pad_idx_tensor.to(device)#cuda()
23 | batch_encode_tensor_len = batch_encode_tensor_len.to(device)#cuda()
24 | #print ("batch_encode_num_pos",batch_encode_num_pos)
25 | b_pred, b_loss, b_count, b_acc_e, b_acc_e_t, b_acc_i = model(batch_encode_pad_idx_tensor, batch_encode_tensor_len, \
26 | batch_encode_num_pos, batch_gd_tree)
27 | self.optimizer.zero_grad()
28 | #print (b_loss)
29 | b_loss.backward(retain_graph=True)
30 | clip_grad_norm_(model.parameters(), 5, norm_type=2.)
31 | self.optimizer.step()
32 | return b_pred, b_loss.item(), b_count, b_acc_e, b_acc_e_t, b_acc_i
33 |
34 | def _test_recur(self, model, data_list):
35 | batch_size = self.params['batch_size']
36 | data_generator = self.data_loader.get_batch(data_list, batch_size)
37 | test_pred = []
38 | test_count = 0
39 | test_acc_e = []
40 | test_acc_e_t = []
41 | test_acc_i = []
42 | for batch_elem in data_generator:
43 | batch_encode_idx = batch_elem['batch_encode_idx']
44 | batch_encode_pad_idx = batch_elem['batch_encode_pad_idx']
45 | batch_encode_num_pos = batch_elem['batch_encode_num_pos']
46 | batch_encode_len = batch_elem['batch_encode_len']
47 |
48 | batch_decode_idx = batch_elem['batch_decode_idx']
49 |
50 | batch_gd_tree = batch_elem['batch_gd_tree']
51 |
52 | batch_encode_pad_idx_tensor = torch.LongTensor(batch_encode_pad_idx)
53 | batch_encode_tensor_len = torch.LongTensor(batch_encode_len)
54 |
55 | batch_encode_pad_idx_tensor = batch_encode_pad_idx_tensor.to(device)#cuda()
56 | batch_encode_tensor_len = batch_encode_tensor_len.to(device)#cuda()
57 |
58 | b_pred, b_count, b_acc_e, b_acc_e_t, b_acc_i = model.test_forward_recur(batch_encode_pad_idx_tensor, batch_encode_tensor_len, \
59 | batch_encode_num_pos, batch_gd_tree)
60 |
61 | test_pred+= b_pred
62 | test_count += b_count
63 | test_acc_e += b_acc_e
64 | test_acc_e_t += b_acc_e_t
65 | test_acc_i += b_acc_i
66 | return test_pred, test_count, test_acc_e, test_acc_e_t, test_acc_i
67 |
68 | def predict_joint_batch(self, model, batch_encode_pad_idx, batch_encode_num_pos, batch_encode_len, batch_gd_tree, batch_index, batch_num_list):
69 | batch_encode_pad_idx_tensor = torch.LongTensor(batch_encode_pad_idx)
70 | batch_encode_tensor_len = torch.LongTensor(batch_encode_len)
71 |
72 | batch_encode_pad_idx_tensor = batch_encode_pad_idx_tensor.to(device)#cuda()
73 | batch_encode_tensor_len = batch_encode_tensor_len.to(device)#cuda()
74 |
75 | #alphas = 'abcdefghijklmnopqrstuvwxyz'
76 | alphas = list(map(str, list(range(0, 14))))
77 | batch_seq_tree = []
78 | batch_flags = []
79 | batch_num_len = []
80 | for i in range(len(batch_index)):
81 | index = batch_index[i]
82 | op_template = self.pg_seq[index]
83 | new_op_temps = []
84 | num_len = 0
85 | #print(op_template)
86 | for temp_elem in op_template:
87 | if 'temp' in temp_elem:
88 | num_idx = alphas.index(temp_elem[5:])
89 | #num_idx = temp_elem[5:]
90 | #new_op_temps.append('temp_'+str(num_idx))
91 | new_op_temps.append(temp_elem)
92 | if num_len < num_idx:
93 | num_len = num_idx
94 | else:
95 | new_op_temps.append(temp_elem)
96 |
97 | #print (new_op_temps)
98 | #print (op_template)
99 | #print
100 | #print ("0000", op_template)
101 | try:
102 | temp_tree = construct_tree_opblank(new_op_temps[:])
103 | except:
104 | #print ("error:", new_op_temps)
105 | temp_tree = construct_tree_opblank(['temp_0', 'temp_1', '^'])
106 | batch_seq_tree.append(temp_tree)
107 | num_list = batch_num_list[i]
108 | if num_len >= len(num_list):
109 | print ('error num len',new_op_temps, num_list)
110 | batch_flags.append(0)
111 | else:
112 | batch_flags.append(1)
113 |
114 | #print(index, temp_tree.__str__())
115 | #print(new_op_temps, num_list)
116 |
117 | #print('Batch flags: ', batch_flags)
118 | #print(batch_index[0], batch_seq_tree[0].__str__())
119 | batch_pred_tree_node, batch_pred_post_equ = model.predict_forward_recur(batch_encode_pad_idx_tensor, batch_encode_tensor_len,\
120 | batch_encode_num_pos, batch_seq_tree, batch_flags)
121 |
122 | return batch_pred_tree_node, batch_pred_post_equ
123 | #model.xxx(batch_encode_pad_idx_tensor, batch_encode_tensor_len, batch_encode_num_pos, batch_gd_tree, batch_flags)
124 |
125 | def predict_joint(self, model):
126 | batch_size = self.params['batch_size']
127 | data_generator = self.data_loader.get_batch(self.data_loader.test_list, batch_size)
128 | test_pred = []
129 | test_count = 0
130 | test_acc_e = []
131 | test_acc_e_t = []
132 | test_acc_i = []
133 |
134 | test_temp_acc = 0.0
135 | test_ans_acc = 0.0
136 |
137 | save_info = []
138 |
139 | for batch_elem in data_generator:
140 | batch_encode_idx = batch_elem['batch_encode_idx'][:]
141 | batch_encode_pad_idx = batch_elem['batch_encode_pad_idx'][:]
142 | batch_encode_num_pos = batch_elem['batch_encode_num_pos'][:]
143 | batch_encode_len = batch_elem['batch_encode_len'][:]
144 |
145 | batch_decode_idx = batch_elem['batch_decode_idx'][:]
146 |
147 | batch_gd_tree = batch_elem['batch_gd_tree'][:]
148 |
149 | batch_index = batch_elem['batch_index']
150 | batch_num_list = batch_elem['batch_num_list'][:]
151 | batch_solution = batch_elem['batch_solution']
152 | batch_post_equation = batch_elem['batch_post_equation']
153 |
154 | #b_pred, b_count, b_acc_e, b_acc_e_t, b_acc_i =
155 | batch_pred_tree_node, batch_pred_post_equ = self.predict_joint_batch(model, batch_encode_pad_idx, batch_encode_num_pos,batch_encode_len, batch_gd_tree, batch_index, batch_num_list)
156 |
157 | for i in range(len(batch_solution)):
158 | pred_post_equ = batch_pred_post_equ[i]
159 | #pred_ans = batch_pred_ans[i]
160 | gold_post_equ = batch_post_equation[i]
161 | gold_ans = batch_solution[i]
162 | idx = batch_index[i]
163 | pgseq = self.pg_seq[idx]
164 | num_list = batch_num_list[i]
165 |
166 | #print (pred_post_equ)
167 | #print (pgseq)
168 | #print (gold_post_equ)
169 | #print (num_list)
170 |
171 | if pred_post_equ == []:
172 | pred_ans = -float('inf')
173 | else:
174 | pred_post_equ_ali = []
175 | for elem in pred_post_equ:
176 | if 'temp' in elem:
177 | num_idx = int(elem[5:])
178 | num_marker = num_list[num_idx]
179 | pred_post_equ_ali.append(str(num_marker))
180 | elif 'PI' == elem:
181 | pred_post_equ_ali.append("3.141592653589793")
182 | else:
183 | pred_post_equ_ali.append(elem)
184 | try:
185 | pred_ans = post_solver(pred_post_equ_ali)
186 | except:
187 | pred_ans = -float('inf')
188 |
189 |
190 | if abs(float(pred_ans)-float(gold_ans)) < 1e-5:
191 | test_ans_acc += 1
192 | if ' '.join(pred_post_equ) == ' '.join(gold_post_equ):
193 | test_temp_acc += 1
194 | #print (pred_ans, gold_ans)
195 | #print ()
196 | save_info.append({"idx":idx, "pred_post_eq":pred_post_equ, "gold_post_equ":gold_post_equ, "pred_ans":pred_ans,"gold_ans": gold_ans})
197 |
198 |
199 | #test_pred+= b_pred
200 | #test_count += b_count
201 | #test_acc_e += b_acc_e
202 | #test_acc_e_t += b_acc_e_t
203 | #test_acc_i += b_acc_i
204 | print ("final test temp_acc:{}, ans_acc:{}".format(test_temp_acc/self.test_len, test_ans_acc/self.test_len))
205 | #logging.debug("final test temp_acc:{}, ans_acc:{}".format(test_ans_acc/self.test_len, test_ans_acc/self.test_len))
206 | write_data_json(save_info, "./result_recur/dolphin_rep/save_info_"+str(test_temp_acc/self.test_len)+"_"+str(test_ans_acc/self.test_len)+".json")
207 | return save_info
208 |
209 |
210 | def _train_recur_epoch(self, model, start_epoch, n_epoch):
211 | batch_size = self.params['batch_size']
212 | data_loader = self.data_loader
213 | train_list = data_loader.train_list
214 | valid_list = data_loader.valid_list
215 | test_list = data_loader.test_list
216 |
217 |
218 | valid_max_acc = 0
219 | for epoch in range(start_epoch, n_epoch + 1):
220 | epoch_loss = 0
221 | epoch_pred = []
222 | epoch_count = 0
223 | epoch_acc_e = []
224 | epoch_acc_e_t = []
225 | epoch_acc_i = []
226 | train_generator = self.data_loader.get_batch(train_list, batch_size)
227 | s_time = time.time()
228 | xx = 0
229 | for batch_elem in train_generator:
230 | batch_encode_idx = batch_elem['batch_encode_idx']
231 | batch_encode_pad_idx = batch_elem['batch_encode_pad_idx']
232 | batch_encode_num_pos = batch_elem['batch_encode_num_pos']
233 | batch_encode_len = batch_elem['batch_encode_len']
234 |
235 | batch_decode_idx = batch_elem['batch_decode_idx']
236 |
237 | batch_gd_tree = batch_elem['batch_gd_tree']
238 |
239 |
240 | b_pred, b_loss, b_count, b_acc_e, b_acc_e_t, b_acc_i = self._train_batch_recur(model, batch_encode_pad_idx, \
241 | batch_encode_num_pos,batch_encode_len, batch_gd_tree)
242 | epoch_loss += b_loss
243 | epoch_pred+= b_pred
244 | epoch_count += b_count
245 | epoch_acc_e += b_acc_e
246 | epoch_acc_e_t += b_acc_e_t
247 | epoch_acc_i += b_acc_i
248 | xx += 1
249 | #print (xx)
250 | #print (b_pred)
251 | #break
252 |
253 |
254 | e_time = time.time()
255 |
256 | #print ("ee": epoch_acc_e)
257 | #print ("et": epoch_acc_e_t)
258 | #print ("recur epoch: {}, loss: {}, acc_e: {}, acc_i: {} time: {}".\
259 | # format(epoch, epoch_loss/epoch_count, sum(epoch_acc_e)*1.0/sum(epoch_acc_e_t), sum(epoch_acc_i)/epoch_count, \
260 | # (e_time-s_time)/60))
261 |
262 | valid_pred, valid_count, valid_acc_e, valid_acc_e_t, valid_acc_i = self._test_recur(model, valid_list)
263 | test_pred, test_count, test_acc_e, test_acc_e_t, test_acc_i = self._test_recur(model, test_list)
264 |
265 | #print ('**********1', test_pred)
266 | #print ('**********2', test_count)
267 | #print ('**********3', test_acc_e)
268 | #print ('**********4', test_acc_e_t)
269 | #print ('**********5', test_acc_i)
270 |
271 | print ("RECUR EPOCH: {}, loss: {}, train_acc_e: {}, train_acc_i: {} time: {}".\
272 | format(epoch, epoch_loss/epoch_count, sum(epoch_acc_e)*1.0/sum(epoch_acc_e_t), sum(epoch_acc_i)/epoch_count, \
273 | (e_time-s_time)/60))
274 |
275 | print ("valid_acc_e: {}, valid_acc_i: {}, test_acc_e: {}, test_acc_i: {}".format(sum(valid_acc_e)*1.0/sum(valid_acc_e_t), sum(valid_acc_i)*1.0/valid_count, sum(test_acc_e)*1.0/sum(test_acc_e_t), sum(test_acc_i)*1.0/test_count))
276 |
277 |
278 | test_acc = sum(valid_acc_i)*1.0/valid_count
279 | if test_acc >= valid_max_acc:
280 | print ("originial", valid_max_acc)
281 | valid_max_acc = test_acc
282 | print ("saving...", valid_max_acc)
283 | if os.path.exists(self.params['save_file']):
284 | os.remove(self.params['save_file'])
285 | torch.save(model, self.params['save_file'])
286 | print ("saveing ok!")
287 | self.predict_joint(model)
288 |
289 | # save recursive data: format: {'id': , 'right_flag':, 'predict_result', 'ground_result'}
290 | # save joint data ['id':, 'right_flag':, predict_result, ground_result]
291 |
292 | else:
293 | print ("jumping...")
294 |
295 | print ()
296 |
297 |
298 | def train(self, model, optimizer):
299 | self.optimizer = optimizer
300 | #self._train_recur_epoch(model, 0, 0)
301 | #self._train_pointer_epoch(model, 0, 0)
302 | #self._train_joint_epoch(model,0, 0)
303 | #self._train_recur_epoch(model, 1, 1)
304 | #self._train_pointer_epoch(model, 1, 1)
305 | #self._train_joint_epoch(model,1, 1)
306 | #self._train_joint_epoch(model,0, 10)
307 | self._train_recur_epoch(model, 0, 100)
308 |
--------------------------------------------------------------------------------
/3_T-RNN_&_baselines/src/utils.py:
--------------------------------------------------------------------------------
1 | import json
2 | import numpy as np
3 | import time
4 |
5 | import random
6 | import pickle
7 | import sys
8 |
9 | def read_data_json(filename):
10 | with open(filename, 'r') as f:
11 | return json.load(f)
12 |
13 | def write_data_json(data, filename):
14 | with open(filename, 'w') as f:
15 | json.dump(data, f, indent=4)
16 |
17 | def save_pickle(data, filename):
18 | with open(filename, 'wb') as f:
19 | pickle.dump(data, f)
20 |
21 | def split_by_feilong_23k(data_dict):
22 | #t_path = "./data/id_ans_test"
23 | #v_path = "./data/valid_ids.json"
24 | t_path = "./data/id_ans_test_dolphin"
25 | v_path = "./data/valid_ids_dolphin.json"
26 | valid_ids = read_data_json(v_path)
27 | test_ids = []
28 | with open(t_path, 'r') as f:
29 | for line in f:
30 | test_id = line.strip().split('\t')[0]
31 | test_ids.append(test_id)
32 | train_list = []
33 | test_list = []
34 | valid_list = []
35 | for key, value in data_dict.items():
36 | if key in test_ids:
37 | test_list.append((key, value))
38 | elif key in valid_ids:
39 | valid_list.append((key, value))
40 | else:
41 | train_list.append((key, value))
42 | print (len(train_list), len(valid_list), len(test_list))
43 | return train_list, valid_list, test_list
44 |
45 | def string_2_idx_sen(sen, vocab_dict):
46 | #print(sen)
47 | return [vocab_dict[word] for word in sen]
48 |
49 | def pad_sen(sen_idx_list, max_len=115, pad_idx=1):
50 | return sen_idx_list + [pad_idx]*(max_len-len(sen_idx_list))
51 |
52 | def encoder_hidden_process(encoder_hidden, bidirectional):
53 | if encoder_hidden is None:
54 | return None
55 | if isinstance(encoder_hidden, tuple):
56 | encoder_hidden = tuple([_cat_directions(h, bidirectional) for h in encoder_hidden])
57 | else:
58 | encoder_hidden = _cat_directions(encoder_hidden, bidirectional)
59 | return encoder_hidden
60 |
61 | def _cat_directions(h, bidirectional):
62 | if bidirectional:
63 | h = torch.cat([h[0:h.size(0):2], h[1:h.size(0):2]], 2)
64 | return h
65 |
66 | def post_solver(post_equ):
67 | stack = []
68 | op_list = ['+', '-', '/', '*', '^']
69 | for elem in post_equ:
70 | if elem not in op_list:
71 | op_v = elem
72 | if '%' in op_v:
73 | op_v = float(op_v[:-1])/100.0
74 | stack.append(str(op_v))
75 | elif elem in op_list:
76 | op_v_1 = stack.pop()
77 | op_v_1 = float(op_v_1)
78 | op_v_2 = stack.pop()
79 | op_v_2 = float(op_v_2)
80 | if elem == '+':
81 | stack.append(str(op_v_2+op_v_1))
82 | elif elem == '-':
83 | stack.append(str(op_v_2-op_v_1))
84 | elif elem == '*':
85 | stack.append(str(op_v_2*op_v_1))
86 | elif elem == '/':
87 | if op_v_1 == 0:
88 | return nan
89 | stack.append(str(op_v_2/op_v_1))
90 | else:
91 | stack.append(str(op_v_2**op_v_1))
92 | return stack.pop()
93 |
94 |
95 | class BinaryTree():
96 | def __init__(self):
97 | self._init_node_info()
98 |
99 | def __str__(self):
100 | return mid_order(self)
101 |
102 | def _init_node_info(self):
103 | self.root_value = None
104 | self.left_tree = None
105 | self.right_tree = None
106 | #self.height = 0
107 | self.is_leaf = True
108 | self.pre_order_list = []
109 | self.mid_order_list = []
110 | self.post_order_list = []
111 | self.node_emb = None
112 |
113 | def _pre_order(root, order_list):
114 | if root == None:
115 | return
116 | #print ('pre\t', root.value,)
117 | order_list.append(root.root_value)
118 | _pre_order(root.left_tree, order_list)
119 | _pre_order(root.right_tree, order_list)
120 |
121 | def pre_order(root):
122 | order_list = []
123 | _pre_order(root, order_list)
124 | #print ("post:\t", order_list)
125 | return order_list
126 |
127 | def _mid_order(root, order_list):
128 | if root == None:
129 | return
130 | _mid_order(root.left_tree, order_list)
131 | #print ('mid\t', root.root_value,)
132 | order_list.append(root.root_value)
133 | _mid_order(root.right_tree, order_list)
134 |
135 | def mid_order(root):
136 | order_list = []
137 | _mid_order(root, order_list)
138 | #print ("post:\t", order_list)
139 | return order_list
140 |
141 | def _post_order(root, order_list):
142 | if root == None:
143 | return
144 | _post_order(root.left_tree, order_list)
145 | _post_order(root.right_tree, order_list)
146 | #print (root.root_value, ' ->\t', root.node_emb,)
147 | order_list.append(root.root_value)
148 |
149 | def post_order(root):
150 | order_list = []
151 | #print ("post:")
152 | _post_order(root, order_list)
153 | #print ("post:\t", order_list)
154 | #print ()
155 | return order_list
156 |
157 | def construct_tree(post_equ):
158 | stack = []
159 | op_list = ['+', '-', '/', '*', '^']
160 | for elem in post_equ:
161 | node = BinaryTree()
162 | node.root_value = elem
163 | if elem in op_list:
164 | node.right_tree = stack.pop()
165 | node.left_tree = stack.pop()
166 |
167 | node.is_leaf = False
168 |
169 | stack.append(node)
170 | else:
171 | stack.append(node)
172 | return stack.pop()
173 |
174 | def form_gdtree(elem):
175 | post_equ = elem['post_template']
176 | tree = construct_tree(post_equ)
177 | gd_list = []
178 | def _form_detail(root):
179 | if root == None:
180 | return
181 | _form_detail(root.left_tree)
182 | _form_detail(root.right_tree)
183 | #if root.left_tree != None and root.right_tree != None:
184 | if root.is_leaf == False:
185 | gd_list.append(root)
186 | _form_detail(tree)
187 | #print ('+++++++++')
188 | #print (gd_list)
189 | #print (post_equ)
190 | #for elem in gd_list:
191 | # print (post_order(elem))
192 | # print (elem.root_value)
193 | #print ('---------')
194 | #print ()
195 | return gd_list[:]
196 |
197 | def construct_tree_opblank(post_equ):
198 | stack = []
199 | op_list = ['', '^']
200 | for elem in post_equ:
201 | node = BinaryTree()
202 | node.root_value = elem
203 | if elem in op_list:
204 | node.right_tree = stack.pop()
205 | node.left_tree = stack.pop()
206 | right_len = post_order(node.right_tree)
207 | left_len = post_order(node.left_tree)
208 | if left_len <= right_len:
209 | node.temp_value = node.left_tree.temp_value
210 | else:
211 | node.temp_value = node.right_tree.temp_value
212 | node.is_leaf = False
213 |
214 | stack.append(node)
215 | else:
216 | node.temp_value = node.root_value
217 | stack.append(node)
218 | return stack.pop()
219 |
220 | def form_seqtree(seq_tree):
221 | #post_equ = elem[u'mask_post_equ_list'][2:]
222 | #tree = construct_tree(post_equ)
223 | seq_list = []
224 | def _form_detail(root):
225 | if root == None:
226 | return
227 | _form_detail(root.left_tree)
228 | _form_detail(root.right_tree)
229 | #if root.left_tree != None and root.right_tree != None:
230 | if root.is_leaf == False:
231 | seq_list.append(root)
232 | _form_detail(seq_tree)
233 | return seq_list[:]
234 |
--------------------------------------------------------------------------------
/LICENSE.md:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2020 Soumava Banerjee
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/readme.md:
--------------------------------------------------------------------------------
1 | ## Experimental Evaluation of Math Word Problem Solver
2 |
3 | ### Our contributions:
4 | - Creation of a new English corpus of arithmetic word problems dealing with {+, -, *, /} and linear variables only. We named this as Dolphin300 which is a subset of publicly available Dolphin18k.
5 | - Creation of equation templates and normalizing equations in par with Math23K dataset [1].
6 | - Experimental evaluation of T-RNN and retrieval baselines on Math23K, Dolphin300 and Dolphin1500.
7 |
8 | ## Sample data processing and cleaning:
9 |
10 | - What is the value of five times the sum of twice of three-fourths and nine? ==> What is the value of 5 times the sum of 2 of 3/4 and 9?
11 | - please help with math! What percentage of 306,800 equals eight thousands? ==> please help with math! What percentage of 306800 equals 8000?
12 | - help!!!!!!!(please) i cant figure this out!? what is the sum of 4 2/5 and 17 3/7 ? ==> help!!!!!!!(please) i cant figure this out!? what is the sum of 22/5 and 122/7 ?
13 | - math homework help? question: 15 is 25% of what number? ==> math homework help? question: 15 is 25/100 of what number?
14 |
15 | ## list of folders:
16 | - Web scraping: it contains the code (OriginalDataExtractor.py) to scrap the math word problems from yahoo answers. A basic data cleaning has also been carried out (CleanVersionExtractor.py) to get the questions in the desired format.
17 | - Data_Cleaning: it contains the code for the data cleaning of Dolphin DataSet, MWP_DataCleaning.py file has all the rule based and filtering logic for transforming the candidate dolphin datasets to cleaned templates. Inside cleaned_data_examples folder, uncleaned_dolphin_data.csv contains the raw data from dolphin dataset and filtered_cleaned_dolphin_data.json contains the filtered out cleaned template json from the csv.
18 | - T-RNN and baselines: contain T-RNN code and baseline models, output folder within contains runs of retrieval model (named as pred_*.txt) and runs of TRNN (named as run_*.txt)
19 |
20 |
21 | ## Implementation:
22 |
23 | - Implemented in >=py3.6 environment with pytorch
24 | - We used part of T-RNN code [1] and added some more implementations for Math23K
25 | - We used replicate.py and MWP_DataCleaning.ipynb to replicate data and process raw noisy Dolphin18k data.
26 | - Finally we obtain Dolphin300 and Dolphin1500 after running replicate.py on Dolphin300.
27 | - Run T-RNN code as :
28 | $ python T-RNN/src/main.py
29 | (Please see the details in the code to change input files)
30 |
31 | ### T-RNN for Math23K
32 | - In the template prediction module, we use a pre-trained word embedding with 128 units, a two-layer Bi-LSTM with 256 hidden units as encoder, a two-layer LSTM with 512 hidden units as decoder. As to the optimizer, we use Adam with learning rate set to 1e−3, β1 = 0.9 and β2 = 0.99. In the answer generation module, we use a embedding layer with 100 units, a two-layer Bi-LSTM with 160 hidden units. SGD with learning rate 0.01 and momentum factor 0.9 is used to optimize this module. In both components, the number of epochs, mini-batch size and dropout rate are set 100, 32 and 0.5 respectively.
33 |
34 | ### T-RNN for Dolphin1500 & Dolphin300
35 |
36 | - Template prediction module: 128 units, two-layer Bi-LSTM with 256 hidden units as encoder, a two-layer LSTM with 512 hidden units as decoder.
37 | ADAM optimizer with default parameters.
38 | Answer generation module - embedding layer with 100 units, a two-layer Bi-LSTM with 160 hidden units, RNN classes = 4.SGD with learning rate 0.01 and momentum factor 0.9 is used to optimize this module. the number of epochs, mini-batch size and dropout rate are set 50, 32 and 0.5 respectively.
39 |
40 |
41 | References:
42 |
43 | [1] Lei Wang, Dongxiang Zhang, Jipeng Zhang, Xing Xu, Lianli Gao, Bingtian Dai, and Heng Tao Shen. Template-based math word problem solvers with recursive neural networks. 2019.
44 |
45 | [2] Yan Wang, Xiaojiang Liu, and Shuming Shi. Deep neural solver for math word problems. In Proceedings of the 2017 Conference on Empirical Methods in Natural Language Processing, pages 845–854, 2017.
46 |
--------------------------------------------------------------------------------