├── 1_WebScraping_Dolphin ├── CleanVersionExtractor.py └── OriginalDataExtractor.py ├── 2_Data_cleaning ├── Initial_Data_Cleaning.ipynb ├── MWP_DataCleaning.py └── cleaned_data_examples │ ├── filtered_cleaned_dolphin_data.json │ └── uncleaned_dolphin_data.csv ├── 3_T-RNN_&_baselines ├── output │ ├── pred_dolphin_rep_tfidf.txt │ ├── pred_dolphin_tfidf.txt │ ├── pred_math23k_tfidf.txt │ ├── run_dolphin.txt │ ├── run_dolphin_rep.txt │ └── run_math23k.txt └── src │ ├── data_loader.py │ ├── main.py │ ├── model.py │ ├── replicate.py │ ├── retrieval_model.py │ ├── trainer.py │ └── utils.py ├── LICENSE.md └── readme.md /1_WebScraping_Dolphin/CleanVersionExtractor.py: -------------------------------------------------------------------------------- 1 | import urllib 2 | import difflib 3 | import pickle 4 | import json 5 | import threading 6 | import sys 7 | import getopt 8 | 9 | reload(sys) 10 | sys.setdefaultencoding('utf8') 11 | 12 | 13 | def generate_clean_question_file(diff_file, raw_file, out_clean_question_file): 14 | pkl_diff = open(diff_file, 'rb') 15 | records = pickle.load(pkl_diff) 16 | word_prob_groups = [] 17 | fp = open(raw_file) 18 | word_probs_datas = json.load(fp, encoding='utf8') 19 | for word_prob_data in word_probs_datas: 20 | qid = word_prob_data["id"] 21 | # if qid.find("20130905131554AArnMlQ") != -1: 22 | # print "debug" 23 | original_text = word_prob_data["original_text"] 24 | text = original_text 25 | diff_oprs = records[qid] 26 | for tag, i1, i2, j1, j2, ins_text in reversed(diff_oprs): 27 | if tag == 'delete': 28 | text = text[0:i1] + text[i2:] 29 | elif tag == 'insert': 30 | text = text[0:i1] + ins_text + text[i2:] 31 | elif tag == 'replace': 32 | text = text[0:i1] + ins_text + text[i2:] 33 | word_prob_data["text"] = text 34 | word_prob_groups.append(word_prob_data) 35 | fp.close() 36 | fp = open(out_clean_question_file, 'w') 37 | json.dump(word_prob_groups, fp, indent=2, ensure_ascii=False) 38 | fp.close() 39 | 40 | 41 | if __name__ == "__main__": 42 | original_file, diff_file, out_file = "", "", "" 43 | try: 44 | opts, args = getopt.getopt(sys.argv[1:], 'i:d:o:') 45 | except getopt.GetoptError: 46 | print 'invalid input format' 47 | sys.exit(2) 48 | for opt, arg in opts: 49 | if opt == "-i": 50 | original_file = arg 51 | elif opt == "-d": 52 | diff_file = arg 53 | elif opt == "-o": 54 | out_file = arg 55 | original_file = "eval_original.json" 56 | diff_file = "eval_diff.pkl" 57 | out_file = "eval_cleaned.json" 58 | generate_clean_question_file(diff_file, original_file, out_file) 59 | -------------------------------------------------------------------------------- /1_WebScraping_Dolphin/OriginalDataExtractor.py: -------------------------------------------------------------------------------- 1 | import urllib 2 | import difflib 3 | import pickle 4 | import json 5 | import threading 6 | import sys 7 | import getopt 8 | 9 | reload(sys) 10 | sys.setdefaultencoding('utf8') 11 | 12 | 13 | def getHtml(url): 14 | page = urllib.urlopen(url) 15 | html = page.read() 16 | return html 17 | 18 | 19 | def extract_raw_problom(str_html): 20 | meta_problem_title = "" 23 | str_content = "" 24 | i_problem_title_start = str_html.find(meta_problem_title) 25 | if i_problem_title_start != -1: 26 | i_problem_title_end = str_html.find(problem_text_end, i_problem_title_start) 27 | str_problem_title = str_html[i_problem_title_start + len(meta_problem_title):i_problem_title_end] 28 | if str_problem_title != "": 29 | str_content += str_problem_title + "\n" 30 | i_problem_text_start = str_html.find(meta_problem_text) 31 | if i_problem_text_start != -1: 32 | i_problem_text_end = str_html.find(problem_text_end, i_problem_text_start) 33 | str_problem_text = str_html[i_problem_text_start + len(meta_problem_text):i_problem_text_end] 34 | if str_problem_text != "": 35 | str_content += str_problem_text 36 | str_content = str_content.replace("&#39;", "\'").replace("&#92;", "\\").replace("&lt;", "<").replace( 37 | "&gt;", ">").replace("&quot;", "\"") 38 | str_content = str_content.replace("&", "&") 39 | return str_content 40 | 41 | 42 | def is_page_not_found(str_content): 43 | if str_content.find("This question does not exist or is under review.") != -1: 44 | return True 45 | return False 46 | 47 | 48 | '''step 1: TODO multi-thread''' 49 | 50 | 51 | def generate_raw_question_file(input_file, out_raw_question_file): 52 | word_prob_groups = [] 53 | fp = open(input_file, 'r') 54 | word_probs_datas = json.load(fp) 55 | index = 0 56 | for word_prob_data in word_probs_datas: 57 | url = word_prob_data["id"].replace("yahoo.answers.", "") 58 | print url 59 | url = "https://answers.yahoo.com/question/index?qid=" + url 60 | html = getHtml(url) 61 | # fpage = open("dev_pages\\" + str(index) + ".txt") 62 | # html = fpage.read() 63 | str_content = extract_raw_problom(html) 64 | word_prob_data["original_text"] = str_content 65 | word_prob_groups.append(word_prob_data) 66 | index += 1 67 | fp.close() 68 | fp = open(out_raw_question_file, 'w') 69 | json.dump(word_prob_groups, fp, indent=2, ensure_ascii=False) 70 | fp.close() 71 | 72 | 73 | def assign_original_text(qids, qid2index, word_probs_datas): 74 | for qid in qids: 75 | print qid 76 | url = qid.replace("yahoo.answers.", "") 77 | url = "https://answers.yahoo.com/question/index?qid=" + url 78 | html = getHtml(url) 79 | # fpage = open("dev_pages\\" + str(qid2index[qid]) + ".txt") 80 | # html = fpage.read() 81 | str_content = extract_raw_problom(html) 82 | word_probs_datas[qid2index[qid]]["original_text"] = str_content 83 | 84 | 85 | def generate_raw_question_file_multi_thread(input_file, out_raw_question_file, ithread=10): 86 | qid2index = {} 87 | fp = open(input_file, 'r') 88 | word_probs_datas = json.load(fp) 89 | index = 0 90 | split_qids = [] 91 | for i in range(0, ithread): 92 | split_qids.append([]) 93 | for word_prob_data in word_probs_datas: 94 | qid = word_prob_data["id"] 95 | qid2index[qid] = index 96 | split_qids[index % ithread].append(qid) 97 | index += 1 98 | fp.close() 99 | threads = [] 100 | for i in range(0, ithread): 101 | # tid = threading.Thread(target=assign_original_text, args=(split_qids[i], qid2index, word_probs_datas, ) ) 102 | # tid.start() 103 | try: 104 | tid = threading.Thread(target=assign_original_text, args=(split_qids[i], qid2index, word_probs_datas,)) 105 | tid.start() 106 | threads.append(tid) 107 | except: 108 | print "Error: unable to start thread" 109 | while True: 110 | is_alive = False 111 | for tid in threads: 112 | if tid.isAlive(): 113 | is_alive = True 114 | break 115 | if is_alive == False: 116 | break 117 | fp = open(out_raw_question_file, 'w') 118 | json.dump(word_probs_datas, fp, indent=2, ensure_ascii=False) 119 | fp.close() 120 | 121 | 122 | '''temp, to be deleted''' 123 | 124 | 125 | def generate_diff_file(json_file, out_diff_file): 126 | id2text = {} 127 | fp = open("dev_dataset_full.json") 128 | word_probs_datas = json.load(fp, encoding='utf8') 129 | for word_prob_data in word_probs_datas: 130 | qid = word_prob_data["id"] 131 | text = word_prob_data["text"] 132 | id2text[qid] = text 133 | fp.close() 134 | fp = open(json_file) 135 | word_probs_datas = json.load(fp, encoding='utf8') 136 | records = {} 137 | for word_prob_data in word_probs_datas: 138 | qid = word_prob_data["id"] 139 | # text = word_prob_data["text"] 140 | text = id2text[qid] 141 | original_text = word_prob_data["original_text"] 142 | matcher = difflib.SequenceMatcher(None, original_text, text) 143 | new_record = [] 144 | for rec in matcher.get_opcodes(): 145 | if rec[0] == 'equal': 146 | continue 147 | if rec[0] == 'insert' or rec[0] == 'replace': 148 | rec = rec + (text[rec[3]:rec[4]],) 149 | else: 150 | rec = rec + ("",) 151 | new_record.append(rec) 152 | records[qid] = new_record 153 | fp.close() 154 | output = open(out_diff_file, 'wb') 155 | pickle.dump(records, output) 156 | output.close() 157 | 158 | 159 | def validate(gold_file, merge_file): 160 | fp = open(gold_file) 161 | word_probs_datas = json.load(fp) 162 | id2text, id2ori = {}, {} 163 | for word_prob_data in word_probs_datas: 164 | qid = word_prob_data["id"] 165 | text = word_prob_data["text"] 166 | if word_prob_data.has_key("original_text"): 167 | original_text = word_prob_data["original_text"] 168 | id2ori[qid] = original_text 169 | id2text[qid] = text 170 | fp.close() 171 | fp1 = open(merge_file) 172 | word_probs_data_merge = json.load(fp1) 173 | for word_prob_data in word_probs_data_merge: 174 | qid = word_prob_data["id"] 175 | text = word_prob_data["text"] 176 | if word_prob_data.has_key("original_text"): 177 | original_text = word_prob_data["original_text"] 178 | if id2ori.has_key(qid) and original_text != id2ori[qid]: 179 | print qid 180 | if text != id2text[qid]: 181 | print qid 182 | fp1.close() 183 | 184 | 185 | if __name__ == "__main__": 186 | ''' 187 | fp = open("eval_dataset_full.json") 188 | word_probs_datas = json.load(fp) 189 | for word_prob_data in word_probs_datas: 190 | if word_prob_data.has_key("type"): 191 | del word_prob_data["type"] 192 | fp = open("eval_urls1.json", 'w') 193 | json.dump(word_probs_datas, fp, indent = 2, ensure_ascii=False) 194 | fp.close() 195 | ''' 196 | 197 | url_file, out_file = "", "" 198 | try: 199 | opts, args = getopt.getopt(sys.argv[1:], 'i:t:o:') 200 | except getopt.GetoptError: 201 | print 'invalid input format' 202 | sys.exit(2) 203 | for opt, arg in opts: 204 | if opt == "-i": 205 | url_file = arg 206 | elif opt == "-t": 207 | i_thread = int(arg) 208 | elif opt == "-o": 209 | out_file = arg 210 | url_file = "eval_urls.json" 211 | out_file = "eval_original.json" 212 | i_thread = 10 213 | generate_raw_question_file_multi_thread(url_file, out_file, i_thread) 214 | # config = "dev" 215 | # generate_diff_file(config + "_original.json", config + "_diff.pkl") 216 | -------------------------------------------------------------------------------- /2_Data_cleaning/Initial_Data_Cleaning.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "colab": { 6 | "name": "Initial Data Cleaning.ipynb", 7 | "provenance": [], 8 | "collapsed_sections": [], 9 | "toc_visible": true, 10 | "machine_shape": "hm" 11 | }, 12 | "kernelspec": { 13 | "name": "python3", 14 | "display_name": "Python 3" 15 | } 16 | }, 17 | "cells": [ 18 | { 19 | "cell_type": "markdown", 20 | "metadata": { 21 | "id": "ihfXrF0zE-5H", 22 | "colab_type": "text" 23 | }, 24 | "source": [ 25 | "Import Libraries\n", 26 | "--" 27 | ] 28 | }, 29 | { 30 | "cell_type": "code", 31 | "metadata": { 32 | "id": "ivCIhm4zyRyI", 33 | "colab_type": "code", 34 | "outputId": "69d1f301-789c-4c09-8059-bb9a72a53c1a", 35 | "colab": { 36 | "base_uri": "https://localhost:8080/", 37 | "height": 35 38 | } 39 | }, 40 | "source": [ 41 | "import json, csv\n", 42 | "import numpy as np\n", 43 | "import pandas as pd\n", 44 | "import re\n", 45 | "import warnings\n", 46 | "\n", 47 | "warnings.filterwarnings('ignore')\n", 48 | "\n", 49 | "print(\"Libraries Imported\")" 50 | ], 51 | "execution_count": 1, 52 | "outputs": [ 53 | { 54 | "output_type": "stream", 55 | "text": [ 56 | "Libraries Imported\n" 57 | ], 58 | "name": "stdout" 59 | } 60 | ] 61 | }, 62 | { 63 | "cell_type": "code", 64 | "metadata": { 65 | "id": "mvxifhcnAx6h", 66 | "colab_type": "code", 67 | "colab": {} 68 | }, 69 | "source": [ 70 | "def getText(doc):\n", 71 | " doc = str(doc)\n", 72 | " doc = doc.lower().strip()\n", 73 | " doc = re.sub('\\n', ' ', doc)\n", 74 | " doc = re.sub(r'\\s+', ' ', doc)\n", 75 | " m = re.search(r'',doc)\n", 76 | " m1 = re.search(r'',doc)\n", 77 | "\n", 78 | " if m != None and m1!= None:\n", 79 | " text = str(m.group(1)) + ' ' + str(m1.group(1))\n", 80 | " else:\n", 81 | " text = \"No match\"\n", 82 | "\n", 83 | " return text" 84 | ], 85 | "execution_count": 0, 86 | "outputs": [] 87 | }, 88 | { 89 | "cell_type": "markdown", 90 | "metadata": { 91 | "id": "8EXob7wzFNcz", 92 | "colab_type": "text" 93 | }, 94 | "source": [ 95 | "Preparing Datasets\n", 96 | "--" 97 | ] 98 | }, 99 | { 100 | "cell_type": "code", 101 | "metadata": { 102 | "id": "YgcH2x0NY6Rz", 103 | "colab_type": "code", 104 | "colab": {} 105 | }, 106 | "source": [ 107 | "data = pd.read_json('eval_cleaned.json')\n", 108 | "\n", 109 | "for i, row in data.iterrows():\n", 110 | " if re.match(r\"^\\\"/>\n", 163 | "Int64Index: 6625 entries, 19 to 14722\n", 164 | "Data columns (total 4 columns):\n", 165 | "text 6625 non-null object\n", 166 | "ans 6625 non-null object\n", 167 | "equations 6625 non-null object\n", 168 | "unknowns 6625 non-null object\n", 169 | "dtypes: object(4)\n", 170 | "memory usage: 258.8+ KB\n" 171 | ], 172 | "name": "stdout" 173 | } 174 | ] 175 | }, 176 | { 177 | "cell_type": "markdown", 178 | "metadata": { 179 | "id": "l_j4sf8UFRAO", 180 | "colab_type": "text" 181 | }, 182 | "source": [ 183 | "Data Cleaning\n", 184 | "--" 185 | ] 186 | }, 187 | { 188 | "cell_type": "code", 189 | "metadata": { 190 | "id": "PdhRZ_R1FTCI", 191 | "colab_type": "code", 192 | "outputId": "136792b6-a0ff-4cc7-b1d7-5ae7cf3f4d62", 193 | "colab": { 194 | "base_uri": "https://localhost:8080/", 195 | "height": 69 196 | } 197 | }, 198 | "source": [ 199 | "import nltk\n", 200 | "nltk.download('stopwords')\n", 201 | "\n", 202 | "from nltk.corpus import stopwords\n", 203 | "\n", 204 | "import spacy\n", 205 | "\n", 206 | "# import string\n", 207 | "\n", 208 | "import re\n", 209 | "\n", 210 | "nlp = spacy.load(\"en\")\n", 211 | "\n", 212 | "nltk_stopwords = set(stopwords.words('english'))\n", 213 | "\n", 214 | "spacy_stopwords = nlp.Defaults.stop_words\n", 215 | "\n", 216 | "stopset = nltk_stopwords.union(spacy_stopwords)\n", 217 | "\n", 218 | "stopset.difference_update([\"a\",\"more\",\"less\",\"than\",\"one\",\"two\",\"three\",\"four\",\"five\",\"six\",\"seven\",\"eight\",\"nine\",\"ten\",\"eleven\",\"twelve\",\"fifteen\",\"twenty\",\"forty\",\"sixty\",\"fifty\",\"hundred\",\"once\",\"first\",\"second\",\"third\"])\n", 219 | "\n", 220 | "punctuation = \"!\\\"#$&',;?@\\_`{|}~\"\n", 221 | "\n", 222 | "def getText(doc):\n", 223 | " doc = str(doc)\n", 224 | " doc = doc.lower().strip()\n", 225 | " doc = re.sub('\\n', ' ', doc)\n", 226 | " doc = re.sub(r'\\s+', ' ', doc)\n", 227 | " m = re.search(r'',doc)\n", 228 | " m1 = re.search(r'',doc)\n", 229 | "\n", 230 | " if m != None and m1!= None:\n", 231 | " text = str(m.group(1)) + ' ' + str(m1.group(1))\n", 232 | " else:\n", 233 | " text = \"No match\"\n", 234 | "\n", 235 | " return text\n", 236 | "\n", 237 | "\n", 238 | "def cleanData(doc):\n", 239 | " doc = str(doc)\n", 240 | " doc = doc.lower().strip()\n", 241 | " doc = re.sub('\\n', ' ', doc)\n", 242 | " doc = re.sub(r'\\s+', ' ', doc)\n", 243 | " pattern = '\"/>\n", 330 | "Int64Index: 6487 entries, 19 to 14722\n", 331 | "Data columns (total 5 columns):\n", 332 | "text 6487 non-null object\n", 333 | "ans 6487 non-null object\n", 334 | "equations 6487 non-null object\n", 335 | "unknowns 6487 non-null object\n", 336 | "cleaned_text 6487 non-null object\n", 337 | "dtypes: object(5)\n", 338 | "memory usage: 304.1+ KB\n" 339 | ], 340 | "name": "stdout" 341 | } 342 | ] 343 | }, 344 | { 345 | "cell_type": "markdown", 346 | "metadata": { 347 | "id": "OpkqG1aVO2Z5", 348 | "colab_type": "text" 349 | }, 350 | "source": [ 351 | "Data Modelling (Archieve)\n", 352 | "--" 353 | ] 354 | }, 355 | { 356 | "cell_type": "code", 357 | "metadata": { 358 | "id": "WvbmU8_8XLBX", 359 | "colab_type": "code", 360 | "outputId": "c7fd0123-7402-4bdd-e6ff-7d2fcdb47475", 361 | "colab": { 362 | "base_uri": "https://localhost:8080/", 363 | "height": 399 364 | } 365 | }, 366 | "source": [ 367 | "import pandas as pd\n", 368 | "\n", 369 | "data = pd.read_csv('new_cleaned_data.csv')\n", 370 | "\n", 371 | "from sklearn.model_selection import train_test_split\n", 372 | "\n", 373 | "trainData, testData = train_test_split(data, test_size = 0.2)\n", 374 | "\n", 375 | "trainData.rename(columns={'cleaned text': 'cleaned_text'}, inplace=True)\n", 376 | "\n", 377 | "testData.rename(columns={'cleaned text': 'cleaned_text'}, inplace=True)\n", 378 | "\n", 379 | "trainData = trainData.reset_index(drop=True)\n", 380 | "\n", 381 | "testData = testData.reset_index(drop=True)\n", 382 | "\n", 383 | "print(trainData.info())\n", 384 | "\n", 385 | "print(testData.info())" 386 | ], 387 | "execution_count": 0, 388 | "outputs": [ 389 | { 390 | "output_type": "stream", 391 | "text": [ 392 | "\n", 393 | "Int64Index: 1896 entries, 1302 to 771\n", 394 | "Data columns (total 5 columns):\n", 395 | "ans 1896 non-null object\n", 396 | "cleaned_text 1896 non-null object\n", 397 | "equations 1896 non-null object\n", 398 | "text 1896 non-null object\n", 399 | "unknowns 1895 non-null object\n", 400 | "dtypes: object(5)\n", 401 | "memory usage: 88.9+ KB\n", 402 | "None\n", 403 | "\n", 404 | "Int64Index: 474 entries, 1928 to 1444\n", 405 | "Data columns (total 5 columns):\n", 406 | "ans 474 non-null object\n", 407 | "cleaned_text 474 non-null object\n", 408 | "equations 474 non-null object\n", 409 | "text 474 non-null object\n", 410 | "unknowns 471 non-null object\n", 411 | "dtypes: object(5)\n", 412 | "memory usage: 22.2+ KB\n", 413 | "None\n" 414 | ], 415 | "name": "stdout" 416 | } 417 | ] 418 | }, 419 | { 420 | "cell_type": "code", 421 | "metadata": { 422 | "id": "52SBxzo3HKMV", 423 | "colab_type": "code", 424 | "colab": {} 425 | }, 426 | "source": [ 427 | "from sklearn.feature_extraction.text import TfidfVectorizer\n", 428 | "\n", 429 | "# print(data.info())\n", 430 | "\n", 431 | "tfidf = TfidfVectorizer(sublinear_tf=True, min_df=1, norm='l2', encoding='latin-1', ngram_range=(1, 2), stop_words= None)\n", 432 | "\n", 433 | "tfidf.fit(trainData['cleaned_text'])\n", 434 | "\n", 435 | "features = tfidf.transform(trainData['cleaned_text']).toarray()\n", 436 | "\n", 437 | "# test = \"Three times the first of three consecutive odd integers is 3 more than twice the third . What is the third integer ?\"\n", 438 | "\n", 439 | "# testClean = cleanData(test)\n", 440 | "\n", 441 | "# print(testClean)\n", 442 | "\n", 443 | "# test_feature = tfidf.transform([testClean]).toarray()\n", 444 | "\n", 445 | "test_features = tfidf.transform(testData['cleaned_text']).toarray()\n", 446 | "\n", 447 | "# print(test_features)\n", 448 | "\n", 449 | "# print(features)" 450 | ], 451 | "execution_count": 0, 452 | "outputs": [] 453 | }, 454 | { 455 | "cell_type": "code", 456 | "metadata": { 457 | "id": "LzubaLyqhBJl", 458 | "colab_type": "code", 459 | "outputId": "625bbd87-2908-4c73-e70d-5d999793212c", 460 | "colab": { 461 | "base_uri": "https://localhost:8080/", 462 | "height": 433 463 | } 464 | }, 465 | "source": [ 466 | "testData['matchedQuestion'] = ''\n", 467 | "testData['matchedEq'] = ''\n", 468 | "\n", 469 | "print(trainData.info())\n", 470 | "\n", 471 | "print(testData.info())" 472 | ], 473 | "execution_count": 0, 474 | "outputs": [ 475 | { 476 | "output_type": "stream", 477 | "text": [ 478 | "\n", 479 | "RangeIndex: 1896 entries, 0 to 1895\n", 480 | "Data columns (total 5 columns):\n", 481 | "ans 1896 non-null object\n", 482 | "cleaned_text 1896 non-null object\n", 483 | "equations 1896 non-null object\n", 484 | "text 1896 non-null object\n", 485 | "unknowns 1895 non-null object\n", 486 | "dtypes: object(5)\n", 487 | "memory usage: 74.2+ KB\n", 488 | "None\n", 489 | "\n", 490 | "RangeIndex: 474 entries, 0 to 473\n", 491 | "Data columns (total 7 columns):\n", 492 | "ans 474 non-null object\n", 493 | "cleaned_text 474 non-null object\n", 494 | "equations 474 non-null object\n", 495 | "text 474 non-null object\n", 496 | "unknowns 471 non-null object\n", 497 | "matchedQuestion 474 non-null object\n", 498 | "matchedEq 474 non-null object\n", 499 | "dtypes: object(7)\n", 500 | "memory usage: 26.0+ KB\n", 501 | "None\n" 502 | ], 503 | "name": "stdout" 504 | } 505 | ] 506 | }, 507 | { 508 | "cell_type": "code", 509 | "metadata": { 510 | "id": "T8xvBdI6gmKo", 511 | "colab_type": "code", 512 | "colab": {} 513 | }, 514 | "source": [ 515 | "from scipy import spatial\n", 516 | "\n", 517 | "score = 0\n", 518 | "index = 0\n", 519 | "\n", 520 | "\n", 521 | "for i, row1 in testData.iterrows():\n", 522 | " score = 0\n", 523 | " for j, row2 in trainData.iterrows():\n", 524 | " similarity = 1 - spatial.distance.cosine(test_features[i], features[j])\n", 525 | " if similarity > score:\n", 526 | " score = similarity\n", 527 | " testData.at[i,'matchedQuestion'] = row2['cleaned_text']\n", 528 | " testData.at[i, 'matchedEq'] = row2['equations']" 529 | ], 530 | "execution_count": 0, 531 | "outputs": [] 532 | }, 533 | { 534 | "cell_type": "code", 535 | "metadata": { 536 | "id": "BS9iUcSVNYpD", 537 | "colab_type": "code", 538 | "colab": {} 539 | }, 540 | "source": [ 541 | "testData.info()\n", 542 | "\n", 543 | "testData.to_csv(\"cosineSimilarity.csv\", index = False)" 544 | ], 545 | "execution_count": 0, 546 | "outputs": [] 547 | }, 548 | { 549 | "cell_type": "code", 550 | "metadata": { 551 | "id": "VIrqPv3czZEw", 552 | "colab_type": "code", 553 | "colab": {} 554 | }, 555 | "source": [ 556 | "from scipy import spatial\n", 557 | "\n", 558 | "score = 0\n", 559 | "index = 0\n", 560 | "\n", 561 | "def similarity(sen1, sen2):\n", 562 | " score = np.dot(sen1, sen2)/(np.linalg.norm(sen1)*np.linalg.norm(sen2))\n", 563 | " return score\n", 564 | "\n", 565 | "for i, row1 in testData.iterrows():\n", 566 | " score = 0\n", 567 | " for j, row2 in trainData.iterrows():\n", 568 | " similarity = 1 - similarity(test_features[i], features[j])\n", 569 | " if similarity > score:\n", 570 | " score = similarity\n", 571 | " testData.at[i,'matchedQuestion'] = row2['cleaned_text']\n", 572 | " testData.at[i, 'matchedEq'] = row2['equations']" 573 | ], 574 | "execution_count": 0, 575 | "outputs": [] 576 | }, 577 | { 578 | "cell_type": "code", 579 | "metadata": { 580 | "id": "18jlqxG3DJa3", 581 | "colab_type": "code", 582 | "outputId": "106fb3d7-d495-4b08-eeea-90864e0bd389", 583 | "colab": { 584 | "base_uri": "https://localhost:8080/", 585 | "height": 225 586 | } 587 | }, 588 | "source": [ 589 | "testData.info()\n", 590 | "\n", 591 | "testData.to_csv(\"generalSimilarity.csv\", index = False)" 592 | ], 593 | "execution_count": 0, 594 | "outputs": [ 595 | { 596 | "output_type": "stream", 597 | "text": [ 598 | "\n", 599 | "RangeIndex: 474 entries, 0 to 473\n", 600 | "Data columns (total 7 columns):\n", 601 | "ans 474 non-null object\n", 602 | "cleaned_text 474 non-null object\n", 603 | "equations 474 non-null object\n", 604 | "text 474 non-null object\n", 605 | "unknowns 471 non-null object\n", 606 | "matchedQuestion 474 non-null object\n", 607 | "matchedEq 474 non-null object\n", 608 | "dtypes: object(7)\n", 609 | "memory usage: 26.0+ KB\n" 610 | ], 611 | "name": "stdout" 612 | } 613 | ] 614 | }, 615 | { 616 | "cell_type": "code", 617 | "metadata": { 618 | "id": "smFQYt8gHHxR", 619 | "colab_type": "code", 620 | "colab": {} 621 | }, 622 | "source": [ 623 | "from math import *\n", 624 | "\n", 625 | "score = 0\n", 626 | "index = 0\n", 627 | "\n", 628 | "def jaccard_similarity(sen1, sen2):\n", 629 | " intersection = len(set.intersection(*[set(sen1), set(sen2)]))\n", 630 | "\n", 631 | " union = len(set.union(*[set(sen1), set(set2)]))\n", 632 | "\n", 633 | " score = intersection/float(union)\n", 634 | "\n", 635 | " return score\n", 636 | "\n", 637 | "for i, row1 in testData.iterrows():\n", 638 | " score = 0\n", 639 | " for j, row2 in trainData.iterrows():\n", 640 | " similarity = jaccard_similarity(test_features[i], features[j])\n", 641 | " if similarity > score:\n", 642 | " score = similarity\n", 643 | " testData.at[i,'matchedQuestion'] = row2['cleaned_text']\n", 644 | " testData.at[i, 'matchedEq'] = row2['equations']" 645 | ], 646 | "execution_count": 0, 647 | "outputs": [] 648 | }, 649 | { 650 | "cell_type": "code", 651 | "metadata": { 652 | "id": "hwLHfc5gNbYF", 653 | "colab_type": "code", 654 | "colab": {} 655 | }, 656 | "source": [ 657 | "testData.info()\n", 658 | "\n", 659 | "testData.to_csv(\"jaccardSimilarity.csv\", index = False)" 660 | ], 661 | "execution_count": 0, 662 | "outputs": [] 663 | } 664 | ] 665 | } -------------------------------------------------------------------------------- /2_Data_cleaning/MWP_DataCleaning.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import numpy as np 3 | import re 4 | import json 5 | import nltk 6 | from word2number import w2n 7 | from nltk.stem.snowball import SnowballStemmer 8 | 9 | df = pd.read_csv('./Cleaned Data/trainData_univariable.csv') 10 | 11 | df = df[np.invert(np.array(df['text'].isna()))] 12 | numMap = {"twice": 2, "double": 2, "thrice": 3, "half": "1/2", "tenth": "1/10", "quarter": "1/4", "fifth": "1/5"} 13 | fraction = {"third": "/3", "half": "/2", "fourth": "/4", "sixth": "/6", "fifth": "/5", "seventh": "/7", "eighth": "/8", 14 | "ninth": "/9", "tenth": "/10", "eleventh": "/11", "twelfth": "/12", "thirteenth": "/13", 15 | "fourteenth": "/14", "fifteenth": "/15", "sixteenth": "/16", "seventeenth": "/17", "eighteenth": "/18", 16 | "nineteenth": "/19", "twentieth": "/20", "%": "/100"} 17 | 18 | stemmer = SnowballStemmer(language='english') 19 | 20 | 21 | # Convert fractions to floating point numbers 22 | def convert_to_float(frac_str): 23 | try: 24 | return float(frac_str) 25 | except ValueError: 26 | num, denom = frac_str.split('/') 27 | try: 28 | leading, num = num.split(' ') 29 | whole = float(leading) 30 | except ValueError: 31 | whole = 0 32 | frac = float(num) / float(denom) 33 | return whole - frac if whole < 0 else whole + frac 34 | 35 | 36 | # Convert infix expression to suffix expression 37 | def postfix_equation(equ_list): 38 | stack = [] 39 | post_equ = [] 40 | op_list = ['+', '-', '*', '/', '^'] 41 | priori = {'^': 3, '*': 2, '/': 2, '+': 1, '-': 1} 42 | for elem in equ_list: 43 | if elem == '(': 44 | stack.append('(') 45 | elif elem == ')': 46 | while 1: 47 | op = stack.pop() 48 | if op == '(': 49 | break 50 | else: 51 | post_equ.append(op) 52 | elif elem in op_list: 53 | while 1: 54 | if not stack: 55 | break 56 | elif stack[-1] == '(': 57 | break 58 | elif priori[elem] > priori[stack[-1]]: 59 | break 60 | else: 61 | op = stack.pop() 62 | post_equ.append(op) 63 | stack.append(elem) 64 | else: 65 | post_equ.append(elem) 66 | while stack: 67 | post_equ.append(stack.pop()) 68 | return post_equ 69 | 70 | 71 | # Identification and filtering of univariable equations 72 | char_set = {'0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '(', ')', '+', '-', '/', '*'} 73 | df_cleaned = pd.DataFrame() 74 | for id_, row in df.iterrows(): 75 | l, r = row['equations'].split("=", 1) 76 | lSet, rSet = set(l.replace(" ", "")), set(r.replace(" ", "")) 77 | flagl = (len(l.strip()) == 1 and not l.strip().isdigit() and len(rSet - char_set) == 0) 78 | flagr = (len(r.strip()) == 1 and not r.strip().isdigit() and len(lSet - char_set) == 0) 79 | if flagl or flagr: 80 | if flagr: 81 | row['equations'] = r + '=' + l 82 | 83 | df_cleaned = df_cleaned.append(row) 84 | 85 | k = 0 86 | 87 | numLists = {} 88 | numLists_idx = {} 89 | eqLists_idx = {} 90 | eqLists = {} 91 | texts = {} 92 | equations_List = {} 93 | final_ans = {} 94 | numListMAP = {} 95 | final_replaced_text = {} 96 | final_replaced_eq = {} 97 | final_replaced_eq_post = {} 98 | final_number_list = {} 99 | final_num_postn_list = {} 100 | numtemp_order = {} 101 | 102 | for id_, row in df_cleaned.iterrows(): 103 | # Converting fractions and ordinals to values through appropriate sub-routines and string builder ops 104 | if not bool(re.search(r'[\d]+', row['text'])): 105 | continue 106 | sb = "" 107 | numSb = "" 108 | val = 0 109 | prevToken = "" 110 | for tokens in nltk.word_tokenize((row['text'])): 111 | try: 112 | val += w2n.word_to_num(tokens) 113 | except ValueError: 114 | if val > 0: 115 | sb = sb + " " + str(val) 116 | if tokens in fraction: 117 | sb = sb + fraction[tokens] 118 | elif stemmer.stem(tokens) in fraction: 119 | sb = sb + fraction[stemmer.stem(tokens)] 120 | else: 121 | sb = sb + " " + tokens 122 | val = 0 123 | else: 124 | if tokens in numMap.keys(): 125 | sb = sb + " " + str(numMap[tokens]) 126 | else: 127 | sb = sb + " " + tokens 128 | prevToken = tokens 129 | 130 | re.sub(' +', ' ', sb) 131 | sb = sb.strip() 132 | 133 | # Re-structure equation anomalies for normalization 134 | row['equations'] = row['equations'].replace(' ', '') 135 | eqs = row['equations'][0] 136 | 137 | for i in range(1, len(row['equations']) - 1): 138 | if row['equations'][i] == '(': 139 | if row['equations'][i - 1] == ')' or row['equations'][i - 1].isdigit(): 140 | eqs += '*' + row['equations'][i] 141 | else: 142 | eqs += row['equations'][i] 143 | elif row['equations'][i + 1].isdigit() and row['equations'][i] == ')': 144 | eqs += row['equations'][i] + "*" 145 | 146 | elif row['equations'][i] == '-' and (row['equations'][i - 1].isdigit() or row['equations'][i - 1] == ')') and \ 147 | row['equations'][i + 1].isdigit(): 148 | eqs += "+" + row['equations'][i] 149 | else: 150 | eqs += row['equations'][i] 151 | eqs += row['equations'][-1] 152 | 153 | # Extract different number types from the text in order 154 | numList = re.findall(r'-?[\d]* *[\d]+ *\/ *?[\d]*|-?[\d]+\.?[\d]*', sb) 155 | if len(numList) == 0: 156 | continue 157 | # Extract different number types from the equation in order 158 | eqNumList = re.findall(r'-?[\d]* *[\d]+ *\/ *?[\d]*|-?[\d]+\.?[\d]*', eqs) 159 | # Get the positions of the numbers in the text 160 | id_pattern = [m.span() for m in re.finditer(r'-?[\d]* *[\d]+ *\/ *?[\d]*|-?[\d]+\.?[\d]*', sb)] 161 | # Get the positions of the numbers in the equation 162 | eqid_pattern = [m.span() for m in re.finditer(r'-?[\d]* *[\d]+ *\/ *?[\d]*|-?[\d]+\.?[\d]*', eqs)] 163 | if len(set(numList)) < len(set(eqNumList)) or len(set(eqNumList) - set(numList)) != 0: 164 | continue 165 | 166 | numLists[id_] = numList 167 | numLists_idx[id_] = id_pattern 168 | eqLists_idx[id_] = eqid_pattern 169 | texts[id_] = sb 170 | eqLists[id_] = eqNumList 171 | equations_List[id_] = eqs 172 | final_ans[id_] = row['ans'] 173 | 174 | # Adding space between numbers for ease of identification 175 | for id, text in texts.items(): 176 | modified_text = '' 177 | prev = 0 178 | for (start, end) in numLists_idx[id]: 179 | modified_text += text[prev:start] + ' ' + text[start:end] + ' ' 180 | prev = end 181 | modified_text += text[end:] 182 | 183 | # Re-evaluate positions of variables in the modified text 184 | numList = re.findall(r'-?[\d]* *[\d]+ *\/ *?[\d]*|-?[\d]+\.?[\d]*', modified_text) 185 | if len(numList) == 0: 186 | continue 187 | id_pattern = [m.span() for m in re.finditer(r'-?[\d]* *[\d]+ *\/ *?[\d]*|-?[\d]+\.?[\d]*', modified_text)] 188 | 189 | # assert (len(numList) == len(numLists[id])) 190 | # for i in range(len(numList)): 191 | # assert (numList[i].strip() == numLists[id][i].strip()) 192 | 193 | numLists[id] = numList 194 | numLists_idx[id] = id_pattern 195 | texts[id] = modified_text 196 | 197 | # Replace questions text with template variables 198 | for id, vals in numLists.items(): 199 | numListMap = {} 200 | num_list = [] 201 | num_postion = [] 202 | k = 0 203 | for i in range(len(vals)): 204 | tmp = convert_to_float(vals[i]) if vals[i].find("/") != -1 else vals[i] 205 | if tmp not in numListMap.keys(): 206 | num_list.append(str(tmp)) 207 | num_postion.append(numLists_idx[id][i][0]) 208 | numListMap[tmp] = 'temp_' + str(k) 209 | k += 1 210 | final_number_list[id] = num_list 211 | final_num_postn_list[id] = num_postion 212 | 213 | numListMAP[id] = numListMap 214 | replaced_text = '' 215 | prev_index = 0 216 | final_num_list = [] 217 | num_temps = [] 218 | for i in range(len(vals)): 219 | tmp = convert_to_float(vals[i]) if vals[i].find("/") != -1 else vals[i] 220 | final_num_list.append(tmp) 221 | start = numLists_idx[id][i][0] 222 | end = numLists_idx[id][i][1] 223 | unchanged = texts[id][prev_index:start] 224 | changed = texts[id][start:end] 225 | replaced_text = replaced_text + unchanged + numListMap[tmp] 226 | num_temps.append(numListMap[tmp]) 227 | prev_index = end 228 | replaced_text += texts[id][numLists_idx[id][i][-1]:] 229 | final_replaced_text[id] = replaced_text 230 | numtemp_order[id] = num_temps 231 | 232 | # Replace equations with template variables 233 | for id, vals in eqLists.items(): 234 | equation = equations_List[id] 235 | replaced_eq = '' 236 | prev_index = 0 237 | for i in range(len(vals)): 238 | tmp = convert_to_float(vals[i]) if vals[i].find("/") != -1 else vals[i] 239 | start = eqLists_idx[id][i][0] 240 | end = eqLists_idx[id][i][1] 241 | unchanged = equation[prev_index:start] 242 | changed = equation[start:end] 243 | replaced_eq = replaced_eq + " " + unchanged + " " + numListMAP[id][tmp] 244 | prev_index = end 245 | 246 | replaced_eq += " " + equation[eqLists_idx[id][i][-1]:] 247 | 248 | list_eqs = replaced_eq.split(" ") 249 | equation_normalized = [] 250 | 251 | for i in range(len(list_eqs)): 252 | if list_eqs[i] == '': 253 | continue 254 | elif list_eqs[i].find('temp') == -1 and len(list_eqs[i]) > 1: 255 | equation_normalized.extend(list(list_eqs[i])) 256 | else: 257 | equation_normalized.append(list_eqs[i]) 258 | 259 | final_replaced_eq[id] = equation_normalized 260 | 261 | eq_norm_postfix = postfix_equation(equation_normalized) 262 | final_replaced_eq_post[id] = eq_norm_postfix 263 | 264 | # Calculate word position for numbers in text 265 | for id, text in texts.items(): 266 | words = [i.replace(' ', '') for i in text.split()] 267 | numpos = [] 268 | 269 | min_ind = {i.replace(' ', ''): float('inf') for i in final_number_list[id]} 270 | 271 | for i in range(len(words)): 272 | if words[i] in [j.replace(' ', '') for j in numLists[id]]: 273 | tmp = convert_to_float(words[i]) if words[i].find("/") != -1 else words[i] 274 | min_ind[str(tmp)] = min(min_ind[str(tmp)], i) 275 | 276 | for val in final_number_list[id]: 277 | numpos.append(min_ind[str(val).replace(' ', '')]) 278 | 279 | final_num_postn_list[id] = numpos 280 | 281 | # Create json dump from the updated dictionaries 282 | data = {} 283 | for id in final_replaced_text.keys(): 284 | data_template = {} 285 | data_template["template_text"] = final_replaced_text[id] 286 | data_template["expression"] = equations_List[id] 287 | data_template["mid_template"] = final_replaced_eq[id] 288 | data_template["num_list"] = final_number_list[id] 289 | data_template["index"] = str(id) 290 | data_template["numtemp_order"] = numtemp_order[id] 291 | data_template["post_template"] = final_replaced_eq_post[id][2:] 292 | if final_ans[id].find("|") != -1: 293 | data_template["ans"] = final_ans[id].split("|")[1].strip() 294 | elif final_ans[id].find("/") != -1: 295 | data_template["ans"] = str(convert_to_float(final_ans[id])) 296 | else: 297 | data_template["ans"] = final_ans[id] 298 | data_template["text"] = texts[id] 299 | data_template["num_position"] = final_num_postn_list[id] 300 | data[id] = data_template 301 | 302 | with open('final_dolphin_data.json', 'w') as f: 303 | json.dump(data, f) 304 | -------------------------------------------------------------------------------- /2_Data_cleaning/cleaned_data_examples/filtered_cleaned_dolphin_data.json: -------------------------------------------------------------------------------- 1 | { 2 | "1": { 3 | "template_text": "what is the sum of temp_0 and temp_0 ?", 4 | "expression": "x=1+1", 5 | "mid_template": [ 6 | "x", 7 | "=", 8 | "temp_0", 9 | "+", 10 | "temp_0" 11 | ], 12 | "num_list": [ 13 | "1" 14 | ], 15 | "index": "1", 16 | "numtemp_order": [ 17 | "temp_0", 18 | "temp_0" 19 | ], 20 | "post_template": [ 21 | "temp_0", 22 | "temp_0", 23 | "+" 24 | ], 25 | "ans": "2", 26 | "text": "what is the sum of 1 and 1 ?", 27 | "num_position": [ 28 | 5 29 | ] 30 | }, 31 | "65": { 32 | "template_text": "how many temp_0 digit numbers can you make using the digits temp_1 , temp_2 , temp_0 , and temp_3 if the hundreds digit is prime and repetition of a digit is not permitted ?", 33 | "expression": "n=2*3*2", 34 | "mid_template": [ 35 | "n", 36 | "=", 37 | "temp_2", 38 | "*", 39 | "temp_0", 40 | "*", 41 | "temp_2" 42 | ], 43 | "num_list": [ 44 | "3", 45 | "1", 46 | "2", 47 | "4" 48 | ], 49 | "index": "65", 50 | "numtemp_order": [ 51 | "temp_0", 52 | "temp_1", 53 | "temp_2", 54 | "temp_0", 55 | "temp_3" 56 | ], 57 | "post_template": [ 58 | "temp_2", 59 | "temp_0", 60 | "*", 61 | "temp_2", 62 | "*" 63 | ], 64 | "ans": "12", 65 | "text": "how many 3 digit numbers can you make using the digits 1 , 2 , 3 , and 4 if the hundreds digit is prime and repetition of a digit is not permitted ?", 66 | "num_position": [ 67 | 2, 68 | 11, 69 | 13, 70 | 18 71 | ] 72 | } 73 | } -------------------------------------------------------------------------------- /2_Data_cleaning/cleaned_data_examples/uncleaned_dolphin_data.csv: -------------------------------------------------------------------------------- 1 | text,ans,equations,unknowns 2 | a number subtracted from 17 gives the quotient of 48 and -8. find the number.,23,17 - x =48 / -8,x 3 | what is the sum of 1 and 1?,2,x = 1 + 1,x 4 | which value is a solution to 60+2x=12x+10?,5,60+2*x=12*x+10,x 5 | "phil found that the sum of twice a number and -21 is 129 greater than the opposite of the number, what is the number?",50,(2*n + (-21)) = 129 + (-1 * n),n 6 | forty-two more than a number divided by four is equal to fifty-two more than the same number divided by fourteen. what is the number?,56,42 + n / 4 = 52 + n / 14,n 7 | "a number when multiplied by 7/18 instead of 7/8 and got the result 770 less than the actual result , find the original number?",1584,7/18*n = 7/8 * n - 770,n 8 | "if one- third of a number is subtracted from three-fourths of that number, the difference is 15. what is the number?",36,3/4*n - 1/3*n = 15,n 9 | one third of a number is 5 less than half of the same number. what is the number?,30,1/3 * n = 1/2 * n - 5,n 10 | 3 times a number is 7 more than twice the number? what is the number?,7,3*n = 7 + 2*n,n 11 | four times a number is 36 greater than the product of the number and -2. what is the number?,6,4*n = 36 + (-2)*n,n 12 | five times a number subtracted from seven times the number is a result of 18?,9,7*n - 5*n = 18,n 13 | "six times a number equals 3 times the number , increased by 24. find the number",8,6*m = 3*m + 24,m 14 | what fraction when added to its reciprocal is equal to 13/6?,3/2 or 2/3,f + 1/f = 13/6,f 15 | if a number is added to its reciprocal the sum is 25/12. find the number.,4/3 or 3/4 | 1.333 or 0.75,n + 1/n = 25/12,n 16 | "when 60% of a number is added to the number, the result is 192. what is the number?",120,60/100*n + n = 192,n 17 | the sum of an integer and its square is 30. find the number?,5 or -6,x + x^2 = 30,x 18 | "if f(x) = x + 2, find the value of x if f(x) = 12.",10,x + 2 = 12,x 19 | "if .4% of a quantity is 2.4, find the total quantity.",600,0.4/100*q = 2.4,q 20 | "what percentage of 306,000 equals 8,200?",0.0268,x * 306000 = 8200,x 21 | 4% of what number is 34?,850,4/100*n = 34,n 22 | 30 is 75% of what number?,40,30 = 75/100*n,n 23 | what percent of 54 is 135?,2.5,x * 54 = 135,x 24 | 56 is 75% of what number?,224/3 | 74.667,56 = 75/100*x,x 25 | if 18% of a number z is 54 find the value of z.,300,18/100*z = 54,z 26 | 35% of what number is 70?,200,35/100*n = 70,n 27 | i have the number 640 and the number 80. i need to figure out what percentage 80 is of 640.,0.125,p*640=80,p 28 | what is the value of x?here's the equation: 3/5(x+2)=12.,18,3/5(x+2)=12,x 29 | "if 5 is decreased by 3 times a number, the result is -4",3,5 - 3*n = -4,n 30 | "if a number is doubled and the result decreased by 35, the result is 25. what is the number?",30,2*n - 35 = 25,n 31 | five times the sum of a number and eight is twenty-five. find the number.,-3,5(n + 8) = 25,n 32 | "if 6 times a number is added to 8, the result is 56. what is the number",8,8 + 6*n = 56,n 33 | the difference of 10 and the product of 6 and n is 21.,-11/6 | -1.833,10 - 6*n = 21,n 34 | three more than the product of a number and 4 is 15. find the number.,3,4*x + 3 = 15,x 35 | three less than twice what number is -7?,-2,2*x-3 = -7,x 36 | the sum of 3 times a number and 17 is 5. what is the number?,-4,3*x + 17 = 5,x 37 | "if you double a number and then add 32,the result is 80. what is the number?",24,2*n + 32 = 80,n 38 | two less than the product of a number and -5 is -42? what is the number?,8,-5 * n - 2 = -42,n 39 | two more than twice a number is equal to the number itself. find the number.,-2,2 + 2*n = n,n 40 | "when the square of an integer is added to ten times the integer, the sum is zero. what is the integer?",-10 or 0,10*n + n^2 = 0,n 41 | two consecutive odd integers have a sum of 48. what is the largest of the two integers?,25,(2*n+1)+(2*n+3) = 48,2*n+3 42 | four times an integer plus 5 is equal to the square of the integer. what is the integer?,-1 or 5,4*n + 5 = n^2,n 43 | "robert thinks of a number. when he multiples his number by 5 and subtracts 16 from the result, he gets the same answer as to when he adds 10 to his number and multiples the result by 3. find the number robert is thinking of.",23,5*x - 16 = (x + 10) * 3,x 44 | the sum of two times a number and eight is equal to three times the difference between the number and four. what's the number?,20,2*n + 8 = 3(n - 4),n 45 | twice the sum of a number and sixteen is five less than three times the number?find the number?,37,2(n + 16) = 3*n - 5,n 46 | find a number such that subtracting its reciprocal from the number gives 16/15.,5/3 or -3/5 | 1.667 or -0.6,n - 1/n = 16/15,n 47 | "if 240 : 3 = x : 5, what's the value of x?",400,240/3 = x/5,x 48 | "given that 7 : 2 = x : 8,find the value of x",28,7/2 = x/8,x 49 | how many eight-digit telephone numbers are possible if the first digit must be nonzero?,90000000 | 9E7,n = 9*10*10*10*10*10*10*10,n 50 | 20% of the number is 3. find the number?,15,0.2*x = 3,x 51 | 5 times a number equals 20what is the number?,4,5*n = 20,n 52 | 15 is 25% of what number?,60,15 = 0.25*x,x 53 | 36 is 60% of what number?,60,36 = 0.6*x,x 54 | four times the difference of a number and one is equal to six times the sum of the number and three. find the number.,-11,4(x - 1) = 6(x + 3),x 55 | 16 more than a number is 20 more than twice the number.,-4,16 + n = 20 + 2*n,n 56 | "when 4 is added to 5 times a number,the number increases by 50. find the number.",23/2 | 11.5,5*n + 4 = n + 50,n 57 | the sum of a number and its reciprocal is 6.41. what is the number?,6.25 or 0.16 | 25/4 or 4/25,n + 1/n = 6.41,n 58 | the sum of a number and its reciprocal is 13. find this number.,12.923 or 0.0774,n + 1/n = 13,n 59 | the sum of a number and twice its reciprocal is 3. what is the number?,2 or 1,x + 2/x = 3,x 60 | "when the number x is increased by x%, it becomes 24. find x.",20 or -120,x*(1+0.01*x) = 24,x 61 | "if a number is multiplied by 5 more than twice that number, the result is 12. find the number?",3/2 or -4 | 1.5 or -4,x * (2*x + 5) = 12,x 62 | a 104 degree angle with 2 congruent sides is called? a 104 degree angle with 2 congruent sides is called?,4 or 2,x^2 + 8 = 6*x,x 63 | the sum of a number and seven times the number is 112. what is the number?,14,x + 7*x = 112,x 64 | "how many 3-digit positive integers are odd and do not contain the digit ""5""?",288,n = 8 * 9 * 4,n 65 | "how many 3 digit even numbers can be made using the digits 1, 2, 3, 4, 6, 7, if no digit is repeated?",60,n=3*5*4,n 66 | "how many 3 digit numbers can you make using the digits 1, 2, 3, and 4 if the hundreds digit is prime and repetition of a digit is not permitted?",12,n = 2 * 3 * 2,n 67 | a number is 110 less than its square. find all such numbers.,-10 or 11,n = n^2 - 110,n 68 | the sum of one third of a number and its reciprocal is the same as 49 divided by the number. find the number,12 or -12,n/3 + 1/n = 49/n,n 69 | "what number, when added to the number three or multiplied to the three, gives the same result?",3/2 | 1.5,n + 3 = 3*n,n 70 | if the average of 20 different positive numbers is 20 then what is the greatest possible number among these 20 numbers?,210,(1+(20-1))*(20-1)/2 + x = 20*20,x 71 | "if 9n^2 - 30n + c is a perfect square for all integers n, what is the value of c?",25,30^2 = 4 * 9 * c,c 72 | what is the least value of k if 3x^2 + 6x + k is never negative?,3,6^2 = 4*3*k,k 73 | find the value of c that makes x^2+18x+c a perfect square trinomial.,81,18^2 = 4 * 1 * c,c 74 | find this quantity: one half of the quantity multiplied by one third of the quantity equals 24.,12 or -12,(1/2)q * (1/3)q = 24,q 75 | "if the expression x^2 - kx - 12 is equal to zero when x=4, what is the value of k?",1,4^2 - 4*k - 12 = 0,k 76 | a 3- digit number in which one digit is the average of the other two digits is called an average number. 456 is an average number because 5 is the average of 4 and 6. how many three digit average numbers are there?,45,n = 9 * 5,n 77 | "what is 8% of 2,000.",160,x = 2000 * 0.08,x 78 | what's 8% of 5.00?,0.4,x = 0.08 * 5,x 79 | one fourth of one third is the same as one half of what fraction?,1/6,(1/4) * (1/3) = (1/2) * x,x 80 | "when the number n is multiplied by 4, the result is the same as when 4 is added to n. what is the value of 3n?",4,4*n = 4 + n,3*n 81 | "when the number w is multiplied by 4, the result is the same as when 4 is added to w. what is the value of 3w?",4,4*w = w + 4,3*w 82 | how many whole numbers less than 500 have seven as the sum of their digits?,30,n = 1 + 7 + 7 + 6 + 5 + 4,n 83 | "if 4 less than 3 times a certain number is 2 more than the number, what is the number",3,3*x - 4 = 2 + x,x 84 | "when twice a number is decreased by 8, the result is the number increased by 7. find the number?",15,2*x - 8 = x + 7,x 85 | 3 less than 5 times a number is 11 more than the number? find the number.,7/2 | 3.5,5*n - 3 = 11 + n,n 86 | "the first three terms of an arithmetic sequence are (12-p), 2p, (4p-5) respectively where p is a constant.find the value of p",7,2*p - (12-p) = (4*p-5) - 2*p,p 87 | "if the product of -3 and the opposite of a number is decreased by 7, the result is 1 greater than the number. what is the number?",4,-3*(-1 * n) - 7 = 1 + n,n 88 | what is the difference between the sum of the first 2004 positive integers and the sum of the next 2004 positive?,4016016,d = 2004 * (1 + 2004 - 1),d 89 | the sum of 2/3 and four times a number is equal to 5/6 subtracted from five times the number. find the number.,3/2 | 1.5,2/3 + 4*n = 5*n - 5/6,n 90 | "find the difference between the product of 26/22 and 3.09 and the sum of 3,507.00, 2.08, 11.50 and 16,712.00",-22251821/1100 | -20228.928,"x = 26/22 * 3.09 - (3507.00 + 2.08 + 11.50 + 16,712.00)",x 91 | what is 5/9 times 36?,20,x = 5/9 * 36,x 92 | what number is 5 sixths of 100?,250/3 | 83.333,x = 5/6 * 100,x 93 | 5/6 of a number is 3/4. what is the number?,9/10 | 0.9,5/6 * n = 3/4,n 94 | what is the value of x in the following equation?3x - 4(x + 1) + 10 = 0,6,3*x - 4(x + 1) + 10 = 0,x 95 | "a whole number between 100 and 1000 is to be formed so that one of the digits is 6, and all the digits are different. how many numbers are possible.",200,x = 9*8 + 8*8 + 8*8,x 96 | how many permutations of digits in the number 12345 will make an even number?,48,n = 2 * (4 * 3 * 2 * 1),n 97 | how many 1/4 ounces are in 56 pounds?,3584,x*(1/4) = 56 * 16,x 98 | what number minus one-half is equal to negative one-half ?,0,x - 1/2 = -1/2,x 99 | 16 over 24 minus 18 over 24?,-1/12 | -0.0833,x = 16/24 - 18/24,x 100 | "the quotient of -54 and 9, subtracted from 8 times a number, is -18. what's the number?",-3,8*n - (-54)/9 = -18,n 101 | "if 30% of 40% of a positive # is equal to 20% of w% of the same #, what is the value of w?",60,30/100*40/100*n = 20/100*w/100*n,w 102 | square root six over square root 7 multiplied by square root 14 over square root 3.,2,x = (sqrt(6)/sqrt(7)) * (sqrt(14)/sqrt(3)),x 103 | "if 4 is added to the square of a composite integer, the result is 14 less than 9 times that integer. find the integer.",6,4 + n^2 = 9*n - 14,n 104 | "how many four-digit numbers can be formed using the digits 0, 1, 2, 3, 4, 5, 6, 7, 8, and 9 if the first digit cannot be 0? repeated digits are allowed.",9000,n = 9 * 10 * 10 * 10,n 105 | the equation is 10x - 19 = 9 - 4x. what's the value of x?,2,10*x - 19 = 9 - 4*x,x 106 | twelve less than 2 times a number is equal to 15 minus 7 times the number.,3,2*n - 12 = 15 - 7*n,n 107 | -------------------------------------------------------------------------------- /3_T-RNN_&_baselines/output/pred_dolphin_rep_tfidf.txt: -------------------------------------------------------------------------------- 1 | 954 216216.0 216216.0 2 | 956 299484.0 299484.0 3 | 2015 1924.0 1924.0 4 | 976 -0.71875 -0.71875 5 | 982 5893157296026.0 5893157296026.0 6 | 985 78706.0 78706.0 7 | 999 223404480.0 223404480.0 8 | 1040 5184000.0 5184000.0 9 | 1004 274.0 274.0 10 | 600 3000.0 3000 11 | 1011 117740.0 117740.0 12 | 1012 10080.0 10080.0 13 | 686 6750.0 6750 14 | 1046 26.543513101659002 26.543513101659002 15 | 1066 625.0 625.0 16 | 1067 560.0 560 17 | 1069 66030.0 66030.0 18 | 41 0.30000000000000004 0.3 19 | 1088 87280691400.0 87280691400.0 20 | 1092 2976.0 2976.0 21 | 1095 376.0 376.0 22 | 1099 56.0 56.0 23 | 1104 3873860100.0 3873860100.0 24 | 1105 453.0 453 25 | 1118 52000.0 52000 26 | 1122 1026.0 1026.0 27 | 1133 11020.0 11020 28 | 1147 5382013074.0 5382013074.0 29 | 1157 385.0 385.0 30 | 1164 1026.0 1026.0 31 | 1173 1682.0 1682.0 32 | 1179 91931.0 91931.0 33 | 1186 418200.0 418200.0 34 | 1194 48440106.0 48440106.0 35 | 1198 6774144.0 6774144.0 36 | 1200 27440.0 27440.0 37 | 1203 112390880.0 112390880.0 38 | 1217 3562120.0 3562120.0 39 | 1230 788112.0 788112.0 40 | 1232 3162.0 3162.0 41 | 1247 225.0 225 42 | 1251 -0.5428571428571428 -0.5428571428571428 43 | 1254 254562.0 254562.0 44 | 617 360.0 360 45 | 110 2205.0 2205 46 | 1295 7851720.0 7851720.0 47 | 1296 218400.0 218400.0 48 | 1315 3146.0 3146.0 49 | 1320 103247.0 103247.0 50 | 1325 4968.0 4968.0 51 | 1331 2675.0 2675.0 52 | 1338 53.0 53.0 53 | 1340 1.4308943089430894 1.4308943089430894 54 | 1352 4567612665.0 4567612665.0 55 | 1359 5285280.0 5285280.0 56 | 1367 475716.0 475716.0 57 | 1371 702.0 702.0 58 | 1372 5985.0 5985.0 59 | 1681 8576.0 8576.0 60 | 1392 6873.0 6873.0 61 | 1396 240.0 240 62 | 1397 48668.0 48668.0 63 | 1402 104580.0 104580.0 64 | 1422 8528085.0 8528085.0 65 | 1438 106.5 106.5 66 | 1444 4758.0 4758.0 67 | 1448 204.0 204.0 68 | 2217 13720224.0 13720224.0 69 | 1449 0.1363636363636 0.136363636364 70 | 1455 0.26954863318499683 0.26954863318499683 71 | 1458 0.8073981542854229 0.8073981542854229 72 | 1473 180.0 180 73 | 1475 10528848.0 10528848.0 74 | 1476 10142808.0 10142808.0 75 | 1488 144.0 144 76 | 1503 39900.0 39900.0 77 | 1514 280.0 280.0 78 | 1518 0.4444444444444444 0.444444444444 79 | 1526 80410.0 80410.0 80 | 1541 13910400.0 13910400.0 81 | 1548 67704.0 67704.0 82 | 1568 23.525339373123398 23.525339373123398 83 | 1576 687.0 687.0 84 | 1577 864.0 864 85 | 1592 4592.0 4592.0 86 | 1602 700.0 700.0 87 | 292 31536000.0 31536000 88 | 1632 -18.0 -18.0 89 | 1648 1936.0 1936.0 90 | 1653 2744.0 2744.0 91 | 1762 155.88888888888889 155.88888888888889 92 | 1661 228.0 228.0 93 | 1674 134.0 134.0 94 | 1687 486.0 486 95 | 1692 1232.0 1232.0 96 | 1711 22950000.0 22950000.0 97 | 1730 167.0 167.0 98 | 1731 94.0 94.0 99 | 1752 0.84 0.84 100 | 1779 -0.75 -0.75 101 | 1784 0.9037037037037037 0.9037037037037037 102 | 2316 1369.0 1369.0 103 | 1900 82.0 82.0 104 | 1919 31000486944.0 31000486944.0 105 | 1920 1890108.0 1890108.0 106 | 1922 6.0 6 107 | 1929 46746.0 46746.0 108 | 1942 205.0 205.0 109 | 1943 133.0 133.0 110 | 1950 1031510.0 1031510.0 111 | 1952 86700.0 86700.0 112 | 1955 1440.0 1440 113 | 1957 291456.0 291456.0 114 | 1965 132.0 132.0 115 | 1981 5750352.0 5750352.0 116 | 1995 131424.0 131424.0 117 | 1997 1082004.0 1082004.0 118 | 2002 100.0 100.0 119 | 2026 462384.0 462384.0 120 | 2033 37004.0 37004.0 121 | 2053 22.053571428571427 22.053571428571427 122 | 2067 1560896.0 1560896.0 123 | 385 80.0 80 124 | 2093 1649061.0 1649061.0 125 | 2099 1331.0 1331.0 126 | 2100 64.0 64.0 127 | 2103 59319.0 59319.0 128 | 2109 8.35483870967742 8.35483870967742 129 | 2118 67811.0 67811.0 130 | 2137 7605.0 7605.0 131 | 2146 3996.0 3996.0 132 | 2167 1800.0 1800.0 133 | 2168 4004.0 4004.0 134 | 2183 4224.0 4224.0 135 | 2209 2.6382978723404253 2.6382978723404253 136 | 2242 867510.0 867510.0 137 | 2252 75548.0 75548.0 138 | 2259 1600.0 1600.0 139 | 2272 39.0 39.0 140 | 2282 17280000.0 17280000.0 141 | 455 11.0 11 142 | 2356 4.8 4.8 143 | 473 210.0 210 144 | 2360 551448.0 551448.0 145 | 2364 429975.0 429975.0 146 | 2367 213759.0 213759.0 147 | 2393 6862.0 6862.0 148 | 2394 44100.0 44100.0 149 | 2413 5720.0 5720.0 150 | 2419 0.008927575047427742 0.008927575047427742 151 | 497 3.0 3 152 | 2424 7826.0 7826.0 153 | 2427 6014.0 6014.0 154 | 2428 1479.0 1479.0 155 | 2435 33284498950.0 33284498950.0 156 | 2439 289304.0 289304.0 157 | 2465 86.11981566820276 86.11981566820276 158 | -------------------------------------------------------------------------------- /3_T-RNN_&_baselines/output/pred_dolphin_tfidf.txt: -------------------------------------------------------------------------------- 1 | 1536 3150.0 60 2 | 1561 20.0 4752 3 | 37 5.0 6 4 | 2123 -6.0 2 5 | 1110 0.0 105 6 | 1551 78.0 78 7 | 2152 25088.0 0.875 8 | 1652 105.0 315 9 | 1301 24.0 120 10 | 717 1944.0 216 11 | 2278 168.0 18 12 | 247 16.0 0.033 13 | 253 12.0 80 14 | 267 75.0 0.615384615385 15 | 1826 1660.0 3160 16 | 1670 72.0 52 17 | 299 5.0 7 18 | 2140 80.0 20 19 | 1690 88.0 408 20 | 1863 -8.0 8 21 | 1370 144.0 144 22 | 455 11.0 11 23 | 1754 1.0 0.833333333333 24 | 1492 180.0 180 25 | 1621 8.0 3.42857142857 26 | -------------------------------------------------------------------------------- /3_T-RNN_&_baselines/output/pred_math23k_tfidf.txt: -------------------------------------------------------------------------------- 1 | 7534 1100.0 60 2 | 7652 -0.8333333333299999 0.266666666667 3 | 23054 -124.0 43 4 | 14980 56.35673624288425 7 5 | 12837 -18.0 21 6 | 978 -178.833333333333 30 7 | 3513 4.0 4 8 | 15534 120.75 68.5714285714 9 | 12831 180.0 36 10 | 21881 799.4 240 11 | 18638 1500.0 1500 12 | 14552 0.3 0.3 13 | 21887 840.2 945 14 | 1818 60.0 60 15 | 5902 400.0 400 16 | 17497 15.0 315 17 | 20458 16.666666666666668 25 18 | 12053 -28.125 15 19 | 11062 796000.0 190 20 | 7828 4.8 3.072 21 | 13693 1203.0 750 22 | 18174 9.0 9 23 | 6679 21.0 21 24 | 6181 250.0 250 25 | 923 4000.0 4000 26 | 21017 0.8 81920 27 | 7686 0.05714285714285714 72 28 | 10874 2369.0 2369 29 | 12148 0.7291666666643749 1 30 | 5726 23005.0 47 31 | 13009 1152.0 1152 32 | 8509 500.0 320 33 | 2339 8.2 9 34 | 12042 0.3333333333333333 6 35 | 15159 0.34594594594594597 12.75 36 | 18333 425.0 42 37 | 7395 49.999999999999986 200 38 | 6061 88014.70999999999 1449.3 39 | 7729 27.041666666666668 2174 40 | 19245 3.761904761904762 5 41 | 316 20.65853658536585 2.9 42 | 6682 3120.0 3120 43 | 15563 96.0 150 44 | 11959 20.0 20 45 | 1120 16.0 16 46 | 233 54.666666666720666 12 47 | 15688 40.0 40 48 | 17147 481.0 481 49 | 21323 245.0 320 50 | 3456 57.292775000000006 0.47 51 | 20787 0.12 0.12 52 | 14915 12.0 12 53 | 1970 15000.0 15000 54 | 12906 0.5384615384615384 0.538461538462 55 | 13206 2160.0 30 56 | 15150 -1.44 936 57 | 581 8.0 8 58 | 3416 8.53981128 8.53981128 59 | 19230 6248.0 6336 60 | 14974 4.5 11 61 | 16206 5100.0 5100 62 | 9130 -5.75 0.125 63 | 3843 49.999999999950006 50 64 | 10464 160.000000000064 160 65 | 16208 7.0 6000 66 | 5364 78500.0 10 67 | 22153 -0.0030303030303030303 320 68 | 22057 30.000000000000007 13.5 69 | 9820 100.0 100 70 | 23017 0.47619047619064286 112 71 | 14404 7.9866888519134775 180 72 | 5413 0.65 0.6 73 | 17285 485.833333333245 318 74 | 21946 755.3325 126.75 75 | 6951 135.375 150 76 | 6051 25.0 20 77 | 15757 2.9899999999999998 3.05 78 | 20445 64.0 440 79 | 12348 80.0 80 80 | 18952 13.333333333333334 12 81 | 4050 -46.199999999999996 46.2 82 | 5899 0.9090909090909091 111 83 | 4412 200.0 200 84 | 15464 0.6351791530944625 61 85 | 12813 0.16666666666675 4.29166666667 86 | 5701 19.7 15.5 87 | 16257 160.0 14 88 | 1830 61350.0 61350 89 | 17773 5.0 40 90 | 10045 36.0 9600 91 | 4640 0.42857142857099995 0.428571428571 92 | 1688 0.14285714285714285 62 93 | 9962 90.60166666666667 120 94 | 20647 0.08256880733944955 0.09 95 | 15678 0.5 0.5 96 | 19222 -1.9 13.6 97 | 1888 1.804 6.76 98 | 18362 672.0 5600 99 | 21301 -277.21774193548384 1.2 100 | 11417 0.0 0.333333333333 101 | 13907 -34.0 65 102 | 22852 384.0 7 103 | 11276 2242.0 1229 104 | 16422 -17567.0 49 105 | 7257 77.999999999976 4 106 | 20524 3320.0 470 107 | 21203 0.6 100 108 | 6155 1.0125 81 109 | 9322 1.5 450 110 | 4129 52.881844380403464 6 111 | 12922 27.0 27 112 | 4234 36.1 36.1 113 | 7459 4200.0 170 114 | 18453 31.999999999992 18 115 | 12472 0.42857142857200003 0.428571428571 116 | 1329 160.0 160 117 | 14341 32.000000000016 32 118 | 4084 9230.0 19520 119 | 22574 1230.0 15 120 | 11350 7.0 7 121 | 22726 1008.0 28 122 | 14706 1.4 1.4 123 | 20302 28.0 28 124 | 10437 0.35714285714285715 715 125 | 18415 649.999999998 600 126 | 22015 89.28571428574999 50 127 | 22195 15.135135135135135 11 128 | 16997 13499775.0 75 129 | 17108 -0.007035175879396985 280 130 | 4289 600.0 600 131 | 21848 75.0 75 132 | 18360 576.0 400 133 | 2350 -6.442953020134228 58 134 | 3744 0.0016638935108153079 54 135 | 10838 0.09090909090909091 0.1 136 | 12421 320.0 130 137 | 22255 0.20930232558139536 80 138 | 13213 -0.15384615384615385 0.2 139 | 4716 1049.76 1.2 140 | 23146 70.0 70 141 | 18585 10.05 0.5 142 | 4504 -0.041666666666666664 28 143 | 17301 1850.0 1850 144 | 22294 23.999999999999996 24 145 | 15835 15.0 15 146 | 4118 -34.399999999999196 46 147 | 17109 280.0 96 148 | 9080 0.92 4600 149 | 23085 10.714285714286 10.7142857143 150 | 16001 2.6666666666666665 48 151 | 21186 -450.0 1200 152 | 2267 104.0 564 153 | 13141 810.00000000081 81 154 | 21632 12.5 6.5 155 | 5135 22.0 22 156 | 3589 530.0 530 157 | 4468 1920.0 4 158 | 6744 2.0996052742063494e-05 63 159 | 10394 0.4 0.4 160 | 7123 100.0 100 161 | 5205 -2.9 0.3 162 | 5089 1.1 11 163 | 10529 76.0 76 164 | 7098 672.0 672 165 | 11573 4.3 4.3 166 | 15765 45.0 225 167 | 7198 250.00000000000003 250 168 | 16277 8.5 1.5 169 | 13466 100000.0 10 170 | 6910 27.0 120 171 | 21695 8.0 8 172 | 15334 336.000000000336 168 173 | 16628 -15.172413793103448 40 174 | 18685 0.0033333333333333335 50 175 | 3757 -1780.0 10 176 | 5144 1.1904761904761905 930 177 | 3379 1.6666666666666667 21 178 | 1536 1.111111111111111 200 179 | 6725 7.0 7 180 | 2479 9.999999999989999 80 181 | 9143 0.14 0.14 182 | 3724 -0.7629999999999999 3.88 183 | 742 -3.1578947368421053 0.25 184 | 4079 0.49 6825 185 | 22672 0.7484375 1120 186 | 19291 10.0 10 187 | 17403 1160.0 12 188 | 11231 0.4444444444444444 0.444444444444 189 | 15583 108.4 123 190 | 15263 2000.0 200 191 | 20131 -27.999999999984 20 192 | 8945 450.0 60 193 | 6827 0.14285714285714285 0.142857142857 194 | 8040 0.26785714285725 0.267857142857 195 | 22913 -100.0 35 196 | 7496 5400.0 5400 197 | 6776 409.0 169 198 | 12186 27.47222222222222 15 199 | 1675 -56.25 1350 200 | 12474 0.4444444444444444 0.222222222222 201 | 1779 60.4 360 202 | 14732 75.0 55 203 | 2332 1000.0 90 204 | 3514 33.84615384605621 300 205 | 8444 126.0 126 206 | 19806 360.0 240 207 | 3906 2810.0 6 208 | 7555 3.75 3.75 209 | 15320 -34.980000000000004 2000 210 | 16903 14112.0 392 211 | 12994 5306.0 5306 212 | 21128 30.0 30 213 | 3946 50.000000000025004 120 214 | 13680 400.0 400 215 | 15377 11.523809523809524 55 216 | 14874 -36.0 40 217 | 19586 64.0 64 218 | 22011 -0.0234375 258 219 | 13229 288.0 6 220 | 4179 -0.25 0.75 221 | 19377 0.8235294117647058 58 222 | 8453 711.0 1935 223 | 6470 5.25 200 224 | 118 3.0 3 225 | 5853 20.0 20 226 | 21398 8001.6 1250 227 | 5036 -7.6 120 228 | 1355 25.0 25 229 | 11936 10.799999999999999 30 230 | 2570 0.2666666666666 360 231 | 13439 7.0 7 232 | 7654 6.0 6 233 | 3784 -57.142857142857146 200 234 | 16486 -2205.000000000014 37 235 | 10750 45.0 4.5 236 | 8438 0.49166666666666664 900 237 | 9681 650140.0 65 238 | 12868 8.888888888888857 80 239 | 8067 -12.0 10 240 | 21421 1.7777777777777777 1.77777777778 241 | 14754 100.8 115.2 242 | 9078 3.2129396788997536e-05 0 243 | 17140 1517472.0 2246 244 | 7722 5625.0 20 245 | 2791 41.0 98 246 | 6648 226.66666666666663 40 247 | 3775 144.0 144 248 | 15793 0.16666666666666666 0.166666666667 249 | 1466 0.2375 84.375 250 | 3501 -4.008016032064128 0.2 251 | 20781 1280.0 2000 252 | 10980 25967.5 6110 253 | 1393 200.0 200 254 | 771 21.0 21 255 | 22667 1.150537634408602 3.6 256 | 15748 1.6 1.6 257 | 15803 0.017361111111111112 36000 258 | 19868 2360.0 344 259 | 19390 -0.0011111111111095238 77.7777777778 260 | 15140 44.444444444488894 60 261 | 5859 20.0 20 262 | 1705 159.0 208 263 | 1411 224.0 800 264 | 22239 -0.039473684210526314 15 265 | 15208 140.0 8120 266 | 12465 128.40000000000597 0.206896551724 267 | 1789 0.10641025641025642 1.56 268 | 19129 6.24 0.1 269 | 14906 0.00019994001799460164 125 270 | 17781 -0.030454318522216674 34 271 | 20076 49.000000000006125 64 272 | 18326 8.0 8 273 | 14762 810.0000000000002 9 274 | 20828 2400.0 2400 275 | 18063 14.0 14 276 | 14100 0.4 2.4 277 | 1538 18.181818181818183 8 278 | 17983 215.99999999987295 216 279 | 15937 488.5714285714286 488.571428571 280 | 9111 -0.36 2.2 281 | 1302 432.0 432 282 | 14281 2.8 2.8 283 | 9252 0.0 27.75 284 | 253 5.0 5 285 | 9208 2.2857142857139183 180 286 | 17422 3605400.0 10 287 | 10328 3.0 15 288 | 2747 0.85 6 289 | 10869 1080.6779999999999 1998 290 | 2195 300.0 7 291 | 1595 1896.0 2121 292 | 21239 67.5 67.5 293 | 4664 270.0 270 294 | 8033 8.0 8 295 | 6953 4.694444444444445 8 296 | 11657 -0.759493670886076 0.25 297 | 17902 4800.0 1800 298 | 7911 -288.0 3 299 | 19018 0.8042328042326032 75 300 | 1592 7.1000000000000005 8.7 301 | 17266 18.0 18 302 | 22981 36.0 36 303 | 21855 450.0 43 304 | 21575 2250.0 9000 305 | 12587 48.0 87 306 | 1911 839.99999999937 840 307 | 10565 0.2 0.2 308 | 4999 0.37142857142799995 0.571428571429 309 | 6152 0.5 0.5 310 | 18609 40.0 40 311 | 14332 -1.4933333333324 9 312 | 3439 -0.5833333333335 0.833333333333 313 | 13381 42.0 42 314 | 17611 1680.0 400 315 | 3478 8000.0 264 316 | 3852 1.7 12 317 | 17504 8.0 18 318 | 12505 1.4666666666666668 0.7 319 | 8452 4000.0 40 320 | 17850 8.0 56 321 | 6997 432.222222222222 264 322 | 14179 200.0 162 323 | 2894 -9.0 28 324 | 11644 11.0 11 325 | 643 1119.0 867 326 | 18153 459.0 459 327 | 12969 45946.56 252 328 | 2337 0.05 0.95 329 | 10407 0.30000000000000027 0.5 330 | 6439 0.525 700 331 | 14142 0.96 0.04 332 | 19900 -6.0 2 333 | 22332 86.4 86.4 334 | 7086 34300.0 28 335 | 12537 5.0 0.1 336 | 13501 2.9699999999999998 2.4 337 | 2141 41.15384615384615 840 338 | 7207 3.0 50 339 | 20436 0.01 4.5 340 | 21726 12.100000000000001 12.1 341 | 9828 -0.0011129660545358931 1500 342 | 3115 -3392.0 73 343 | 7066 0.002802144249513158 216 344 | 18232 0.8333333333325 0.266666666667 345 | 17668 6.0 22 346 | 514 3.0 3 347 | 14444 0.5333333333336 150 348 | 19741 8.7 8.7 349 | 2732 240.0 135 350 | 18517 5.0 5 351 | 12797 0.19999999999760004 18 352 | 9856 504.0 48 353 | 2466 15.333333333333 10 354 | 17388 3600.0 3600 355 | 9650 31.0 31 356 | 1988 -5634.67 130.13 357 | 20281 0.09999999999999998 0.1 358 | 842 280721.0 5 359 | 8137 213.54545454545456 26100 360 | 10287 -0.0026041666666666665 392 361 | 9251 35.555555555519994 360 362 | 2866 5.0 90 363 | 16212 2.0 2 364 | 4197 1185230.0 490 365 | 5069 4037.04 6.1 366 | 9566 0.13793103448275862 170 367 | 15995 10.0 10 368 | 18319 122.4 489.6 369 | 853 40.0 40 370 | 11135 301.45 780 371 | 18772 -0.18388429752066118 39.2 372 | 18618 8.0 5 373 | 12567 0.6 240 374 | 12256 66.66666666666667 96 375 | 11894 -3745.2 1.5 376 | 11538 50.0 50 377 | 7295 50.0 50 378 | 8089 0.166666666667 10 379 | 6780 0.625 0.35 380 | 19497 5.429246293061726e+168 1.5 381 | 15177 0.16 0.16 382 | 8565 36.0 36 383 | 21876 0.06666666666666667 143 384 | 7931 -43.0 87.6 385 | 11060 0.39999999999999997 0.4 386 | 15573 14.0 14 387 | 1553 -0.031055900621086957 34.5 388 | 11683 114.0 114 389 | 8269 0.0009302325581395349 1075 390 | 17025 4018.0 58 391 | 7364 7.0 7 392 | 14314 0.036000000000000004 0.036 393 | 20327 -28320.0 0.125 394 | 10799 -5.333333333333 0.833333333333 395 | 11165 198.0 54 396 | 15036 4.5 7.5 397 | 11767 78.000000000078 78 398 | 17369 -113.89830508474576 39 399 | 8319 -50.00000000022998 528 400 | 406 1926.0 7 401 | 13945 33.666666666666664 50 402 | 10050 0.011908931698774083 30 403 | 11429 160.0 25 404 | 17479 8.9 8.9 405 | 11659 0.0015282730514503335 120 406 | 16156 93500.00000000001 93500 407 | 10333 107.2 48 408 | 12203 0.06584362139920823 0.0833333333333 409 | 20631 -52680.0 7920 410 | 5009 64.00000000008001 64 411 | 3258 0.005333333333333333 30 412 | 17281 10.5 4 413 | 6922 250.0 730 414 | 9165 840.0 840 415 | 1249 3.0 600 416 | 9342 14400.0 1 417 | 11078 0.08662696264212236 0.25 418 | 2102 4.8076923076923075 16 419 | 5137 0.19999999999999998 0.25 420 | 18088 11.0 11 421 | 19267 0.5625 52 422 | 7293 180.0 180 423 | 11509 1.666666666666 0.972222222222 424 | 1591 0.9944751381215425 225 425 | 5846 53.333333333333336 300 426 | 4346 -1349.42 783 427 | 17849 0.033333333333300005 5.45454545455 428 | 10505 2.6315789473684212 40 429 | 3202 105.0 105 430 | 9511 20.0 16 431 | 11061 60.0 60 432 | 2316 12.0 3 433 | 14575 0.435 0.125 434 | 9289 0.004166666666666667 32 435 | 11873 18.0 198 436 | 17694 -50.0 50 437 | 10001 2.4000000000000004 0.4 438 | 1520 8.421052631578947 5 439 | 8463 -59.416666666667 28 440 | 10077 18.96 18.96 441 | 21794 -15.0 0.25 442 | 19545 -0.0046875 200 443 | 17535 1980.0 1980 444 | 22821 1.3333333333333333 4 445 | 21557 1542.6 52 446 | 1140 132.0 9 447 | 15890 0.2 0.8 448 | 17134 1698.6666666666667 5 449 | 3475 0.0004982559798166703 0.999002991027 450 | 9849 400.0 400 451 | 13309 2064.15 310 452 | 20185 4560.0 10 453 | 18748 146.285714285696 112 454 | 2957 -6.75 80 455 | 1929 2000500.0 528.35 456 | 7607 36.0 36 457 | 327 22500.0 0 458 | 11541 1100.0 1100 459 | 7657 75.0 5 460 | 9592 23.0 15 461 | 13295 7.0 8 462 | 12387 56.0 5 463 | 10930 3.0 3 464 | 8289 3300.0 108 465 | 14335 0.006 238 466 | 3500 8000.0 8000 467 | 22994 -28.0 60 468 | 21882 324.0 36 469 | 18002 -20.0 3 470 | 11879 -0.053106744556558685 6 471 | 7417 39.8 40.2 472 | 6088 11.0 27 473 | 19437 168.8 15.6 474 | 14847 26.0 26 475 | 15901 28.8 228.8 476 | 15948 144.0 144 477 | 5063 45.2 80 478 | 11790 840.0000000001528 840 479 | 4829 0.32000000000000006 1.33333333333 480 | 4159 42.0 42 481 | 20513 25.0 200 482 | 13691 72.0 72 483 | 6892 -14160.0 0.2 484 | 8793 0.13333333333333333 30.6 485 | 5231 0.05555555555555555 18 486 | 21370 -0.08333333333299997 2 487 | 2506 750.0 750 488 | 12936 540.0 2040 489 | 7536 9.142857142857144 7.8 490 | 16035 4.800000000000001 3.57 491 | 2916 6.0 6 492 | 20803 12.0 12 493 | 11214 0.30000000000004357 0.3 494 | 12802 56980.0 76 495 | 15384 nan 95.76 496 | 1919 46.444444444398 171 497 | 1795 -29.0 3 498 | 23066 -75.01854140914709 53.9466666667 499 | 21639 75.00000000000001 75 500 | 4493 -3.0000000000094285 45 501 | 21176 3.9599999999999995 6 502 | 22059 210.0 90 503 | 10632 -0.20746058091286307 720 504 | 12514 63.9999999999744 36 505 | 1291 2304.0 26 506 | 7942 2.0000000000035003 2 507 | 829 30.0 30 508 | 17470 110.00000000002501 110 509 | 432 1008.0 27 510 | 610 -224400.0 17 511 | 13490 30.25000000003025 0.590909090909 512 | 18573 3.8461538461538463 0.3125 513 | 2651 0.25 4 514 | 2625 0.013888888888888888 72 515 | 1035 -1.877167755990811 33.5 516 | 4299 9.0 15 517 | 22620 102.0 102 518 | 9682 7.0 2500 519 | 20963 0.0011111111111111111 330 520 | 17516 500.0 500 521 | 12323 1.3333333333333333 40 522 | 14903 80.99999999998988 81 523 | 17078 35.999999999951996 36 524 | 10910 640.0 11 525 | 20800 1000.0 1000 526 | 8590 760.0 2 527 | 11446 16.25 16.25 528 | 7487 87120.0 0.2 529 | 5006 2800.0 1520 530 | 8101 1.0 2.5 531 | 11476 12.407407407416665 11 532 | 17789 -434.99999999999983 200 533 | 14524 80.0 620 534 | 2349 31.249999999999996 31.25 535 | 4383 2268.0 2268 536 | 12008 350.0 100 537 | 13992 19.0 20 538 | 9448 2.0 2500 539 | 14773 0.02 8.33333333333 540 | 3003 4.2749999999999995 90 541 | 3301 9.32 4.66 542 | 5918 81.08108108108108 100 543 | 20952 -36.5625 6 544 | 9883 500.0 500 545 | 22714 85.33333333333333 170 546 | 21612 270.0 270 547 | 10474 1.1800000000000002 14.4 548 | 14965 15.0 15 549 | 3487 24.975 5 550 | 19875 20.0 20 551 | 16757 11.0 38 552 | 5754 9.0 9 553 | 13635 0.06153846153846154 1040 554 | 22110 12.0 12 555 | 8535 3016.0 308 556 | 11727 -3.04 56 557 | 15141 354.0 120 558 | 18419 16.0 1 559 | 10867 105.0 45 560 | 11204 0.33333333333333337 15 561 | 9874 0.525252525253 0.525252525253 562 | 12075 2.0 12.5 563 | 17268 0.02 33 564 | 11792 4320.0 4320 565 | 7240 399.0 399 566 | 12602 300.0 300 567 | 3932 11.818181818181818 713 568 | 10298 51.0 17 569 | 7150 3574.2857142857147 700 570 | 13982 0.5833333333329717 0.583333333333 571 | 5273 0.050724637681159424 26 572 | 22259 90.0 60 573 | 12349 4.0 20 574 | 5906 195.0 195 575 | 16642 480.0 480 576 | 3255 404.3076923076923 84 577 | 8248 59.0 61 578 | 20022 1.8571428571428572 0.533333333333 579 | 19973 15.0 0.0666666666667 580 | 13573 316.0 31 581 | 22928 0.00021132713440405747 7 582 | 6427 1.5 21 583 | 9658 75.00000000033751 75 584 | 163 0.125 0.125 585 | 12855 -0.6612244897959183 95 586 | 15740 162.0 27 587 | 14855 19659.0 390 588 | 13741 -0.05259259259259259 1.4332 589 | 22202 840.0 840 590 | 20771 2268.0 1764 591 | 18851 24.0 25 592 | 6931 -29.666666666666668 120 593 | 15346 -10.25316455696202 110 594 | 10002 76.5 75 595 | 17583 17600.0 1000 596 | 11602 48.333333333333336 80 597 | 14609 3.4 2.4 598 | 15897 5985.0 5985 599 | 21211 60.00000000000001 60 600 | 18405 7.2 3 601 | 19680 28.800000000000004 5.5 602 | 11450 59.999999999879996 50 603 | 12691 2.5 2.5 604 | 7012 13.5 13.5 605 | 6096 84.0 38 606 | 18450 0.9319727891157142 0.3 607 | 21348 800.0 800 608 | 13807 0.0196078431372549 0.0196078431373 609 | 3561 0.6666666666666666 0.5 610 | 14069 6036.0 20700 611 | 10589 21875.428571428572 1500 612 | 2954 35.555555555573335 20 613 | 16273 1.0345071300973196 1 614 | 8287 501.0 3 615 | 19878 154.13333333333333 6 616 | 1611 27327.89527242847 282.6 617 | 1883 349.9999999999416 350 618 | 20297 244.72299999999998 2 619 | 14752 0.8571428571428571 3.71428571429 620 | 8819 0.625 1.3 621 | 3064 -0.6020066889632106 500 622 | 15886 0.3333333333333333 0.5 623 | 4358 240.0 375 624 | 14955 4.0 4 625 | 4622 9.0 11.75 626 | 10707 32.0 32 627 | 2599 60.0 12 628 | 4752 6.5 66 629 | 7154 18.0 2 630 | 22380 -0.039603960396039604 1368 631 | 21487 -0.375 36 632 | 7844 1123.2 1123.2 633 | 11081 12.0 3 634 | 18385 0.00881542699724518 200 635 | 7185 145.0 145 636 | 2216 57.5 30 637 | 8199 0.7142857142857143 16 638 | 16493 180.0 65 639 | 11169 1350.0 60 640 | 2972 315.0 315 641 | 20043 2.5 2 642 | 22444 85.0 15 643 | 2498 5.999999999994 6 644 | 5404 66.0 66 645 | 1125 31999.999999996002 4000 646 | 14461 367.0 367 647 | 16284 1.2954545454545454 104 648 | 12335 -4.341875 15.7 649 | 18883 1250.0 1100 650 | 18318 78.0 4 651 | 12119 2000.000000001 2000 652 | 7490 3.0 3 653 | 16758 49.0 200000 654 | 20234 7.363636363636363 10 655 | 35 1.5 4 656 | 20657 18720.0 3120 657 | 440 6.0 6 658 | 4219 17.43576388888889 35.7 659 | 20848 22.1 9.25 660 | 22775 36.57142857142857 256 661 | 5250 5040.0 5160 662 | 15092 24.0 24 663 | 11334 10430.0 10430 664 | 2717 33.0 720 665 | 22148 4.9999999999995834 5 666 | 969 3600.0 3600 667 | 15090 19.45 19.45 668 | 5483 -4.3500000000000005 40 669 | 3094 33.23076923077026 33 670 | 7598 726.0 180 671 | 20519 448.00000000000006 64 672 | 14076 75.0 75 673 | 18959 10.0 250 674 | 8614 600.0 600 675 | 5126 280.0 40 676 | 10725 0.0044444444444444444 37.5 677 | 19368 11.666666666667 16 678 | 2519 -48.0 1296 679 | 20186 0.1875 192 680 | 17901 1587.5 1587.5 681 | 22051 36.0 36 682 | 15481 -49.0 63 683 | 2598 6.0714285714309995 9.33333333333 684 | 16489 3.95959595959596 2 685 | 4142 12.0 14 686 | 7258 33.5999999999328 224 687 | 7212 2.520080321285141 275 688 | 15434 300.0 300 689 | 14117 28.0 37 690 | 13947 23800.0 23800 691 | 15323 26.0 26 692 | 2587 599.0 15 693 | 18617 1000.0 1000 694 | 1235 6.0 0.5 695 | 3193 0.30845681789999996 1.111 696 | 18156 4.986666666666666 6 697 | 11665 2400.0 2400 698 | 4963 0.0666666666668 560 699 | 22633 52.5 8 700 | 18553 23.666666666667 9 701 | 19730 40.0 40 702 | 10398 480.0 480 703 | 2801 0.041666666666666664 27 704 | 9245 61.0 61 705 | 14014 -3.8285714285714287 4 706 | 10772 0.4444444444444444 13 707 | 11132 7.0 12 708 | 2075 14592.0 16 709 | 19122 15675.0 15675 710 | 4922 -0.999375 1920 711 | 19595 0.020408163265309524 49 712 | 14531 0.25 0.375 713 | 7781 1.5902777777777777 5.5 714 | 3161 1.666666666665 562.5 715 | 5687 3600.000000000001 40 716 | 10590 23.5 23.5 717 | 6057 19.700000000000003 12.1 718 | 7553 2.9999999999994 3 719 | 588 60.0 45 720 | 21505 420.0 90 721 | 12215 0.05656565656565657 523 722 | 11985 52.0 420 723 | 8241 130.0 130 724 | 13746 1.26 1.26 725 | 12878 0.25 57 726 | 243 1000.0 240 727 | 7221 5.6 6 728 | 20984 0.0026666666666666666 20 729 | 16019 497.77777777777777 490 730 | 19176 512.0 800 731 | 754 34.29 38.75 732 | 17290 -0.8275862068965517 2 733 | 19158 0.7000000000000001 0.7 734 | 13814 0.20833333333333331 4.8 735 | 22682 -153.0 32 736 | 306 50.000000000025004 50 737 | 15599 0.625 0.15 738 | 18902 -331.1639871382636 718 739 | 20360 60.25 15 740 | 19866 62.800000000000004 31.4 741 | 4415 0.5833333333333334 10 742 | 15083 0.8571428571428571 845 743 | 2772 0.08 0.08 744 | 20388 2400.0 2300 745 | 12577 4200.0 4200 746 | 18707 -66.66666666666666 0.047619047619 747 | 1514 48.0 48 748 | 6187 55.800000000000004 55.8 749 | 16895 36.0 144 750 | 14372 1.0 3 751 | 11528 112.0 112 752 | 16601 -77.0 42 753 | 16638 11.999999999994001 12 754 | 8047 0.21 0.21 755 | 20300 67.5 210 756 | 362 0.16666666666666666 0.166666666667 757 | 13792 0.33999999999999997 0.34 758 | 8321 6.0 36 759 | 22396 0.24999999999999994 0.444444444444 760 | 12072 135459.0 12 761 | 11294 0.9999999999835012 1 762 | 22395 913.6363636363635 60 763 | 21325 1.35 0.825 764 | 6137 72.0 72 765 | 7028 1.2195121951219512 8217 766 | 11258 0.7972222222222223 0.73 767 | 6581 30.0 330 768 | 21662 7410.0 204 769 | 10977 5.0 5 770 | 13450 0.1250521050437682 4800 771 | 6254 2405.0 1080 772 | 15248 90.428571428571 255 773 | 14217 2.5 2.5 774 | 17256 0.993816425120773 180 775 | 15553 5.0 5 776 | 11029 11.052631578947368 10 777 | 4436 3.3600000000000003 4.2 778 | 6900 18.0 18 779 | 3180 0.00199700449326011 0.25 780 | 12115 5.0 5 781 | 8839 11666.666666666666 17 782 | 17180 810.0 70 783 | 12069 0.8 0.657142857143 784 | 4872 119.0 177 785 | 14058 -1.8606847697756788 2424 786 | 7211 6.48 1.3 787 | 706 -9.485714285714286 620 788 | 8561 -11.666666666642 8 789 | 16946 -224.25 400 790 | 3111 -0.05128205128205128 90 791 | 15917 28.8 28.8 792 | 14731 100.0 100 793 | 3421 150.0 30 794 | 3645 8.333333333333002 8 795 | 244 6000.0 3840 796 | 5868 31.999999999920004 32 797 | 18252 3.0 3 798 | 14126 3.39 3.39 799 | 22739 12.495 12.495 800 | 1131 1.8285714285714285 0.8 801 | 6130 8.0 8 802 | 12775 0.08333333333333333 0.0833333333333 803 | 10684 499.0 3 804 | 20151 17.0 22 805 | 2929 0.4285714285715 0.428571428571 806 | 20140 3.0 3 807 | 20611 4.9 4 808 | 3950 27.428571428568 21 809 | 4250 24.0 44 810 | 22183 145.0 145 811 | 1739 15300.0 15300 812 | 6832 12.0 12 813 | 13861 112.0 96 814 | 6992 80.5 40 815 | 14143 435344.0 1157 816 | 3268 175.0 175 817 | 20583 -50.90909090909091 4370 818 | 2255 1500.0 600 819 | 3583 0.16666666666666666 0.0833333333333 820 | 19867 480.0 720 821 | 21555 -0.009259259259259259 112 822 | 6022 26.052631578947366 15 823 | 10299 -4.5 2 824 | 19587 300.0 800 825 | 17421 -4.1000000000000005 0.833333333333 826 | 8030 0.16666666666666666 54 827 | 4012 39.111111111111114 80 828 | 19921 23.0 23 829 | 1462 13000.0 20 830 | 11718 5.0 5 831 | 16702 76.0 76 832 | 17880 96.0 96 833 | 11982 1.32 1.32 834 | 3271 10850.0 10850 835 | 5985 2.0 251.2 836 | 1790 -47.800000000000004 900 837 | 2305 4.32 44.5 838 | 1392 6388.8 6388.8 839 | 12149 0.037037037037037035 324 840 | 15961 71.35714285714286 1850 841 | 13217 0.25 4 842 | 16746 -480.0 960 843 | 5449 438636005.19200003 3 844 | 20321 -143.8666666666668 240 845 | 10640 0.15789473684210525 4000 846 | 16227 220.0 60 847 | 8558 -0.222222222222 0.333333333333 848 | 18301 1.25 0.2 849 | 14235 10.0 10 850 | 2455 24.03 24.03 851 | 20259 19.800000000000008 19.8 852 | 19655 75.0 70 853 | 13275 323.9999999997408 100 854 | -------------------------------------------------------------------------------- /3_T-RNN_&_baselines/output/run_dolphin.txt: -------------------------------------------------------------------------------- 1 | 10 2 | 297 3 | 250 25 30 4 | loading finished 5 | RECUR EPOCH: 0, loss: 1.1682798597547743, train_acc_e: 0.6151079136690647, train_acc_i: 0.5679012345679012 time: 0.08729670445124309 6 | valid_acc_e: 0.75, valid_acc_i: 0.76, test_acc_e: 0.49333333333333335, test_acc_i: 0.41379310344827586 7 | originial 0 8 | saving... 0.76 9 | saveing ok! 10 | final test temp_acc:0.43333333333333335, ans_acc:0.43333333333333335 11 | 12 | RECUR EPOCH: 1, loss: 0.9512225806467818, train_acc_e: 0.6276978417266187, train_acc_i: 0.6008230452674898 time: 0.08907195727030436 13 | valid_acc_e: 0.75, valid_acc_i: 0.76, test_acc_e: 0.49333333333333335, test_acc_i: 0.41379310344827586 14 | originial 0.76 15 | saving... 0.76 16 | saveing ok! 17 | final test temp_acc:0.43333333333333335, ans_acc:0.43333333333333335 18 | 19 | RECUR EPOCH: 2, loss: 0.88405234725387, train_acc_e: 0.6276978417266187, train_acc_i: 0.6008230452674898 time: 0.09000606139500936 20 | valid_acc_e: 0.75, valid_acc_i: 0.76, test_acc_e: 0.49333333333333335, test_acc_i: 0.41379310344827586 21 | originial 0.76 22 | saving... 0.76 23 | saveing ok! 24 | final test temp_acc:0.43333333333333335, ans_acc:0.43333333333333335 25 | 26 | RECUR EPOCH: 3, loss: 0.8569288528505176, train_acc_e: 0.6276978417266187, train_acc_i: 0.6008230452674898 time: 0.09406771659851074 27 | valid_acc_e: 0.75, valid_acc_i: 0.76, test_acc_e: 0.49333333333333335, test_acc_i: 0.41379310344827586 28 | originial 0.76 29 | saving... 0.76 30 | saveing ok! 31 | final test temp_acc:0.43333333333333335, ans_acc:0.43333333333333335 32 | 33 | RECUR EPOCH: 4, loss: 0.8268176341743626, train_acc_e: 0.6420863309352518, train_acc_i: 0.6131687242798354 time: 0.09377397696177164 34 | valid_acc_e: 0.7857142857142857, valid_acc_i: 0.8, test_acc_e: 0.5733333333333334, test_acc_i: 0.4827586206896552 35 | originial 0.76 36 | saving... 0.8 37 | saveing ok! 38 | final test temp_acc:0.5, ans_acc:0.5 39 | 40 | RECUR EPOCH: 5, loss: 0.8387280844857172, train_acc_e: 0.6636690647482014, train_acc_i: 0.6460905349794238 time: 0.09163166284561157 41 | valid_acc_e: 0.7857142857142857, valid_acc_i: 0.8, test_acc_e: 0.5733333333333334, test_acc_i: 0.4827586206896552 42 | originial 0.8 43 | saving... 0.8 44 | saveing ok! 45 | final test temp_acc:0.5, ans_acc:0.5 46 | 47 | RECUR EPOCH: 6, loss: 0.7824347774678297, train_acc_e: 0.6690647482014388, train_acc_i: 0.6460905349794238 time: 0.09390345414479574 48 | valid_acc_e: 0.7321428571428571, valid_acc_i: 0.72, test_acc_e: 0.5466666666666666, test_acc_i: 0.4482758620689655 49 | jumping... 50 | 51 | RECUR EPOCH: 7, loss: 0.7914263799847889, train_acc_e: 0.658273381294964, train_acc_i: 0.6255144032921811 time: 0.09514609972635905 52 | valid_acc_e: 0.7857142857142857, valid_acc_i: 0.76, test_acc_e: 0.5866666666666667, test_acc_i: 0.4827586206896552 53 | jumping... 54 | 55 | RECUR EPOCH: 8, loss: 0.7353058999457968, train_acc_e: 0.6960431654676259, train_acc_i: 0.6337448559670782 time: 0.09473386208216349 56 | valid_acc_e: 0.7857142857142857, valid_acc_i: 0.8, test_acc_e: 0.5866666666666667, test_acc_i: 0.4827586206896552 57 | originial 0.8 58 | saving... 0.8 59 | saveing ok! 60 | final test temp_acc:0.5, ans_acc:0.5 61 | 62 | RECUR EPOCH: 9, loss: 0.7557127485549989, train_acc_e: 0.6672661870503597, train_acc_i: 0.6255144032921811 time: 0.09168526728947958 63 | valid_acc_e: 0.7321428571428571, valid_acc_i: 0.68, test_acc_e: 0.56, test_acc_i: 0.3793103448275862 64 | jumping... 65 | 66 | RECUR EPOCH: 10, loss: 0.6924405353059494, train_acc_e: 0.6816546762589928, train_acc_i: 0.5967078189300411 time: 0.09236328999201457 67 | valid_acc_e: 0.7678571428571429, valid_acc_i: 0.8, test_acc_e: 0.5866666666666667, test_acc_i: 0.4827586206896552 68 | originial 0.8 69 | saving... 0.8 70 | saveing ok! 71 | final test temp_acc:0.5, ans_acc:0.4666666666666667 72 | 73 | RECUR EPOCH: 11, loss: 0.6715376661638174, train_acc_e: 0.6942446043165468, train_acc_i: 0.654320987654321 time: 0.09408228794733683 74 | valid_acc_e: 0.6964285714285714, valid_acc_i: 0.68, test_acc_e: 0.48, test_acc_i: 0.2413793103448276 75 | jumping... 76 | 77 | RECUR EPOCH: 12, loss: 0.6259040008356542, train_acc_e: 0.710431654676259, train_acc_i: 0.6378600823045267 time: 0.09618140856424967 78 | valid_acc_e: 0.6785714285714286, valid_acc_i: 0.64, test_acc_e: 0.4666666666666667, test_acc_i: 0.3103448275862069 79 | jumping... 80 | 81 | RECUR EPOCH: 13, loss: 0.6323614552187822, train_acc_e: 0.7140287769784173, train_acc_i: 0.654320987654321 time: 0.09348649183909098 82 | valid_acc_e: 0.6964285714285714, valid_acc_i: 0.68, test_acc_e: 0.5333333333333333, test_acc_i: 0.27586206896551724 83 | jumping... 84 | 85 | RECUR EPOCH: 14, loss: 0.6169184696527175, train_acc_e: 0.7302158273381295, train_acc_i: 0.6378600823045267 time: 0.09691879351933798 86 | valid_acc_e: 0.6964285714285714, valid_acc_i: 0.68, test_acc_e: 0.6266666666666667, test_acc_i: 0.4482758620689655 87 | jumping... 88 | 89 | RECUR EPOCH: 15, loss: 0.5792982862810049, train_acc_e: 0.7050359712230215, train_acc_i: 0.6460905349794238 time: 0.09498010873794556 90 | valid_acc_e: 0.6964285714285714, valid_acc_i: 0.6, test_acc_e: 0.5333333333333333, test_acc_i: 0.27586206896551724 91 | jumping... 92 | 93 | RECUR EPOCH: 16, loss: 0.6057494206683626, train_acc_e: 0.7140287769784173, train_acc_i: 0.6584362139917695 time: 0.093644118309021 94 | valid_acc_e: 0.5535714285714286, valid_acc_i: 0.4, test_acc_e: 0.5066666666666667, test_acc_i: 0.27586206896551724 95 | jumping... 96 | 97 | RECUR EPOCH: 17, loss: 0.5289612836798523, train_acc_e: 0.7589928057553957, train_acc_i: 0.6625514403292181 time: 0.09545961221059164 98 | valid_acc_e: 0.6785714285714286, valid_acc_i: 0.64, test_acc_e: 0.5733333333333334, test_acc_i: 0.4482758620689655 99 | jumping... 100 | 101 | RECUR EPOCH: 18, loss: 0.5154155585991502, train_acc_e: 0.7589928057553957, train_acc_i: 0.6790123456790124 time: 0.09287167390187581 102 | valid_acc_e: 0.625, valid_acc_i: 0.52, test_acc_e: 0.5333333333333333, test_acc_i: 0.27586206896551724 103 | jumping... 104 | 105 | RECUR EPOCH: 19, loss: 0.5485985622484497, train_acc_e: 0.7446043165467626, train_acc_i: 0.6502057613168725 time: 0.09662172396977743 106 | valid_acc_e: 0.6071428571428571, valid_acc_i: 0.52, test_acc_e: 0.5866666666666667, test_acc_i: 0.3448275862068966 107 | jumping... 108 | 109 | RECUR EPOCH: 20, loss: 0.46297485249523274, train_acc_e: 0.7931654676258992, train_acc_i: 0.7242798353909465 time: 0.0916707158088684 110 | valid_acc_e: 0.6428571428571429, valid_acc_i: 0.56, test_acc_e: 0.6666666666666666, test_acc_i: 0.4482758620689655 111 | jumping... 112 | 113 | RECUR EPOCH: 21, loss: 0.4903093247747225, train_acc_e: 0.7805755395683454, train_acc_i: 0.7078189300411523 time: 0.09656307299931845 114 | valid_acc_e: 0.6964285714285714, valid_acc_i: 0.6, test_acc_e: 0.6266666666666667, test_acc_i: 0.4482758620689655 115 | jumping... 116 | 117 | RECUR EPOCH: 22, loss: 0.39395640612629707, train_acc_e: 0.8201438848920863, train_acc_i: 0.7366255144032922 time: 0.09862770636876424 118 | valid_acc_e: 0.7142857142857143, valid_acc_i: 0.68, test_acc_e: 0.6, test_acc_i: 0.4827586206896552 119 | jumping... 120 | 121 | RECUR EPOCH: 23, loss: 0.36069519137158806, train_acc_e: 0.841726618705036, train_acc_i: 0.7695473251028807 time: 0.09261656999588012 122 | valid_acc_e: 0.6964285714285714, valid_acc_i: 0.64, test_acc_e: 0.6266666666666667, test_acc_i: 0.41379310344827586 123 | jumping... 124 | 125 | RECUR EPOCH: 24, loss: 0.3598016642739253, train_acc_e: 0.829136690647482, train_acc_i: 0.7489711934156379 time: 0.09933442274729411 126 | valid_acc_e: 0.7321428571428571, valid_acc_i: 0.68, test_acc_e: 0.64, test_acc_i: 0.5517241379310345 127 | jumping... 128 | 129 | RECUR EPOCH: 25, loss: 0.4195781421268918, train_acc_e: 0.8075539568345323, train_acc_i: 0.7448559670781894 time: 0.09181210199991861 130 | valid_acc_e: 0.48214285714285715, valid_acc_i: 0.36, test_acc_e: 0.5733333333333334, test_acc_i: 0.3448275862068966 131 | jumping... 132 | 133 | RECUR EPOCH: 26, loss: 0.42452051021434645, train_acc_e: 0.8093525179856115, train_acc_i: 0.7489711934156379 time: 0.09510379234949748 134 | valid_acc_e: 0.7857142857142857, valid_acc_i: 0.76, test_acc_e: 0.6533333333333333, test_acc_i: 0.4482758620689655 135 | jumping... 136 | 137 | RECUR EPOCH: 27, loss: 0.332387818230523, train_acc_e: 0.8489208633093526, train_acc_i: 0.7736625514403292 time: 0.09798026084899902 138 | valid_acc_e: 0.7857142857142857, valid_acc_i: 0.8, test_acc_e: 0.64, test_acc_i: 0.5172413793103449 139 | originial 0.8 140 | saving... 0.8 141 | saveing ok! 142 | final test temp_acc:0.5, ans_acc:0.5 143 | 144 | RECUR EPOCH: 28, loss: 0.32374130060643325, train_acc_e: 0.8489208633093526, train_acc_i: 0.7654320987654321 time: 0.09599583148956299 145 | valid_acc_e: 0.625, valid_acc_i: 0.48, test_acc_e: 0.6133333333333333, test_acc_i: 0.3103448275862069 146 | jumping... 147 | 148 | RECUR EPOCH: 29, loss: 0.3817319320553124, train_acc_e: 0.8219424460431655, train_acc_i: 0.7695473251028807 time: 0.09603599707285564 149 | valid_acc_e: 0.5357142857142857, valid_acc_i: 0.32, test_acc_e: 0.6, test_acc_i: 0.3793103448275862 150 | jumping... 151 | 152 | RECUR EPOCH: 30, loss: 0.3141926777215652, train_acc_e: 0.8525179856115108, train_acc_i: 0.7901234567901234 time: 0.0925633192062378 153 | valid_acc_e: 0.6964285714285714, valid_acc_i: 0.68, test_acc_e: 0.6, test_acc_i: 0.4827586206896552 154 | jumping... 155 | 156 | RECUR EPOCH: 31, loss: 0.22721145869282539, train_acc_e: 0.8812949640287769, train_acc_i: 0.8148148148148148 time: 0.09542576471964519 157 | valid_acc_e: 0.7321428571428571, valid_acc_i: 0.64, test_acc_e: 0.72, test_acc_i: 0.4827586206896552 158 | jumping... 159 | 160 | RECUR EPOCH: 32, loss: 0.21175358324875065, train_acc_e: 0.8902877697841727, train_acc_i: 0.8106995884773662 time: 0.09543842077255249 161 | valid_acc_e: 0.6607142857142857, valid_acc_i: 0.6, test_acc_e: 0.7066666666666667, test_acc_i: 0.5172413793103449 162 | jumping... 163 | 164 | RECUR EPOCH: 33, loss: 0.16183604134453666, train_acc_e: 0.9190647482014388, train_acc_i: 0.8765432098765432 time: 0.09457332690556844 165 | valid_acc_e: 0.6071428571428571, valid_acc_i: 0.56, test_acc_e: 0.68, test_acc_i: 0.5517241379310345 166 | jumping... 167 | 168 | RECUR EPOCH: 34, loss: 0.17849770151538613, train_acc_e: 0.9118705035971223, train_acc_i: 0.8600823045267489 time: 0.0989931066830953 169 | valid_acc_e: 0.6428571428571429, valid_acc_i: 0.56, test_acc_e: 0.6666666666666666, test_acc_i: 0.4482758620689655 170 | jumping... 171 | 172 | RECUR EPOCH: 35, loss: 0.17104502617086403, train_acc_e: 0.9100719424460432, train_acc_i: 0.8683127572016461 time: 0.09347049792607626 173 | valid_acc_e: 0.6785714285714286, valid_acc_i: 0.56, test_acc_e: 0.64, test_acc_i: 0.4827586206896552 174 | jumping... 175 | 176 | RECUR EPOCH: 36, loss: 0.16812368626457183, train_acc_e: 0.9100719424460432, train_acc_i: 0.8600823045267489 time: 0.09435887336730957 177 | valid_acc_e: 0.6071428571428571, valid_acc_i: 0.56, test_acc_e: 0.6666666666666666, test_acc_i: 0.41379310344827586 178 | jumping... 179 | 180 | RECUR EPOCH: 37, loss: 0.15582417485154706, train_acc_e: 0.920863309352518, train_acc_i: 0.8641975308641975 time: 0.09741787115732829 181 | valid_acc_e: 0.6071428571428571, valid_acc_i: 0.48, test_acc_e: 0.6933333333333334, test_acc_i: 0.4482758620689655 182 | jumping... 183 | 184 | RECUR EPOCH: 38, loss: 0.17178444214809088, train_acc_e: 0.8992805755395683, train_acc_i: 0.8353909465020576 time: 0.09656809568405152 185 | valid_acc_e: 0.75, valid_acc_i: 0.72, test_acc_e: 0.6533333333333333, test_acc_i: 0.5172413793103449 186 | jumping... 187 | 188 | RECUR EPOCH: 39, loss: 0.23990754335505482, train_acc_e: 0.9100719424460432, train_acc_i: 0.8395061728395061 time: 0.09630475044250489 189 | valid_acc_e: 0.75, valid_acc_i: 0.76, test_acc_e: 0.6666666666666666, test_acc_i: 0.4482758620689655 190 | jumping... 191 | 192 | RECUR EPOCH: 40, loss: 0.13023324586727, train_acc_e: 0.9370503597122302, train_acc_i: 0.8930041152263375 time: 0.09295337200164795 193 | valid_acc_e: 0.6071428571428571, valid_acc_i: 0.56, test_acc_e: 0.6666666666666666, test_acc_i: 0.4482758620689655 194 | jumping... 195 | 196 | RECUR EPOCH: 41, loss: 0.14373765438181874, train_acc_e: 0.9334532374100719, train_acc_i: 0.8930041152263375 time: 0.09432597160339355 197 | valid_acc_e: 0.6071428571428571, valid_acc_i: 0.6, test_acc_e: 0.5733333333333334, test_acc_i: 0.3793103448275862 198 | jumping... 199 | 200 | RECUR EPOCH: 42, loss: 0.10655531520215572, train_acc_e: 0.9496402877697842, train_acc_i: 0.9218106995884774 time: 0.09521870215733846 201 | valid_acc_e: 0.75, valid_acc_i: 0.72, test_acc_e: 0.68, test_acc_i: 0.4482758620689655 202 | jumping... 203 | 204 | RECUR EPOCH: 43, loss: 0.10571475958627928, train_acc_e: 0.9460431654676259, train_acc_i: 0.9012345679012346 time: 0.09794450203577677 205 | valid_acc_e: 0.6785714285714286, valid_acc_i: 0.6, test_acc_e: 0.7066666666666667, test_acc_i: 0.5517241379310345 206 | jumping... 207 | 208 | RECUR EPOCH: 44, loss: 0.10741040772861904, train_acc_e: 0.9406474820143885, train_acc_i: 0.9053497942386831 time: 0.09870184659957885 209 | valid_acc_e: 0.7678571428571429, valid_acc_i: 0.76, test_acc_e: 0.68, test_acc_i: 0.5862068965517241 210 | jumping... 211 | 212 | RECUR EPOCH: 45, loss: 0.13996641748726613, train_acc_e: 0.9334532374100719, train_acc_i: 0.8765432098765432 time: 0.09284635384877522 213 | valid_acc_e: 0.625, valid_acc_i: 0.56, test_acc_e: 0.7066666666666667, test_acc_i: 0.4827586206896552 214 | jumping... 215 | 216 | RECUR EPOCH: 46, loss: 0.12933335284637326, train_acc_e: 0.9370503597122302, train_acc_i: 0.897119341563786 time: 0.09647434155146281 217 | valid_acc_e: 0.6071428571428571, valid_acc_i: 0.56, test_acc_e: 0.6666666666666666, test_acc_i: 0.3793103448275862 218 | jumping... 219 | 220 | RECUR EPOCH: 47, loss: 0.06780000196562873, train_acc_e: 0.9676258992805755, train_acc_i: 0.9382716049382716 time: 0.0968562642733256 221 | valid_acc_e: 0.7321428571428571, valid_acc_i: 0.72, test_acc_e: 0.76, test_acc_i: 0.5862068965517241 222 | jumping... 223 | 224 | RECUR EPOCH: 48, loss: 0.0724956984382598, train_acc_e: 0.9676258992805755, train_acc_i: 0.934156378600823 time: 0.09332544406255086 225 | valid_acc_e: 0.75, valid_acc_i: 0.72, test_acc_e: 0.6666666666666666, test_acc_i: 0.5517241379310345 226 | jumping... 227 | 228 | RECUR EPOCH: 49, loss: 0.08274855584274103, train_acc_e: 0.947841726618705, train_acc_i: 0.9053497942386831 time: 0.09943470557530722 229 | valid_acc_e: 0.5714285714285714, valid_acc_i: 0.44, test_acc_e: 0.64, test_acc_i: 0.41379310344827586 230 | jumping... 231 | 232 | RECUR EPOCH: 50, loss: 0.06540786852071315, train_acc_e: 0.9568345323741008, train_acc_i: 0.9176954732510288 time: 0.09516379833221436 233 | valid_acc_e: 0.6607142857142857, valid_acc_i: 0.52, test_acc_e: 0.6933333333333334, test_acc_i: 0.41379310344827586 234 | jumping... 235 | 236 | RECUR EPOCH: 51, loss: 0.05518364053940086, train_acc_e: 0.9658273381294964, train_acc_i: 0.9465020576131687 time: 0.09435760180155436 237 | valid_acc_e: 0.7321428571428571, valid_acc_i: 0.72, test_acc_e: 0.6533333333333333, test_acc_i: 0.5517241379310345 238 | jumping... 239 | 240 | RECUR EPOCH: 52, loss: 0.09833989197334635, train_acc_e: 0.9586330935251799, train_acc_i: 0.9094650205761317 time: 0.0969707727432251 241 | valid_acc_e: 0.5892857142857143, valid_acc_i: 0.52, test_acc_e: 0.6666666666666666, test_acc_i: 0.4827586206896552 242 | jumping... 243 | 244 | RECUR EPOCH: 53, loss: 0.09804084958362971, train_acc_e: 0.9532374100719424, train_acc_i: 0.9300411522633745 time: 0.09771036307017009 245 | valid_acc_e: 0.7321428571428571, valid_acc_i: 0.72, test_acc_e: 0.6533333333333333, test_acc_i: 0.4827586206896552 246 | jumping... 247 | 248 | RECUR EPOCH: 54, loss: 0.09636338147116297, train_acc_e: 0.9622302158273381, train_acc_i: 0.9300411522633745 time: 0.09834134578704834 249 | valid_acc_e: 0.5535714285714286, valid_acc_i: 0.52, test_acc_e: 0.68, test_acc_i: 0.4482758620689655 250 | jumping... 251 | 252 | RECUR EPOCH: 55, loss: 0.07666621182435825, train_acc_e: 0.9532374100719424, train_acc_i: 0.9135802469135802 time: 0.0945337176322937 253 | valid_acc_e: 0.6607142857142857, valid_acc_i: 0.68, test_acc_e: 0.7466666666666667, test_acc_i: 0.5517241379310345 254 | jumping... 255 | 256 | RECUR EPOCH: 56, loss: 0.04281204975681541, train_acc_e: 0.9820143884892086, train_acc_i: 0.9588477366255144 time: 0.09585176308949789 257 | valid_acc_e: 0.7857142857142857, valid_acc_i: 0.76, test_acc_e: 0.6666666666666666, test_acc_i: 0.5862068965517241 258 | jumping... 259 | 260 | RECUR EPOCH: 57, loss: 0.07147320019610134, train_acc_e: 0.9586330935251799, train_acc_i: 0.9259259259259259 time: 0.09756191174189249 261 | valid_acc_e: 0.6607142857142857, valid_acc_i: 0.6, test_acc_e: 0.6666666666666666, test_acc_i: 0.41379310344827586 262 | jumping... 263 | 264 | RECUR EPOCH: 58, loss: 0.049979879440348825, train_acc_e: 0.9712230215827338, train_acc_i: 0.9465020576131687 time: 0.09521978696187337 265 | valid_acc_e: 0.6964285714285714, valid_acc_i: 0.64, test_acc_e: 0.6666666666666666, test_acc_i: 0.5172413793103449 266 | jumping... 267 | 268 | RECUR EPOCH: 59, loss: 0.03579042790603245, train_acc_e: 0.9820143884892086, train_acc_i: 0.9629629629629629 time: 0.09564072688420613 269 | valid_acc_e: 0.6964285714285714, valid_acc_i: 0.68, test_acc_e: 0.6533333333333333, test_acc_i: 0.41379310344827586 270 | jumping... 271 | 272 | RECUR EPOCH: 60, loss: 0.07674725838158847, train_acc_e: 0.960431654676259, train_acc_i: 0.934156378600823 time: 0.09347194035847982 273 | valid_acc_e: 0.7321428571428571, valid_acc_i: 0.68, test_acc_e: 0.6266666666666667, test_acc_i: 0.4482758620689655 274 | jumping... 275 | 276 | RECUR EPOCH: 61, loss: 0.03581807494868712, train_acc_e: 0.9748201438848921, train_acc_i: 0.9506172839506173 time: 0.09602559407552083 277 | valid_acc_e: 0.6785714285714286, valid_acc_i: 0.6, test_acc_e: 0.7066666666666667, test_acc_i: 0.5517241379310345 278 | jumping... 279 | 280 | RECUR EPOCH: 62, loss: 0.02495744525466436, train_acc_e: 0.9802158273381295, train_acc_i: 0.9629629629629629 time: 0.09728382031122844 281 | valid_acc_e: 0.625, valid_acc_i: 0.6, test_acc_e: 0.6533333333333333, test_acc_i: 0.4482758620689655 282 | jumping... 283 | 284 | RECUR EPOCH: 63, loss: 0.039477180061026365, train_acc_e: 0.9784172661870504, train_acc_i: 0.9629629629629629 time: 0.09547470808029175 285 | valid_acc_e: 0.7857142857142857, valid_acc_i: 0.76, test_acc_e: 0.68, test_acc_i: 0.5172413793103449 286 | jumping... 287 | 288 | RECUR EPOCH: 64, loss: 0.041618613725520455, train_acc_e: 0.9820143884892086, train_acc_i: 0.9629629629629629 time: 0.09983280102411905 289 | valid_acc_e: 0.6607142857142857, valid_acc_i: 0.64, test_acc_e: 0.6133333333333333, test_acc_i: 0.5172413793103449 290 | jumping... 291 | 292 | RECUR EPOCH: 65, loss: 0.027657700135322757, train_acc_e: 0.9856115107913669, train_acc_i: 0.9753086419753086 time: 0.0956899086634318 293 | valid_acc_e: 0.7321428571428571, valid_acc_i: 0.72, test_acc_e: 0.6666666666666666, test_acc_i: 0.5172413793103449 294 | jumping... 295 | 296 | RECUR EPOCH: 66, loss: 0.037495824657840494, train_acc_e: 0.9838129496402878, train_acc_i: 0.9629629629629629 time: 0.09479515552520752 297 | valid_acc_e: 0.7321428571428571, valid_acc_i: 0.68, test_acc_e: 0.7333333333333333, test_acc_i: 0.4827586206896552 298 | jumping... 299 | 300 | RECUR EPOCH: 67, loss: 0.0549890123644974, train_acc_e: 0.9730215827338129, train_acc_i: 0.9506172839506173 time: 0.0966543157895406 301 | valid_acc_e: 0.5892857142857143, valid_acc_i: 0.48, test_acc_e: 0.6666666666666666, test_acc_i: 0.3793103448275862 302 | jumping... 303 | 304 | RECUR EPOCH: 68, loss: 0.023487998686209627, train_acc_e: 0.9910071942446043, train_acc_i: 0.9794238683127572 time: 0.09579426447550456 305 | valid_acc_e: 0.7678571428571429, valid_acc_i: 0.72, test_acc_e: 0.6533333333333333, test_acc_i: 0.4482758620689655 306 | jumping... 307 | 308 | RECUR EPOCH: 69, loss: 0.029502890176243253, train_acc_e: 0.9838129496402878, train_acc_i: 0.9711934156378601 time: 0.09850571552912395 309 | valid_acc_e: 0.6964285714285714, valid_acc_i: 0.68, test_acc_e: 0.6533333333333333, test_acc_i: 0.4827586206896552 310 | jumping... 311 | 312 | RECUR EPOCH: 70, loss: 0.050007173891175434, train_acc_e: 0.9784172661870504, train_acc_i: 0.9588477366255144 time: 0.09385395050048828 313 | valid_acc_e: 0.7321428571428571, valid_acc_i: 0.68, test_acc_e: 0.6933333333333334, test_acc_i: 0.4482758620689655 314 | jumping... 315 | 316 | RECUR EPOCH: 71, loss: 0.0374770586382705, train_acc_e: 0.9910071942446043, train_acc_i: 0.9794238683127572 time: 0.09573783079783121 317 | valid_acc_e: 0.7857142857142857, valid_acc_i: 0.8, test_acc_e: 0.64, test_acc_i: 0.4827586206896552 318 | originial 0.8 319 | saving... 0.8 320 | saveing ok! 321 | final test temp_acc:0.5333333333333333, ans_acc:0.5 322 | 323 | RECUR EPOCH: 72, loss: 0.0281009430472007, train_acc_e: 0.9856115107913669, train_acc_i: 0.9670781893004116 time: 0.09709215958913167 324 | valid_acc_e: 0.75, valid_acc_i: 0.72, test_acc_e: 0.6533333333333333, test_acc_i: 0.4482758620689655 325 | jumping... 326 | 327 | RECUR EPOCH: 73, loss: 0.01708699586521451, train_acc_e: 0.987410071942446, train_acc_i: 0.9753086419753086 time: 0.0944095253944397 328 | valid_acc_e: 0.8214285714285714, valid_acc_i: 0.8, test_acc_e: 0.64, test_acc_i: 0.4827586206896552 329 | originial 0.8 330 | saving... 0.8 331 | saveing ok! 332 | final test temp_acc:0.5, ans_acc:0.5 333 | 334 | RECUR EPOCH: 74, loss: 0.05629130771743908, train_acc_e: 0.9838129496402878, train_acc_i: 0.9670781893004116 time: 0.09714263677597046 335 | valid_acc_e: 0.6785714285714286, valid_acc_i: 0.72, test_acc_e: 0.6, test_acc_i: 0.41379310344827586 336 | jumping... 337 | 338 | RECUR EPOCH: 75, loss: 0.04158908542659548, train_acc_e: 0.9748201438848921, train_acc_i: 0.9506172839506173 time: 0.09388278722763062 339 | valid_acc_e: 0.6964285714285714, valid_acc_i: 0.68, test_acc_e: 0.6666666666666666, test_acc_i: 0.4827586206896552 340 | jumping... 341 | 342 | RECUR EPOCH: 76, loss: 0.04258053513711372, train_acc_e: 0.9838129496402878, train_acc_i: 0.9670781893004116 time: 0.09425394932428996 343 | valid_acc_e: 0.8035714285714286, valid_acc_i: 0.84, test_acc_e: 0.6, test_acc_i: 0.4482758620689655 344 | originial 0.8 345 | saving... 0.84 346 | saveing ok! 347 | final test temp_acc:0.4666666666666667, ans_acc:0.43333333333333335 348 | 349 | RECUR EPOCH: 77, loss: 0.06083750326201749, train_acc_e: 0.9676258992805755, train_acc_i: 0.9465020576131687 time: 0.09651113351186116 350 | valid_acc_e: 0.6785714285714286, valid_acc_i: 0.68, test_acc_e: 0.6533333333333333, test_acc_i: 0.4482758620689655 351 | jumping... 352 | 353 | RECUR EPOCH: 78, loss: 0.04671263096877086, train_acc_e: 0.9802158273381295, train_acc_i: 0.9588477366255144 time: 0.09435170888900757 354 | valid_acc_e: 0.7678571428571429, valid_acc_i: 0.72, test_acc_e: 0.6933333333333334, test_acc_i: 0.4827586206896552 355 | jumping... 356 | 357 | RECUR EPOCH: 79, loss: 0.05908534688867414, train_acc_e: 0.9820143884892086, train_acc_i: 0.9588477366255144 time: 0.09631927410761515 358 | valid_acc_e: 0.6607142857142857, valid_acc_i: 0.6, test_acc_e: 0.7333333333333333, test_acc_i: 0.5172413793103449 359 | jumping... 360 | 361 | RECUR EPOCH: 80, loss: 0.028193465131239145, train_acc_e: 0.9910071942446043, train_acc_i: 0.9794238683127572 time: 0.09260084629058837 362 | valid_acc_e: 0.6785714285714286, valid_acc_i: 0.6, test_acc_e: 0.6666666666666666, test_acc_i: 0.4482758620689655 363 | jumping... 364 | 365 | RECUR EPOCH: 81, loss: 0.009333069808014627, train_acc_e: 0.9928057553956835, train_acc_i: 0.9876543209876543 time: 0.0965872049331665 366 | valid_acc_e: 0.6428571428571429, valid_acc_i: 0.52, test_acc_e: 0.6266666666666667, test_acc_i: 0.5172413793103449 367 | jumping... 368 | 369 | RECUR EPOCH: 82, loss: 0.011288232396168964, train_acc_e: 0.9946043165467626, train_acc_i: 0.9876543209876543 time: 0.09728428125381469 370 | valid_acc_e: 0.7321428571428571, valid_acc_i: 0.68, test_acc_e: 0.6933333333333334, test_acc_i: 0.4482758620689655 371 | jumping... 372 | 373 | RECUR EPOCH: 83, loss: 0.023994814577692574, train_acc_e: 0.9910071942446043, train_acc_i: 0.9794238683127572 time: 0.09791800578435263 374 | valid_acc_e: 0.7321428571428571, valid_acc_i: 0.68, test_acc_e: 0.68, test_acc_i: 0.4482758620689655 375 | jumping... 376 | 377 | RECUR EPOCH: 84, loss: 0.011793449894743568, train_acc_e: 0.9964028776978417, train_acc_i: 0.9917695473251029 time: 0.098715607325236 378 | valid_acc_e: 0.75, valid_acc_i: 0.68, test_acc_e: 0.68, test_acc_i: 0.41379310344827586 379 | jumping... 380 | 381 | RECUR EPOCH: 85, loss: 0.01421276430902167, train_acc_e: 0.9892086330935251, train_acc_i: 0.9753086419753086 time: 0.09430506229400634 382 | valid_acc_e: 0.6964285714285714, valid_acc_i: 0.64, test_acc_e: 0.6266666666666667, test_acc_i: 0.5517241379310345 383 | jumping... 384 | 385 | RECUR EPOCH: 86, loss: 0.007714040408777111, train_acc_e: 0.9946043165467626, train_acc_i: 0.9876543209876543 time: 0.09514733552932739 386 | valid_acc_e: 0.75, valid_acc_i: 0.72, test_acc_e: 0.6133333333333333, test_acc_i: 0.41379310344827586 387 | jumping... 388 | 389 | RECUR EPOCH: 87, loss: 0.02914390094011041, train_acc_e: 0.9928057553956835, train_acc_i: 0.9917695473251029 time: 0.09735808769861858 390 | valid_acc_e: 0.6964285714285714, valid_acc_i: 0.68, test_acc_e: 0.68, test_acc_i: 0.5172413793103449 391 | jumping... 392 | 393 | RECUR EPOCH: 88, loss: 0.029182899762481573, train_acc_e: 0.9928057553956835, train_acc_i: 0.9835390946502057 time: 0.09592658281326294 394 | valid_acc_e: 0.6964285714285714, valid_acc_i: 0.64, test_acc_e: 0.6533333333333333, test_acc_i: 0.4827586206896552 395 | jumping... 396 | 397 | RECUR EPOCH: 89, loss: 0.009332968670407256, train_acc_e: 0.9946043165467626, train_acc_i: 0.9876543209876543 time: 0.09735538562138875 398 | valid_acc_e: 0.7321428571428571, valid_acc_i: 0.68, test_acc_e: 0.64, test_acc_i: 0.4827586206896552 399 | jumping... 400 | 401 | RECUR EPOCH: 90, loss: 0.024998605864528765, train_acc_e: 0.9820143884892086, train_acc_i: 0.9670781893004116 time: 0.09479767481486003 402 | valid_acc_e: 0.7142857142857143, valid_acc_i: 0.64, test_acc_e: 0.5866666666666667, test_acc_i: 0.4482758620689655 403 | jumping... 404 | 405 | RECUR EPOCH: 91, loss: 0.022708293832378622, train_acc_e: 0.987410071942446, train_acc_i: 0.9711934156378601 time: 0.09715214967727662 406 | valid_acc_e: 0.7678571428571429, valid_acc_i: 0.72, test_acc_e: 0.56, test_acc_i: 0.4482758620689655 407 | jumping... 408 | 409 | RECUR EPOCH: 92, loss: 0.0042395228826644, train_acc_e: 0.9982014388489209, train_acc_i: 0.9958847736625515 time: 0.09556546608606974 410 | valid_acc_e: 0.7142857142857143, valid_acc_i: 0.72, test_acc_e: 0.68, test_acc_i: 0.5172413793103449 411 | jumping... 412 | 413 | RECUR EPOCH: 93, loss: 0.037599779972078376, train_acc_e: 0.9946043165467626, train_acc_i: 0.9917695473251029 time: 0.09529571135838827 414 | valid_acc_e: 0.7142857142857143, valid_acc_i: 0.68, test_acc_e: 0.6533333333333333, test_acc_i: 0.4482758620689655 415 | jumping... 416 | 417 | RECUR EPOCH: 94, loss: 0.019692288941623254, train_acc_e: 0.9910071942446043, train_acc_i: 0.9794238683127572 time: 0.09814429680506388 418 | valid_acc_e: 0.7321428571428571, valid_acc_i: 0.64, test_acc_e: 0.6533333333333333, test_acc_i: 0.4827586206896552 419 | jumping... 420 | 421 | RECUR EPOCH: 95, loss: 0.012651146805381456, train_acc_e: 0.9946043165467626, train_acc_i: 0.9876543209876543 time: 0.09417590697606405 422 | valid_acc_e: 0.6785714285714286, valid_acc_i: 0.68, test_acc_e: 0.6933333333333334, test_acc_i: 0.4482758620689655 423 | jumping... 424 | 425 | RECUR EPOCH: 96, loss: 0.003612281639634827, train_acc_e: 1.0, train_acc_i: 1.0 time: 0.0961928923924764 426 | valid_acc_e: 0.75, valid_acc_i: 0.72, test_acc_e: 0.64, test_acc_i: 0.4827586206896552 427 | jumping... 428 | 429 | RECUR EPOCH: 97, loss: 0.04159402045126192, train_acc_e: 0.9910071942446043, train_acc_i: 0.9876543209876543 time: 0.0963845173517863 430 | valid_acc_e: 0.6785714285714286, valid_acc_i: 0.6, test_acc_e: 0.68, test_acc_i: 0.4482758620689655 431 | jumping... 432 | 433 | RECUR EPOCH: 98, loss: 0.028848348652997624, train_acc_e: 0.9946043165467626, train_acc_i: 0.9876543209876543 time: 0.0943213144938151 434 | valid_acc_e: 0.7678571428571429, valid_acc_i: 0.76, test_acc_e: 0.6933333333333334, test_acc_i: 0.6206896551724138 435 | jumping... 436 | 437 | RECUR EPOCH: 99, loss: 0.017882701150659538, train_acc_e: 0.9928057553956835, train_acc_i: 0.9835390946502057 time: 0.09760322570800781 438 | valid_acc_e: 0.7321428571428571, valid_acc_i: 0.72, test_acc_e: 0.6666666666666666, test_acc_i: 0.4482758620689655 439 | jumping... 440 | 441 | RECUR EPOCH: 100, loss: 0.004860475782013724, train_acc_e: 0.9982014388489209, train_acc_i: 0.9958847736625515 time: 0.09372621774673462 442 | valid_acc_e: 0.6785714285714286, valid_acc_i: 0.64, test_acc_e: 0.7466666666666667, test_acc_i: 0.4827586206896552 443 | jumping... 444 | 445 | -------------------------------------------------------------------------------- /3_T-RNN_&_baselines/output/run_dolphin_rep.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Phybiolo57/MathWordProblemSolver/a7753a8762370f20fce86ea3001554c4f927b563/3_T-RNN_&_baselines/output/run_dolphin_rep.txt -------------------------------------------------------------------------------- /3_T-RNN_&_baselines/src/data_loader.py: -------------------------------------------------------------------------------- 1 | import json 2 | import numpy as np 3 | import time 4 | 5 | import random 6 | import pickle 7 | import sys 8 | 9 | from utils import * 10 | 11 | class DataLoader: 12 | def __init__(self, dataset):#, dataset, math23k_vocab_list, math23k_decode_list): 13 | 14 | self.dataset = dataset 15 | 16 | math23k_vocab_list = ['PAD','', 'EOS', '1', 'PI'] 17 | math23k_decode_list = ['temp_0','temp_1','temp_2','temp_3','temp_4','temp_5','temp_6','temp_7',\ 18 | 'temp_8', 'temp_9', 'temp_10', 'temp_11', 'temp_12','temp_13',#'temp_14', 19 | '1', 'PI', 'PAD', '','EOS'] 20 | 21 | #math23k_decode_list = ['temp_0','temp_1','temp_2','temp_3','1', 'PI', 'PAD', '','EOS'] 22 | for k, v in dataset.items(): 23 | for elem in v['template_text'].split(' '): 24 | if elem not in math23k_vocab_list: 25 | math23k_vocab_list.append(elem) 26 | 27 | self.train_list, self.valid_list, self.test_list = split_by_feilong_23k(dataset) 28 | self.data_set = dataset 29 | self.vocab_list = math23k_vocab_list 30 | self.vocab_dict = dict([(elem, idx) for idx, elem in enumerate(self.vocab_list)]) 31 | self.vocab_len = len(self.vocab_list) 32 | 33 | self.decode_classes_list = math23k_decode_list 34 | self.decode_classes_dict = dict([(elem, idx) for idx, elem in enumerate(self.decode_classes_list)]) 35 | self.decode_classed_len = len(self.decode_classes_list) 36 | 37 | self.decode_emb_idx = [self.vocab_dict[elem] for elem in self.decode_classes_list] 38 | 39 | 40 | 41 | 42 | def __len__(self): 43 | return self.data_size 44 | 45 | def _data_batch_preprocess(self, data_batch): 46 | batch_encode_idx = [] 47 | batch_encode_len = [] 48 | batch_encode_num_pos = [] 49 | 50 | batch_decode_idx = [] 51 | batch_decode_emb_idx = [] 52 | batch_decode_len = [] 53 | 54 | batch_idxs = [] 55 | batch_text = [] 56 | batch_num_list = [] 57 | batch_solution = [] 58 | 59 | batch_post_equation = [] 60 | 61 | batch_gd_tree = [] 62 | #batch_mask = [] 63 | 64 | #dict_keys(['numtemp_order', 'index', 'post_template', 'ans', 'num_list', 'template_text', 'mid_template', 'expression', 'text']) 65 | for elem in data_batch: 66 | idx = elem[0] 67 | encode_sen = elem[1]['template_text'][:] 68 | encode_sen_idx = string_2_idx_sen(encode_sen.strip().split(' '), self.vocab_dict) 69 | 70 | batch_encode_idx.append(encode_sen_idx) 71 | batch_encode_len.append(len(encode_sen_idx)) 72 | batch_encode_num_pos.append(elem[1]['num_position'][:]) 73 | 74 | decode_sen = elem[1]['numtemp_order'][:] 75 | #print (decode_sen) 76 | decode_sen.append('EOS') 77 | decode_sen_idx = string_2_idx_sen(decode_sen, self.decode_classes_dict) 78 | decode_sen_emb_idx = [self.decode_emb_idx[elem] for elem in decode_sen_idx] 79 | 80 | batch_decode_idx.append(decode_sen_idx) 81 | batch_decode_emb_idx.append(decode_sen_emb_idx) 82 | batch_decode_len.append(len(decode_sen_idx)) 83 | 84 | batch_idxs.append(idx) 85 | batch_text.append(encode_sen) 86 | batch_num_list.append(elem[1]['num_list'][:]) 87 | batch_solution.append(elem[1]['ans'][:]) 88 | batch_post_equation.append(elem[1]['post_template'][:]) 89 | #print (elem[1]['post_template'][2:]) 90 | 91 | batch_gd_tree.append(elem[1]['gd_tree_list']) 92 | 93 | 94 | max_encode_len = max(batch_encode_len) 95 | batch_encode_pad_idx = [] 96 | #max_decode_len = max(batch_decode_len) 97 | max_decode_len = 22 98 | batch_decode_pad_idx = [] 99 | batch_decode_pad_emb_idx = [] 100 | 101 | 102 | for i in range(len(data_batch)): 103 | encode_sen_idx = batch_encode_idx[i] 104 | encode_sen_pad_idx = pad_sen(\ 105 | encode_sen_idx, max_encode_len, self.vocab_dict['PAD']) 106 | batch_encode_pad_idx.append(encode_sen_pad_idx) 107 | 108 | decode_sen_idx = batch_decode_idx[i] 109 | decode_sen_pad_idx = pad_sen(\ 110 | decode_sen_idx, max_decode_len, self.decode_classes_dict['PAD']) 111 | #decode_sen_idx, max_decode_len, self.decode_classes_dict['PAD_token']) 112 | batch_decode_pad_idx.append(decode_sen_pad_idx) 113 | decode_sen_pad_emb_idx = [self.decode_emb_idx[elem] for elem in decode_sen_pad_idx] 114 | batch_decode_pad_emb_idx.append(decode_sen_pad_emb_idx) 115 | 116 | batch_data_dict = dict() 117 | 118 | batch_data_dict['batch_encode_idx'] = batch_encode_idx 119 | batch_data_dict['batch_encode_len'] = batch_encode_len 120 | batch_data_dict['batch_encode_num_pos'] = batch_encode_num_pos 121 | batch_data_dict['batch_encode_pad_idx'] = batch_encode_pad_idx 122 | 123 | batch_data_dict['batch_decode_idx'] = batch_decode_idx 124 | batch_data_dict['batch_decode_len'] = batch_decode_len 125 | batch_data_dict['batch_decode_emb_idx'] = batch_decode_emb_idx 126 | batch_data_dict['batch_decode_pad_idx'] = batch_decode_pad_idx 127 | batch_data_dict['batch_decode_pad_emb_idx'] = batch_decode_pad_emb_idx 128 | 129 | batch_data_dict['batch_index'] = batch_idxs 130 | batch_data_dict['batch_text'] = batch_text 131 | batch_data_dict['batch_num_list'] = batch_num_list 132 | batch_data_dict['batch_solution'] = batch_solution 133 | batch_data_dict['batch_post_equation'] = batch_post_equation 134 | 135 | batch_data_dict['batch_gd_tree'] = batch_gd_tree 136 | 137 | if len(data_batch) != 1: 138 | new_batch_data_dict = self._sorted_batch(batch_data_dict) 139 | else: 140 | new_batch_data_dict = batch_data_dict 141 | return new_batch_data_dict 142 | 143 | def _sorted_batch(self, batch_data_dict): 144 | new_batch_data_dict = dict() 145 | batch_encode_len = np.array(batch_data_dict['batch_encode_len']) 146 | sort_idx = np.argsort(-batch_encode_len) 147 | for key, value in batch_data_dict.items(): 148 | new_batch_data_dict[key] = np.array(value)[sort_idx] 149 | return new_batch_data_dict 150 | 151 | def get_batch(self, data_list, batch_size, verbose=0): 152 | #print data_list 153 | batch_num = int(len(data_list)/batch_size)+1 154 | for idx in range(batch_num): 155 | batch_start = idx*batch_size 156 | batch_end = min((idx+1)*batch_size, len(data_list)) 157 | #print batch_start, batch_end, len(data_list) 158 | batch_data_dict = self._data_batch_preprocess(data_list[batch_start: batch_end]) 159 | yield batch_data_dict 160 | 161 | -------------------------------------------------------------------------------- /3_T-RNN_&_baselines/src/main.py: -------------------------------------------------------------------------------- 1 | import json 2 | import numpy as np 3 | import time 4 | from tqdm import tqdm_notebook as tqdm 5 | 6 | import torch 7 | import torch.nn as nn 8 | from torch.nn import Parameter 9 | import torch.nn.functional as F 10 | from torch.utils.data import Dataset 11 | from torch.autograd import Variable 12 | import torch.optim as optim 13 | from torch.nn.utils import clip_grad_norm_ 14 | import logging 15 | import random 16 | import pickle 17 | import sys 18 | 19 | import os 20 | 21 | from trainer import * 22 | from utils import * 23 | from model import * 24 | from data_loader import * 25 | 26 | 27 | import warnings 28 | warnings.filterwarnings("ignore") 29 | 30 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 31 | os.environ["CUDA_VISIBLE_DEVICES"] = "0" 32 | 33 | #np.random.seed(123) 34 | #random.seed(123) 35 | 36 | def main(): 37 | 38 | dataset = read_data_json("./data/final_dolphin_data_replicate.json") 39 | #emb_vectors = np.load('./data/emb_100.npy') 40 | #dict_keys(['text', 'ans', 'mid_template', 'num_position', 'post_template', 'num_list', \ 41 | #'template_text', 'expression', 'numtemp_order', 'index', 'gd_tree_list']) 42 | count = 0 43 | max_l = 0 44 | #norm_templates = read_data_json("./data/post_dup_templates_num.json") 45 | for key, elem in dataset.items(): 46 | #print (elem['post_template']) 47 | #print (norm_templates[key]) 48 | #print () 49 | #elem['post_template'] = norm_templates[key] 50 | elem['gd_tree_list'] = form_gdtree(elem) 51 | if len(elem['gd_tree_list']): 52 | #print (elem.keys()) 53 | #print (elem['text']) 54 | #print (elem['mid_template']) 55 | #print (elem['post_template']) 56 | #print (elem['post_template'][2:]) 57 | l = max([int(i.split('_')[1]) for i in set(elem['post_template']) if 'temp' in i]) 58 | if max_l < l: 59 | max_l = l 60 | count += 1 61 | print (max_l) 62 | #print (elem['gd_tree_list']) 63 | print (count) 64 | 65 | data_loader = DataLoader(dataset) 66 | print ('loading finished') 67 | 68 | if os.path.isfile("./ablation_recursive-Copy1.log"): 69 | os.remove("./ablation_recursive-Copy1.log") 70 | 71 | logger = logging.getLogger() 72 | fhandler = logging.FileHandler(filename='mylog.log', mode='w') 73 | formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') 74 | fhandler.setFormatter(formatter) 75 | logger.addHandler(fhandler) 76 | logger.setLevel(logging.DEBUG) 77 | 78 | 79 | params = {"batch_size": 32, 80 | "start_epoch" : 1, 81 | "n_epoch": 50, # 100, 50 82 | "rnn_classes":4, # 5, 4 83 | "save_file": "model_xxx_att.pkl" 84 | } 85 | encode_params = { 86 | "emb_size": 100, 87 | "hidden_size": 160, # 160, 128 88 | "input_dropout_p": 0.2, 89 | "dropout_p": 0.5, 90 | "n_layers": 2, 91 | "bidirectional": True, 92 | "rnn_cell": None, 93 | "rnn_cell_name": 'lstm', 94 | "variable_lengths_flag": True 95 | } 96 | 97 | #dataset = read_data_json("/home/wanglei/aaai_2019/pointer_math_dqn/dataset/source_2/math23k_final.json") 98 | #emb_vectors = np.load('/home/wanglei/aaai_2019/parsing_for_mwp/data/source_2/emb_100.npy') 99 | #data_loader = DataLoader(dataset) 100 | #print ('loading finished') 101 | 102 | 103 | recu_nn = RecursiveNN(data_loader.vocab_len, encode_params['emb_size'], params["rnn_classes"]) 104 | #recu_nn = recu_nn.cuda() 105 | recu_nn = recu_nn.to(device) 106 | self_att_recu_tree = Self_ATT_RTree(data_loader, encode_params, recu_nn) 107 | #self_att_recu_tree = self_att_recu_tree.cuda() 108 | self_att_recu_tree = self_att_recu_tree.to(device) 109 | #for name, params in self_att_recu_tree.named_children(): 110 | # print (name, params) 111 | optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad, self_att_recu_tree.parameters()), \ 112 | lr=0.01, momentum=0.9, dampening=0.0) 113 | 114 | trainer = Trainer(data_loader, params) 115 | trainer.train(self_att_recu_tree, optimizer) 116 | 117 | 118 | main() 119 | -------------------------------------------------------------------------------- /3_T-RNN_&_baselines/src/model.py: -------------------------------------------------------------------------------- 1 | import json 2 | import numpy as np 3 | import time 4 | 5 | import torch 6 | import torch.nn as nn 7 | from torch.nn import Parameter 8 | import torch.nn.functional as F 9 | from torch.utils.data import Dataset 10 | from torch.autograd import Variable 11 | import torch.optim as optim 12 | from torch.nn.utils import clip_grad_norm_ 13 | import logging 14 | import random 15 | import pickle 16 | import sys 17 | 18 | from utils import * 19 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 20 | 21 | class BaseRNN(nn.Module): 22 | def __init__(self, vocab_size, emb_size, hidden_size, input_dropout_p, dropout_p, \ 23 | n_layers, rnn_cell_name): 24 | super(BaseRNN, self).__init__() 25 | self.vocab_size = vocab_size 26 | self.emb_size = emb_size 27 | self.hidden_size = hidden_size 28 | self.n_layers = n_layers 29 | self.input_dropout_p = input_dropout_p 30 | self.input_dropout = nn.Dropout(p=input_dropout_p) 31 | self.rnn_cell_name = rnn_cell_name 32 | if rnn_cell_name.lower() == 'lstm': 33 | self.rnn_cell = nn.LSTM 34 | elif rnn_cell_name.lower() == 'gru': 35 | self.rnn_cell = nn.GRU 36 | else: 37 | raise ValueError("Unsupported RNN Cell: {0}".format(rnn_cell_name)) 38 | self.dropout_p = dropout_p 39 | 40 | def forward(self, *args, **kwargs): 41 | raise NotImplementedError() 42 | 43 | class EncoderRNN(BaseRNN): 44 | def __init__(self, vocab_size, embed_model, emb_size=100, hidden_size=128, \ 45 | input_dropout_p=0, dropout_p=0, n_layers=1, bidirectional=False, \ 46 | rnn_cell=None, rnn_cell_name='gru', variable_lengths_flag=True): 47 | super(EncoderRNN, self).__init__(vocab_size, emb_size, hidden_size, 48 | input_dropout_p, dropout_p, n_layers, rnn_cell_name) 49 | self.variable_lengths_flag = variable_lengths_flag 50 | self.bidirectional = bidirectional 51 | self.embedding = embed_model 52 | if rnn_cell == None: 53 | self.rnn = self.rnn_cell(emb_size, hidden_size, n_layers, 54 | batch_first=True, bidirectional=bidirectional, dropout=dropout_p) 55 | else: 56 | self.rnn = rnn_cell 57 | 58 | def forward(self, input_var, input_lengths=None): 59 | embedded = self.embedding(input_var) 60 | embedded = self.input_dropout(embedded) 61 | #pdb.set_trace() 62 | if self.variable_lengths_flag: 63 | embedded = nn.utils.rnn.pack_padded_sequence(embedded, input_lengths, batch_first=True) 64 | output, hidden = self.rnn(embedded) 65 | if self.variable_lengths_flag: 66 | output, _ = nn.utils.rnn.pad_packed_sequence(output, batch_first=True) 67 | return output, hidden 68 | 69 | class Attention_1(nn.Module): 70 | def __init__(self, input_size, output_size): 71 | super(Attention_1, self).__init__() 72 | self.linear_out = nn.Linear(input_size, output_size) 73 | #self.mask = Parameter(torch.ones(1), requires_grad=False) 74 | 75 | #def set_mask(self, batch_size, input_length, num_pos): 76 | # self.mask = self.mask.repeat(input_length).unsqueeze(0).repeat(batch_size, 1) 77 | # for mask_i in range(batch_size): 78 | # self.mask[mask_i][num_pos[mask_i]] = 1 79 | def init_mask(self, size_0, size_1, input_length): 80 | mask = Parameter(torch.ones(1), requires_grad=False) 81 | mask = mask.repeat(size_1).unsqueeze(0).repeat(size_0, 1) 82 | #for i in range(input_length) 83 | input_index = list(range(input_length)) 84 | for i in range(size_0): 85 | mask[i][input_index] = 0 86 | #print (mask) 87 | mask = mask.byte() 88 | #mask = mask.cuda() 89 | mask = mask.to(device) 90 | return mask 91 | 92 | 93 | 94 | def _forward(self, output, context, input_lengths, mask): 95 | ''' 96 | output: len x hidden_size 97 | context: num_len x hidden_size 98 | input_lengths: torch scalar 99 | ''' 100 | #print (output.size()) torch.Size([5, 256]) 101 | #print (.size()) torch.Size([80, 256]) 102 | #print (input_lengths) 103 | attn = torch.matmul(output, context.transpose(1,0)) 104 | #print (attn.size()) 0 x 1 105 | attn.data.masked_fill_(mask, -float('inf')) 106 | attn = F.softmax(attn, dim=1) 107 | #print (attn) 108 | mix = torch.matmul(attn, context) 109 | #print ("mix:", mix) 110 | #print ("output:", output) 111 | combined = torch.cat((mix, output), dim=1) 112 | #print ("combined:",combined) 113 | output = F.tanh(self.linear_out(combined)) 114 | 115 | #print ("output:",output) 116 | #print ("------------") 117 | #print () 118 | return output, attn 119 | 120 | 121 | 122 | 123 | def forward(self, output, context, num_pos, input_lengths): 124 | ''' 125 | output: decoder, (batch, 1, hiddem_dim2) 126 | context: from encoder, (batch, n, hidden_dim1) 127 | actually, dim2 == dim1, otherwise cannot do matrix multiplication 128 | 129 | ''' 130 | batch_size = output.size(0) 131 | hidden_size = output.size(2) 132 | input_size = context.size(1) 133 | #print ('att:', hidden_size, input_size) 134 | #print ("context", context.size()) 135 | 136 | attn_list = [] 137 | mask_list = [] 138 | output_list = [] 139 | for b_i in range(batch_size): 140 | per_output = output[b_i] 141 | per_num_pos = num_pos[b_i] 142 | #print(context, num_pos) 143 | current_output = per_output[per_num_pos] 144 | per_mask = self.init_mask(len(per_num_pos), input_size, input_lengths[b_i]) 145 | mask_list.append(per_mask) 146 | #print ("current_context:", current_context.size()) 147 | per_output, per_attn = self._forward(current_output, context[b_i], input_lengths[b_i], per_mask) 148 | #for p_j in range(len(per_num_pos)): 149 | # current_context = per_context[per_num_pos[p_j]] 150 | # print ("c_context:", current_context.size()) 151 | output_list.append(per_output) 152 | attn_list.append(per_attn) 153 | 154 | 155 | 156 | #self.set_mask(batch_size, input_size, num_pos) 157 | # (b, o, dim) * (b, dim, i) -> (b, o, i) 158 | ''' 159 | attn = torch.bmm(output, context.transpose(1,2)) 160 | if self.mask is not None: 161 | attn.data.masked_fill_(self.mask, -float('inf')) 162 | attn = F.softmax(attn.view(-1, input_size), dim=1).view(batch_size, -1, input_size) 163 | 164 | # (b, o, i) * (b, i, dim) -> (b, o, dim) 165 | mix = torch.bmm(attn, context) 166 | 167 | combined = torch.cat((mix, output), dim=2) 168 | 169 | #output = F.tanh(self.linear_out(combined.view(-1, 2*hidden_size)))\ 170 | .view(batch_size, -1, hidden_size) 171 | 172 | # output: (b, o, dim) 173 | # attn : (b, o, i) 174 | #return output, attn 175 | ''' 176 | return output_list, attn_list 177 | 178 | class RecursiveNN(nn.Module): 179 | def __init__(self, vocabSize, embedSize=100, numClasses=5): 180 | super(RecursiveNN, self).__init__() 181 | #self.embedding = nn.Embedding(int(vocabSize), embedSize) 182 | #self.self-att -> embedding 183 | #self.att 184 | #self.self_att_model = self_att_model 185 | self.W = nn.Linear(2*embedSize, embedSize, bias=True) 186 | self.projection = nn.Linear(embedSize, numClasses, bias=True) 187 | self.activation = F.relu 188 | self.nodeProbList = [] 189 | self.labelList = [] 190 | self.classes = ['+','-','*','/','^'] 191 | 192 | def leaf_emb(self, node, num_embed, look_up): 193 | if node.is_leaf: 194 | #try: 195 | node.node_emb = num_embed[look_up.index(node.root_value)] 196 | # except: 197 | # print(node) 198 | else: 199 | self.leaf_emb(node.left_tree, num_embed, look_up) 200 | self.leaf_emb(node.right_tree, num_embed, look_up) 201 | 202 | def traverse(self, node): 203 | if node.is_leaf: 204 | currentNode = node.node_emb.unsqueeze(0) 205 | else: 206 | #currentNode = self.activation(self.W(torch.cat((self.traverse(node.left_tree),self.traverse(node.right_tree)),1))) 207 | left_vector = self.traverse(node.left_tree)#.unsqueeze(0) 208 | right_vector = self.traverse(node.right_tree)#.unsqueeze(0) 209 | #print (left_vector) 210 | combined_v = torch.cat((left_vector, right_vector),1) 211 | currentNode = self.activation(self.W(combined_v)) 212 | node.node_emb = currentNode.squeeze(0) 213 | assert node.is_leaf==False, "error is leaf" 214 | #self.nodeProbList.append(self.projection(currentNode)) 215 | proj_probs = self.projection(currentNode) 216 | self.nodeProbList.append(proj_probs) 217 | #node.numclass_probs = proj_probs 218 | self.labelList.append(self.classes.index(node.root_value)) 219 | return currentNode 220 | 221 | def forward(self, tree_node, num_embed, look_up): 222 | 223 | self.nodeProbList = [] 224 | self.labelList = [] 225 | self.leaf_emb(tree_node, num_embed, look_up) 226 | self.traverse(tree_node) 227 | self.labelList = torch.LongTensor(self.labelList) 228 | 229 | #self.labelList = self.labelList.cuda() 230 | #batch_loss_l = torch.FloatTensor([0])[0].cuda() 231 | batch_loss_l = torch.FloatTensor([0])[0].to(device) 232 | self.labelList = self.labelList.to(device) 233 | #print (torch.cat(self.nodeProbList).size()) 234 | #print (self.labelList) 235 | return torch.cat(self.nodeProbList)#, tree_node 236 | 237 | def getLoss_train(self, tree_node, num_embed, look_up): 238 | nodes = self.forward(tree_node, num_embed, look_up) 239 | predictions = nodes.max(dim=1)[1] 240 | loss = F.cross_entropy(input=nodes, target=self.labelList) 241 | 242 | #exit(0) 243 | #print (predictions.size()) 244 | #print () 245 | acc_elemwise, acc_t = self.compute_acc_elemwise(predictions, self.labelList) 246 | acc_integrate = self.compute_acc_integrate(predictions, self.labelList) 247 | return predictions, loss, acc_elemwise, acc_t, acc_integrate#, tree_node 248 | 249 | def test_forward(self, tree_node, num_embed, look_up): 250 | nodes = self.forward(tree_node, num_embed, look_up) 251 | predictions = nodes.max(dim=1)[1] 252 | acc_elemwise, acc_t = self.compute_acc_elemwise(predictions, self.labelList) 253 | acc_integrate = self.compute_acc_integrate(predictions, self.labelList) 254 | return predictions, acc_elemwise, acc_t, acc_integrate#, tree_node 255 | 256 | def predict_traverse(self, node): 257 | if node.is_leaf: 258 | currentNode = node.node_emb.unsqueeze(0) 259 | else: 260 | #currentNode = self.activation(self.W(torch.cat((self.traverse(node.left_tree),self.traverse(node.right_tree)),1))) 261 | left_vector = self.predict_traverse(node.left_tree)#.unsqueeze(0) 262 | right_vector = self.predict_traverse(node.right_tree)#.unsqueeze(0) 263 | #print (left_vector) 264 | combined_v = torch.cat((left_vector, right_vector),1) 265 | currentNode = self.activation(self.W(combined_v)) 266 | node.node_emb = currentNode.squeeze(0) 267 | assert node.is_leaf==False, "error is leaf" 268 | #self.nodeProbList.append(self.projection(currentNode)) 269 | proj_probs = self.projection(currentNode) 270 | node_id = proj_probs.max(dim=1)[1] 271 | node_marker = self.classes[node_id] 272 | node.root_value = node_marker 273 | #print ('_++_', node_marker) 274 | #print ("_++_", proj_probs.size()) 275 | #self.nodeProbList.append(proj_probs) 276 | #node.numclass_probs = proj_probs 277 | #self.labelList.append(self.classes.index(node.root_value)) 278 | return currentNode 279 | 280 | def predict(self, tree_node, num_embed, look_up):#, num_list, gold_ans): 281 | self.leaf_emb(tree_node, num_embed, look_up) 282 | self.predict_traverse(tree_node) 283 | post_equ = post_order(tree_node) 284 | #print ('tst:', post_equ) 285 | #pred_ans = post_solver(post_equ) 286 | 287 | return tree_node, post_equ#, pred_ans 288 | 289 | def compute_acc_elemwise(self, pred_tensor, label_tensor): 290 | return torch.sum((pred_tensor == label_tensor).int()).item() , len(pred_tensor) 291 | 292 | def compute_acc_integrate(self, pred_tensor, label_tensor): 293 | return 1 if torch.equal(pred_tensor, label_tensor) else 0 294 | 295 | def evaluate(self, trees): 296 | n = nAll = correctRoot = correctAll = 0.0 297 | return correctRoot / n, correctAll/nAll 298 | 299 | def forward_one_layer(self, left_node, right_node): 300 | left_vector = left_node.node_emb.unsqueeze(0) 301 | right_vector = right_node.node_emb.unsqueeze(0) 302 | combined_v = torch.cat((left_vector, right_vector),1) 303 | currentNode = self.activation(self.W(combined_v)) 304 | root_node = BinaryTree() 305 | root_node.is_leaf = False 306 | root_node.node_emb = currentNode.squeeze(0) 307 | proj_probs = self.projection(currentNode) 308 | #print ('recur:', proj_probs) 309 | #print ('r_m',proj_probs.max(1)[1][0].item()) 310 | pred_idx = proj_probs.max(1)[1][0].item() 311 | root_node.root_value = self.classes[pred_idx] 312 | root_node.left_tree = left_node 313 | root_node.right_tree = right_node 314 | return root_node, proj_probs 315 | 316 | class Self_ATT_RTree(nn.Module): 317 | def __init__(self, data_loader, encode_params, RecursiveNN): 318 | super(Self_ATT_RTree, self).__init__() 319 | self.data_loader = data_loader 320 | self.encode_params = encode_params 321 | self.embed_model = nn.Embedding(data_loader.vocab_len, encode_params['emb_size']) 322 | #self.embed_model = self.embed_model.cuda() 323 | self.embed_model = self.embed_model.to(device) 324 | self.encoder = EncoderRNN(vocab_size = data_loader.vocab_len, 325 | embed_model = self.embed_model, 326 | emb_size = encode_params['emb_size'], 327 | hidden_size = encode_params['hidden_size'], 328 | input_dropout_p = encode_params['input_dropout_p'], 329 | dropout_p = encode_params['dropout_p'], 330 | n_layers = encode_params['n_layers'], 331 | bidirectional = encode_params['bidirectional'], 332 | rnn_cell = encode_params['rnn_cell'], 333 | rnn_cell_name = encode_params['rnn_cell_name'], 334 | variable_lengths_flag = encode_params['variable_lengths_flag']) 335 | if encode_params['bidirectional'] == True: 336 | self.self_attention = Attention_1(encode_params['hidden_size']*4, encode_params['emb_size']) 337 | else: 338 | self.self_attention = Attention_1(encode_params['hidden_size']*2, encode_params['emb_size']) 339 | 340 | if encode_params['bidirectional'] == True: 341 | decoder_hidden_size = encode_params['hidden_size']*2 342 | 343 | self.recur_nn = RecursiveNN 344 | 345 | self._prepare_for_recur() 346 | self._prepare_for_pointer() 347 | 348 | def _prepare_for_recur(self): 349 | self.fixed_num_symbol = ['1', 'PI'] 350 | self.fixed_num_idx = [self.data_loader.vocab_dict[elem] for elem in self.fixed_num_symbol] 351 | self.fixed_num = torch.LongTensor(self.fixed_num_idx) 352 | 353 | #self.fixed_num = self.fixed_num.cuda() 354 | self.fixed_num = self.fixed_num.to(device) 355 | 356 | self.fixed_num_emb = self.embed_model(self.fixed_num) 357 | 358 | def _prepare_for_pointer(self): 359 | self.fixed_p_num_symbol = ['EOS','1', 'PI'] 360 | self.fixed_p_num_idx = [self.data_loader.vocab_dict[elem] for elem in self.fixed_p_num_symbol] 361 | self.fixed_p_num = torch.LongTensor(self.fixed_p_num_idx) 362 | 363 | #self.fixed_p_num = self.fixed_p_num.cuda() 364 | self.fixed_p_num = self.fixed_p_num.to(device) 365 | 366 | self.fixed_p_num_emb = self.embed_model(self.fixed_p_num) 367 | 368 | def forward(self, input_tensor, input_lengths, num_pos, b_gd_tree): 369 | encoder_outputs, encoder_hidden = self.encoder(input_tensor, input_lengths) 370 | en_output_list, en_attn_list = self.self_attention(encoder_outputs, encoder_outputs, num_pos, input_lengths) 371 | 372 | batch_size = len(en_output_list) 373 | 374 | batch_predictions = [] 375 | batch_acc_e_list = [] 376 | batch_acc_e_t_list = [] 377 | batch_acc_i_list = [] 378 | #batch_loss_l = torch.FloatTensor([0])[0].cuda() 379 | batch_loss_l = torch.FloatTensor([0])[0].to(device) 380 | batch_count = 0 381 | for b_i in range(batch_size): 382 | #en_output = en_output_list[b_i] 383 | en_output = torch.cat([self.fixed_num_emb, en_output_list[b_i]], dim=0) 384 | #print (num_pos) 385 | look_up = self.fixed_num_symbol + ['temp_'+str(temp_i) for temp_i in range(len(num_pos[b_i]))] 386 | #print (mid_order(b_gd_tree)) 387 | if len(b_gd_tree[b_i]) == 0: 388 | continue 389 | gd_tree_node = b_gd_tree[b_i][-1] 390 | 391 | #print (en_output) 392 | #print (look_up.index('temp_0')) 393 | #print (en_output[look_up.index("temp_0")]) 394 | #print () 395 | 396 | #print (mid_order(gd_tree_node)) 397 | #self.recur_nn(gd_tree_node, en_output, look_up) 398 | p, l, acc_e, acc_e_t, acc_i = self.recur_nn.getLoss_train(gd_tree_node, en_output, look_up) 399 | #print (p)tensor([ 0, 2]) 400 | #print (l)tensor(1.5615) 401 | #print (l) 402 | #print (post_order(gd_tree_node)) 403 | #get_info_teacher_pointer(gd_tree_node) 404 | 405 | batch_predictions.append(p) 406 | batch_acc_e_list.append(acc_e) 407 | batch_acc_e_t_list.append(acc_e_t) 408 | batch_acc_i_list.append(acc_i) 409 | #batch_loss_l.append(l) 410 | #batch_loss_l += l 411 | batch_loss_l = torch.sum(torch.cat([ batch_loss_l.unsqueeze(0), l.unsqueeze(0)], 0)) 412 | batch_count += 1 413 | 414 | #print (post_order(gd_tree_final)) 415 | #print ("hhhhhh:", en_output) 416 | #print (batch_loss_l) 417 | #torch.cat(batch_loss_l) 418 | #print (torch.sum(torch.cat(batch_loss_l))) 419 | #print () 420 | return batch_predictions, batch_loss_l, batch_count, batch_acc_e_list, batch_acc_e_t_list, batch_acc_i_list 421 | 422 | def test_forward_recur(self, input_tensor, input_lengths, num_pos, b_gd_tree): 423 | encoder_outputs, encoder_hidden = self.encoder(input_tensor, input_lengths) 424 | en_output_list, en_attn_list = self.self_attention(encoder_outputs, encoder_outputs, num_pos, input_lengths) 425 | 426 | batch_size = len(en_output_list) 427 | 428 | batch_predictions = [] 429 | batch_acc_e_list = [] 430 | batch_acc_e_t_list = [] 431 | batch_acc_i_list = [] 432 | batch_count = 0 433 | for b_i in range(batch_size): 434 | #en_output = en_output_list[b_i] 435 | en_output = torch.cat([self.fixed_num_emb, en_output_list[b_i]], dim=0) 436 | #print (num_pos) 437 | look_up = self.fixed_num_symbol + ['temp_'+str(temp_i) for temp_i in range(len(num_pos[b_i]))] 438 | #print (b_gd_tree[b_i]) 439 | if len(b_gd_tree[b_i]) == 0: 440 | continue 441 | gd_tree_node = b_gd_tree[b_i][-1] 442 | 443 | #print (en_output) 444 | #print (look_up.index('temp_0')) 445 | #print (en_output[look_up.index("temp_0")]) 446 | #print () 447 | 448 | #self.recur_nn(gd_tree_node, en_output, look_up) 449 | p, acc_e, acc_e_t, acc_i = self.recur_nn.test_forward(gd_tree_node, en_output, look_up) 450 | #print (p)tensor([ 0, 2]) 451 | #print (l)tensor(1.5615) 452 | #print (l) 453 | #print (post_order(gd_tree_node)) 454 | #get_info_teacher_pointer(gd_tree_node) 455 | 456 | batch_predictions.append(p) 457 | batch_acc_e_list.append(acc_e) 458 | batch_acc_e_t_list.append(acc_e_t) 459 | batch_acc_i_list.append(acc_i) 460 | #batch_loss_l.append(l) 461 | #batch_loss_l += l 462 | batch_count += 1 463 | 464 | #print (post_order(gd_tree_final)) 465 | #print ("hhhhhh:", en_output) 466 | #print (batch_loss_l) 467 | #torch.cat(batch_loss_l) 468 | #print (torch.sum(torch.cat(batch_loss_l))) 469 | #print () 470 | return batch_predictions, batch_count, batch_acc_e_list, batch_acc_e_t_list, batch_acc_i_list 471 | 472 | def predict_forward_recur(self, input_tensor, input_lengths, num_pos, batch_seq_tree, batch_flags): 473 | encoder_outputs, encoder_hidden = self.encoder(input_tensor, input_lengths) 474 | en_output_list, en_attn_list = self.self_attention(encoder_outputs, encoder_outputs, num_pos, input_lengths) 475 | #print ('-x-x-x-',en_attn_list[0]) 476 | batch_size = len(en_output_list) 477 | batch_pred_tree_node = [] 478 | batch_pred_post_equ = [] 479 | #batch_pred_ans = [] 480 | 481 | 482 | for b_i in range(batch_size): 483 | 484 | flag = batch_flags[b_i] 485 | #num_list = batch_num_list[b_i] 486 | if flag == 1: 487 | 488 | 489 | en_output = torch.cat([self.fixed_num_emb, en_output_list[b_i]], dim=0) 490 | 491 | look_up = self.fixed_num_symbol + ['temp_'+str(temp_i) for temp_i in range(len(num_pos[b_i]))] 492 | 493 | seq_node = batch_seq_tree[b_i] 494 | #num_list = batch_num_list[i] 495 | #gold_ans = batch_solution[i] 496 | tree_node, post_equ = self.recur_nn.predict(seq_node, en_output, look_up)#, num_list, gold_ans) 497 | #p, acc_e, acc_e_t, acc_i = self.recur_nn.test_forward(gd_tree_node, en_output, look_up) 498 | 499 | batch_pred_tree_node.append(tree_node) 500 | batch_pred_post_equ.append(post_equ) 501 | 502 | else: 503 | batch_pred_tree_node.append(None) 504 | batch_pred_post_equ.append([]) 505 | 506 | return batch_pred_tree_node, batch_pred_post_equ#,en_attn_list 507 | -------------------------------------------------------------------------------- /3_T-RNN_&_baselines/src/replicate.py: -------------------------------------------------------------------------------- 1 | ###################################################################### 2 | # File: replicate.py 3 | # Author: Vishal Dey 4 | # Created on: 11 Dec 2019 5 | ####################################################################### 6 | ''' 7 | Synopsis: Duplicates Dolphin300 to Dolphin1500 from Dolphin18k 8 | duplicates templates, substitutes quantities and recompute ans 9 | ''' 10 | 11 | import os 12 | import sys 13 | import json 14 | import numpy as np 15 | from copy import deepcopy 16 | 17 | from utils import * 18 | 19 | np.random.seed(123) 20 | 21 | def read_json(fname): 22 | with open(os.path.join('./data/', fname), 'r') as fp: 23 | return json.load(fp) 24 | 25 | def write_json(fname, data): 26 | with open(os.path.join('./data/', fname), 'w') as fp: 27 | json.dump(data, fp) 28 | 29 | 30 | def duplicate(entry): 31 | text = entry['text'] 32 | #equation = entry['expression'] 33 | num_list = entry['num_list'] 34 | post_temp = entry['post_template'] 35 | 36 | dup_num_list = [] 37 | dup_text = [] 38 | 39 | for i in num_list: 40 | if '.' in i: 41 | temp = str(np.random.rand())[:len(i)] 42 | while (temp in num_list or temp in dup_num_list): 43 | temp = str(np.random.rand())[:len(i)] 44 | dup_num_list.append(temp) 45 | else: 46 | temp = str(np.random.randint(123)) 47 | while (temp in num_list or temp in dup_num_list): 48 | temp = str(np.random.randint(10000)) 49 | dup_num_list.append(temp) 50 | 51 | for each in text.split(): 52 | if each in num_list: 53 | dup_text.append(dup_num_list[num_list.index(each)]) 54 | else: 55 | dup_text.append(each) 56 | 57 | post_equ = [] 58 | for i in post_temp: 59 | if 'temp' in i: 60 | num = dup_num_list[int(i.split('_')[1])] 61 | post_equ.append(num) 62 | else: 63 | post_equ.append(i) 64 | try: 65 | ans = post_solver(post_equ) 66 | except: 67 | return None 68 | # print(dup_num_list, ' '.join(dup_text), ans) 69 | return dup_num_list, ' '.join(dup_text), ans 70 | 71 | 72 | def main(): 73 | dolphin = read_json('final_dolphin_data.json') 74 | 75 | n_repeats = 5 76 | max_id = int(max(list(dolphin.keys()))) 77 | 78 | new_dolphin = {} 79 | 80 | for k, v in dolphin.items(): 81 | new_dolphin[k] = deepcopy(v) 82 | 83 | for i in range(n_repeats): 84 | max_id += 1 85 | tmp = duplicate(v) 86 | if tmp: 87 | new_dolphin[str(max_id)] = deepcopy(v) 88 | new_dolphin[str(max_id)]['index'] = str(max_id) 89 | new_dolphin[str(max_id)]['text'] = tmp[1] 90 | new_dolphin[str(max_id)]['num_list'] = tmp[0] 91 | new_dolphin[str(max_id)]['ans'] = tmp[-1] 92 | # print(new_dolphin) 93 | 94 | print(len(new_dolphin.keys())) 95 | 96 | write_json('final_dolphin_data_replicate.json', new_dolphin) 97 | index = list(new_dolphin.keys()) 98 | 99 | valid_size = int(0.2*len(index)) 100 | 101 | valid_ids = np.random.choice(index, valid_size) 102 | test_size = int(valid_size/2) 103 | valid_size -= test_size 104 | 105 | print(len(index)-valid_size-test_size, valid_size, test_size) 106 | 107 | # validation ids 108 | write_json('valid_ids_dolphin.json', valid_ids[:valid_size].tolist()) 109 | # test ids 110 | fp1 = open('./data/id_ans_test_dolphin', 'w') 111 | 112 | testtemp = [] 113 | for id in valid_ids[valid_size:]: 114 | print(id + '\t' + new_dolphin[id]['ans'], file=fp1) 115 | tmp = ' '.join(new_dolphin[id]['post_template']) 116 | for ch in ['+', '-', '*', '/']: 117 | tmp = tmp.replace(ch, '^') 118 | testtemp.append([id, tmp.split()]) 119 | 120 | fp1.close() 121 | 122 | # test templates masked 123 | write_json('pg_norm_test_dolphin.json', testtemp) 124 | 125 | 126 | main() 127 | -------------------------------------------------------------------------------- /3_T-RNN_&_baselines/src/retrieval_model.py: -------------------------------------------------------------------------------- 1 | ###################################################################### 2 | # File: retrieval_model.py 3 | # Author: Vishal Dey 4 | # Created on: 11 Dec 2019 5 | ####################################################################### 6 | ''' 7 | Synopsis: Create w2v for corresponding string description of each problem 8 | Reads in pretrianed Word2vec vectors and add each word vector to obtain 9 | phrase vectors. 10 | Either compute TF-IDF / w2v based cosine similarity 11 | ''' 12 | 13 | import os 14 | import json 15 | import sys 16 | 17 | import numpy as np 18 | from sklearn.feature_extraction.text import TfidfVectorizer 19 | from sklearn.metrics.pairwise import linear_kernel, cosine_similarity 20 | from copy import deepcopy 21 | 22 | # for reproducibility 23 | np.random.choice(123) 24 | 25 | 26 | # read json file 27 | def read_json(fname): 28 | with open(os.path.join('./data/', fname), 'r') as fp: 29 | return json.load(fp) 30 | 31 | 32 | # compute TF-iDF based most similar problem 33 | def find_similar_tfidf(tfidf_matrix, tfidf_vector): 34 | cosine_similarities = linear_kernel(tfidf_vector, tfidf_matrix).flatten() 35 | related_docs_indices = [i for i in cosine_similarities.argsort()[::-1]] 36 | #print(related_docs_indices[0], cosine_similarities[related_docs_indices[0]]) 37 | return related_docs_indices[0] 38 | 39 | # compute w2v cosine based most similar problem 40 | def find_similar_w2v(w2vmatrix, w2vemb, query, EMB_DIM, minv=0, maxv=1): 41 | query_emb = [] 42 | for word in query.split(): 43 | if word in w2vemb: 44 | query_emb.append(np.array(list(map(float, w2vemb[word])))) 45 | else: 46 | query_emb.append(np.random.uniform(minv, maxv, EMB_DIM)) 47 | 48 | cosine_similarities = linear_kernel(query_emb, w2vmatrix).flatten() 49 | related_docs_indices = [i for i in cosine_similarities.argsort()[::-1]] 50 | print(related_docs_indices) 51 | return related_docs_indices[0] 52 | 53 | 54 | # load w2v problem 55 | def load_w2v(w2v_file): 56 | EMB_DIM = 0 57 | w2v_emb = {} 58 | minv = sys.float_info.max 59 | maxv = sys.float_info.min 60 | 61 | with open(os.path.join('./w2v', w2v_file), 'r') as fp: 62 | EMB_DIM = int(fp.readline().split()[1]) 63 | for line in fp.readlines(): 64 | tmp = line.split() 65 | tmp = list(map(float, tmp[1:])) 66 | minv = min(minv, min(tmp)) 67 | maxv = max(maxv, max(tmp)) 68 | 69 | print(EMB_DIM, minv, maxv) 70 | return w2v_emb, EMB_DIM, minv, maxv 71 | 72 | 73 | # compute w2v matrix (problem x EMB_DIM) 74 | def compute_w2vmatrix(docs, w2v_emb, EMB_DIM, minv=0, maxv=1): 75 | w2vmatrix = [] 76 | for doc in docs: 77 | emb = [0]*EMB_DIM 78 | for word in doc.split(): 79 | if word in w2v_emb: 80 | emb += np.array(list(map(float, w2v_emb[word]))) 81 | else: 82 | emb += np.random.uniform(minv, maxv, EMB_DIM) 83 | emb /= len(doc.split()) 84 | 85 | w2vmatrix.append(emb) 86 | return w2vmatrix 87 | 88 | 89 | # post fix equation solver 90 | def post_solver(post_equ): 91 | stack = [] 92 | op_list = ['+', '-', '/', '*', '^'] 93 | 94 | # if len(post_equ)-2*len([i for i in post_equ if (any(c.isdigit() for c in i))]) >= 0: 95 | # print(post_equ) 96 | # return 0 97 | for elem in post_equ: 98 | if elem not in op_list: 99 | op_v = elem 100 | if '%' in op_v: 101 | op_v = float(op_v[:-1])/100.0 102 | if op_v == 'PI': 103 | op_v = np.pi 104 | stack.append(str(op_v)) 105 | elif elem in op_list: 106 | op_v_1 = stack.pop() 107 | op_v_1 = float(op_v_1) 108 | op_v_2 = stack.pop() 109 | op_v_2 = float(op_v_2) 110 | if elem == '+': 111 | stack.append(str(op_v_2+op_v_1)) 112 | elif elem == '-': 113 | stack.append(str(op_v_2-op_v_1)) 114 | elif elem == '*': 115 | stack.append(str(op_v_2*op_v_1)) 116 | elif elem == '/': 117 | if op_v_1 == 0: 118 | return np.nan 119 | stack.append(str(op_v_2/op_v_1)) 120 | else: 121 | stack.append(str(op_v_2**op_v_1)) 122 | return stack.pop() 123 | 124 | # get post fix equation from the template 125 | def get_equation(template, num_list): 126 | equation = [] 127 | for x in template: 128 | if 'temp' in x: 129 | equation.append(str(num_list[int(x.split('_')[1])])) 130 | else: 131 | equation.append(x) 132 | return equation 133 | 134 | 135 | 136 | def main(): 137 | math23k = read_json('math23k_final.json') 138 | #math23k = read_json('final_dolphin_data.json') 139 | 140 | # read in test IDs and answer - gold truth 141 | test_ids = [] 142 | with open('./data/id_ans_test', 'r') as fp: 143 | for line in fp.readlines(): 144 | test_ids.append(line.split('\t')[0]) 145 | # del math23k['1682'] 146 | 147 | test_length = len(test_ids) 148 | 149 | train = {} 150 | test = {} 151 | corpus = [] 152 | 153 | # make training corpus text 154 | for k, v in math23k.items(): 155 | if k not in test_ids: 156 | train[k] = v 157 | corpus.append(v['template_text']) 158 | else: 159 | test[k] = v 160 | 161 | corpus = list(corpus) 162 | indices = list(train.keys()) 163 | 164 | print(len(corpus), test_length) 165 | 166 | # tfidf vectorizer 167 | vectorizer = TfidfVectorizer() 168 | corpus_tfidf = vectorizer.fit_transform(corpus) 169 | 170 | # load w2v 171 | #w2v_emb, EMB_DIM, minv, maxv = load_w2v('crawl-300d-2M.vec') 172 | #w2vmatrix = compute_w2vmatrix(corpus, w2v_emb, EMB_DIM, minv, maxv) 173 | 174 | #print(len(w2vmatrix)) 175 | similar_indices = {} 176 | 177 | for k, v in test.items(): 178 | query_tfidf = vectorizer.transform([v['template_text']]) 179 | i = find_similar_tfidf(corpus_tfidf, query_tfidf) 180 | #i = find_similar_w2v(w2vmatrix, w2v_emb, v['template_text'], EMB_DIM, minv, maxv) 181 | #i = jaccard_similarity(corpus, v['template_text']) 182 | #print(i) 183 | similar_indices[k] = indices[i] 184 | 185 | num_correct = 0 186 | # outputs the prediction 187 | fout = open('pred.txt', 'w') 188 | 189 | for k, v in similar_indices.items(): 190 | template = math23k[v]['post_template'] 191 | num_list = math23k[k]['num_list'] 192 | ans = math23k[k]['ans'] 193 | 194 | if (max([int(s.split('temp_')[1]) for s in set(template) if 'temp' in s]) < len(num_list)): 195 | pred_ans = post_solver(get_equation(template, num_list)) 196 | print(math23k[k]['index'], pred_ans, ans, file=fout) 197 | if (np.isclose(float(pred_ans), float(ans))): 198 | num_correct += 1 199 | else: 200 | print('wrong template prediction: mismatch with len of num list') 201 | 202 | print('Accuracy = ', str(float(num_correct)/test_length)) 203 | fout.close() 204 | 205 | main() 206 | -------------------------------------------------------------------------------- /3_T-RNN_&_baselines/src/trainer.py: -------------------------------------------------------------------------------- 1 | from utils import * 2 | from model import * 3 | import os 4 | 5 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 6 | 7 | class Trainer(): 8 | def __init__(self, data_loader, params): 9 | self.data_loader = data_loader 10 | self.params = params 11 | self.train_len = len(data_loader.train_list) 12 | self.valid_len = len(data_loader.valid_list) 13 | self.test_len = len(data_loader.test_list) 14 | #self.optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9, dampening=0.0) 15 | #self.pg_seq = dict(read_data_json("./data/pg_seq_norm_0828.json")) #dict 16 | self.pg_seq = dict(read_data_json("./data/pg_norm_test_dolphin.json")) #dict 17 | 18 | def _train_batch_recur(self, model, batch_encode_pad_idx, batch_encode_num_pos, batch_encode_len, batch_gd_tree): 19 | batch_encode_pad_idx_tensor = torch.LongTensor(batch_encode_pad_idx) 20 | batch_encode_tensor_len = torch.LongTensor(batch_encode_len) 21 | 22 | batch_encode_pad_idx_tensor = batch_encode_pad_idx_tensor.to(device)#cuda() 23 | batch_encode_tensor_len = batch_encode_tensor_len.to(device)#cuda() 24 | #print ("batch_encode_num_pos",batch_encode_num_pos) 25 | b_pred, b_loss, b_count, b_acc_e, b_acc_e_t, b_acc_i = model(batch_encode_pad_idx_tensor, batch_encode_tensor_len, \ 26 | batch_encode_num_pos, batch_gd_tree) 27 | self.optimizer.zero_grad() 28 | #print (b_loss) 29 | b_loss.backward(retain_graph=True) 30 | clip_grad_norm_(model.parameters(), 5, norm_type=2.) 31 | self.optimizer.step() 32 | return b_pred, b_loss.item(), b_count, b_acc_e, b_acc_e_t, b_acc_i 33 | 34 | def _test_recur(self, model, data_list): 35 | batch_size = self.params['batch_size'] 36 | data_generator = self.data_loader.get_batch(data_list, batch_size) 37 | test_pred = [] 38 | test_count = 0 39 | test_acc_e = [] 40 | test_acc_e_t = [] 41 | test_acc_i = [] 42 | for batch_elem in data_generator: 43 | batch_encode_idx = batch_elem['batch_encode_idx'] 44 | batch_encode_pad_idx = batch_elem['batch_encode_pad_idx'] 45 | batch_encode_num_pos = batch_elem['batch_encode_num_pos'] 46 | batch_encode_len = batch_elem['batch_encode_len'] 47 | 48 | batch_decode_idx = batch_elem['batch_decode_idx'] 49 | 50 | batch_gd_tree = batch_elem['batch_gd_tree'] 51 | 52 | batch_encode_pad_idx_tensor = torch.LongTensor(batch_encode_pad_idx) 53 | batch_encode_tensor_len = torch.LongTensor(batch_encode_len) 54 | 55 | batch_encode_pad_idx_tensor = batch_encode_pad_idx_tensor.to(device)#cuda() 56 | batch_encode_tensor_len = batch_encode_tensor_len.to(device)#cuda() 57 | 58 | b_pred, b_count, b_acc_e, b_acc_e_t, b_acc_i = model.test_forward_recur(batch_encode_pad_idx_tensor, batch_encode_tensor_len, \ 59 | batch_encode_num_pos, batch_gd_tree) 60 | 61 | test_pred+= b_pred 62 | test_count += b_count 63 | test_acc_e += b_acc_e 64 | test_acc_e_t += b_acc_e_t 65 | test_acc_i += b_acc_i 66 | return test_pred, test_count, test_acc_e, test_acc_e_t, test_acc_i 67 | 68 | def predict_joint_batch(self, model, batch_encode_pad_idx, batch_encode_num_pos, batch_encode_len, batch_gd_tree, batch_index, batch_num_list): 69 | batch_encode_pad_idx_tensor = torch.LongTensor(batch_encode_pad_idx) 70 | batch_encode_tensor_len = torch.LongTensor(batch_encode_len) 71 | 72 | batch_encode_pad_idx_tensor = batch_encode_pad_idx_tensor.to(device)#cuda() 73 | batch_encode_tensor_len = batch_encode_tensor_len.to(device)#cuda() 74 | 75 | #alphas = 'abcdefghijklmnopqrstuvwxyz' 76 | alphas = list(map(str, list(range(0, 14)))) 77 | batch_seq_tree = [] 78 | batch_flags = [] 79 | batch_num_len = [] 80 | for i in range(len(batch_index)): 81 | index = batch_index[i] 82 | op_template = self.pg_seq[index] 83 | new_op_temps = [] 84 | num_len = 0 85 | #print(op_template) 86 | for temp_elem in op_template: 87 | if 'temp' in temp_elem: 88 | num_idx = alphas.index(temp_elem[5:]) 89 | #num_idx = temp_elem[5:] 90 | #new_op_temps.append('temp_'+str(num_idx)) 91 | new_op_temps.append(temp_elem) 92 | if num_len < num_idx: 93 | num_len = num_idx 94 | else: 95 | new_op_temps.append(temp_elem) 96 | 97 | #print (new_op_temps) 98 | #print (op_template) 99 | #print 100 | #print ("0000", op_template) 101 | try: 102 | temp_tree = construct_tree_opblank(new_op_temps[:]) 103 | except: 104 | #print ("error:", new_op_temps) 105 | temp_tree = construct_tree_opblank(['temp_0', 'temp_1', '^']) 106 | batch_seq_tree.append(temp_tree) 107 | num_list = batch_num_list[i] 108 | if num_len >= len(num_list): 109 | print ('error num len',new_op_temps, num_list) 110 | batch_flags.append(0) 111 | else: 112 | batch_flags.append(1) 113 | 114 | #print(index, temp_tree.__str__()) 115 | #print(new_op_temps, num_list) 116 | 117 | #print('Batch flags: ', batch_flags) 118 | #print(batch_index[0], batch_seq_tree[0].__str__()) 119 | batch_pred_tree_node, batch_pred_post_equ = model.predict_forward_recur(batch_encode_pad_idx_tensor, batch_encode_tensor_len,\ 120 | batch_encode_num_pos, batch_seq_tree, batch_flags) 121 | 122 | return batch_pred_tree_node, batch_pred_post_equ 123 | #model.xxx(batch_encode_pad_idx_tensor, batch_encode_tensor_len, batch_encode_num_pos, batch_gd_tree, batch_flags) 124 | 125 | def predict_joint(self, model): 126 | batch_size = self.params['batch_size'] 127 | data_generator = self.data_loader.get_batch(self.data_loader.test_list, batch_size) 128 | test_pred = [] 129 | test_count = 0 130 | test_acc_e = [] 131 | test_acc_e_t = [] 132 | test_acc_i = [] 133 | 134 | test_temp_acc = 0.0 135 | test_ans_acc = 0.0 136 | 137 | save_info = [] 138 | 139 | for batch_elem in data_generator: 140 | batch_encode_idx = batch_elem['batch_encode_idx'][:] 141 | batch_encode_pad_idx = batch_elem['batch_encode_pad_idx'][:] 142 | batch_encode_num_pos = batch_elem['batch_encode_num_pos'][:] 143 | batch_encode_len = batch_elem['batch_encode_len'][:] 144 | 145 | batch_decode_idx = batch_elem['batch_decode_idx'][:] 146 | 147 | batch_gd_tree = batch_elem['batch_gd_tree'][:] 148 | 149 | batch_index = batch_elem['batch_index'] 150 | batch_num_list = batch_elem['batch_num_list'][:] 151 | batch_solution = batch_elem['batch_solution'] 152 | batch_post_equation = batch_elem['batch_post_equation'] 153 | 154 | #b_pred, b_count, b_acc_e, b_acc_e_t, b_acc_i = 155 | batch_pred_tree_node, batch_pred_post_equ = self.predict_joint_batch(model, batch_encode_pad_idx, batch_encode_num_pos,batch_encode_len, batch_gd_tree, batch_index, batch_num_list) 156 | 157 | for i in range(len(batch_solution)): 158 | pred_post_equ = batch_pred_post_equ[i] 159 | #pred_ans = batch_pred_ans[i] 160 | gold_post_equ = batch_post_equation[i] 161 | gold_ans = batch_solution[i] 162 | idx = batch_index[i] 163 | pgseq = self.pg_seq[idx] 164 | num_list = batch_num_list[i] 165 | 166 | #print (pred_post_equ) 167 | #print (pgseq) 168 | #print (gold_post_equ) 169 | #print (num_list) 170 | 171 | if pred_post_equ == []: 172 | pred_ans = -float('inf') 173 | else: 174 | pred_post_equ_ali = [] 175 | for elem in pred_post_equ: 176 | if 'temp' in elem: 177 | num_idx = int(elem[5:]) 178 | num_marker = num_list[num_idx] 179 | pred_post_equ_ali.append(str(num_marker)) 180 | elif 'PI' == elem: 181 | pred_post_equ_ali.append("3.141592653589793") 182 | else: 183 | pred_post_equ_ali.append(elem) 184 | try: 185 | pred_ans = post_solver(pred_post_equ_ali) 186 | except: 187 | pred_ans = -float('inf') 188 | 189 | 190 | if abs(float(pred_ans)-float(gold_ans)) < 1e-5: 191 | test_ans_acc += 1 192 | if ' '.join(pred_post_equ) == ' '.join(gold_post_equ): 193 | test_temp_acc += 1 194 | #print (pred_ans, gold_ans) 195 | #print () 196 | save_info.append({"idx":idx, "pred_post_eq":pred_post_equ, "gold_post_equ":gold_post_equ, "pred_ans":pred_ans,"gold_ans": gold_ans}) 197 | 198 | 199 | #test_pred+= b_pred 200 | #test_count += b_count 201 | #test_acc_e += b_acc_e 202 | #test_acc_e_t += b_acc_e_t 203 | #test_acc_i += b_acc_i 204 | print ("final test temp_acc:{}, ans_acc:{}".format(test_temp_acc/self.test_len, test_ans_acc/self.test_len)) 205 | #logging.debug("final test temp_acc:{}, ans_acc:{}".format(test_ans_acc/self.test_len, test_ans_acc/self.test_len)) 206 | write_data_json(save_info, "./result_recur/dolphin_rep/save_info_"+str(test_temp_acc/self.test_len)+"_"+str(test_ans_acc/self.test_len)+".json") 207 | return save_info 208 | 209 | 210 | def _train_recur_epoch(self, model, start_epoch, n_epoch): 211 | batch_size = self.params['batch_size'] 212 | data_loader = self.data_loader 213 | train_list = data_loader.train_list 214 | valid_list = data_loader.valid_list 215 | test_list = data_loader.test_list 216 | 217 | 218 | valid_max_acc = 0 219 | for epoch in range(start_epoch, n_epoch + 1): 220 | epoch_loss = 0 221 | epoch_pred = [] 222 | epoch_count = 0 223 | epoch_acc_e = [] 224 | epoch_acc_e_t = [] 225 | epoch_acc_i = [] 226 | train_generator = self.data_loader.get_batch(train_list, batch_size) 227 | s_time = time.time() 228 | xx = 0 229 | for batch_elem in train_generator: 230 | batch_encode_idx = batch_elem['batch_encode_idx'] 231 | batch_encode_pad_idx = batch_elem['batch_encode_pad_idx'] 232 | batch_encode_num_pos = batch_elem['batch_encode_num_pos'] 233 | batch_encode_len = batch_elem['batch_encode_len'] 234 | 235 | batch_decode_idx = batch_elem['batch_decode_idx'] 236 | 237 | batch_gd_tree = batch_elem['batch_gd_tree'] 238 | 239 | 240 | b_pred, b_loss, b_count, b_acc_e, b_acc_e_t, b_acc_i = self._train_batch_recur(model, batch_encode_pad_idx, \ 241 | batch_encode_num_pos,batch_encode_len, batch_gd_tree) 242 | epoch_loss += b_loss 243 | epoch_pred+= b_pred 244 | epoch_count += b_count 245 | epoch_acc_e += b_acc_e 246 | epoch_acc_e_t += b_acc_e_t 247 | epoch_acc_i += b_acc_i 248 | xx += 1 249 | #print (xx) 250 | #print (b_pred) 251 | #break 252 | 253 | 254 | e_time = time.time() 255 | 256 | #print ("ee": epoch_acc_e) 257 | #print ("et": epoch_acc_e_t) 258 | #print ("recur epoch: {}, loss: {}, acc_e: {}, acc_i: {} time: {}".\ 259 | # format(epoch, epoch_loss/epoch_count, sum(epoch_acc_e)*1.0/sum(epoch_acc_e_t), sum(epoch_acc_i)/epoch_count, \ 260 | # (e_time-s_time)/60)) 261 | 262 | valid_pred, valid_count, valid_acc_e, valid_acc_e_t, valid_acc_i = self._test_recur(model, valid_list) 263 | test_pred, test_count, test_acc_e, test_acc_e_t, test_acc_i = self._test_recur(model, test_list) 264 | 265 | #print ('**********1', test_pred) 266 | #print ('**********2', test_count) 267 | #print ('**********3', test_acc_e) 268 | #print ('**********4', test_acc_e_t) 269 | #print ('**********5', test_acc_i) 270 | 271 | print ("RECUR EPOCH: {}, loss: {}, train_acc_e: {}, train_acc_i: {} time: {}".\ 272 | format(epoch, epoch_loss/epoch_count, sum(epoch_acc_e)*1.0/sum(epoch_acc_e_t), sum(epoch_acc_i)/epoch_count, \ 273 | (e_time-s_time)/60)) 274 | 275 | print ("valid_acc_e: {}, valid_acc_i: {}, test_acc_e: {}, test_acc_i: {}".format(sum(valid_acc_e)*1.0/sum(valid_acc_e_t), sum(valid_acc_i)*1.0/valid_count, sum(test_acc_e)*1.0/sum(test_acc_e_t), sum(test_acc_i)*1.0/test_count)) 276 | 277 | 278 | test_acc = sum(valid_acc_i)*1.0/valid_count 279 | if test_acc >= valid_max_acc: 280 | print ("originial", valid_max_acc) 281 | valid_max_acc = test_acc 282 | print ("saving...", valid_max_acc) 283 | if os.path.exists(self.params['save_file']): 284 | os.remove(self.params['save_file']) 285 | torch.save(model, self.params['save_file']) 286 | print ("saveing ok!") 287 | self.predict_joint(model) 288 | 289 | # save recursive data: format: {'id': , 'right_flag':, 'predict_result', 'ground_result'} 290 | # save joint data ['id':, 'right_flag':, predict_result, ground_result] 291 | 292 | else: 293 | print ("jumping...") 294 | 295 | print () 296 | 297 | 298 | def train(self, model, optimizer): 299 | self.optimizer = optimizer 300 | #self._train_recur_epoch(model, 0, 0) 301 | #self._train_pointer_epoch(model, 0, 0) 302 | #self._train_joint_epoch(model,0, 0) 303 | #self._train_recur_epoch(model, 1, 1) 304 | #self._train_pointer_epoch(model, 1, 1) 305 | #self._train_joint_epoch(model,1, 1) 306 | #self._train_joint_epoch(model,0, 10) 307 | self._train_recur_epoch(model, 0, 100) 308 | -------------------------------------------------------------------------------- /3_T-RNN_&_baselines/src/utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | import numpy as np 3 | import time 4 | 5 | import random 6 | import pickle 7 | import sys 8 | 9 | def read_data_json(filename): 10 | with open(filename, 'r') as f: 11 | return json.load(f) 12 | 13 | def write_data_json(data, filename): 14 | with open(filename, 'w') as f: 15 | json.dump(data, f, indent=4) 16 | 17 | def save_pickle(data, filename): 18 | with open(filename, 'wb') as f: 19 | pickle.dump(data, f) 20 | 21 | def split_by_feilong_23k(data_dict): 22 | #t_path = "./data/id_ans_test" 23 | #v_path = "./data/valid_ids.json" 24 | t_path = "./data/id_ans_test_dolphin" 25 | v_path = "./data/valid_ids_dolphin.json" 26 | valid_ids = read_data_json(v_path) 27 | test_ids = [] 28 | with open(t_path, 'r') as f: 29 | for line in f: 30 | test_id = line.strip().split('\t')[0] 31 | test_ids.append(test_id) 32 | train_list = [] 33 | test_list = [] 34 | valid_list = [] 35 | for key, value in data_dict.items(): 36 | if key in test_ids: 37 | test_list.append((key, value)) 38 | elif key in valid_ids: 39 | valid_list.append((key, value)) 40 | else: 41 | train_list.append((key, value)) 42 | print (len(train_list), len(valid_list), len(test_list)) 43 | return train_list, valid_list, test_list 44 | 45 | def string_2_idx_sen(sen, vocab_dict): 46 | #print(sen) 47 | return [vocab_dict[word] for word in sen] 48 | 49 | def pad_sen(sen_idx_list, max_len=115, pad_idx=1): 50 | return sen_idx_list + [pad_idx]*(max_len-len(sen_idx_list)) 51 | 52 | def encoder_hidden_process(encoder_hidden, bidirectional): 53 | if encoder_hidden is None: 54 | return None 55 | if isinstance(encoder_hidden, tuple): 56 | encoder_hidden = tuple([_cat_directions(h, bidirectional) for h in encoder_hidden]) 57 | else: 58 | encoder_hidden = _cat_directions(encoder_hidden, bidirectional) 59 | return encoder_hidden 60 | 61 | def _cat_directions(h, bidirectional): 62 | if bidirectional: 63 | h = torch.cat([h[0:h.size(0):2], h[1:h.size(0):2]], 2) 64 | return h 65 | 66 | def post_solver(post_equ): 67 | stack = [] 68 | op_list = ['+', '-', '/', '*', '^'] 69 | for elem in post_equ: 70 | if elem not in op_list: 71 | op_v = elem 72 | if '%' in op_v: 73 | op_v = float(op_v[:-1])/100.0 74 | stack.append(str(op_v)) 75 | elif elem in op_list: 76 | op_v_1 = stack.pop() 77 | op_v_1 = float(op_v_1) 78 | op_v_2 = stack.pop() 79 | op_v_2 = float(op_v_2) 80 | if elem == '+': 81 | stack.append(str(op_v_2+op_v_1)) 82 | elif elem == '-': 83 | stack.append(str(op_v_2-op_v_1)) 84 | elif elem == '*': 85 | stack.append(str(op_v_2*op_v_1)) 86 | elif elem == '/': 87 | if op_v_1 == 0: 88 | return nan 89 | stack.append(str(op_v_2/op_v_1)) 90 | else: 91 | stack.append(str(op_v_2**op_v_1)) 92 | return stack.pop() 93 | 94 | 95 | class BinaryTree(): 96 | def __init__(self): 97 | self._init_node_info() 98 | 99 | def __str__(self): 100 | return mid_order(self) 101 | 102 | def _init_node_info(self): 103 | self.root_value = None 104 | self.left_tree = None 105 | self.right_tree = None 106 | #self.height = 0 107 | self.is_leaf = True 108 | self.pre_order_list = [] 109 | self.mid_order_list = [] 110 | self.post_order_list = [] 111 | self.node_emb = None 112 | 113 | def _pre_order(root, order_list): 114 | if root == None: 115 | return 116 | #print ('pre\t', root.value,) 117 | order_list.append(root.root_value) 118 | _pre_order(root.left_tree, order_list) 119 | _pre_order(root.right_tree, order_list) 120 | 121 | def pre_order(root): 122 | order_list = [] 123 | _pre_order(root, order_list) 124 | #print ("post:\t", order_list) 125 | return order_list 126 | 127 | def _mid_order(root, order_list): 128 | if root == None: 129 | return 130 | _mid_order(root.left_tree, order_list) 131 | #print ('mid\t', root.root_value,) 132 | order_list.append(root.root_value) 133 | _mid_order(root.right_tree, order_list) 134 | 135 | def mid_order(root): 136 | order_list = [] 137 | _mid_order(root, order_list) 138 | #print ("post:\t", order_list) 139 | return order_list 140 | 141 | def _post_order(root, order_list): 142 | if root == None: 143 | return 144 | _post_order(root.left_tree, order_list) 145 | _post_order(root.right_tree, order_list) 146 | #print (root.root_value, ' ->\t', root.node_emb,) 147 | order_list.append(root.root_value) 148 | 149 | def post_order(root): 150 | order_list = [] 151 | #print ("post:") 152 | _post_order(root, order_list) 153 | #print ("post:\t", order_list) 154 | #print () 155 | return order_list 156 | 157 | def construct_tree(post_equ): 158 | stack = [] 159 | op_list = ['+', '-', '/', '*', '^'] 160 | for elem in post_equ: 161 | node = BinaryTree() 162 | node.root_value = elem 163 | if elem in op_list: 164 | node.right_tree = stack.pop() 165 | node.left_tree = stack.pop() 166 | 167 | node.is_leaf = False 168 | 169 | stack.append(node) 170 | else: 171 | stack.append(node) 172 | return stack.pop() 173 | 174 | def form_gdtree(elem): 175 | post_equ = elem['post_template'] 176 | tree = construct_tree(post_equ) 177 | gd_list = [] 178 | def _form_detail(root): 179 | if root == None: 180 | return 181 | _form_detail(root.left_tree) 182 | _form_detail(root.right_tree) 183 | #if root.left_tree != None and root.right_tree != None: 184 | if root.is_leaf == False: 185 | gd_list.append(root) 186 | _form_detail(tree) 187 | #print ('+++++++++') 188 | #print (gd_list) 189 | #print (post_equ) 190 | #for elem in gd_list: 191 | # print (post_order(elem)) 192 | # print (elem.root_value) 193 | #print ('---------') 194 | #print () 195 | return gd_list[:] 196 | 197 | def construct_tree_opblank(post_equ): 198 | stack = [] 199 | op_list = ['', '^'] 200 | for elem in post_equ: 201 | node = BinaryTree() 202 | node.root_value = elem 203 | if elem in op_list: 204 | node.right_tree = stack.pop() 205 | node.left_tree = stack.pop() 206 | right_len = post_order(node.right_tree) 207 | left_len = post_order(node.left_tree) 208 | if left_len <= right_len: 209 | node.temp_value = node.left_tree.temp_value 210 | else: 211 | node.temp_value = node.right_tree.temp_value 212 | node.is_leaf = False 213 | 214 | stack.append(node) 215 | else: 216 | node.temp_value = node.root_value 217 | stack.append(node) 218 | return stack.pop() 219 | 220 | def form_seqtree(seq_tree): 221 | #post_equ = elem[u'mask_post_equ_list'][2:] 222 | #tree = construct_tree(post_equ) 223 | seq_list = [] 224 | def _form_detail(root): 225 | if root == None: 226 | return 227 | _form_detail(root.left_tree) 228 | _form_detail(root.right_tree) 229 | #if root.left_tree != None and root.right_tree != None: 230 | if root.is_leaf == False: 231 | seq_list.append(root) 232 | _form_detail(seq_tree) 233 | return seq_list[:] 234 | -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Soumava Banerjee 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /readme.md: -------------------------------------------------------------------------------- 1 | ## Experimental Evaluation of Math Word Problem Solver 2 | 3 | ### Our contributions: 4 | - Creation of a new English corpus of arithmetic word problems dealing with {+, -, *, /} and linear variables only. We named this as Dolphin300 which is a subset of publicly available Dolphin18k. 5 | - Creation of equation templates and normalizing equations in par with Math23K dataset [1]. 6 | - Experimental evaluation of T-RNN and retrieval baselines on Math23K, Dolphin300 and Dolphin1500. 7 | 8 | ## Sample data processing and cleaning: 9 | 10 | - What is the value of five times the sum of twice of three-fourths and nine? ==> What is the value of 5 times the sum of 2 of 3/4 and 9? 11 | - please help with math! What percentage of 306,800 equals eight thousands? ==> please help with math! What percentage of 306800 equals 8000? 12 | - help!!!!!!!(please) i cant figure this out!? what is the sum of 4 2/5 and 17 3/7 ? ==> help!!!!!!!(please) i cant figure this out!? what is the sum of 22/5 and 122/7 ? 13 | - math homework help? question: 15 is 25% of what number? ==> math homework help? question: 15 is 25/100 of what number? 14 | 15 | ## list of folders: 16 | - Web scraping: it contains the code (OriginalDataExtractor.py) to scrap the math word problems from yahoo answers. A basic data cleaning has also been carried out (CleanVersionExtractor.py) to get the questions in the desired format. 17 | - Data_Cleaning: it contains the code for the data cleaning of Dolphin DataSet, MWP_DataCleaning.py file has all the rule based and filtering logic for transforming the candidate dolphin datasets to cleaned templates. Inside cleaned_data_examples folder, uncleaned_dolphin_data.csv contains the raw data from dolphin dataset and filtered_cleaned_dolphin_data.json contains the filtered out cleaned template json from the csv. 18 | - T-RNN and baselines: contain T-RNN code and baseline models, output folder within contains runs of retrieval model (named as pred_*.txt) and runs of TRNN (named as run_*.txt) 19 | 20 | 21 | ## Implementation: 22 | 23 | - Implemented in >=py3.6 environment with pytorch 24 | - We used part of T-RNN code [1] and added some more implementations for Math23K 25 | - We used replicate.py and MWP_DataCleaning.ipynb to replicate data and process raw noisy Dolphin18k data. 26 | - Finally we obtain Dolphin300 and Dolphin1500 after running replicate.py on Dolphin300. 27 | - Run T-RNN code as : 28 | $ python T-RNN/src/main.py 29 | (Please see the details in the code to change input files) 30 | 31 | ### T-RNN for Math23K 32 | - In the template prediction module, we use a pre-trained word embedding with 128 units, a two-layer Bi-LSTM with 256 hidden units as encoder, a two-layer LSTM with 512 hidden units as decoder. As to the optimizer, we use Adam with learning rate set to 1e−3, β1 = 0.9 and β2 = 0.99. In the answer generation module, we use a embedding layer with 100 units, a two-layer Bi-LSTM with 160 hidden units. SGD with learning rate 0.01 and momentum factor 0.9 is used to optimize this module. In both components, the number of epochs, mini-batch size and dropout rate are set 100, 32 and 0.5 respectively. 33 | 34 | ### T-RNN for Dolphin1500 & Dolphin300 35 | 36 | - Template prediction module: 128 units, two-layer Bi-LSTM with 256 hidden units as encoder, a two-layer LSTM with 512 hidden units as decoder. 37 | ADAM optimizer with default parameters. 38 | Answer generation module - embedding layer with 100 units, a two-layer Bi-LSTM with 160 hidden units, RNN classes = 4.SGD with learning rate 0.01 and momentum factor 0.9 is used to optimize this module. the number of epochs, mini-batch size and dropout rate are set 50, 32 and 0.5 respectively. 39 | 40 | 41 | References: 42 | 43 | [1] Lei Wang, Dongxiang Zhang, Jipeng Zhang, Xing Xu, Lianli Gao, Bingtian Dai, and Heng Tao Shen. Template-based math word problem solvers with recursive neural networks. 2019. 44 | 45 | [2] Yan Wang, Xiaojiang Liu, and Shuming Shi. Deep neural solver for math word problems. In Proceedings of the 2017 Conference on Empirical Methods in Natural Language Processing, pages 845–854, 2017. 46 | --------------------------------------------------------------------------------