├── .gitignore ├── LICENSE ├── README.md ├── __init__.py ├── convert_retriever_result.py ├── evaluate.py ├── inference_configs ├── program_generation_inference.yaml ├── question_classification_inference.yaml ├── retriever_inference.yaml └── span_selection_inference.yaml ├── lightning_modules ├── __init__.py ├── callbacks │ ├── __init__.py │ ├── program_generation_save_prediction_callback.py │ ├── question_classification_save_prediction_callback.py │ ├── retriever_save_prediction_callback.py │ └── span_selection_save_prediction_callback.py ├── datasets │ ├── .ipynb_checkpoints │ │ ├── __init__-checkpoint.py │ │ ├── datasets_util-checkpoint.py │ │ └── retriever_reader-checkpoint.py │ ├── __init__.py │ ├── program_generation_reader.py │ ├── question_classification_reader.py │ ├── retriever_reader.py │ └── span_selection_reader.py ├── models │ ├── __init__.py │ ├── program_generation_model.py │ ├── question_classification_model.py │ ├── retriever_model.py │ └── span_selection_model.py └── patches │ ├── __init__.py │ └── patched_loggers.py ├── requirements.txt ├── trainer.py ├── training_configs ├── program_generation_finetuning.yaml ├── question_classification_finetuning.yaml ├── retriever_finetuning.yaml └── span_selection_finetuning.yaml ├── txt_files ├── constant_list.txt └── operation_list.txt └── utils ├── datasets_util.py ├── program_generation_utils.py ├── retriever_utils.py ├── span_selection_utils.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | wandb/* 2 | results 3 | .DS_Store 4 | log.txt 5 | *predictions.json 6 | *.pyc 7 | __pycache__/ 8 | *copy* 9 | dataset/* 10 | lightning_logs 11 | .ipynb_checkpoints 12 | output 13 | checkpoints 14 | wandb 15 | *.json 16 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Penn State NLP Group 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # MultiHiertt 2 | Data and code for ACL 2022 paper "MultiHiertt: Numerical Reasoning over Multi Hierarchical Tabular and Textual Data" 3 | 4 | 5 | ## Requirements 6 | - python 3.9.7 7 | - pytorch 1.10.2, 8 | - pytorch-lightning 1.5.10 9 | - huggingface transformers 4.18.0 10 | - run `pip install -r requirements.txt` to install rest of the dependencies 11 | 12 | ## Leaderboard 13 | - The leaderboard for the private test data is held on [CodaLab](https://codalab.lisn.upsaclay.fr/competitions/6738) 14 | 15 | ## Main Files Structures 16 | ```shell 17 | dataset/ 18 | training_configs/ & inference_configs/ # Configuration files for training and inference 19 | lightning_modules/: 20 | models/ # Implementation for each module 21 | datasets/ # Dataloaders 22 | callbacks/ # Callbacks for saving predictions 23 | utils/ # Utilities for modules 24 | txt_files/ # Txt files such as constant_list.txt, etc 25 | output/ # Predictions and intermediate results 26 | checkpoint/ 27 | 28 | convert_retriever_result.py # convert inference of Fact Retrieving & Question Type Classification Module into model input of Reasoning Modules. 29 | trainer.py 30 | evaluate.py 31 | ``` 32 | 33 | ## Dataset 34 | The dataset is stored as json files [Download Link](https://drive.google.com/drive/folders/1ituEWZ5F7G9T9AZ0kzZZLrHNhRigHCZJ?usp=sharing), each entry has the following format: 35 | 36 | ``` 37 | "uid": unique example id; 38 | "paragraphs": the list of sentences in the document; 39 | "tables": the list of tables in HTML format in the document; 40 | "table_description": the list of table descriptions for each data cell in tables. Generated by the pre-processing script; 41 | "qa": { 42 | "question": the question; 43 | "answer": the answer; 44 | "program": the reasoning program; 45 | "text_evidence": the list of indices of gold supporting text facts; 46 | "table_evidence": the list of indices of gold supporting table facts; 47 | } 48 | ``` 49 | 50 | ## MT2Net 51 | We provide the model checkpoints in [Hugging Face](https://huggingface.co/datasets/yilunzhao/MultiHiertt/tree/main). Download them (`*.ckpt`) into the directory `checkpoints`. 52 | ### 1. Fact Retrieving & Question Type Classification Module 53 | #### 1.1 Training Stage 54 | - Edit `training_configs/retriever_finetuning.yaml` & `training_configs/question_classification_finetuning.yaml` to set your own project and data path. 55 | - Run the following commands to train the model. 56 | ``` 57 | export PYTHONPATH=`pwd`; python trainer.py {fit, validate} --config training_configs/*_finetuning.yaml 58 | ``` 59 | #### 1.2 Inference Stage 60 | - Edit `inference_configs/retriever_inference.yaml` & `inference_configs/retriever_inference.yaml` to set your own project and data path. 61 | - Run the following commands to get the intermediate results for {Train, Dev, Test} set, respectively. 62 | ``` 63 | export PYTHONPATH=`pwd`; python trainer.py predict --ckpt_path checkpoints/*_model.ckpt --config inference_configs/*_inference.yaml 64 | ``` 65 | where `checkpoints/*_model.ckpt` can be replaced by the checkpoint path from training stage. And the inference set or files should be specified in *_inference.yaml. 66 | 67 | ### 2. Reasoning Module Input Generation 68 | - Prepare `output/retriever_output/{train, dev, test}.json` & `output/question_classification_output/{train, dev, test}.json` from Step 1. 69 | 70 | - Run the following commands to convert predictions of Fact Retrieving & Question Type Classification Module for {Train, Dev, Test} into model input of Reasoning Module, respectively. 71 | ``` 72 | python convert_retriever_result.py 73 | ``` 74 | The output files are stored in `dataset/reasoning_module_input`, where `*_training.json` is used for the training stage and `*_inference.json` is used for the inference stage. 75 | 76 | ### 3. Reasoning Module 77 | #### 3.1 Training Stage 78 | - Edit `training_configs/program_generation_finetuning.yaml` & `training_configs/span_selection_finetuning.yaml` to set your own project and data path. 79 | - Run the following commands to train the model and generate the prediction files. 80 | ``` 81 | export PYTHONPATH=`pwd`; python trainer.py fit --config training_configs/*_finetuning.yaml 82 | ``` 83 | #### 3.2 Inference Stage 84 | - Edit `inference_configs/program_generation_inference.yaml` & `inference_configs/span_selection_inference.yaml` to set your own project and data path. 85 | - Run the following commands to get the prediction file for {Dev, Test} set 86 | ``` 87 | export PYTHONPATH=`pwd`; python trainer.py predict --ckpt_path checkpoints/*_model.ckpt --config inference_configs/*_inference.yaml 88 | ``` 89 | where `checkpoints/*_model.ckpt` can be replaced by the checkpoint path from training stage. And the inference set or files should be specified in *_inference.yaml. 90 | 91 | 92 | ## Evaluation 93 | Run the following commands to get the prediction file for {Dev, Test} set (and the performance on the Dev set), respectively. 94 | ``` 95 | python evaluate.py dataset/{test, dev}.json 96 | ``` 97 | The prediction file with the following format will be generated in the directory `output/final_predictions`: 98 | ``` 99 | [ 100 | { 101 | "uid": "bd2ce4dbf70d43e094d93d314b30bd39", 102 | "predicted_ans": "106.0", 103 | "predicted_program": [] 104 | }, 105 | ... 106 | ] 107 | ``` 108 | For test set, Please zip the generated test prediction file `test_predictions.json` into `test_predictions.zip`; and submit `test_predictions.zip` to [CodaLab](https://codalab.lisn.upsaclay.fr/competitions/6738) to get the final score. Please exactly match the filename. 109 | 110 | ## Any Questions? 111 | For any issues or questions, kindly email us at: Yilun Zhao yilun.zhao@yale.edu. 112 | 113 | ## Citation 114 | ``` 115 | @inproceedings{zhao-etal-2022-multihiertt, 116 | title = "{M}ulti{H}iertt: Numerical Reasoning over Multi Hierarchical Tabular and Textual Data", 117 | author = "Zhao, Yilun and 118 | Li, Yunxiang and 119 | Li, Chenying and 120 | Zhang, Rui", 121 | booktitle = "Proceedings of the 60th Annual Meeting of the Association for Computational Linguistics (Volume 1: Long Papers)", 122 | month = may, 123 | year = "2022", 124 | address = "Dublin, Ireland", 125 | publisher = "Association for Computational Linguistics", 126 | url = "https://aclanthology.org/2022.acl-long.454", 127 | pages = "6588--6600", 128 | } 129 | ``` 130 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/psunlpgroup/MultiHiertt/45bd9ccdf3142ea059bd5e69c0afb83437fa539c/__init__.py -------------------------------------------------------------------------------- /convert_retriever_result.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import collections 3 | import json 4 | import os 5 | import sys 6 | import random 7 | 8 | 9 | ### for single sent retrieve 10 | def convert_train(json_in, json_out, topn, max_len = 256): 11 | with open(json_in) as f_in: 12 | data = json.load(f_in) 13 | 14 | for each_data in data: 15 | try: 16 | gold_inds = [] 17 | cur_len = 0 18 | table_retrieved = each_data["table_retrieved_all"] 19 | text_retrieved = each_data["text_retrieved_all"] 20 | all_retrieved = table_retrieved + text_retrieved 21 | 22 | gold_table_inds = each_data["qa"]["table_evidence"] 23 | gold_text_inds = each_data["qa"]["text_evidence"] 24 | for ind in gold_table_inds: 25 | gold_inds.append(ind) 26 | cur_len += len(each_data["table_description"][ind].split()) 27 | 28 | for ind in gold_text_inds: 29 | gold_inds.append(ind) 30 | try: 31 | cur_len += len(each_data["paragraphs"][ind].split()) 32 | except: 33 | continue 34 | 35 | false_retrieved = [] 36 | for tmp in all_retrieved: 37 | if tmp["ind"] not in gold_inds: 38 | false_retrieved.append(tmp) 39 | 40 | sorted_dict = sorted(false_retrieved, key=lambda kv: kv["score"], reverse=True) 41 | res_n = topn - len(gold_inds) 42 | 43 | other_cands = [] 44 | while res_n > 0 and cur_len < max_len: 45 | next_false_retrieved = sorted_dict.pop(0) 46 | if next_false_retrieved["score"] < 0: 47 | break 48 | 49 | if type(next_false_retrieved["ind"]) == int: 50 | cur_len += len(each_data["paragraphs"][next_false_retrieved["ind"]].split()) 51 | other_cands.append(next_false_retrieved["ind"]) 52 | res_n -= 1 53 | else: 54 | cur_len += len(each_data["table_description"][next_false_retrieved["ind"]].split()) 55 | other_cands.append(next_false_retrieved["ind"]) 56 | res_n -= 1 57 | 58 | # recover the original order in the document 59 | input_inds = gold_inds + other_cands 60 | context = get_context(each_data, input_inds) 61 | each_data["model_input"] = context 62 | del each_data["table_retrieved_all"] 63 | del each_data["text_retrieved_all"] 64 | except: 65 | print(each_data["uid"]) 66 | 67 | with open(json_out, "w") as f: 68 | json.dump(data, f, indent=4) 69 | 70 | def convert_test(retriever_json_in, question_classification_json_in, json_out, topn, max_len = 256): 71 | with open(retriever_json_in) as f_in: 72 | data = json.load(f_in) 73 | 74 | with open(question_classification_json_in) as f_in: 75 | qc_data = json.load(f_in) 76 | 77 | qc_map = {} 78 | for example in qc_data: 79 | qc_map[example["uid"]] = example["pred"] 80 | 81 | for each_data in data: 82 | cur_len = 0 83 | table_retrieved = each_data["table_retrieved_all"] 84 | text_retrieved = each_data["text_retrieved_all"] 85 | all_retrieved = table_retrieved + text_retrieved 86 | 87 | cands_retrieved = [] 88 | for tmp in all_retrieved: 89 | cands_retrieved.append(tmp) 90 | 91 | sorted_dict = sorted(cands_retrieved, key=lambda kv: kv["score"], reverse=True) 92 | res_n = topn 93 | 94 | other_cands = [] 95 | 96 | while res_n > 0 and cur_len < max_len: 97 | next_false_retrieved = sorted_dict.pop(0) 98 | if next_false_retrieved["score"] < 0: 99 | break 100 | 101 | if type(next_false_retrieved["ind"]) == int: 102 | cur_len += len(each_data["paragraphs"][next_false_retrieved["ind"]].split()) 103 | other_cands.append(next_false_retrieved["ind"]) 104 | res_n -= 1 105 | else: 106 | cur_len += len(each_data["table_description"][next_false_retrieved["ind"]].split()) 107 | other_cands.append(next_false_retrieved["ind"]) 108 | res_n -= 1 109 | 110 | # recover the original order in the document 111 | input_inds = other_cands 112 | context = get_context(each_data, input_inds) 113 | each_data["model_input"] = context 114 | 115 | each_data["qa"]["predicted_question_type"] = qc_map[each_data["uid"]] 116 | del each_data["table_retrieved_all"] 117 | del each_data["text_retrieved_all"] 118 | 119 | 120 | with open(json_out, "w") as f: 121 | json.dump(data, f, indent=4) 122 | 123 | def get_context(each_data, input_inds): 124 | context = [] 125 | table_sent_map = get_table_sent_map(each_data["paragraphs"]) 126 | inds_map = {} 127 | for ind in input_inds: 128 | if type(ind) == str: 129 | table_ind = int(ind.split("-")[0]) 130 | sent_ind = table_sent_map[table_ind] 131 | if sent_ind not in inds_map: 132 | inds_map[sent_ind] = [ind] 133 | else: 134 | if type(inds_map[sent_ind]) == int: 135 | inds_map[sent_ind] = [ind] 136 | else: 137 | inds_map[sent_ind].append(ind) 138 | else: 139 | if ind not in inds_map: 140 | inds_map[ind] = ind 141 | 142 | for sent_ind in sorted(inds_map.keys()): 143 | if type(inds_map[sent_ind]) != list: 144 | context.append(sent_ind) 145 | else: 146 | for table_ind in sorted(inds_map[sent_ind]): 147 | context.append(table_ind) 148 | 149 | return context 150 | 151 | def get_table_sent_map(paragraphs): 152 | table_index = 0 153 | table_sent_map = {} 154 | for i, sent in enumerate(paragraphs): 155 | if sent.startswith("## Table "): 156 | table_sent_map[table_index] = i 157 | table_index += 1 158 | return table_sent_map 159 | 160 | 161 | 162 | if __name__ == '__main__': 163 | 164 | json_dir_in = "output/retriever_output" 165 | question_classification_json_dir_in = "output/question_classification_output" 166 | json_dir_out = "dataset/reasoning_module_input" 167 | os.makedirs(json_dir_out, exist_ok = True) 168 | 169 | topn, max_len = 10, 256 170 | 171 | mode_names = ["train", "test", "dev"] 172 | for mode in mode_names: 173 | json_in = os.path.join(json_dir_in, f"{mode}.json") 174 | question_classification_json_in = os.path.join(question_classification_json_dir_in, f"{mode}.json") 175 | json_out_train = os.path.join(json_dir_out, mode + "_training.json") 176 | json_out_inference = os.path.join(json_dir_out, mode + "_inference.json") 177 | 178 | if mode == "train": 179 | convert_train(json_in, json_out_train, topn, max_len) 180 | if mode == "dev": 181 | convert_train(json_in, json_out_train, topn, max_len) 182 | convert_test(json_in, question_classification_json_in, json_out_inference, topn, max_len) 183 | elif mode == "test": 184 | convert_test(json_in, question_classification_json_in, json_out_inference, topn, max_len) 185 | 186 | print(f"Convert {mode} set done") 187 | 188 | -------------------------------------------------------------------------------- /evaluate.py: -------------------------------------------------------------------------------- 1 | import json, os, re 2 | from utils.span_selection_utils import * 3 | from utils.program_generation_utils import * 4 | import math 5 | 6 | def evaluate_program_result(pred_prog, gold_prog): 7 | ''' 8 | execution acc 9 | execution acc = exact match = f1 10 | ''' 11 | invalid_flag, exe_res = eval_program(pred_prog) 12 | 13 | gold = program_tokenization(gold_prog) 14 | invalid_flag, exe_gold_res = eval_program(gold) 15 | 16 | if invalid_flag: 17 | print(gold) 18 | if exe_res == exe_gold_res: 19 | exe_acc = 1 20 | else: 21 | exe_acc = 0 22 | 23 | return exe_acc, exe_acc 24 | 25 | def evaluate_span_program_result(span_ans, prog_ans): 26 | span_ans = str(span_ans) 27 | if str_to_num(span_ans) != "n/a": 28 | span_ans = str_to_num(span_ans) 29 | if math.isclose(prog_ans, span_ans, abs_tol= min(abs(min(prog_ans, span_ans) / 1000), 0.1)): 30 | exact_match, f1 = 1, 1 31 | else: 32 | exact_match, f1 = 0, 0 33 | else: 34 | exact_match, f1 = get_span_selection_metrics(span_ans, str(prog_ans)) 35 | return exact_match, f1 36 | 37 | def combine_predictions(span_selection_json_in, program_generation_json_in, test_file_json_in, output_dir): 38 | span_selection_data = json.load(open(span_selection_json_in)) 39 | program_generation_data = json.load(open(program_generation_json_in)) 40 | orig_data = json.load(open(test_file_json_in)) 41 | 42 | prediction_dict = {} 43 | for example in span_selection_data + program_generation_data: 44 | uid = example["uid"] 45 | pred_ans = example["predicted_ans"] 46 | pred_program = example["predicted_program"] 47 | 48 | if uid in prediction_dict: 49 | print(f"uid {uid} already in prediction_dict") 50 | else: 51 | prediction_dict[uid] = { 52 | "uid": uid, 53 | "predicted_ans": pred_ans, 54 | "predicted_program": pred_program 55 | } 56 | 57 | output_data = [] 58 | for example in orig_data: 59 | output_data.append(prediction_dict[example["uid"]]) 60 | 61 | mode = "dev" if "dev" in test_file_json_in else "test" 62 | output_file = os.path.join(output_dir, f"{mode}_predictions.json") 63 | json.dump(output_data, open(output_file, "w"), indent=4) 64 | 65 | print(f"{mode}: Combine {len(span_selection_data)} examples from span selection output, {len(program_generation_data)} examples from program generation output. The prediction are generated in {output_file}") 66 | 67 | return prediction_dict 68 | 69 | 70 | def evaluation_prediction_result(span_selection_json_in, program_generation_json_in, test_file_json_in, output_dir): 71 | exact_match_total, f1_total = 0, 0 72 | prediction_dict = combine_predictions(span_selection_json_in, program_generation_json_in, test_file_json_in, output_dir) 73 | 74 | if "test" in test_file_json_in: 75 | print("Please submit the test prediction file to CodaLab to get the results") 76 | return 77 | 78 | orig_data = json.load(open(test_file_json_in)) 79 | num_examples = len(orig_data) 80 | 81 | for example in orig_data: 82 | uid = example["uid"] 83 | pred = prediction_dict[uid] 84 | 85 | gold_prog = example["qa"]["program"] 86 | gold_ans = example["qa"]["answer"] 87 | 88 | # both program generation 89 | if pred["predicted_program"] and gold_prog: 90 | exact_acc, f1_acc = evaluate_program_result(pred["predicted_program"], gold_prog) 91 | # both span selection 92 | elif not pred["predicted_program"] and not gold_prog: 93 | exact_acc, f1_acc = get_span_selection_metrics(pred["predicted_ans"], gold_ans) 94 | # gold is span selection, pred is program generation 95 | elif not pred["predicted_program"] and gold_prog: 96 | exact_acc, f1_acc = evaluate_span_program_result(span_ans = pred["predicted_ans"], prog_ans = gold_ans) 97 | # gold is program generation, pred is span selection 98 | elif pred["predicted_program"] and not gold_prog: 99 | exact_acc, f1_acc = evaluate_span_program_result(span_ans = gold_ans, prog_ans = pred["predicted_ans"]) 100 | 101 | exact_match_total += exact_acc 102 | f1_total += f1_acc 103 | exact_match_score, f1_score = exact_match_total / num_examples, f1_total / num_examples 104 | print(f"Exact Match Score: {exact_match_score}, F1 Score: {f1_score}") 105 | 106 | return exact_match_score, f1_score 107 | 108 | 109 | 110 | if __name__ == '__main__': 111 | test_path = sys.argv[1] 112 | if "dev" in test_path: 113 | mode = "dev" 114 | elif "test" in test_path: 115 | mode = "test" 116 | else: 117 | raise ValueError("Cannot recognize the file name") 118 | 119 | output_dir = "output" 120 | span_selection_dir = "span_selection_output" 121 | program_generation_dir = "program_generation_output" 122 | 123 | span_selection_json_in = os.path.join(output_dir, span_selection_dir, f"{mode}_predictions.json") 124 | program_generation_json_in = os.path.join(output_dir, program_generation_dir, f"{mode}_predictions.json") 125 | test_file_json_in = os.path.join("dataset", f"{mode}.json") 126 | 127 | prediction_output_dir = os.path.join(output_dir, "final_predictions") 128 | os.makedirs(prediction_output_dir, exist_ok=True) 129 | evaluation_prediction_result(span_selection_json_in, program_generation_json_in, test_file_json_in, prediction_output_dir) 130 | -------------------------------------------------------------------------------- /inference_configs/program_generation_inference.yaml: -------------------------------------------------------------------------------- 1 | seed_everything: 333 2 | trainer: 3 | gpus: [0] 4 | callbacks: 5 | - class_path: lightning_modules.callbacks.program_generation_save_prediction_callback.SavePredictionCallback 6 | init_args: 7 | test_set: &prediction_set test # {dev, test} 8 | input_dir: &input_dir_path dataset/reasoning_module_input 9 | output_dir: output/program_generation_output 10 | model_name: &transformer roberta-base 11 | program_length: &prog_len 30 12 | input_length: &input_max_len 512 13 | entity_name: &selected_entity_name predicted_question_type 14 | 15 | accelerator: gpu 16 | strategy: ddp_find_unused_parameters_false 17 | 18 | model: 19 | class_path: lightning_modules.models.program_generation_model.ProgramGenerationModel 20 | init_args: 21 | model_name: *transformer 22 | program_length: *prog_len 23 | input_length: *input_max_len 24 | max_step_ind: 11 25 | dropout_rate: 0.1 26 | num_decoder_layers: 1 27 | n_best_size: 20 28 | sep_attention: True 29 | layer_norm: True 30 | optimizer: 31 | init_args: 32 | lr: 0.0 33 | lr_scheduler: 34 | name: linear 35 | init_args: 36 | num_warmup_steps: 100 37 | num_training_steps: 10000 38 | test_set: *prediction_set 39 | entity_name: *selected_entity_name 40 | input_dir: *input_dir_path 41 | 42 | data: 43 | class_path: lightning_modules.datasets.program_generation_reader.ProgramGenerationPredictionDataModule 44 | init_args: 45 | model_name: *transformer 46 | max_seq_length: 512 47 | max_program_length: 30 48 | batch_size: 64 49 | test_file_path: ./dataset/reasoning_module_input/test_inference.json # also change line 7 50 | test_max_instances: -1 51 | entity_name: *selected_entity_name 52 | 53 | # clear; export PYTHONPATH=`pwd`; python trainer.py predict --ckpt_path checkpoints/program_generation_model.ckpt --config inference_configs/program_generation_inference.yaml 54 | -------------------------------------------------------------------------------- /inference_configs/question_classification_inference.yaml: -------------------------------------------------------------------------------- 1 | seed_everything: 333 2 | trainer: 3 | gpus: [0] 4 | callbacks: 5 | - class_path: lightning_modules.callbacks.question_classification_save_prediction_callback.SavePredictionCallback 6 | init_args: 7 | test_set: &prediction_set test 8 | input_dir: &input_dir_path dataset 9 | output_dir: output/question_classification_output 10 | 11 | accelerator: gpu 12 | strategy: ddp_find_unused_parameters_false 13 | 14 | 15 | model: 16 | class_path: lightning_modules.models.question_classification_model.QuestionClassificationModel 17 | init_args: 18 | model_name: &transformer roberta-base 19 | optimizer: 20 | init_args: 21 | lr: 0.0 22 | lr_scheduler: 23 | name: linear 24 | init_args: 25 | num_warmup_steps: 100 26 | num_training_steps: 10000 27 | test_set: *prediction_set 28 | 29 | data: 30 | class_path: lightning_modules.datasets.question_classification_reader.QuestionClassificationPredictionDataModule 31 | init_args: 32 | model_name: *transformer 33 | batch_size: 64 34 | test_file_path: dataset/test.json # {dataset/test.json, dataset/dev.json}, also change line 7 35 | test_max_instances: -1 36 | 37 | # for inference: 38 | # clear; export PYTHONPATH=`pwd`; python trainer.py predict --ckpt_path checkpoints/question_classification_model.ckpt --config inference_configs/question_classification_inference.yaml 39 | -------------------------------------------------------------------------------- /inference_configs/retriever_inference.yaml: -------------------------------------------------------------------------------- 1 | seed_everything: 333 2 | trainer: 3 | gpus: [0] 4 | callbacks: 5 | - class_path: lightning_modules.callbacks.retriever_save_prediction_callback.SavePredictionCallback 6 | init_args: 7 | test_set: &prediction_set test 8 | input_dir: &input_dir_path dataset 9 | output_dir: output/retriever_output 10 | 11 | accelerator: gpu 12 | strategy: ddp 13 | 14 | model: 15 | class_path: lightning_modules.models.retriever_model.RetrieverModel 16 | init_args: 17 | transformer_model_name: &transformer roberta-base 18 | topn: 10 19 | dropout_rate: 0.1 20 | optimizer: 21 | init_args: 22 | lr: 0.0 23 | lr_scheduler: 24 | name: linear 25 | init_args: 26 | num_warmup_steps: 100 27 | num_training_steps: 10000∂ 28 | 29 | data: 30 | class_path: lightning_modules.datasets.retriever_reader.RetrieverPredictionDataModule 31 | init_args: 32 | transformer_model_name: *transformer 33 | batch_size: 512 34 | num_workers: 8 35 | test_file_path: dataset/test.json # {train, dev, test}, also change line 7 36 | test_max_instances: -1 37 | 38 | # for inference: 39 | # clear; export PYTHONPATH=`pwd`; python trainer.py predict --ckpt_path checkpoints/retriever_model.ckpt --config inference_configs/retriever_inference.yaml 40 | -------------------------------------------------------------------------------- /inference_configs/span_selection_inference.yaml: -------------------------------------------------------------------------------- 1 | seed_everything: 333 2 | trainer: 3 | gpus: [0] 4 | callbacks: 5 | - class_path: lightning_modules.callbacks.span_selection_save_prediction_callback.SavePredictionCallback 6 | init_args: 7 | test_set: &prediction_set test # {dev, test}, also change line 34 8 | input_dir: &input_dir_path dataset/reasoning_module_input 9 | output_dir: output/span_selection_output 10 | 11 | accelerator: gpu 12 | strategy: ddp_find_unused_parameters_false 13 | 14 | model: 15 | class_path: lightning_modules.models.span_selection_model.SpanSelectionModel 16 | init_args: 17 | model_name: &transformer t5-base 18 | optimizer: 19 | init_args: 20 | lr: 0.0 21 | lr_scheduler: 22 | name: linear 23 | init_args: 24 | num_warmup_steps: 100 25 | num_training_steps: 10000 26 | test_set: test_inference.json # no usage here 27 | input_dir: *input_dir_path 28 | 29 | data: 30 | class_path: lightning_modules.datasets.span_selection_reader.SpanSelectionInferenceDataModule 31 | init_args: 32 | model_name: *transformer 33 | batch_size: 128 34 | test_file_path: ./dataset/reasoning_module_input/test_inference.json 35 | test_max_instances: -1 36 | entity_name: predicted_question_type 37 | 38 | 39 | # For Inference 40 | # clear; export PYTHONPATH=`pwd`; python trainer.py predict --ckpt_path checkpoints/span_selection_model.ckpt --config inference_configs/span_selection_inference.yaml -------------------------------------------------------------------------------- /lightning_modules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/psunlpgroup/MultiHiertt/45bd9ccdf3142ea059bd5e69c0afb83437fa539c/lightning_modules/__init__.py -------------------------------------------------------------------------------- /lightning_modules/callbacks/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/psunlpgroup/MultiHiertt/45bd9ccdf3142ea059bd5e69c0afb83437fa539c/lightning_modules/callbacks/__init__.py -------------------------------------------------------------------------------- /lightning_modules/callbacks/program_generation_save_prediction_callback.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import torch 4 | import pytorch_lightning as pl 5 | 6 | from typing import Any, Dict, Optional, List 7 | from pytorch_lightning.callbacks import Callback 8 | from pathlib import Path 9 | from transformers import AutoConfig, AutoTokenizer, AutoModel 10 | from utils.program_generation_utils import * 11 | from utils.utils import * 12 | 13 | op_list, const_list = get_op_const_list() 14 | 15 | class SavePredictionCallback(Callback): 16 | def __init__(self, test_set: str, input_dir: str, output_dir: str, model_name: str, program_length: int, input_length: int, entity_name: str): 17 | self.test_set = test_set 18 | self.input_dir = input_dir 19 | self.output_dir = output_dir 20 | self.predictions = [] 21 | 22 | self.model_name = model_name 23 | self.program_length = program_length 24 | self.input_length = input_length 25 | self.entity_name = entity_name 26 | 27 | self.tokenizer = AutoTokenizer.from_pretrained(self.model_name) 28 | 29 | def on_predict_batch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", 30 | outputs: List[Dict[str, Any]], batch: Any, batch_idx: int, dataloader_idx: int) -> None: 31 | self.predictions.extend(outputs) 32 | 33 | def on_predict_epoch_end(self, trainer, pl_module, outputs) -> None: 34 | test_file = os.path.join(self.input_dir, f"{self.test_set}_inference.json") 35 | 36 | with open(test_file) as input_file: 37 | input_data = json.load(input_file) 38 | 39 | data_ori = [] 40 | for entry in input_data: 41 | example = read_mathqa_entry(entry, self.tokenizer, self.entity_name) 42 | if example: 43 | data_ori.append(example) 44 | 45 | kwargs = { 46 | "examples": data_ori, 47 | "tokenizer": self.tokenizer, 48 | "max_seq_length": self.input_length, 49 | "max_program_length": self.program_length, 50 | "is_training": False, 51 | "op_list": op_list, 52 | "op_list_size": len(op_list), 53 | "const_list": const_list, 54 | "const_list_size": len(const_list), 55 | "verbose": True 56 | } 57 | 58 | data = convert_examples_to_features(**kwargs) 59 | 60 | all_results = [] 61 | 62 | 63 | for output_dict in self.predictions: 64 | all_results.append( 65 | RawResult( 66 | unique_id=output_dict["unique_id"], 67 | logits=output_dict["logits"], 68 | loss=None 69 | )) 70 | 71 | all_predictions, all_nbest = compute_predictions( 72 | data_ori, 73 | data, 74 | all_results, 75 | n_best_size=1, 76 | max_program_length=self.program_length, 77 | tokenizer=self.tokenizer, 78 | op_list=op_list, 79 | op_list_size=len(op_list), 80 | const_list=const_list, 81 | const_list_size=len(const_list)) 82 | 83 | output_data = [] 84 | for i in all_nbest: 85 | pred = all_nbest[i][0] 86 | uid = pred["id"] 87 | pred_prog = pred["pred_prog"] 88 | invalid_flag, pred_ans = eval_program(pred_prog) 89 | if invalid_flag == 1: 90 | pred_ans = -float("inf") 91 | output_data.append({"uid": uid, "predicted_ans": pred_ans, "predicted_program": pred_prog}) 92 | 93 | os.makedirs(self.output_dir, exist_ok=True) 94 | 95 | output_file = os.path.join(self.output_dir, f"{self.test_set}_predictions.json") 96 | json.dump(output_data, open(output_file, "w"), indent = 4) 97 | print(f"Predictions saved to {output_file}") 98 | # reset the predictions 99 | self.predictions = [] 100 | -------------------------------------------------------------------------------- /lightning_modules/callbacks/question_classification_save_prediction_callback.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import torch 4 | import pytorch_lightning as pl 5 | 6 | from typing import Any, Dict, Optional, List 7 | from pytorch_lightning.callbacks import Callback 8 | from pathlib import Path 9 | 10 | 11 | class SavePredictionCallback(Callback): 12 | def __init__(self, test_set: str, input_dir: str, output_dir: str): 13 | self.test_set = test_set 14 | self.input_dir = input_dir 15 | self.output_dir = output_dir 16 | self.predictions = [] 17 | 18 | def on_predict_batch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", 19 | outputs: List[Dict[str, Any]], batch: Any, batch_idx: int, dataloader_idx: int) -> None: 20 | preds = outputs["preds"].detach().cpu().numpy() 21 | 22 | for i, uid in enumerate(outputs["uids"]): 23 | result = { 24 | "uid": uid, 25 | "pred": "arithmetic" if int(preds[i]) == 1 else "span_selection", 26 | } 27 | self.predictions.append(result) 28 | 29 | def on_predict_epoch_end(self, trainer, pl_module, outputs) -> None: 30 | # save the predictions 31 | os.makedirs(self.output_dir, exist_ok=True) 32 | output_prediction_file = os.path.join(self.output_dir, f"{self.test_set}.json") 33 | 34 | json.dump(self.predictions, open(output_prediction_file, "w"), indent = 4) 35 | print(f"generate {self.test_set} inference file in {output_prediction_file}") 36 | 37 | self.predictions = [] 38 | -------------------------------------------------------------------------------- /lightning_modules/callbacks/retriever_save_prediction_callback.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import torch 4 | import pytorch_lightning as pl 5 | 6 | from typing import Any, Dict, Optional, List 7 | from pytorch_lightning.callbacks import Callback 8 | from pathlib import Path 9 | from utils.retriever_utils import * 10 | 11 | class SavePredictionCallback(Callback): 12 | def __init__(self, test_set: str, input_dir: str, output_dir: str): 13 | self.test_set = test_set 14 | self.input_dir = input_dir 15 | self.output_dir = output_dir 16 | self.predictions = [] 17 | 18 | def on_predict_batch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", 19 | outputs: List[Dict[str, Any]], batch: Any, batch_idx: int, dataloader_idx: int) -> None: 20 | self.predictions.extend(outputs) 21 | 22 | def on_predict_epoch_end(self, trainer, pl_module, outputs) -> None: 23 | # save the predictions 24 | all_logits = [] 25 | all_filename_id = [] 26 | all_ind = [] 27 | for output_dict in self.predictions: 28 | all_logits.append(output_dict["logits"]) 29 | all_filename_id.append(output_dict["filename_id"]) 30 | all_ind.append(output_dict["ind"]) 31 | 32 | test_file = os.path.join(self.input_dir, f"{self.test_set}.json") 33 | 34 | os.makedirs(self.output_dir, exist_ok=True) 35 | output_prediction_file = os.path.join(self.output_dir, f"{self.test_set}.json") 36 | 37 | retrieve_inference(all_logits, all_filename_id, all_ind, output_prediction_file, test_file) 38 | print(f"generate {self.test_set} inference file in {output_prediction_file}") 39 | 40 | self.predictions = [] 41 | -------------------------------------------------------------------------------- /lightning_modules/callbacks/span_selection_save_prediction_callback.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import torch 4 | import pytorch_lightning as pl 5 | 6 | from typing import Any, Dict, Optional, List 7 | from pytorch_lightning.callbacks import Callback 8 | from pathlib import Path 9 | from utils.retriever_utils import * 10 | 11 | class SavePredictionCallback(Callback): 12 | def __init__(self, test_set: str, input_dir: str, output_dir: str): 13 | self.test_set = test_set 14 | self.input_dir = input_dir 15 | self.output_dir = output_dir 16 | self.predictions = [] 17 | 18 | def on_predict_batch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", 19 | outputs: List[Dict[str, Any]], batch: Any, batch_idx: int, dataloader_idx: int) -> None: 20 | self.predictions.extend(outputs) 21 | 22 | def on_predict_epoch_end(self, trainer, pl_module, outputs) -> None: 23 | # save the predictions 24 | all_filename_id = [] 25 | all_preds = [] 26 | for output_dict in self.predictions: 27 | pred = output_dict["preds"] 28 | unique_id = output_dict["uid"] 29 | 30 | all_filename_id.append(unique_id) 31 | all_preds.append(pred) 32 | 33 | output_data = [] 34 | for filename_id, pred in zip(all_filename_id, all_preds): 35 | output_example = { 36 | "uid": filename_id, 37 | "predicted_ans": pred, 38 | "predicted_program": [] 39 | } 40 | output_data.append(output_example) 41 | 42 | os.makedirs(self.output_dir, exist_ok=True) 43 | output_prediction_file = os.path.join(self.output_dir, f"{self.test_set}_predictions.json") 44 | json.dump(output_data, open(output_prediction_file, "w"), indent=4) 45 | print(f"generate {self.test_set}.json file in {output_prediction_file}") 46 | 47 | self.predictions = [] 48 | -------------------------------------------------------------------------------- /lightning_modules/datasets/.ipynb_checkpoints/__init__-checkpoint.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/psunlpgroup/MultiHiertt/45bd9ccdf3142ea059bd5e69c0afb83437fa539c/lightning_modules/datasets/.ipynb_checkpoints/__init__-checkpoint.py -------------------------------------------------------------------------------- /lightning_modules/datasets/.ipynb_checkpoints/datasets_util-checkpoint.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import io, tokenize, re 3 | import ast, astunparse 4 | 5 | from typing import Tuple, Optional, List, Union 6 | 7 | def right_pad_sequences(sequences: List[torch.Tensor], batch_first: bool = True, padding_value: Union[int, bool] = 0, 8 | max_len: int = -1, device: torch.device = None) -> torch.Tensor: 9 | assert all([len(seq.shape) == 1 for seq in sequences]) 10 | max_len = max_len if max_len > 0 else max(len(s) for s in sequences) 11 | device = device if device is not None else sequences[0].device 12 | 13 | padded_seqs = [] 14 | for seq in sequences: 15 | padded_seqs.append(torch.cat(seq, (torch.full((max_len - seq.shape[0],), padding_value, dtype=torch.long).to(device)))) 16 | return torch.stack(padded_seqs) -------------------------------------------------------------------------------- /lightning_modules/datasets/.ipynb_checkpoints/retriever_reader-checkpoint.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | import sys 4 | import os 5 | import torch 6 | 7 | from typing import Dict, Iterable, List, Any, Optional, Union 8 | 9 | from pytorch_lightning import LightningDataModule 10 | from torch.utils.data import Dataset 11 | 12 | from datasets_util import right_pad_sequences 13 | 14 | from transformers import AutoTokenizer 15 | import retriever_utils as retriever_utils 16 | from retriever_utils import * 17 | from utils import * 18 | from torch.utils.data import DataLoader 19 | 20 | os.environ['TOKENIZERS_PARALLELISM']='0' 21 | 22 | class MHQADataset(Dataset): 23 | def __init__( 24 | self, 25 | transformer_model_name: str, 26 | file_path: str, 27 | max_instances: int, 28 | mode: str = "train", 29 | **kwargs): 30 | super().__init__(**kwargs) 31 | 32 | assert mode in ["train", "test", "valid"] 33 | 34 | self.tokenizer = AutoTokenizer.from_pretrained(transformer_model_name) 35 | 36 | self.max_instances = max_instances 37 | self.mode = mode 38 | self.instances = self.read(file_path, self.tokenizer) 39 | 40 | 41 | def read(self, input_path: str, tokenizer) -> Iterable[Dict[str, Any]]: 42 | with open(input_path) as input_file: 43 | input_data = json.load(input_file)[:self.max_instances] 44 | 45 | examples = [] 46 | for entry in input_data: 47 | examples.append(retriever_utils.read_mathqa_entry(entry, tokenizer)) 48 | 49 | if self.mode == "train": 50 | kwargs = {"examples": examples, 51 | "tokenizer": tokenizer, 52 | "option": "rand", 53 | "is_training": True, 54 | "max_seq_length": 512, 55 | } 56 | else: 57 | kwargs = {"examples": examples, 58 | "tokenizer": tokenizer, 59 | "option": "rand", 60 | "is_training": False, 61 | "max_seq_length": 512, 62 | } 63 | 64 | features = convert_examples_to_features(**kwargs) 65 | data_pos, neg_sent, irrelevant_neg_table, relevant_neg_table = features[0], features[1], features[2], features[3] 66 | 67 | 68 | if self.mode == "train": 69 | random.shuffle(neg_sent) 70 | random.shuffle(irrelevant_neg_table) 71 | random.shuffle(relevant_neg_table) 72 | data = data_pos + relevant_neg_table[:min(len(relevant_neg_table),len(data_pos) * 3)] + irrelevant_neg_table[:min(len(irrelevant_neg_table),len(data_pos) * 2)] + neg_sent[:min(len(neg_sent),len(data_pos))] 73 | else: 74 | data = data_pos + neg_sent + irrelevant_neg_table + relevant_neg_table 75 | print(self.mode, len(data)) 76 | return data 77 | 78 | def __getitem__(self, idx: int): 79 | return self.instances[idx] 80 | 81 | def __len__(self): 82 | return len(self.instances) 83 | 84 | def truncate(self, max_instances): 85 | truncated_instances = self.instances[max_instances:] 86 | self.instances = self.instances[:max_instances] 87 | return truncated_instances 88 | 89 | def extend(self, instances): 90 | self.instances.extend(instances) 91 | 92 | def customized_collate_fn(examples: List[Dict[str, Any]]) -> Dict[str, Any]: 93 | result_dict = {} 94 | for k in examples[0].keys(): 95 | try: 96 | result_dict[k] = right_pad_sequences([torch.tensor(ex[k]) for ex in examples], 97 | batch_first=True, padding_value=0) 98 | except: 99 | result_dict[k] = [ex[k] for ex in examples] 100 | return result_dict 101 | 102 | class MHQADataModule(LightningDataModule): 103 | def __init__(self, 104 | transformer_model_name: str, 105 | batch_size: int = 1, 106 | val_batch_size: int = 1, 107 | test_batch_size: int = 1, 108 | train_file_path: str = None, 109 | num_workers: int = 8, 110 | val_file_path: str = None, 111 | test_file_path: str = None, 112 | train_max_instances: int = sys.maxsize, 113 | val_max_instances: int = sys.maxsize, 114 | test_max_instances: int = sys.maxsize): 115 | super().__init__() 116 | self.transformer_model_name = transformer_model_name 117 | 118 | self.batch_size = batch_size 119 | self.val_batch_size = val_batch_size 120 | self.test_batch_size = test_batch_size 121 | self.num_workers = num_workers 122 | 123 | self.train_file_path = train_file_path 124 | self.val_file_path = val_file_path 125 | self.test_file_path = test_file_path 126 | 127 | self.train_max_instances = train_max_instances 128 | self.val_max_instances = val_max_instances 129 | self.test_max_instances = test_max_instances 130 | 131 | self.train_data = None 132 | self.val_data = None 133 | 134 | # OPTIONAL, called for every GPU/machine (assigning state is OK) 135 | def setup(self, stage: Optional[str] = None): 136 | assert stage in ["fit", "validate", "test"] 137 | 138 | train_data = MHQADataset(transformer_model_name = self.transformer_model_name, 139 | file_path=self.train_file_path, 140 | max_instances=self.train_max_instances, 141 | mode="train") 142 | self.train_data = train_data 143 | 144 | val_data = MHQADataset(transformer_model_name = self.transformer_model_name, 145 | file_path=self.val_file_path, 146 | max_instances=self.val_max_instances, 147 | mode="valid") 148 | self.val_data = val_data 149 | 150 | test_data = MHQADataset(transformer_model_name = self.transformer_model_name, 151 | file_path=self.test_file_path, 152 | max_instances=self.test_max_instances, 153 | mode="test") 154 | self.test_data = test_data 155 | 156 | def train_dataloader(self): 157 | if self.train_data is None: 158 | self.setup(stage="fit") 159 | 160 | dtloader = DataLoader(self.train_data, batch_size=self.batch_size, shuffle=True, drop_last=True, collate_fn=customized_collate_fn, num_workers = self.num_workers) 161 | return dtloader 162 | 163 | def val_dataloader(self): 164 | if self.val_data is None: 165 | self.setup(stage="validate") 166 | 167 | dtloader = DataLoader(self.val_data, batch_size=self.val_batch_size, shuffle=True, drop_last=False, collate_fn=customized_collate_fn, num_workers = self.num_workers) 168 | return dtloader 169 | 170 | def test_dataloader(self): 171 | if self.test_data is None: 172 | self.setup(stage="test") 173 | 174 | dtloader = DataLoader(self.test_data, batch_size=self.test_batch_size, shuffle=True, drop_last=False, collate_fn=customized_collate_fn, num_workers = self.num_workers) 175 | return dtloader 176 | 177 | 178 | class MHQAPredictionDataModule(LightningDataModule): 179 | def __init__(self, 180 | transformer_model_name: str, 181 | batch_size: int = 1, 182 | num_workers: int = 8, 183 | test_file_path: str = None, 184 | test_max_instances: int = sys.maxsize): 185 | super().__init__() 186 | self.transformer_model_name = transformer_model_name 187 | 188 | self.batch_size = batch_size 189 | self.num_workers = num_workers 190 | 191 | self.test_file_path = test_file_path 192 | 193 | self.test_max_instances = test_max_instances 194 | 195 | self.test_data = None 196 | 197 | # OPTIONAL, called for every GPU/machine (assigning state is OK) 198 | def setup(self, stage: Optional[str] = None): 199 | assert stage in ["test"] 200 | 201 | test_data = MHQADataset(transformer_model_name = self.transformer_model_name, 202 | file_path=self.test_file_path, 203 | max_instances=self.test_max_instances, 204 | mode="test") 205 | self.test_data = test_data 206 | 207 | def test_dataloader(self): 208 | if self.test_data is None: 209 | self.setup(stage="test") 210 | 211 | dtloader = DataLoader(self.test_data, batch_size=self.batch_size, shuffle=True, drop_last=False, collate_fn=customized_collate_fn, num_workers = self.num_workers) 212 | return dtloader -------------------------------------------------------------------------------- /lightning_modules/datasets/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/psunlpgroup/MultiHiertt/45bd9ccdf3142ea059bd5e69c0afb83437fa539c/lightning_modules/datasets/__init__.py -------------------------------------------------------------------------------- /lightning_modules/datasets/program_generation_reader.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | import sys 4 | import os 5 | import torch 6 | 7 | from typing import Dict, Iterable, List, Any, Optional, Union 8 | 9 | from pytorch_lightning import LightningDataModule 10 | from torch.utils.data import Dataset 11 | 12 | from transformers import AutoTokenizer 13 | 14 | from utils.program_generation_utils import * 15 | from utils.utils import * 16 | from utils.datasets_util import right_pad_sequences 17 | from torch.utils.data import DataLoader 18 | 19 | os.environ['TOKENIZERS_PARALLELISM']='0' 20 | op_list, const_list = get_op_const_list() 21 | reserved_token_size = len(op_list) + len(const_list) 22 | 23 | class ProgramGenerationDataset(Dataset): 24 | def __init__( 25 | self, 26 | model_name: str, 27 | file_path: str, 28 | max_seq_length: int, 29 | max_program_length: int, 30 | max_instances: int, 31 | mode: str = "train", 32 | entity_name: str = "question_type", 33 | **kwargs): 34 | super().__init__(**kwargs) 35 | 36 | assert mode in ["train", "test", "valid"] 37 | 38 | self.max_seq_length = max_seq_length 39 | self.max_program_length = max_program_length 40 | 41 | self.tokenizer = AutoTokenizer.from_pretrained(model_name) 42 | 43 | self.max_instances = max_instances 44 | self.mode = mode 45 | self.entity_name = entity_name 46 | self.instances = self.read(file_path, self.tokenizer, self.entity_name) 47 | 48 | print(f"read {len(self.instances)} {self.mode} examples") 49 | 50 | def read(self, input_path: str, tokenizer, entity_name: str) -> Iterable[Dict[str, Any]]: 51 | with open(input_path) as input_file: 52 | if self.max_instances > 0: 53 | input_data = json.load(input_file)[:self.max_instances] 54 | else: 55 | input_data = json.load(input_file) 56 | 57 | examples = [] 58 | for entry in input_data: 59 | example = read_mathqa_entry(entry, tokenizer, entity_name) 60 | if example: 61 | examples.append(example) 62 | 63 | 64 | kwargs = { 65 | "examples": examples, 66 | "tokenizer": tokenizer, 67 | "max_seq_length": self.max_seq_length, 68 | "max_program_length": self.max_program_length, 69 | "is_training": True, 70 | "op_list": op_list, 71 | "op_list_size": len(op_list), 72 | "const_list": const_list, 73 | "const_list_size": len(const_list), 74 | "verbose": True 75 | } 76 | 77 | if self.mode != "train": 78 | kwargs["is_training"] = False 79 | 80 | data = convert_examples_to_features(**kwargs) 81 | return data 82 | 83 | def __getitem__(self, idx: int): 84 | return self.instances[idx] 85 | 86 | def __len__(self): 87 | return len(self.instances) 88 | 89 | def truncate(self, max_instances): 90 | truncated_instances = self.instances[max_instances:] 91 | self.instances = self.instances[:max_instances] 92 | return truncated_instances 93 | 94 | def extend(self, instances): 95 | self.instances.extend(instances) 96 | 97 | def customized_collate_fn(examples: List) -> Dict[str, Any]: 98 | result_dict = {} 99 | for k in examples[0].keys(): 100 | try: 101 | result_dict[k] = right_pad_sequences([torch.tensor(ex[k]) for ex in examples], 102 | batch_first=True, padding_value=0) 103 | except: 104 | result_dict[k] = [ex[k] for ex in examples] 105 | return result_dict 106 | 107 | class ProgramGenerationDataModule(LightningDataModule): 108 | def __init__(self, 109 | model_name: str, 110 | max_seq_length: int, 111 | max_program_length: int, 112 | batch_size: int = 1, 113 | val_batch_size: int = 1, 114 | train_file_path: str = None, 115 | val_file_path: str = None, 116 | train_max_instances: int = sys.maxsize, 117 | val_max_instances: int = sys.maxsize, 118 | entity_name: str = "question_type"): 119 | super().__init__() 120 | self.model_name = model_name 121 | self.max_seq_length = max_seq_length 122 | self.max_program_length = max_program_length 123 | 124 | self.batch_size = batch_size 125 | self.val_batch_size = val_batch_size 126 | 127 | self.train_file_path = train_file_path 128 | self.val_file_path = val_file_path 129 | 130 | self.train_max_instances = train_max_instances 131 | self.val_max_instances = val_max_instances 132 | 133 | self.entity_name = entity_name 134 | 135 | self.train_data = None 136 | self.val_data = None 137 | 138 | # OPTIONAL, called for every GPU/machine (assigning state is OK) 139 | def setup(self, stage: Optional[str] = None): 140 | assert stage in ["fit", "validate", "test"] 141 | 142 | train_data = ProgramGenerationDataset(model_name=self.model_name, 143 | file_path=self.train_file_path, 144 | max_seq_length = self.max_seq_length, 145 | max_program_length = self.max_program_length, 146 | max_instances = self.train_max_instances, 147 | mode = "train", 148 | entity_name = self.entity_name) 149 | 150 | self.train_data = train_data 151 | 152 | val_data = ProgramGenerationDataset(model_name=self.model_name, 153 | file_path=self.val_file_path, 154 | max_seq_length = self.max_seq_length, 155 | max_program_length = self.max_program_length, 156 | max_instances=self.val_max_instances, 157 | mode="valid", 158 | entity_name = self.entity_name) 159 | self.val_data = val_data 160 | 161 | def train_dataloader(self): 162 | if self.train_data is None: 163 | self.setup(stage="fit") 164 | 165 | dtloader = DataLoader(self.train_data, batch_size=self.batch_size, shuffle=True, drop_last=True, collate_fn=customized_collate_fn) 166 | return dtloader 167 | 168 | def val_dataloader(self): 169 | if self.val_data is None: 170 | self.setup(stage="validate") 171 | 172 | dtloader = DataLoader(self.val_data, batch_size=self.val_batch_size, shuffle=False, drop_last=False, collate_fn=customized_collate_fn) 173 | return dtloader 174 | 175 | 176 | class ProgramGenerationPredictionDataModule(LightningDataModule): 177 | def __init__(self, 178 | model_name: str, 179 | max_seq_length: int, 180 | max_program_length: int, 181 | batch_size: int = 1, 182 | test_file_path: str = None, 183 | test_max_instances: int = sys.maxsize, 184 | entity_name: str = "question_type"): 185 | super().__init__() 186 | self.model_name = model_name 187 | self.max_seq_length = max_seq_length 188 | self.max_program_length = max_program_length 189 | 190 | self.batch_size = batch_size 191 | 192 | self.test_file_path = test_file_path 193 | self.test_max_instances = test_max_instances 194 | 195 | self.test_data = None 196 | 197 | self.entity_name = entity_name 198 | 199 | # OPTIONAL, called for every GPU/machine (assigning state is OK) 200 | def setup(self, stage: Optional[str] = None): 201 | assert stage in ["test", "predict"] 202 | 203 | test_data = ProgramGenerationDataset(model_name=self.model_name, 204 | file_path=self.test_file_path, 205 | max_seq_length = self.max_seq_length, 206 | max_program_length = self.max_program_length, 207 | max_instances=self.test_max_instances, 208 | mode="test", 209 | entity_name = self.entity_name) 210 | self.test_data = test_data 211 | 212 | def test_dataloader(self): 213 | if self.test_data is None: 214 | self.setup(stage="test") 215 | 216 | dtloader = DataLoader(self.test_data, batch_size=self.batch_size, shuffle=False, drop_last=False, collate_fn=customized_collate_fn) 217 | return dtloader 218 | 219 | def predict_dataloader(self): 220 | if self.test_data is None: 221 | self.setup(stage="predict") 222 | 223 | dtloader = DataLoader(self.test_data, batch_size=self.batch_size, shuffle=False, drop_last=False, collate_fn=customized_collate_fn) 224 | return dtloader -------------------------------------------------------------------------------- /lightning_modules/datasets/question_classification_reader.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | import sys 4 | import os 5 | import torch 6 | 7 | from typing import Dict, Iterable, List, Any, Optional, Union 8 | 9 | from pytorch_lightning import LightningDataModule 10 | from torch.utils.data import Dataset 11 | 12 | from utils.datasets_util import right_pad_sequences 13 | 14 | from transformers import AutoTokenizer 15 | from utils.utils import * 16 | from torch.utils.data import DataLoader 17 | 18 | os.environ['TOKENIZERS_PARALLELISM']='0' 19 | 20 | class QuestionClassificationDataset(Dataset): 21 | def __init__( 22 | self, 23 | model_name: str, 24 | file_path: str, 25 | max_instances: int, 26 | mode: str = "train", 27 | **kwargs): 28 | super().__init__(**kwargs) 29 | 30 | assert mode in ["train", "test", "valid", "predict"] 31 | 32 | self.tokenizer = AutoTokenizer.from_pretrained(model_name) 33 | 34 | self.max_instances = max_instances 35 | self.mode = mode 36 | self.instances = self.read(file_path, self.tokenizer, mode) 37 | 38 | print(f"read {len(self.instances)} {self.mode} examples") 39 | 40 | def read(self, input_path: str, tokenizer, mode) -> Iterable[Dict[str, Any]]: 41 | with open(input_path) as input_file: 42 | if self.max_instances > 0: 43 | input_data = json.load(input_file)[:self.max_instances] 44 | else: 45 | input_data = json.load(input_file) 46 | 47 | data = [] 48 | for entry in input_data: 49 | feature = {} 50 | input_text_encoded = tokenizer.encode_plus(entry["qa"]["question"], 51 | max_length=128, 52 | pad_to_max_length=True) 53 | input_ids = input_text_encoded["input_ids"] 54 | input_mask = input_text_encoded["attention_mask"] 55 | 56 | feature = { 57 | "uid": entry["uid"], 58 | "question": entry["qa"]["question"], 59 | "input_ids": input_ids, 60 | "input_mask": input_mask, 61 | } 62 | 63 | if mode != "predict": 64 | feature["labels"] = 1 if entry["qa"]["question_type"] == "arithmetic" else 0 65 | 66 | data.append(feature) 67 | 68 | return data 69 | 70 | def __getitem__(self, idx: int): 71 | return self.instances[idx] 72 | 73 | def __len__(self): 74 | return len(self.instances) 75 | 76 | def truncate(self, max_instances): 77 | truncated_instances = self.instances[max_instances:] 78 | self.instances = self.instances[:max_instances] 79 | return truncated_instances 80 | 81 | def extend(self, instances): 82 | self.instances.extend(instances) 83 | 84 | def customized_collate_fn(examples: List[Dict[str, Any]]) -> Dict[str, Any]: 85 | result_dict = {} 86 | for k in examples[0].keys(): 87 | try: 88 | if k == "labels": 89 | result_dict[k] = torch.tensor([example[k] for example in examples]) 90 | else: 91 | result_dict[k] = right_pad_sequences([torch.tensor(ex[k]) for ex in examples], 92 | batch_first=True, padding_value=0) 93 | except: 94 | result_dict[k] = [ex[k] for ex in examples] 95 | return result_dict 96 | 97 | class QuestionClassificationDataModule(LightningDataModule): 98 | def __init__(self, 99 | model_name: str, 100 | batch_size: int = 1, 101 | val_batch_size: int = 1, 102 | train_file_path: str = None, 103 | val_file_path: str = None, 104 | num_workers: int = 8, 105 | train_max_instances: int = sys.maxsize, 106 | val_max_instances: int = sys.maxsize): 107 | super().__init__() 108 | self.transformer_model_name = model_name 109 | 110 | self.batch_size = batch_size 111 | self.val_batch_size = val_batch_size 112 | self.num_workers = num_workers 113 | 114 | self.train_file_path = train_file_path 115 | self.val_file_path = val_file_path 116 | 117 | self.train_max_instances = train_max_instances 118 | self.val_max_instances = val_max_instances 119 | 120 | self.train_data = None 121 | self.val_data = None 122 | 123 | # OPTIONAL, called for every GPU/machine (assigning state is OK) 124 | def setup(self, stage: Optional[str] = None): 125 | assert stage in ["fit", "validate", "test"] 126 | 127 | train_data = QuestionClassificationDataset(model_name = self.transformer_model_name, 128 | file_path=self.train_file_path, 129 | max_instances=self.train_max_instances, 130 | mode="train") 131 | self.train_data = train_data 132 | 133 | val_data = QuestionClassificationDataset(model_name = self.transformer_model_name, 134 | file_path=self.val_file_path, 135 | max_instances=self.val_max_instances, 136 | mode="valid") 137 | self.val_data = val_data 138 | 139 | def train_dataloader(self): 140 | if self.train_data is None: 141 | self.setup(stage="fit") 142 | 143 | dtloader = DataLoader(self.train_data, batch_size=self.batch_size, shuffle=True, drop_last=True, collate_fn=customized_collate_fn, num_workers = self.num_workers) 144 | return dtloader 145 | 146 | def val_dataloader(self): 147 | if self.val_data is None: 148 | self.setup(stage="validate") 149 | 150 | dtloader = DataLoader(self.val_data, batch_size=self.val_batch_size, shuffle=True, drop_last=False, collate_fn=customized_collate_fn, num_workers = self.num_workers) 151 | return dtloader 152 | 153 | 154 | class QuestionClassificationPredictionDataModule(LightningDataModule): 155 | def __init__(self, 156 | model_name: str, 157 | batch_size: int = 1, 158 | num_workers: int = 8, 159 | test_file_path: str = None, 160 | test_max_instances: int = sys.maxsize): 161 | super().__init__() 162 | self.transformer_model_name = model_name 163 | 164 | self.batch_size = batch_size 165 | self.num_workers = num_workers 166 | 167 | self.test_file_path = test_file_path 168 | 169 | self.test_max_instances = test_max_instances 170 | 171 | self.test_data = None 172 | 173 | # OPTIONAL, called for every GPU/machine (assigning state is OK) 174 | def setup(self, stage: Optional[str] = None): 175 | assert stage in ["test", "predict"] 176 | 177 | test_data = QuestionClassificationDataset(model_name = self.transformer_model_name, 178 | file_path=self.test_file_path, 179 | max_instances=self.test_max_instances, 180 | mode=stage) 181 | self.test_data = test_data 182 | 183 | def test_dataloader(self): 184 | if self.test_data is None: 185 | self.setup(stage="test") 186 | 187 | dtloader = DataLoader(self.test_data, batch_size=self.batch_size, shuffle=False, drop_last=False, collate_fn=customized_collate_fn, num_workers = self.num_workers) 188 | return dtloader 189 | 190 | def predict_dataloader(self): 191 | if self.test_data is None: 192 | self.setup(stage="predict") 193 | 194 | dtloader = DataLoader(self.test_data, batch_size=self.batch_size, shuffle=False, drop_last=False, collate_fn=customized_collate_fn, num_workers = self.num_workers) 195 | return dtloader -------------------------------------------------------------------------------- /lightning_modules/datasets/retriever_reader.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | import sys 4 | import os 5 | import torch 6 | 7 | from typing import Dict, Iterable, List, Any, Optional, Union 8 | 9 | from pytorch_lightning import LightningDataModule 10 | from torch.utils.data import Dataset 11 | 12 | from utils.datasets_util import right_pad_sequences 13 | 14 | from transformers import AutoTokenizer 15 | import utils.retriever_utils as retriever_utils 16 | from utils.retriever_utils import * 17 | from utils.utils import * 18 | from torch.utils.data import DataLoader 19 | 20 | os.environ['TOKENIZERS_PARALLELISM']='0' 21 | 22 | class RetrieverDataset(Dataset): 23 | def __init__( 24 | self, 25 | transformer_model_name: str, 26 | file_path: str, 27 | max_instances: int, 28 | mode: str = "train", 29 | **kwargs): 30 | super().__init__(**kwargs) 31 | 32 | assert mode in ["train", "test", "valid"] 33 | 34 | self.tokenizer = AutoTokenizer.from_pretrained(transformer_model_name) 35 | 36 | self.max_instances = max_instances 37 | self.mode = mode 38 | self.instances = self.read(file_path, self.tokenizer) 39 | 40 | 41 | def read(self, input_path: str, tokenizer) -> Iterable[Dict[str, Any]]: 42 | with open(input_path) as input_file: 43 | if self.max_instances > 0: 44 | input_data = json.load(input_file)[:self.max_instances] 45 | else: 46 | input_data = json.load(input_file) 47 | 48 | examples = [] 49 | for entry in input_data: 50 | examples.append(retriever_utils.read_mathqa_entry(entry, tokenizer)) 51 | 52 | if self.mode == "train": 53 | kwargs = {"examples": examples, 54 | "tokenizer": tokenizer, 55 | "option": "rand", 56 | "is_training": True, 57 | "max_seq_length": 512, 58 | } 59 | else: 60 | kwargs = {"examples": examples, 61 | "tokenizer": tokenizer, 62 | "option": "rand", 63 | "is_training": False, 64 | "max_seq_length": 512, 65 | } 66 | 67 | features = convert_examples_to_features(**kwargs) 68 | data_pos, neg_sent, irrelevant_neg_table, relevant_neg_table = features[0], features[1], features[2], features[3] 69 | 70 | 71 | if self.mode == "train": 72 | random.shuffle(neg_sent) 73 | random.shuffle(irrelevant_neg_table) 74 | random.shuffle(relevant_neg_table) 75 | data = data_pos + relevant_neg_table[:min(len(relevant_neg_table),len(data_pos) * 3)] + irrelevant_neg_table[:min(len(irrelevant_neg_table),len(data_pos) * 2)] + neg_sent[:min(len(neg_sent),len(data_pos))] 76 | else: 77 | data = data_pos + neg_sent + irrelevant_neg_table + relevant_neg_table 78 | print(self.mode, len(data)) 79 | return data 80 | 81 | def __getitem__(self, idx: int): 82 | return self.instances[idx] 83 | 84 | def __len__(self): 85 | return len(self.instances) 86 | 87 | def truncate(self, max_instances): 88 | truncated_instances = self.instances[max_instances:] 89 | self.instances = self.instances[:max_instances] 90 | return truncated_instances 91 | 92 | def extend(self, instances): 93 | self.instances.extend(instances) 94 | 95 | def customized_collate_fn(examples: List[Dict[str, Any]]) -> Dict[str, Any]: 96 | result_dict = {} 97 | for k in examples[0].keys(): 98 | try: 99 | result_dict[k] = right_pad_sequences([torch.tensor(ex[k]) for ex in examples], 100 | batch_first=True, padding_value=0) 101 | except: 102 | result_dict[k] = [ex[k] for ex in examples] 103 | return result_dict 104 | 105 | class RetrieverDataModule(LightningDataModule): 106 | def __init__(self, 107 | transformer_model_name: str, 108 | batch_size: int = 1, 109 | val_batch_size: int = 1, 110 | train_file_path: str = None, 111 | num_workers: int = 8, 112 | val_file_path: str = None, 113 | train_max_instances: int = sys.maxsize, 114 | val_max_instances: int = sys.maxsize): 115 | super().__init__() 116 | self.transformer_model_name = transformer_model_name 117 | 118 | self.batch_size = batch_size 119 | self.val_batch_size = val_batch_size 120 | 121 | self.num_workers = num_workers 122 | self.train_file_path = train_file_path 123 | self.val_file_path = val_file_path 124 | 125 | self.train_max_instances = train_max_instances 126 | self.val_max_instances = val_max_instances 127 | 128 | self.train_data = None 129 | self.val_data = None 130 | 131 | # OPTIONAL, called for every GPU/machine (assigning state is OK) 132 | def setup(self, stage: Optional[str] = None): 133 | assert stage in ["fit", "validate", "test"] 134 | 135 | train_data = RetrieverDataset(transformer_model_name = self.transformer_model_name, 136 | file_path=self.train_file_path, 137 | max_instances=self.train_max_instances, 138 | mode="train") 139 | self.train_data = train_data 140 | 141 | val_data = RetrieverDataset(transformer_model_name = self.transformer_model_name, 142 | file_path=self.val_file_path, 143 | max_instances=self.val_max_instances, 144 | mode="valid") 145 | self.val_data = val_data 146 | 147 | 148 | def train_dataloader(self): 149 | if self.train_data is None: 150 | self.setup(stage="fit") 151 | 152 | dtloader = DataLoader(self.train_data, batch_size=self.batch_size, shuffle=True, drop_last=True, collate_fn=customized_collate_fn, num_workers = self.num_workers) 153 | return dtloader 154 | 155 | def val_dataloader(self): 156 | if self.val_data is None: 157 | self.setup(stage="validate") 158 | 159 | dtloader = DataLoader(self.val_data, batch_size=self.val_batch_size, shuffle=True, drop_last=False, collate_fn=customized_collate_fn, num_workers = self.num_workers) 160 | return dtloader 161 | 162 | 163 | class RetrieverPredictionDataModule(LightningDataModule): 164 | def __init__(self, 165 | transformer_model_name: str, 166 | batch_size: int = 1, 167 | num_workers: int = 8, 168 | test_file_path: str = None, 169 | test_max_instances: int = sys.maxsize): 170 | super().__init__() 171 | self.transformer_model_name = transformer_model_name 172 | 173 | self.batch_size = batch_size 174 | self.num_workers = num_workers 175 | 176 | self.test_file_path = test_file_path 177 | 178 | self.test_max_instances = test_max_instances 179 | 180 | self.test_data = None 181 | 182 | # OPTIONAL, called for every GPU/machine (assigning state is OK) 183 | def setup(self, stage: Optional[str] = None): 184 | assert stage in ["test", "predict"] 185 | 186 | test_data = RetrieverDataset(transformer_model_name = self.transformer_model_name, 187 | file_path=self.test_file_path, 188 | max_instances=self.test_max_instances, 189 | mode="test") 190 | self.test_data = test_data 191 | 192 | def test_dataloader(self): 193 | if self.test_data is None: 194 | self.setup(stage="test") 195 | 196 | dtloader = DataLoader(self.test_data, batch_size=self.batch_size, shuffle=False, drop_last=False, collate_fn=customized_collate_fn, num_workers = self.num_workers) 197 | return dtloader 198 | 199 | def predict_dataloader(self): 200 | if self.test_data is None: 201 | self.setup(stage="predict") 202 | 203 | dtloader = DataLoader(self.test_data, batch_size=self.batch_size, shuffle=False, drop_last=False, collate_fn=customized_collate_fn, num_workers = self.num_workers) 204 | return dtloader -------------------------------------------------------------------------------- /lightning_modules/datasets/span_selection_reader.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | import sys 4 | import os 5 | import torch 6 | 7 | from typing import Dict, Iterable, List, Any, Optional, Union 8 | 9 | from pytorch_lightning import LightningDataModule 10 | from torch.utils.data import Dataset 11 | 12 | from datasets_util import right_pad_sequences 13 | 14 | from transformers import AutoTokenizer 15 | from span_selection_utils import * 16 | from utils import * 17 | from torch.utils.data import DataLoader 18 | # from torch.utils.data import DataLoader 19 | # from torch.utils.data import DataLoader 20 | # set environment variable to avoid deadlocks, see: 21 | # https://docs.allennlp.org/main/api/data/data_loaders/multiprocess_data_loader/#multiprocessdataloader.common_issues 22 | os.environ['TOKENIZERS_PARALLELISM']='0' 23 | 24 | class SpanSelectionDataset(Dataset): 25 | def __init__( 26 | self, 27 | model_name: str, 28 | file_path: str, 29 | max_instances: int, 30 | mode: str = "train", 31 | entity_name: str = "question_type", 32 | **kwargs): 33 | super().__init__(**kwargs) 34 | 35 | assert mode in ["train", "test", "valid"] 36 | 37 | self.model_name = model_name 38 | self.tokenizer = AutoTokenizer.from_pretrained(self.model_name) 39 | self.max_instances = max_instances 40 | self.mode = mode 41 | self.entity_name = entity_name 42 | 43 | self.instances = self.read(file_path, self.tokenizer, self.entity_name) 44 | 45 | print(f"read {len(self.instances)} {self.mode} examples") 46 | 47 | def read(self, input_path: str, tokenizer, entity_name) -> Iterable[Dict[str, Any]]: 48 | with open(input_path) as input_file: 49 | if self.max_instances > 0: 50 | input_data = json.load(input_file)[:self.max_instances] 51 | else: 52 | input_data = json.load(input_file) 53 | 54 | examples = [] 55 | for entry in input_data: 56 | example = read_mathqa_entry(entry, tokenizer, entity_name) 57 | if example: 58 | examples.append(example) 59 | 60 | 61 | kwargs = { 62 | "examples": examples, 63 | "tokenizer": tokenizer, 64 | "max_seq_length": 512, 65 | } 66 | 67 | self.entity_name = entity_name 68 | data = convert_examples_to_features(**kwargs) 69 | return data 70 | 71 | def __getitem__(self, idx: int): 72 | return self.instances[idx] 73 | 74 | def __len__(self): 75 | return len(self.instances) 76 | 77 | def truncate(self, max_instances): 78 | truncated_instances = self.instances[max_instances:] 79 | self.instances = self.instances[:max_instances] 80 | return truncated_instances 81 | 82 | def extend(self, instances): 83 | self.instances.extend(instances) 84 | 85 | def customized_collate_fn(examples: List[Dict[str, Any]]) -> Dict[str, Any]: 86 | result_dict = {} 87 | for k in examples[0].keys(): 88 | try: 89 | result_dict[k] = right_pad_sequences([torch.tensor(ex[k]) for ex in examples], 90 | batch_first=True, padding_value=0) 91 | except: 92 | result_dict[k] = [ex[k] for ex in examples] 93 | return result_dict 94 | 95 | class SpanSelectionDataModule(LightningDataModule): 96 | def __init__(self, 97 | model_name: str, 98 | batch_size: int = 1, 99 | val_batch_size: int = 1, 100 | train_file_path: str = None, 101 | val_file_path: str = None, 102 | train_max_instances: int = sys.maxsize, 103 | val_max_instances: int = sys.maxsize, 104 | entity_name: str = "question_type"): 105 | super().__init__() 106 | self.model_name = model_name 107 | self.batch_size = batch_size 108 | self.val_batch_size = val_batch_size 109 | 110 | self.train_file_path = train_file_path 111 | self.val_file_path = val_file_path 112 | 113 | self.train_max_instances = train_max_instances 114 | self.val_max_instances = val_max_instances 115 | 116 | self.entity_name = entity_name 117 | 118 | self.train_data = None 119 | self.val_data = None 120 | 121 | # OPTIONAL, called for every GPU/machine (assigning state is OK) 122 | def setup(self, stage: Optional[str] = None): 123 | assert stage in ["fit", "validate"] 124 | 125 | train_data = SpanSelectionDataset(model_name = self.model_name, 126 | file_path=self.train_file_path, 127 | max_instances=self.train_max_instances, 128 | mode="train", 129 | entity_name = self.entity_name) 130 | self.train_data = train_data 131 | 132 | val_data = SpanSelectionDataset(model_name = self.model_name, 133 | file_path=self.val_file_path, 134 | max_instances=self.val_max_instances, 135 | mode="valid", 136 | entity_name = self.entity_name) 137 | self.val_data = val_data 138 | 139 | def train_dataloader(self): 140 | if self.train_data is None: 141 | self.setup(stage="fit") 142 | 143 | dtloader = DataLoader(self.train_data, batch_size=self.batch_size, shuffle=True, drop_last=True, collate_fn=customized_collate_fn) 144 | return dtloader 145 | 146 | def val_dataloader(self): 147 | if self.val_data is None: 148 | self.setup(stage="validate") 149 | 150 | dtloader = DataLoader(self.val_data, batch_size=self.val_batch_size, shuffle=True, drop_last=False, collate_fn=customized_collate_fn) 151 | return dtloader 152 | 153 | class SpanSelectionInferenceDataModule(LightningDataModule): 154 | def __init__(self, 155 | model_name: str, 156 | batch_size: int = 1, 157 | test_file_path: str = None, 158 | test_max_instances: int = sys.maxsize, 159 | entity_name: str = "question_type"): 160 | super().__init__() 161 | self.model_name = model_name 162 | self.batch_size = batch_size 163 | self.test_file_path = test_file_path 164 | self.test_max_instances = test_max_instances 165 | self.entity_name = entity_name 166 | self.test_data = None 167 | 168 | def setup(self, stage: Optional[str] = None): 169 | assert stage in ["predict", "test"] 170 | 171 | test_data = SpanSelectionDataset(model_name = self.model_name, 172 | file_path=self.test_file_path, 173 | max_instances=self.test_max_instances, 174 | mode="test", 175 | entity_name = self.entity_name) 176 | self.test_data = test_data 177 | 178 | def predict_dataloader(self): 179 | if self.test_data is None: 180 | self.setup(stage="predict") 181 | 182 | dtloader = DataLoader(self.test_data, batch_size=self.batch_size, shuffle=False, drop_last=False, collate_fn=customized_collate_fn) 183 | return dtloader -------------------------------------------------------------------------------- /lightning_modules/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/psunlpgroup/MultiHiertt/45bd9ccdf3142ea059bd5e69c0afb83437fa539c/lightning_modules/models/__init__.py -------------------------------------------------------------------------------- /lightning_modules/models/program_generation_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.optim as optim 4 | import torch.nn.functional as F 5 | import math 6 | import numpy as np 7 | import pytorch_lightning as pl 8 | from pytorch_lightning import LightningModule 9 | from utils.program_generation_utils import * 10 | from utils.utils import * 11 | from transformers import AutoConfig, AutoTokenizer, AutoModel 12 | from typing import Optional, Dict, Any, Tuple, List 13 | from transformers.optimization import AdamW, get_constant_schedule_with_warmup, get_linear_schedule_with_warmup 14 | from transformers.optimization import get_cosine_schedule_with_warmup 15 | 16 | op_list, const_list = get_op_const_list() 17 | reserved_token_size = len(op_list) + len(const_list) 18 | 19 | class ProgramGenerationModel(LightningModule): 20 | 21 | def __init__(self, 22 | model_name: str, 23 | program_length: int, 24 | input_length: int, 25 | max_step_ind: int, 26 | dropout_rate: float, 27 | num_decoder_layers: int, 28 | n_best_size: int, 29 | sep_attention: bool, 30 | layer_norm: bool, 31 | warmup_steps: int = 0, 32 | optimizer: Dict[str, Any] = None, 33 | lr_scheduler: Dict[str, Any] = None, 34 | test_set: str = "dev_training.json", 35 | entity_name: str = "question_type", 36 | input_dir: str = "dataset/program_generator_input", 37 | load_ckpt_file: str = None, 38 | ) -> None: 39 | 40 | super().__init__() 41 | 42 | self.model_name = model_name 43 | self.const_list = const_list 44 | self.op_list = op_list 45 | self.op_list_size = len(op_list) 46 | self.const_list_size = len(const_list) 47 | self.reserved_token_size = self.op_list_size + self.const_list_size 48 | self.max_step_ind = max_step_ind 49 | 50 | self.program_length = program_length 51 | self.input_length = input_length 52 | self.num_decoder_layers = num_decoder_layers 53 | self.n_best_size = n_best_size 54 | 55 | self.test_set = test_set 56 | self.entity_name = entity_name 57 | self.input_dir = input_dir 58 | 59 | self.sep_attention = sep_attention 60 | self.layer_norm = layer_norm 61 | 62 | self.reserved_ind = nn.Parameter(torch.arange( 63 | 0, self.reserved_token_size), requires_grad=False) 64 | self.reserved_go = nn.Parameter(torch.arange(op_list.index( 65 | 'GO'), op_list.index('GO') + 1), requires_grad=False) 66 | 67 | self.reserved_para = nn.Parameter(torch.arange(op_list.index( 68 | ')'), op_list.index(')') + 1), requires_grad=False) 69 | 70 | # masking for decoidng for test time 71 | op_ones = nn.Parameter(torch.ones( 72 | self.op_list_size), requires_grad=False) 73 | op_zeros = nn.Parameter(torch.zeros( 74 | self.op_list_size), requires_grad=False) 75 | other_ones = nn.Parameter(torch.ones( 76 | input_length + self.const_list_size), requires_grad=False) 77 | other_zeros = nn.Parameter(torch.zeros( 78 | input_length + self.const_list_size), requires_grad=False) 79 | self.op_only_mask = nn.Parameter( 80 | torch.cat((op_ones, other_zeros), 0), requires_grad=False) 81 | self.seq_only_mask = nn.Parameter( 82 | torch.cat((op_zeros, other_ones), 0), requires_grad=False) 83 | 84 | # for ")" 85 | para_before_ones = nn.Parameter(torch.ones( 86 | op_list.index(')')), requires_grad=False) 87 | para_after_ones = nn.Parameter(torch.ones( 88 | input_length + self.reserved_token_size - op_list.index(')') - 1), requires_grad=False) 89 | para_zero = nn.Parameter(torch.zeros(1), requires_grad=False) 90 | self.para_mask = nn.Parameter(torch.cat( 91 | (para_before_ones, para_zero, para_after_ones), 0), requires_grad=False) 92 | 93 | # for step embedding 94 | # self.step_masks = [] 95 | all_tmp_list = self.op_list + self.const_list 96 | self.step_masks = nn.Parameter(torch.zeros( 97 | self.max_step_ind, input_length + self.reserved_token_size), requires_grad=False) 98 | for i in range(self.max_step_ind): 99 | this_step_mask_ind = all_tmp_list.index("#" + str(i)) 100 | self.step_masks[i, this_step_mask_ind] = 1.0 101 | 102 | 103 | self.model = AutoModel.from_pretrained(self.model_name) 104 | self.tokenizer = AutoTokenizer.from_pretrained(self.model_name) 105 | self.model_config = AutoConfig.from_pretrained(self.model_name) 106 | 107 | self.hidden_size = self.model_config.hidden_size 108 | 109 | self.cls_prj = nn.Linear(self.hidden_size, self.hidden_size, bias=True) 110 | self.cls_dropout = nn.Dropout(dropout_rate) 111 | 112 | self.seq_prj = nn.Linear(self.hidden_size, self.hidden_size, bias=True) 113 | self.seq_dropout = nn.Dropout(dropout_rate) 114 | 115 | self.reserved_token_embedding = nn.Embedding( 116 | self.reserved_token_size, self.hidden_size) 117 | 118 | # attentions 119 | self.decoder_history_attn_prj = nn.Linear( 120 | self.hidden_size, self.hidden_size, bias=True) 121 | self.decoder_history_attn_dropout = nn.Dropout(dropout_rate) 122 | 123 | self.question_attn_prj = nn.Linear(self.hidden_size, self.hidden_size, bias=True) 124 | self.question_attn_dropout = nn.Dropout(dropout_rate) 125 | 126 | self.question_summary_attn_prj = nn.Linear( 127 | self.hidden_size, self.hidden_size, bias=True) 128 | self.question_summary_attn_dropout = nn.Dropout(dropout_rate) 129 | 130 | if self.sep_attention: 131 | self.input_embeddings_prj = nn.Linear( 132 | self.hidden_size*3, self.hidden_size, bias=True) 133 | else: 134 | self.input_embeddings_prj = nn.Linear( 135 | self.hidden_size*2, self.hidden_size, bias=True) 136 | self.input_embeddings_layernorm = nn.LayerNorm([1, self.hidden_size]) 137 | 138 | self.option_embeddings_prj = nn.Linear( 139 | self.hidden_size*2, self.hidden_size, bias=True) 140 | 141 | # decoder lstm 142 | self.rnn = torch.nn.LSTM(input_size=self.hidden_size, hidden_size=self.hidden_size, 143 | num_layers=self.num_decoder_layers, batch_first=True) 144 | 145 | # step vector 146 | self.decoder_step_proj = nn.Linear( 147 | 3*self.hidden_size, self.hidden_size, bias=True) 148 | self.decoder_step_proj_dropout = nn.Dropout(dropout_rate) 149 | 150 | self.step_mix_proj = nn.Linear( 151 | self.hidden_size*2, self.hidden_size, bias=True) 152 | 153 | 154 | self.warmup_steps = warmup_steps 155 | self.criterion = nn.CrossEntropyLoss(reduction='none', ignore_index=-1) 156 | 157 | self.predictions: List[Dict[str, Any]] = [] 158 | 159 | self.opt_params = optimizer["init_args"] 160 | self.lrs_params = lr_scheduler 161 | 162 | 163 | def forward(self, is_training, input_ids, input_mask, segment_ids, option_mask, program_ids, program_mask, metadata) -> List[Dict[str, Any]]: 164 | 165 | input_ids = torch.tensor(input_ids).to("cuda") 166 | input_mask = torch.tensor(input_mask).to("cuda") 167 | segment_ids = torch.tensor(segment_ids).to("cuda") 168 | option_mask = torch.tensor(option_mask).to("cuda") 169 | program_ids = torch.tensor(program_ids).to("cuda") 170 | program_mask = torch.tensor(program_mask).to("cuda") 171 | 172 | bert_outputs = self.model( 173 | input_ids=input_ids, attention_mask=input_mask, token_type_ids=segment_ids) 174 | 175 | bert_sequence_output = bert_outputs.last_hidden_state 176 | bert_pooled_output = bert_sequence_output[:, 0, :] 177 | batch_size, seq_length, bert_dim = list(bert_sequence_output.size()) 178 | split_program_ids = torch.split(program_ids, 1, dim=1) 179 | 180 | pooled_output = self.cls_prj(bert_pooled_output) 181 | pooled_output = self.cls_dropout(pooled_output) 182 | 183 | sequence_output = self.seq_prj(bert_sequence_output) 184 | sequence_output = self.seq_dropout(sequence_output) 185 | 186 | op_embeddings = self.reserved_token_embedding(self.reserved_ind) 187 | op_embeddings = op_embeddings.repeat(batch_size, 1, 1) 188 | 189 | init_decoder_output = self.reserved_token_embedding(self.reserved_go) 190 | decoder_output = init_decoder_output.repeat(batch_size, 1, 1) 191 | 192 | logits = [] 193 | 194 | # [batch, op + seq len, hidden] 195 | initial_option_embeddings = torch.cat( 196 | [op_embeddings, sequence_output], dim=1) 197 | 198 | if self.sep_attention: 199 | decoder_history = decoder_output 200 | else: 201 | decoder_history = torch.unsqueeze(pooled_output, dim=-1) 202 | 203 | decoder_state_h = torch.zeros(1, batch_size, self.hidden_size, device = "cuda") 204 | decoder_state_c = torch.zeros(1, batch_size, self.hidden_size, device = "cuda") 205 | 206 | float_input_mask = input_mask.float() 207 | float_input_mask = torch.unsqueeze(float_input_mask, dim=-1) 208 | 209 | this_step_new_op_emb = initial_option_embeddings 210 | 211 | for cur_step in range(self.program_length): 212 | 213 | # decoder history att 214 | decoder_history_attn_vec = self.decoder_history_attn_prj( 215 | decoder_output) 216 | decoder_history_attn_vec = self.decoder_history_attn_dropout( 217 | decoder_history_attn_vec) 218 | 219 | decoder_history_attn_w = torch.matmul( 220 | decoder_history, torch.transpose(decoder_history_attn_vec, 1, 2)) 221 | decoder_history_attn_w = F.softmax(decoder_history_attn_w, dim=1) 222 | 223 | decoder_history_ctx_embeddings = torch.matmul( 224 | torch.transpose(decoder_history_attn_w, 1, 2), decoder_history) 225 | 226 | if self.sep_attention: 227 | # input seq att 228 | question_attn_vec = self.question_attn_prj(decoder_output) 229 | question_attn_vec = self.question_attn_dropout( 230 | question_attn_vec) 231 | 232 | question_attn_w = torch.matmul( 233 | sequence_output, torch.transpose(question_attn_vec, 1, 2)) 234 | question_attn_w -= 1e6 * (1 - float_input_mask) 235 | question_attn_w = F.softmax(question_attn_w, dim=1) 236 | 237 | question_ctx_embeddings = torch.matmul( 238 | torch.transpose(question_attn_w, 1, 2), sequence_output) 239 | 240 | # another input seq att 241 | question_summary_vec = self.question_summary_attn_prj( 242 | decoder_output) 243 | question_summary_vec = self.question_summary_attn_dropout( 244 | question_summary_vec) 245 | 246 | question_summary_w = torch.matmul( 247 | sequence_output, torch.transpose(question_summary_vec, 1, 2)) 248 | question_summary_w -= 1e6 * (1 - float_input_mask) 249 | question_summary_w = F.softmax(question_summary_w, dim=1) 250 | 251 | question_summary_embeddings = torch.matmul( 252 | torch.transpose(question_summary_w, 1, 2), sequence_output) 253 | 254 | if self.sep_attention: 255 | concat_input_embeddings = torch.cat([decoder_history_ctx_embeddings, 256 | question_ctx_embeddings, 257 | decoder_output], dim=-1) 258 | else: 259 | concat_input_embeddings = torch.cat([decoder_history_ctx_embeddings, 260 | decoder_output], dim=-1) 261 | 262 | input_embeddings = self.input_embeddings_prj( 263 | concat_input_embeddings) 264 | 265 | if self.layer_norm: 266 | input_embeddings = self.input_embeddings_layernorm( 267 | input_embeddings) 268 | 269 | question_option_vec = this_step_new_op_emb * question_summary_embeddings 270 | option_embeddings = torch.cat( 271 | [this_step_new_op_emb, question_option_vec], dim=-1) 272 | 273 | option_embeddings = self.option_embeddings_prj(option_embeddings) 274 | option_logits = torch.matmul( 275 | option_embeddings, torch.transpose(input_embeddings, 1, 2)) 276 | option_logits = torch.squeeze( 277 | option_logits, dim=2) # [batch, op + seq_len] 278 | option_logits -= 1e6 * (1 - option_mask) 279 | logits.append(option_logits) 280 | 281 | if is_training: 282 | program_index = torch.unsqueeze( 283 | split_program_ids[cur_step], dim=1) 284 | else: 285 | # constrain decoding 286 | if cur_step % 4 == 0 or (cur_step + 1) % 4 == 0: 287 | # op round 288 | option_logits -= 1e6 * self.seq_only_mask 289 | else: 290 | # number round 291 | option_logits -= 1e6 * self.op_only_mask 292 | 293 | if (cur_step + 1) % 4 == 0: 294 | # ")" round 295 | option_logits -= 1e6 * self.para_mask 296 | # print(program_index) 297 | 298 | program_index = torch.argmax( 299 | option_logits, axis=-1, keepdim=True) 300 | 301 | program_index = torch.unsqueeze( 302 | program_index, dim=1 303 | ) 304 | 305 | if (cur_step + 1) % 4 == 0: 306 | # update op embeddings 307 | this_step_index = cur_step // 4 308 | this_step_list_index = ( 309 | self.op_list + self.const_list).index("#" + str(this_step_index)) 310 | this_step_mask = self.step_masks[this_step_index, :] 311 | 312 | decoder_step_vec = self.decoder_step_proj( 313 | concat_input_embeddings) 314 | decoder_step_vec = self.decoder_step_proj_dropout( 315 | decoder_step_vec) 316 | decoder_step_vec = torch.squeeze(decoder_step_vec) 317 | 318 | this_step_new_emb = decoder_step_vec # [batch, hidden] 319 | 320 | this_step_new_emb = torch.unsqueeze(this_step_new_emb, 1) 321 | this_step_new_emb = this_step_new_emb.repeat( 322 | 1, self.reserved_token_size+self.input_length, 1) # [batch, op seq, hidden] 323 | 324 | this_step_mask = torch.unsqueeze( 325 | this_step_mask, 0) # [1, op seq] 326 | # print(this_step_mask) 327 | 328 | this_step_mask = torch.unsqueeze( 329 | this_step_mask, 2) # [1, op seq, 1] 330 | this_step_mask = this_step_mask.repeat( 331 | batch_size, 1, self.hidden_size) # [batch, op seq, hidden] 332 | 333 | this_step_new_op_emb = torch.where( 334 | this_step_mask > 0, this_step_new_emb, initial_option_embeddings) 335 | 336 | # print(program_index.size()) 337 | program_index = torch.repeat_interleave( 338 | program_index, self.hidden_size, dim=2) # [batch, 1, hidden] 339 | 340 | input_program_embeddings = torch.gather( 341 | option_embeddings, dim=1, index=program_index) 342 | 343 | decoder_output, (decoder_state_h, decoder_state_c) = self.rnn( 344 | input_program_embeddings, (decoder_state_h, decoder_state_c)) 345 | decoder_history = torch.cat( 346 | [decoder_history, input_program_embeddings], dim=1) 347 | 348 | 349 | logits = torch.stack(logits, dim=1) 350 | 351 | output_dicts = [] 352 | for i in range(len(metadata)): 353 | output_dicts.append({"logits": logits[i], "unique_id": metadata[i]["unique_id"]}) 354 | return output_dicts 355 | 356 | 357 | def training_step(self, batch: Dict[str, torch.Tensor], batch_idx: int) -> Dict[str, torch.Tensor]: 358 | input_ids = batch["input_ids"] 359 | input_mask = batch["input_mask"] 360 | segment_ids = batch["segment_ids"] 361 | program_ids = batch["program_ids"] 362 | program_mask = batch["program_mask"] 363 | option_mask = batch["option_mask"] 364 | is_training = True 365 | 366 | program_ids = torch.tensor(program_ids).to("cuda") 367 | program_mask = torch.tensor(program_mask).to("cuda") 368 | 369 | metadata = [{"unique_id": filename_id} for filename_id in batch["unique_id"]] 370 | 371 | output_dicts = self(is_training, input_ids, input_mask, segment_ids, option_mask, program_ids, program_mask, metadata) 372 | 373 | logits = [] 374 | for output_dict in output_dicts: 375 | logits.append(output_dict["logits"]) 376 | logits = torch.stack(logits) 377 | loss = self.criterion(logits.view(-1, logits.shape[-1]), program_ids.view(-1)) 378 | loss = loss * program_mask.view(-1) 379 | 380 | self.log("loss", loss.sum(), on_step=True, on_epoch=True, prog_bar=True, logger=True) 381 | return {"loss": loss.sum()} 382 | 383 | def on_fit_start(self) -> None: 384 | # save the code using wandb 385 | if self.logger: 386 | # if logger is initialized, save the code 387 | self.logger[0].log_code() 388 | else: 389 | print("logger is not initialized, code will not be saved") 390 | 391 | return super().on_fit_start() 392 | 393 | def validation_step(self, batch: torch.Tensor, batch_idx: int): 394 | input_ids = batch["input_ids"] 395 | input_mask = batch["input_mask"] 396 | segment_ids = batch["segment_ids"] 397 | option_mask = batch["option_mask"] 398 | program_ids = batch["program_ids"] 399 | program_mask = batch["program_mask"] 400 | is_training = False 401 | 402 | program_ids = torch.tensor(program_ids).to("cuda") 403 | 404 | metadata = [{"unique_id": filename_id} for filename_id in batch["unique_id"]] 405 | 406 | output_dicts = self(is_training, input_ids, input_mask, segment_ids, option_mask, program_ids, program_mask, metadata) 407 | 408 | logits = [] 409 | for output_dict in output_dicts: 410 | logits.append(output_dict["logits"]) 411 | logits = torch.stack(logits) 412 | 413 | loss = self.criterion(logits.view(-1, logits.shape[-1]), program_ids.view(-1)) 414 | self.log("val_loss", loss) 415 | 416 | return output_dicts 417 | 418 | def validation_step_end(self, outputs: List[Dict[str, Any]]) -> None: 419 | self.predictions.extend(outputs) 420 | 421 | def validation_epoch_end(self, outputs: List[Dict[str, Any]]) -> None: 422 | test_file = os.path.join(self.input_dir, self.test_set) 423 | with open(test_file) as input_file: 424 | input_data = json.load(input_file) 425 | 426 | data_ori = [] 427 | for entry in input_data: 428 | example = read_mathqa_entry(entry, self.tokenizer, self.entity_name) 429 | if example: 430 | data_ori.append(example) 431 | 432 | kwargs = { 433 | "examples": data_ori, 434 | "tokenizer": self.tokenizer, 435 | "max_seq_length": self.input_length, 436 | "max_program_length": self.program_length, 437 | "is_training": False, 438 | "op_list": op_list, 439 | "op_list_size": len(op_list), 440 | "const_list": const_list, 441 | "const_list_size": len(const_list), 442 | "verbose": True 443 | } 444 | 445 | data = convert_examples_to_features(**kwargs) 446 | 447 | all_results = [] 448 | 449 | 450 | for output_dict in self.predictions: 451 | all_results.append( 452 | RawResult( 453 | unique_id=output_dict["unique_id"], 454 | logits=output_dict["logits"], 455 | loss=None 456 | )) 457 | 458 | all_predictions, all_nbest = compute_predictions( 459 | data_ori, 460 | data, 461 | all_results, 462 | n_best_size=self.n_best_size, 463 | max_program_length=self.program_length, 464 | tokenizer=self.tokenizer, 465 | op_list=op_list, 466 | op_list_size=len(op_list), 467 | const_list=const_list, 468 | const_list_size=len(const_list)) 469 | 470 | exe_acc = evaluate_result(all_nbest, test_file, program_mode="seq") 471 | 472 | self.log("exe_acc", exe_acc) 473 | 474 | # reset the predictions 475 | self.predictions = [] 476 | 477 | def predict_step(self, batch: torch.Tensor, batch_idx: int): 478 | input_ids = batch["input_ids"] 479 | input_mask = batch["input_mask"] 480 | segment_ids = batch["segment_ids"] 481 | option_mask = batch["option_mask"] 482 | program_ids = batch["program_ids"] 483 | program_mask = batch["program_mask"] 484 | is_training = False 485 | 486 | program_ids = torch.tensor(program_ids).to("cuda") 487 | 488 | metadata = [{"unique_id": filename_id} for filename_id in batch["unique_id"]] 489 | 490 | output_dicts = self(is_training, input_ids, input_mask, segment_ids, option_mask, program_ids, program_mask, metadata) 491 | 492 | return output_dicts 493 | 494 | def predict_step_end(self, outputs: List[Dict[str, Any]]) -> None: 495 | self.predictions.extend(outputs) 496 | 497 | def configure_optimizers(self): 498 | optimizer = AdamW(self.parameters(), **self.opt_params) 499 | if self.lrs_params["name"] == "cosine": 500 | lr_scheduler = get_cosine_schedule_with_warmup(optimizer, **self.lrs_params["init_args"]) 501 | elif self.lrs_params["name"] == "linear": 502 | lr_scheduler = get_linear_schedule_with_warmup(optimizer, **self.lrs_params["init_args"]) 503 | elif self.lrs_params["name"] == "constant": 504 | lr_scheduler = get_constant_schedule_with_warmup(optimizer, **self.lrs_params["init_args"]) 505 | else: 506 | raise ValueError(f"lr_scheduler {self.lrs_params} is not supported") 507 | 508 | return {"optimizer": optimizer, 509 | "lr_scheduler": { 510 | "scheduler": lr_scheduler, 511 | "interval": "step" 512 | } 513 | } -------------------------------------------------------------------------------- /lightning_modules/models/question_classification_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os, json 3 | from torch import nn 4 | import torch.optim as optim 5 | import torch.nn.functional as F 6 | import math 7 | import numpy as np 8 | import pytorch_lightning as pl 9 | from pytorch_lightning import LightningModule 10 | from transformers import AutoModel, AutoConfig, AutoModelForSequenceClassification 11 | from typing import Optional, Dict, Any, Tuple, List 12 | from transformers.optimization import AdamW, get_constant_schedule_with_warmup, get_linear_schedule_with_warmup 13 | from transformers.optimization import get_cosine_schedule_with_warmup 14 | import datasets 15 | 16 | 17 | class QuestionClassificationModel(LightningModule): 18 | 19 | def __init__(self, 20 | model_name: str, 21 | warmup_steps: int = 0, 22 | optimizer: Dict[str, Any] = None, 23 | lr_scheduler: Dict[str, Any] = None, 24 | test_set: str = "dev", 25 | ) -> None: 26 | 27 | super().__init__() 28 | self.transformer_model_name = model_name 29 | 30 | self.model_config = AutoConfig.from_pretrained(self.transformer_model_name, num_labels=2) 31 | self.model = AutoModelForSequenceClassification.from_pretrained(self.transformer_model_name, config=self.model_config) 32 | self.metric = datasets.load_metric('precision') 33 | 34 | self.test_set = test_set 35 | 36 | self.warmup_steps = warmup_steps 37 | self.opt_params = optimizer["init_args"] 38 | self.lrs_params = lr_scheduler 39 | 40 | def forward(self, **inputs) -> List[Dict[str, Any]]: 41 | return self.model(**inputs) 42 | 43 | 44 | def training_step(self, batch: Dict[str, torch.Tensor], batch_idx: int) -> Dict[str, torch.Tensor]: 45 | input_ids = torch.tensor(batch["input_ids"]).to("cuda") 46 | attention_mask = torch.tensor(batch["input_mask"]).to("cuda") 47 | labels = torch.tensor(batch["labels"]).to("cuda") 48 | 49 | outputs = self(input_ids=input_ids, attention_mask=attention_mask, labels=labels) 50 | loss = outputs.loss 51 | self.log("loss", loss, on_step=True, on_epoch=True) 52 | 53 | return loss 54 | 55 | def on_fit_start(self) -> None: 56 | # save the code using wandb 57 | if self.logger: 58 | # if logger is initialized, save the code 59 | self.logger[0].log_code() 60 | else: 61 | print("logger is not initialized, code will not be saved") 62 | 63 | return super().on_fit_start() 64 | 65 | def validation_step(self, batch: torch.Tensor, batch_idx: int): 66 | input_ids = torch.tensor(batch["input_ids"]).to("cuda") 67 | attention_mask = torch.tensor(batch["input_mask"]).to("cuda") 68 | labels = torch.tensor(batch["labels"]).to("cuda") 69 | 70 | outputs = self(input_ids=input_ids, attention_mask=attention_mask, labels=labels) 71 | 72 | loss = outputs.loss 73 | self.log("val_loss", loss) 74 | 75 | logits = outputs.logits 76 | preds = torch.argmax(logits, dim=1) 77 | labels = batch["labels"] 78 | uids = batch["uid"] 79 | 80 | return {"preds": preds, "labels": labels, "uids": uids} 81 | 82 | def validation_epoch_end(self, outputs: List[Dict[str, Any]]) -> None: 83 | preds = torch.cat([x["preds"] for x in outputs]).detach().cpu().numpy() 84 | labels = torch.cat([x["labels"] for x in outputs]).detach().cpu().numpy() 85 | 86 | self.log_dict(self.metric.compute(predictions=preds, references=labels), prog_bar=True) 87 | print("precision:", self.metric.compute(predictions=preds, references=labels)) 88 | 89 | def test_step(self, batch: Dict[str, torch.Tensor], batch_idx: int) -> Dict[str, torch.Tensor]: 90 | return self.validation_step(batch, batch_idx) 91 | 92 | def test_epoch_end(self, outputs: List[Dict[str, Any]]) -> None: 93 | self.validation_epoch_end(outputs) 94 | 95 | def predict_step(self, batch: Dict[str, torch.Tensor], batch_idx: int) -> Dict[str, torch.Tensor]: 96 | input_ids = torch.tensor(batch["input_ids"]).to("cuda") 97 | attention_mask = torch.tensor(batch["input_mask"]).to("cuda") 98 | 99 | outputs = self(input_ids=input_ids, attention_mask=attention_mask, labels=None) 100 | 101 | logits = outputs.logits 102 | preds = torch.argmax(logits, dim=1) 103 | uids = batch["uid"] 104 | 105 | return {"preds": preds, "uids": uids} 106 | 107 | def configure_optimizers(self): 108 | optimizer = AdamW(self.parameters(), **self.opt_params) 109 | if self.lrs_params["name"] == "cosine": 110 | lr_scheduler = get_cosine_schedule_with_warmup(optimizer, **self.lrs_params["init_args"]) 111 | elif self.lrs_params["name"] == "linear": 112 | lr_scheduler = get_linear_schedule_with_warmup(optimizer, **self.lrs_params["init_args"]) 113 | elif self.lrs_params["name"] == "constant": 114 | lr_scheduler = get_constant_schedule_with_warmup(optimizer, **self.lrs_params["init_args"]) 115 | else: 116 | raise ValueError(f"lr_scheduler {self.lrs_params} is not supported") 117 | 118 | return {"optimizer": optimizer, 119 | "lr_scheduler": { 120 | "scheduler": lr_scheduler, 121 | "interval": "step" 122 | } 123 | } -------------------------------------------------------------------------------- /lightning_modules/models/retriever_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.optim as optim 4 | import torch.nn.functional as F 5 | import math 6 | import numpy as np 7 | import pytorch_lightning as pl 8 | from pytorch_lightning import LightningModule 9 | from utils.retriever_utils import * 10 | from utils.utils import * 11 | from transformers import AutoModel, AutoTokenizer, AutoConfig 12 | from typing import Optional, Dict, Any, Tuple, List 13 | from transformers.optimization import AdamW, get_constant_schedule_with_warmup, get_linear_schedule_with_warmup 14 | from transformers.optimization import get_cosine_schedule_with_warmup 15 | 16 | 17 | class RetrieverModel(LightningModule): 18 | 19 | def __init__(self, 20 | transformer_model_name: str, 21 | topn: int, 22 | dropout_rate: float, 23 | warmup_steps: int = 0, 24 | optimizer: Dict[str, Any] = None, 25 | lr_scheduler: Dict[str, Any] = None, 26 | ) -> None: 27 | 28 | super().__init__() 29 | 30 | self.topn = topn 31 | self.dropout_rate = dropout_rate 32 | self.transformer_model_name = transformer_model_name 33 | 34 | self.model = AutoModel.from_pretrained(self.transformer_model_name) 35 | self.warmup_steps = warmup_steps 36 | self.model_config = AutoConfig.from_pretrained(self.transformer_model_name) 37 | 38 | self.criterion = nn.CrossEntropyLoss(reduction='none', ignore_index=-1) 39 | self.predictions: List[Dict[str, Any]] = [] 40 | 41 | self.opt_params = optimizer["init_args"] 42 | self.lrs_params = lr_scheduler 43 | 44 | hidden_size = self.model_config.hidden_size 45 | self.cls_prj = nn.Linear(hidden_size, hidden_size, bias=True) 46 | self.cls_dropout = nn.Dropout(self.dropout_rate) 47 | 48 | self.cls_final = nn.Linear(hidden_size, 2, bias=True) 49 | 50 | self.predictions = [] 51 | 52 | def forward(self, input_ids, attention_mask, segment_ids, metadata) -> List[Dict[str, Any]]: 53 | 54 | input_ids = torch.tensor(input_ids).to("cuda") 55 | attention_mask = torch.tensor(attention_mask).to("cuda") 56 | segment_ids = torch.tensor(segment_ids).to("cuda") 57 | 58 | bert_outputs = self.model( 59 | input_ids=input_ids, attention_mask=attention_mask, token_type_ids=segment_ids) 60 | 61 | bert_sequence_output = bert_outputs.last_hidden_state 62 | 63 | bert_pooled_output = bert_sequence_output[:, 0, :] 64 | 65 | pooled_output = self.cls_prj(bert_pooled_output) 66 | pooled_output = self.cls_dropout(pooled_output) 67 | 68 | logits = self.cls_final(pooled_output) 69 | output_dicts = [] 70 | for i in range(len(metadata)): 71 | output_dicts.append({"logits": logits[i], "filename_id": metadata[i]["filename_id"], "ind": metadata[i]["ind"]}) 72 | return output_dicts 73 | 74 | 75 | def training_step(self, batch: Dict[str, torch.Tensor], batch_idx: int) -> Dict[str, torch.Tensor]: 76 | input_ids = batch["input_ids"] 77 | attention_mask = batch["input_mask"] 78 | segment_ids = batch["segment_ids"] 79 | labels = batch["label"] 80 | labels = torch.tensor(labels).to("cuda") 81 | 82 | metadata = [{"filename_id": filename_id, "ind": ind} for filename_id, ind in zip(batch["filename_id"], batch["ind"])] 83 | 84 | output_dicts = self(input_ids, attention_mask, segment_ids, metadata) 85 | 86 | logits = [] 87 | for output_dict in output_dicts: 88 | logits.append(output_dict["logits"]) 89 | logits = torch.stack(logits) 90 | loss = self.criterion(logits.view(-1, logits.shape[-1]), labels.view(-1)) 91 | 92 | self.log("loss", loss.sum(), on_step=True, on_epoch=True, prog_bar=True, logger=True) 93 | return {"loss": loss.sum()} 94 | 95 | def on_fit_start(self) -> None: 96 | # save the code using wandb 97 | if self.logger: 98 | # if logger is initialized, save the code 99 | self.logger[0].log_code() 100 | else: 101 | print("logger is not initialized, code will not be saved") 102 | 103 | return super().on_fit_start() 104 | 105 | def validation_step(self, batch: torch.Tensor, batch_idx: int): 106 | input_ids = batch["input_ids"] 107 | attention_mask = batch["input_mask"] 108 | segment_ids = batch["segment_ids"] 109 | 110 | labels = batch["label"] 111 | labels = torch.tensor(labels).to("cuda") 112 | 113 | metadata = [{"filename_id": filename_id, "ind": ind} for filename_id, ind in zip(batch["filename_id"], batch["ind"])] 114 | 115 | output_dicts = self(input_ids, attention_mask, segment_ids, metadata) 116 | 117 | logits = [] 118 | for output_dict in output_dicts: 119 | logits.append(output_dict["logits"]) 120 | logits = torch.stack(logits) 121 | loss = self.criterion(logits.view(-1, logits.shape[-1]), labels.view(-1)) 122 | self.log("val_loss", loss) 123 | return output_dicts 124 | 125 | def predict_step(self, batch: torch.Tensor, batch_idx: int): 126 | input_ids = batch["input_ids"] 127 | attention_mask = batch["input_mask"] 128 | segment_ids = batch["segment_ids"] 129 | 130 | metadata = [{"filename_id": filename_id, "ind": ind} for filename_id, ind in zip(batch["filename_id"], batch["ind"])] 131 | 132 | output_dicts = self(input_ids, attention_mask, segment_ids, metadata) 133 | return output_dicts 134 | 135 | 136 | def predict_step_end(self, outputs: List[Dict[str, Any]]) -> None: 137 | self.predictions.extend(outputs) 138 | 139 | 140 | def configure_optimizers(self): 141 | optimizer = AdamW(self.parameters(), **self.opt_params) 142 | if self.lrs_params["name"] == "cosine": 143 | lr_scheduler = get_cosine_schedule_with_warmup(optimizer, **self.lrs_params["init_args"]) 144 | elif self.lrs_params["name"] == "linear": 145 | lr_scheduler = get_linear_schedule_with_warmup(optimizer, **self.lrs_params["init_args"]) 146 | elif self.lrs_params["name"] == "constant": 147 | lr_scheduler = get_constant_schedule_with_warmup(optimizer, **self.lrs_params["init_args"]) 148 | else: 149 | raise ValueError(f"lr_scheduler {self.lrs_params} is not supported") 150 | 151 | return {"optimizer": optimizer, 152 | "lr_scheduler": { 153 | "scheduler": lr_scheduler, 154 | "interval": "step" 155 | } 156 | } -------------------------------------------------------------------------------- /lightning_modules/models/span_selection_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.optim as optim 4 | import torch.nn.functional as F 5 | import math 6 | import numpy as np 7 | import pytorch_lightning as pl 8 | from pytorch_lightning import LightningModule 9 | from utils.span_selection_utils import * 10 | from utils.utils import * 11 | from transformers import T5ForConditionalGeneration, AutoTokenizer 12 | from typing import Optional, Dict, Any, Tuple, List 13 | from transformers.optimization import AdamW, get_constant_schedule_with_warmup, get_linear_schedule_with_warmup 14 | from transformers.optimization import get_cosine_schedule_with_warmup 15 | 16 | class SpanSelectionModel(LightningModule): 17 | 18 | def __init__(self, 19 | model_name: str, 20 | optimizer: Dict[str, Any] = None, 21 | lr_scheduler: Dict[str, Any] = None, 22 | load_ckpt_file: str = None, 23 | test_set: str = "dev_training.json", 24 | input_dir: str = "dataset/reasoning_module_input", 25 | ) -> None: 26 | 27 | super().__init__() 28 | self.model_name = model_name 29 | self.tokenizer = AutoTokenizer.from_pretrained(self.model_name) 30 | self.model = T5ForConditionalGeneration.from_pretrained(self.model_name) 31 | 32 | self.lrs_params = lr_scheduler 33 | self.opt_params = optimizer["init_args"] 34 | 35 | self.test_set = test_set 36 | self.input_dir = input_dir 37 | 38 | self.predictions = [] 39 | 40 | def forward(self, input_ids, attention_mask, label_ids) -> List[Dict[str, Any]]: 41 | input_ids = torch.tensor(input_ids).to("cuda") 42 | attention_mask = torch.tensor(attention_mask).to("cuda") 43 | label_ids = torch.tensor(label_ids).to("cuda") 44 | 45 | loss = self.model( 46 | input_ids=input_ids, attention_mask=attention_mask, labels = label_ids).get("loss") 47 | 48 | return {"loss": loss} 49 | 50 | 51 | def training_step(self, batch: Dict[str, torch.Tensor], batch_idx: int) -> Dict[str, torch.Tensor]: 52 | input_ids = batch["input_ids"] 53 | attention_mask = batch["input_mask"] 54 | label_ids = batch["label_ids"] 55 | label_ids = torch.tensor(label_ids).to("cuda") 56 | 57 | 58 | output_dict = self(input_ids, attention_mask, label_ids) 59 | 60 | loss = output_dict["loss"] 61 | 62 | self.log("loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True) 63 | return {"loss": loss} 64 | 65 | def on_fit_start(self) -> None: 66 | # save the code using wandb 67 | if self.logger: 68 | # if logger is initialized, save the code 69 | self.logger[0].log_code() 70 | else: 71 | print("logger is not initialized, code will not be saved") 72 | 73 | return super().on_fit_start() 74 | 75 | def validation_step(self, batch: torch.Tensor, batch_idx: int): 76 | input_ids = batch["input_ids"] 77 | attention_mask = batch["input_mask"] 78 | input_ids = torch.tensor(input_ids).to("cuda") 79 | attention_mask = torch.tensor(attention_mask).to("cuda") 80 | 81 | labels = batch["label"] 82 | label_ids = batch["label_ids"] 83 | label_ids = torch.tensor(label_ids).to("cuda") 84 | 85 | generated_ids = self.model.generate( 86 | input_ids=input_ids, 87 | attention_mask=attention_mask, 88 | ) 89 | preds = [ 90 | self.tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=True) 91 | for g in generated_ids 92 | ] 93 | 94 | 95 | output_dict = self(input_ids, attention_mask, label_ids = label_ids) 96 | 97 | unique_ids = batch["uid"] 98 | output_dict["preds"] = {} 99 | for i, unique_id in enumerate(unique_ids): 100 | output_dict["preds"][unique_id] = (preds[i], labels[i]) 101 | 102 | loss = output_dict["loss"] 103 | 104 | self.log("val_loss", loss) 105 | return output_dict 106 | 107 | def validation_step_end(self, outputs: List[Dict[str, Any]]) -> None: 108 | self.predictions.append(outputs) 109 | 110 | def validation_epoch_end(self, outputs: List[Dict[str, Any]]) -> None: 111 | all_filename_id = [] 112 | all_preds = [] 113 | all_labels = [] 114 | for output_dict in self.predictions: 115 | preds = output_dict["preds"] 116 | for unique_id, pred in preds.items(): 117 | all_filename_id.append(unique_id) 118 | all_preds.append(pred[0]) 119 | all_labels.append(pred[1]) 120 | 121 | 122 | test_file = os.path.join(self.input_dir, self.test_set) 123 | res = 0 124 | res = span_selection_evaluate(all_preds, all_filename_id, test_file) 125 | 126 | self.log("exact_match", res[0]) 127 | self.log("f1", res[1]) 128 | print(f"exact_match: {res[0]}, f1: {res[1]}") 129 | # reset the predictions 130 | self.predictions = [] 131 | 132 | def predict_step(self, batch: torch.Tensor, batch_idx: int): 133 | input_ids = batch["input_ids"] 134 | attention_mask = batch["input_mask"] 135 | input_ids = torch.tensor(input_ids).to("cuda") 136 | attention_mask = torch.tensor(attention_mask).to("cuda") 137 | 138 | generated_ids = self.model.generate( 139 | input_ids=input_ids, 140 | attention_mask=attention_mask, 141 | ) 142 | preds = [ 143 | self.tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=True) 144 | for g in generated_ids 145 | ] 146 | 147 | unique_ids = batch["uid"] 148 | output_dict = [] 149 | for i, unique_id in enumerate(unique_ids): 150 | output_dict.append({"uid": unique_id, "preds": preds[i]}) 151 | return output_dict 152 | 153 | def configure_optimizers(self): 154 | optimizer = AdamW(self.parameters(), **self.opt_params) 155 | if self.lrs_params["name"] == "cosine": 156 | lr_scheduler = get_cosine_schedule_with_warmup(optimizer, **self.lrs_params["init_args"]) 157 | elif self.lrs_params["name"] == "linear": 158 | lr_scheduler = get_linear_schedule_with_warmup(optimizer, **self.lrs_params["init_args"]) 159 | elif self.lrs_params["name"] == "constant": 160 | lr_scheduler = get_constant_schedule_with_warmup(optimizer, **self.lrs_params["init_args"]) 161 | else: 162 | raise ValueError(f"lr_scheduler {self.lrs_params} is not supported") 163 | 164 | return {"optimizer": optimizer, 165 | "lr_scheduler": { 166 | "scheduler": lr_scheduler, 167 | "interval": "step" 168 | } 169 | } -------------------------------------------------------------------------------- /lightning_modules/patches/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/psunlpgroup/MultiHiertt/45bd9ccdf3142ea059bd5e69c0afb83437fa539c/lightning_modules/patches/__init__.py -------------------------------------------------------------------------------- /lightning_modules/patches/patched_loggers.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from typing import Optional, Union, List 4 | from pytorch_lightning.loggers import NeptuneLogger, CSVLogger, TensorBoardLogger, WandbLogger 5 | from pytorch_lightning.utilities import rank_zero_only 6 | 7 | class PatchedNeptuneLogger(NeptuneLogger): 8 | def __init__(self, project_name: str, *args, **kwargs): 9 | api_key = os.getenv('NEPTUNE_API_KEY') 10 | if api_key is None: 11 | raise ValueError("Please provide an API key for the neptune logger in the env vars.") 12 | # exp_name = os.getenv('PL_LOG_DIR').split('/')[0] 13 | # exp_name = os.getenv('AMLT_JOB_NAME') 14 | 15 | kwargs['api_key'] = api_key 16 | # kwargs['experiment_id'] = exp_name 17 | kwargs['project'] = project_name 18 | kwargs['source_files'] = ['**/*.py', '**/*.yaml', '**/*.sh'] 19 | 20 | super().__init__(*args, **kwargs) 21 | 22 | class PatchedWandbLogger(WandbLogger): 23 | def __init__(self, entity: str, project: str, name: str, log_model: bool, save_code: bool, 24 | tags: List[str] = None, *args, **kwargs): 25 | 26 | kwargs['entity'] = entity 27 | kwargs['save_code'] = save_code 28 | 29 | # remove the preceeding folder name 30 | processed_name = name.split('/')[-1] 31 | if tags is None: 32 | kwargs['tags'] = processed_name.split('-') 33 | else: 34 | kwargs['tags'] = tags 35 | 36 | super().__init__(name=processed_name, project=project, log_model=log_model, *args, **kwargs) 37 | 38 | @rank_zero_only 39 | def log_code(self): 40 | # log the yaml and py files 41 | root = "." 42 | print(f"saving all files in {os.path.abspath(root)}") 43 | result = self.experiment.log_code(root=root, 44 | include_fn=(lambda path: path.endswith(".py") or \ 45 | path.endswith(".yaml")), 46 | exclude_fn=(lambda path: ".venv" in path or \ 47 | "debug-tmp" in path)) 48 | if result is not None: 49 | print("########################################") 50 | print("######## Logged code to wandb. #########") 51 | print("########################################") 52 | else: 53 | print("######## logger inited but not successfully saved #########") 54 | 55 | class PatchedCSVLogger(CSVLogger): 56 | def __init__(self, 57 | name: Optional[str] = "default", 58 | version: Optional[Union[int, str]] = None, 59 | prefix: str = "", 60 | ): 61 | save_dir = os.getenv('PL_LOG_DIR') 62 | super().__init__(save_dir, name, version, prefix) 63 | 64 | class PatchedTensorBoardLogger(TensorBoardLogger): 65 | def __init__(self, 66 | name: Optional[str] = "default", 67 | version: Optional[Union[int, str]] = None, 68 | log_graph: bool = False, 69 | default_hp_metric: bool = True, 70 | prefix: str = "", 71 | sub_dir: Optional[str] = None, 72 | ): 73 | save_dir = os.getenv('PL_LOG_DIR') 74 | super().__init__(save_dir, name, version, log_graph, default_hp_metric, prefix, sub_dir) 75 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==1.1.0 2 | aiohttp==3.8.1 3 | aiosignal==1.2.0 4 | anyio==3.6.1 5 | argon2-cffi==21.3.0 6 | argon2-cffi-bindings==21.2.0 7 | asttokens==2.0.5 8 | astunparse==1.6.3 9 | async-timeout==4.0.2 10 | attrs==21.4.0 11 | Babel==2.10.3 12 | backcall==0.2.0 13 | beautifulsoup4==4.11.1 14 | bleach==5.0.0 15 | brotlipy==0.7.0 16 | cachetools==5.2.0 17 | certifi==2021.5.30 18 | cffi==1.14.6 19 | chardet==4.0.0 20 | charset-normalizer==2.1.0 21 | click==8.1.3 22 | conda==4.10.3 23 | conda-package-handling==1.7.3 24 | cryptography==3.4.7 25 | cycler==0.11.0 26 | datasets==2.3.2 27 | debugpy==1.6.0 28 | decorator==5.1.1 29 | deepspeed==0.6.5 30 | defusedxml==0.7.1 31 | dill==0.3.5.1 32 | docker-pycreds==0.4.0 33 | docstring-parser==0.14.1 34 | entrypoints==0.4 35 | executing==0.8.3 36 | fastjsonschema==2.15.3 37 | filelock==3.7.1 38 | fonttools==4.33.3 39 | frozenlist==1.3.0 40 | fsspec==2022.5.0 41 | future==0.18.2 42 | gitdb==4.0.9 43 | GitPython==3.1.27 44 | google-auth==2.8.0 45 | google-auth-oauthlib==0.4.6 46 | grpcio==1.46.3 47 | hjson==3.0.2 48 | huggingface-hub==0.8.1 49 | idna==2.10 50 | importlib-metadata==4.11.4 51 | importlib-resources==5.8.0 52 | ipykernel==6.15.0 53 | ipython==8.4.0 54 | ipython-genutils==0.2.0 55 | ipywidgets==7.7.0 56 | jedi==0.18.1 57 | Jinja2==3.1.2 58 | joblib==1.1.0 59 | json5==0.9.8 60 | jsonargparse==4.10.2 61 | jsonschema==4.6.0 62 | jupyter-client==7.3.4 63 | jupyter-core==4.10.0 64 | jupyter-server==1.17.1 65 | jupyterlab==3.4.3 66 | jupyterlab-language-pack-zh-CN==3.4.post1 67 | jupyterlab-pygments==0.2.2 68 | jupyterlab-server==2.14.0 69 | jupyterlab-widgets==1.1.0 70 | kiwisolver==1.4.3 71 | Markdown==3.3.7 72 | MarkupSafe==2.1.1 73 | matplotlib==3.5.2 74 | matplotlib-inline==0.1.3 75 | mistune==0.8.4 76 | mpmath==1.2.1 77 | multidict==6.0.2 78 | multiprocess==0.70.13 79 | nbclassic==0.3.7 80 | nbclient==0.6.4 81 | nbconvert==6.5.0 82 | nbformat==5.4.0 83 | nest-asyncio==1.5.5 84 | ninja==1.10.2.3 85 | notebook==6.4.12 86 | notebook-shim==0.1.0 87 | numpy==1.22.4 88 | oauthlib==3.2.0 89 | packaging==21.3 90 | pandas==1.4.3 91 | pandocfilters==1.5.0 92 | parso==0.8.3 93 | pathtools==0.1.2 94 | pexpect==4.8.0 95 | pickleshare==0.7.5 96 | Pillow==9.1.1 97 | pip==21.1.3 98 | prometheus-client==0.14.1 99 | promise==2.3 100 | prompt-toolkit==3.0.29 101 | protobuf==3.19.4 102 | psutil==5.9.1 103 | ptyprocess==0.7.0 104 | pure-eval==0.2.2 105 | py-cpuinfo==8.0.0 106 | pyarrow==8.0.0 107 | pyasn1==0.4.8 108 | pyasn1-modules==0.2.8 109 | pycosat==0.6.3 110 | pycparser==2.20 111 | pyDeprecate==0.3.1 112 | Pygments==2.12.0 113 | pyOpenSSL==20.0.1 114 | pyparsing==3.0.9 115 | pyrsistent==0.18.1 116 | PySocks==1.7.1 117 | python-dateutil==2.8.2 118 | pytorch-lightning==1.5.10 119 | pytz==2022.1 120 | PyYAML==6.0 121 | pyzmq==23.2.0 122 | regex==2022.6.2 123 | requests==2.25.1 124 | requests-oauthlib==1.3.1 125 | responses==0.18.0 126 | rsa==4.8 127 | ruamel-yaml-conda==0.15.100 128 | scikit-learn==1.1.1 129 | scipy==1.8.1 130 | Send2Trash==1.8.0 131 | sentencepiece==0.1.96 132 | sentry-sdk==1.6.0 133 | setproctitle==1.2.3 134 | setuptools==59.5.0 135 | shortuuid==1.0.9 136 | six==1.16.0 137 | sklearn==0.0 138 | smmap==5.0.0 139 | sniffio==1.2.0 140 | soupsieve==2.3.2.post1 141 | stack-data==0.3.0 142 | supervisor==4.2.4 143 | sympy==1.10.1 144 | tensorboard==2.9.1 145 | tensorboard-data-server==0.6.1 146 | tensorboard-plugin-wit==1.8.1 147 | terminado==0.15.0 148 | threadpoolctl==3.1.0 149 | tinycss2==1.1.1 150 | tokenizers==0.12.1 151 | torch==1.11.0+cu113 152 | torchmetrics==0.9.2 153 | torchvision==0.12.0+cu113 154 | tornado==6.1 155 | tqdm==4.64.0 156 | traitlets==5.3.0 157 | transformers==4.20.1 158 | typing-extensions==4.2.0 159 | urllib3==1.26.6 160 | wandb==0.12.20 161 | wcwidth==0.2.5 162 | webencodings==0.5.1 163 | websocket-client==1.3.3 164 | Werkzeug==2.1.2 165 | wheel==0.36.2 166 | widgetsnbextension==3.6.0 167 | xxhash==3.0.0 168 | yarl==1.7.2 169 | zipp==3.8.0 170 | -------------------------------------------------------------------------------- /trainer.py: -------------------------------------------------------------------------------- 1 | from pytorch_lightning import LightningModule, LightningDataModule 2 | from pytorch_lightning.utilities.cli import LightningCLI 3 | 4 | # see https://github.com/PyTorchLightning/pytorch-lightning/issues/10349 5 | import warnings 6 | 7 | warnings.filterwarnings( 8 | "ignore", ".*Trying to infer the `batch_size` from an ambiguous collection.*" 9 | ) 10 | 11 | cli = LightningCLI(LightningModule, LightningDataModule, 12 | subclass_mode_model=True, subclass_mode_data=True, 13 | save_config_callback=None) -------------------------------------------------------------------------------- /training_configs/program_generation_finetuning.yaml: -------------------------------------------------------------------------------- 1 | seed_everything: 333 2 | trainer: 3 | gpus: [0, 3, 4, 7] 4 | gradient_clip_val: 1.0 5 | default_root_dir: &exp_name results/MH_program_generation 6 | check_val_every_n_epoch: 4 7 | max_epochs: &max_epochs 60 8 | log_every_n_steps: 1 9 | logger: 10 | - class_path: lightning_modules.patches.patched_loggers.PatchedWandbLogger 11 | init_args: 12 | entity: yilunzhao 13 | project: mh_program_generator 14 | name: *exp_name 15 | log_model: False 16 | save_code: True 17 | offline: False 18 | callbacks: 19 | - class_path: pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint 20 | init_args: 21 | monitor: val_loss 22 | mode: min 23 | filename: '{step}-{val_loss:.4f}' 24 | save_top_k: 5 25 | - class_path: pytorch_lightning.callbacks.LearningRateMonitor 26 | init_args: 27 | logging_interval: step 28 | - class_path: pytorch_lightning.callbacks.progress.TQDMProgressBar 29 | init_args: 30 | refresh_rate: 1 31 | 32 | accelerator: gpu 33 | strategy: ddp 34 | accumulate_grad_batches: 2 35 | 36 | model: 37 | class_path: lightning_modules.models.program_generation_model.ProgramGenerationModel 38 | init_args: 39 | model_name: &transformer roberta-base 40 | program_length: 30 41 | input_length: 512 42 | max_step_ind: 11 43 | dropout_rate: 0.1 44 | num_decoder_layers: 1 45 | n_best_size: 20 46 | sep_attention: True 47 | layer_norm: True 48 | optimizer: 49 | init_args: 50 | lr: 2.0e-5 51 | betas: 52 | - 0.9 53 | - 0.999 54 | eps: 1.0e-8 55 | weight_decay: 0.1 56 | lr_scheduler: 57 | name: linear 58 | init_args: 59 | num_warmup_steps: 100 60 | num_training_steps: 10000 61 | test_set: dev_training.json 62 | entity_name: &selected_entity_name question_type 63 | input_dir: dataset/reasoning_module_input 64 | 65 | data: 66 | class_path: lightning_modules.datasets.program_generation_reader.ProgramGenerationDataModule 67 | init_args: 68 | model_name: *transformer 69 | max_seq_length: 512 70 | max_program_length: 30 71 | batch_size: 24 72 | val_batch_size: 24 # when using multi-GPU, wandb will automatically average the exe-acc 73 | train_file_path: ./dataset/reasoning_module_input/train_training.json 74 | val_file_path: ./dataset/reasoning_module_input/dev_training.json 75 | train_max_instances: -1 76 | val_max_instances: -1 77 | entity_name: *selected_entity_name 78 | 79 | # clear; export PYTHONPATH=`pwd`; python trainer.py fit --config training_configs/program_generation_finetuning.yaml -------------------------------------------------------------------------------- /training_configs/question_classification_finetuning.yaml: -------------------------------------------------------------------------------- 1 | seed_everything: 333 2 | trainer: 3 | gpus: [0, 1, 2, 3] 4 | gradient_clip_val: 1.0 5 | default_root_dir: &exp_name results/MH_question_classification 6 | check_val_every_n_epoch: 5 7 | max_epochs: &max_epochs 20 8 | log_every_n_steps: 1 9 | logger: 10 | - class_path: lightning_modules.patches.patched_loggers.PatchedWandbLogger 11 | init_args: 12 | entity: yilunzhao 13 | project: MH_question_classification 14 | name: *exp_name 15 | log_model: False 16 | save_code: True 17 | offline: False 18 | callbacks: 19 | - class_path: pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint 20 | init_args: 21 | monitor: val_loss 22 | mode: min 23 | filename: '{step}-{val_loss:.4f}' 24 | save_top_k: 2 25 | - class_path: pytorch_lightning.callbacks.LearningRateMonitor 26 | init_args: 27 | logging_interval: step 28 | - class_path: pytorch_lightning.callbacks.progress.TQDMProgressBar 29 | init_args: 30 | refresh_rate: 1 31 | - class_path: lightning_modules.callbacks.question_classification_save_prediction_callback.SavePredictionCallback 32 | init_args: 33 | test_set: &prediction_set test 34 | input_dir: &input_dir_path dataset 35 | output_dir: question_classification_output 36 | 37 | accelerator: gpu 38 | strategy: ddp_find_unused_parameters_false 39 | accumulate_grad_batches: 2 40 | 41 | 42 | model: 43 | class_path: lightning_modules.models.question_classification_model.QuestionClassificationModel 44 | init_args: 45 | model_name: &transformer roberta-base 46 | optimizer: 47 | init_args: 48 | lr: 2.0e-5 49 | betas: 50 | - 0.9 51 | - 0.999 52 | eps: 1.0e-8 53 | weight_decay: 0.1 54 | lr_scheduler: 55 | name: linear 56 | init_args: 57 | num_warmup_steps: 100 58 | num_training_steps: 10000 59 | test_set: dev 60 | 61 | data: 62 | class_path: lightning_modules.datasets.question_classification_reader.QuestionClassificationDataModule 63 | init_args: 64 | model_name: *transformer 65 | batch_size: 192 66 | val_batch_size: 512 67 | train_file_path: dataset/train.json 68 | val_file_path: dataset/dev.json 69 | train_max_instances: -1 70 | val_max_instances: -1 71 | 72 | # clear; export PYTHONPATH=`pwd`; python trainer.py fit --config training_configs/question_classification_finetuning.yaml -------------------------------------------------------------------------------- /training_configs/retriever_finetuning.yaml: -------------------------------------------------------------------------------- 1 | seed_everything: 333 2 | trainer: 3 | gpus: [0, 1, 3, 6] 4 | gradient_clip_val: 1.0 5 | default_root_dir: &exp_name results/MH_retriever 6 | val_check_interval: 1.0 7 | max_epochs: &max_epochs 5 8 | log_every_n_steps: 1 9 | logger: 10 | - class_path: lightning_modules.patches.patched_loggers.PatchedWandbLogger 11 | init_args: 12 | entity: yilunzhao 13 | project: mh_retriever 14 | name: *exp_name 15 | log_model: False 16 | save_code: True 17 | offline: True 18 | callbacks: 19 | - class_path: pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint 20 | init_args: 21 | monitor: val_loss 22 | mode: min 23 | filename: '{step}-{val_loss:.4f}' 24 | save_top_k: 3 25 | - class_path: pytorch_lightning.callbacks.LearningRateMonitor 26 | init_args: 27 | logging_interval: step 28 | - class_path: pytorch_lightning.callbacks.progress.TQDMProgressBar 29 | init_args: 30 | refresh_rate: 1 31 | 32 | accelerator: gpu 33 | strategy: ddp 34 | accumulate_grad_batches: 2 35 | 36 | model: 37 | class_path: lightning_modules.models.retriever_model.RetrieverModel 38 | init_args: 39 | transformer_model_name: &transformer roberta-base 40 | topn: 10 41 | dropout_rate: 0.1 42 | optimizer: 43 | init_args: 44 | lr: 2.0e-5 45 | betas: 46 | - 0.9 47 | - 0.999 48 | eps: 1.0e-8 49 | weight_decay: 0.1 50 | lr_scheduler: 51 | name: linear 52 | init_args: 53 | num_warmup_steps: 100 54 | num_training_steps: 10000 55 | 56 | data: 57 | class_path: lightning_modules.datasets.retriever_reader.RetrieverDataModule 58 | init_args: 59 | transformer_model_name: *transformer 60 | batch_size: 24 61 | val_batch_size: 64 62 | num_workers: 8 63 | train_file_path: dataset/train.json 64 | val_file_path: dataset/dev.json 65 | train_max_instances: -1 66 | val_max_instances: -1 67 | 68 | # clear; export PYTHONPATH=`pwd`; python trainer.py fit --config training_configs/retriever_finetuning.yaml 69 | -------------------------------------------------------------------------------- /training_configs/span_selection_finetuning.yaml: -------------------------------------------------------------------------------- 1 | seed_everything: 333 2 | trainer: 3 | gpus: [1, 3, 4, 7] 4 | default_root_dir: &exp_name results/MH_span_selection 5 | check_val_every_n_epoch: 4 6 | max_epochs: &max_epochs 40 7 | # progress_bar_refresh_rate: 1 8 | log_every_n_steps: 1 9 | logger: 10 | - class_path: lightning_modules.patches.patched_loggers.PatchedWandbLogger 11 | init_args: 12 | entity: yilunzhao 13 | project: MH_span_selection 14 | name: *exp_name 15 | log_model: False 16 | save_code: True 17 | offline: False 18 | callbacks: 19 | - class_path: pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint 20 | init_args: 21 | monitor: val_loss 22 | mode: min 23 | filename: '{step}-{val_loss:.4f}' 24 | save_top_k: 5 25 | - class_path: pytorch_lightning.callbacks.LearningRateMonitor 26 | init_args: 27 | logging_interval: step 28 | - class_path: pytorch_lightning.callbacks.progress.TQDMProgressBar 29 | init_args: 30 | refresh_rate: 1 31 | 32 | accelerator: gpu 33 | strategy: ddp_find_unused_parameters_false 34 | accumulate_grad_batches: 2 35 | 36 | model: 37 | class_path: lightning_modules.models.span_selection_model.SpanSelectionModel 38 | init_args: 39 | model_name: &transformer t5-base 40 | optimizer: 41 | init_args: 42 | lr: 2.0e-5 43 | # lr: 0.0 44 | betas: 45 | - 0.9 46 | - 0.999 47 | eps: 1.0e-8 48 | weight_decay: 0.1 49 | lr_scheduler: 50 | name: linear 51 | init_args: 52 | num_warmup_steps: 100 53 | num_training_steps: 10000 54 | test_set: dev_training.json 55 | input_dir: dataset/reasoning_module_input 56 | 57 | data: 58 | class_path: lightning_modules.datasets.span_selection_reader.SpanSelectionDataModule 59 | init_args: 60 | model_name: *transformer 61 | batch_size: 20 62 | val_batch_size: 32 63 | train_file_path: ./dataset/reasoning_module_input/train_training.json 64 | val_file_path: ./dataset/reasoning_module_input/dev_training.json 65 | train_max_instances: -1 66 | val_max_instances: -1 67 | entity_name: question_type 68 | 69 | 70 | # clear; export PYTHONPATH=`pwd`; python trainer.py fit --config training_configs/span_selection_finetuning.yaml -------------------------------------------------------------------------------- /txt_files/constant_list.txt: -------------------------------------------------------------------------------- 1 | CONST_2 2 | CONST_1 3 | CONST_3 4 | CONST_4 5 | CONST_5 6 | CONST_6 7 | CONST_7 8 | CONST_8 9 | CONST_9 10 | CONST_10 11 | CONST_100 12 | CONST_1000 13 | CONST_10000 14 | CONST_100000 15 | CONST_1000000 16 | CONST_10000000 17 | CONST_1000000000 18 | CONST_M1 19 | #0 20 | #1 21 | #2 22 | #3 23 | #4 24 | #5 25 | #6 26 | #7 27 | #8 28 | #9 29 | #10 30 | NONE -------------------------------------------------------------------------------- /txt_files/operation_list.txt: -------------------------------------------------------------------------------- 1 | add 2 | subtract 3 | multiply 4 | divide 5 | exp 6 | greater 7 | table_sum 8 | table_average 9 | table_max 10 | table_min -------------------------------------------------------------------------------- /utils/datasets_util.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import io, tokenize, re 3 | import ast, astunparse 4 | 5 | from typing import Tuple, Optional, List, Union 6 | 7 | def right_pad_sequences(sequences: List[torch.Tensor], batch_first: bool = True, padding_value: Union[int, bool] = 0, 8 | max_len: int = -1, device: torch.device = None) -> torch.Tensor: 9 | assert all([len(seq.shape) == 1 for seq in sequences]) 10 | max_len = max_len if max_len > 0 else max(len(s) for s in sequences) 11 | device = device if device is not None else sequences[0].device 12 | 13 | padded_seqs = [] 14 | for seq in sequences: 15 | padded_seqs.append(torch.cat(seq, (torch.full((max_len - seq.shape[0],), padding_value, dtype=torch.long).to(device)))) 16 | return torch.stack(padded_seqs) -------------------------------------------------------------------------------- /utils/program_generation_utils.py: -------------------------------------------------------------------------------- 1 | """MathQA utils. 2 | """ 3 | import argparse 4 | import collections 5 | import json 6 | import numpy as np 7 | import os 8 | import re 9 | import string 10 | import sys 11 | import random 12 | import enum 13 | import six 14 | import copy 15 | from six.moves import map 16 | from six.moves import range 17 | from six.moves import zip 18 | import math 19 | import tqdm 20 | from sympy import simplify 21 | from utils.utils import * 22 | 23 | all_ops = ["add", "subtract", "multiply", "divide", "exp"] 24 | 25 | sys.path.insert(0, '../utils/') 26 | max_seq_length = 512 27 | max_program_length = 30 28 | 29 | class MathQAExample( 30 | collections.namedtuple( 31 | "MathQAExample", 32 | "id original_question question_tokens options answer \ 33 | numbers number_indices original_program program" 34 | )): 35 | 36 | def convert_single_example(self, *args, **kwargs): 37 | return convert_single_mathqa_example(self, *args, **kwargs) 38 | 39 | 40 | def tokenize(tokenizer, text, apply_basic_tokenization=False): 41 | """Tokenizes text, optionally looking up special tokens separately. 42 | 43 | Args: 44 | tokenizer: a tokenizer from bert.tokenization.FullTokenizer 45 | text: text to tokenize 46 | apply_basic_tokenization: If True, apply the basic tokenization. If False, 47 | apply the full tokenization (basic + wordpiece). 48 | 49 | Returns: 50 | tokenized text. 51 | 52 | A special token is any text with no spaces enclosed in square brackets with no 53 | space, so we separate those out and look them up in the dictionary before 54 | doing actual tokenization. 55 | """ 56 | 57 | _SPECIAL_TOKENS_RE = re.compile(r"^<[^ ]*>$", re.UNICODE) 58 | 59 | tokenize_fn = tokenizer.tokenize 60 | if apply_basic_tokenization: 61 | tokenize_fn = tokenizer.basic_tokenizer.tokenize 62 | 63 | tokens = [] 64 | for token in text.split(" "): 65 | if _SPECIAL_TOKENS_RE.match(token): 66 | if token in tokenizer.get_vocab(): 67 | tokens.append(token) 68 | else: 69 | tokens.append(tokenizer.unk_token) 70 | else: 71 | tokens.extend(tokenize_fn(token)) 72 | 73 | return tokens 74 | 75 | 76 | def program_tokenization(original_program): 77 | original_program = original_program.split(',') 78 | program = [] 79 | for tok in original_program: 80 | tok = tok.strip() 81 | cur_tok = '' 82 | for c in tok: 83 | if c == ')': 84 | if cur_tok != '': 85 | program.append(cur_tok) 86 | cur_tok = '' 87 | cur_tok += c 88 | if c in ['(', ')']: 89 | program.append(cur_tok) 90 | cur_tok = '' 91 | if cur_tok != '': 92 | program.append(cur_tok) 93 | program.append('EOF') 94 | return program 95 | 96 | 97 | def convert_single_mathqa_example(example, is_training, tokenizer, max_seq_length, 98 | max_program_length, op_list, op_list_size, 99 | const_list, const_list_size, 100 | cls_token, sep_token): 101 | """Converts a single MathQAExample into an InputFeature.""" 102 | features = [] 103 | question_tokens = example.question_tokens 104 | if len(question_tokens) > max_seq_length - 2: 105 | print("too long") 106 | question_tokens = question_tokens[:max_seq_length - 2] 107 | tokens = [cls_token] + question_tokens + [sep_token] 108 | segment_ids = [0] * len(tokens) 109 | 110 | input_ids = tokenizer.convert_tokens_to_ids(tokens) 111 | 112 | 113 | input_mask = [1] * len(input_ids) 114 | for ind, offset in enumerate(example.number_indices): 115 | if offset < len(input_mask): 116 | input_mask[offset] = 2 117 | else: 118 | if is_training == True: 119 | return features 120 | 121 | padding = [0] * (max_seq_length - len(input_ids)) 122 | input_ids.extend(padding) 123 | input_mask.extend(padding) 124 | segment_ids.extend(padding) 125 | 126 | # print(len(input_ids)) 127 | assert len(input_ids) == max_seq_length 128 | assert len(input_mask) == max_seq_length 129 | assert len(segment_ids) == max_seq_length 130 | 131 | number_mask = [tmp - 1 for tmp in input_mask] 132 | for ind in range(len(number_mask)): 133 | if number_mask[ind] < 0: 134 | number_mask[ind] = 0 135 | option_mask = [1, 0, 0, 1] + [1] * (len(op_list) + len(const_list) - 4) 136 | option_mask = option_mask + number_mask 137 | option_mask = [float(tmp) for tmp in option_mask] 138 | 139 | for ind in range(len(input_mask)): 140 | if input_mask[ind] > 1: 141 | input_mask[ind] = 1 142 | 143 | numbers = example.numbers 144 | number_indices = example.number_indices 145 | program = example.program 146 | if program is not None and is_training: 147 | program_ids = prog_token_to_indices(program, numbers, number_indices, 148 | max_seq_length, op_list, op_list_size, 149 | const_list, const_list_size) 150 | if not program_ids: 151 | return None 152 | 153 | program_mask = [1] * len(program_ids) 154 | program_ids = program_ids[:max_program_length] 155 | program_mask = program_mask[:max_program_length] 156 | if len(program_ids) < max_program_length: 157 | padding = [0] * (max_program_length - len(program_ids)) 158 | program_ids.extend(padding) 159 | program_mask.extend(padding) 160 | else: 161 | program = "" 162 | program_ids = [0] * max_program_length 163 | program_mask = [0] * max_program_length 164 | assert len(program_ids) == max_program_length 165 | assert len(program_mask) == max_program_length 166 | 167 | this_input_features = { 168 | "id": example.id, 169 | "unique_id": -1, 170 | "example_index": -1, 171 | "tokens": tokens, 172 | "question": example.original_question, 173 | "input_ids": input_ids, 174 | "input_mask": input_mask, 175 | "option_mask": option_mask, 176 | "segment_ids": segment_ids, 177 | "options": example.options, 178 | "answer": example.answer, 179 | "program": program, 180 | "program_ids": program_ids, 181 | "program_weight": 1.0, 182 | "program_mask": program_mask 183 | } 184 | 185 | features.append(this_input_features) 186 | return features 187 | 188 | 189 | def read_mathqa_entry(entry, tokenizer, entity_name): 190 | if entry["qa"][entity_name] != "arithmetic": 191 | return None 192 | 193 | 194 | context = "" 195 | for idx in entry["model_input"]: 196 | if type(idx) == int: 197 | context += entry["paragraphs"][idx][:-1] 198 | context += " " 199 | 200 | else: 201 | context += entry["table_description"][idx][:-1] 202 | context += " " 203 | 204 | question = entry["qa"]["question"] 205 | this_id = entry["uid"] 206 | 207 | original_question = question + " " + tokenizer.sep_token + " " + context.strip() 208 | 209 | options = entry["qa"]["answer"] if "answer" in entry["qa"] else None 210 | answer = entry["qa"]["answer"] if "answer" in entry["qa"] else None 211 | 212 | original_question_tokens = original_question.split(' ') 213 | numbers = [] 214 | number_indices = [] 215 | question_tokens = [] 216 | 217 | # TODO 218 | for i, tok in enumerate(original_question_tokens): 219 | num = str_to_num(tok) 220 | if num is not None: 221 | if num != "n/a": 222 | numbers.append(str(num)) 223 | else: 224 | numbers.append(tok) 225 | number_indices.append(len(question_tokens)) 226 | if tok and tok[0] == '.': 227 | numbers.append(str(str_to_num(tok[1:]))) 228 | number_indices.append(len(question_tokens) + 1) 229 | tok_proc = tokenize(tokenizer, tok) 230 | question_tokens.extend(tok_proc) 231 | 232 | 233 | 234 | original_program = entry["qa"]['program'] if "program" in entry["qa"] else None 235 | if original_program: 236 | program = program_tokenization(original_program) 237 | else: 238 | program = None 239 | 240 | 241 | return MathQAExample( 242 | id=this_id, 243 | original_question=original_question, 244 | question_tokens=question_tokens, 245 | options=options, 246 | answer=answer, 247 | numbers=numbers, 248 | number_indices=number_indices, 249 | original_program=original_program, 250 | program=program) 251 | 252 | 253 | def _compute_softmax(scores): 254 | """Compute softmax probability over raw logits.""" 255 | if scores == None: 256 | return [] 257 | 258 | max_score = None 259 | for score in scores: 260 | if max_score is None or score > max_score: 261 | max_score = score 262 | 263 | exp_scores = [] 264 | total_sum = 0.0 265 | for score in scores: 266 | x = math.exp(score - max_score) 267 | exp_scores.append(x) 268 | total_sum += x 269 | 270 | probs = [] 271 | for score in exp_scores: 272 | probs.append(score / total_sum) 273 | return probs 274 | 275 | def read_examples(input_path, tokenizer, op_list, const_list, log_file): 276 | """Read a json file into a list of examples.""" 277 | with open(input_path) as input_file: 278 | input_data = json.load(input_file) 279 | 280 | examples = [] 281 | for entry in input_data: 282 | examples.append(read_mathqa_entry(entry, tokenizer)) 283 | program = examples[-1].program 284 | return input_data, examples, op_list, const_list 285 | 286 | def convert_examples_to_features(examples, 287 | tokenizer, 288 | max_seq_length, 289 | max_program_length, 290 | is_training, 291 | op_list, 292 | op_list_size, 293 | const_list, 294 | const_list_size, 295 | verbose=True): 296 | """Converts a list of DropExamples into InputFeatures.""" 297 | unique_id = 1000000000 298 | res = [] 299 | for (example_index, example) in enumerate(examples): 300 | features = example.convert_single_example( 301 | is_training=is_training, 302 | tokenizer=tokenizer, 303 | max_seq_length=max_seq_length, 304 | max_program_length=max_program_length, 305 | op_list=op_list, 306 | op_list_size=op_list_size, 307 | const_list=const_list, 308 | const_list_size=const_list_size, 309 | cls_token=tokenizer.cls_token, 310 | sep_token=tokenizer.sep_token) 311 | 312 | if features: 313 | for feature in features: 314 | feature["unique_id"] = unique_id 315 | feature["example_index"] = example_index 316 | res.append(feature) 317 | unique_id += 1 318 | 319 | return res 320 | 321 | 322 | RawResult = collections.namedtuple( 323 | "RawResult", 324 | "unique_id logits loss") 325 | 326 | 327 | def compute_prog_from_logits(logits, max_program_length, example, 328 | template=None): 329 | pred_prog_ids = [] 330 | op_stack = [] 331 | loss = 0 332 | for cur_step in range(max_program_length): 333 | cur_logits = logits[cur_step] 334 | cur_pred_softmax = _compute_softmax(cur_logits) 335 | cur_pred_token = np.argmax(cur_logits.cpu()) 336 | loss -= np.log(cur_pred_softmax[cur_pred_token]) 337 | pred_prog_ids.append(cur_pred_token) 338 | if cur_pred_token == 0: 339 | break 340 | return pred_prog_ids, loss 341 | 342 | 343 | def compute_predictions(all_examples, all_features, all_results, n_best_size, 344 | max_program_length, tokenizer, op_list, op_list_size, 345 | const_list, const_list_size): 346 | """Computes final predictions based on logits.""" 347 | example_index_to_features = collections.defaultdict(list) 348 | for feature in all_features: 349 | example_index_to_features[feature["example_index"]].append(feature) 350 | 351 | unique_id_to_result = {} 352 | for result in all_results: 353 | unique_id_to_result[result.unique_id] = result 354 | 355 | _PrelimPrediction = collections.namedtuple( # pylint: disable=invalid-name 356 | "PrelimPrediction", [ 357 | "feature_index", "logits" 358 | ]) 359 | 360 | all_predictions = collections.OrderedDict() 361 | all_predictions["pred_programs"] = collections.OrderedDict() 362 | all_predictions["ref_programs"] = collections.OrderedDict() 363 | all_nbest = collections.OrderedDict() 364 | for (example_index, example) in enumerate(all_examples): 365 | if example_index not in example_index_to_features: 366 | continue 367 | features = example_index_to_features[example_index] 368 | prelim_predictions = [] 369 | for (feature_index, feature) in enumerate(features): 370 | if feature["unique_id"] not in unique_id_to_result: 371 | continue 372 | result = unique_id_to_result[feature["unique_id"]] 373 | logits = result.logits 374 | prelim_predictions.append( 375 | _PrelimPrediction( 376 | feature_index=feature_index, 377 | logits=logits)) 378 | 379 | _NbestPrediction = collections.namedtuple( # pylint: disable=invalid-name 380 | "NbestPrediction", "options answer program_ids program") 381 | 382 | nbest = [] 383 | for pred in prelim_predictions: 384 | if len(nbest) >= n_best_size: 385 | break 386 | program = example.program 387 | pred_prog_ids, loss = compute_prog_from_logits(pred.logits, 388 | max_program_length, 389 | example) 390 | pred_prog = indices_to_prog(pred_prog_ids, 391 | example.numbers, 392 | example.number_indices, 393 | max_seq_length, 394 | op_list, op_list_size, 395 | const_list, const_list_size 396 | ) 397 | nbest.append( 398 | _NbestPrediction( 399 | options=example.options, 400 | answer=example.answer, 401 | program_ids=pred_prog_ids, 402 | program=pred_prog)) 403 | 404 | # assert len(nbest) >= 1 405 | if len(nbest) == 0: 406 | continue 407 | 408 | nbest_json = [] 409 | for (i, entry) in enumerate(nbest): 410 | output = collections.OrderedDict() 411 | output["id"] = example.id 412 | output["options"] = entry.options 413 | output["ref_answer"] = entry.answer 414 | output["pred_prog"] = [str(prog) for prog in entry.program] 415 | output["ref_prog"] = example.program 416 | output["question_tokens"] = example.question_tokens 417 | output["numbers"] = example.numbers 418 | output["number_indices"] = example.number_indices 419 | nbest_json.append(output) 420 | 421 | assert len(nbest_json) >= 1 422 | 423 | all_predictions["pred_programs"][example_index] = nbest_json[0]["pred_prog"] 424 | all_predictions["ref_programs"][example_index] = nbest_json[0]["ref_prog"] 425 | all_nbest[example_index] = nbest_json 426 | 427 | return all_predictions, all_nbest 428 | 429 | 430 | def write_predictions(all_predictions, output_prediction_file): 431 | """Writes final predictions in json format.""" 432 | 433 | with open(output_prediction_file, "w") as writer: 434 | writer.write(json.dumps(all_predictions, indent=4) + "\n") 435 | 436 | 437 | def process_row(row_in): 438 | 439 | row_out = [] 440 | invalid_flag = 0 441 | 442 | for num in row_in: 443 | num = num.replace("$", "").strip() 444 | num = num.split("(")[0].strip() 445 | 446 | num = str_to_num(num) 447 | 448 | if num == "n/a": 449 | invalid_flag = 1 450 | break 451 | 452 | row_out.append(num) 453 | 454 | if invalid_flag: 455 | return "n/a" 456 | 457 | return row_out 458 | 459 | 460 | def reprog_to_seq(prog_in, is_gold): 461 | ''' 462 | predicted recursive program to list program 463 | ["divide(", "72", "multiply(", "6", "210", ")", ")"] 464 | ["multiply(", "6", "210", ")", "divide(", "72", "#0", ")"] 465 | ''' 466 | 467 | st = [] 468 | res = [] 469 | 470 | try: 471 | num = 0 472 | for tok in prog_in: 473 | if tok != ")": 474 | st.append(tok) 475 | else: 476 | this_step_vec = [")"] 477 | for _ in range(3): 478 | this_step_vec.append(st[-1]) 479 | st = st[:-1] 480 | res.extend(this_step_vec[::-1]) 481 | st.append("#" + str(num)) 482 | num += 1 483 | except: 484 | if is_gold: 485 | raise ValueError 486 | 487 | return res 488 | 489 | 490 | def eval_program(program): 491 | ''' 492 | calculate the numerical results of the program 493 | ''' 494 | 495 | invalid_flag = 0 496 | this_res = "n/a" 497 | 498 | try: 499 | program = program[:-1] # remove EOF 500 | # check structure 501 | for ind, token in enumerate(program): 502 | if ind % 4 == 0: 503 | if token.strip("(") not in all_ops: 504 | return 1, "n/a" 505 | if (ind + 1) % 4 == 0: 506 | if token != ")": 507 | return 1, "n/a" 508 | 509 | program = "|".join(program) 510 | steps = program.split(")")[:-1] 511 | 512 | res_dict = {} 513 | 514 | for ind, step in enumerate(steps): 515 | step = step.strip() 516 | 517 | if len(step.split("(")) > 2: 518 | invalid_flag = 1 519 | break 520 | op = step.split("(")[0].strip("|").strip() 521 | args = step.split("(")[1].strip("|").strip() 522 | 523 | arg1 = args.split("|")[0].strip() 524 | arg2 = args.split("|")[1].strip() 525 | 526 | if "#" in arg1: 527 | arg1 = res_dict[int(arg1.replace("#", ""))] 528 | else: 529 | arg1 = str_to_num(arg1) 530 | if arg1 == "n/a": 531 | invalid_flag = 1 532 | break 533 | 534 | if "#" in arg2: 535 | arg2 = res_dict[int(arg2.replace("#", ""))] 536 | else: 537 | arg2 = str_to_num(arg2) 538 | if arg2 == "n/a": 539 | invalid_flag = 1 540 | break 541 | 542 | if op == "add": 543 | this_res = arg1 + arg2 544 | elif op == "subtract": 545 | this_res = arg1 - arg2 546 | elif op == "multiply": 547 | this_res = arg1 * arg2 548 | elif op == "divide": 549 | this_res = arg1 / arg2 550 | elif op == "exp": 551 | this_res = arg1 ** arg2 552 | 553 | res_dict[ind] = this_res 554 | 555 | if this_res != "n/a": 556 | this_res = round(this_res, 5) 557 | 558 | except: 559 | invalid_flag = 1 560 | 561 | return invalid_flag, this_res 562 | 563 | 564 | def evaluate_result(all_nbest, json_ori, program_mode): 565 | ''' 566 | execution acc 567 | program acc 568 | ''' 569 | 570 | data = all_nbest 571 | 572 | with open(json_ori) as f_in: 573 | data_ori = json.load(f_in) 574 | 575 | data_dict = {} 576 | for each_data in data_ori: 577 | assert each_data["uid"] not in data_dict 578 | data_dict[each_data["uid"]] = each_data 579 | 580 | exe_correct = 0 581 | 582 | res_list = [] 583 | all_res_list = [] 584 | 585 | for tmp in data: 586 | each_data = data[tmp][0] 587 | each_id = each_data["id"] 588 | 589 | each_ori_data = data_dict[each_id] 590 | gold_res = each_ori_data["qa"]["answer"] 591 | 592 | pred = each_data["pred_prog"] 593 | gold = each_data["ref_prog"] 594 | 595 | if program_mode == "nest": 596 | if pred[-1] == "EOF": 597 | pred = pred[:-1] 598 | pred = reprog_to_seq(pred, is_gold=False) 599 | pred += ["EOF"] 600 | gold = gold[:-1] 601 | gold = reprog_to_seq(gold, is_gold=True) 602 | gold += ["EOF"] 603 | 604 | invalid_flag, exe_res = eval_program(pred) 605 | 606 | if invalid_flag == 0: 607 | if exe_res == gold_res: 608 | exe_correct += 1 609 | 610 | each_ori_data["qa"]["predicted"] = pred 611 | 612 | if exe_res != gold_res: 613 | res_list.append(each_ori_data) 614 | all_res_list.append(each_ori_data) 615 | 616 | exe_acc = float(exe_correct) / len(data) 617 | 618 | print("All: ", len(data)) 619 | print("Exe acc: ", exe_acc) 620 | 621 | return exe_acc 622 | -------------------------------------------------------------------------------- /utils/retriever_utils.py: -------------------------------------------------------------------------------- 1 | """MathQA utils. 2 | """ 3 | import argparse 4 | import collections 5 | import json 6 | import numpy as np 7 | import os 8 | import re 9 | import string 10 | import sys 11 | import random 12 | import enum 13 | import six 14 | import copy 15 | from six.moves import map 16 | from six.moves import range 17 | from six.moves import zip 18 | from tqdm import tqdm 19 | from utils.utils import * 20 | 21 | _SPECIAL_TOKENS_RE = re.compile(r"^\[[^ ]*\]$", re.UNICODE) 22 | 23 | class MathQAExample( 24 | collections.namedtuple( 25 | "MathQAExample", 26 | "filename_id question paragraphs table_descriptions \ 27 | pos_sent_ids pos_table_ids" 28 | )): 29 | def convert_single_example(self, *args, **kwargs): 30 | return convert_single_mathqa_example(self, *args, **kwargs) 31 | 32 | 33 | class InputFeatures(object): 34 | """A single set of features of data.""" 35 | 36 | def __init__(self, 37 | filename_id, 38 | retrieve_ind, 39 | tokens, 40 | input_ids, 41 | segment_ids, 42 | input_mask, 43 | label): 44 | 45 | self.filename_id = filename_id 46 | self.retrieve_ind = retrieve_ind 47 | self.tokens = tokens 48 | self.input_ids = input_ids 49 | self.input_mask = input_mask 50 | self.segment_ids = segment_ids 51 | self.label = label 52 | 53 | 54 | def tokenize(tokenizer, text, apply_basic_tokenization=False): 55 | """Tokenizes text, optionally looking up special tokens separately. 56 | 57 | Args: 58 | tokenizer: a tokenizer from bert.tokenization.FullTokenizer 59 | text: text to tokenize 60 | apply_basic_tokenization: If True, apply the basic tokenization. If False, 61 | apply the full tokenization (basic + wordpiece). 62 | 63 | Returns: 64 | tokenized text. 65 | 66 | A special token is any text with no spaces enclosed in square brackets with no 67 | space, so we separate those out and look them up in the dictionary before 68 | doing actual tokenization. 69 | """ 70 | 71 | _SPECIAL_TOKENS_RE = re.compile(r"^<[^ ]*>$", re.UNICODE) 72 | 73 | tokenize_fn = tokenizer.tokenize 74 | if apply_basic_tokenization: 75 | tokenize_fn = tokenizer.basic_tokenizer.tokenize 76 | 77 | tokens = [] 78 | for token in text.split(" "): 79 | if _SPECIAL_TOKENS_RE.match(token): 80 | if token in tokenizer.get_vocab(): 81 | tokens.append(token) 82 | else: 83 | tokens.append(tokenizer.unk_token) 84 | else: 85 | tokens.extend(tokenize_fn(token)) 86 | 87 | return tokens 88 | 89 | def remove_space(text_in): 90 | res = [] 91 | 92 | for tmp in text_in.split(" "): 93 | if tmp != "": 94 | res.append(tmp) 95 | 96 | return " ".join(res) 97 | 98 | 99 | def wrap_single_pair(tokenizer, question, context, label, max_seq_length, 100 | cls_token, sep_token): 101 | ''' 102 | single pair of question, context, label feature 103 | ''' 104 | 105 | question_tokens = tokenize(tokenizer, question) 106 | this_gold_tokens = tokenize(tokenizer, context) 107 | 108 | tokens = [cls_token] + question_tokens + [sep_token] 109 | segment_ids = [0] * len(tokens) 110 | 111 | tokens += this_gold_tokens 112 | segment_ids.extend([0] * len(this_gold_tokens)) 113 | 114 | if len(tokens) > max_seq_length: 115 | tokens = tokens[:max_seq_length-1] 116 | tokens += [sep_token] 117 | segment_ids = segment_ids[:max_seq_length] 118 | 119 | input_ids = tokenizer.convert_tokens_to_ids(tokens) 120 | input_mask = [1] * len(input_ids) 121 | 122 | padding = [0] * (max_seq_length - len(input_ids)) 123 | input_ids.extend(padding) 124 | input_mask.extend(padding) 125 | segment_ids.extend(padding) 126 | 127 | assert len(input_ids) == max_seq_length 128 | assert len(input_mask) == max_seq_length 129 | assert len(segment_ids) == max_seq_length 130 | 131 | this_input_feature = { 132 | "context": context, 133 | "tokens": tokens, 134 | "input_ids": input_ids, 135 | "input_mask": input_mask, 136 | "segment_ids": segment_ids, 137 | "label": label 138 | } 139 | 140 | return this_input_feature 141 | 142 | 143 | def convert_single_mathqa_example(example, option, is_training, tokenizer, max_seq_length, 144 | cls_token, sep_token): 145 | """Converts a single MathQAExample into Multiple Retriever Features.""" 146 | """ option: tf idf or all""" 147 | """train: 1:3 pos neg. Test: all""" 148 | 149 | pos_features, neg_sent_features, irrelevant_neg_table_features, relevant_neg_table_features = [], [], [], [] 150 | 151 | question = example.question 152 | 153 | # positive examples 154 | # tables = example.tables 155 | paragraphs = example.paragraphs 156 | pos_text_ids = example.pos_sent_ids 157 | pos_table_ids = example.pos_table_ids 158 | table_descriptions = example.table_descriptions 159 | 160 | relevant_table_ids = set([i.split("-")[0] for i in pos_table_ids]) 161 | 162 | for sent_idx, sent in enumerate(paragraphs): 163 | if sent_idx in pos_text_ids: 164 | this_input_feature = wrap_single_pair( 165 | tokenizer, example.question, sent, 1, max_seq_length, 166 | cls_token, sep_token) 167 | else: 168 | this_input_feature = wrap_single_pair( 169 | tokenizer, example.question, sent, 0, max_seq_length, 170 | cls_token, sep_token) 171 | this_input_feature["ind"] = sent_idx 172 | this_input_feature["filename_id"] = example.filename_id 173 | 174 | if sent_idx in pos_text_ids: 175 | pos_features.append(this_input_feature) 176 | else: 177 | neg_sent_features.append(this_input_feature) 178 | 179 | for cell_idx in table_descriptions: 180 | this_gold_sent = table_descriptions[cell_idx] 181 | if cell_idx in pos_table_ids: 182 | this_input_feature = wrap_single_pair( 183 | tokenizer, question, this_gold_sent, 1, max_seq_length, 184 | cls_token, sep_token) 185 | this_input_feature["ind"] = cell_idx 186 | this_input_feature["filename_id"] = example.filename_id 187 | pos_features.append(this_input_feature) 188 | else: 189 | ti = cell_idx.split("-")[0] 190 | this_input_feature = wrap_single_pair( 191 | tokenizer, question, this_gold_sent, 0, max_seq_length, 192 | cls_token, sep_token) 193 | this_input_feature["ind"] = cell_idx 194 | this_input_feature["filename_id"] = example.filename_id 195 | if ti in relevant_table_ids: 196 | relevant_neg_table_features.append(this_input_feature) 197 | else: 198 | irrelevant_neg_table_features.append(this_input_feature) 199 | 200 | return pos_features, neg_sent_features, irrelevant_neg_table_features, relevant_neg_table_features 201 | 202 | def read_examples(input_path, tokenizer, op_list, const_list, log_file): 203 | """Read a json file into a list of examples.""" 204 | 205 | write_log(log_file, "Reading " + input_path) 206 | with open(input_path) as input_file: 207 | input_data = json.load(input_file) 208 | 209 | examples = [] 210 | for entry in input_data: 211 | examples.append(read_mathqa_entry(entry, tokenizer)) 212 | 213 | return input_data, examples, op_list, const_list 214 | 215 | def read_mathqa_entry(entry, tokenizer): 216 | 217 | question = entry["qa"]["question"] 218 | 219 | paragraphs = entry["paragraphs"] 220 | # tables = entry["tables"] 221 | 222 | if 'text_evidence' in entry["qa"]: 223 | pos_sent_ids = entry["qa"]['text_evidence'] 224 | pos_table_ids = entry["qa"]['table_evidence'] 225 | else: # test set 226 | pos_sent_ids = [] 227 | pos_table_ids = [] 228 | 229 | 230 | table_descriptions = entry["table_description"] 231 | filename_id = entry["uid"] 232 | 233 | return MathQAExample( 234 | filename_id=filename_id, 235 | question=question, 236 | paragraphs=paragraphs, 237 | # tables=tables, 238 | table_descriptions=table_descriptions, 239 | pos_sent_ids=pos_sent_ids, 240 | pos_table_ids=pos_table_ids, 241 | ) 242 | 243 | 244 | def convert_examples_to_features(examples, 245 | tokenizer, 246 | max_seq_length, 247 | option, 248 | is_training, 249 | ): 250 | """Converts a list of DropExamples into InputFeatures.""" 251 | res, res_neg_sent, res_irrelevant_neg_table, res_relevant_neg_table = [], [], [], [] 252 | for (example_index, example) in tqdm(enumerate(examples)): 253 | pos_features, neg_sent_features, irrelevant_neg_table_features, relevant_neg_table_features = example.convert_single_example( 254 | tokenizer=tokenizer, 255 | max_seq_length=max_seq_length, 256 | option=option, 257 | is_training=is_training, 258 | cls_token=tokenizer.cls_token, 259 | sep_token=tokenizer.sep_token) 260 | 261 | res.extend(pos_features) 262 | res_neg_sent.extend(neg_sent_features) 263 | res_irrelevant_neg_table.extend(irrelevant_neg_table_features) 264 | res_relevant_neg_table.extend(relevant_neg_table_features) 265 | 266 | return res, res_neg_sent, res_irrelevant_neg_table, res_relevant_neg_table 267 | 268 | 269 | 270 | def retrieve_evaluate(all_logits, all_filename_ids, all_inds, ori_file, topn): 271 | ''' 272 | save results to file. calculate recall 273 | ''' 274 | 275 | res_filename = {} 276 | res_filename_inds = {} 277 | 278 | print(len(all_logits)) 279 | for this_logit, this_filename_id, this_ind in zip(all_logits, all_filename_ids, all_inds): 280 | 281 | if this_filename_id not in res_filename: 282 | res_filename[this_filename_id] = [] 283 | res_filename_inds[this_filename_id] = [] 284 | 285 | if this_ind not in res_filename_inds[this_filename_id]: 286 | res_filename[this_filename_id].append({ 287 | "score": this_logit[1].item(), 288 | "ind": this_ind 289 | }) 290 | res_filename_inds[this_filename_id].append(this_ind) 291 | 292 | 293 | 294 | with open(ori_file) as f: 295 | data_all = json.load(f) 296 | 297 | # take top ten 298 | all_recall = 0.0 299 | # all_recall_3 = 0.0 300 | 301 | count_data = 0 302 | for data in data_all: 303 | this_filename_id = data["uid"] 304 | 305 | if this_filename_id not in res_filename: 306 | continue 307 | count_data += 1 308 | this_res = res_filename[this_filename_id] 309 | 310 | sorted_dict = sorted(this_res, key=lambda kv: kv["score"], reverse=True) 311 | 312 | # sorted_dict = sorted_dict[:topn] 313 | 314 | gold_sent_inds = data["qa"]["text_evidence"] 315 | gold_table_inds = data["qa"]["table_evidence"] 316 | 317 | # table rows 318 | table_retrieved = [] 319 | text_retrieved = [] 320 | 321 | # all retrieved 322 | table_re_all = [] 323 | text_re_all = [] 324 | 325 | correct = 0 326 | # correct_3 = 0 327 | 328 | for tmp in sorted_dict[:topn]: 329 | if type(tmp["ind"]) == str: 330 | table_retrieved.append(tmp) 331 | if tmp["ind"] in gold_table_inds: 332 | correct += 1 333 | else: 334 | text_retrieved.append(tmp) 335 | if tmp["ind"] in gold_sent_inds: 336 | correct += 1 337 | 338 | # if tmp["ind"] in gold_inds: 339 | # correct += 1 340 | # # print(sorted_dict) 341 | for tmp in sorted_dict: 342 | if type(tmp["ind"]) == str: 343 | table_re_all.append(tmp) 344 | else: 345 | text_re_all.append(tmp) 346 | 347 | # for tmp in sorted_dict[:3]: 348 | # if tmp["ind"] in gold_inds: 349 | # correct_3 += 1 350 | 351 | all_recall += (float(correct) / (len(gold_table_inds) + len(gold_sent_inds))) 352 | # all_recall_3 += (float(correct_3) / len(gold_inds)) 353 | 354 | data["table_retrieved_all"] = table_re_all 355 | data["text_retrieved_all"] = text_re_all 356 | 357 | # res_3 = all_recall_3 / len(data_all) 358 | res = all_recall / len(data_all) 359 | 360 | # res_message = "Top 3: " + str(res_3) + "\n" + "Top 5: " + str(res) + "\n" 361 | res_message = f"Top {topn}: {res}\n" 362 | 363 | return res, res_message 364 | 365 | 366 | def retrieve_inference(all_logits, all_filename_ids, all_inds, output_prediction_file, ori_file): 367 | ''' 368 | save results to file. calculate recall 369 | ''' 370 | 371 | res_filename = {} 372 | res_filename_inds = {} 373 | 374 | for this_logit, this_filename_id, this_ind in zip(all_logits, all_filename_ids, all_inds): 375 | if this_filename_id not in res_filename: 376 | res_filename[this_filename_id] = [] 377 | res_filename_inds[this_filename_id] = [] 378 | 379 | if this_ind not in res_filename_inds[this_filename_id]: 380 | res_filename[this_filename_id].append({ 381 | "score": this_logit[1].item(), 382 | "ind": this_ind 383 | }) 384 | res_filename_inds[this_filename_id].append(this_ind) 385 | 386 | 387 | 388 | with open(ori_file) as f: 389 | data_all = json.load(f) 390 | 391 | 392 | output_data = [] 393 | for data in data_all: 394 | table_re_all = [] 395 | text_re_all = [] 396 | this_filename_id = data["uid"] 397 | 398 | if this_filename_id not in res_filename: 399 | continue 400 | 401 | this_res = res_filename[this_filename_id] 402 | sorted_dict = sorted(this_res, key=lambda kv: kv["score"], reverse=True) 403 | 404 | for tmp in sorted_dict: 405 | if type(tmp["ind"]) == str: 406 | table_re_all.append(tmp) 407 | else: 408 | text_re_all.append(tmp) 409 | 410 | data["table_retrieved_all"] = table_re_all 411 | data["text_retrieved_all"] = text_re_all 412 | output_data.append(data) 413 | 414 | with open(output_prediction_file, "w") as f: 415 | json.dump(output_data, f, indent=4) 416 | 417 | return None -------------------------------------------------------------------------------- /utils/span_selection_utils.py: -------------------------------------------------------------------------------- 1 | """MathQA utils. 2 | """ 3 | import argparse 4 | import collections 5 | import json 6 | import numpy as np 7 | import os 8 | import re 9 | import string 10 | import sys 11 | import random 12 | import enum 13 | import six 14 | import copy 15 | from six.moves import map 16 | from six.moves import range 17 | from six.moves import zip 18 | import math 19 | import tqdm 20 | from sympy import simplify 21 | from utils.utils import * 22 | from collections import defaultdict 23 | from typing import Any, Dict, List, Set, Tuple, Union, Optional 24 | import json 25 | import argparse 26 | import string 27 | import re 28 | from scipy.optimize import linear_sum_assignment 29 | 30 | all_ops = ["add", "subtract", "multiply", "divide", "exp"] 31 | 32 | sys.path.insert(0, '../utils/') 33 | max_seq_length = 512 34 | max_program_length = 30 35 | 36 | class MathQAExample( 37 | collections.namedtuple( 38 | "MathQAExample", 39 | "id original_question question_tokens answer" 40 | )): 41 | 42 | def convert_single_example(self, *args, **kwargs): 43 | return convert_single_mathqa_example(self, *args, **kwargs) 44 | 45 | 46 | def tokenize(tokenizer, text, apply_basic_tokenization=False): 47 | """Tokenizes text, optionally looking up special tokens separately. 48 | 49 | Args: 50 | tokenizer: a tokenizer from bert.tokenization.FullTokenizer 51 | text: text to tokenize 52 | apply_basic_tokenization: If True, apply the basic tokenization. If False, 53 | apply the full tokenization (basic + wordpiece). 54 | 55 | Returns: 56 | tokenized text. 57 | 58 | A special token is any text with no spaces enclosed in square brackets with no 59 | space, so we separate those out and look them up in the dictionary before 60 | doing actual tokenization. 61 | """ 62 | 63 | _SPECIAL_TOKENS_RE = re.compile(r"^<[^ ]*>$", re.UNICODE) 64 | 65 | tokenize_fn = tokenizer.tokenize 66 | if apply_basic_tokenization: 67 | tokenize_fn = tokenizer.basic_tokenizer.tokenize 68 | 69 | tokens = [] 70 | for token in text.split(" "): 71 | if _SPECIAL_TOKENS_RE.match(token): 72 | if token in tokenizer.get_vocab(): 73 | tokens.append(token) 74 | else: 75 | tokens.append(tokenizer.unk_token) 76 | else: 77 | tokens.extend(tokenize_fn(token)) 78 | 79 | return tokens 80 | 81 | 82 | 83 | def convert_single_mathqa_example(example, tokenizer, max_seq_length): 84 | """Converts a single MathQAExample into an InputFeature.""" 85 | # input_ids = tokenizer.convert_tokens_to_ids(tokens) 86 | input_text_encoded = tokenizer.encode_plus(example.original_question, 87 | max_length=max_seq_length, 88 | pad_to_max_length=True) 89 | input_ids = input_text_encoded["input_ids"] 90 | input_mask = input_text_encoded["attention_mask"] 91 | 92 | label_encoded = tokenizer.encode_plus(str(example.answer), 93 | max_length=16, 94 | pad_to_max_length=True) 95 | label_ids = label_encoded["input_ids"] 96 | 97 | this_input_feature = { 98 | "uid": example.id, 99 | "tokens": example.question_tokens, 100 | "question": example.original_question, 101 | "input_ids": input_ids, 102 | "input_mask": input_mask, 103 | "label_ids": label_ids, 104 | "label": str(example.answer) 105 | } 106 | 107 | return this_input_feature 108 | 109 | 110 | def read_mathqa_entry(entry, tokenizer, entity_name): 111 | if entry["qa"][entity_name] != "span_selection": 112 | return None 113 | 114 | 115 | context = "" 116 | for idx in entry["model_input"]: 117 | if type(idx) == int: 118 | context += entry["paragraphs"][idx][:-1] 119 | context += " " 120 | 121 | else: 122 | context += entry["table_description"][idx][:-1] 123 | context += " " 124 | 125 | question = entry["qa"]["question"] 126 | this_id = entry["uid"] 127 | 128 | original_question = f"Question: {question} Context: {context.strip()}" 129 | if "answer" in entry["qa"]: 130 | answer = entry["qa"]["answer"] 131 | else: 132 | answer = "" 133 | if type(answer) != str: 134 | answer = str(int(answer)) 135 | 136 | original_question_tokens = original_question.split(' ') 137 | 138 | 139 | return MathQAExample( 140 | id=this_id, 141 | original_question=original_question, 142 | question_tokens=original_question_tokens, 143 | answer=answer) 144 | 145 | 146 | def read_examples(input_path, tokenizer): 147 | """Read a json file into a list of examples.""" 148 | with open(input_path) as input_file: 149 | input_data = json.load(input_file) 150 | 151 | examples = [] 152 | for entry in tqdm(input_data): 153 | examples.append(read_mathqa_entry(entry, tokenizer)) 154 | return input_data, examples 155 | 156 | def convert_examples_to_features(examples, 157 | tokenizer, 158 | max_seq_length, 159 | verbose=True): 160 | """Converts a list of DropExamples into InputFeatures.""" 161 | res = [] 162 | for (example_index, example) in enumerate(examples): 163 | feature = example.convert_single_example( 164 | tokenizer=tokenizer, 165 | max_seq_length=max_seq_length 166 | ) 167 | res.append(feature) 168 | return res 169 | 170 | 171 | def write_predictions(all_predictions, output_prediction_file): 172 | """Writes final predictions in json format.""" 173 | 174 | with open(output_prediction_file, "w") as writer: 175 | writer.write(json.dumps(all_predictions, indent=4) + "\n") 176 | 177 | 178 | 179 | # From here through get_metric was originally copied from: 180 | # https://github.com/allenai/allennlp-reading-comprehension/blob/master/allennlp_rc/eval/drop_eval.py 181 | def _remove_articles(text: str) -> str: 182 | regex = re.compile(r"\b(a|an|the)\b", re.UNICODE) 183 | return re.sub(regex, " ", text) 184 | 185 | 186 | def _white_space_fix(text: str) -> str: 187 | return " ".join(text.split()) 188 | 189 | 190 | EXCLUDE = set(string.punctuation) 191 | 192 | 193 | def _remove_punc(text: str) -> str: 194 | if not _is_number(text): 195 | return "".join(ch for ch in text if ch not in EXCLUDE) 196 | else: 197 | return text 198 | 199 | 200 | def _lower(text: str) -> str: 201 | return text.lower() 202 | 203 | 204 | def _tokenize(text: str) -> List[str]: 205 | return re.split(" |-", text) 206 | 207 | 208 | def _normalize_answer(text: str) -> str: 209 | """Lower text and remove punctuation, articles and extra whitespace.""" 210 | 211 | parts = [ 212 | _white_space_fix(_remove_articles(_normalize_number(_remove_punc(_lower(token))))) 213 | for token in _tokenize(text) 214 | ] 215 | parts = [part for part in parts if part.strip()] 216 | normalized = " ".join(parts).strip() 217 | return normalized 218 | 219 | 220 | def _is_number(text: str) -> bool: 221 | try: 222 | float(text) 223 | return True 224 | except ValueError: 225 | return False 226 | 227 | 228 | def _normalize_number(text: str) -> str: 229 | if _is_number(text): 230 | return str(float(text)) 231 | else: 232 | return text 233 | 234 | 235 | def _answer_to_bags( 236 | answer: Union[str, List[str], Tuple[str, ...]] 237 | ) -> Tuple[List[str], List[Set[str]]]: 238 | if isinstance(answer, (list, tuple)): 239 | raw_spans = answer 240 | else: 241 | raw_spans = [answer] 242 | normalized_spans: List[str] = [] 243 | token_bags = [] 244 | for raw_span in raw_spans: 245 | normalized_span = _normalize_answer(raw_span) 246 | normalized_spans.append(normalized_span) 247 | token_bags.append(set(normalized_span.split())) 248 | return normalized_spans, token_bags 249 | 250 | 251 | def _align_bags(predicted: List[Set[str]], gold: List[Set[str]]) -> List[float]: 252 | """ 253 | Takes gold and predicted answer sets and first finds the optimal 1-1 alignment 254 | between them and gets maximum metric values over all the answers. 255 | """ 256 | scores = np.zeros([len(gold), len(predicted)]) 257 | for gold_index, gold_item in enumerate(gold): 258 | for pred_index, pred_item in enumerate(predicted): 259 | if _match_numbers_if_present(gold_item, pred_item): 260 | scores[gold_index, pred_index] = _compute_f1(pred_item, gold_item) 261 | row_ind, col_ind = linear_sum_assignment(-scores) 262 | 263 | max_scores = np.zeros([max(len(gold), len(predicted))]) 264 | for row, column in zip(row_ind, col_ind): 265 | max_scores[row] = max(max_scores[row], scores[row, column]) 266 | return max_scores 267 | 268 | 269 | def _compute_f1(predicted_bag: Set[str], gold_bag: Set[str]) -> float: 270 | intersection = len(gold_bag.intersection(predicted_bag)) 271 | if not predicted_bag: 272 | precision = 1.0 273 | else: 274 | precision = intersection / float(len(predicted_bag)) 275 | if not gold_bag: 276 | recall = 1.0 277 | else: 278 | recall = intersection / float(len(gold_bag)) 279 | f1 = ( 280 | (2 * precision * recall) / (precision + recall) 281 | if not (precision == 0.0 and recall == 0.0) 282 | else 0.0 283 | ) 284 | return f1 285 | 286 | 287 | def _match_numbers_if_present(gold_bag: Set[str], predicted_bag: Set[str]) -> bool: 288 | gold_numbers = set() 289 | predicted_numbers = set() 290 | for word in gold_bag: 291 | if _is_number(word): 292 | gold_numbers.add(word) 293 | for word in predicted_bag: 294 | if _is_number(word): 295 | predicted_numbers.add(word) 296 | if (not gold_numbers) or gold_numbers.intersection(predicted_numbers): 297 | return True 298 | return False 299 | 300 | 301 | def get_span_selection_metrics( 302 | predicted: Union[str, List[str], Tuple[str, ...]], gold: Union[str, List[str], Tuple[str, ...]] 303 | ) -> Tuple[float, float]: 304 | """ 305 | Takes a predicted answer and a gold answer (that are both either a string or a list of 306 | strings), and returns exact match and the DROP F1 metric for the prediction. If you are 307 | writing a script for evaluating objects in memory (say, the output of predictions during 308 | validation, or while training), this is the function you want to call, after using 309 | :func:`answer_json_to_strings` when reading the gold answer from the released data file. 310 | """ 311 | predicted_bags = _answer_to_bags(predicted) 312 | gold_bags = _answer_to_bags(gold) 313 | 314 | if set(predicted_bags[0]) == set(gold_bags[0]) and len(predicted_bags[0]) == len(gold_bags[0]): 315 | exact_match = 1.0 316 | else: 317 | exact_match = 0.0 318 | 319 | f1_per_bag = _align_bags(predicted_bags[1], gold_bags[1]) 320 | f1 = np.mean(f1_per_bag) 321 | f1 = round(f1, 2) 322 | return exact_match, f1 323 | 324 | def span_selection_evaluate(all_preds, all_filename_id, test_file): 325 | ''' 326 | Exact Match 327 | F1 328 | ''' 329 | results = [] 330 | exact_match, f1 = 0, 0 331 | with open(test_file) as f_in: 332 | data_ori = json.load(f_in) 333 | 334 | data_dict = {} 335 | for each_data in data_ori: 336 | assert each_data["uid"] not in data_dict 337 | data_dict[each_data["uid"]] = each_data["qa"]["answer"] 338 | 339 | for pred, uid in zip(all_preds, all_filename_id): 340 | gold = data_dict[uid] 341 | if type(gold) != str: 342 | gold = str(int(gold)) 343 | 344 | cur_exact_match, cur_f1 = get_span_selection_metrics(pred, gold) 345 | 346 | result = {"uid": uid, "answer": gold, "predicted_answer": pred, "exact_match": exact_match, "f1": f1} 347 | results.append(result) 348 | 349 | exact_match += cur_exact_match 350 | f1 += cur_f1 351 | 352 | exact_match = exact_match / len(all_preds) 353 | f1 = f1 / len(all_preds) 354 | return exact_match, f1 355 | -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | """MathQA utils. 2 | """ 3 | import json 4 | from tqdm import tqdm 5 | 6 | def str_to_num(text): 7 | text = text.replace("$","") 8 | text = text.replace(",", "") 9 | text = text.replace("-", "") 10 | text = text.replace("%", "") 11 | try: 12 | num = float(text) 13 | except ValueError: 14 | if "const_" in text: 15 | text = text.replace("const_", "") 16 | if text == "m1": 17 | text = "-1" 18 | num = float(text) 19 | else: 20 | num = "n/a" 21 | return num 22 | 23 | 24 | def prog_token_to_indices(prog, numbers, number_indices, max_seq_length, 25 | op_list, op_list_size, const_list, 26 | const_list_size): 27 | prog_indices = [] 28 | for i, token in enumerate(prog): 29 | if token in op_list: 30 | prog_indices.append(op_list.index(token)) 31 | elif token in const_list: 32 | prog_indices.append(op_list_size + const_list.index(token)) 33 | else: 34 | if token in numbers: 35 | cur_num_idx = numbers.index(token) 36 | else: 37 | cur_num_idx = -1 38 | for num_idx, num in enumerate(numbers): 39 | if str_to_num(num) == str_to_num(token) or (str_to_num(num) != "n/a" and str_to_num(num) / 100 == str_to_num(token)): 40 | cur_num_idx = num_idx 41 | break 42 | 43 | if cur_num_idx == -1: 44 | return None 45 | prog_indices.append(op_list_size + const_list_size + 46 | number_indices[cur_num_idx]) 47 | return prog_indices 48 | 49 | 50 | def indices_to_prog(program_indices, numbers, number_indices, max_seq_length, 51 | op_list, op_list_size, const_list, const_list_size): 52 | prog = [] 53 | for i, prog_id in enumerate(program_indices): 54 | if prog_id < op_list_size: 55 | prog.append(op_list[prog_id]) 56 | elif prog_id < op_list_size + const_list_size: 57 | prog.append(const_list[prog_id - op_list_size]) 58 | else: 59 | prog.append(numbers[number_indices.index(prog_id - op_list_size 60 | - const_list_size)]) 61 | return prog 62 | 63 | 64 | def write_log(log_file, s): 65 | print(s) 66 | with open(log_file, 'a') as f: 67 | f.write(s+'\n') 68 | 69 | 70 | 71 | def read_txt(input_path): 72 | """Read a txt file into a list.""" 73 | with open(input_path) as input_file: 74 | input_data = input_file.readlines() 75 | items = [] 76 | for line in input_data: 77 | items.append(line.strip()) 78 | return items 79 | 80 | def get_op_const_list(): 81 | op_list_file = "../txt_files/operation_list.txt" 82 | const_list_file = "../txt_files/constant_list.txt" 83 | op_list = read_txt(op_list_file) 84 | op_list = [op + '(' for op in op_list] 85 | op_list = ['EOF', 'UNK', 'GO', ')'] + op_list 86 | const_list = read_txt(const_list_file) 87 | const_list = [const.lower().replace('.', '_') for const in const_list] 88 | return op_list, const_list 89 | 90 | 91 | def write_predictions(all_predictions, output_prediction_file): 92 | """Writes final predictions in json format.""" 93 | 94 | with open(output_prediction_file, "w") as writer: 95 | writer.write(json.dumps(all_predictions, indent=4) + "\n") --------------------------------------------------------------------------------