├── .gitignore ├── README.md ├── codalab ├── answer.txt └── scoring_program │ ├── README.md │ ├── evaluate.py │ └── metadata ├── config ├── default.config └── qa │ ├── all_multi │ ├── HAF.config │ ├── bidaf.config │ └── comatch.config │ ├── all_single │ ├── HAF.config │ ├── bidaf.config │ └── comatch.config │ ├── typea_multi │ ├── HAF.config │ ├── bidaf.config │ └── comatch.config │ ├── typea_single │ ├── HAF.config │ ├── bert.config │ ├── bidaf.config │ └── comatch.config │ ├── typeb_multi │ ├── HAF.config │ ├── bidaf.config │ └── comatch.config │ └── typeb_single │ ├── HAF.config │ ├── bidaf.config │ └── comatch.config ├── config_parser ├── __init__.py └── parser.py ├── dataset ├── __init__.py └── nlp │ ├── JsonFromFiles.py │ └── __init__.py ├── formatter ├── Basic.py ├── __init__.py └── qa │ ├── Bert.py │ ├── Char.py │ ├── Comatch.py │ ├── HAF.py │ └── Word.py ├── model ├── __init__.py ├── encoder │ ├── BertEncoder.py │ ├── CNNEncoder.py │ ├── GRUEncoder.py │ ├── LSTMEncoder.py │ └── __init__.py ├── layer │ ├── Attention.py │ └── __init__.py ├── loss.py ├── optimizer.py └── qa │ ├── Bert.py │ ├── BiDAF.py │ ├── CoMatch.py │ ├── HAF.py │ └── util.py ├── reader ├── __init__.py └── reader.py ├── test.py ├── tools ├── __init__.py ├── accuracy_init.py ├── accuracy_tool.py ├── dataset_tool.py ├── eval_tool.py ├── init_tool.py ├── output_init.py ├── output_tool.py ├── test_tool.py └── train_tool.py └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | *~ 2 | *swp 3 | *swo 4 | .idea 5 | __pycache__ 6 | *un~ 7 | 8 | config/default_local.config 9 | 10 | temp 11 | notebook 12 | result 13 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # JEC-QA 2 | 3 | This is the repository of the paper ``JEC-QA: A Legal-Domain Question Answering Dataset``. 4 | 5 | ## Content 6 | 7 | - [Usage](#Usage) 8 | - [Reference](#Reference) 9 | 10 | ## Usage 11 | 12 | To know more details about this dataset, please visit the [homepage](http://jecqa.thunlp.org/). 13 | 14 | To know how to use the framework, please visit [pytorcher](https://github.com/haoxizhong/pytorch-worker). 15 | 16 | You can find a sample submission and the evaluation program on Codalab from the directory ``codalab``. 17 | 18 | ## Reference 19 | 20 | If you want to use the dataset, please cite the paper like 21 | 22 | ``` 23 | @inproceedings{zhong2019jec, 24 | title={JEC-QA: A Legal-Domain Question Answering Dataset}, 25 | author={Zhong, Haoxi and Xiao, Chaojun and Tu, Cunchao and Zhang, Tianyang and Liu, Zhiyuan and Sun, Maosong}, 26 | booktitle={Proceedings of AAAI}, 27 | year={2020}, 28 | } 29 | ``` -------------------------------------------------------------------------------- /codalab/answer.txt: -------------------------------------------------------------------------------- 1 | { 2 | "1_10": [ 3 | "A" 4 | ] 5 | } 6 | -------------------------------------------------------------------------------- /codalab/scoring_program/README.md: -------------------------------------------------------------------------------- 1 | ### Building an evaluation program that works with CodaLab 2 | 3 | This example uses python. 4 | 5 | `evaluate.py` - is an example that checks that the submission data matches the truth data, which is "Hello World!" 6 | `setup.py` - this is a file that enables py2exe to build a windows executable of the evaluate.py script. 7 | `metadata` - this is a file that lists the contents of the program.zip bundle for the CodaLab system. 8 | 9 | Once these pieces are assembled they are packages as program.zip which CodaLab can then use to evaluate the submissions 10 | for a competition. -------------------------------------------------------------------------------- /codalab/scoring_program/evaluate.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | import sys 3 | import os 4 | import os.path 5 | import json 6 | 7 | input_dir = sys.argv[1] 8 | output_dir = sys.argv[2] 9 | 10 | submit_dir = os.path.join(input_dir, 'res') 11 | truth_dir = os.path.join(input_dir, 'ref') 12 | 13 | if not os.path.isdir(submit_dir): 14 | print("%s doesn't exist" % submit_dir) 15 | 16 | if os.path.isdir(submit_dir) and os.path.isdir(truth_dir): 17 | if not os.path.exists(output_dir): 18 | os.makedirs(output_dir) 19 | 20 | output_filename = os.path.join(output_dir, 'scores.txt') 21 | output_file = open(output_filename, 'wb') 22 | 23 | truth_file = os.path.join(truth_dir, "truth.txt") 24 | truth = json.load(open(truth_file, "r")) 25 | 26 | submission_answer_file = os.path.join(submit_dir, "answer.txt") 27 | submission_answer = json.load(open(submission_answer_file, "r")) 28 | 29 | try: 30 | ans = {} 31 | for name in truth: 32 | ans[name] = {"answer": truth[name]["answer"], "predict": [], "source": truth[name]["source"]} 33 | for name in submission_answer: 34 | if name in ans.keys(): 35 | ans[name]["predict"] = submission_answer[name] 36 | 37 | res = [[[0, 0], [0, 0]], [[0, 0], [0, 0]], [[0, 0], [0, 0]]] 38 | for name in ans: 39 | t = ans[name]["source"] 40 | if len(ans[name]["answer"]) == 1: 41 | s = 0 42 | else: 43 | s = 1 44 | 45 | c = 0 46 | if set(ans[name]["answer"]) == set(ans[name]["predict"]): 47 | c = 1 48 | for a in range(s, 2): 49 | res[t][a][0] += c 50 | res[t][a][1] += 1 51 | res[2][a][0] += c 52 | res[2][a][1] += 1 53 | 54 | output_file.write("kds:%.3lf\n" % (100.0 * res[0][0][0] / res[0][0][1])) 55 | output_file.write("kda:%.3lf\n" % (100.0 * res[0][1][0] / res[0][1][1])) 56 | output_file.write("cas:%.3lf\n" % (100.0 * res[1][0][0] / res[1][0][1])) 57 | output_file.write("caa:%.3lf\n" % (100.0 * res[1][1][0] / res[1][1][1])) 58 | output_file.write("as:%.3lf\n" % (100.0 * res[2][0][0] / res[2][0][1])) 59 | output_file.write("aa:%.3lf\n" % (100.0 * res[2][1][0] / res[2][1][1])) 60 | 61 | output_file.close() 62 | 63 | print("kds:%.3lf\n" % (100.0 * res[0][0][0] / res[0][0][1])) 64 | print("kda:%.3lf\n" % (100.0 * res[0][1][0] / res[0][1][1])) 65 | print("cas:%.3lf\n" % (100.0 * res[1][0][0] / res[1][0][1])) 66 | print("caa:%.3lf\n" % (100.0 * res[1][1][0] / res[1][1][1])) 67 | print("as:%.3lf\n" % (100.0 * res[2][0][0] / res[2][0][1])) 68 | print("aa:%.3lf\n" % (100.0 * res[2][1][0] / res[2][1][1])) 69 | 70 | except Exception as e: 71 | print(e) 72 | raise e 73 | -------------------------------------------------------------------------------- /codalab/scoring_program/metadata: -------------------------------------------------------------------------------- 1 | command: python $program/evaluate.py $input $output 2 | description: Jec-QA evalutaion program. 3 | -------------------------------------------------------------------------------- /config/default.config: -------------------------------------------------------------------------------- 1 | [train] #train parameters 2 | shuffle = True 3 | 4 | reader_num = 16 5 | 6 | optimizer = adam 7 | learning_rate = 1e-3 8 | weight_decay = 0 9 | step_size = 1 10 | lr_multiplier = 1 11 | 12 | [eval] #eval parameters 13 | shuffle = False 14 | 15 | reader_num = 16 16 | 17 | [data] #data parameters 18 | test_dataset_type = JsonFromFiles 19 | test_data_path = /data/disk3/private/zhx/jecqa/data/cutted/v1 20 | test_file_list = 0_test.json,1_test.json 21 | 22 | reduce = False 23 | 24 | [model] #model parameters 25 | 26 | [output] #output parameters 27 | output_time = 1 28 | test_time = 1 29 | -------------------------------------------------------------------------------- /config/qa/all_multi/HAF.config: -------------------------------------------------------------------------------- 1 | [train] #train parameters 2 | epoch = 64 3 | batch_size = 12 4 | 5 | shuffle = True 6 | 7 | reader_num = 16 8 | 9 | optimizer = adam 10 | learning_rate = 1e-3 11 | step_size = 1 12 | lr_multiplier = 0.95 13 | 14 | [eval] #eval parameters 15 | batch_size = 12 16 | 17 | shuffle = False 18 | 19 | reader_num = 16 20 | 21 | [data] #data parameters 22 | train_dataset_type = JsonFromFiles 23 | train_formatter_type = HAF 24 | train_data_path = /data/disk3/private/zhx/jecqa/data/cutted/v1 25 | train_file_list = 0_train.json,1_train.json 26 | 27 | valid_dataset_type = JsonFromFiles 28 | valid_formatter_type = HAF 29 | valid_data_path = /data/disk3/private/zhx/jecqa/data/cutted/v1 30 | valid_file_list = 0_test.json,1_test.json 31 | 32 | test_formatter_type = HAF 33 | 34 | topk = 18 35 | multi_choice = True 36 | 37 | question_len = 64 38 | option_len = 48 39 | passage_len = 256 40 | 41 | word2id = /data/disk3/private/zhx/jecqa/data/word2id.txt 42 | 43 | [model] #model parameters 44 | model_name = HAF 45 | 46 | hidden_size = 256 47 | 48 | bi_direction = True 49 | num_layers = 2 50 | 51 | [output] #output parameters 52 | model_path = /data/disk3/private/zhx/jecqa/model/all/multi 53 | model_name = HAF 54 | 55 | tensorboard_path = /data/disk3/private/zhx/jecqa/tensorboard 56 | 57 | output_function = Basic 58 | -------------------------------------------------------------------------------- /config/qa/all_multi/bidaf.config: -------------------------------------------------------------------------------- 1 | [train] #train parameters 2 | epoch = 160 3 | batch_size = 12 4 | 5 | shuffle = True 6 | 7 | reader_num = 16 8 | 9 | optimizer = adam 10 | learning_rate = 1e-3 11 | step_size = 1 12 | lr_multiplier = 1 13 | 14 | [eval] #eval parameters 15 | batch_size = 12 16 | 17 | shuffle = False 18 | 19 | reader_num = 16 20 | 21 | [data] #data parameters 22 | train_dataset_type = JsonFromFiles 23 | train_formatter_type = WordQA 24 | train_data_path = /data/disk3/private/zhx/jecqa/data/cutted/v1 25 | train_file_list = 0_train.json,1_train.json 26 | 27 | valid_dataset_type = JsonFromFiles 28 | valid_formatter_type = WordQA 29 | valid_data_path = /data/disk3/private/zhx/jecqa/data/cutted/v1 30 | valid_file_list = 0_test.json,1_test.json 31 | 32 | test_formatter_type = WordQA 33 | 34 | topk = 18 35 | multi_choice = True 36 | 37 | max_len1 = 128 38 | max_len2 = 384 39 | 40 | word2id = /data/disk3/private/zhx/jecqa/data/word2id.txt 41 | 42 | [model] #model parameters 43 | model_name = BiDAF 44 | 45 | hidden_size = 256 46 | 47 | bi_direction = True 48 | num_layers = 2 49 | 50 | [output] #output parameters 51 | model_path = /data/disk3/private/zhx/jecqa/model/all/multi 52 | model_name = BiDAF 53 | 54 | tensorboard_path = /data/disk3/private/zhx/jecqa/tensorboard 55 | 56 | output_function = Basic 57 | -------------------------------------------------------------------------------- /config/qa/all_multi/comatch.config: -------------------------------------------------------------------------------- 1 | [train] #train parameters 2 | epoch = 128 3 | batch_size = 1 4 | 5 | shuffle = True 6 | 7 | reader_num = 16 8 | 9 | optimizer = adam 10 | learning_rate = 1e-3 11 | step_size = 1 12 | lr_multiplier = 1 13 | 14 | [eval] #eval parameters 15 | batch_size = 1 16 | 17 | shuffle = False 18 | 19 | reader_num = 16 20 | 21 | [data] #data parameters 22 | train_dataset_type = JsonFromFiles 23 | train_formatter_type = ComatchQA 24 | train_data_path = /data/disk3/private/zhx/jecqa/data/cutted/v1 25 | train_file_list = 0_train.json,1_train.json 26 | 27 | valid_dataset_type = JsonFromFiles 28 | valid_formatter_type = ComatchQA 29 | valid_data_path = /data/disk3/private/zhx/jecqa/data/cutted/v1 30 | valid_file_list = 0_test.json,1_test.json 31 | 32 | test_formatter_type = ComatchQA 33 | 34 | topk = 18 35 | multi_choice = True 36 | 37 | word2id = /data/disk3/private/zhx/jecqa/data/word2id.txt 38 | sent_max_len = 50 39 | max_sent = 50 40 | 41 | 42 | [model] #model parameters 43 | model_name = Comatch 44 | 45 | hidden_size = 150 46 | 47 | dropout = 0.2 48 | 49 | [output] #output parameters 50 | model_path = /data/disk3/private/zhx/jecqa/model/all/multi 51 | model_name = Comatch 52 | 53 | tensorboard_path = /data/disk3/private/zhx/jecqa/tensorboard 54 | 55 | output_function = Basic 56 | -------------------------------------------------------------------------------- /config/qa/all_single/HAF.config: -------------------------------------------------------------------------------- 1 | [train] #train parameters 2 | epoch = 64 3 | batch_size = 12 4 | 5 | shuffle = True 6 | 7 | reader_num = 16 8 | 9 | optimizer = adam 10 | learning_rate = 1e-3 11 | step_size = 1 12 | lr_multiplier = 0.95 13 | 14 | [eval] #eval parameters 15 | batch_size = 12 16 | 17 | shuffle = False 18 | 19 | reader_num = 16 20 | 21 | [data] #data parameters 22 | train_dataset_type = JsonFromFiles 23 | train_formatter_type = HAF 24 | train_data_path = /data/disk3/private/zhx/jecqa/data/cutted/v1 25 | train_file_list = 0_train.json,1_train.json 26 | 27 | valid_dataset_type = JsonFromFiles 28 | valid_formatter_type = HAF 29 | valid_data_path = /data/disk3/private/zhx/jecqa/data/cutted/v1 30 | valid_file_list = 0_test.json,1_test.json 31 | 32 | test_formatter_type = HAF 33 | 34 | topk = 18 35 | multi_choice = False 36 | 37 | question_len = 64 38 | option_len = 48 39 | passage_len = 256 40 | 41 | word2id = /data/disk3/private/zhx/jecqa/data/word2id.txt 42 | 43 | [model] #model parameters 44 | model_name = HAF 45 | 46 | hidden_size = 256 47 | 48 | bi_direction = True 49 | num_layers = 2 50 | 51 | [output] #output parameters 52 | model_path = /data/disk3/private/zhx/jecqa/model/all/single 53 | model_name = HAF 54 | 55 | tensorboard_path = /data/disk3/private/zhx/jecqa/tensorboard 56 | 57 | output_function = Basic 58 | -------------------------------------------------------------------------------- /config/qa/all_single/bidaf.config: -------------------------------------------------------------------------------- 1 | [train] #train parameters 2 | epoch = 160 3 | batch_size = 12 4 | 5 | shuffle = True 6 | 7 | reader_num = 16 8 | 9 | optimizer = adam 10 | learning_rate = 1e-3 11 | step_size = 1 12 | lr_multiplier = 1 13 | 14 | [eval] #eval parameters 15 | batch_size = 12 16 | 17 | shuffle = False 18 | 19 | reader_num = 16 20 | 21 | [data] #data parameters 22 | train_dataset_type = JsonFromFiles 23 | train_formatter_type = WordQA 24 | train_data_path = /data/disk3/private/zhx/jecqa/data/cutted/v1 25 | train_file_list = 0_train.json,1_train.json 26 | 27 | valid_dataset_type = JsonFromFiles 28 | valid_formatter_type = WordQA 29 | valid_data_path = /data/disk3/private/zhx/jecqa/data/cutted/v1 30 | valid_file_list = 0_test.json,1_test.json 31 | 32 | test_formatter_type = WordQA 33 | 34 | topk = 18 35 | multi_choice = False 36 | 37 | max_len1 = 128 38 | max_len2 = 384 39 | 40 | word2id = /data/disk3/private/zhx/jecqa/data/word2id.txt 41 | 42 | [model] #model parameters 43 | model_name = BiDAF 44 | 45 | hidden_size = 256 46 | 47 | bi_direction = True 48 | num_layers = 2 49 | 50 | [output] #output parameters 51 | model_path = /data/disk3/private/zhx/jecqa/model/all/single 52 | model_name = BiDAF 53 | 54 | tensorboard_path = /data/disk3/private/zhx/jecqa/tensorboard 55 | 56 | output_function = Basic 57 | -------------------------------------------------------------------------------- /config/qa/all_single/comatch.config: -------------------------------------------------------------------------------- 1 | [train] #train parameters 2 | epoch = 128 3 | batch_size = 1 4 | 5 | shuffle = True 6 | 7 | reader_num = 16 8 | 9 | optimizer = adam 10 | learning_rate = 1e-3 11 | step_size = 1 12 | lr_multiplier = 1 13 | 14 | [eval] #eval parameters 15 | batch_size = 1 16 | 17 | shuffle = False 18 | 19 | reader_num = 16 20 | 21 | [data] #data parameters 22 | train_dataset_type = JsonFromFiles 23 | train_formatter_type = ComatchQA 24 | train_data_path = /data/disk3/private/zhx/jecqa/data/cutted/v1 25 | train_file_list = 0_train.json,1_train.json 26 | 27 | valid_dataset_type = JsonFromFiles 28 | valid_formatter_type = ComatchQA 29 | valid_data_path = /data/disk3/private/zhx/jecqa/data/cutted/v1 30 | valid_file_list = 0_test.json,1_test.json 31 | 32 | test_formatter_type = ComatchQA 33 | 34 | topk = 18 35 | multi_choice = False 36 | 37 | word2id = /data/disk3/private/zhx/jecqa/data/word2id.txt 38 | sent_max_len = 50 39 | max_sent = 50 40 | 41 | 42 | [model] #model parameters 43 | model_name = Comatch 44 | 45 | hidden_size = 150 46 | 47 | dropout = 0.2 48 | 49 | [output] #output parameters 50 | model_path = /data/disk3/private/zhx/jecqa/model/all/single 51 | model_name = Comatch 52 | 53 | tensorboard_path = /data/disk3/private/zhx/jecqa/tensorboard 54 | 55 | output_function = Basic 56 | -------------------------------------------------------------------------------- /config/qa/typea_multi/HAF.config: -------------------------------------------------------------------------------- 1 | [train] #train parameters 2 | epoch = 64 3 | batch_size = 12 4 | 5 | shuffle = True 6 | 7 | reader_num = 16 8 | 9 | optimizer = adam 10 | learning_rate = 1e-3 11 | step_size = 1 12 | lr_multiplier = 0.95 13 | 14 | [eval] #eval parameters 15 | batch_size = 12 16 | 17 | shuffle = False 18 | 19 | reader_num = 16 20 | 21 | [data] #data parameters 22 | train_dataset_type = JsonFromFiles 23 | train_formatter_type = HAF 24 | train_data_path = /data/disk3/private/zhx/jecqa/data/cutted/v1 25 | train_file_list = 0_train.json 26 | 27 | valid_dataset_type = JsonFromFiles 28 | valid_formatter_type = HAF 29 | valid_data_path = /data/disk3/private/zhx/jecqa/data/cutted/v1 30 | valid_file_list = 0_test.json 31 | 32 | test_formatter_type = HAF 33 | 34 | topk = 18 35 | multi_choice = True 36 | 37 | question_len = 64 38 | option_len = 48 39 | passage_len = 256 40 | 41 | word2id = /data/disk3/private/zhx/jecqa/data/word2id.txt 42 | 43 | [model] #model parameters 44 | model_name = HAF 45 | 46 | hidden_size = 256 47 | 48 | bi_direction = True 49 | num_layers = 2 50 | 51 | [output] #output parameters 52 | model_path = /data/disk3/private/zhx/jecqa/model/typea/multi 53 | model_name = HAF 54 | 55 | tensorboard_path = /data/disk3/private/zhx/jecqa/tensorboard 56 | 57 | output_function = Basic 58 | -------------------------------------------------------------------------------- /config/qa/typea_multi/bidaf.config: -------------------------------------------------------------------------------- 1 | [train] #train parameters 2 | epoch = 160 3 | batch_size = 12 4 | 5 | shuffle = True 6 | 7 | reader_num = 16 8 | 9 | optimizer = adam 10 | learning_rate = 1e-3 11 | step_size = 1 12 | lr_multiplier = 1 13 | 14 | [eval] #eval parameters 15 | batch_size = 12 16 | 17 | shuffle = False 18 | 19 | reader_num = 16 20 | 21 | [data] #data parameters 22 | train_dataset_type = JsonFromFiles 23 | train_formatter_type = WordQA 24 | train_data_path = /data/disk3/private/zhx/jecqa/data/cutted/v1 25 | train_file_list = 0_train.json 26 | 27 | valid_dataset_type = JsonFromFiles 28 | valid_formatter_type = WordQA 29 | valid_data_path = /data/disk3/private/zhx/jecqa/data/cutted/v1 30 | valid_file_list = 0_test.json 31 | 32 | test_formatter_type = WordQA 33 | 34 | topk = 18 35 | multi_choice = True 36 | 37 | max_len1 = 128 38 | max_len2 = 384 39 | 40 | word2id = /data/disk3/private/zhx/jecqa/data/word2id.txt 41 | 42 | [model] #model parameters 43 | model_name = BiDAF 44 | 45 | hidden_size = 256 46 | 47 | bi_direction = True 48 | num_layers = 2 49 | 50 | [output] #output parameters 51 | model_path = /data/disk3/private/zhx/jecqa/model/typea/multi 52 | model_name = BiDAF 53 | 54 | tensorboard_path = /data/disk3/private/zhx/jecqa/tensorboard 55 | 56 | output_function = Basic 57 | -------------------------------------------------------------------------------- /config/qa/typea_multi/comatch.config: -------------------------------------------------------------------------------- 1 | [train] #train parameters 2 | epoch = 128 3 | batch_size = 1 4 | 5 | shuffle = True 6 | 7 | reader_num = 16 8 | 9 | optimizer = adam 10 | learning_rate = 1e-3 11 | step_size = 1 12 | lr_multiplier = 1 13 | 14 | [eval] #eval parameters 15 | batch_size = 1 16 | 17 | shuffle = False 18 | 19 | reader_num = 16 20 | 21 | [data] #data parameters 22 | train_dataset_type = JsonFromFiles 23 | train_formatter_type = ComatchQA 24 | train_data_path = /data/disk3/private/zhx/jecqa/data/cutted/v1 25 | train_file_list = 0_train.json 26 | 27 | valid_dataset_type = JsonFromFiles 28 | valid_formatter_type = ComatchQA 29 | valid_data_path = /data/disk3/private/zhx/jecqa/data/cutted/v1 30 | valid_file_list = 0_test.json 31 | 32 | test_formatter_type = ComatchQA 33 | 34 | topk = 18 35 | multi_choice = True 36 | 37 | word2id = /data/disk3/private/zhx/jecqa/data/word2id.txt 38 | sent_max_len = 50 39 | max_sent = 50 40 | 41 | 42 | [model] #model parameters 43 | model_name = Comatch 44 | 45 | hidden_size = 150 46 | 47 | dropout = 0.2 48 | 49 | [output] #output parameters 50 | model_path = /data/disk3/private/zhx/jecqa/model/typea/multi 51 | model_name = Comatch 52 | 53 | tensorboard_path = /data/disk3/private/zhx/jecqa/tensorboard 54 | 55 | output_function = Basic 56 | -------------------------------------------------------------------------------- /config/qa/typea_single/HAF.config: -------------------------------------------------------------------------------- 1 | [train] #train parameters 2 | epoch = 64 3 | batch_size = 12 4 | 5 | shuffle = True 6 | 7 | reader_num = 16 8 | 9 | optimizer = adam 10 | learning_rate = 1e-3 11 | step_size = 1 12 | lr_multiplier = 0.95 13 | 14 | [eval] #eval parameters 15 | batch_size = 12 16 | 17 | shuffle = False 18 | 19 | reader_num = 16 20 | 21 | [data] #data parameters 22 | train_dataset_type = JsonFromFiles 23 | train_formatter_type = HAF 24 | train_data_path = /data/disk3/private/zhx/jecqa/data/cutted/v1 25 | train_file_list = 0_train.json 26 | 27 | valid_dataset_type = JsonFromFiles 28 | valid_formatter_type = HAF 29 | valid_data_path = /data/disk3/private/zhx/jecqa/data/cutted/v1 30 | valid_file_list = 0_test.json 31 | 32 | test_formatter_type = HAF 33 | 34 | topk = 18 35 | multi_choice = False 36 | 37 | question_len = 64 38 | option_len = 48 39 | passage_len = 256 40 | 41 | word2id = /data/disk3/private/zhx/jecqa/data/word2id.txt 42 | 43 | [model] #model parameters 44 | model_name = HAF 45 | 46 | hidden_size = 256 47 | 48 | bi_direction = True 49 | num_layers = 2 50 | 51 | [output] #output parameters 52 | model_path = /data/disk3/private/zhx/jecqa/model/typea/single 53 | model_name = HAF 54 | 55 | tensorboard_path = /data/disk3/private/zhx/jecqa/tensorboard 56 | 57 | output_function = Basic 58 | -------------------------------------------------------------------------------- /config/qa/typea_single/bert.config: -------------------------------------------------------------------------------- 1 | [train] #train parameters 2 | epoch = 16 3 | batch_size = 1 4 | 5 | shuffle = True 6 | 7 | reader_num = 16 8 | 9 | optimizer = bert_adam 10 | learning_rate = 1e-5 11 | step_size = 1 12 | lr_multiplier = 1 13 | 14 | [eval] #eval parameters 15 | batch_size = 1 16 | 17 | shuffle = False 18 | 19 | reader_num = 16 20 | 21 | [data] #data parameters 22 | train_dataset_type = JsonFromFiles 23 | train_formatter_type = BertQA 24 | train_data_path = /data/disk3/private/zhx/jecqa/data/origin/v1 25 | train_file_list = 0_train.json 26 | 27 | valid_dataset_type = JsonFromFiles 28 | valid_formatter_type = BertQA 29 | valid_data_path = /data/disk3/private/zhx/jecqa/data/origin/v1 30 | valid_file_list = 0_test.json 31 | 32 | test_formatter_type = BertQA 33 | test_data_path = /data/disk3/private/zhx/jecqa/data/origin/v1 34 | 35 | topk = 16 36 | multi_choice = False 37 | 38 | max_len1 = 64 39 | max_len2 = 192 40 | 41 | [model] #model parameters 42 | model_name = Bert 43 | 44 | bert_path = /data/disk1/private/zhx/bert/chinese 45 | 46 | hidden_size = 768 47 | 48 | [output] #output parameters 49 | model_path = /data/disk3/private/zhx/jecqa/model/typea/single 50 | model_name = Bert 51 | 52 | tensorboard_path = /data/disk3/private/zhx/jecqa/tensorboard 53 | 54 | output_function = Basic 55 | -------------------------------------------------------------------------------- /config/qa/typea_single/bidaf.config: -------------------------------------------------------------------------------- 1 | [train] #train parameters 2 | epoch = 160 3 | batch_size = 12 4 | 5 | shuffle = True 6 | 7 | reader_num = 16 8 | 9 | optimizer = adam 10 | learning_rate = 1e-3 11 | step_size = 1 12 | lr_multiplier = 1 13 | 14 | [eval] #eval parameters 15 | batch_size = 12 16 | 17 | shuffle = False 18 | 19 | reader_num = 16 20 | 21 | [data] #data parameters 22 | train_dataset_type = JsonFromFiles 23 | train_formatter_type = WordQA 24 | train_data_path = /data/disk3/private/zhx/jecqa/data/cutted/v1 25 | train_file_list = 0_train.json 26 | 27 | valid_dataset_type = JsonFromFiles 28 | valid_formatter_type = WordQA 29 | valid_data_path = /data/disk3/private/zhx/jecqa/data/cutted/v1 30 | valid_file_list = 0_test.json 31 | 32 | test_formatter_type = WordQA 33 | 34 | topk = 18 35 | multi_choice = False 36 | 37 | max_len1 = 128 38 | max_len2 = 384 39 | 40 | word2id = /data/disk3/private/zhx/jecqa/data/word2id.txt 41 | 42 | [model] #model parameters 43 | model_name = BiDAF 44 | 45 | hidden_size = 256 46 | 47 | bi_direction = True 48 | num_layers = 2 49 | 50 | [output] #output parameters 51 | model_path = /data/disk3/private/zhx/jecqa/model/typea/single 52 | model_name = BiDAF 53 | 54 | tensorboard_path = /data/disk3/private/zhx/jecqa/tensorboard 55 | 56 | output_function = Basic 57 | -------------------------------------------------------------------------------- /config/qa/typea_single/comatch.config: -------------------------------------------------------------------------------- 1 | [train] #train parameters 2 | epoch = 128 3 | batch_size = 1 4 | 5 | shuffle = True 6 | 7 | reader_num = 16 8 | 9 | optimizer = adam 10 | learning_rate = 1e-3 11 | step_size = 1 12 | lr_multiplier = 1 13 | 14 | [eval] #eval parameters 15 | batch_size = 1 16 | 17 | shuffle = False 18 | 19 | reader_num = 16 20 | 21 | [data] #data parameters 22 | train_dataset_type = JsonFromFiles 23 | train_formatter_type = ComatchQA 24 | train_data_path = /data/disk3/private/zhx/jecqa/data/cutted/v1 25 | train_file_list = 0_train.json 26 | 27 | valid_dataset_type = JsonFromFiles 28 | valid_formatter_type = ComatchQA 29 | valid_data_path = /data/disk3/private/zhx/jecqa/data/cutted/v1 30 | valid_file_list = 0_test.json 31 | 32 | test_formatter_type = ComatchQA 33 | 34 | topk = 18 35 | multi_choice = False 36 | 37 | word2id = /data/disk3/private/zhx/jecqa/data/word2id.txt 38 | sent_max_len = 50 39 | max_sent = 50 40 | 41 | 42 | [model] #model parameters 43 | model_name = Comatch 44 | 45 | hidden_size = 150 46 | 47 | dropout = 0.2 48 | 49 | [output] #output parameters 50 | model_path = /data/disk3/private/zhx/jecqa/model/typea/single 51 | model_name = Comatch 52 | 53 | tensorboard_path = /data/disk3/private/zhx/jecqa/tensorboard 54 | 55 | output_function = Basic 56 | -------------------------------------------------------------------------------- /config/qa/typeb_multi/HAF.config: -------------------------------------------------------------------------------- 1 | [train] #train parameters 2 | epoch = 64 3 | batch_size = 12 4 | 5 | shuffle = True 6 | 7 | reader_num = 16 8 | 9 | optimizer = adam 10 | learning_rate = 1e-3 11 | step_size = 1 12 | lr_multiplier = 0.95 13 | 14 | [eval] #eval parameters 15 | batch_size = 12 16 | 17 | shuffle = False 18 | 19 | reader_num = 16 20 | 21 | [data] #data parameters 22 | train_dataset_type = JsonFromFiles 23 | train_formatter_type = HAF 24 | train_data_path = /data/disk3/private/zhx/jecqa/data/cutted/v1 25 | train_file_list = 1_train.json 26 | 27 | valid_dataset_type = JsonFromFiles 28 | valid_formatter_type = HAF 29 | valid_data_path = /data/disk3/private/zhx/jecqa/data/cutted/v1 30 | valid_file_list = 1_test.json 31 | 32 | test_formatter_type = HAF 33 | 34 | topk = 18 35 | multi_choice = True 36 | 37 | question_len = 64 38 | option_len = 48 39 | passage_len = 256 40 | 41 | word2id = /data/disk3/private/zhx/jecqa/data/word2id.txt 42 | 43 | [model] #model parameters 44 | model_name = HAF 45 | 46 | hidden_size = 256 47 | 48 | bi_direction = True 49 | num_layers = 2 50 | 51 | [output] #output parameters 52 | model_path = /data/disk3/private/zhx/jecqa/model/typeb/multi 53 | model_name = HAF 54 | 55 | tensorboard_path = /data/disk3/private/zhx/jecqa/tensorboard 56 | 57 | output_function = Basic 58 | -------------------------------------------------------------------------------- /config/qa/typeb_multi/bidaf.config: -------------------------------------------------------------------------------- 1 | [train] #train parameters 2 | epoch = 160 3 | batch_size = 12 4 | 5 | shuffle = True 6 | 7 | reader_num = 16 8 | 9 | optimizer = adam 10 | learning_rate = 1e-3 11 | step_size = 1 12 | lr_multiplier = 1 13 | 14 | [eval] #eval parameters 15 | batch_size = 12 16 | 17 | shuffle = False 18 | 19 | reader_num = 16 20 | 21 | [data] #data parameters 22 | train_dataset_type = JsonFromFiles 23 | train_formatter_type = WordQA 24 | train_data_path = /data/disk3/private/zhx/jecqa/data/cutted/v1 25 | train_file_list = 1_train.json 26 | 27 | valid_dataset_type = JsonFromFiles 28 | valid_formatter_type = WordQA 29 | valid_data_path = /data/disk3/private/zhx/jecqa/data/cutted/v1 30 | valid_file_list = 1_test.json 31 | 32 | test_formatter_type = WordQA 33 | 34 | topk = 18 35 | multi_choice = True 36 | 37 | max_len1 = 128 38 | max_len2 = 384 39 | 40 | word2id = /data/disk3/private/zhx/jecqa/data/word2id.txt 41 | 42 | [model] #model parameters 43 | model_name = BiDAF 44 | 45 | hidden_size = 256 46 | 47 | bi_direction = True 48 | num_layers = 2 49 | 50 | [output] #output parameters 51 | model_path = /data/disk3/private/zhx/jecqa/model/typeb/multi 52 | model_name = BiDAF 53 | 54 | tensorboard_path = /data/disk3/private/zhx/jecqa/tensorboard 55 | 56 | output_function = Basic 57 | -------------------------------------------------------------------------------- /config/qa/typeb_multi/comatch.config: -------------------------------------------------------------------------------- 1 | [train] #train parameters 2 | epoch = 128 3 | batch_size = 1 4 | 5 | shuffle = True 6 | 7 | reader_num = 16 8 | 9 | optimizer = adam 10 | learning_rate = 1e-3 11 | step_size = 1 12 | lr_multiplier = 1 13 | 14 | [eval] #eval parameters 15 | batch_size = 1 16 | 17 | shuffle = False 18 | 19 | reader_num = 16 20 | 21 | [data] #data parameters 22 | train_dataset_type = JsonFromFiles 23 | train_formatter_type = ComatchQA 24 | train_data_path = /data/disk3/private/zhx/jecqa/data/cutted/v1 25 | train_file_list = 1_train.json 26 | 27 | valid_dataset_type = JsonFromFiles 28 | valid_formatter_type = ComatchQA 29 | valid_data_path = /data/disk3/private/zhx/jecqa/data/cutted/v1 30 | valid_file_list = 1_test.json 31 | 32 | test_formatter_type = ComatchQA 33 | 34 | topk = 18 35 | multi_choice = True 36 | 37 | word2id = /data/disk3/private/zhx/jecqa/data/word2id.txt 38 | sent_max_len = 50 39 | max_sent = 50 40 | 41 | 42 | [model] #model parameters 43 | model_name = Comatch 44 | 45 | hidden_size = 150 46 | 47 | dropout = 0.2 48 | 49 | [output] #output parameters 50 | model_path = /data/disk3/private/zhx/jecqa/model/typeb/multi 51 | model_name = Comatch 52 | 53 | tensorboard_path = /data/disk3/private/zhx/jecqa/tensorboard 54 | 55 | output_function = Basic 56 | -------------------------------------------------------------------------------- /config/qa/typeb_single/HAF.config: -------------------------------------------------------------------------------- 1 | [train] #train parameters 2 | epoch = 64 3 | batch_size = 12 4 | 5 | shuffle = True 6 | 7 | reader_num = 16 8 | 9 | optimizer = adam 10 | learning_rate = 1e-3 11 | step_size = 1 12 | lr_multiplier = 0.95 13 | 14 | [eval] #eval parameters 15 | batch_size = 12 16 | 17 | shuffle = False 18 | 19 | reader_num = 16 20 | 21 | [data] #data parameters 22 | train_dataset_type = JsonFromFiles 23 | train_formatter_type = HAF 24 | train_data_path = /data/disk3/private/zhx/jecqa/data/cutted/v1 25 | train_file_list = 1_train.json 26 | 27 | valid_dataset_type = JsonFromFiles 28 | valid_formatter_type = HAF 29 | valid_data_path = /data/disk3/private/zhx/jecqa/data/cutted/v1 30 | valid_file_list = 1_test.json 31 | 32 | test_formatter_type = HAF 33 | 34 | topk = 18 35 | multi_choice = False 36 | 37 | question_len = 64 38 | option_len = 48 39 | passage_len = 256 40 | 41 | word2id = /data/disk3/private/zhx/jecqa/data/word2id.txt 42 | 43 | [model] #model parameters 44 | model_name = HAF 45 | 46 | hidden_size = 256 47 | 48 | bi_direction = True 49 | num_layers = 2 50 | 51 | [output] #output parameters 52 | model_path = /data/disk3/private/zhx/jecqa/model/typeb/single 53 | model_name = HAF 54 | 55 | tensorboard_path = /data/disk3/private/zhx/jecqa/tensorboard 56 | 57 | output_function = Basic 58 | -------------------------------------------------------------------------------- /config/qa/typeb_single/bidaf.config: -------------------------------------------------------------------------------- 1 | [train] #train parameters 2 | epoch = 160 3 | batch_size = 12 4 | 5 | shuffle = True 6 | 7 | reader_num = 16 8 | 9 | optimizer = adam 10 | learning_rate = 1e-3 11 | step_size = 1 12 | lr_multiplier = 1 13 | 14 | [eval] #eval parameters 15 | batch_size = 12 16 | 17 | shuffle = False 18 | 19 | reader_num = 16 20 | 21 | [data] #data parameters 22 | train_dataset_type = JsonFromFiles 23 | train_formatter_type = WordQA 24 | train_data_path = /data/disk3/private/zhx/jecqa/data/cutted/v1 25 | train_file_list = 1_train.json 26 | 27 | valid_dataset_type = JsonFromFiles 28 | valid_formatter_type = WordQA 29 | valid_data_path = /data/disk3/private/zhx/jecqa/data/cutted/v1 30 | valid_file_list = 1_test.json 31 | 32 | test_formatter_type = WordQA 33 | 34 | topk = 18 35 | multi_choice = False 36 | 37 | max_len1 = 128 38 | max_len2 = 384 39 | 40 | word2id = /data/disk3/private/zhx/jecqa/data/word2id.txt 41 | 42 | [model] #model parameters 43 | model_name = BiDAF 44 | 45 | hidden_size = 256 46 | 47 | bi_direction = True 48 | num_layers = 2 49 | 50 | [output] #output parameters 51 | model_path = /data/disk3/private/zhx/jecqa/model/typeb/single 52 | model_name = BiDAF 53 | 54 | tensorboard_path = /data/disk3/private/zhx/jecqa/tensorboard 55 | 56 | output_function = Basic 57 | -------------------------------------------------------------------------------- /config/qa/typeb_single/comatch.config: -------------------------------------------------------------------------------- 1 | [train] #train parameters 2 | epoch = 128 3 | batch_size = 1 4 | 5 | shuffle = True 6 | 7 | reader_num = 16 8 | 9 | optimizer = adam 10 | learning_rate = 1e-3 11 | step_size = 1 12 | lr_multiplier = 1 13 | 14 | [eval] #eval parameters 15 | batch_size = 1 16 | 17 | shuffle = False 18 | 19 | reader_num = 16 20 | 21 | [data] #data parameters 22 | train_dataset_type = JsonFromFiles 23 | train_formatter_type = ComatchQA 24 | train_data_path = /data/disk3/private/zhx/jecqa/data/cutted/v1 25 | train_file_list = 1_train.json 26 | 27 | valid_dataset_type = JsonFromFiles 28 | valid_formatter_type = ComatchQA 29 | valid_data_path = /data/disk3/private/zhx/jecqa/data/cutted/v1 30 | valid_file_list = 1_test.json 31 | 32 | test_formatter_type = ComatchQA 33 | 34 | topk = 18 35 | multi_choice = False 36 | 37 | word2id = /data/disk3/private/zhx/jecqa/data/word2id.txt 38 | sent_max_len = 50 39 | max_sent = 50 40 | 41 | 42 | [model] #model parameters 43 | model_name = Comatch 44 | 45 | hidden_size = 150 46 | 47 | dropout = 0.2 48 | 49 | [output] #output parameters 50 | model_path = /data/disk3/private/zhx/jecqa/model/typeb/single 51 | model_name = Comatch 52 | 53 | tensorboard_path = /data/disk3/private/zhx/jecqa/tensorboard 54 | 55 | output_function = Basic 56 | -------------------------------------------------------------------------------- /config_parser/__init__.py: -------------------------------------------------------------------------------- 1 | from .parser import create_config 2 | -------------------------------------------------------------------------------- /config_parser/parser.py: -------------------------------------------------------------------------------- 1 | import configparser 2 | import os 3 | import functools 4 | 5 | 6 | class ConfigParser: 7 | def __init__(self, *args, **params): 8 | self.default_config = configparser.RawConfigParser(*args, **params) 9 | self.local_config = configparser.RawConfigParser(*args, **params) 10 | self.config = configparser.RawConfigParser(*args, **params) 11 | 12 | def read(self, filenames, encoding=None): 13 | if os.path.exists("config/default_local.config"): 14 | self.local_config.read("config/default_local.config", encoding=encoding) 15 | else: 16 | self.local_config.read("config/default.config", encoding=encoding) 17 | 18 | self.default_config.read("config/default.config", encoding=encoding) 19 | self.config.read(filenames, encoding=encoding) 20 | 21 | 22 | def _build_func(func_name): 23 | @functools.wraps(getattr(configparser.RawConfigParser, func_name)) 24 | def func(self, *args, **kwargs): 25 | try: 26 | return getattr(self.config, func_name)(*args, **kwargs) 27 | except Exception as e: 28 | try: 29 | return getattr(self.local_config, func_name)(*args, **kwargs) 30 | except Exception as e: 31 | return getattr(self.default_config, func_name)(*args, **kwargs) 32 | 33 | return func 34 | 35 | 36 | def create_config(path): 37 | for func_name in dir(configparser.RawConfigParser): 38 | if not func_name.startswith('_') and func_name != "read": 39 | setattr(ConfigParser, func_name, _build_func(func_name)) 40 | 41 | config = ConfigParser() 42 | config.read(path) 43 | 44 | return config 45 | -------------------------------------------------------------------------------- /dataset/__init__.py: -------------------------------------------------------------------------------- 1 | from .nlp.JsonFromFiles import JsonFromFilesDataset 2 | 3 | dataset_list = { 4 | "JsonFromFiles": JsonFromFilesDataset 5 | } 6 | -------------------------------------------------------------------------------- /dataset/nlp/JsonFromFiles.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | from torch.utils.data import Dataset 4 | import random 5 | 6 | from tools.dataset_tool import dfs_search 7 | 8 | 9 | class JsonFromFilesDataset(Dataset): 10 | def __init__(self, config, mode, encoding="utf8", *args, **params): 11 | self.config = config 12 | self.mode = mode 13 | self.file_list = [] 14 | self.data_path = config.get("data", "%s_data_path" % mode) 15 | self.encoding = encoding 16 | 17 | filename_list = config.get("data", "%s_file_list" % mode).replace(" ", "").split(",") 18 | recursive = False 19 | 20 | multi = config.getboolean("data", "multi_choice") 21 | 22 | for name in filename_list: 23 | self.file_list = self.file_list + dfs_search(os.path.join(self.data_path, name), recursive) 24 | self.file_list.sort() 25 | 26 | self.data = [] 27 | for filename in self.file_list: 28 | f = open(filename, "r", encoding=encoding) 29 | for line in f: 30 | data = json.loads(line) 31 | if (not multi) and len(data["answer"]) != 1: 32 | if mode != "test": 33 | continue 34 | self.data.append(json.loads(line)) 35 | 36 | if mode == "train": 37 | random.shuffle(self.data) 38 | 39 | self.reduce = config.getboolean("data", "reduce") 40 | if mode != "train": 41 | self.reduce = False 42 | if self.reduce: 43 | self.reduce_ratio = config.getfloat("data", "reduce_ratio") 44 | 45 | def __getitem__(self, item): 46 | if self.reduce: 47 | return self.data[random.randint(0, len(self.data) - 1)] 48 | return self.data[item] 49 | 50 | def __len__(self): 51 | if self.reduce: 52 | return int(self.reduce_ratio * len(self.data)) 53 | return len(self.data) 54 | -------------------------------------------------------------------------------- /dataset/nlp/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thunlp/jec-qa/8706c64ff62637e61cd8a729815f585a6df3b3f1/dataset/nlp/__init__.py -------------------------------------------------------------------------------- /formatter/Basic.py: -------------------------------------------------------------------------------- 1 | class BasicFormatter: 2 | def __init__(self, config, mode, *args, **params): 3 | self.config = config 4 | self.mode = mode 5 | 6 | def process(self, data, config, mode, *args, **params): 7 | return data 8 | 9 | -------------------------------------------------------------------------------- /formatter/__init__.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | from formatter.qa.Bert import BertQA 4 | from formatter.qa.Char import CharQA 5 | from formatter.qa.Comatch import ComatchingFormatter 6 | from formatter.qa.HAF import HAFQA 7 | from formatter.qa.Word import WordQA 8 | 9 | logger = logging.getLogger(__name__) 10 | 11 | formatter_list = { 12 | "BertQA": BertQA, 13 | "CharQA": CharQA, 14 | "WordQA": WordQA, 15 | "ComatchQA": ComatchingFormatter, 16 | "HAF": HAFQA 17 | } 18 | 19 | 20 | def init_formatter(config, mode, *args, **params): 21 | temp_mode = mode 22 | if mode != "train": 23 | try: 24 | config.get("data", "%s_formatter_type" % temp_mode) 25 | except Exception as e: 26 | logger.warning( 27 | "[reader] %s_formatter_type has not been defined in config file, use [dataset] train_formatter_type instead." % temp_mode) 28 | temp_mode = "train" 29 | which = config.get("data", "%s_formatter_type" % temp_mode) 30 | 31 | if which in formatter_list: 32 | formatter = formatter_list[which](config, mode, *args, **params) 33 | 34 | return formatter 35 | else: 36 | logger.error("There is no formatter called %s, check your config." % which) 37 | raise NotImplementedError 38 | -------------------------------------------------------------------------------- /formatter/qa/Bert.py: -------------------------------------------------------------------------------- 1 | import json 2 | import torch 3 | import numpy as np 4 | import os 5 | from pytorch_pretrained_bert import BertTokenizer 6 | 7 | 8 | class BertQA: 9 | def __init__(self, config, mode): 10 | self.max_len1 = config.getint("data", "max_len1") 11 | self.max_len2 = config.getint("data", "max_len2") 12 | 13 | self.tokenizer = BertTokenizer.from_pretrained(os.path.join(config.get("model", "bert_path"), "vocab.txt")) 14 | self.k = config.getint("data", "topk") 15 | 16 | def convert(self, tokens, which, l): 17 | mask = [] 18 | tokenx = [] 19 | 20 | tokens = self.tokenizer.tokenize(tokens) 21 | ids = self.tokenizer.convert_tokens_to_ids(tokens) 22 | 23 | for a in range(0, len(ids)): 24 | mask.append(1) 25 | tokenx.append(which) 26 | 27 | while len(ids) < l: 28 | ids.append(self.tokenizer.vocab["[PAD]"]) 29 | mask.append(0) 30 | tokenx.append(which) 31 | 32 | ids = torch.LongTensor(ids) 33 | mask = torch.LongTensor(mask) 34 | tokenx = torch.LongTensor(tokenx) 35 | 36 | return ids, mask, tokenx 37 | 38 | def process(self, data, config, mode, *args, **params): 39 | txt = [] 40 | mask = [] 41 | token = [] 42 | label = [] 43 | idx = [] 44 | 45 | for temp_data in data: 46 | idx.append(temp_data["id"]) 47 | if config.getboolean("data", "multi_choice"): 48 | label_x = 0 49 | if "A" in temp_data["answer"]: 50 | label_x += 1 51 | if "B" in temp_data["answer"]: 52 | label_x += 2 53 | if "C" in temp_data["answer"]: 54 | label_x += 4 55 | if "D" in temp_data["answer"]: 56 | label_x += 8 57 | else: 58 | label_x = 0 59 | if "A" in temp_data["answer"]: 60 | label_x = 0 61 | if "B" in temp_data["answer"]: 62 | label_x = 1 63 | if "C" in temp_data["answer"]: 64 | label_x = 2 65 | if "D" in temp_data["answer"]: 66 | label_x = 3 67 | 68 | label.append(label_x) 69 | 70 | temp_text = [] 71 | temp_mask = [] 72 | temp_token = [] 73 | 74 | for option in ["A", "B", "C", "D"]: 75 | res = temp_data["statement"] + temp_data["option_list"][option] 76 | text = [] 77 | 78 | for a in range(0, len(res)): 79 | text = text + [res[a]] 80 | text = text[0:self.max_len1] 81 | 82 | txt1, mask1, token1 = self.convert(text, 0, self.max_len1) 83 | 84 | ref = [] 85 | k = [0, 1, 2, 6, 12, 7, 13, 3, 8, 9, 14, 15, 4, 10, 16, 5, 16, 17] 86 | for a in range(0, self.k): 87 | res = temp_data["reference"][option][k[a]] 88 | text = [] 89 | 90 | for a in range(0, len(res)): 91 | text = text + [res[a]] 92 | text = text[0:self.max_len2] 93 | 94 | txt2, mask2, token2 = self.convert(text, 1, self.max_len2) 95 | 96 | temp_text.append(torch.cat([txt1, txt2])) 97 | temp_mask.append(torch.cat([mask1, mask2])) 98 | temp_token.append(torch.cat([token1, token2])) 99 | 100 | txt.append(torch.stack(temp_text)) 101 | mask.append(torch.stack(temp_mask)) 102 | token.append(torch.stack(temp_token)) 103 | 104 | txt = torch.stack(txt) 105 | mask = torch.stack(mask) 106 | token = torch.stack(token) 107 | label = torch.LongTensor(np.array(label, dtype=np.int32)) 108 | 109 | return {"text": txt, "mask": mask, "token": token, 'label': label, "id": idx} 110 | -------------------------------------------------------------------------------- /formatter/qa/Char.py: -------------------------------------------------------------------------------- 1 | import json 2 | import torch 3 | import numpy as np 4 | import os 5 | from pytorch_pretrained_bert import BertTokenizer 6 | 7 | 8 | class CharQA: 9 | def __init__(self, config, mode): 10 | self.max_len1 = config.getint("data", "max_len1") 11 | self.max_len2 = config.getint("data", "max_len2") 12 | 13 | self.tokenizer = BertTokenizer.from_pretrained(config.get("data", "word2id")) 14 | self.k = config.getint("data", "topk") 15 | 16 | def convert(self, tokens, l): 17 | tokens = self.tokenizer.tokenize(tokens) 18 | while len(tokens) < l: 19 | tokens.append("[PAD]") 20 | tokens = tokens[:l] 21 | ids = self.tokenizer.convert_tokens_to_ids(tokens) 22 | 23 | return ids 24 | 25 | def process(self, data, config, mode, *args, **params): 26 | context = [] 27 | question = [] 28 | label = [] 29 | 30 | for temp_data in data: 31 | if config.getboolean("data", "multi_choice"): 32 | label_x = 0 33 | if "A" in temp_data["answer"]: 34 | label_x += 1 35 | if "B" in temp_data["answer"]: 36 | label_x += 2 37 | if "C" in temp_data["answer"]: 38 | label_x += 4 39 | if "D" in temp_data["answer"]: 40 | label_x += 8 41 | else: 42 | label_x = 0 43 | if "A" in temp_data["answer"]: 44 | label_x = 0 45 | if "B" in temp_data["answer"]: 46 | label_x = 1 47 | if "C" in temp_data["answer"]: 48 | label_x = 2 49 | if "D" in temp_data["answer"]: 50 | label_x = 3 51 | 52 | label.append(label_x) 53 | 54 | temp_context = [] 55 | temp_question = [] 56 | 57 | for option in ["A", "B", "C", "D"]: 58 | res = temp_data["statement"] + temp_data["option_list"][option] 59 | text = [] 60 | temp_question.append(self.convert(res, self.max_len1)) 61 | 62 | ref = [] 63 | k = [0, 1, 2, 6, 12, 7, 13, 3, 8, 9, 14, 15, 4, 10, 16, 5, 16, 17] 64 | for a in range(0, self.k): 65 | res = temp_data["reference"][option][k[a]] 66 | 67 | ref.append(self.convert(res, self.max_len2)) 68 | 69 | temp_context.append(ref) 70 | 71 | context.append(temp_context) 72 | question.append(temp_question) 73 | 74 | question = torch.LongTensor(question) 75 | context = torch.LongTensor(context) 76 | label = torch.LongTensor(np.array(label, dtype=np.int32)) 77 | 78 | return {"context": context, "question": question, 'label': label} 79 | -------------------------------------------------------------------------------- /formatter/qa/Comatch.py: -------------------------------------------------------------------------------- 1 | import json 2 | import torch 3 | import numpy as np 4 | import jieba 5 | import random 6 | 7 | 8 | class ComatchingFormatter: 9 | def __init__(self, config, mode): 10 | self.word2id = json.load(open(config.get("data", "word2id"), "r")) 11 | 12 | self.sent_max_len = config.getint("data", "sent_max_len") 13 | self.max_sent = config.getint("data", "max_sent") 14 | self.k = config.getint("data", "topk") 15 | 16 | self.symbol = [",", ".", "?", "\"", "”", "。", "?", ""] 17 | self.last_symbol = [".", "?", "。", "?"] 18 | 19 | def transform(self, word): 20 | if not (word in self.word2id.keys()): 21 | return self.word2id["UNK"] 22 | else: 23 | return self.word2id[word] 24 | 25 | def seq2tensor(self, sents, max_len): 26 | sent_len_max = max([len(s) for s in sents]) 27 | sent_len_max = min(sent_len_max, max_len) 28 | 29 | sent_tensor = torch.LongTensor(len(sents), sent_len_max).zero_() 30 | 31 | sent_len = torch.LongTensor(len(sents)).zero_() 32 | for s_id, sent in enumerate(sents): 33 | sent_len[s_id] = len(sent) 34 | for w_id, word in enumerate(sent): 35 | if w_id >= sent_len_max: break 36 | sent_tensor[s_id][w_id] = self.transform(word) 37 | return [sent_tensor, sent_len] 38 | 39 | def seq2Htensor(self, docs, max_sent, max_sent_len, v1=0, v2=0): 40 | sent_num_max = max([len(s) for s in docs]) 41 | sent_num_max = min(sent_num_max, max_sent) 42 | sent_len_max = max([len(w) for s in docs for w in s]) 43 | sent_len_max = min(sent_len_max, max_sent_len) 44 | sent_num_max = max(sent_num_max, v1) 45 | sent_len_max = max(sent_len_max, v2) 46 | 47 | sent_tensor = torch.LongTensor(len(docs), sent_num_max, sent_len_max).zero_() 48 | sent_len = torch.LongTensor(len(docs), sent_num_max).zero_() 49 | doc_len = torch.LongTensor(len(docs)).zero_() 50 | for d_id, doc in enumerate(docs): 51 | doc_len[d_id] = len(doc) 52 | for s_id, sent in enumerate(doc): 53 | if s_id >= sent_num_max: break 54 | sent_len[d_id][s_id] = len(sent) 55 | for w_id, word in enumerate(sent): 56 | if w_id >= sent_len_max: break 57 | sent_tensor[d_id][s_id][w_id] = self.transform(word) 58 | return [sent_tensor, doc_len, sent_len] 59 | 60 | def gen_max(self, docs, max_sent, max_sent_len): 61 | sent_num_max = max([len(s) for s in docs]) 62 | sent_num_max = min(sent_num_max, max_sent) 63 | sent_len_max = max([len(w) for s in docs for w in s]) 64 | sent_len_max = min(sent_len_max, max_sent_len) 65 | 66 | return sent_num_max, sent_len_max 67 | 68 | def parse(self, sent): 69 | result = [] 70 | for word in sent: 71 | if len(word) == 0: 72 | continue 73 | 74 | result.append(word) 75 | 76 | return result 77 | 78 | def parseH(self, sent): 79 | result = [] 80 | temp = [] 81 | for word in sent: 82 | temp.append(word) 83 | last = False 84 | for symbol in self.last_symbol: 85 | if word == symbol: 86 | last = True 87 | if last: 88 | result.append(temp) 89 | temp = [] 90 | 91 | if len(temp) != 0: 92 | result.append(temp) 93 | 94 | return result 95 | 96 | def process(self, data, config, mode, *args, **params): 97 | document = [[], [], [], []] 98 | option = [] 99 | question = [] 100 | label = [] 101 | 102 | for temp_data in data: 103 | question.append(self.parse(temp_data["statement"])) 104 | 105 | if config.getboolean("data", "multi_choice"): 106 | option.append([self.parse(temp_data["option_list"]["A"]), 107 | self.parse(temp_data["option_list"]["B"]), 108 | self.parse(temp_data["option_list"]["C"]), 109 | self.parse(temp_data["option_list"]["D"])]) 110 | 111 | label_x = 0 112 | if "A" in temp_data["answer"]: 113 | label_x += 1 114 | if "B" in temp_data["answer"]: 115 | label_x += 2 116 | if "C" in temp_data["answer"]: 117 | label_x += 4 118 | if "D" in temp_data["answer"]: 119 | label_x += 8 120 | else: 121 | option.append([self.parse(temp_data["option_list"]["A"]), 122 | self.parse(temp_data["option_list"]["B"]), 123 | self.parse(temp_data["option_list"]["C"]), 124 | self.parse(temp_data["option_list"]["D"])]) 125 | 126 | label_x = 0 127 | if "A" in temp_data["answer"]: 128 | label_x = 0 129 | if "B" in temp_data["answer"]: 130 | label_x = 1 131 | if "C" in temp_data["answer"]: 132 | label_x = 2 133 | if "D" in temp_data["answer"]: 134 | label_x = 3 135 | 136 | temp = [] 137 | for a in range(0, 4): 138 | arr = ["A", "B", "C", "D"] 139 | res = [] 140 | k = [0, 1, 2, 6, 12, 7, 13, 3, 8, 9, 14, 15, 4, 10, 16, 5, 16, 17] 141 | for b in range(0, self.k): 142 | res.append(self.parseH(temp_data["reference"][arr[a]][k[b]])) 143 | document[a].append(res) 144 | 145 | label.append(label_x) 146 | 147 | v1 = 0 148 | v2 = 0 149 | for a in range(0, 4): 150 | for b in range(0, len(document[a])): 151 | v1t, v2t = self.gen_max(document[a][b], self.max_sent, self.sent_max_len) 152 | v1 = max(v1, v1t) 153 | v2 = max(v2, v2t) 154 | option = self.seq2Htensor(option, self.max_sent, self.sent_max_len) 155 | question = self.seq2tensor(question, self.sent_max_len) 156 | 157 | for a in range(0, 4): 158 | for b in range(0, len(document[a])): 159 | document[a][b] = self.seq2Htensor(document[a][b], self.max_sent, self.sent_max_len, v1, v2) 160 | 161 | document_sent = [] 162 | document_len = [] 163 | do = [] 164 | for a in range(0, 4): 165 | d = [] 166 | ds = [] 167 | dl = [] 168 | for b in range(0, len(document[a])): 169 | d.append(document[a][b][0]) 170 | ds.append(document[a][b][1]) 171 | dl.append(document[a][b][2]) 172 | 173 | d = torch.stack(d) 174 | ds = torch.stack(ds) 175 | dl = torch.stack(dl) 176 | 177 | do.append(d) 178 | document_sent.append(ds) 179 | document_len.append(dl) 180 | 181 | document = torch.stack(do) 182 | document_len = torch.stack(document_len) 183 | document_sent = torch.stack(document_sent) 184 | 185 | document = torch.transpose(document, 1, 0) 186 | document_len = torch.transpose(document_len, 1, 0) 187 | document_sent = torch.transpose(document_sent, 1, 0) 188 | 189 | label = torch.tensor(label, dtype=torch.long) 190 | 191 | return { 192 | "question": question[0], 193 | "question_len": question[1], 194 | "option": option[0], 195 | "option_sent": option[1], 196 | "option_len": option[2], 197 | "document": document, 198 | "document_sent": document_sent, 199 | "document_len": document_len, 200 | "label": label 201 | } 202 | -------------------------------------------------------------------------------- /formatter/qa/HAF.py: -------------------------------------------------------------------------------- 1 | import json 2 | import torch 3 | import numpy as np 4 | import os 5 | from pytorch_pretrained_bert import BertTokenizer 6 | 7 | 8 | class HAFQA: 9 | def __init__(self, config, mode): 10 | self.question_len = config.getint("data", "question_len") 11 | self.option_len = config.getint("data", "option_len") 12 | self.passage_len = config.getint("data", "passage_len") 13 | 14 | self.word2id = json.load(open(config.get("data", "word2id"), "r")) 15 | self.k = config.getint("data", "topk") 16 | 17 | def convert_tokens_to_ids(self, tokens): 18 | arr = [] 19 | for a in range(0, len(tokens)): 20 | if tokens[a] in self.word2id: 21 | arr.append(self.word2id[tokens[a]]) 22 | else: 23 | arr.append(self.word2id["UNK"]) 24 | return arr 25 | 26 | def convert(self, tokens, l): 27 | while len(tokens) < l: 28 | tokens.append("[PAD]") 29 | tokens = tokens[:l] 30 | ids = self.convert_tokens_to_ids(tokens) 31 | 32 | return ids 33 | 34 | def process(self, data, config, mode, *args, **params): 35 | passage = [] 36 | question = [] 37 | option = [] 38 | label = [] 39 | idx = [] 40 | 41 | for temp_data in data: 42 | idx.append(temp_data["id"]) 43 | if config.getboolean("data", "multi_choice"): 44 | label_x = 0 45 | if "A" in temp_data["answer"]: 46 | label_x += 1 47 | if "B" in temp_data["answer"]: 48 | label_x += 2 49 | if "C" in temp_data["answer"]: 50 | label_x += 4 51 | if "D" in temp_data["answer"]: 52 | label_x += 8 53 | else: 54 | label_x = 0 55 | if "A" in temp_data["answer"]: 56 | label_x = 0 57 | if "B" in temp_data["answer"]: 58 | label_x = 1 59 | if "C" in temp_data["answer"]: 60 | label_x = 2 61 | if "D" in temp_data["answer"]: 62 | label_x = 3 63 | 64 | label.append(label_x) 65 | 66 | temp_passage = [] 67 | temp_option = [] 68 | question.append(self.convert(temp_data["statement"], self.question_len)) 69 | 70 | for option_ in ["A", "B", "C", "D"]: 71 | temp_option.append(self.convert(temp_data["option_list"][option_], self.option_len)) 72 | 73 | ref = [] 74 | k = [0, 1, 2, 6, 12, 7, 13, 3, 8, 9, 14, 15, 4, 10, 16, 5, 16, 17] 75 | for a in range(0, self.k): 76 | res = temp_data["reference"][option_][k[a]] 77 | 78 | ref.append(self.convert(res, self.passage_len)) 79 | 80 | temp_passage.append(ref) 81 | 82 | passage.append(temp_passage) 83 | option.append(temp_option) 84 | 85 | question = torch.LongTensor(question) 86 | passage = torch.LongTensor(passage) 87 | option = torch.LongTensor(option) 88 | label = torch.LongTensor(np.array(label, dtype=np.int32)) 89 | 90 | return {"passage": passage, "option": option, "question": question, 'label': label, "id": idx} 91 | -------------------------------------------------------------------------------- /formatter/qa/Word.py: -------------------------------------------------------------------------------- 1 | import json 2 | import torch 3 | import numpy as np 4 | import os 5 | from pytorch_pretrained_bert import BertTokenizer 6 | 7 | 8 | class WordQA: 9 | def __init__(self, config, mode): 10 | self.max_len1 = config.getint("data", "max_len1") 11 | self.max_len2 = config.getint("data", "max_len2") 12 | 13 | self.word2id = json.load(open(config.get("data", "word2id"), "r")) 14 | self.k = config.getint("data", "topk") 15 | 16 | def convert_tokens_to_ids(self, tokens): 17 | arr = [] 18 | for a in range(0, len(tokens)): 19 | if tokens[a] in self.word2id: 20 | arr.append(self.word2id[tokens[a]]) 21 | else: 22 | arr.append(self.word2id["UNK"]) 23 | return arr 24 | 25 | def convert(self, tokens, l, bk=False): 26 | while len(tokens) < l: 27 | tokens.append("PAD") 28 | if bk: 29 | tokens = tokens[len(tokens) - l:] 30 | else: 31 | tokens = tokens[:l] 32 | ids = self.convert_tokens_to_ids(tokens) 33 | 34 | return ids 35 | 36 | def process(self, data, config, mode, *args, **params): 37 | context = [] 38 | question = [] 39 | label = [] 40 | idx = [] 41 | 42 | for temp_data in data: 43 | idx.append(temp_data["id"]) 44 | if config.getboolean("data", "multi_choice"): 45 | label_x = 0 46 | if "A" in temp_data["answer"]: 47 | label_x += 1 48 | if "B" in temp_data["answer"]: 49 | label_x += 2 50 | if "C" in temp_data["answer"]: 51 | label_x += 4 52 | if "D" in temp_data["answer"]: 53 | label_x += 8 54 | else: 55 | label_x = 0 56 | if "A" in temp_data["answer"]: 57 | label_x = 0 58 | if "B" in temp_data["answer"]: 59 | label_x = 1 60 | if "C" in temp_data["answer"]: 61 | label_x = 2 62 | if "D" in temp_data["answer"]: 63 | label_x = 3 64 | 65 | label.append(label_x) 66 | 67 | temp_context = [] 68 | temp_question = [] 69 | 70 | for option in ["A", "B", "C", "D"]: 71 | res = temp_data["statement"] + temp_data["option_list"][option] 72 | text = [] 73 | temp_question.append(self.convert(res, self.max_len1, bk=True)) 74 | 75 | ref = [] 76 | k = [0, 1, 2, 6, 12, 7, 13, 3, 8, 9, 14, 15, 4, 10, 16, 5, 16, 17] 77 | for a in range(0, self.k): 78 | res = temp_data["reference"][option][k[a]] 79 | 80 | ref.append(self.convert(res, self.max_len2)) 81 | 82 | temp_context.append(ref) 83 | 84 | context.append(temp_context) 85 | question.append(temp_question) 86 | 87 | question = torch.LongTensor(question) 88 | context = torch.LongTensor(context) 89 | label = torch.LongTensor(np.array(label, dtype=np.int32)) 90 | 91 | return {"context": context, "question": question, 'label': label, "id": idx} 92 | -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- 1 | from model.qa.Bert import BertQA 2 | from model.qa.BiDAF import BiDAFQA 3 | from model.qa.CoMatch import CoMatching 4 | from model.qa.HAF import HAF 5 | 6 | model_list = { 7 | "Bert": BertQA, 8 | "BiDAF": BiDAFQA, 9 | "Comatch": CoMatching, 10 | "HAF": HAF 11 | } 12 | 13 | 14 | def get_model(model_name): 15 | if model_name in model_list.keys(): 16 | return model_list[model_name] 17 | else: 18 | raise NotImplementedError 19 | -------------------------------------------------------------------------------- /model/encoder/BertEncoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from pytorch_pretrained_bert import BertModel 5 | 6 | 7 | class BertEncoder(nn.Module): 8 | def __init__(self, config, gpu_list, *args, **params): 9 | super(BertEncoder, self).__init__() 10 | 11 | self.bert = BertModel.from_pretrained(config.get("model", "bert_path")) 12 | 13 | def forward(self, x): 14 | _, y = self.bert(x, output_all_encoded_layers=False) 15 | 16 | return y 17 | -------------------------------------------------------------------------------- /model/encoder/CNNEncoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class CNNEncoder(nn.Module): 7 | def __init__(self, config, gpu_list, *args, **params): 8 | super(CNNEncoder, self).__init__() 9 | 10 | self.emb_dim = config.getint("model", "hidden_Size") 11 | self.output_dim = self.emb_dim // 4 12 | 13 | self.min_gram = 2 14 | self.max_gram = 5 15 | self.convs = [] 16 | for a in range(self.min_gram, self.max_gram + 1): 17 | self.convs.append(nn.Conv2d(1, self.output_dim, (a, self.emb_dim))) 18 | 19 | self.convs = nn.ModuleList(self.convs) 20 | self.feature_len = self.emb_dim 21 | self.relu = nn.ReLU() 22 | 23 | def forward(self, x): 24 | batch_size = x.size()[0] 25 | 26 | x = x.view(batch_size, 1, -1, self.emb_dim) 27 | 28 | conv_out = [] 29 | gram = self.min_gram 30 | for conv in self.convs: 31 | y = self.relu(conv(x)) 32 | y = torch.max(y, dim=2)[0].view(batch_size, -1) 33 | 34 | conv_out.append(y) 35 | gram += 1 36 | 37 | conv_out = torch.cat(conv_out, dim=1) 38 | 39 | return conv_out 40 | -------------------------------------------------------------------------------- /model/encoder/GRUEncoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class GRUEncoder(nn.Module): 7 | def __init__(self, config, gpu_list, *args, **params): 8 | super(GRUEncoder, self).__init__() 9 | 10 | self.hidden_size = config.getint("model", "hidden_size") 11 | self.bi = config.getboolean("model", "bi_direction") 12 | self.output_size = self.hidden_size 13 | self.num_layers = config.getint("model", "num_layers") 14 | if self.bi: 15 | self.output_size = self.output_size // 2 16 | 17 | self.gru = nn.GRU(input_size=self.hidden_size, hidden_size=self.output_size, num_layers=self.num_layers, 18 | batch_first=True, bidirectional=self.bi) 19 | 20 | def forward(self, x): 21 | h_, c = self.gru(x) 22 | 23 | h = torch.max(h_, dim=1)[0] 24 | 25 | return h, h_ 26 | -------------------------------------------------------------------------------- /model/encoder/LSTMEncoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class LSTMEncoder(nn.Module): 7 | def __init__(self, config, gpu_list, *args, **params): 8 | super(LSTMEncoder, self).__init__() 9 | 10 | self.hidden_size = config.getint("model", "hidden_size") 11 | self.bi = config.getboolean("model", "bi_direction") 12 | self.output_size = self.hidden_size 13 | self.num_layers = config.getint("model", "num_layers") 14 | if self.bi: 15 | self.output_size = self.output_size // 2 16 | 17 | self.lstm = nn.LSTM(input_size=self.hidden_size, hidden_size=self.output_size, 18 | num_layers=self.num_layers, batch_first=True, bidirectional=self.bi) 19 | 20 | def forward(self, x): 21 | batch_size = x.size()[0] 22 | seq_len = x.size()[1] 23 | # print(x.size()) 24 | # print(batch_size, self.num_layers + int(self.bi) * self.num_layers, self.output_size) 25 | hidden = ( 26 | torch.autograd.Variable( 27 | torch.zeros(self.num_layers + int(self.bi) * self.num_layers, batch_size, self.output_size)).cuda(), 28 | torch.autograd.Variable( 29 | torch.zeros(self.num_layers + int(self.bi) * self.num_layers, batch_size, self.output_size)).cuda()) 30 | 31 | h_, c = self.lstm(x, hidden) 32 | 33 | h = torch.max(h_, dim=1)[0] 34 | 35 | return h, h_ 36 | -------------------------------------------------------------------------------- /model/encoder/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thunlp/jec-qa/8706c64ff62637e61cd8a729815f585a6df3b3f1/model/encoder/__init__.py -------------------------------------------------------------------------------- /model/layer/Attention.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class Attention(nn.Module): 7 | def __init__(self, config, gpu_list, *args, **params): 8 | super(Attention, self).__init__() 9 | 10 | self.hidden_size = config.getint("model", "hidden_size") 11 | self.fc = nn.Linear(self.hidden_size, self.hidden_size) 12 | 13 | def forward(self, x, y): 14 | x_ = x # self.fc(x) 15 | y_ = torch.transpose(y, 1, 2) 16 | a_ = torch.bmm(x_, y_) 17 | 18 | x_atten = torch.softmax(a_, dim=2) 19 | x_atten = torch.bmm(x_atten, y) 20 | 21 | y_atten = torch.softmax(a_, dim=1) 22 | y_atten = torch.bmm(torch.transpose(y_atten, 2, 1), x) 23 | 24 | return x_atten, y_atten, a_ 25 | -------------------------------------------------------------------------------- /model/layer/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thunlp/jec-qa/8706c64ff62637e61cd8a729815f585a6df3b3f1/model/layer/__init__.py -------------------------------------------------------------------------------- /model/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.autograd import Variable 5 | import numpy as np 6 | 7 | 8 | class MultiLabelSoftmaxLoss(nn.Module): 9 | def __init__(self, config, task_num=0): 10 | super(MultiLabelSoftmaxLoss, self).__init__() 11 | self.task_num = task_num 12 | self.criterion = [] 13 | for a in range(0, self.task_num): 14 | try: 15 | ratio = config.getfloat("train", "loss_weight_%d" % a) 16 | self.criterion.append( 17 | nn.CrossEntropyLoss(weight=torch.from_numpy(np.array([1.0, ratio], dtype=np.float32)).cuda())) 18 | # print_info("Task %d with weight %.3lf" % (task, ratio)) 19 | except Exception as e: 20 | self.criterion.append(nn.CrossEntropyLoss()) 21 | 22 | def forward(self, outputs, labels): 23 | loss = 0 24 | for a in range(0, len(outputs[0])): 25 | o = outputs[:, a, :].view(outputs.size()[0], -1) 26 | loss += self.criterion[a](o, labels[:, a]) 27 | 28 | return loss 29 | 30 | 31 | def multi_label_cross_entropy_loss(outputs, labels): 32 | labels = labels.float() 33 | temp = outputs 34 | res = - labels * torch.log(temp) - (1 - labels) * torch.log(1 - temp) 35 | res = torch.mean(torch.sum(res, dim=1)) 36 | 37 | return res 38 | 39 | 40 | def cross_entropy_loss(outputs, labels): 41 | criterion = nn.CrossEntropyLoss() 42 | return criterion(outputs, labels) 43 | 44 | 45 | class FocalLoss(nn.Module): 46 | def __init__(self, gamma=0, alpha=None, size_average=True): 47 | super(FocalLoss, self).__init__() 48 | self.gamma = gamma 49 | self.alpha = alpha 50 | self.size_average = size_average 51 | 52 | def forward(self, input, target): 53 | if input.dim() > 2: 54 | input = input.view(input.size(0), input.size(1), -1) # N,C,H,W => N,C,H*W 55 | input = input.transpose(1, 2) # N,C,H*W => N,H*W,C 56 | input = input.contiguous().view(-1, input.size(2)) # N,H*W,C => N*H*W,C 57 | target = target.view(-1, 1) 58 | 59 | logpt = F.log_softmax(input) 60 | logpt = logpt.gather(1, target) 61 | logpt = logpt.view(-1) 62 | pt = Variable(logpt.data.exp()) 63 | 64 | if self.alpha is not None: 65 | if self.alpha.type() != input.data.type(): 66 | self.alpha = self.alpha.type_as(input.data) 67 | at = self.alpha.gather(0, target.data.view(-1)) 68 | logpt = logpt * Variable(at) 69 | 70 | loss = -1 * (1 - pt) ** self.gamma * logpt 71 | if self.size_average: 72 | return loss.mean() 73 | else: 74 | return loss.sum() 75 | -------------------------------------------------------------------------------- /model/optimizer.py: -------------------------------------------------------------------------------- 1 | import torch.optim as optim 2 | from pytorch_pretrained_bert import BertAdam 3 | 4 | 5 | def init_optimizer(model, config, *args, **params): 6 | optimizer_type = config.get("train", "optimizer") 7 | learning_rate = config.getfloat("train", "learning_rate") 8 | if optimizer_type == "adam": 9 | optimizer = optim.Adam(model.parameters(), lr=learning_rate, 10 | weight_decay=config.getfloat("train", "weight_decay")) 11 | elif optimizer_type == "sgd": 12 | optimizer = optim.SGD(model.parameters(), lr=learning_rate, 13 | weight_decay=config.getfloat("train", "weight_decay")) 14 | elif optimizer_type == "bert_adam": 15 | optimizer = BertAdam(model.parameters(), lr=learning_rate, 16 | weight_decay=config.getfloat("train", "weight_decay")) 17 | else: 18 | raise NotImplementedError 19 | 20 | return optimizer 21 | -------------------------------------------------------------------------------- /model/qa/Bert.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from pytorch_pretrained_bert import BertModel 5 | 6 | from tools.accuracy_tool import single_label_top1_accuracy 7 | 8 | 9 | class BertQA(nn.Module): 10 | def __init__(self, config, gpu_list, *args, **params): 11 | super(BertQA, self).__init__() 12 | 13 | self.bert = BertModel.from_pretrained(config.get("model", "bert_path")) 14 | self.rank_module = nn.Linear(768 * config.getint("data", "topk"), 1) 15 | 16 | self.criterion = nn.CrossEntropyLoss() 17 | 18 | self.multi = config.getboolean("data", "multi_choice") 19 | self.multi_module = nn.Linear(4, 16) 20 | self.accuracy_function = single_label_top1_accuracy 21 | 22 | def init_multi_gpu(self, device, config, *args, **params): 23 | self.bert = nn.DataParallel(self.bert, device_ids=device) 24 | 25 | def forward(self, data, config, gpu_list, acc_result, mode): 26 | text = data["text"] 27 | token = data["token"] 28 | mask = data["mask"] 29 | 30 | batch = text.size()[0] 31 | option = text.size()[1] 32 | k = config.getint("data", "topk") 33 | option = option // k 34 | text = text.view(text.size()[0] * text.size()[1], text.size()[2]) 35 | token = token.view(token.size()[0] * token.size()[1], token.size()[2]) 36 | mask = mask.view(mask.size()[0] * mask.size()[1], mask.size()[2]) 37 | 38 | encode, y = self.bert.forward(text, token, mask, output_all_encoded_layers=False) 39 | 40 | y = y.view(batch * option, -1) 41 | y = self.rank_module(y) 42 | 43 | y = y.view(batch, option) 44 | 45 | if self.multi: 46 | y = self.multi_module(y) 47 | 48 | label = data["label"] 49 | loss = self.criterion(y, label) 50 | acc_result = self.accuracy_function(y, label, config, acc_result) 51 | return {"loss": loss, "acc_result": acc_result} 52 | -------------------------------------------------------------------------------- /model/qa/BiDAF.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from model.encoder.LSTMEncoder import LSTMEncoder 6 | from model.layer.Attention import Attention 7 | from tools.accuracy_tool import single_label_top1_accuracy 8 | from model.qa.util import generate_ans 9 | 10 | 11 | class BiDAFQA(nn.Module): 12 | def __init__(self, config, gpu_list, *args, **params): 13 | super(BiDAFQA, self).__init__() 14 | 15 | self.hidden_size = config.getint("model", "hidden_size") 16 | self.word_num = 0 17 | f = open(config.get("data", "word2id"), "r", encoding="utf8") 18 | for line in f: 19 | self.word_num += 1 20 | 21 | self.embedding = nn.Embedding(self.word_num, self.hidden_size) 22 | self.context_encoder = LSTMEncoder(config, gpu_list, *args, **params) 23 | self.question_encoder = LSTMEncoder(config, gpu_list, *args, **params) 24 | self.attention = Attention(config, gpu_list, *args, **params) 25 | 26 | self.rank_module = nn.Linear(self.hidden_size * 2 * config.getint("data", "topk"), 1) 27 | 28 | self.criterion = nn.CrossEntropyLoss() 29 | 30 | self.multi = config.getboolean("data", "multi_choice") 31 | self.multi_module = nn.Linear(4, 16) 32 | self.accuracy_function = single_label_top1_accuracy 33 | 34 | def init_multi_gpu(self, device, config, *args, **params): 35 | pass 36 | # self.bert = nn.DataParallel(self.bert, device_ids=device) 37 | 38 | def forward(self, data, config, gpu_list, acc_result, mode): 39 | context = data["context"] 40 | question = data["question"] 41 | 42 | batch = question.size()[0] 43 | option = question.size()[1] 44 | k = config.getint("data", "topk") 45 | 46 | context = context.view(batch * option * k, -1) 47 | question = question.view(batch, option, 1, -1).repeat(1, 1, k, 1) 48 | question = question.view(batch * option * k, -1) 49 | context = self.embedding(context) 50 | question = self.embedding(question) 51 | 52 | _, context = self.context_encoder(context) 53 | _, question = self.question_encoder(question) 54 | 55 | c, q, a = self.attention(context, question) 56 | 57 | y = torch.cat([torch.max(c, dim=1)[0], torch.max(q, dim=1)[0]], dim=1) 58 | 59 | y = y.view(batch * option, -1) 60 | y = self.rank_module(y) 61 | 62 | y = y.view(batch, option) 63 | 64 | if self.multi: 65 | y = self.multi_module(y) 66 | 67 | if mode != "test": 68 | label = data["label"] 69 | loss = self.criterion(y, label) 70 | acc_result = self.accuracy_function(y, label, config, acc_result) 71 | return {"loss": loss, "acc_result": acc_result} 72 | 73 | return {"output": generate_ans(data["id"], y)} 74 | -------------------------------------------------------------------------------- /model/qa/CoMatch.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Copyright 2015 Singapore Management University (SMU). All Rights Reserved. 3 | Permission to use, copy, modify and distribute this software and its documentation for purposes of research, teaching and general academic pursuits, without fee and without a signed licensing agreement, is hereby granted, provided that the above copyright statement, this paragraph and the following paragraph on disclaimer appear in all copies, modifications, and distributions. Contact Singapore Management University, Intellectual Property Management Office at iie@smu.edu.sg, for commercial licensing opportunities. 4 | This software is provided by the copyright holder and creator "as is" and any express or implied warranties, including, but not Limited to, the implied warranties of merchantability and fitness for a particular purpose are disclaimed. In no event shall SMU or the creator be liable for any direct, indirect, incidental, special, exemplary or consequential damages, however caused arising in any way out of the use of this software. 5 | ''' 6 | import torch 7 | import torch.nn as nn 8 | import torch.optim as optim 9 | from torch.autograd import Variable 10 | import numpy as np 11 | 12 | import json 13 | from tools.accuracy_tool import single_label_top1_accuracy 14 | 15 | 16 | def masked_softmax(vector, seq_lens): 17 | mask = vector.new(vector.size()).zero_() 18 | for i in range(seq_lens.size(0)): 19 | mask[i, :, :seq_lens[i]] = 1 20 | mask = Variable(mask, requires_grad=False) 21 | 22 | if mask is None: 23 | result = torch.nn.functional.softmax(vector, dim=-1) 24 | else: 25 | result = torch.nn.functional.softmax(vector * mask, dim=-1) 26 | result = result * mask 27 | result = result / (result.sum(dim=-1, keepdim=True) + 1e-13) 28 | return result 29 | 30 | 31 | class MatchNet(nn.Module): 32 | def __init__(self, mem_dim, dropoutP): 33 | super(MatchNet, self).__init__() 34 | self.map_linear = nn.Linear(2 * mem_dim, 2 * mem_dim) 35 | self.trans_linear = nn.Linear(mem_dim, mem_dim) 36 | self.drop_module = nn.Dropout(dropoutP) 37 | 38 | def forward(self, inputs): 39 | proj_p, proj_q, seq_len = inputs 40 | trans_q = self.trans_linear(proj_q) 41 | att_weights = proj_p.bmm(torch.transpose(proj_q, 1, 2)) 42 | att_norm = masked_softmax(att_weights, seq_len) 43 | 44 | att_vec = att_norm.bmm(proj_q) 45 | elem_min = att_vec - proj_p 46 | elem_mul = att_vec * proj_p 47 | all_con = torch.cat([elem_min, elem_mul], 2) 48 | output = nn.ReLU()(self.map_linear(all_con)) 49 | return output 50 | 51 | 52 | class MaskLSTM(nn.Module): 53 | def __init__(self, in_dim, out_dim, layers=1, batch_first=True, bidirectional=True, dropoutP=0.3): 54 | super(MaskLSTM, self).__init__() 55 | self.lstm_module = nn.LSTM(in_dim, out_dim, layers, batch_first=batch_first, bidirectional=bidirectional, 56 | dropout=dropoutP) 57 | self.drop_module = nn.Dropout(dropoutP) 58 | 59 | def forward(self, inputs): 60 | input, seq_lens = inputs 61 | mask_in = input.new(input.size()).zero_() 62 | for i in range(seq_lens.size(0)): 63 | mask_in[i, :seq_lens[i]] = 1 64 | mask_in = Variable(mask_in, requires_grad=False) 65 | 66 | input_drop = self.drop_module(input * mask_in) 67 | 68 | H, _ = self.lstm_module(input_drop) 69 | 70 | mask = H.new(H.size()).zero_() 71 | for i in range(seq_lens.size(0)): 72 | mask[i, :seq_lens[i]] = 1 73 | mask = Variable(mask, requires_grad=False) 74 | 75 | output = H * mask 76 | 77 | return output 78 | 79 | 80 | class CoMatch(nn.Module): 81 | def __init__(self, config): 82 | super(CoMatch, self).__init__() 83 | self.emb_dim = config.getint("model", "hidden_size") # 300 84 | self.mem_dim = config.getint("model", "hidden_size") # 150 85 | self.dropoutP = config.getfloat("model", "dropout") # args.dropoutP 0.2 86 | # self.cuda_bool = args.cuda 87 | 88 | self.word_num = len(json.load(open(config.get("data", "word2id"), "r"))) 89 | 90 | self.embs = nn.Embedding(self.word_num, self.emb_dim) 91 | 92 | self.encoder = MaskLSTM(self.emb_dim, self.mem_dim, dropoutP=self.dropoutP) 93 | self.l_encoder = MaskLSTM(self.mem_dim * 8, self.mem_dim, dropoutP=self.dropoutP) 94 | self.h_encoder = MaskLSTM(self.mem_dim * 2, self.mem_dim, dropoutP=0) 95 | 96 | self.match_module = MatchNet(self.mem_dim * 2, self.dropoutP) 97 | 98 | self.rank_module = nn.Linear(self.mem_dim * 2 * config.getint("data", "topk"), 1) 99 | 100 | self.multi = config.getboolean("data", "multi_choice") 101 | self.multi_module = nn.Linear(4, 16) 102 | 103 | self.drop_module = nn.Dropout(self.dropoutP) 104 | 105 | def init_multi_gpu(self, device, config, *args, **params): 106 | self.embs = nn.DataParallel(self.embs) 107 | self.encoder = nn.DataParallel(self.encoder) 108 | self.l_encoder = nn.DataParallel(self.l_encoder) 109 | self.h_encoder = nn.DataParallel(self.h_encoder) 110 | self.match_module = nn.DataParallel(self.match_module) 111 | self.rank_module = nn.DataParallel(self.rank_module) 112 | 113 | def forward(self, inputs): 114 | documents, questions, options = inputs 115 | d_word, d_h_len, d_l_len = documents 116 | o_word, o_h_len, o_l_len = options 117 | q_word, q_len = questions 118 | # print("d_word", d_word.size()) 119 | # print("d_h_len", d_h_len.size()) 120 | # print("d_l_len", d_l_len.size()) 121 | # print("o_word", o_word.size()) 122 | # print("o_h_len", o_h_len.size()) 123 | # print("o_l_len", o_l_len.size()) 124 | # print("q_word", q_word.size()) 125 | # print("q_len", q_len.size()) 126 | 127 | batch = d_word.size()[0] 128 | option = d_word.size()[1] 129 | k = d_word.size()[2] 130 | 131 | d_embs = self.drop_module(self.embs(d_word)) 132 | d_embs = torch.zeros(d_embs.shape).cuda() 133 | o_embs = self.drop_module(self.embs(o_word)) 134 | q_embs = self.drop_module(self.embs(q_word)) 135 | # print("d_embs", d_embs.size()) 136 | # print("o_embs", o_embs.size()) 137 | # print("q_embs", q_embs.size()) 138 | 139 | d_hidden = self.encoder( 140 | [d_embs.view(d_embs.size(0) * d_embs.size(1) * d_embs.size(2) * d_embs.size(3), d_embs.size(4), 141 | self.emb_dim), 142 | d_l_len.view(-1)]) 143 | o_hidden = self.encoder( 144 | [o_embs.view(o_embs.size(0) * o_embs.size(1), o_embs.size(2), self.emb_dim), o_l_len.view(-1)]) 145 | q_hidden = self.encoder([q_embs, q_len]) 146 | 147 | # print("d_hidden", d_hidden.size()) 148 | # print("o_hidden", o_hidden.size()) 149 | # print("q_hidden", q_hidden.size()) 150 | 151 | # d_hidden_3d = d_hidden.view(d_embs.size(0), d_embs.size(1) * d_embs.size(2), d_hidden.size(-1)) 152 | # d_hidden_3d_repeat = d_hidden_3d.repeat(1, o_embs.size(1), 1).view(d_hidden_3d.size(0) * o_embs.size(1), 153 | # d_hidden_3d.size(1), d_hidden_3d.size(2)) 154 | d_hidden_3d_repeat = d_hidden.view(d_word.size()[0] * d_word.size()[2] * o_embs.size(1), -1, 155 | d_hidden.size()[-1]) 156 | 157 | # print("d_hidden_3d", d_hidden_3d.size()) 158 | # print("d_hidden_3d_repeat", d_hidden_3d_repeat.size()) 159 | 160 | q_hidden_repeat = q_hidden.repeat(1, o_embs.size(1) * d_word.size()[2], 1).view( 161 | q_hidden.size()[0] * o_embs.size(1) * d_word.size()[2], q_hidden.size()[1], q_hidden.size()[2]) 162 | q_len_repeat = q_len.repeat(o_embs.size(1) * d_word.size()[2]) 163 | # print("q_hidden_repeat", q_hidden_repeat.size()) 164 | # print("q_len_repeat", q_len_repeat.size()) 165 | 166 | o_hidden_repeat = o_hidden.repeat(1, 1, d_word.size()[2], ).view(o_hidden.size()[0] * d_word.size()[2], 167 | o_hidden.size()[1], o_hidden.size()[2]) 168 | o_l_len_repeat = o_l_len.repeat(1, d_word.size()[2]).view(o_l_len.size()[0] * d_word.size()[2], 169 | o_l_len.size()[1]) 170 | # print("o_hidden_repeat", o_hidden_repeat.size()) 171 | # print("o_l_len_repeat", o_l_len_repeat.size()) 172 | 173 | do_match = self.match_module([d_hidden_3d_repeat, o_hidden_repeat, o_l_len_repeat.view(-1)]) 174 | dq_match = self.match_module([d_hidden_3d_repeat, q_hidden_repeat, q_len_repeat]) 175 | 176 | # print("do_match", do_match.size()) 177 | # print("dq_match", dq_match.size()) 178 | 179 | dq_match_repeat = dq_match 180 | # print("dq_match_repeat", dq_match_repeat.size()) 181 | 182 | co_match = torch.cat([do_match, dq_match_repeat], -1) 183 | 184 | # print("co_match", co_match.size()) 185 | 186 | co_match_hier = co_match.view(d_embs.size(0) * o_embs.size(1) * d_embs.size(2) * d_embs.size(3), d_embs.size(4), 187 | -1) 188 | # print("co_match_hier", co_match_hier.size()) 189 | 190 | l_hidden = self.l_encoder([co_match_hier, d_l_len.view(-1)]) 191 | # print("l_hidden", l_hidden.size()) 192 | l_hidden_pool, _ = l_hidden.max(1) 193 | # print("l_hidden_pool", l_hidden_pool.size()) 194 | 195 | h_hidden = self.h_encoder( 196 | [l_hidden_pool.view(d_embs.size(0) * o_embs.size(1) * d_embs.size(2), d_embs.size(3), -1), 197 | d_h_len.view(-1, 1).view(-1)]) 198 | # print("h_hidden", h_hidden.size()) 199 | h_hidden_pool, _ = h_hidden.max(1) 200 | # print("h_hidden_pool", h_hidden_pool.size()) 201 | 202 | # o_rep = h_hidden_pool.view(d_embs.size(0), o_embs.size(1), -1) 203 | # print("o_rep", o_rep.size()) 204 | # output = self.rank_module(o_rep).squeeze(2) 205 | 206 | o_rep = h_hidden_pool.view(d_embs.size(0), o_embs.size(1), -1) 207 | output = self.rank_module(o_rep).squeeze(2) 208 | 209 | if self.multi: 210 | output = self.multi_module(output) 211 | 212 | return output 213 | 214 | 215 | class CoMatching(nn.Module): 216 | def __init__(self, config, gpu_list, *args, **params): 217 | super(CoMatching, self).__init__() 218 | 219 | self.co_match = CoMatch(config) 220 | 221 | self.criterion = nn.CrossEntropyLoss() 222 | self.accuracy_function = single_label_top1_accuracy 223 | 224 | def init_multi_gpu(self, device, config, *args, **params): 225 | self.co_match.init_multi_gpu(device, config, *args, **params) 226 | 227 | def forward(self, data, config, gpu_list, acc_result, mode): 228 | q, ql = data["question"], data["question_len"] 229 | o, oh, ol = data["option"], data["option_sent"], data["option_len"] 230 | d, dh, dl = data["document"], data["document_sent"], data["document_len"] 231 | label = data["label"] 232 | 233 | x = [[d, dh, dl], [q, ql], [o, oh, ol]] 234 | y = self.co_match(x) 235 | 236 | loss = self.criterion(y, label) 237 | acc_result = self.accuracy_function(y, label, config, acc_result) 238 | return {"loss": loss, "acc_result": acc_result} 239 | -------------------------------------------------------------------------------- /model/qa/HAF.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from model.encoder.GRUEncoder import GRUEncoder 6 | # from model.layer.Attention import Attention 7 | from tools.accuracy_tool import single_label_top1_accuracy 8 | from model.qa.util import generate_ans 9 | 10 | """class BiAttention(nn.Module): 11 | def __init__(self, config, gpu_list, *args, **params): 12 | super(BiAttention, self).__init__() 13 | 14 | self.attention = Attention(config, gpu_list, *args, **params) 15 | 16 | def init_multi_gpu(self, device, config, *args, **params): 17 | pass 18 | 19 | def forward(self, x1, x2): 20 | c, q, a = self.attention(x1, x2) 21 | 22 | y = torch.cat([torch.max(c, dim=1)[0], torch.max(q, dim=1)[0]], dim=1) 23 | 24 | return y 25 | """ 26 | 27 | 28 | class Attention(nn.Module): 29 | def __init__(self, config, gpu_list, *args, **params): 30 | super(Attention, self).__init__() 31 | 32 | self.hidden_size = config.getint("model", "hidden_size") 33 | self.fc = nn.Linear(self.hidden_size, self.hidden_size) 34 | 35 | def forward(self, x, y): 36 | x_ = self.fc(x) 37 | y_ = torch.transpose(y, 1, 2) 38 | a_ = torch.bmm(x_, y_) 39 | 40 | s = torch.softmax(a_, dim=2) 41 | a = torch.mean(s, dim=1) 42 | 43 | return a 44 | 45 | 46 | class HAF(nn.Module): 47 | def __init__(self, config, gpu_list, *args, **params): 48 | super(HAF, self).__init__() 49 | 50 | self.hidden_size = config.getint("model", "hidden_size") 51 | self.word_num = 0 52 | f = open(config.get("data", "word2id"), "r", encoding="utf8") 53 | for line in f: 54 | self.word_num += 1 55 | 56 | self.embedding = nn.Embedding(self.word_num, self.hidden_size) 57 | 58 | self.question_encoder = GRUEncoder(config, gpu_list, *args, **params) 59 | self.passage_encoder = GRUEncoder(config, gpu_list, *args, **params) 60 | self.option_encoder = GRUEncoder(config, gpu_list, *args, **params) 61 | self.s = GRUEncoder(config, gpu_list, *args, **params) 62 | 63 | self.q2p = Attention(config, gpu_list, *args, **params) 64 | self.q2o = Attention(config, gpu_list, *args, **params) 65 | self.o2p = Attention(config, gpu_list, *args, **params) 66 | self.oc = Attention(config, gpu_list, *args, **params) 67 | 68 | self.wp = nn.Linear(self.hidden_size, self.hidden_size * 2) 69 | self.score = nn.Linear( 70 | config.getint("data", "topk") * config.getint("data", "option_len") * config.getint("data", "passage_len"), 71 | 1) 72 | 73 | self.criterion = nn.CrossEntropyLoss() 74 | 75 | self.multi = config.getboolean("data", "multi_choice") 76 | self.multi_module = nn.Linear(4, 16) 77 | self.accuracy_function = single_label_top1_accuracy 78 | 79 | def init_multi_gpu(self, device, config, *args, **params): 80 | pass 81 | # self.bert = nn.DataParallel(self.bert, device_ids=device) 82 | 83 | def forward(self, data, config, gpu_list, acc_result, mode): 84 | passage = data["passage"] 85 | question = data["question"] 86 | option = data["option"] 87 | 88 | batch = question.size()[0] 89 | option_num = option.size()[1] 90 | k = config.getint("data", "topk") 91 | 92 | passage = passage.view(batch * option_num * k, -1) 93 | question = question.view(batch, -1) 94 | option = option.view(batch * option_num, -1) 95 | # print(passage.size(), question.size(), option.size()) 96 | 97 | passage = self.embedding(passage) 98 | question = self.embedding(question) 99 | option = self.embedding(option) 100 | # print(passage.size(), question.size(), option.size()) 101 | 102 | _, passage = self.passage_encoder(passage) 103 | _, question = self.question_encoder(question) 104 | _, option = self.option_encoder(option) 105 | # print(passage.size(), question.size(), option.size()) 106 | 107 | passage = passage.view(batch * option_num * k, -1, self.hidden_size) 108 | question = question.view(batch, 1, 1, -1, self.hidden_size).repeat(1, option_num, k, 1, 1).view( 109 | batch * option_num * k, -1, self.hidden_size) 110 | option = option.view(batch, option_num, 1, -1, self.hidden_size).repeat(1, 1, k, 1, 1).view( 111 | batch * option_num * k, -1, self.hidden_size) 112 | # print(passage.size(), question.size(), option.size()) 113 | 114 | vp = self.q2p(question, passage).view(batch * option_num * k, -1, 1) 115 | # print("vp", vp.size()) 116 | vp = vp * passage 117 | # print("vp", vp.size()) 118 | vo = self.q2o(question, option).view(batch * option_num * k, -1, 1) 119 | # print("vo", vo.size()) 120 | vo = vo * option 121 | # print("vo", vo.size()) 122 | _, vpp = self.s(vp) 123 | # print("vpp", vpp.size()) 124 | rp = self.q2o(vo, vpp).view(batch * option_num * k, -1, 1) 125 | # print("rp", rp.size()) 126 | rp = rp * vpp 127 | # print("rp", rp.size()) 128 | vop = self.oc(vo, vo).view(batch * option_num * k, -1, 1) 129 | # print("vop", vop.size()) 130 | vop = vop * vo 131 | # print("vop", vop.size()) 132 | ro = torch.cat([vo, vo - vop], dim=2) 133 | # print("ro", ro.size()) 134 | 135 | s = self.wp(rp) 136 | ro = torch.transpose(ro, 2, 1) 137 | s = torch.bmm(s, ro) 138 | # print(s.size()) 139 | s = s.view(batch * option_num, -1) 140 | s = self.score(s) 141 | y = s.view(batch, option_num) 142 | # print(y.size()) 143 | # gg 144 | 145 | if self.multi: 146 | y = self.multi_module(y) 147 | 148 | if mode != "test": 149 | label = data["label"] 150 | loss = self.criterion(y, label) 151 | acc_result = self.accuracy_function(y, label, config, acc_result) 152 | return {"loss": loss, "acc_result": acc_result} 153 | 154 | return {"output": generate_ans(data["id"], y)} 155 | -------------------------------------------------------------------------------- /model/qa/util.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def generate_ans(id_list, ans_list): 5 | result = [] 6 | for a in range(0, len(id_list)): 7 | idx = id_list[a] 8 | ans = ans_list[a] 9 | if len(ans) == 4: 10 | ans = [["A", "B", "C", "D"][int(torch.max(ans, dim=0)[1])]] 11 | else: 12 | ans_ = int(torch.max(ans, dim=0)[1]) 13 | ans = [] 14 | for x, y in [(8, "D"), (4, "C"), (2, "B"), (1, "A")]: 15 | if ans_ >= x: 16 | ans.append(y) 17 | ans_ -= x 18 | result.append({"id": idx, "answer": ans}) 19 | 20 | return result 21 | -------------------------------------------------------------------------------- /reader/__init__.py: -------------------------------------------------------------------------------- 1 | from .reader import init_dataset, init_test_dataset 2 | -------------------------------------------------------------------------------- /reader/reader.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import DataLoader 2 | import logging 3 | 4 | import formatter as form 5 | from dataset import dataset_list 6 | 7 | logger = logging.getLogger(__name__) 8 | 9 | collate_fn = {} 10 | formatter = {} 11 | 12 | 13 | def init_formatter(config, task_list, *args, **params): 14 | for task in task_list: 15 | formatter[task] = form.init_formatter(config, task, *args, **params) 16 | 17 | def train_collate_fn(data): 18 | return formatter["train"].process(data, config, "train") 19 | 20 | def valid_collate_fn(data): 21 | return formatter["valid"].process(data, config, "valid") 22 | 23 | def test_collate_fn(data): 24 | return formatter["test"].process(data, config, "test") 25 | 26 | if task == "train": 27 | collate_fn[task] = train_collate_fn 28 | elif task == "valid": 29 | collate_fn[task] = valid_collate_fn 30 | else: 31 | collate_fn[task] = test_collate_fn 32 | 33 | 34 | def init_one_dataset(config, mode, *args, **params): 35 | temp_mode = mode 36 | if mode != "train": 37 | try: 38 | config.get("data", "%s_dataset_type" % temp_mode) 39 | except Exception as e: 40 | logger.warning( 41 | "[reader] %s_dataset_type has not been defined in config file, use [dataset] train_dataset_type instead." % temp_mode) 42 | temp_mode = "train" 43 | which = config.get("data", "%s_dataset_type" % temp_mode) 44 | 45 | if which in dataset_list: 46 | dataset = dataset_list[which](config, mode, *args, **params) 47 | batch_size = config.getint("train", "batch_size") 48 | shuffle = config.getboolean("train", "shuffle") 49 | reader_num = config.getint("train", "reader_num") 50 | drop_last = True 51 | if mode in ["valid", "test"]: 52 | if mode == "test": 53 | drop_last = False 54 | 55 | try: 56 | batch_size = config.getint("eval", "batch_size") 57 | except Exception as e: 58 | logger.warning("[eval] batch size has not been defined in config file, use [train] batch_size instead.") 59 | 60 | try: 61 | shuffle = config.getboolean("eval", "shuffle") 62 | except Exception as e: 63 | shuffle = False 64 | logger.warning("[eval] shuffle has not been defined in config file, use false as default.") 65 | try: 66 | reader_num = config.getint("eval", "reader_num") 67 | except Exception as e: 68 | logger.warning("[eval] reader num has not been defined in config file, use [train] reader num instead.") 69 | 70 | dataloader = DataLoader(dataset=dataset, 71 | batch_size=batch_size, 72 | shuffle=shuffle, 73 | num_workers=reader_num, 74 | collate_fn=collate_fn[mode], 75 | drop_last=drop_last) 76 | 77 | return dataloader 78 | else: 79 | logger.error("There is no dataset called %s, check your config." % which) 80 | raise NotImplementedError 81 | 82 | 83 | def init_test_dataset(config, *args, **params): 84 | init_formatter(config, ["test"], *args, **params) 85 | test_dataset = init_one_dataset(config, "test", *args, **params) 86 | 87 | return test_dataset 88 | 89 | 90 | def init_dataset(config, *args, **params): 91 | init_formatter(config, ["train", "valid"], *args, **params) 92 | train_dataset = init_one_dataset(config, "train", *args, **params) 93 | valid_dataset = init_one_dataset(config, "valid", *args, **params) 94 | 95 | return train_dataset, valid_dataset 96 | 97 | 98 | if __name__ == "__main__": 99 | pass 100 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import torch 4 | import logging 5 | import json 6 | 7 | from tools.init_tool import init_all 8 | from config_parser import create_config 9 | from tools.test_tool import test 10 | 11 | logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', 12 | datefmt='%m/%d/%Y %H:%M:%S', 13 | level=logging.INFO) 14 | 15 | logger = logging.getLogger(__name__) 16 | 17 | if __name__ == "__main__": 18 | parser = argparse.ArgumentParser() 19 | parser.add_argument('--config', '-c', help="specific config file", required=True) 20 | parser.add_argument('--gpu', '-g', help="gpu id list") 21 | parser.add_argument('--checkpoint', help="checkpoint file path", required=True) 22 | parser.add_argument('--result', help="result file path", required=True) 23 | args = parser.parse_args() 24 | 25 | configFilePath = args.config 26 | 27 | use_gpu = True 28 | gpu_list = [] 29 | if args.gpu is None: 30 | use_gpu = False 31 | else: 32 | use_gpu = True 33 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu 34 | 35 | device_list = args.gpu.split(",") 36 | for a in range(0, len(device_list)): 37 | gpu_list.append(int(a)) 38 | 39 | os.system("clear") 40 | 41 | config = create_config(configFilePath) 42 | 43 | cuda = torch.cuda.is_available() 44 | logger.info("CUDA available: %s" % str(cuda)) 45 | if not cuda and len(gpu_list) > 0: 46 | logger.error("CUDA is not available but specific gpu id") 47 | raise NotImplementedError 48 | 49 | parameters = init_all(config, gpu_list, args.checkpoint, "test") 50 | 51 | json.dump(test(parameters, config, gpu_list), open(args.result, "w", encoding="utf8"), ensure_ascii=False, 52 | sort_keys=True, indent=2) 53 | -------------------------------------------------------------------------------- /tools/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thunlp/jec-qa/8706c64ff62637e61cd8a729815f585a6df3b3f1/tools/__init__.py -------------------------------------------------------------------------------- /tools/accuracy_init.py: -------------------------------------------------------------------------------- 1 | from .accuracy_tool import single_label_top1_accuracy, single_label_top2_accuracy, multi_label_accuracy, \ 2 | null_accuracy_function, log_distance_accuracy_function 3 | 4 | accuracy_function_dic = { 5 | "SingleLabelTop1": single_label_top1_accuracy, 6 | "MultiLabel": multi_label_accuracy, 7 | "Null": null_accuracy_function, 8 | "LogDis": log_distance_accuracy_function 9 | } 10 | 11 | 12 | def init_accuracy_function(config, *args, **params): 13 | name = config.get("output", "accuracy_method") 14 | if name in accuracy_function_dic: 15 | return accuracy_function_dic[name] 16 | else: 17 | raise NotImplementedError 18 | -------------------------------------------------------------------------------- /tools/accuracy_tool.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import torch 3 | 4 | logger = logging.Logger(__name__) 5 | 6 | 7 | def get_prf(res): 8 | # According to https://github.com/dice-group/gerbil/wiki/Precision,-Recall-and-F1-measure 9 | if res["TP"] == 0: 10 | if res["FP"] == 0 and res["FN"] == 0: 11 | precision = 1.0 12 | recall = 1.0 13 | f1 = 1.0 14 | else: 15 | precision = 0.0 16 | recall = 0.0 17 | f1 = 0.0 18 | else: 19 | precision = 1.0 * res["TP"] / (res["TP"] + res["FP"]) 20 | recall = 1.0 * res["TP"] / (res["TP"] + res["FN"]) 21 | f1 = 2 * precision * recall / (precision + recall) 22 | 23 | return precision, recall, f1 24 | 25 | 26 | def gen_micro_macro_result(res): 27 | precision = [] 28 | recall = [] 29 | f1 = [] 30 | total = {"TP": 0, "FP": 0, "FN": 0, "TN": 0} 31 | for a in range(0, len(res)): 32 | total["TP"] += res[a]["TP"] 33 | total["FP"] += res[a]["FP"] 34 | total["FN"] += res[a]["FN"] 35 | total["TN"] += res[a]["TN"] 36 | 37 | p, r, f = get_prf(res[a]) 38 | precision.append(p) 39 | recall.append(r) 40 | f1.append(f) 41 | 42 | micro_precision, micro_recall, micro_f1 = get_prf(total) 43 | 44 | macro_precision = 0 45 | macro_recall = 0 46 | macro_f1 = 0 47 | for a in range(0, len(f1)): 48 | macro_precision += precision[a] 49 | macro_recall += recall[a] 50 | macro_f1 += f1[a] 51 | 52 | macro_precision /= len(f1) 53 | macro_recall /= len(f1) 54 | macro_f1 /= len(f1) 55 | 56 | return { 57 | "mip": round(micro_precision, 4), 58 | "mir": round(micro_recall, 4), 59 | "mif": round(micro_f1, 4), 60 | "map": round(macro_precision, 4), 61 | "mar": round(macro_recall, 4), 62 | "maf": round(macro_f1, 4) 63 | } 64 | 65 | 66 | def null_accuracy_function(outputs, label, config, result=None): 67 | return None 68 | 69 | 70 | def log_distance_accuracy_function(outputs, label, config, result=None): 71 | if result is None: 72 | result = [0, 0] 73 | 74 | result[0] += outputs.size()[0] 75 | result[1] += float(torch.sum(torch.log(torch.abs(torch.clamp(outputs, 0, 450) - torch.clamp(label, 0, 450)) + 1))) 76 | 77 | return result 78 | 79 | 80 | def single_label_top1_accuracy(outputs, label, config, result=None): 81 | if result is None: 82 | result = [] 83 | id1 = torch.max(outputs, dim=1)[1] 84 | # id2 = torch.max(label, dim=1)[1] 85 | id2 = label 86 | nr_classes = outputs.size(1) 87 | while len(result) < nr_classes: 88 | result.append({"TP": 0, "FN": 0, "FP": 0, "TN": 0}) 89 | for a in range(0, len(id1)): 90 | # if len(result) < a: 91 | # result.append({"TP": 0, "FN": 0, "FP": 0, "TN": 0}) 92 | 93 | it_is = int(id1[a]) 94 | should_be = int(id2[a]) 95 | if it_is == should_be: 96 | result[it_is]["TP"] += 1 97 | else: 98 | result[it_is]["FP"] += 1 99 | result[should_be]["FN"] += 1 100 | 101 | return result 102 | 103 | 104 | def multi_label_accuracy(outputs, label, config, result=None): 105 | if len(label[0]) != len(outputs[0]): 106 | raise ValueError('Input dimensions of labels and outputs must match.') 107 | 108 | if len(outputs.size()) > 2: 109 | outputs = outputs.view(outputs.size()[0], -1, 2) 110 | outputs = torch.nn.Softmax(dim=2)(outputs) 111 | outputs = outputs[:, :, 1] 112 | 113 | outputs = outputs.data 114 | labels = label.data 115 | 116 | if result is None: 117 | result = [] 118 | 119 | total = 0 120 | nr_classes = outputs.size(1) 121 | 122 | while len(result) < nr_classes: 123 | result.append({"TP": 0, "FN": 0, "FP": 0, "TN": 0}) 124 | 125 | for i in range(nr_classes): 126 | outputs1 = (outputs[:, i] >= 0.5).long() 127 | labels1 = (labels[:, i].float() >= 0.5).long() 128 | total += int((labels1 * outputs1).sum()) 129 | total += int(((1 - labels1) * (1 - outputs1)).sum()) 130 | 131 | if result is None: 132 | continue 133 | 134 | # if len(result) < i: 135 | # result.append({"TP": 0, "FN": 0, "FP": 0, "TN": 0}) 136 | 137 | result[i]["TP"] += int((labels1 * outputs1).sum()) 138 | result[i]["FN"] += int((labels1 * (1 - outputs1)).sum()) 139 | result[i]["FP"] += int(((1 - labels1) * outputs1).sum()) 140 | result[i]["TN"] += int(((1 - labels1) * (1 - outputs1)).sum()) 141 | 142 | return result 143 | 144 | 145 | def single_label_top2_accuracy(outputs, label, config, result=None): 146 | raise NotImplementedError 147 | # still bug here 148 | 149 | if result is None: 150 | result = [] 151 | # print(label) 152 | 153 | id1 = torch.max(outputs, dim=1)[1] 154 | # id2 = torch.max(label, dim=1)[1] 155 | id2 = label 156 | nr_classes = outputs.size(1) 157 | while len(result) < nr_classes: 158 | result.append({"TP": 0, "FN": 0, "FP": 0, "TN": 0}) 159 | for a in range(0, len(id1)): 160 | # if len(result) < a: 161 | # result.append({"TP": 0, "FN": 0, "FP": 0, "TN": 0}) 162 | 163 | it_is = int(id1[a]) 164 | should_be = int(id2[a]) 165 | if it_is == should_be: 166 | result[it_is]["TP"] += 1 167 | else: 168 | result[it_is]["FP"] += 1 169 | result[should_be]["FN"] += 1 170 | 171 | _, prediction = torch.topk(outputs, 2, 1, largest=True) 172 | prediction1 = prediction[:, 0:1] 173 | prediction2 = prediction[:, 1:] 174 | 175 | prediction1 = prediction1.view(-1) 176 | prediction2 = prediction2.view(-1) 177 | 178 | return result 179 | -------------------------------------------------------------------------------- /tools/dataset_tool.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | 4 | def dfs_search(path, recursive): 5 | if os.path.isfile(path): 6 | return [path] 7 | file_list = [] 8 | name_list = os.listdir(path) 9 | name_list.sort() 10 | for filename in name_list: 11 | real_path = os.path.join(path, filename) 12 | 13 | if os.path.isdir(real_path): 14 | if recursive: 15 | file_list = file_list + dfs_search(real_path, recursive) 16 | else: 17 | file_list.append(real_path) 18 | 19 | return file_list 20 | -------------------------------------------------------------------------------- /tools/eval_tool.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import torch 4 | from torch.autograd import Variable 5 | from torch.optim import lr_scheduler 6 | from tensorboardX import SummaryWriter 7 | from timeit import default_timer as timer 8 | 9 | logger = logging.getLogger(__name__) 10 | 11 | 12 | def gen_time_str(t): 13 | t = int(t) 14 | minute = t // 60 15 | second = t % 60 16 | return '%2d:%02d' % (minute, second) 17 | 18 | 19 | def output_value(epoch, mode, step, time, loss, info, end, config): 20 | try: 21 | delimiter = config.get("output", "delimiter") 22 | except Exception as e: 23 | delimiter = " " 24 | s = "" 25 | s = s + str(epoch) + " " 26 | while len(s) < 7: 27 | s += " " 28 | s = s + str(mode) + " " 29 | while len(s) < 14: 30 | s += " " 31 | s = s + str(step) + " " 32 | while len(s) < 25: 33 | s += " " 34 | s += str(time) 35 | while len(s) < 40: 36 | s += " " 37 | s += str(loss) 38 | while len(s) < 48: 39 | s += " " 40 | s += str(info) 41 | s = s.replace(" ", delimiter) 42 | if not (end is None): 43 | print(s, end=end) 44 | else: 45 | print(s) 46 | 47 | 48 | def valid(model, dataset, epoch, writer, config, gpu_list, output_function, mode="valid"): 49 | model.eval() 50 | 51 | acc_result = None 52 | total_loss = 0 53 | cnt = 0 54 | total_len = len(dataset) 55 | start_time = timer() 56 | output_info = "" 57 | 58 | output_time = config.getint("output", "output_time") 59 | step = -1 60 | more = "" 61 | if total_len < 10000: 62 | more = "\t" 63 | 64 | for step, data in enumerate(dataset): 65 | for key in data.keys(): 66 | if isinstance(data[key], torch.Tensor): 67 | if len(gpu_list) > 0: 68 | data[key] = Variable(data[key].cuda()) 69 | else: 70 | data[key] = Variable(data[key]) 71 | 72 | results = model(data, config, gpu_list, acc_result, "valid") 73 | 74 | loss, acc_result = results["loss"], results["acc_result"] 75 | total_loss += float(loss) 76 | cnt += 1 77 | 78 | if step % output_time == 0: 79 | delta_t = timer() - start_time 80 | 81 | output_value(epoch, mode, "%d/%d" % (step + 1, total_len), "%s/%s" % ( 82 | gen_time_str(delta_t), gen_time_str(delta_t * (total_len - step - 1) / (step + 1))), 83 | "%.3lf" % (total_loss / (step + 1)), output_info, '\r', config) 84 | 85 | if step == -1: 86 | logger.error("There is no data given to the model in this epoch, check your data.") 87 | raise NotImplementedError 88 | 89 | delta_t = timer() - start_time 90 | output_info = output_function(acc_result, config) 91 | output_value(epoch, mode, "%d/%d" % (step + 1, total_len), "%s/%s" % ( 92 | gen_time_str(delta_t), gen_time_str(delta_t * (total_len - step - 1) / (step + 1))), 93 | "%.3lf" % (total_loss / (step + 1)), output_info, None, config) 94 | 95 | writer.add_scalar(config.get("output", "model_name") + "_eval_epoch", float(total_loss) / (step + 1), 96 | epoch) 97 | 98 | model.train() 99 | -------------------------------------------------------------------------------- /tools/init_tool.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import torch 3 | 4 | from reader.reader import init_dataset, init_formatter, init_test_dataset 5 | from model import get_model 6 | from model.optimizer import init_optimizer 7 | from .output_init import init_output_function 8 | 9 | logger = logging.getLogger(__name__) 10 | 11 | 12 | def init_all(config, gpu_list, checkpoint, mode, *args, **params): 13 | result = {} 14 | 15 | logger.info("Begin to initialize dataset and formatter...") 16 | if mode == "train": 17 | init_formatter(config, ["train", "valid"], *args, **params) 18 | result["train_dataset"], result["valid_dataset"] = init_dataset(config, *args, **params) 19 | else: 20 | init_formatter(config, ["test"], *args, **params) 21 | result["test_dataset"] = init_test_dataset(config, *args, **params) 22 | 23 | logger.info("Begin to initialize models...") 24 | 25 | model = get_model(config.get("model", "model_name"))(config, gpu_list, *args, **params) 26 | optimizer = init_optimizer(model, config, *args, **params) 27 | trained_epoch = -1 28 | global_step = 0 29 | 30 | if len(gpu_list) > 0: 31 | model = model.cuda() 32 | 33 | try: 34 | model.init_multi_gpu(gpu_list, config, *args, **params) 35 | except Exception as e: 36 | logger.warning("No init_multi_gpu implemented in the model, use single gpu instead.") 37 | 38 | try: 39 | parameters = torch.load(checkpoint) 40 | model.load_state_dict(parameters["model"]) 41 | 42 | if mode == "train": 43 | trained_epoch = parameters["trained_epoch"] 44 | if config.get("train", "optimizer") == parameters["optimizer_name"]: 45 | optimizer.load_state_dict(parameters["optimizer"]) 46 | else: 47 | logger.warning("Optimizer changed, do not load parameters of optimizer.") 48 | 49 | if "global_step" in parameters: 50 | global_step = parameters["global_step"] 51 | except Exception as e: 52 | information = "Cannot load checkpoint file with error %s" % str(e) 53 | if mode == "test": 54 | logger.error(information) 55 | raise e 56 | else: 57 | logger.warning(information) 58 | 59 | result["model"] = model 60 | if mode == "train": 61 | result["optimizer"] = optimizer 62 | result["trained_epoch"] = trained_epoch 63 | result["output_function"] = init_output_function(config) 64 | result["global_step"] = global_step 65 | 66 | logger.info("Initialize done.") 67 | 68 | return result 69 | -------------------------------------------------------------------------------- /tools/output_init.py: -------------------------------------------------------------------------------- 1 | from .output_tool import basic_output_function, null_output_function 2 | 3 | output_function_dic = { 4 | "Basic": basic_output_function, 5 | "Null": null_output_function 6 | } 7 | 8 | 9 | def init_output_function(config, *args, **params): 10 | name = config.get("output", "output_function") 11 | 12 | if name in output_function_dic: 13 | return output_function_dic[name] 14 | else: 15 | raise NotImplementedError 16 | -------------------------------------------------------------------------------- /tools/output_tool.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | from .accuracy_tool import gen_micro_macro_result 4 | 5 | 6 | def null_output_function(data, config, *args, **params): 7 | return "" 8 | 9 | 10 | def basic_output_function(data, config, *args, **params): 11 | temp = gen_micro_macro_result(data) 12 | result = [] 13 | for name in ["mip"]: 14 | result.append(round(temp[name] * 100, 2)) 15 | 16 | return json.dumps(result, sort_keys=True) 17 | -------------------------------------------------------------------------------- /tools/test_tool.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import torch 4 | from torch.autograd import Variable 5 | from timeit import default_timer as timer 6 | 7 | from tools.eval_tool import gen_time_str, output_value 8 | 9 | logger = logging.getLogger(__name__) 10 | 11 | 12 | def test(parameters, config, gpu_list): 13 | model = parameters["model"] 14 | dataset = parameters["test_dataset"] 15 | model.eval() 16 | 17 | acc_result = None 18 | total_loss = 0 19 | cnt = 0 20 | total_len = len(dataset) 21 | start_time = timer() 22 | output_info = "testing" 23 | 24 | output_time = config.getint("output", "output_time") 25 | step = -1 26 | result = [] 27 | 28 | for step, data in enumerate(dataset): 29 | for key in data.keys(): 30 | if isinstance(data[key], torch.Tensor): 31 | if len(gpu_list) > 0: 32 | data[key] = Variable(data[key].cuda()) 33 | else: 34 | data[key] = Variable(data[key]) 35 | 36 | results = model(data, config, gpu_list, acc_result, "test") 37 | result = result + results["output"] 38 | cnt += 1 39 | 40 | if step % output_time == 0: 41 | delta_t = timer() - start_time 42 | 43 | output_value(0, "test", "%d/%d" % (step + 1, total_len), "%s/%s" % ( 44 | gen_time_str(delta_t), gen_time_str(delta_t * (total_len - step - 1) / (step + 1))), 45 | "%.3lf" % (total_loss / (step + 1)), output_info, '\r', config) 46 | 47 | if step == -1: 48 | logger.error("There is no data given to the model in this epoch, check your data.") 49 | raise NotImplementedError 50 | 51 | delta_t = timer() - start_time 52 | output_info = "testing" 53 | output_value(0, "test", "%d/%d" % (step + 1, total_len), "%s/%s" % ( 54 | gen_time_str(delta_t), gen_time_str(delta_t * (total_len - step - 1) / (step + 1))), 55 | "%.3lf" % (total_loss / (step + 1)), output_info, None, config) 56 | 57 | res = {} 58 | for x in result: 59 | res[x["id"]] = x["answer"] 60 | 61 | return res 62 | -------------------------------------------------------------------------------- /tools/train_tool.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import torch 4 | from torch.autograd import Variable 5 | from torch.optim import lr_scheduler 6 | from tensorboardX import SummaryWriter 7 | import shutil 8 | from timeit import default_timer as timer 9 | 10 | from tools.eval_tool import valid, gen_time_str, output_value 11 | from tools.init_tool import init_test_dataset, init_formatter 12 | 13 | logger = logging.getLogger(__name__) 14 | 15 | 16 | def checkpoint(filename, model, optimizer, trained_epoch, config, global_step): 17 | model_to_save = model.module if hasattr(model, 'module') else model 18 | save_params = { 19 | "model": model_to_save.state_dict(), 20 | "optimizer_name": config.get("train", "optimizer"), 21 | "optimizer": optimizer.state_dict(), 22 | "trained_epoch": trained_epoch, 23 | "global_step": global_step 24 | } 25 | 26 | try: 27 | torch.save(save_params, filename) 28 | except Exception as e: 29 | logger.warning("Cannot save models with error %s, continue anyway" % str(e)) 30 | 31 | 32 | def train(parameters, config, gpu_list, do_test=False): 33 | epoch = config.getint("train", "epoch") 34 | batch_size = config.getint("train", "batch_size") 35 | 36 | output_time = config.getint("output", "output_time") 37 | test_time = config.getint("output", "test_time") 38 | 39 | output_path = os.path.join(config.get("output", "model_path"), config.get("output", "model_name")) 40 | if os.path.exists(output_path): 41 | logger.warning("Output path exists, check whether need to change a name of model") 42 | os.makedirs(output_path, exist_ok=True) 43 | 44 | trained_epoch = parameters["trained_epoch"] + 1 45 | model = parameters["model"] 46 | optimizer = parameters["optimizer"] 47 | dataset = parameters["train_dataset"] 48 | global_step = parameters["global_step"] 49 | output_function = parameters["output_function"] 50 | 51 | if do_test: 52 | init_formatter(config, ["test"]) 53 | test_dataset = init_test_dataset(config) 54 | 55 | if trained_epoch == 0: 56 | shutil.rmtree( 57 | os.path.join(config.get("output", "tensorboard_path"), config.get("output", "model_name")), True) 58 | 59 | os.makedirs(os.path.join(config.get("output", "tensorboard_path"), config.get("output", "model_name")), 60 | exist_ok=True) 61 | 62 | writer = SummaryWriter(os.path.join(config.get("output", "tensorboard_path"), config.get("output", "model_name")), 63 | config.get("output", "model_name")) 64 | 65 | step_size = config.getint("train", "step_size") 66 | gamma = config.getfloat("train", "lr_multiplier") 67 | exp_lr_scheduler = lr_scheduler.StepLR(optimizer, step_size=step_size, gamma=gamma) 68 | exp_lr_scheduler.step(trained_epoch) 69 | 70 | logger.info("Training start....") 71 | 72 | print("Epoch Stage Iterations Time Usage Loss Output Information") 73 | 74 | total_len = len(dataset) 75 | more = "" 76 | if total_len < 10000: 77 | more = "\t" 78 | for epoch_num in range(trained_epoch, epoch): 79 | start_time = timer() 80 | current_epoch = epoch_num 81 | 82 | exp_lr_scheduler.step(current_epoch) 83 | 84 | acc_result = None 85 | total_loss = 0 86 | 87 | output_info = "" 88 | step = -1 89 | for step, data in enumerate(dataset): 90 | for key in data.keys(): 91 | if isinstance(data[key], torch.Tensor): 92 | if len(gpu_list) > 0: 93 | data[key] = Variable(data[key].cuda()) 94 | else: 95 | data[key] = Variable(data[key]) 96 | 97 | optimizer.zero_grad() 98 | 99 | results = model(data, config, gpu_list, acc_result, "train") 100 | 101 | loss, acc_result = results["loss"], results["acc_result"] 102 | total_loss += float(loss) 103 | 104 | loss.backward() 105 | optimizer.step() 106 | 107 | if step % output_time == 0: 108 | output_info = output_function(acc_result, config) 109 | 110 | delta_t = timer() - start_time 111 | 112 | output_value(current_epoch, "train", "%d/%d" % (step + 1, total_len), "%s/%s" % ( 113 | gen_time_str(delta_t), gen_time_str(delta_t * (total_len - step - 1) / (step + 1))), 114 | "%.3lf" % (total_loss / (step + 1)), output_info, '\r', config) 115 | 116 | global_step += 1 117 | writer.add_scalar(config.get("output", "model_name") + "_train_iter", float(loss), global_step) 118 | 119 | output_value(current_epoch, "train", "%d/%d" % (step + 1, total_len), "%s/%s" % ( 120 | gen_time_str(delta_t), gen_time_str(delta_t * (total_len - step - 1) / (step + 1))), 121 | "%.3lf" % (total_loss / (step + 1)), output_info, None, config) 122 | 123 | if step == -1: 124 | logger.error("There is no data given to the model in this epoch, check your data.") 125 | raise NotImplementedError 126 | 127 | checkpoint(os.path.join(output_path, "%d.pkl" % current_epoch), model, optimizer, current_epoch, config, 128 | global_step) 129 | writer.add_scalar(config.get("output", "model_name") + "_train_epoch", float(total_loss) / (step + 1), 130 | current_epoch) 131 | 132 | if current_epoch % test_time == 0: 133 | with torch.no_grad(): 134 | valid(model, parameters["valid_dataset"], current_epoch, writer, config, gpu_list, output_function) 135 | if do_test: 136 | valid(model, test_dataset, current_epoch, writer, config, gpu_list, output_function, mode="test") 137 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import torch 4 | import logging 5 | 6 | from tools.init_tool import init_all 7 | from config_parser import create_config 8 | from tools.train_tool import train 9 | 10 | logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', 11 | datefmt='%m/%d/%Y %H:%M:%S', 12 | level=logging.INFO) 13 | 14 | logger = logging.getLogger(__name__) 15 | 16 | if __name__ == "__main__": 17 | parser = argparse.ArgumentParser() 18 | parser.add_argument('--config', '-c', help="specific config file", required=True) 19 | parser.add_argument('--gpu', '-g', help="gpu id list") 20 | parser.add_argument('--checkpoint', help="checkpoint file path") 21 | parser.add_argument('--do_test', help="do test while training or not", action="store_true") 22 | args = parser.parse_args() 23 | 24 | configFilePath = args.config 25 | 26 | use_gpu = True 27 | gpu_list = [] 28 | if args.gpu is None: 29 | use_gpu = False 30 | else: 31 | use_gpu = True 32 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu 33 | 34 | device_list = args.gpu.split(",") 35 | for a in range(0, len(device_list)): 36 | gpu_list.append(int(a)) 37 | 38 | os.system("clear") 39 | 40 | config = create_config(configFilePath) 41 | 42 | cuda = torch.cuda.is_available() 43 | logger.info("CUDA available: %s" % str(cuda)) 44 | if not cuda and len(gpu_list) > 0: 45 | logger.error("CUDA is not available but specific gpu id") 46 | raise NotImplementedError 47 | 48 | parameters = init_all(config, gpu_list, args.checkpoint, "train") 49 | do_test = False 50 | if args.do_test: 51 | do_test = True 52 | 53 | train(parameters, config, gpu_list, do_test) 54 | --------------------------------------------------------------------------------