├── utils ├── __init__.py ├── maths.py ├── utils.py ├── convert_obqa.py ├── optimization_utils.py ├── parser_utils.py ├── convert_csqa.py ├── tokenization_utils.py ├── grounding.py ├── conceptnet.py ├── graph.py └── data_utils.py ├── .gitignore ├── figs ├── task.png └── overview.png ├── download_preprocessed_data.sh ├── eval_qagnn__csqa.sh ├── eval_qagnn__obqa.sh ├── LICENSE ├── eval_qagnn__medqa_usmle.sh ├── download_raw_data.sh ├── run_qagnn__csqa.sh ├── run_qagnn__obqa.sh ├── run_qagnn__medqa_usmle.sh ├── README.md ├── modeling ├── modeling_encoder.py └── modeling_qagnn.py ├── preprocess.py ├── utils_biomed └── preprocess_medqa_usmle.ipynb └── qagnn.py /utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | 3 | .ipynb_checkpoints 4 | 5 | saved* 6 | -------------------------------------------------------------------------------- /figs/task.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/michiyasunaga/qagnn/HEAD/figs/task.png -------------------------------------------------------------------------------- /figs/overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/michiyasunaga/qagnn/HEAD/figs/overview.png -------------------------------------------------------------------------------- /download_preprocessed_data.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | mv data data_old 4 | 5 | wget https://nlp.stanford.edu/projects/myasu/QAGNN/data_preprocessed_release.zip 6 | unzip data_preprocessed_release.zip 7 | mv data_preprocessed_release data 8 | -------------------------------------------------------------------------------- /utils/maths.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scipy import sparse 3 | 4 | 5 | def normalize_sparse_adj(A, sparse_type='coo'): 6 | """ 7 | normalize A along the second axis 8 | 9 | A: scipy.sparse matrix 10 | sparse_type: str (optional, default 'coo') 11 | returns: scipy.sparse.coo_marix 12 | """ 13 | in_degree = np.array(A.sum(1)).reshape(-1) 14 | in_degree[in_degree == 0] = 1e-5 15 | d_inv = sparse.diags(1 / in_degree) 16 | A = getattr(d_inv.dot(A), 'to' + sparse_type)() 17 | return A 18 | -------------------------------------------------------------------------------- /eval_qagnn__csqa.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | export CUDA_VISIBLE_DEVICES=0 4 | dt=`date '+%Y%m%d_%H%M%S'` 5 | 6 | 7 | dataset="csqa" 8 | model='roberta-large' 9 | shift 10 | shift 11 | args=$@ 12 | 13 | 14 | echo "******************************" 15 | echo "dataset: $dataset" 16 | echo "******************************" 17 | 18 | save_dir_pref='saved_models' 19 | mkdir -p $save_dir_pref 20 | 21 | ###### Eval ###### 22 | python3 -u qagnn.py --dataset $dataset \ 23 | --train_adj data/${dataset}/graph/train.graph.adj.pk \ 24 | --dev_adj data/${dataset}/graph/dev.graph.adj.pk \ 25 | --test_adj data/${dataset}/graph/test.graph.adj.pk \ 26 | --train_statements data/${dataset}/statement/train.statement.jsonl \ 27 | --dev_statements data/${dataset}/statement/dev.statement.jsonl \ 28 | --test_statements data/${dataset}/statement/test.statement.jsonl \ 29 | --save_model \ 30 | --save_dir saved_models \ 31 | --mode eval_detail \ 32 | --load_model_path saved_models/csqa_model_hf3.4.0.pt \ 33 | $args 34 | -------------------------------------------------------------------------------- /eval_qagnn__obqa.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | export CUDA_VISIBLE_DEVICES=0 4 | dt=`date '+%Y%m%d_%H%M%S'` 5 | 6 | 7 | dataset="obqa" 8 | model='roberta-large' 9 | shift 10 | shift 11 | args=$@ 12 | 13 | 14 | echo "******************************" 15 | echo "dataset: $dataset" 16 | echo "******************************" 17 | 18 | save_dir_pref='saved_models' 19 | mkdir -p $save_dir_pref 20 | 21 | ###### Eval ###### 22 | python3 -u qagnn.py --dataset $dataset \ 23 | --train_adj data/${dataset}/graph/train.graph.adj.pk \ 24 | --dev_adj data/${dataset}/graph/dev.graph.adj.pk \ 25 | --test_adj data/${dataset}/graph/test.graph.adj.pk \ 26 | --train_statements data/${dataset}/statement/train.statement.jsonl \ 27 | --dev_statements data/${dataset}/statement/dev.statement.jsonl \ 28 | --test_statements data/${dataset}/statement/test.statement.jsonl \ 29 | --save_model \ 30 | --save_dir saved_models \ 31 | --mode eval_detail \ 32 | --load_model_path saved_models/obqa_model_hf3.4.0.pt \ 33 | $args 34 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Michihiro Yasunaga 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /eval_qagnn__medqa_usmle.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | export CUDA_VISIBLE_DEVICES=0 4 | dt=`date '+%Y%m%d_%H%M%S'` 5 | 6 | 7 | dataset="medqa_usmle" 8 | model='cambridgeltl/SapBERT-from-PubMedBERT-fulltext' 9 | ent_emb='ddb' 10 | shift 11 | shift 12 | args=$@ 13 | 14 | 15 | echo "******************************" 16 | echo "dataset: $dataset" 17 | echo "******************************" 18 | 19 | save_dir_pref='saved_models' 20 | mkdir -p $save_dir_pref 21 | 22 | ###### Eval ###### 23 | python3 -u qagnn.py --dataset $dataset \ 24 | --train_adj data/${dataset}/graph/dev.graph.adj.pk \ 25 | --dev_adj data/${dataset}/graph/dev.graph.adj.pk \ 26 | --test_adj data/${dataset}/graph/test.graph.adj.pk \ 27 | --train_statements data/${dataset}/statement/dev.statement.jsonl \ 28 | --dev_statements data/${dataset}/statement/dev.statement.jsonl \ 29 | --test_statements data/${dataset}/statement/test.statement.jsonl \ 30 | --ent_emb ${ent_emb} \ 31 | --save_model \ 32 | --save_dir saved_models \ 33 | --mode eval_detail \ 34 | --load_model_path saved_models/medqa_usmle_model_hf3.4.0.pt \ 35 | $args 36 | -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import time 4 | import argparse 5 | 6 | 7 | def bool_flag(v): 8 | if v.lower() in ('yes', 'true', 't', 'y', '1'): 9 | return True 10 | elif v.lower() in ('no', 'false', 'f', 'n', '0'): 11 | return False 12 | else: 13 | raise argparse.ArgumentTypeError('Boolean value expected.') 14 | 15 | 16 | def check_path(path): 17 | d = os.path.dirname(path) 18 | if not os.path.exists(d): 19 | os.makedirs(d) 20 | 21 | 22 | def check_file(file): 23 | return os.path.isfile(file) 24 | 25 | 26 | def export_config(config, path): 27 | param_dict = dict(vars(config)) 28 | check_path(path) 29 | with open(path, 'w') as fout: 30 | json.dump(param_dict, fout, indent=4) 31 | 32 | 33 | def freeze_net(module): 34 | for p in module.parameters(): 35 | p.requires_grad = False 36 | 37 | 38 | def unfreeze_net(module): 39 | for p in module.parameters(): 40 | p.requires_grad = True 41 | 42 | 43 | def test_data_loader_ms_per_batch(data_loader, max_steps=10000): 44 | start = time.time() 45 | n_batch = sum(1 for batch, _ in zip(data_loader, range(max_steps))) 46 | return (time.time() - start) * 1000 / n_batch 47 | -------------------------------------------------------------------------------- /download_raw_data.sh: -------------------------------------------------------------------------------- 1 | # download ConceptNet 2 | mkdir -p data/ 3 | mkdir -p data/cpnet/ 4 | wget -nc -P data/cpnet/ https://s3.amazonaws.com/conceptnet/downloads/2018/edges/conceptnet-assertions-5.6.0.csv.gz 5 | cd data/cpnet/ 6 | yes n | gzip -d conceptnet-assertions-5.6.0.csv.gz 7 | # download ConceptNet entity embedding 8 | wget https://csr.s3-us-west-1.amazonaws.com/tzw.ent.npy 9 | cd ../../ 10 | 11 | 12 | 13 | 14 | # download CommensenseQA dataset 15 | mkdir -p data/csqa/ 16 | wget -nc -P data/csqa/ https://s3.amazonaws.com/commensenseqa/train_rand_split.jsonl 17 | wget -nc -P data/csqa/ https://s3.amazonaws.com/commensenseqa/dev_rand_split.jsonl 18 | wget -nc -P data/csqa/ https://s3.amazonaws.com/commensenseqa/test_rand_split_no_answers.jsonl 19 | 20 | # create output folders 21 | mkdir -p data/csqa/grounded/ 22 | mkdir -p data/csqa/graph/ 23 | mkdir -p data/csqa/statement/ 24 | 25 | 26 | 27 | # download OpenBookQA dataset 28 | wget -nc -P data/obqa/ https://s3-us-west-2.amazonaws.com/ai2-website/data/OpenBookQA-V1-Sep2018.zip 29 | yes n | unzip data/obqa/OpenBookQA-V1-Sep2018.zip -d data/obqa/ 30 | 31 | # create output folders 32 | mkdir -p data/obqa/fairseq/official/ 33 | mkdir -p data/obqa/grounded/ 34 | mkdir -p data/obqa/graph/ 35 | mkdir -p data/obqa/statement/ 36 | -------------------------------------------------------------------------------- /run_qagnn__csqa.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | export CUDA_VISIBLE_DEVICES=0,1 4 | dt=`date '+%Y%m%d_%H%M%S'` 5 | 6 | 7 | dataset="csqa" 8 | model='roberta-large' 9 | shift 10 | shift 11 | args=$@ 12 | 13 | 14 | elr="1e-5" 15 | dlr="1e-3" 16 | bs=64 17 | mbs=2 18 | n_epochs=15 19 | num_relation=38 #(17 +2) * 2: originally 17, add 2 relation types (QA context -> Q node; QA context -> A node), and double because we add reverse edges 20 | 21 | 22 | k=5 #num of gnn layers 23 | gnndim=200 24 | 25 | echo "***** hyperparameters *****" 26 | echo "dataset: $dataset" 27 | echo "enc_name: $model" 28 | echo "batch_size: $bs" 29 | echo "learning_rate: elr $elr dlr $dlr" 30 | echo "gnn: dim $gnndim layer $k" 31 | echo "******************************" 32 | 33 | save_dir_pref='saved_models' 34 | mkdir -p $save_dir_pref 35 | mkdir -p logs 36 | 37 | ###### Training ###### 38 | for seed in 0; do 39 | python3 -u qagnn.py --dataset $dataset \ 40 | --encoder $model -k $k --gnn_dim $gnndim -elr $elr -dlr $dlr -bs $bs -mbs $mbs --fp16 true --seed $seed \ 41 | --num_relation $num_relation \ 42 | --n_epochs $n_epochs --max_epochs_before_stop 10 \ 43 | --train_adj data/${dataset}/graph/train.graph.adj.pk \ 44 | --dev_adj data/${dataset}/graph/dev.graph.adj.pk \ 45 | --test_adj data/${dataset}/graph/test.graph.adj.pk \ 46 | --train_statements data/${dataset}/statement/train.statement.jsonl \ 47 | --dev_statements data/${dataset}/statement/dev.statement.jsonl \ 48 | --test_statements data/${dataset}/statement/test.statement.jsonl \ 49 | --save_model \ 50 | --save_dir ${save_dir_pref}/${dataset}/enc-${model}__k${k}__gnndim${gnndim}__bs${bs}__seed${seed}__${dt} $args \ 51 | > logs/train_${dataset}__enc-${model}__k${k}__gnndim${gnndim}__bs${bs}__seed${seed}__${dt}.log.txt 52 | done 53 | -------------------------------------------------------------------------------- /run_qagnn__obqa.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | export CUDA_VISIBLE_DEVICES=0,1 4 | dt=`date '+%Y%m%d_%H%M%S'` 5 | 6 | 7 | dataset="obqa" 8 | model='roberta-large' 9 | shift 10 | shift 11 | args=$@ 12 | 13 | 14 | elr="1e-5" 15 | dlr="1e-3" 16 | bs=128 17 | mbs=1 18 | n_epochs=100 19 | num_relation=38 #(17 +2) * 2: originally 17, add 2 relation types (QA context -> Q node; QA context -> A node), and double because we add reverse edges 20 | 21 | 22 | k=5 #num of gnn layers 23 | gnndim=200 24 | 25 | echo "***** hyperparameters *****" 26 | echo "dataset: $dataset" 27 | echo "enc_name: $model" 28 | echo "batch_size: $bs" 29 | echo "learning_rate: elr $elr dlr $dlr" 30 | echo "gnn: dim $gnndim layer $k" 31 | echo "******************************" 32 | 33 | save_dir_pref='saved_models' 34 | mkdir -p $save_dir_pref 35 | mkdir -p logs 36 | 37 | ###### Training ###### 38 | for seed in 0; do 39 | python3 -u qagnn.py --dataset $dataset \ 40 | --encoder $model -k $k --gnn_dim $gnndim -elr $elr -dlr $dlr -bs $bs -mbs $mbs --fp16 true --seed $seed \ 41 | --num_relation $num_relation \ 42 | --n_epochs $n_epochs --max_epochs_before_stop 50 \ 43 | --train_adj data/${dataset}/graph/train.graph.adj.pk \ 44 | --dev_adj data/${dataset}/graph/dev.graph.adj.pk \ 45 | --test_adj data/${dataset}/graph/test.graph.adj.pk \ 46 | --train_statements data/${dataset}/statement/train.statement.jsonl \ 47 | --dev_statements data/${dataset}/statement/dev.statement.jsonl \ 48 | --test_statements data/${dataset}/statement/test.statement.jsonl \ 49 | --save_model \ 50 | --save_dir ${save_dir_pref}/${dataset}/enc-${model}__k${k}__gnndim${gnndim}__bs${bs}__seed${seed}__${dt} $args \ 51 | > logs/train_${dataset}__enc-${model}__k${k}__gnndim${gnndim}__bs${bs}__seed${seed}__${dt}.log.txt 52 | done 53 | -------------------------------------------------------------------------------- /utils/convert_obqa.py: -------------------------------------------------------------------------------- 1 | import json 2 | import re 3 | import sys 4 | from tqdm import tqdm 5 | 6 | __all__ = ['convert_to_obqa_statement'] 7 | 8 | # String used to indicate a blank 9 | BLANK_STR = "___" 10 | 11 | 12 | def convert_to_obqa_statement(qa_file: str, output_file1: str, output_file2: str): 13 | print(f'converting {qa_file} to entailment dataset...') 14 | nrow = sum(1 for _ in open(qa_file, 'r')) 15 | with open(output_file1, 'w') as output_handle1, open(output_file2, 'w') as output_handle2, open(qa_file, 'r') as qa_handle: 16 | # print("Writing to {} from {}".format(output_file, qa_file)) 17 | for line in tqdm(qa_handle, total=nrow): 18 | json_line = json.loads(line) 19 | output_dict = convert_qajson_to_entailment(json_line) 20 | output_handle1.write(json.dumps(output_dict)) 21 | output_handle1.write("\n") 22 | output_handle2.write(json.dumps(output_dict)) 23 | output_handle2.write("\n") 24 | print(f'converted statements saved to {output_file1}, {output_file2}') 25 | print() 26 | 27 | 28 | # Convert the QA file json to output dictionary containing premise and hypothesis 29 | def convert_qajson_to_entailment(qa_json: dict): 30 | question_text = qa_json["question"]["stem"] 31 | choices = qa_json["question"]["choices"] 32 | for choice in choices: 33 | choice_text = choice["text"] 34 | statement = question_text + ' ' + choice_text 35 | create_output_dict(qa_json, statement, choice["label"] == qa_json.get("answerKey", "A")) 36 | 37 | return qa_json 38 | 39 | 40 | # Create the output json dictionary from the input json, premise and hypothesis statement 41 | def create_output_dict(input_json: dict, statement: str, label: bool) -> dict: 42 | if "statements" not in input_json: 43 | input_json["statements"] = [] 44 | input_json["statements"].append({"label": label, "statement": statement}) 45 | return input_json 46 | -------------------------------------------------------------------------------- /run_qagnn__medqa_usmle.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | export CUDA_VISIBLE_DEVICES=0 4 | dt=`date '+%Y%m%d_%H%M%S'` 5 | 6 | 7 | dataset="medqa_usmle" 8 | model='cambridgeltl/SapBERT-from-PubMedBERT-fulltext' 9 | shift 10 | shift 11 | args=$@ 12 | 13 | 14 | elr="5e-5" 15 | dlr="1e-3" 16 | bs=128 17 | mbs=2 18 | sl=512 19 | n_epochs=15 20 | ent_emb='ddb' 21 | num_relation=34 #(15 +2) * 2: originally 15, add 2 relation types (QA context -> Q node; QA context -> A node), and double because we add reverse edges 22 | 23 | 24 | k=5 #num of gnn layers 25 | gnndim=200 26 | unfrz=0 27 | 28 | 29 | echo "***** hyperparameters *****" 30 | echo "dataset: $dataset" 31 | echo "enc_name: $model" 32 | echo "batch_size: $bs" 33 | echo "learning_rate: elr $elr dlr $dlr" 34 | echo "gnn: dim $gnndim layer $k" 35 | echo "******************************" 36 | 37 | save_dir_pref='saved_models' 38 | mkdir -p $save_dir_pref 39 | mkdir -p logs 40 | 41 | ###### Training ###### 42 | for seed in 0; do 43 | python3 -u qagnn.py --dataset $dataset \ 44 | --encoder $model -k $k --gnn_dim $gnndim -elr $elr -dlr $dlr -bs $bs -mbs $mbs -sl $sl --fp16 true --seed $seed \ 45 | --num_relation $num_relation \ 46 | --n_epochs $n_epochs --max_epochs_before_stop 10 --unfreeze_epoch $unfrz \ 47 | --train_adj data/${dataset}/graph/train.graph.adj.pk \ 48 | --dev_adj data/${dataset}/graph/dev.graph.adj.pk \ 49 | --test_adj data/${dataset}/graph/test.graph.adj.pk \ 50 | --train_statements data/${dataset}/statement/train.statement.jsonl \ 51 | --dev_statements data/${dataset}/statement/dev.statement.jsonl \ 52 | --test_statements data/${dataset}/statement/test.statement.jsonl \ 53 | --ent_emb ${ent_emb} \ 54 | --save_model \ 55 | --save_dir ${save_dir_pref}/${dataset}/enc-sapbert__k${k}__gnndim${gnndim}__bs${bs}__seed${seed}__${dt} $args \ 56 | > logs/train_${dataset}__enc-sapbert__k${k}__gnndim${gnndim}__bs${bs}__sl${sl}__unfrz${unfrz}__seed${seed}__${dt}.log.txt 57 | done 58 | -------------------------------------------------------------------------------- /utils/optimization_utils.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from transformers import AdamW 4 | from torch.optim import SGD, Adam 5 | from torch.optim.optimizer import Optimizer 6 | 7 | 8 | class RAdam(Optimizer): 9 | 10 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, degenerated_to_sgd=True): 11 | if not 0.0 <= lr: 12 | raise ValueError("Invalid learning rate: {}".format(lr)) 13 | if not 0.0 <= eps: 14 | raise ValueError("Invalid epsilon value: {}".format(eps)) 15 | if not 0.0 <= betas[0] < 1.0: 16 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) 17 | if not 0.0 <= betas[1] < 1.0: 18 | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) 19 | 20 | self.degenerated_to_sgd = degenerated_to_sgd 21 | if isinstance(params, (list, tuple)) and len(params) > 0 and isinstance(params[0], dict): 22 | for param in params: 23 | if 'betas' in param and (param['betas'][0] != betas[0] or param['betas'][1] != betas[1]): 24 | param['buffer'] = [[None, None, None] for _ in range(10)] 25 | defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, buffer=[[None, None, None] for _ in range(10)]) 26 | super(RAdam, self).__init__(params, defaults) 27 | 28 | def __setstate__(self, state): 29 | super(RAdam, self).__setstate__(state) 30 | 31 | def step(self, closure=None): 32 | 33 | loss = None 34 | if closure is not None: 35 | loss = closure() 36 | 37 | for group in self.param_groups: 38 | 39 | for p in group['params']: 40 | if p.grad is None: 41 | continue 42 | grad = p.grad.data.float() 43 | if grad.is_sparse: 44 | raise RuntimeError('RAdam does not support sparse gradients') 45 | 46 | p_data_fp32 = p.data.float() 47 | 48 | state = self.state[p] 49 | 50 | if len(state) == 0: 51 | state['step'] = 0 52 | state['exp_avg'] = torch.zeros_like(p_data_fp32) 53 | state['exp_avg_sq'] = torch.zeros_like(p_data_fp32) 54 | else: 55 | state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32) 56 | state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32) 57 | 58 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 59 | beta1, beta2 = group['betas'] 60 | 61 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) 62 | exp_avg.mul_(beta1).add_(1 - beta1, grad) 63 | 64 | state['step'] += 1 65 | buffered = group['buffer'][int(state['step'] % 10)] 66 | if state['step'] == buffered[0]: 67 | N_sma, step_size = buffered[1], buffered[2] 68 | else: 69 | buffered[0] = state['step'] 70 | beta2_t = beta2 ** state['step'] 71 | N_sma_max = 2 / (1 - beta2) - 1 72 | N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t) 73 | buffered[1] = N_sma 74 | 75 | # more conservative since it's an approximated value 76 | if N_sma >= 5: 77 | step_size = math.sqrt((1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / (N_sma_max - 2)) / (1 - beta1 ** state['step']) 78 | elif self.degenerated_to_sgd: 79 | step_size = 1.0 / (1 - beta1 ** state['step']) 80 | else: 81 | step_size = -1 82 | buffered[2] = step_size 83 | 84 | # more conservative since it's an approximated value 85 | if N_sma >= 5: 86 | if group['weight_decay'] != 0: 87 | p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32) 88 | denom = exp_avg_sq.sqrt().add_(group['eps']) 89 | p_data_fp32.addcdiv_(-step_size * group['lr'], exp_avg, denom) 90 | p.data.copy_(p_data_fp32) 91 | elif step_size > 0: 92 | if group['weight_decay'] != 0: 93 | p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32) 94 | p_data_fp32.add_(-step_size * group['lr'], exp_avg) 95 | p.data.copy_(p_data_fp32) 96 | 97 | return loss 98 | 99 | 100 | OPTIMIZER_CLASSES = { 101 | 'sgd': SGD, 102 | 'adam': Adam, 103 | 'adamw': AdamW, 104 | 'radam': RAdam, 105 | } 106 | 107 | 108 | def run_test(): 109 | import torch.nn as nn 110 | model = nn.Sequential(*[nn.Linear(100, 10), nn.ReLU(), nn.Linear(10, 2)]) 111 | x = torch.randn(10, 100).repeat(100, 1) 112 | y = torch.randint(0, 2, (10,)).repeat(100) 113 | crit = nn.CrossEntropyLoss() 114 | optim = RAdam(model.parameters(), lr=1e-2, weight_decay=0.01) 115 | model.train() 116 | for a in range(0, 1000, 10): 117 | b = a + 10 118 | loss = crit(model(x[a:b]), y[a:b]) 119 | loss.backward() 120 | optim.step() 121 | print('| loss: {:.4f} |'.format(loss.item())) 122 | 123 | 124 | if __name__ == '__main__': 125 | run_test() 126 | -------------------------------------------------------------------------------- /utils/parser_utils.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from utils.utils import * 3 | from modeling.modeling_encoder import MODEL_NAME_TO_CLASS 4 | 5 | ENCODER_DEFAULT_LR = { 6 | 'default': 1e-3, 7 | 'csqa': { 8 | 'lstm': 3e-4, 9 | 'openai-gpt': 1e-4, 10 | 'bert-base-uncased': 3e-5, 11 | 'bert-large-uncased': 2e-5, 12 | 'roberta-large': 1e-5, 13 | }, 14 | 'obqa': { 15 | 'lstm': 3e-4, 16 | 'openai-gpt': 3e-5, 17 | 'bert-base-cased': 1e-4, 18 | 'bert-large-cased': 1e-4, 19 | 'roberta-large': 1e-5, 20 | }, 21 | 'medqa_usmle': { 22 | 'cambridgeltl/SapBERT-from-PubMedBERT-fulltext': 5e-5, 23 | }, 24 | } 25 | 26 | DATASET_LIST = ['csqa', 'obqa', 'socialiqa', 'medqa_usmle'] 27 | 28 | DATASET_SETTING = { 29 | 'csqa': 'inhouse', 30 | 'obqa': 'official', 31 | 'socialiqa': 'official', 32 | 'medqa_usmle': 'official', 33 | } 34 | 35 | DATASET_NO_TEST = ['socialiqa'] 36 | 37 | EMB_PATHS = { 38 | 'transe': 'data/transe/glove.transe.sgd.ent.npy', 39 | 'lm': 'data/transe/glove.transe.sgd.ent.npy', 40 | 'numberbatch': 'data/transe/concept.nb.npy', 41 | 'tzw': 'data/cpnet/tzw.ent.npy', 42 | 'ddb': 'data/ddb/ent_emb.npy', 43 | } 44 | 45 | 46 | def add_data_arguments(parser): 47 | # arguments that all datasets share 48 | parser.add_argument('--ent_emb', default=['tzw'], nargs='+', help='sources for entity embeddings') 49 | # dataset specific 50 | parser.add_argument('-ds', '--dataset', default='csqa', choices=DATASET_LIST, help='dataset name') 51 | parser.add_argument('-ih', '--inhouse', type=bool_flag, nargs='?', const=True, help='run in-house setting') 52 | parser.add_argument('--inhouse_train_qids', default='data/{dataset}/inhouse_split_qids.txt', help='qids of the in-house training set') 53 | # statements 54 | parser.add_argument('--train_statements', default='data/{dataset}/statement/train.statement.jsonl') 55 | parser.add_argument('--dev_statements', default='data/{dataset}/statement/dev.statement.jsonl') 56 | parser.add_argument('--test_statements', default='data/{dataset}/statement/test.statement.jsonl') 57 | # preprocessing options 58 | parser.add_argument('-sl', '--max_seq_len', default=100, type=int) 59 | # set dataset defaults 60 | args, _ = parser.parse_known_args() 61 | parser.set_defaults(ent_emb_paths=[EMB_PATHS.get(s) for s in args.ent_emb], 62 | inhouse=(DATASET_SETTING[args.dataset] == 'inhouse'), 63 | inhouse_train_qids=args.inhouse_train_qids.format(dataset=args.dataset)) 64 | data_splits = ('train', 'dev') if args.dataset in DATASET_NO_TEST else ('train', 'dev', 'test') 65 | for split in data_splits: 66 | for attribute in ('statements',): 67 | attr_name = f'{split}_{attribute}' 68 | parser.set_defaults(**{attr_name: getattr(args, attr_name).format(dataset=args.dataset)}) 69 | if 'test' not in data_splits: 70 | parser.set_defaults(test_statements=None) 71 | 72 | 73 | def add_encoder_arguments(parser): 74 | parser.add_argument('-enc', '--encoder', default='bert-large-uncased', help='encoder type') 75 | parser.add_argument('--encoder_layer', default=-1, type=int, help='encoder layer ID to use as features (used only by non-LSTM encoders)') 76 | parser.add_argument('-elr', '--encoder_lr', default=2e-5, type=float, help='learning rate') 77 | args, _ = parser.parse_known_args() 78 | parser.set_defaults(encoder_lr=ENCODER_DEFAULT_LR[args.dataset].get(args.encoder, ENCODER_DEFAULT_LR['default'])) 79 | 80 | 81 | def add_optimization_arguments(parser): 82 | parser.add_argument('--loss', default='cross_entropy', choices=['margin_rank', 'cross_entropy'], help='model type') 83 | parser.add_argument('--optim', default='radam', choices=['sgd', 'adam', 'adamw', 'radam'], help='learning rate scheduler') 84 | parser.add_argument('--lr_schedule', default='fixed', choices=['fixed', 'warmup_linear', 'warmup_constant'], help='learning rate scheduler') 85 | parser.add_argument('-bs', '--batch_size', default=32, type=int) 86 | parser.add_argument('--warmup_steps', type=float, default=150) 87 | parser.add_argument('--max_grad_norm', default=1.0, type=float, help='max grad norm (0 to disable)') 88 | parser.add_argument('--weight_decay', default=1e-2, type=float, help='l2 weight decay strength') 89 | parser.add_argument('--n_epochs', default=100, type=int, help='total number of training epochs to perform.') 90 | parser.add_argument('-me', '--max_epochs_before_stop', default=10, type=int, help='stop training if dev does not increase for N epochs') 91 | 92 | 93 | def add_additional_arguments(parser): 94 | parser.add_argument('--log_interval', default=10, type=int) 95 | parser.add_argument('--cuda', default=True, type=bool_flag, nargs='?', const=True, help='use GPU') 96 | parser.add_argument('--seed', default=0, type=int, help='random seed') 97 | parser.add_argument('--debug', default=False, type=bool_flag, nargs='?', const=True, help='run in debug mode') 98 | args, _ = parser.parse_known_args() 99 | if args.debug: 100 | parser.set_defaults(batch_size=1, log_interval=1, eval_interval=5) 101 | 102 | 103 | def get_parser(): 104 | """A helper function that handles the arguments that all models share""" 105 | parser = argparse.ArgumentParser(add_help=False) 106 | add_data_arguments(parser) 107 | add_encoder_arguments(parser) 108 | add_optimization_arguments(parser) 109 | add_additional_arguments(parser) 110 | return parser 111 | 112 | 113 | def get_lstm_config_from_args(args): 114 | lstm_config = { 115 | 'hidden_size': args.encoder_dim, 116 | 'output_size': args.encoder_dim, 117 | 'num_layers': args.encoder_layer_num, 118 | 'bidirectional': args.encoder_bidir, 119 | 'emb_p': args.encoder_dropoute, 120 | 'input_p': args.encoder_dropouti, 121 | 'hidden_p': args.encoder_dropouth, 122 | 'pretrained_emb_or_path': args.encoder_pretrained_emb, 123 | 'freeze_emb': args.encoder_freeze_emb, 124 | 'pool_function': args.encoder_pooler, 125 | } 126 | return lstm_config 127 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # QA-GNN: Question Answering using Language Models and Knowledge Graphs 2 | 3 | This repo provides the source code & data of our paper: [QA-GNN: Reasoning with Language Models and Knowledge Graphs for Question Answering](https://arxiv.org/abs/2104.06378) (NAACL 2021). 4 | ```bib 5 | @InProceedings{yasunaga2021qagnn, 6 | author = {Michihiro Yasunaga and Hongyu Ren and Antoine Bosselut and Percy Liang and Jure Leskovec}, 7 | title = {QA-GNN: Reasoning with Language Models and Knowledge Graphs for Question Answering}, 8 | year = {2021}, 9 | booktitle = {North American Chapter of the Association for Computational Linguistics (NAACL)}, 10 | } 11 | ``` 12 | Webpage: [https://snap.stanford.edu/qagnn](https://snap.stanford.edu/qagnn) 13 |

14 | 15 |

16 |

17 | 18 |

19 | 20 | 21 | ## Usage 22 | ### 0. Dependencies 23 | Run the following commands to create a conda environment (assuming CUDA10.1): 24 | ```bash 25 | conda create -n qagnn python=3.7 26 | source activate qagnn 27 | pip install torch==1.8.0+cu101 -f https://download.pytorch.org/whl/torch_stable.html 28 | pip install transformers==3.4.0 29 | pip install nltk spacy==2.1.6 30 | python -m spacy download en 31 | 32 | # for torch-geometric 33 | pip install torch-scatter==2.0.7 -f https://pytorch-geometric.com/whl/torch-1.8.0+cu101.html 34 | pip install torch-sparse==0.6.9 -f https://pytorch-geometric.com/whl/torch-1.8.0+cu101.html 35 | pip install torch-geometric==1.7.0 -f https://pytorch-geometric.com/whl/torch-1.8.0+cu101.html 36 | ``` 37 | 38 | 39 | ### 1. Download data 40 | We use the question answering datasets (*CommonsenseQA*, *OpenBookQA*) and the ConceptNet knowledge graph. 41 | Download all the raw data by 42 | ``` 43 | ./download_raw_data.sh 44 | ``` 45 | 46 | Preprocess the raw data by running 47 | ``` 48 | python preprocess.py -p 49 | ``` 50 | The script will: 51 | * Setup ConceptNet (e.g., extract English relations from ConceptNet, merge the original 42 relation types into 17 types) 52 | * Convert the QA datasets into .jsonl files (e.g., stored in `data/csqa/statement/`) 53 | * Identify all mentioned concepts in the questions and answers 54 | * Extract subgraphs for each q-a pair 55 | 56 | **TL;DR (Skip above steps and just get preprocessed data)**. The preprocessing may take long. For your convenience, you can download all the processed data by 57 | ``` 58 | ./download_preprocessed_data.sh 59 | ``` 60 | 61 | **🔴 NEWS (Add MedQA-USMLE)**. Besides the commonsense QA datasets (*CommonsenseQA*, *OpenBookQA*) with the ConceptNet knowledge graph, we added a biomedical QA dataset ([*MedQA-USMLE*](https://github.com/jind11/MedQA)) with a biomedical knowledge graph based on Disease Database and DrugBank. You can download all the data for this from [**[here]**](https://nlp.stanford.edu/projects/myasu/QAGNN/data_preprocessed_biomed.zip). Unzip it and put the `medqa_usmle` and `ddb` folders inside the `data/` directory. While this data is already preprocessed, we also provide the preprocessing scripts we used in `utils_biomed/`. 62 | 63 | 64 | The resulting file structure will look like: 65 | 66 | ```plain 67 | . 68 | ├── README.md 69 | ├── data/ 70 | ├── cpnet/ (prerocessed ConceptNet) 71 | ├── csqa/ 72 | ├── train_rand_split.jsonl 73 | ├── dev_rand_split.jsonl 74 | ├── test_rand_split_no_answers.jsonl 75 | ├── statement/ (converted statements) 76 | ├── grounded/ (grounded entities) 77 | ├── graphs/ (extracted subgraphs) 78 | ├── ... 79 | ├── obqa/ 80 | ├── medqa_usmle/ 81 | └── ddb/ 82 | ``` 83 | 84 | ### 2. Train QA-GNN 85 | For CommonsenseQA, run 86 | ``` 87 | ./run_qagnn__csqa.sh 88 | ``` 89 | For OpenBookQA, run 90 | ``` 91 | ./run_qagnn__obqa.sh 92 | ``` 93 | For MedQA-USMLE, run 94 | ``` 95 | ./run_qagnn__medqa_usmle.sh 96 | ``` 97 | As configured in these scripts, the model needs two types of input files 98 | * `--{train,dev,test}_statements`: preprocessed question statements in jsonl format. This is mainly loaded by `load_input_tensors` function in `utils/data_utils.py`. 99 | * `--{train,dev,test}_adj`: information of the KG subgraph extracted for each question. This is mainly loaded by `load_sparse_adj_data_with_contextnode` function in `utils/data_utils.py`. 100 | 101 | **Note**: We find that training for OpenBookQA is unstable (e.g. best dev accuracy varies when using different seeds, different versions of the transformers / torch-geometric libraries, etc.), likely because the dataset is small. We suggest trying out different seeds. Another potential way to stabilize training is to initialize the model with one of the successful checkpoints provided below, e.g. by adding an argument `--load_model_path obqa_model.pt`. 102 | 103 | 104 | ### 3. Evaluate trained model 105 | For CommonsenseQA, run 106 | ``` 107 | ./eval_qagnn__csqa.sh 108 | ``` 109 | Similarly, for other datasets (OpenBookQA, MedQA-USMLE), run `./eval_qagnn__obqa.sh` and `./eval_qagnn__medqa_usmle.sh`. 110 | You can download trained model checkpoints in the next section. 111 | 112 | 113 | ## Trained model examples 114 | CommonsenseQA 115 | 116 | 117 | 118 | 119 | 120 | 121 | 122 | 123 | 124 | 125 | 126 |
Trained modelIn-house Dev acc.In-house Test acc.
RoBERTa-large + QA-GNN [link]0.77070.7405
127 | 128 | OpenBookQA 129 | 130 | 131 | 132 | 133 | 134 | 135 | 136 | 137 | 138 | 139 | 140 |
Trained modelDev acc.Test acc.
RoBERTa-large + QA-GNN [link]0.69600.6900
141 | 142 | MedQA-USMLE 143 | 144 | 145 | 146 | 147 | 148 | 149 | 150 | 151 | 152 | 153 | 154 |
Trained modelDev acc.Test acc.
SapBERT-base + QA-GNN [link]0.37890.3810
155 | 156 | **Note**: The models were trained and tested with HuggingFace transformers==3.4.0. 157 | 158 | 159 | ## Use your own dataset 160 | - Convert your dataset to `{train,dev,test}.statement.jsonl` in .jsonl format (see `data/csqa/statement/train.statement.jsonl`) 161 | - Create a directory in `data/{yourdataset}/` to store the .jsonl files 162 | - Modify `preprocess.py` and perform subgraph extraction for your data 163 | - Modify `utils/parser_utils.py` to support your own dataset 164 | 165 | 166 | ## Acknowledgment 167 | This repo is built upon the following work: 168 | ``` 169 | Scalable Multi-Hop Relational Reasoning for Knowledge-Aware Question Answering. Yanlin Feng*, Xinyue Chen*, Bill Yuchen Lin, Peifeng Wang, Jun Yan and Xiang Ren. EMNLP 2020. 170 | https://github.com/INK-USC/MHGRN 171 | ``` 172 | Many thanks to the authors and developers! 173 | -------------------------------------------------------------------------------- /modeling/modeling_encoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | from transformers import (OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP, BERT_PRETRAINED_CONFIG_ARCHIVE_MAP, 6 | XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP, ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP) 7 | try: 8 | from transformers import ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP 9 | except: 10 | pass 11 | from transformers import AutoModel, BertModel, BertConfig 12 | # from transformers.modeling_bert import BERT_PRETRAINED_MODEL_ARCHIVE_MAP 13 | from utils.layers import * 14 | from utils.data_utils import get_gpt_token_num 15 | 16 | MODEL_CLASS_TO_NAME = { 17 | 'gpt': list(OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP.keys()), 18 | 'bert': list(BERT_PRETRAINED_CONFIG_ARCHIVE_MAP.keys()), 19 | 'xlnet': list(XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP.keys()), 20 | 'roberta': list(ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP.keys()), 21 | 'lstm': ['lstm'], 22 | } 23 | try: 24 | MODEL_CLASS_TO_NAME['albert'] = list(ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP.keys()) 25 | except: 26 | pass 27 | 28 | MODEL_NAME_TO_CLASS = {model_name: model_class for model_class, model_name_list in MODEL_CLASS_TO_NAME.items() for model_name in model_name_list} 29 | 30 | #Add SapBERT configuration 31 | model_name = 'cambridgeltl/SapBERT-from-PubMedBERT-fulltext' 32 | MODEL_NAME_TO_CLASS[model_name] = 'bert' 33 | 34 | 35 | class LSTMTextEncoder(nn.Module): 36 | pool_layer_classes = {'mean': MeanPoolLayer, 'max': MaxPoolLayer} 37 | 38 | def __init__(self, vocab_size=1, emb_size=300, hidden_size=300, output_size=300, num_layers=2, bidirectional=True, 39 | emb_p=0.0, input_p=0.0, hidden_p=0.0, pretrained_emb_or_path=None, freeze_emb=True, 40 | pool_function='max', output_hidden_states=False): 41 | super().__init__() 42 | self.output_size = output_size 43 | self.num_layers = num_layers 44 | self.output_hidden_states = output_hidden_states 45 | assert not bidirectional or hidden_size % 2 == 0 46 | 47 | if pretrained_emb_or_path is not None: 48 | if isinstance(pretrained_emb_or_path, str): # load pretrained embedding from a .npy file 49 | pretrained_emb_or_path = torch.tensor(np.load(pretrained_emb_or_path), dtype=torch.float) 50 | emb = nn.Embedding.from_pretrained(pretrained_emb_or_path, freeze=freeze_emb) 51 | emb_size = emb.weight.size(1) 52 | else: 53 | emb = nn.Embedding(vocab_size, emb_size) 54 | self.emb = EmbeddingDropout(emb, emb_p) 55 | self.rnns = nn.ModuleList([nn.LSTM(emb_size if l == 0 else hidden_size, 56 | (hidden_size if l != num_layers else output_size) // (2 if bidirectional else 1), 57 | 1, bidirectional=bidirectional, batch_first=True) for l in range(num_layers)]) 58 | self.pooler = self.pool_layer_classes[pool_function]() 59 | 60 | self.input_dropout = nn.Dropout(input_p) 61 | self.hidden_dropout = nn.ModuleList([RNNDropout(hidden_p) for _ in range(num_layers)]) 62 | 63 | def forward(self, inputs, lengths): 64 | """ 65 | inputs: tensor of shape (batch_size, seq_len) 66 | lengths: tensor of shape (batch_size) 67 | 68 | returns: tensor of shape (batch_size, hidden_size) 69 | """ 70 | assert (lengths > 0).all() 71 | batch_size, seq_len = inputs.size() 72 | hidden_states = self.input_dropout(self.emb(inputs)) 73 | all_hidden_states = [hidden_states] 74 | for l, (rnn, hid_dp) in enumerate(zip(self.rnns, self.hidden_dropout)): 75 | hidden_states = pack_padded_sequence(hidden_states, lengths, batch_first=True, enforce_sorted=False) 76 | hidden_states, _ = rnn(hidden_states) 77 | hidden_states, _ = pad_packed_sequence(hidden_states, batch_first=True, total_length=seq_len) 78 | all_hidden_states.append(hidden_states) 79 | if l != self.num_layers - 1: 80 | hidden_states = hid_dp(hidden_states) 81 | pooled = self.pooler(all_hidden_states[-1], lengths) 82 | assert len(all_hidden_states) == self.num_layers + 1 83 | outputs = (all_hidden_states[-1], pooled) 84 | if self.output_hidden_states: 85 | outputs = outputs + (all_hidden_states,) 86 | return outputs 87 | 88 | 89 | class TextEncoder(nn.Module): 90 | valid_model_types = set(MODEL_CLASS_TO_NAME.keys()) 91 | 92 | def __init__(self, model_name, output_token_states=False, from_checkpoint=None, **kwargs): 93 | super().__init__() 94 | self.model_type = MODEL_NAME_TO_CLASS[model_name] 95 | self.output_token_states = output_token_states 96 | assert not self.output_token_states or self.model_type in ('bert', 'roberta', 'albert') 97 | 98 | if self.model_type in ('lstm',): 99 | self.module = LSTMTextEncoder(**kwargs, output_hidden_states=True) 100 | self.sent_dim = self.module.output_size 101 | else: 102 | model_class = AutoModel 103 | self.module = model_class.from_pretrained(model_name, output_hidden_states=True) 104 | if from_checkpoint is not None: 105 | self.module = self.module.from_pretrained(from_checkpoint, output_hidden_states=True) 106 | if self.model_type in ('gpt',): 107 | self.module.resize_token_embeddings(get_gpt_token_num()) 108 | self.sent_dim = self.module.config.n_embd if self.model_type in ('gpt',) else self.module.config.hidden_size 109 | 110 | def forward(self, *inputs, layer_id=-1): 111 | ''' 112 | layer_id: only works for non-LSTM encoders 113 | output_token_states: if True, return hidden states of specific layer and attention masks 114 | ''' 115 | 116 | if self.model_type in ('lstm',): # lstm 117 | input_ids, lengths = inputs 118 | outputs = self.module(input_ids, lengths) 119 | elif self.model_type in ('gpt',): # gpt 120 | input_ids, cls_token_ids, lm_labels = inputs # lm_labels is not used 121 | outputs = self.module(input_ids) 122 | else: # bert / xlnet / roberta 123 | input_ids, attention_mask, token_type_ids, output_mask = inputs 124 | outputs = self.module(input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask) 125 | all_hidden_states = outputs[-1] 126 | hidden_states = all_hidden_states[layer_id] 127 | 128 | if self.model_type in ('lstm',): 129 | sent_vecs = outputs[1] 130 | elif self.model_type in ('gpt',): 131 | cls_token_ids = cls_token_ids.view(-1).unsqueeze(-1).unsqueeze(-1).expand(-1, 1, hidden_states.size(-1)) 132 | sent_vecs = hidden_states.gather(1, cls_token_ids).squeeze(1) 133 | elif self.model_type in ('xlnet',): 134 | sent_vecs = hidden_states[:, -1] 135 | elif self.model_type in ('albert',): 136 | if self.output_token_states: 137 | return hidden_states, output_mask 138 | sent_vecs = hidden_states[:, 0] 139 | else: # bert / roberta 140 | if self.output_token_states: 141 | return hidden_states, output_mask 142 | sent_vecs = self.module.pooler(hidden_states) 143 | return sent_vecs, all_hidden_states 144 | 145 | 146 | def run_test(): 147 | encoder = TextEncoder('lstm', vocab_size=100, emb_size=100, hidden_size=200, num_layers=4) 148 | input_ids = torch.randint(0, 100, (30, 70)) 149 | lenghts = torch.randint(1, 70, (30,)) 150 | outputs = encoder(input_ids, lenghts) 151 | assert outputs[0].size() == (30, 200) 152 | assert len(outputs[1]) == 4 + 1 153 | assert all([x.size() == (30, 70, 100 if l == 0 else 200) for l, x in enumerate(outputs[1])]) 154 | print('all tests are passed') 155 | -------------------------------------------------------------------------------- /utils/convert_csqa.py: -------------------------------------------------------------------------------- 1 | """ 2 | Script to convert the retrieved HITS into an entailment dataset 3 | USAGE: 4 | python convert_csqa.py input_file output_file 5 | 6 | JSONL format of files 7 | 1. input_file: 8 | { 9 | "id": "d3b479933e716fb388dfb297e881054c", 10 | "question": { 11 | "stem": "If a lantern is not for sale, where is it likely to be?" 12 | "choices": [{"label": "A", "text": "antique shop"}, {"label": "B", "text": "house"}, {"label": "C", "text": "dark place"}] 13 | }, 14 | "answerKey":"B" 15 | } 16 | 17 | 2. output_file: 18 | { 19 | "id": "d3b479933e716fb388dfb297e881054c", 20 | "question": { 21 | "stem": "If a lantern is not for sale, where is it likely to be?" 22 | "choices": [{"label": "A", "text": "antique shop"}, {"label": "B", "text": "house"}, {"label": "C", "text": "dark place"}] 23 | }, 24 | "answerKey":"B", 25 | 26 | "statements":[ 27 | {label:true, stem: "If a lantern is not for sale, it likely to be at house"}, 28 | {label:false, stem: "If a lantern is not for sale, it likely to be at antique shop"}, 29 | {label:false, stem: "If a lantern is not for sale, it likely to be at dark place"} 30 | ] 31 | } 32 | """ 33 | 34 | import json 35 | import re 36 | import sys 37 | from tqdm import tqdm 38 | 39 | __all__ = ['convert_to_entailment'] 40 | 41 | # String used to indicate a blank 42 | BLANK_STR = "___" 43 | 44 | 45 | def convert_to_entailment(qa_file: str, output_file: str, ans_pos: bool=False): 46 | print(f'converting {qa_file} to entailment dataset...') 47 | nrow = sum(1 for _ in open(qa_file, 'r')) 48 | with open(output_file, 'w') as output_handle, open(qa_file, 'r') as qa_handle: 49 | # print("Writing to {} from {}".format(output_file, qa_file)) 50 | for line in tqdm(qa_handle, total=nrow): 51 | json_line = json.loads(line) 52 | output_dict = convert_qajson_to_entailment(json_line, ans_pos) 53 | output_handle.write(json.dumps(output_dict)) 54 | output_handle.write("\n") 55 | print(f'converted statements saved to {output_file}') 56 | print() 57 | 58 | 59 | # Convert the QA file json to output dictionary containing premise and hypothesis 60 | def convert_qajson_to_entailment(qa_json: dict, ans_pos: bool): 61 | question_text = qa_json["question"]["stem"] 62 | choices = qa_json["question"]["choices"] 63 | for choice in choices: 64 | choice_text = choice["text"] 65 | pos = None 66 | if not ans_pos: 67 | statement = create_hypothesis(get_fitb_from_question(question_text), choice_text, ans_pos) 68 | else: 69 | statement, pos = create_hypothesis(get_fitb_from_question(question_text), choice_text, ans_pos) 70 | create_output_dict(qa_json, statement, choice["label"] == qa_json.get("answerKey", "A"), ans_pos, pos) 71 | 72 | return qa_json 73 | 74 | 75 | # Get a Fill-In-The-Blank (FITB) statement from the question text. E.g. "George wants to warm his 76 | # hands quickly by rubbing them. Which skin surface will produce the most heat?" -> 77 | # "George wants to warm his hands quickly by rubbing them. ___ skin surface will produce the most 78 | # heat? 79 | def get_fitb_from_question(question_text: str) -> str: 80 | fitb = replace_wh_word_with_blank(question_text) 81 | if not re.match(".*_+.*", fitb): 82 | # print("Can't create hypothesis from: '{}'. Appending {} !".format(question_text, BLANK_STR)) 83 | # Strip space, period and question mark at the end of the question and add a blank 84 | fitb = re.sub(r"[\.\? ]*$", "", question_text.strip()) + " " + BLANK_STR 85 | return fitb 86 | 87 | 88 | # Create a hypothesis statement from the the input fill-in-the-blank statement and answer choice. 89 | def create_hypothesis(fitb: str, choice: str, ans_pos: bool) -> str: 90 | 91 | if ". " + BLANK_STR in fitb or fitb.startswith(BLANK_STR): 92 | choice = choice[0].upper() + choice[1:] 93 | else: 94 | choice = choice.lower() 95 | # Remove period from the answer choice, if the question doesn't end with the blank 96 | if not fitb.endswith(BLANK_STR): 97 | choice = choice.rstrip(".") 98 | # Some questions already have blanks indicated with 2+ underscores 99 | if not ans_pos: 100 | try: 101 | hypothesis = re.sub("__+", choice, fitb) 102 | except: 103 | print (choice, fitb) 104 | return hypothesis 105 | choice = choice.strip() 106 | m = re.search("__+", fitb) 107 | start = m.start() 108 | 109 | length = (len(choice) - 1) if fitb.endswith(BLANK_STR) and choice[-1] in ['.', '?', '!'] else len(choice) 110 | hypothesis = re.sub("__+", choice, fitb) 111 | 112 | return hypothesis, (start, start + length) 113 | 114 | 115 | # Identify the wh-word in the question and replace with a blank 116 | def replace_wh_word_with_blank(question_str: str): 117 | # if "What is the name of the government building that houses the U.S. Congress?" in question_str: 118 | # print() 119 | question_str = question_str.replace("What's", "What is") 120 | question_str = question_str.replace("whats", "what") 121 | question_str = question_str.replace("U.S.", "US") 122 | wh_word_offset_matches = [] 123 | wh_words = ["which", "what", "where", "when", "how", "who", "why"] 124 | for wh in wh_words: 125 | # Some Turk-authored SciQ questions end with wh-word 126 | # E.g. The passing of traits from parents to offspring is done through what? 127 | 128 | if wh == "who" and "people who" in question_str: 129 | continue 130 | 131 | m = re.search(wh + r"\?[^\.]*[\. ]*$", question_str.lower()) 132 | if m: 133 | wh_word_offset_matches = [(wh, m.start())] 134 | break 135 | else: 136 | # Otherwise, find the wh-word in the last sentence 137 | m = re.search(wh + r"[ ,][^\.]*[\. ]*$", question_str.lower()) 138 | if m: 139 | wh_word_offset_matches.append((wh, m.start())) 140 | # else: 141 | # wh_word_offset_matches.append((wh, question_str.index(wh))) 142 | 143 | # If a wh-word is found 144 | if len(wh_word_offset_matches): 145 | # Pick the first wh-word as the word to be replaced with BLANK 146 | # E.g. Which is most likely needed when describing the change in position of an object? 147 | wh_word_offset_matches.sort(key=lambda x: x[1]) 148 | wh_word_found = wh_word_offset_matches[0][0] 149 | wh_word_start_offset = wh_word_offset_matches[0][1] 150 | # Replace the last question mark with period. 151 | question_str = re.sub(r"\?$", ".", question_str.strip()) 152 | # Introduce the blank in place of the wh-word 153 | fitb_question = (question_str[:wh_word_start_offset] + BLANK_STR + 154 | question_str[wh_word_start_offset + len(wh_word_found):]) 155 | # Drop "of the following" as it doesn't make sense in the absence of a multiple-choice 156 | # question. E.g. "Which of the following force ..." -> "___ force ..." 157 | final = fitb_question.replace(BLANK_STR + " of the following", BLANK_STR) 158 | final = final.replace(BLANK_STR + " of these", BLANK_STR) 159 | return final 160 | 161 | elif " them called?" in question_str: 162 | return question_str.replace(" them called?", " " + BLANK_STR + ".") 163 | elif " meaning he was not?" in question_str: 164 | return question_str.replace(" meaning he was not?", " he was not " + BLANK_STR + ".") 165 | elif " one of these?" in question_str: 166 | return question_str.replace(" one of these?", " " + BLANK_STR + ".") 167 | elif re.match(r".*[^\.\?] *$", question_str): 168 | # If no wh-word is found and the question ends without a period/question, introduce a 169 | # blank at the end. e.g. The gravitational force exerted by an object depends on its 170 | return question_str + " " + BLANK_STR 171 | else: 172 | # If all else fails, assume "this ?" indicates the blank. Used in Turk-authored questions 173 | # e.g. Virtually every task performed by living organisms requires this? 174 | return re.sub(r" this[ \?]", " ___ ", question_str) 175 | 176 | 177 | # Create the output json dictionary from the input json, premise and hypothesis statement 178 | def create_output_dict(input_json: dict, statement: str, label: bool, ans_pos: bool, pos=None) -> dict: 179 | if "statements" not in input_json: 180 | input_json["statements"] = [] 181 | if not ans_pos: 182 | input_json["statements"].append({"label": label, "statement": statement}) 183 | else: 184 | input_json["statements"].append({"label": label, "statement": statement, "ans_pos": pos}) 185 | return input_json 186 | 187 | 188 | if __name__ == "__main__": 189 | if len(sys.argv) < 3: 190 | raise ValueError("Provide at least two arguments: " 191 | "json file with hits, output file name") 192 | convert_to_entailment(sys.argv[1], sys.argv[2]) 193 | -------------------------------------------------------------------------------- /utils/tokenization_utils.py: -------------------------------------------------------------------------------- 1 | from transformers import PreTrainedTokenizer 2 | import os 3 | import nltk 4 | import json 5 | from tqdm import tqdm 6 | import spacy 7 | 8 | EOS_TOK = '' 9 | UNK_TOK = '' 10 | PAD_TOK = '' 11 | SEP_TOK = '' 12 | EXTRA_TOKS = [EOS_TOK, UNK_TOK, PAD_TOK, SEP_TOK] 13 | 14 | 15 | class WordTokenizer(PreTrainedTokenizer): 16 | vocab_files_names = {'vocab_file': 'vocab.txt'} 17 | pretrained_vocab_files_map = {'vocab_file': {'lstm': './data/glove/glove.vocab'}} 18 | max_model_input_sizes = {'lstm': None} 19 | """ 20 | vocab_file: Path to a json file that contains token-to-id mapping 21 | """ 22 | 23 | def __init__(self, vocab_file, unk_token=UNK_TOK, sep_token=SEP_TOK, pad_token=PAD_TOK, eos_token=EOS_TOK, **kwargs): 24 | super(WordTokenizer, self).__init__(unk_token=unk_token, sep_token=sep_token, 25 | pad_token=pad_token, eos_token=eos_token, **kwargs) 26 | with open(vocab_file, 'r', encoding='utf-8') as fin: 27 | self.vocab = {line.rstrip('\n'): i for i, line in enumerate(fin)} 28 | self.ids_to_tokens = {ids: tok for tok, ids in self.vocab.items()} 29 | self.spacy_tokenizer = spacy.load('en_core_web_sm', disable=['tagger', 'parser', 'ner', 'textcat']) 30 | 31 | @property 32 | def vocab_size(self): 33 | return len(self.vocab) 34 | 35 | def _tokenize(self, text): 36 | return tokenize_sentence_spacy(self.spacy_tokenizer, text, lower_case=True, convert_num=False) 37 | 38 | def _convert_token_to_id(self, token): 39 | """ Converts a token (str/unicode) in an id using the vocab. """ 40 | return self.vocab.get(token, self.vocab.get(self.unk_token)) 41 | 42 | def _convert_id_to_token(self, index): 43 | """Converts an index (integer) in a token (string/unicode) using the vocab.""" 44 | return self.ids_to_tokens.get(index, self.unk_token) 45 | 46 | def convert_tokens_to_string(self, tokens): 47 | """ Converts a sequence of tokens (string) in a single string. """ 48 | out_string = ' '.join(tokens).strip() 49 | return out_string 50 | 51 | def add_special_tokens_single_sequence(self, token_ids): 52 | return token_ids + [self.eos_token_id] 53 | 54 | def add_special_tokens_sequence_pair(self, token_ids_0, token_ids_1): 55 | return token_ids_0 + [self.sep_token_id] + token_ids_1 56 | 57 | def save_vocabulary(self, vocab_path): 58 | """Save the tokenizer vocabulary to a directory or file.""" 59 | if os.path.isdir(vocab_path): 60 | vocab_file = os.path.join(vocab_path, self.vocab_files_names['vocab_file']) 61 | else: 62 | vocab_file = vocab_path 63 | with open(vocab_file, "w", encoding="utf-8") as fout: 64 | for i in range(len(self.vocab)): 65 | fout.write(self.ids_to_tokens[i] + '\n') 66 | return (vocab_file,) 67 | 68 | 69 | class WordVocab(object): 70 | 71 | def __init__(self, sents=None, path=None, freq_cutoff=5, encoding='utf-8', verbose=True): 72 | """ 73 | sents: list[str] (optional, default None) 74 | path: str (optional, default None) 75 | freq_cutoff: int (optional, default 5, 0 to disable) 76 | encoding: str (optional, default utf-8) 77 | """ 78 | if sents is not None: 79 | counts = {} 80 | for text in sents: 81 | for w in text.split(): 82 | counts[w] = counts.get(w, 0) + 1 83 | self._idx2w = [t[0] for t in sorted(counts.items(), key=lambda x: -x[1])] 84 | self._w2idx = {w: i for i, w in enumerate(self._idx2w)} 85 | self._counts = counts 86 | 87 | elif path is not None: 88 | self._idx2w = [] 89 | self._counts = {} 90 | with open(path, 'r', encoding=encoding) as fin: 91 | for line in fin: 92 | w, c = line.rstrip().split(' ') 93 | self._idx2w.append(w) 94 | self._counts[w] = c 95 | self._w2idx = {w: i for i, w in enumerate(self._idx2w)} 96 | 97 | else: 98 | self._idx2w = [] 99 | self._w2idx = {} 100 | self._counts = {} 101 | 102 | if freq_cutoff > 1: 103 | self._idx2w = [w for w in self._idx2w if self._counts[w] >= freq_cutoff] 104 | 105 | in_sum = sum([self._counts[w] for w in self._idx2w]) 106 | total_sum = sum([self._counts[w] for w in self._counts]) 107 | if verbose: 108 | print('vocab oov rate: {:.4f}'.format(1 - in_sum / total_sum)) 109 | 110 | self._w2idx = {w: i for i, w in enumerate(self._idx2w)} 111 | self._counts = {w: self._counts[w] for w in self._idx2w} 112 | 113 | def add_word(self, w, count=1): 114 | if w not in self.w2idx: 115 | self._w2idx[w] = len(self._idx2w) 116 | self._idx2w.append(w) 117 | self._counts[w] = count 118 | else: 119 | self._counts[w] += count 120 | return self 121 | 122 | def top_k_cutoff(self, size): 123 | if size < len(self._idx2w): 124 | for w in self._idx2w[size:]: 125 | self._w2idx.pop(w) 126 | self._counts.pop(w) 127 | self._idx2w = self._idx2w[:size] 128 | 129 | assert len(self._idx2w) == len(self._w2idx) == len(self._counts) 130 | return self 131 | 132 | def save(self, path, encoding='utf-8'): 133 | with open(path, 'w', encoding=encoding) as fout: 134 | for w in self._idx2w: 135 | fout.write(w + ' ' + str(self._counts[w]) + '\n') 136 | 137 | def __len__(self): 138 | return len(self._idx2w) 139 | 140 | def __contains__(self, word): 141 | return word in self._w2idx 142 | 143 | def __iter__(self): 144 | for word in self._idx2w: 145 | yield word 146 | 147 | @property 148 | def w2idx(self): 149 | return self._w2idx 150 | 151 | @property 152 | def idx2w(self): 153 | return self._idx2w 154 | 155 | @property 156 | def counts(self): 157 | return self._counts 158 | 159 | 160 | def tokenize_sentence_nltk(sent, lower_case=True, convert_num=False): 161 | tokens = nltk.word_tokenize(sent) 162 | if lower_case: 163 | tokens = [t.lower() for t in tokens] 164 | if convert_num: 165 | tokens = ['' if t.isdigit() else t for t in tokens] 166 | return tokens 167 | 168 | 169 | def tokenize_sentence_spacy(nlp, sent, lower_case=True, convert_num=False): 170 | tokens = [tok.text for tok in nlp(sent)] 171 | if lower_case: 172 | tokens = [t.lower() for t in tokens] 173 | if convert_num: 174 | tokens = ['' if t.isdigit() else t for t in tokens] 175 | return tokens 176 | 177 | 178 | def tokenize_statement_file(statement_path, output_path, lower_case=True, convert_num=False): 179 | nlp = spacy.load('en_core_web_sm', disable=['parser', 'ner', 'textcat']) 180 | nrow = sum(1 for _ in open(statement_path, 'r')) 181 | with open(statement_path, 'r') as fin, open(output_path, 'w') as fout: 182 | for line in tqdm(fin, total=nrow, desc='tokenizing'): 183 | data = json.loads(line) 184 | for statement in data['statements']: 185 | tokens = tokenize_sentence_spacy(nlp, statement['statement'], lower_case=lower_case, convert_num=convert_num) 186 | fout.write(' '.join(tokens) + '\n') 187 | 188 | 189 | def make_word_vocab(statement_path_list, output_path, lower_case=True, convert_num=True, freq_cutoff=5): 190 | """save the vocab to the output_path in json format""" 191 | nlp = spacy.load('en_core_web_sm', disable=['tagger', 'parser', 'ner', 'textcat']) 192 | 193 | docs = [] 194 | for path in statement_path_list: 195 | with open(path, 'r', encoding='utf-8') as fin: 196 | for line in fin: 197 | json_dic = json.loads(line) 198 | docs += [json_dic['question']['stem']] + [s['text'] for s in json_dic['question']['choices']] 199 | 200 | counts = {} 201 | for doc in tqdm(docs, desc='making word vocab'): 202 | for w in tokenize_sentence_spacy(nlp, doc, lower_case=lower_case, convert_num=convert_num): 203 | counts[w] = counts.get(w, 0) + 1 204 | idx2w = [t[0] for t in sorted(counts.items(), key=lambda x: -x[1])] 205 | idx2w = [w for w in idx2w if counts[w] >= freq_cutoff] 206 | idx2w += EXTRA_TOKS 207 | w2idx = {w: i for i, w in enumerate(idx2w)} 208 | with open(output_path, 'w', encoding='utf-8') as fout: 209 | json.dump(w2idx, fout) 210 | 211 | 212 | def run_test(): 213 | # tokenize_statement_file('data/csqa/statement/dev.statement.jsonl', '/tmp/tokenized.txt', True, True) 214 | # make_word_vocab(['data/csqa/statement/dev.statement.jsonl', 'data/csqa/statement/train.statement.jsonl'], '/tmp/vocab.txt', True, True) 215 | tokenizer = WordTokenizer.from_pretrained('lstm') 216 | print(tokenizer.tokenize('I love NLP since 1998DEC')) 217 | print(tokenizer.tokenize('CXY loves NLP since 1998')) 218 | print(tokenizer.convert_tokens_to_ids(tokenizer.tokenize('CXY loves NLP since 1998'))) 219 | print(tokenizer.convert_ids_to_tokens(tokenizer.convert_tokens_to_ids(tokenizer.tokenize('CXY loves NLP since 1998')))) 220 | tokenizer.save_pretrained('/tmp/') 221 | tokenizer = WordTokenizer.from_pretrained('/tmp/') 222 | print('vocab size = {}'.format(tokenizer.vocab_size)) 223 | 224 | 225 | if __name__ == '__main__': 226 | run_test() 227 | -------------------------------------------------------------------------------- /preprocess.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | from multiprocessing import cpu_count 4 | from utils.convert_csqa import convert_to_entailment 5 | from utils.convert_obqa import convert_to_obqa_statement 6 | from utils.conceptnet import extract_english, construct_graph 7 | from utils.grounding import create_matcher_patterns, ground 8 | from utils.graph import generate_adj_data_from_grounded_concepts__use_LM 9 | 10 | input_paths = { 11 | 'csqa': { 12 | 'train': './data/csqa/train_rand_split.jsonl', 13 | 'dev': './data/csqa/dev_rand_split.jsonl', 14 | 'test': './data/csqa/test_rand_split_no_answers.jsonl', 15 | }, 16 | 'obqa': { 17 | 'train': './data/obqa/OpenBookQA-V1-Sep2018/Data/Main/train.jsonl', 18 | 'dev': './data/obqa/OpenBookQA-V1-Sep2018/Data/Main/dev.jsonl', 19 | 'test': './data/obqa/OpenBookQA-V1-Sep2018/Data/Main/test.jsonl', 20 | }, 21 | 'obqa-fact': { 22 | 'train': './data/obqa/OpenBookQA-V1-Sep2018/Data/Additional/train_complete.jsonl', 23 | 'dev': './data/obqa/OpenBookQA-V1-Sep2018/Data/Additional/dev_complete.jsonl', 24 | 'test': './data/obqa/OpenBookQA-V1-Sep2018/Data/Additional/test_complete.jsonl', 25 | }, 26 | 'cpnet': { 27 | 'csv': './data/cpnet/conceptnet-assertions-5.6.0.csv', 28 | }, 29 | } 30 | 31 | output_paths = { 32 | 'cpnet': { 33 | 'csv': './data/cpnet/conceptnet.en.csv', 34 | 'vocab': './data/cpnet/concept.txt', 35 | 'patterns': './data/cpnet/matcher_patterns.json', 36 | 'unpruned-graph': './data/cpnet/conceptnet.en.unpruned.graph', 37 | 'pruned-graph': './data/cpnet/conceptnet.en.pruned.graph', 38 | }, 39 | 'csqa': { 40 | 'statement': { 41 | 'train': './data/csqa/statement/train.statement.jsonl', 42 | 'dev': './data/csqa/statement/dev.statement.jsonl', 43 | 'test': './data/csqa/statement/test.statement.jsonl', 44 | }, 45 | 'grounded': { 46 | 'train': './data/csqa/grounded/train.grounded.jsonl', 47 | 'dev': './data/csqa/grounded/dev.grounded.jsonl', 48 | 'test': './data/csqa/grounded/test.grounded.jsonl', 49 | }, 50 | 'graph': { 51 | 'adj-train': './data/csqa/graph/train.graph.adj.pk', 52 | 'adj-dev': './data/csqa/graph/dev.graph.adj.pk', 53 | 'adj-test': './data/csqa/graph/test.graph.adj.pk', 54 | }, 55 | }, 56 | 'obqa': { 57 | 'statement': { 58 | 'train': './data/obqa/statement/train.statement.jsonl', 59 | 'dev': './data/obqa/statement/dev.statement.jsonl', 60 | 'test': './data/obqa/statement/test.statement.jsonl', 61 | 'train-fairseq': './data/obqa/fairseq/official/train.jsonl', 62 | 'dev-fairseq': './data/obqa/fairseq/official/valid.jsonl', 63 | 'test-fairseq': './data/obqa/fairseq/official/test.jsonl', 64 | }, 65 | 'grounded': { 66 | 'train': './data/obqa/grounded/train.grounded.jsonl', 67 | 'dev': './data/obqa/grounded/dev.grounded.jsonl', 68 | 'test': './data/obqa/grounded/test.grounded.jsonl', 69 | }, 70 | 'graph': { 71 | 'adj-train': './data/obqa/graph/train.graph.adj.pk', 72 | 'adj-dev': './data/obqa/graph/dev.graph.adj.pk', 73 | 'adj-test': './data/obqa/graph/test.graph.adj.pk', 74 | }, 75 | }, 76 | 'obqa-fact': { 77 | 'statement': { 78 | 'train': './data/obqa/statement/train-fact.statement.jsonl', 79 | 'dev': './data/obqa/statement/dev-fact.statement.jsonl', 80 | 'test': './data/obqa/statement/test-fact.statement.jsonl', 81 | 'train-fairseq': './data/obqa/fairseq/official/train-fact.jsonl', 82 | 'dev-fairseq': './data/obqa/fairseq/official/valid-fact.jsonl', 83 | 'test-fairseq': './data/obqa/fairseq/official/test-fact.jsonl', 84 | }, 85 | }, 86 | } 87 | 88 | 89 | def main(): 90 | parser = argparse.ArgumentParser() 91 | parser.add_argument('--run', default=['common'], choices=['common', 'csqa', 'hswag', 'anli', 'exp', 'scitail', 'phys', 'socialiqa', 'obqa', 'obqa-fact', 'make_word_vocab'], nargs='+') 92 | parser.add_argument('--path_prune_threshold', type=float, default=0.12, help='threshold for pruning paths') 93 | parser.add_argument('--max_node_num', type=int, default=200, help='maximum number of nodes per graph') 94 | parser.add_argument('-p', '--nprocs', type=int, default=cpu_count(), help='number of processes to use') 95 | parser.add_argument('--seed', type=int, default=0, help='random seed') 96 | parser.add_argument('--debug', action='store_true', help='enable debug mode') 97 | 98 | args = parser.parse_args() 99 | if args.debug: 100 | raise NotImplementedError() 101 | 102 | routines = { 103 | 'common': [ 104 | {'func': extract_english, 'args': (input_paths['cpnet']['csv'], output_paths['cpnet']['csv'], output_paths['cpnet']['vocab'])}, 105 | {'func': construct_graph, 'args': (output_paths['cpnet']['csv'], output_paths['cpnet']['vocab'], 106 | output_paths['cpnet']['unpruned-graph'], False)}, 107 | {'func': construct_graph, 'args': (output_paths['cpnet']['csv'], output_paths['cpnet']['vocab'], 108 | output_paths['cpnet']['pruned-graph'], True)}, 109 | {'func': create_matcher_patterns, 'args': (output_paths['cpnet']['vocab'], output_paths['cpnet']['patterns'])}, 110 | ], 111 | 'csqa': [ 112 | {'func': convert_to_entailment, 'args': (input_paths['csqa']['train'], output_paths['csqa']['statement']['train'])}, 113 | {'func': convert_to_entailment, 'args': (input_paths['csqa']['dev'], output_paths['csqa']['statement']['dev'])}, 114 | {'func': convert_to_entailment, 'args': (input_paths['csqa']['test'], output_paths['csqa']['statement']['test'])}, 115 | {'func': ground, 'args': (output_paths['csqa']['statement']['train'], output_paths['cpnet']['vocab'], 116 | output_paths['cpnet']['patterns'], output_paths['csqa']['grounded']['train'], args.nprocs)}, 117 | {'func': ground, 'args': (output_paths['csqa']['statement']['dev'], output_paths['cpnet']['vocab'], 118 | output_paths['cpnet']['patterns'], output_paths['csqa']['grounded']['dev'], args.nprocs)}, 119 | {'func': ground, 'args': (output_paths['csqa']['statement']['test'], output_paths['cpnet']['vocab'], 120 | output_paths['cpnet']['patterns'], output_paths['csqa']['grounded']['test'], args.nprocs)}, 121 | {'func': generate_adj_data_from_grounded_concepts__use_LM, 'args': (output_paths['csqa']['grounded']['train'], output_paths['cpnet']['pruned-graph'], output_paths['cpnet']['vocab'], output_paths['csqa']['graph']['adj-train'], args.nprocs)}, 122 | {'func': generate_adj_data_from_grounded_concepts__use_LM, 'args': (output_paths['csqa']['grounded']['dev'], output_paths['cpnet']['pruned-graph'], output_paths['cpnet']['vocab'], output_paths['csqa']['graph']['adj-dev'], args.nprocs)}, 123 | {'func': generate_adj_data_from_grounded_concepts__use_LM, 'args': (output_paths['csqa']['grounded']['test'], output_paths['cpnet']['pruned-graph'], output_paths['cpnet']['vocab'], output_paths['csqa']['graph']['adj-test'], args.nprocs)}, 124 | ], 125 | 126 | 'obqa': [ 127 | {'func': convert_to_obqa_statement, 'args': (input_paths['obqa']['train'], output_paths['obqa']['statement']['train'], output_paths['obqa']['statement']['train-fairseq'])}, 128 | {'func': convert_to_obqa_statement, 'args': (input_paths['obqa']['dev'], output_paths['obqa']['statement']['dev'], output_paths['obqa']['statement']['dev-fairseq'])}, 129 | {'func': convert_to_obqa_statement, 'args': (input_paths['obqa']['test'], output_paths['obqa']['statement']['test'], output_paths['obqa']['statement']['test-fairseq'])}, 130 | {'func': ground, 'args': (output_paths['obqa']['statement']['train'], output_paths['cpnet']['vocab'], 131 | output_paths['cpnet']['patterns'], output_paths['obqa']['grounded']['train'], args.nprocs)}, 132 | {'func': ground, 'args': (output_paths['obqa']['statement']['dev'], output_paths['cpnet']['vocab'], 133 | output_paths['cpnet']['patterns'], output_paths['obqa']['grounded']['dev'], args.nprocs)}, 134 | {'func': ground, 'args': (output_paths['obqa']['statement']['test'], output_paths['cpnet']['vocab'], 135 | output_paths['cpnet']['patterns'], output_paths['obqa']['grounded']['test'], args.nprocs)}, 136 | {'func': generate_adj_data_from_grounded_concepts__use_LM, 'args': (output_paths['obqa']['grounded']['train'], output_paths['cpnet']['pruned-graph'], output_paths['cpnet']['vocab'], output_paths['obqa']['graph']['adj-train'], args.nprocs)}, 137 | {'func': generate_adj_data_from_grounded_concepts__use_LM, 'args': (output_paths['obqa']['grounded']['dev'], output_paths['cpnet']['pruned-graph'], output_paths['cpnet']['vocab'], output_paths['obqa']['graph']['adj-dev'], args.nprocs)}, 138 | {'func': generate_adj_data_from_grounded_concepts__use_LM, 'args': (output_paths['obqa']['grounded']['test'], output_paths['cpnet']['pruned-graph'], output_paths['cpnet']['vocab'], output_paths['obqa']['graph']['adj-test'], args.nprocs)}, 139 | ], 140 | } 141 | 142 | for rt in args.run: 143 | for rt_dic in routines[rt]: 144 | rt_dic['func'](*rt_dic['args']) 145 | 146 | print('Successfully run {}'.format(' '.join(args.run))) 147 | 148 | 149 | if __name__ == '__main__': 150 | main() 151 | # pass 152 | -------------------------------------------------------------------------------- /utils/grounding.py: -------------------------------------------------------------------------------- 1 | from multiprocessing import Pool 2 | import spacy 3 | from spacy.matcher import Matcher 4 | from tqdm import tqdm 5 | import nltk 6 | import json 7 | import string 8 | 9 | 10 | __all__ = ['create_matcher_patterns', 'ground'] 11 | 12 | 13 | # the lemma of it/them/mine/.. is -PRON- 14 | 15 | blacklist = set(["-PRON-", "actually", "likely", "possibly", "want", 16 | "make", "my", "someone", "sometimes_people", "sometimes", "would", "want_to", 17 | "one", "something", "sometimes", "everybody", "somebody", "could", "could_be" 18 | ]) 19 | 20 | 21 | nltk.download('stopwords', quiet=True) 22 | nltk_stopwords = nltk.corpus.stopwords.words('english') 23 | 24 | # CHUNK_SIZE = 1 25 | 26 | CPNET_VOCAB = None 27 | PATTERN_PATH = None 28 | nlp = None 29 | matcher = None 30 | 31 | 32 | def load_cpnet_vocab(cpnet_vocab_path): 33 | with open(cpnet_vocab_path, "r", encoding="utf8") as fin: 34 | cpnet_vocab = [l.strip() for l in fin] 35 | cpnet_vocab = [c.replace("_", " ") for c in cpnet_vocab] 36 | return cpnet_vocab 37 | 38 | 39 | def create_pattern(nlp, doc, debug=False): 40 | pronoun_list = set(["my", "you", "it", "its", "your", "i", "he", "she", "his", "her", "they", "them", "their", "our", "we"]) 41 | # Filtering concepts consisting of all stop words and longer than four words. 42 | if len(doc) >= 5 or doc[0].text in pronoun_list or doc[-1].text in pronoun_list or \ 43 | all([(token.text in nltk_stopwords or token.lemma_ in nltk_stopwords or token.lemma_ in blacklist) for token in doc]): 44 | if debug: 45 | return False, doc.text 46 | return None # ignore this concept as pattern 47 | 48 | pattern = [] 49 | for token in doc: # a doc is a concept 50 | pattern.append({"LEMMA": token.lemma_}) 51 | if debug: 52 | return True, doc.text 53 | return pattern 54 | 55 | 56 | def create_matcher_patterns(cpnet_vocab_path, output_path, debug=False): 57 | cpnet_vocab = load_cpnet_vocab(cpnet_vocab_path) 58 | nlp = spacy.load('en_core_web_sm', disable=['parser', 'ner', 'textcat']) 59 | docs = nlp.pipe(cpnet_vocab) 60 | all_patterns = {} 61 | 62 | if debug: 63 | f = open("filtered_concept.txt", "w") 64 | 65 | for doc in tqdm(docs, total=len(cpnet_vocab)): 66 | 67 | pattern = create_pattern(nlp, doc, debug) 68 | if debug: 69 | if not pattern[0]: 70 | f.write(pattern[1] + '\n') 71 | 72 | if pattern is None: 73 | continue 74 | all_patterns["_".join(doc.text.split(" "))] = pattern 75 | 76 | print("Created " + str(len(all_patterns)) + " patterns.") 77 | with open(output_path, "w", encoding="utf8") as fout: 78 | json.dump(all_patterns, fout) 79 | if debug: 80 | f.close() 81 | 82 | 83 | def lemmatize(nlp, concept): 84 | 85 | doc = nlp(concept.replace("_", " ")) 86 | lcs = set() 87 | # for i in range(len(doc)): 88 | # lemmas = [] 89 | # for j, token in enumerate(doc): 90 | # if j == i: 91 | # lemmas.append(token.lemma_) 92 | # else: 93 | # lemmas.append(token.text) 94 | # lc = "_".join(lemmas) 95 | # lcs.add(lc) 96 | lcs.add("_".join([token.lemma_ for token in doc])) # all lemma 97 | return lcs 98 | 99 | 100 | def load_matcher(nlp, pattern_path): 101 | with open(pattern_path, "r", encoding="utf8") as fin: 102 | all_patterns = json.load(fin) 103 | 104 | matcher = Matcher(nlp.vocab) 105 | for concept, pattern in all_patterns.items(): 106 | matcher.add(concept, None, pattern) 107 | return matcher 108 | 109 | 110 | def ground_qa_pair(qa_pair): 111 | global nlp, matcher 112 | if nlp is None or matcher is None: 113 | nlp = spacy.load('en_core_web_sm', disable=['ner', 'parser', 'textcat']) 114 | nlp.add_pipe(nlp.create_pipe('sentencizer')) 115 | matcher = load_matcher(nlp, PATTERN_PATH) 116 | 117 | s, a = qa_pair 118 | all_concepts = ground_mentioned_concepts(nlp, matcher, s, a) 119 | answer_concepts = ground_mentioned_concepts(nlp, matcher, a) 120 | question_concepts = all_concepts - answer_concepts 121 | if len(question_concepts) == 0: 122 | question_concepts = hard_ground(nlp, s, CPNET_VOCAB) # not very possible 123 | 124 | if len(answer_concepts) == 0: 125 | answer_concepts = hard_ground(nlp, a, CPNET_VOCAB) # some case 126 | 127 | # question_concepts = question_concepts - answer_concepts 128 | question_concepts = sorted(list(question_concepts)) 129 | answer_concepts = sorted(list(answer_concepts)) 130 | return {"sent": s, "ans": a, "qc": question_concepts, "ac": answer_concepts} 131 | 132 | 133 | def ground_mentioned_concepts(nlp, matcher, s, ans=None): 134 | 135 | s = s.lower() 136 | doc = nlp(s) 137 | matches = matcher(doc) 138 | 139 | mentioned_concepts = set() 140 | span_to_concepts = {} 141 | 142 | if ans is not None: 143 | ans_matcher = Matcher(nlp.vocab) 144 | ans_words = nlp(ans) 145 | # print(ans_words) 146 | ans_matcher.add(ans, None, [{'TEXT': token.text.lower()} for token in ans_words]) 147 | 148 | ans_match = ans_matcher(doc) 149 | ans_mentions = set() 150 | for _, ans_start, ans_end in ans_match: 151 | ans_mentions.add((ans_start, ans_end)) 152 | 153 | for match_id, start, end in matches: 154 | if ans is not None: 155 | if (start, end) in ans_mentions: 156 | continue 157 | 158 | span = doc[start:end].text # the matched span 159 | 160 | # a word that appears in answer is not considered as a mention in the question 161 | # if len(set(span.split(" ")).intersection(set(ans.split(" ")))) > 0: 162 | # continue 163 | original_concept = nlp.vocab.strings[match_id] 164 | original_concept_set = set() 165 | original_concept_set.add(original_concept) 166 | 167 | # print("span", span) 168 | # print("concept", original_concept) 169 | # print("Matched '" + span + "' to the rule '" + string_id) 170 | 171 | # why do you lemmatize a mention whose len == 1? 172 | 173 | if len(original_concept.split("_")) == 1: 174 | # tag = doc[start].tag_ 175 | # if tag in ['VBN', 'VBG']: 176 | 177 | original_concept_set.update(lemmatize(nlp, nlp.vocab.strings[match_id])) 178 | 179 | if span not in span_to_concepts: 180 | span_to_concepts[span] = set() 181 | 182 | span_to_concepts[span].update(original_concept_set) 183 | 184 | for span, concepts in span_to_concepts.items(): 185 | concepts_sorted = list(concepts) 186 | # print("span:") 187 | # print(span) 188 | # print("concept_sorted:") 189 | # print(concepts_sorted) 190 | concepts_sorted.sort(key=len) 191 | 192 | # mentioned_concepts.update(concepts_sorted[0:2]) 193 | 194 | shortest = concepts_sorted[0:3] 195 | 196 | for c in shortest: 197 | if c in blacklist: 198 | continue 199 | 200 | # a set with one string like: set("like_apples") 201 | lcs = lemmatize(nlp, c) 202 | intersect = lcs.intersection(shortest) 203 | if len(intersect) > 0: 204 | mentioned_concepts.add(list(intersect)[0]) 205 | else: 206 | mentioned_concepts.add(c) 207 | 208 | # if a mention exactly matches with a concept 209 | 210 | exact_match = set([concept for concept in concepts_sorted if concept.replace("_", " ").lower() == span.lower()]) 211 | # print("exact match:") 212 | # print(exact_match) 213 | assert len(exact_match) < 2 214 | mentioned_concepts.update(exact_match) 215 | 216 | return mentioned_concepts 217 | 218 | 219 | def hard_ground(nlp, sent, cpnet_vocab): 220 | sent = sent.lower() 221 | doc = nlp(sent) 222 | res = set() 223 | for t in doc: 224 | if t.lemma_ in cpnet_vocab: 225 | res.add(t.lemma_) 226 | sent = " ".join([t.text for t in doc]) 227 | if sent in cpnet_vocab: 228 | res.add(sent) 229 | try: 230 | assert len(res) > 0 231 | except Exception: 232 | print(f"for {sent}, concept not found in hard grounding.") 233 | return res 234 | 235 | 236 | def match_mentioned_concepts(sents, answers, num_processes): 237 | res = [] 238 | with Pool(num_processes) as p: 239 | res = list(tqdm(p.imap(ground_qa_pair, zip(sents, answers)), total=len(sents))) 240 | return res 241 | 242 | 243 | # To-do: examine prune 244 | def prune(data, cpnet_vocab_path): 245 | # reload cpnet_vocab 246 | with open(cpnet_vocab_path, "r", encoding="utf8") as fin: 247 | cpnet_vocab = [l.strip() for l in fin] 248 | 249 | prune_data = [] 250 | for item in tqdm(data): 251 | qc = item["qc"] 252 | prune_qc = [] 253 | for c in qc: 254 | if c[-2:] == "er" and c[:-2] in qc: 255 | continue 256 | if c[-1:] == "e" and c[:-1] in qc: 257 | continue 258 | have_stop = False 259 | # remove all concepts having stopwords, including hard-grounded ones 260 | for t in c.split("_"): 261 | if t in nltk_stopwords: 262 | have_stop = True 263 | if not have_stop and c in cpnet_vocab: 264 | prune_qc.append(c) 265 | 266 | ac = item["ac"] 267 | prune_ac = [] 268 | for c in ac: 269 | if c[-2:] == "er" and c[:-2] in ac: 270 | continue 271 | if c[-1:] == "e" and c[:-1] in ac: 272 | continue 273 | all_stop = True 274 | for t in c.split("_"): 275 | if t not in nltk_stopwords: 276 | all_stop = False 277 | if not all_stop and c in cpnet_vocab: 278 | prune_ac.append(c) 279 | 280 | try: 281 | assert len(prune_ac) > 0 and len(prune_qc) > 0 282 | except Exception as e: 283 | pass 284 | # print("In pruning") 285 | # print(prune_qc) 286 | # print(prune_ac) 287 | # print("original:") 288 | # print(qc) 289 | # print(ac) 290 | # print() 291 | item["qc"] = prune_qc 292 | item["ac"] = prune_ac 293 | 294 | prune_data.append(item) 295 | return prune_data 296 | 297 | 298 | def ground(statement_path, cpnet_vocab_path, pattern_path, output_path, num_processes=1, debug=False): 299 | global PATTERN_PATH, CPNET_VOCAB 300 | if PATTERN_PATH is None: 301 | PATTERN_PATH = pattern_path 302 | CPNET_VOCAB = load_cpnet_vocab(cpnet_vocab_path) 303 | 304 | sents = [] 305 | answers = [] 306 | with open(statement_path, 'r') as fin: 307 | lines = [line for line in fin] 308 | 309 | if debug: 310 | lines = lines[192:195] 311 | print(len(lines)) 312 | for line in lines: 313 | if line == "": 314 | continue 315 | j = json.loads(line) 316 | # {'answerKey': 'B', 317 | # 'id': 'b8c0a4703079cf661d7261a60a1bcbff', 318 | # 'question': {'question_concept': 'magazines', 319 | # 'choices': [{'label': 'A', 'text': 'doctor'}, {'label': 'B', 'text': 'bookstore'}, {'label': 'C', 'text': 'market'}, {'label': 'D', 'text': 'train station'}, {'label': 'E', 'text': 'mortuary'}], 320 | # 'stem': 'Where would you find magazines along side many other printed works?'}, 321 | # 'statements': [{'label': False, 'statement': 'Doctor would you find magazines along side many other printed works.'}, {'label': True, 'statement': 'Bookstore would you find magazines along side many other printed works.'}, {'label': False, 'statement': 'Market would you find magazines along side many other printed works.'}, {'label': False, 'statement': 'Train station would you find magazines along side many other printed works.'}, {'label': False, 'statement': 'Mortuary would you find magazines along side many other printed works.'}]} 322 | 323 | for statement in j["statements"]: 324 | sents.append(statement["statement"]) 325 | 326 | for answer in j["question"]["choices"]: 327 | ans = answer['text'] 328 | # ans = " ".join(answer['text'].split("_")) 329 | try: 330 | assert all([i != "_" for i in ans]) 331 | except Exception: 332 | print(ans) 333 | answers.append(ans) 334 | 335 | res = match_mentioned_concepts(sents, answers, num_processes) 336 | res = prune(res, cpnet_vocab_path) 337 | 338 | # check_path(output_path) 339 | with open(output_path, 'w') as fout: 340 | for dic in res: 341 | fout.write(json.dumps(dic) + '\n') 342 | 343 | print(f'grounded concepts saved to {output_path}') 344 | print() 345 | 346 | 347 | if __name__ == "__main__": 348 | create_matcher_patterns("../data/cpnet/concept.txt", "./matcher_res.txt", True) 349 | # ground("../data/statement/dev.statement.jsonl", "../data/cpnet/concept.txt", "../data/cpnet/matcher_patterns.json", "./ground_res.jsonl", 10, True) 350 | 351 | # s = "a revolving door is convenient for two direction travel, but it also serves as a security measure at a bank." 352 | # a = "bank" 353 | # nlp = spacy.load('en_core_web_sm', disable=['ner', 'parser', 'textcat']) 354 | # nlp.add_pipe(nlp.create_pipe('sentencizer')) 355 | # ans_words = nlp(a) 356 | # doc = nlp(s) 357 | # ans_matcher = Matcher(nlp.vocab) 358 | # print([{'TEXT': token.text.lower()} for token in ans_words]) 359 | # ans_matcher.add("ok", None, [{'TEXT': token.text.lower()} for token in ans_words]) 360 | # 361 | # matches = ans_matcher(doc) 362 | # for a, b, c in matches: 363 | # print(a, b, c) 364 | -------------------------------------------------------------------------------- /utils/conceptnet.py: -------------------------------------------------------------------------------- 1 | import networkx as nx 2 | import nltk 3 | import json 4 | import math 5 | from tqdm import tqdm 6 | import numpy as np 7 | import sys 8 | 9 | try: 10 | from .utils import check_file 11 | except ImportError: 12 | from utils import check_file 13 | 14 | __all__ = ['extract_english', 'construct_graph', 'merged_relations'] 15 | 16 | relation_groups = [ 17 | 'atlocation/locatednear', 18 | 'capableof', 19 | 'causes/causesdesire/*motivatedbygoal', 20 | 'createdby', 21 | 'desires', 22 | 'antonym/distinctfrom', 23 | 'hascontext', 24 | 'hasproperty', 25 | 'hassubevent/hasfirstsubevent/haslastsubevent/hasprerequisite/entails/mannerof', 26 | 'isa/instanceof/definedas', 27 | 'madeof', 28 | 'notcapableof', 29 | 'notdesires', 30 | 'partof/*hasa', 31 | 'relatedto/similarto/synonym', 32 | 'usedfor', 33 | 'receivesaction', 34 | ] 35 | 36 | merged_relations = [ 37 | 'antonym', 38 | 'atlocation', 39 | 'capableof', 40 | 'causes', 41 | 'createdby', 42 | 'isa', 43 | 'desires', 44 | 'hassubevent', 45 | 'partof', 46 | 'hascontext', 47 | 'hasproperty', 48 | 'madeof', 49 | 'notcapableof', 50 | 'notdesires', 51 | 'receivesaction', 52 | 'relatedto', 53 | 'usedfor', 54 | ] 55 | 56 | relation_text = [ 57 | 'is the antonym of', 58 | 'is at location of', 59 | 'is capable of', 60 | 'causes', 61 | 'is created by', 62 | 'is a kind of', 63 | 'desires', 64 | 'has subevent', 65 | 'is part of', 66 | 'has context', 67 | 'has property', 68 | 'is made of', 69 | 'is not capable of', 70 | 'does not desires', 71 | 'is', 72 | 'is related to', 73 | 'is used for', 74 | ] 75 | 76 | 77 | def load_merge_relation(): 78 | relation_mapping = dict() 79 | for line in relation_groups: 80 | ls = line.strip().split('/') 81 | rel = ls[0] 82 | for l in ls: 83 | if l.startswith("*"): 84 | relation_mapping[l[1:]] = "*" + rel 85 | else: 86 | relation_mapping[l] = rel 87 | return relation_mapping 88 | 89 | 90 | def del_pos(s): 91 | """ 92 | Deletes part-of-speech encoding from an entity string, if present. 93 | :param s: Entity string. 94 | :return: Entity string with part-of-speech encoding removed. 95 | """ 96 | if s.endswith("/n") or s.endswith("/a") or s.endswith("/v") or s.endswith("/r"): 97 | s = s[:-2] 98 | return s 99 | 100 | 101 | def extract_english(conceptnet_path, output_csv_path, output_vocab_path): 102 | """ 103 | Reads original conceptnet csv file and extracts all English relations (head and tail are both English entities) into 104 | a new file, with the following format for each line: . 105 | :return: 106 | """ 107 | print('extracting English concepts and relations from ConceptNet...') 108 | relation_mapping = load_merge_relation() 109 | num_lines = sum(1 for line in open(conceptnet_path, 'r', encoding='utf-8')) 110 | cpnet_vocab = [] 111 | concepts_seen = set() 112 | with open(conceptnet_path, 'r', encoding="utf8") as fin, \ 113 | open(output_csv_path, 'w', encoding="utf8") as fout: 114 | for line in tqdm(fin, total=num_lines): 115 | toks = line.strip().split('\t') 116 | if toks[2].startswith('/c/en/') and toks[3].startswith('/c/en/'): 117 | """ 118 | Some preprocessing: 119 | - Remove part-of-speech encoding. 120 | - Split("/")[-1] to trim the "/c/en/" and just get the entity name, convert all to 121 | - Lowercase for uniformity. 122 | """ 123 | rel = toks[1].split("/")[-1].lower() 124 | head = del_pos(toks[2]).split("/")[-1].lower() 125 | tail = del_pos(toks[3]).split("/")[-1].lower() 126 | 127 | if not head.replace("_", "").replace("-", "").isalpha(): 128 | continue 129 | if not tail.replace("_", "").replace("-", "").isalpha(): 130 | continue 131 | if rel not in relation_mapping: 132 | continue 133 | 134 | rel = relation_mapping[rel] 135 | if rel.startswith("*"): 136 | head, tail, rel = tail, head, rel[1:] 137 | 138 | data = json.loads(toks[4]) 139 | 140 | fout.write('\t'.join([rel, head, tail, str(data["weight"])]) + '\n') 141 | 142 | for w in [head, tail]: 143 | if w not in concepts_seen: 144 | concepts_seen.add(w) 145 | cpnet_vocab.append(w) 146 | 147 | with open(output_vocab_path, 'w') as fout: 148 | for word in cpnet_vocab: 149 | fout.write(word + '\n') 150 | 151 | print(f'extracted ConceptNet csv file saved to {output_csv_path}') 152 | print(f'extracted concept vocabulary saved to {output_vocab_path}') 153 | print() 154 | 155 | 156 | def construct_graph(cpnet_csv_path, cpnet_vocab_path, output_path, prune=True): 157 | print('generating ConceptNet graph file...') 158 | 159 | nltk.download('stopwords', quiet=True) 160 | nltk_stopwords = nltk.corpus.stopwords.words('english') 161 | nltk_stopwords += ["like", "gone", "did", "going", "would", "could", 162 | "get", "in", "up", "may", "wanter"] # issue: mismatch with the stop words in grouding.py 163 | 164 | blacklist = set(["uk", "us", "take", "make", "object", "person", "people"]) # issue: mismatch with the blacklist in grouding.py 165 | 166 | concept2id = {} 167 | id2concept = {} 168 | with open(cpnet_vocab_path, "r", encoding="utf8") as fin: 169 | id2concept = [w.strip() for w in fin] 170 | concept2id = {w: i for i, w in enumerate(id2concept)} 171 | 172 | id2relation = merged_relations 173 | relation2id = {r: i for i, r in enumerate(id2relation)} 174 | 175 | graph = nx.MultiDiGraph() 176 | nrow = sum(1 for _ in open(cpnet_csv_path, 'r', encoding='utf-8')) 177 | with open(cpnet_csv_path, "r", encoding="utf8") as fin: 178 | 179 | def not_save(cpt): 180 | if cpt in blacklist: 181 | return True 182 | '''originally phrases like "branch out" would not be kept in the graph''' 183 | # for t in cpt.split("_"): 184 | # if t in nltk_stopwords: 185 | # return True 186 | return False 187 | 188 | attrs = set() 189 | 190 | for line in tqdm(fin, total=nrow): 191 | ls = line.strip().split('\t') 192 | rel = relation2id[ls[0]] 193 | subj = concept2id[ls[1]] 194 | obj = concept2id[ls[2]] 195 | weight = float(ls[3]) 196 | if prune and (not_save(ls[1]) or not_save(ls[2]) or id2relation[rel] == "hascontext"): 197 | continue 198 | # if id2relation[rel] == "relatedto" or id2relation[rel] == "antonym": 199 | # weight -= 0.3 200 | # continue 201 | if subj == obj: # delete loops 202 | continue 203 | # weight = 1 + float(math.exp(1 - weight)) # issue: ??? 204 | 205 | if (subj, obj, rel) not in attrs: 206 | graph.add_edge(subj, obj, rel=rel, weight=weight) 207 | attrs.add((subj, obj, rel)) 208 | graph.add_edge(obj, subj, rel=rel + len(relation2id), weight=weight) 209 | attrs.add((obj, subj, rel + len(relation2id))) 210 | 211 | nx.write_gpickle(graph, output_path) 212 | print(f"graph file saved to {output_path}") 213 | print() 214 | 215 | 216 | def glove_init(input, output, concept_file): 217 | embeddings_file = output + '.npy' 218 | vocabulary_file = output.split('.')[0] + '.vocab.txt' 219 | output_dir = '/'.join(output.split('/')[:-1]) 220 | output_prefix = output.split('/')[-1] 221 | 222 | words = [] 223 | vectors = [] 224 | vocab_exist = check_file(vocabulary_file) 225 | print("loading embedding") 226 | with open(input, 'rb') as f: 227 | for line in f: 228 | fields = line.split() 229 | if len(fields) <= 2: 230 | continue 231 | if not vocab_exist: 232 | word = fields[0].decode('utf-8') 233 | words.append(word) 234 | vector = np.fromiter((float(x) for x in fields[1:]), 235 | dtype=np.float) 236 | 237 | vectors.append(vector) 238 | dim = vector.shape[0] 239 | print("converting") 240 | matrix = np.array(vectors, dtype="float32") 241 | print("writing") 242 | np.save(embeddings_file, matrix) 243 | text = '\n'.join(words) 244 | if not vocab_exist: 245 | with open(vocabulary_file, 'wb') as f: 246 | f.write(text.encode('utf-8')) 247 | 248 | def load_glove_from_npy(glove_vec_path, glove_vocab_path): 249 | vectors = np.load(glove_vec_path) 250 | with open(glove_vocab_path, "r", encoding="utf8") as f: 251 | vocab = [l.strip() for l in f.readlines()] 252 | 253 | assert (len(vectors) == len(vocab)) 254 | 255 | glove_embeddings = {} 256 | for i in range(0, len(vectors)): 257 | glove_embeddings[vocab[i]] = vectors[i] 258 | print("Read " + str(len(glove_embeddings)) + " glove vectors.") 259 | return glove_embeddings 260 | 261 | def weighted_average(avg, new, n): 262 | # TODO: maybe a better name for this function? 263 | return ((n - 1) / n) * avg + (new / n) 264 | 265 | def max_pooling(old, new): 266 | # TODO: maybe a better name for this function? 267 | return np.maximum(old, new) 268 | 269 | def write_embeddings_npy(embeddings, embeddings_cnt, npy_path, vocab_path): 270 | words = [] 271 | vectors = [] 272 | for key, vec in embeddings.items(): 273 | words.append(key) 274 | vectors.append(vec) 275 | 276 | matrix = np.array(vectors, dtype="float32") 277 | print(matrix.shape) 278 | 279 | print("Writing embeddings matrix to " + npy_path, flush=True) 280 | np.save(npy_path, matrix) 281 | print("Finished writing embeddings matrix to " + npy_path, flush=True) 282 | 283 | if not check_file(vocab_path): 284 | print("Writing vocab file to " + vocab_path, flush=True) 285 | to_write = ["\t".join([w, str(embeddings_cnt[w])]) for w in words] 286 | with open(vocab_path, "w", encoding="utf8") as f: 287 | f.write("\n".join(to_write)) 288 | print("Finished writing vocab file to " + vocab_path, flush=True) 289 | 290 | def create_embeddings_glove(pooling="max", dim=100): 291 | print("Pooling: " + pooling) 292 | 293 | with open(concept_file, "r", encoding="utf8") as f: 294 | triple_str_json = json.load(f) 295 | print("Loaded " + str(len(triple_str_json)) + " triple strings.") 296 | 297 | glove_embeddings = load_glove_from_npy(embeddings_file, vocabulary_file) 298 | print("Loaded glove.", flush=True) 299 | 300 | concept_embeddings = {} 301 | concept_embeddings_cnt = {} 302 | rel_embeddings = {} 303 | rel_embeddings_cnt = {} 304 | 305 | for i in tqdm(range(len(triple_str_json))): 306 | data = triple_str_json[i] 307 | 308 | words = data["string"].strip().split(" ") 309 | 310 | rel = data["rel"] 311 | subj_start = data["subj_start"] 312 | subj_end = data["subj_end"] 313 | obj_start = data["obj_start"] 314 | obj_end = data["obj_end"] 315 | 316 | subj_words = words[subj_start:subj_end] 317 | obj_words = words[obj_start:obj_end] 318 | 319 | subj = " ".join(subj_words) 320 | obj = " ".join(obj_words) 321 | 322 | # counting the frequency (only used for the avg pooling) 323 | if subj not in concept_embeddings: 324 | concept_embeddings[subj] = np.zeros((dim,)) 325 | concept_embeddings_cnt[subj] = 0 326 | concept_embeddings_cnt[subj] += 1 327 | 328 | if obj not in concept_embeddings: 329 | concept_embeddings[obj] = np.zeros((dim,)) 330 | concept_embeddings_cnt[obj] = 0 331 | concept_embeddings_cnt[obj] += 1 332 | 333 | if rel not in rel_embeddings: 334 | rel_embeddings[rel] = np.zeros((dim,)) 335 | rel_embeddings_cnt[rel] = 0 336 | rel_embeddings_cnt[rel] += 1 337 | 338 | if pooling == "avg": 339 | subj_encoding_sum = sum([glove_embeddings.get(word, np.zeros((dim,))) for word in subj]) 340 | obj_encoding_sum = sum([glove_embeddings.get(word, np.zeros((dim,))) for word in obj]) 341 | 342 | if rel in ["relatedto", "antonym"]: 343 | # Symmetric relation. 344 | rel_encoding_sum = sum([glove_embeddings.get(word, np.zeros((dim,))) for word in 345 | words]) - subj_encoding_sum - obj_encoding_sum 346 | else: 347 | # Asymmetrical relation. 348 | rel_encoding_sum = obj_encoding_sum - subj_encoding_sum 349 | 350 | subj_len = subj_end - subj_start 351 | obj_len = obj_end - obj_start 352 | 353 | subj_encoding = subj_encoding_sum / subj_len 354 | obj_encoding = obj_encoding_sum / obj_len 355 | rel_encoding = rel_encoding_sum / (len(words) - subj_len - obj_len) 356 | 357 | concept_embeddings[subj] = subj_encoding 358 | concept_embeddings[obj] = obj_encoding 359 | rel_embeddings[rel] = weighted_average(rel_embeddings[rel], rel_encoding, rel_embeddings_cnt[rel]) 360 | 361 | elif pooling == "max": 362 | subj_encoding = np.amax([glove_embeddings.get(word, np.zeros((dim,))) for word in subj_words], axis=0) 363 | obj_encoding = np.amax([glove_embeddings.get(word, np.zeros((dim,))) for word in obj_words], axis=0) 364 | 365 | mask_rel = [] 366 | for j in range(len(words)): 367 | if subj_start <= j < subj_end or obj_start <= j < obj_end: 368 | continue 369 | mask_rel.append(j) 370 | rel_vecs = [glove_embeddings.get(words[i], np.zeros((dim,))) for i in mask_rel] 371 | rel_encoding = np.amax(rel_vecs, axis=0) 372 | 373 | # here it is actually avg over max for relation 374 | concept_embeddings[subj] = max_pooling(concept_embeddings[subj], subj_encoding) 375 | concept_embeddings[obj] = max_pooling(concept_embeddings[obj], obj_encoding) 376 | rel_embeddings[rel] = weighted_average(rel_embeddings[rel], rel_encoding, rel_embeddings_cnt[rel]) 377 | 378 | print(str(len(concept_embeddings)) + " concept embeddings") 379 | print(str(len(rel_embeddings)) + " relation embeddings") 380 | 381 | write_embeddings_npy(concept_embeddings, concept_embeddings_cnt, f'{output_dir}/concept.{output_prefix}.{pooling}.npy', 382 | f'{output_dir}/concept.glove.{pooling}.txt') 383 | write_embeddings_npy(rel_embeddings, rel_embeddings_cnt, f'{output_dir}/relation.{output_prefix}.{pooling}.npy', 384 | f'{output_dir}/relation.glove.{pooling}.txt') 385 | 386 | create_embeddings_glove(dim=dim) 387 | 388 | 389 | if __name__ == "__main__": 390 | glove_init("../data/glove/glove.6B.200d.txt", "../data/glove/glove.200d", '../data/glove/tp_str_corpus.json') 391 | -------------------------------------------------------------------------------- /utils_biomed/preprocess_medqa_usmle.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import os\n", 10 | "import json\n", 11 | "import pickle\n", 12 | "import numpy as np\n", 13 | "from tqdm import tqdm\n", 14 | "from collections import defaultdict" 15 | ] 16 | }, 17 | { 18 | "cell_type": "code", 19 | "execution_count": null, 20 | "metadata": {}, 21 | "outputs": [], 22 | "source": [ 23 | "repo_root = '..'" 24 | ] 25 | }, 26 | { 27 | "cell_type": "markdown", 28 | "metadata": {}, 29 | "source": [ 30 | "## Get MedQA-USMLE dataset" 31 | ] 32 | }, 33 | { 34 | "cell_type": "markdown", 35 | "metadata": {}, 36 | "source": [ 37 | "First, download the original MedQA dataset: https://github.com/jind11/MedQA. \n", 38 | "Put the unzipped folder in `data/medqa_usmle/raw`" 39 | ] 40 | }, 41 | { 42 | "cell_type": "code", 43 | "execution_count": null, 44 | "metadata": {}, 45 | "outputs": [], 46 | "source": [ 47 | "#Prepare `statement` data following CommonsenseQA, OpenBookQA\n", 48 | "medqa_root = f'{repo_root}/data/medqa_usmle'\n", 49 | "os.system(f'mkdir -p {medqa_root}/statement')\n", 50 | "\n", 51 | "for fname in [\"train\", \"dev\", \"test\"]:\n", 52 | " with open(f\"{medqa_root}/raw/questions/US/4_options/phrases_no_exclude_{fname}.jsonl\") as f:\n", 53 | " lines = f.readlines()\n", 54 | " examples = []\n", 55 | " for i in tqdm(range(len(lines))):\n", 56 | " line = json.loads(lines[i])\n", 57 | " _id = f\"train-{i:05d}\"\n", 58 | " answerKey = line[\"answer_idx\"]\n", 59 | " stem = line[\"question\"] \n", 60 | " choices = [{\"label\": k, \"text\": line[\"options\"][k]} for k in \"ABCD\"]\n", 61 | " stmts = [{\"statement\": stem +\" \"+ c[\"text\"]} for c in choices]\n", 62 | " ex_obj = {\"id\": _id, \n", 63 | " \"question\": {\"stem\": stem, \"choices\": choices}, \n", 64 | " \"answerKey\": answerKey, \n", 65 | " \"statements\": stmts\n", 66 | " }\n", 67 | " examples.append(ex_obj)\n", 68 | " with open(f\"{medqa_root}/statement/{fname}.statement.jsonl\", 'w') as fout:\n", 69 | " for dic in examples:\n", 70 | " print (json.dumps(dic), file=fout)" 71 | ] 72 | }, 73 | { 74 | "cell_type": "markdown", 75 | "metadata": {}, 76 | "source": [ 77 | "## Link entities to KG" 78 | ] 79 | }, 80 | { 81 | "cell_type": "markdown", 82 | "metadata": {}, 83 | "source": [ 84 | "First, install the scispacy model:\n", 85 | "```\n", 86 | "pip install scispacy==0.3.0\n", 87 | "pip install https://s3-us-west-2.amazonaws.com/ai2-s2-scispacy/releases/v0.3.0/en_core_sci_sm-0.3.0.tar.gz\n", 88 | "```" 89 | ] 90 | }, 91 | { 92 | "cell_type": "code", 93 | "execution_count": null, 94 | "metadata": {}, 95 | "outputs": [], 96 | "source": [ 97 | "#Load scispacy entity linker\n", 98 | "import spacy\n", 99 | "import scispacy\n", 100 | "from scispacy.linking import EntityLinker\n", 101 | "\n", 102 | "def load_entity_linker(threshold=0.90):\n", 103 | " nlp = spacy.load(\"en_core_sci_sm\")\n", 104 | " linker = EntityLinker(\n", 105 | " resolve_abbreviations=True,\n", 106 | " name=\"umls\",\n", 107 | " threshold=threshold)\n", 108 | " nlp.add_pipe(linker)\n", 109 | " return nlp, linker\n", 110 | "\n", 111 | "nlp, linker = load_entity_linker()" 112 | ] 113 | }, 114 | { 115 | "cell_type": "code", 116 | "execution_count": null, 117 | "metadata": {}, 118 | "outputs": [], 119 | "source": [ 120 | "def entity_linking_to_umls(sentence, nlp, linker):\n", 121 | " doc = nlp(sentence)\n", 122 | " entities = doc.ents\n", 123 | " all_entities_results = []\n", 124 | " for mm in range(len(entities)):\n", 125 | " entity_text = entities[mm].text\n", 126 | " entity_start = entities[mm].start\n", 127 | " entity_end = entities[mm].end\n", 128 | " all_linked_entities = entities[mm]._.kb_ents\n", 129 | " all_entity_results = []\n", 130 | " for ii in range(len(all_linked_entities)):\n", 131 | " curr_concept_id = all_linked_entities[ii][0]\n", 132 | " curr_score = all_linked_entities[ii][1]\n", 133 | " curr_scispacy_entity = linker.kb.cui_to_entity[all_linked_entities[ii][0]]\n", 134 | " curr_canonical_name = curr_scispacy_entity.canonical_name\n", 135 | " curr_TUIs = curr_scispacy_entity.types\n", 136 | " curr_entity_result = {\"Canonical Name\": curr_canonical_name, \"Concept ID\": curr_concept_id,\n", 137 | " \"TUIs\": curr_TUIs, \"Score\": curr_score}\n", 138 | " all_entity_results.append(curr_entity_result)\n", 139 | " curr_entities_result = {\"text\": entity_text, \"start\": entity_start, \"end\": entity_end, \n", 140 | " \"start_char\": entities[mm].start_char, \"end_char\": entities[mm].end_char,\n", 141 | " \"linking_results\": all_entity_results}\n", 142 | " all_entities_results.append(curr_entities_result)\n", 143 | " return all_entities_results" 144 | ] 145 | }, 146 | { 147 | "cell_type": "code", 148 | "execution_count": null, 149 | "metadata": {}, 150 | "outputs": [], 151 | "source": [ 152 | "#Example\n", 153 | "sent = \"A 5-year-old girl is brought to the emergency department by her mother because of multiple episodes of nausea and vomiting that last about 2 hours. During this period, she has had 6–8 episodes of bilious vomiting and abdominal pain. The vomiting was preceded by fatigue.\"\n", 154 | "ent_link_results = entity_linking_to_umls(sent, nlp, linker)\n", 155 | "ent_link_results" 156 | ] 157 | }, 158 | { 159 | "cell_type": "code", 160 | "execution_count": null, 161 | "metadata": {}, 162 | "outputs": [], 163 | "source": [ 164 | "#Run entity linking to UMLS for all questions\n", 165 | "def process(input):\n", 166 | " nlp, linker = load_entity_linker()\n", 167 | " stmts = input\n", 168 | " for stmt in tqdm(stmts):\n", 169 | " stem = stmt['question']['stem']\n", 170 | " stem = stem[:3500]\n", 171 | " stmt['question']['stem_ents'] = entity_linking_to_umls(stem, nlp, linker)\n", 172 | " for ii, choice in enumerate(stmt['question']['choices']):\n", 173 | " text = stmt['question']['choices'][ii]['text']\n", 174 | " stmt['question']['choices'][ii]['text_ents'] = entity_linking_to_umls(text, nlp, linker)\n", 175 | " return stmts\n", 176 | "\n", 177 | "for fname in [\"dev\", \"test\", \"train\"]:\n", 178 | " with open(f\"{medqa_root}/statement/{fname}.statement.jsonl\") as fin:\n", 179 | " stmts = [json.loads(line) for line in fin]\n", 180 | " res = process(stmts) \n", 181 | " with open(f\"{medqa_root}/statement/{fname}.statement.umls_linked.jsonl\", 'w') as fout:\n", 182 | " for dic in res:\n", 183 | " print (json.dumps(dic), file=fout)" 184 | ] 185 | }, 186 | { 187 | "cell_type": "code", 188 | "execution_count": null, 189 | "metadata": {}, 190 | "outputs": [], 191 | "source": [ 192 | "#Convert UMLS entity linking to DDB entity linking (our KG)\n", 193 | "umls_to_ddb = {}\n", 194 | "with open(f'{repo_root}/data/ddb/ddb_to_umls_cui.txt') as f:\n", 195 | " for line in f.readlines()[1:]:\n", 196 | " elms = line.split(\"\\t\")\n", 197 | " umls_to_ddb[elms[2]] = elms[1]\n", 198 | "\n", 199 | "def map_to_ddb(ent_obj):\n", 200 | " res = []\n", 201 | " for ent_cand in ent_obj['linking_results']:\n", 202 | " CUI = ent_cand['Concept ID']\n", 203 | " name = ent_cand['Canonical Name']\n", 204 | " if CUI in umls_to_ddb:\n", 205 | " ddb_cid = umls_to_ddb[CUI]\n", 206 | " res.append((ddb_cid, name))\n", 207 | " return res\n", 208 | "\n", 209 | "def process(fname):\n", 210 | " with open(f\"{medqa_root}/statement/{fname}.statement.umls_linked.jsonl\") as fin:\n", 211 | " stmts = [json.loads(line) for line in fin]\n", 212 | " with open(f\"{medqa_root}/grounded/{fname}.grounded.jsonl\", 'w') as fout:\n", 213 | " for stmt in tqdm(stmts):\n", 214 | " sent = stmt['question']['stem']\n", 215 | " qc = []\n", 216 | " qc_names = []\n", 217 | " for ent_obj in stmt['question']['stem_ents']:\n", 218 | " res = map_to_ddb(ent_obj)\n", 219 | " for elm in res:\n", 220 | " ddb_cid, name = elm\n", 221 | " qc.append(ddb_cid)\n", 222 | " qc_names.append(name)\n", 223 | " for cid, choice in enumerate(stmt['question']['choices']):\n", 224 | " ans = choice['text']\n", 225 | " ac = []\n", 226 | " ac_names = []\n", 227 | " for ent_obj in choice['text_ents']:\n", 228 | " res = map_to_ddb(ent_obj)\n", 229 | " for elm in res:\n", 230 | " ddb_cid, name = elm\n", 231 | " ac.append(ddb_cid)\n", 232 | " ac_names.append(name)\n", 233 | " out = {'sent': sent, 'ans': ans, 'qc': qc, 'qc_names': qc_names, 'ac': ac, 'ac_names': ac_names}\n", 234 | " print (json.dumps(out), file=fout) \n", 235 | "\n", 236 | "os.system(f'mkdir -p {medqa_root}/grounded')\n", 237 | "for fname in [\"dev\", \"test\", \"train\"]:\n", 238 | " process(fname) " 239 | ] 240 | }, 241 | { 242 | "cell_type": "markdown", 243 | "metadata": {}, 244 | "source": [ 245 | "## Load knowledge graph (KG)" 246 | ] 247 | }, 248 | { 249 | "cell_type": "markdown", 250 | "metadata": {}, 251 | "source": [ 252 | "Load our KG, which is based on Disease Database + DrugBank." 253 | ] 254 | }, 255 | { 256 | "cell_type": "code", 257 | "execution_count": null, 258 | "metadata": {}, 259 | "outputs": [], 260 | "source": [ 261 | "def load_ddb(): \n", 262 | " with open(f'{repo_root}/data/ddb/ddb_names.json') as f:\n", 263 | " all_names = json.load(f)\n", 264 | " with open(f'{repo_root}/data/ddb/ddb_relas.json') as f:\n", 265 | " all_relas = json.load(f)\n", 266 | " relas_lst = []\n", 267 | " for key, val in all_relas.items():\n", 268 | " relas_lst.append(val)\n", 269 | " \n", 270 | " ddb_ptr_to_preferred_name = {}\n", 271 | " ddb_ptr_to_name = defaultdict(list)\n", 272 | " ddb_name_to_ptr = {}\n", 273 | " for key, val in all_names.items():\n", 274 | " item_name = key\n", 275 | " item_ptr = val[0]\n", 276 | " item_preferred = val[1]\n", 277 | " if item_preferred == \"1\":\n", 278 | " ddb_ptr_to_preferred_name[item_ptr] = item_name\n", 279 | " ddb_name_to_ptr[item_name] = item_ptr\n", 280 | " ddb_ptr_to_name[item_ptr].append(item_name)\n", 281 | " \n", 282 | " return (relas_lst, ddb_ptr_to_name, ddb_name_to_ptr, ddb_ptr_to_preferred_name)\n", 283 | "\n", 284 | "\n", 285 | "relas_lst, ddb_ptr_to_name, ddb_name_to_ptr, ddb_ptr_to_preferred_name = load_ddb()\n", 286 | "\n", 287 | "\n", 288 | "ddb_ptr_lst, ddb_names_lst = [], []\n", 289 | "for key, val in ddb_ptr_to_preferred_name.items():\n", 290 | " ddb_ptr_lst.append(key)\n", 291 | " ddb_names_lst.append(val)\n", 292 | "\n", 293 | "with open(f\"{repo_root}/data/ddb/vocab.txt\", \"w\") as fout:\n", 294 | " for ddb_name in ddb_names_lst:\n", 295 | " print (ddb_name, file=fout)\n", 296 | "\n", 297 | "with open(f\"{repo_root}/data/ddb/ptrs.txt\", \"w\") as fout:\n", 298 | " for ddb_ptr in ddb_ptr_lst:\n", 299 | " print (ddb_ptr, file=fout)\n", 300 | "\n", 301 | "id2concept = ddb_ptr_lst" 302 | ] 303 | }, 304 | { 305 | "cell_type": "code", 306 | "execution_count": null, 307 | "metadata": {}, 308 | "outputs": [], 309 | "source": [ 310 | "len(ddb_ptr_to_name), len(ddb_ptr_to_preferred_name), len(ddb_name_to_ptr)" 311 | ] 312 | }, 313 | { 314 | "cell_type": "code", 315 | "execution_count": null, 316 | "metadata": {}, 317 | "outputs": [], 318 | "source": [ 319 | "ddb_name_to_ptr['Ethanol'], ddb_name_to_ptr['Serine']" 320 | ] 321 | }, 322 | { 323 | "cell_type": "code", 324 | "execution_count": null, 325 | "metadata": {}, 326 | "outputs": [], 327 | "source": [ 328 | "merged_relations = [\n", 329 | " 'belongs_to_the_category_of',\n", 330 | " 'is_a_category',\n", 331 | " 'may_cause',\n", 332 | " 'is_a_subtype_of',\n", 333 | " 'is_a_risk_factor_of',\n", 334 | " 'is_associated_with',\n", 335 | " 'may_contraindicate',\n", 336 | " 'interacts_with',\n", 337 | " 'belongs_to_the_drug_family_of',\n", 338 | " 'belongs_to_drug_super-family',\n", 339 | " 'is_a_vector_for',\n", 340 | " 'may_be_allelic_with',\n", 341 | " 'see_also',\n", 342 | " 'is_an_ingradient_of',\n", 343 | " 'may_treat'\n", 344 | "]\n", 345 | "\n", 346 | "relas_dict = {\"0\": 0, \"1\": 1, \"2\": 2, \"3\": 3, \"4\": 4, \"6\": 5, \"10\": 6, \"12\": 7, \"16\": 8, \"17\": 9, \"18\": 10,\n", 347 | " \"20\": 11, \"26\": 12, \"30\": 13, \"233\": 14}" 348 | ] 349 | }, 350 | { 351 | "cell_type": "code", 352 | "execution_count": null, 353 | "metadata": {}, 354 | "outputs": [], 355 | "source": [ 356 | "import networkx as nx\n", 357 | "\n", 358 | "def construct_graph():\n", 359 | " concept2id = {w: i for i, w in enumerate(id2concept)}\n", 360 | " id2relation = merged_relations\n", 361 | " relation2id = {r: i for i, r in enumerate(id2relation)}\n", 362 | " graph = nx.MultiDiGraph()\n", 363 | " attrs = set()\n", 364 | " for relation in relas_lst:\n", 365 | " subj = concept2id[relation[0]]\n", 366 | " obj = concept2id[relation[1]]\n", 367 | " rel = relas_dict[relation[2]]\n", 368 | " weight = 1.\n", 369 | " graph.add_edge(subj, obj, rel=rel, weight=weight)\n", 370 | " attrs.add((subj, obj, rel))\n", 371 | " graph.add_edge(obj, subj, rel=rel + len(relation2id), weight=weight)\n", 372 | " attrs.add((obj, subj, rel + len(relation2id)))\n", 373 | " output_path = f\"{repo_root}/data/ddb/ddb.graph\"\n", 374 | " nx.write_gpickle(graph, output_path)\n", 375 | " return concept2id, id2relation, relation2id, graph\n", 376 | "\n", 377 | "concept2id, id2relation, relation2id, KG = construct_graph()" 378 | ] 379 | }, 380 | { 381 | "cell_type": "markdown", 382 | "metadata": {}, 383 | "source": [ 384 | "## Get KG subgraph" 385 | ] 386 | }, 387 | { 388 | "cell_type": "markdown", 389 | "metadata": {}, 390 | "source": [ 391 | "We get KG subgraph for each question, following the method used for CommonsenseQA + ConceptNet." 392 | ] 393 | }, 394 | { 395 | "cell_type": "code", 396 | "execution_count": null, 397 | "metadata": {}, 398 | "outputs": [], 399 | "source": [ 400 | "def load_kg():\n", 401 | " global cpnet, cpnet_simple\n", 402 | " cpnet = KG\n", 403 | " cpnet_simple = nx.Graph()\n", 404 | " for u, v, data in cpnet.edges(data=True):\n", 405 | " w = data['weight'] if 'weight' in data else 1.0\n", 406 | " if cpnet_simple.has_edge(u, v):\n", 407 | " cpnet_simple[u][v]['weight'] += w\n", 408 | " else:\n", 409 | " cpnet_simple.add_edge(u, v, weight=w)\n", 410 | "\n", 411 | "load_kg()" 412 | ] 413 | }, 414 | { 415 | "cell_type": "code", 416 | "execution_count": null, 417 | "metadata": {}, 418 | "outputs": [], 419 | "source": [ 420 | "from scipy.sparse import csr_matrix, coo_matrix\n", 421 | "from multiprocessing import Pool\n", 422 | "\n", 423 | "def concepts2adj(node_ids):\n", 424 | " global id2relation\n", 425 | " cids = np.array(node_ids, dtype=np.int32)\n", 426 | " n_rel = len(id2relation)\n", 427 | " n_node = cids.shape[0]\n", 428 | " adj = np.zeros((n_rel, n_node, n_node), dtype=np.uint8)\n", 429 | " for s in range(n_node):\n", 430 | " for t in range(n_node):\n", 431 | " s_c, t_c = cids[s], cids[t]\n", 432 | " if cpnet.has_edge(s_c, t_c):\n", 433 | " for e_attr in cpnet[s_c][t_c].values():\n", 434 | " if e_attr['rel'] >= 0 and e_attr['rel'] < n_rel:\n", 435 | " adj[e_attr['rel']][s][t] = 1\n", 436 | " adj = coo_matrix(adj.reshape(-1, n_node))\n", 437 | " return adj, cids\n", 438 | "\n", 439 | "def concepts_to_adj_matrices_2hop_all_pair(data):\n", 440 | " qc_ids, ac_ids = data\n", 441 | " qa_nodes = set(qc_ids) | set(ac_ids)\n", 442 | " extra_nodes = set()\n", 443 | " for qid in qa_nodes:\n", 444 | " for aid in qa_nodes:\n", 445 | " if qid != aid and qid in cpnet_simple.nodes and aid in cpnet_simple.nodes:\n", 446 | " extra_nodes |= set(cpnet_simple[qid]) & set(cpnet_simple[aid])\n", 447 | " extra_nodes = extra_nodes - qa_nodes\n", 448 | " schema_graph = sorted(qc_ids) + sorted(ac_ids) + sorted(extra_nodes)\n", 449 | " arange = np.arange(len(schema_graph))\n", 450 | " qmask = arange < len(qc_ids)\n", 451 | " amask = (arange >= len(qc_ids)) & (arange < (len(qc_ids) + len(ac_ids)))\n", 452 | " adj, concepts = concepts2adj(schema_graph)\n", 453 | " return {'adj': adj, 'concepts': concepts, 'qmask': qmask, 'amask': amask, 'cid2score': None}" 454 | ] 455 | }, 456 | { 457 | "cell_type": "code", 458 | "execution_count": null, 459 | "metadata": {}, 460 | "outputs": [], 461 | "source": [ 462 | "def generate_adj_data_from_grounded_concepts(grounded_path, cpnet_graph_path, cpnet_vocab_path, output_path, num_processes):\n", 463 | " global concept2id, id2concept, relation2id, id2relation, cpnet_simple, cpnet\n", 464 | "\n", 465 | " qa_data = []\n", 466 | " with open(grounded_path, 'r', encoding='utf-8') as fin:\n", 467 | " for line in fin:\n", 468 | " dic = json.loads(line)\n", 469 | " q_ids = set(concept2id[c] for c in dic['qc'])\n", 470 | " if not q_ids:\n", 471 | " q_ids = {concept2id['31770']} \n", 472 | " a_ids = set(concept2id[c] for c in dic['ac'])\n", 473 | " if not a_ids:\n", 474 | " a_ids = {concept2id['325']}\n", 475 | " q_ids = q_ids - a_ids\n", 476 | " qa_data.append((q_ids, a_ids))\n", 477 | "\n", 478 | " with Pool(num_processes) as p:\n", 479 | " res = list(tqdm(p.imap(concepts_to_adj_matrices_2hop_all_pair, qa_data), total=len(qa_data)))\n", 480 | " \n", 481 | " lens = [len(e['concepts']) for e in res]\n", 482 | " print ('mean #nodes', int(np.mean(lens)), 'med', int(np.median(lens)), '5th', int(np.percentile(lens, 5)), '95th', int(np.percentile(lens, 95)))\n", 483 | "\n", 484 | " with open(output_path, 'wb') as fout:\n", 485 | " pickle.dump(res, fout)\n", 486 | "\n", 487 | " print(f'adj data saved to {output_path}')\n", 488 | " print()\n" 489 | ] 490 | }, 491 | { 492 | "cell_type": "code", 493 | "execution_count": null, 494 | "metadata": { 495 | "scrolled": true 496 | }, 497 | "outputs": [], 498 | "source": [ 499 | "os.system(f'mkdir -p {repo_root}/data/medqa_usmle/graph')\n", 500 | "\n", 501 | "for fname in [\"dev\", \"test\", \"train\"]:\n", 502 | " grounded_path = f\"{repo_root}/data/medqa_usmle/grounded/{fname}.grounded.jsonl\"\n", 503 | " kg_path = f\"{repo_root}/data/ddb/ddb.graph\"\n", 504 | " kg_vocab_path = f\"{repo_root}/data/ddb/ddb_ptrs.txt\"\n", 505 | " output_path = f\"{repo_root}/data/medqa_usmle/graph/{fname}.graph.adj.pk\"\n", 506 | "\n", 507 | " generate_adj_data_from_grounded_concepts(grounded_path, kg_path, kg_vocab_path, output_path, 10)" 508 | ] 509 | }, 510 | { 511 | "cell_type": "markdown", 512 | "metadata": {}, 513 | "source": [ 514 | "## Get KG entity embedding" 515 | ] 516 | }, 517 | { 518 | "cell_type": "code", 519 | "execution_count": null, 520 | "metadata": {}, 521 | "outputs": [], 522 | "source": [ 523 | "import torch\n", 524 | "from transformers import AutoTokenizer, AutoModel, AutoConfig\n", 525 | "tokenizer = AutoTokenizer.from_pretrained(\"cambridgeltl/SapBERT-from-PubMedBERT-fulltext\")\n", 526 | "bert_model = AutoModel.from_pretrained(\"cambridgeltl/SapBERT-from-PubMedBERT-fulltext\"\")\n", 527 | "device = torch.device('cuda')\n", 528 | "bert_model.to(device)\n", 529 | "bert_model.eval()" 530 | ] 531 | }, 532 | { 533 | "cell_type": "code", 534 | "execution_count": null, 535 | "metadata": {}, 536 | "outputs": [], 537 | "source": [ 538 | "with open(f\"{repo_root}/data/ddb/vocab.txt\") as f:\n", 539 | " names = [line.strip() for line in f]" 540 | ] 541 | }, 542 | { 543 | "cell_type": "code", 544 | "execution_count": null, 545 | "metadata": {}, 546 | "outputs": [], 547 | "source": [ 548 | "embs = []\n", 549 | "tensors = tokenizer(names, padding=True, truncation=True, return_tensors=\"pt\")\n", 550 | "with torch.no_grad():\n", 551 | " for i, j in enumerate(tqdm(names)):\n", 552 | " outputs = bert_model(input_ids=tensors[\"input_ids\"][i:i+1].to(device), \n", 553 | " attention_mask=tensors['attention_mask'][i:i+1].to(device))\n", 554 | " out = np.array(outputs[1].squeeze().tolist()).reshape((1, -1))\n", 555 | " embs.append(out)\n", 556 | "embs = np.concatenate(embs)\n", 557 | "np.save(f\"{repo_root}/data/ddb/ent_emb.npy\", embs)" 558 | ] 559 | }, 560 | { 561 | "cell_type": "code", 562 | "execution_count": null, 563 | "metadata": {}, 564 | "outputs": [], 565 | "source": [] 566 | } 567 | ], 568 | "metadata": { 569 | "kernelspec": { 570 | "display_name": "ct", 571 | "language": "python", 572 | "name": "ct" 573 | }, 574 | "language_info": { 575 | "codemirror_mode": { 576 | "name": "ipython", 577 | "version": 3 578 | }, 579 | "file_extension": ".py", 580 | "mimetype": "text/x-python", 581 | "name": "python", 582 | "nbconvert_exporter": "python", 583 | "pygments_lexer": "ipython3", 584 | "version": "3.7.10" 585 | } 586 | }, 587 | "nbformat": 4, 588 | "nbformat_minor": 2 589 | } 590 | -------------------------------------------------------------------------------- /qagnn.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | try: 4 | from transformers import (ConstantLRSchedule, WarmupLinearSchedule, WarmupConstantSchedule) 5 | except: 6 | from transformers import get_constant_schedule, get_constant_schedule_with_warmup, get_linear_schedule_with_warmup 7 | 8 | from modeling.modeling_qagnn import * 9 | from utils.optimization_utils import OPTIMIZER_CLASSES 10 | from utils.parser_utils import * 11 | 12 | 13 | DECODER_DEFAULT_LR = { 14 | 'csqa': 1e-3, 15 | 'obqa': 3e-4, 16 | 'medqa_usmle': 1e-3, 17 | } 18 | 19 | from collections import defaultdict, OrderedDict 20 | import numpy as np 21 | 22 | import socket, os, subprocess, datetime 23 | print(socket.gethostname()) 24 | print ("pid:", os.getpid()) 25 | print ("conda env:", os.environ['CONDA_DEFAULT_ENV']) 26 | print ("screen: %s" % subprocess.check_output('echo $STY', shell=True).decode('utf')) 27 | print ("gpu: %s" % subprocess.check_output('echo $CUDA_VISIBLE_DEVICES', shell=True).decode('utf')) 28 | 29 | 30 | def evaluate_accuracy(eval_set, model): 31 | n_samples, n_correct = 0, 0 32 | model.eval() 33 | with torch.no_grad(): 34 | for qids, labels, *input_data in tqdm(eval_set): 35 | logits, _ = model(*input_data) 36 | n_correct += (logits.argmax(1) == labels).sum().item() 37 | n_samples += labels.size(0) 38 | return n_correct / n_samples 39 | 40 | 41 | def main(): 42 | parser = get_parser() 43 | args, _ = parser.parse_known_args() 44 | parser.add_argument('--mode', default='train', choices=['train', 'eval_detail'], help='run training or evaluation') 45 | parser.add_argument('--save_dir', default=f'./saved_models/qagnn/', help='model output directory') 46 | parser.add_argument('--save_model', dest='save_model', action='store_true') 47 | parser.add_argument('--load_model_path', default=None) 48 | 49 | 50 | # data 51 | parser.add_argument('--num_relation', default=38, type=int, help='number of relations') 52 | parser.add_argument('--train_adj', default=f'data/{args.dataset}/graph/train.graph.adj.pk') 53 | parser.add_argument('--dev_adj', default=f'data/{args.dataset}/graph/dev.graph.adj.pk') 54 | parser.add_argument('--test_adj', default=f'data/{args.dataset}/graph/test.graph.adj.pk') 55 | parser.add_argument('--use_cache', default=True, type=bool_flag, nargs='?', const=True, help='use cached data to accelerate data loading') 56 | 57 | # model architecture 58 | parser.add_argument('-k', '--k', default=5, type=int, help='perform k-layer message passing') 59 | parser.add_argument('--att_head_num', default=2, type=int, help='number of attention heads') 60 | parser.add_argument('--gnn_dim', default=100, type=int, help='dimension of the GNN layers') 61 | parser.add_argument('--fc_dim', default=200, type=int, help='number of FC hidden units') 62 | parser.add_argument('--fc_layer_num', default=0, type=int, help='number of FC layers') 63 | parser.add_argument('--freeze_ent_emb', default=True, type=bool_flag, nargs='?', const=True, help='freeze entity embedding layer') 64 | 65 | parser.add_argument('--max_node_num', default=200, type=int) 66 | parser.add_argument('--simple', default=False, type=bool_flag, nargs='?', const=True) 67 | parser.add_argument('--subsample', default=1.0, type=float) 68 | parser.add_argument('--init_range', default=0.02, type=float, help='stddev when initializing with normal distribution') 69 | 70 | 71 | # regularization 72 | parser.add_argument('--dropouti', type=float, default=0.2, help='dropout for embedding layer') 73 | parser.add_argument('--dropoutg', type=float, default=0.2, help='dropout for GNN layers') 74 | parser.add_argument('--dropoutf', type=float, default=0.2, help='dropout for fully-connected layers') 75 | 76 | # optimization 77 | parser.add_argument('-dlr', '--decoder_lr', default=DECODER_DEFAULT_LR[args.dataset], type=float, help='learning rate') 78 | parser.add_argument('-mbs', '--mini_batch_size', default=1, type=int) 79 | parser.add_argument('-ebs', '--eval_batch_size', default=2, type=int) 80 | parser.add_argument('--unfreeze_epoch', default=4, type=int) 81 | parser.add_argument('--refreeze_epoch', default=10000, type=int) 82 | parser.add_argument('--fp16', default=False, type=bool_flag, help='use fp16 training. this requires torch>=1.6.0') 83 | parser.add_argument('--drop_partial_batch', default=False, type=bool_flag, help='') 84 | parser.add_argument('--fill_partial_batch', default=False, type=bool_flag, help='') 85 | 86 | parser.add_argument('-h', '--help', action='help', default=argparse.SUPPRESS, help='show this help message and exit') 87 | args = parser.parse_args() 88 | if args.simple: 89 | parser.set_defaults(k=1) 90 | args = parser.parse_args() 91 | args.fp16 = args.fp16 and (torch.__version__ >= '1.6.0') 92 | 93 | if args.mode == 'train': 94 | train(args) 95 | elif args.mode == 'eval_detail': 96 | # raise NotImplementedError 97 | eval_detail(args) 98 | else: 99 | raise ValueError('Invalid mode') 100 | 101 | 102 | 103 | 104 | def train(args): 105 | print(args) 106 | 107 | random.seed(args.seed) 108 | np.random.seed(args.seed) 109 | torch.manual_seed(args.seed) 110 | if torch.cuda.is_available() and args.cuda: 111 | torch.cuda.manual_seed(args.seed) 112 | 113 | config_path = os.path.join(args.save_dir, 'config.json') 114 | model_path = os.path.join(args.save_dir, 'model.pt') 115 | log_path = os.path.join(args.save_dir, 'log.csv') 116 | export_config(args, config_path) 117 | check_path(model_path) 118 | with open(log_path, 'w') as fout: 119 | fout.write('step,dev_acc,test_acc\n') 120 | 121 | ################################################################################################### 122 | # Load data # 123 | ################################################################################################### 124 | cp_emb = [np.load(path) for path in args.ent_emb_paths] 125 | cp_emb = torch.tensor(np.concatenate(cp_emb, 1), dtype=torch.float) 126 | 127 | concept_num, concept_dim = cp_emb.size(0), cp_emb.size(1) 128 | print('| num_concepts: {} |'.format(concept_num)) 129 | 130 | # try: 131 | if True: 132 | if torch.cuda.device_count() >= 2 and args.cuda: 133 | device0 = torch.device("cuda:0") 134 | device1 = torch.device("cuda:1") 135 | elif torch.cuda.device_count() == 1 and args.cuda: 136 | device0 = torch.device("cuda:0") 137 | device1 = torch.device("cuda:0") 138 | else: 139 | device0 = torch.device("cpu") 140 | device1 = torch.device("cpu") 141 | dataset = LM_QAGNN_DataLoader(args, args.train_statements, args.train_adj, 142 | args.dev_statements, args.dev_adj, 143 | args.test_statements, args.test_adj, 144 | batch_size=args.batch_size, eval_batch_size=args.eval_batch_size, 145 | device=(device0, device1), 146 | model_name=args.encoder, 147 | max_node_num=args.max_node_num, max_seq_length=args.max_seq_len, 148 | is_inhouse=args.inhouse, inhouse_train_qids_path=args.inhouse_train_qids, 149 | subsample=args.subsample, use_cache=args.use_cache) 150 | 151 | ################################################################################################### 152 | # Build model # 153 | ################################################################################################### 154 | print ('args.num_relation', args.num_relation) 155 | model = LM_QAGNN(args, args.encoder, k=args.k, n_ntype=4, n_etype=args.num_relation, n_concept=concept_num, 156 | concept_dim=args.gnn_dim, 157 | concept_in_dim=concept_dim, 158 | n_attention_head=args.att_head_num, fc_dim=args.fc_dim, n_fc_layer=args.fc_layer_num, 159 | p_emb=args.dropouti, p_gnn=args.dropoutg, p_fc=args.dropoutf, 160 | pretrained_concept_emb=cp_emb, freeze_ent_emb=args.freeze_ent_emb, 161 | init_range=args.init_range, 162 | encoder_config={}) 163 | if args.load_model_path: 164 | print (f'loading and initializing model from {args.load_model_path}') 165 | model_state_dict, old_args = torch.load(args.load_model_path, map_location=torch.device('cpu')) 166 | model.load_state_dict(model_state_dict) 167 | 168 | model.encoder.to(device0) 169 | model.decoder.to(device1) 170 | 171 | 172 | no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight'] 173 | 174 | grouped_parameters = [ 175 | {'params': [p for n, p in model.encoder.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': args.weight_decay, 'lr': args.encoder_lr}, 176 | {'params': [p for n, p in model.encoder.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0, 'lr': args.encoder_lr}, 177 | {'params': [p for n, p in model.decoder.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': args.weight_decay, 'lr': args.decoder_lr}, 178 | {'params': [p for n, p in model.decoder.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0, 'lr': args.decoder_lr}, 179 | ] 180 | optimizer = OPTIMIZER_CLASSES[args.optim](grouped_parameters) 181 | 182 | if args.lr_schedule == 'fixed': 183 | try: 184 | scheduler = ConstantLRSchedule(optimizer) 185 | except: 186 | scheduler = get_constant_schedule(optimizer) 187 | elif args.lr_schedule == 'warmup_constant': 188 | try: 189 | scheduler = WarmupConstantSchedule(optimizer, warmup_steps=args.warmup_steps) 190 | except: 191 | scheduler = get_constant_schedule_with_warmup(optimizer, num_warmup_steps=args.warmup_steps) 192 | elif args.lr_schedule == 'warmup_linear': 193 | max_steps = int(args.n_epochs * (dataset.train_size() / args.batch_size)) 194 | try: 195 | scheduler = WarmupLinearSchedule(optimizer, warmup_steps=args.warmup_steps, t_total=max_steps) 196 | except: 197 | scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=max_steps) 198 | 199 | print('parameters:') 200 | for name, param in model.decoder.named_parameters(): 201 | if param.requires_grad: 202 | print('\t{:45}\ttrainable\t{}\tdevice:{}'.format(name, param.size(), param.device)) 203 | else: 204 | print('\t{:45}\tfixed\t{}\tdevice:{}'.format(name, param.size(), param.device)) 205 | num_params = sum(p.numel() for p in model.decoder.parameters() if p.requires_grad) 206 | print('\ttotal:', num_params) 207 | 208 | if args.loss == 'margin_rank': 209 | loss_func = nn.MarginRankingLoss(margin=0.1, reduction='mean') 210 | elif args.loss == 'cross_entropy': 211 | loss_func = nn.CrossEntropyLoss(reduction='mean') 212 | 213 | def compute_loss(logits, labels): 214 | if args.loss == 'margin_rank': 215 | num_choice = logits.size(1) 216 | flat_logits = logits.view(-1) 217 | correct_mask = F.one_hot(labels, num_classes=num_choice).view(-1) # of length batch_size*num_choice 218 | correct_logits = flat_logits[correct_mask == 1].contiguous().view(-1, 1).expand(-1, num_choice - 1).contiguous().view(-1) # of length batch_size*(num_choice-1) 219 | wrong_logits = flat_logits[correct_mask == 0] 220 | y = wrong_logits.new_ones((wrong_logits.size(0),)) 221 | loss = loss_func(correct_logits, wrong_logits, y) # margin ranking loss 222 | elif args.loss == 'cross_entropy': 223 | loss = loss_func(logits, labels) 224 | return loss 225 | 226 | ################################################################################################### 227 | # Training # 228 | ################################################################################################### 229 | 230 | print() 231 | print('-' * 71) 232 | if args.fp16: 233 | print ('Using fp16 training') 234 | scaler = torch.cuda.amp.GradScaler() 235 | 236 | global_step, best_dev_epoch = 0, 0 237 | best_dev_acc, final_test_acc, total_loss = 0.0, 0.0, 0.0 238 | start_time = time.time() 239 | model.train() 240 | freeze_net(model.encoder) 241 | if True: 242 | # try: 243 | for epoch_id in range(args.n_epochs): 244 | if epoch_id == args.unfreeze_epoch: 245 | unfreeze_net(model.encoder) 246 | if epoch_id == args.refreeze_epoch: 247 | freeze_net(model.encoder) 248 | model.train() 249 | for qids, labels, *input_data in dataset.train(): 250 | optimizer.zero_grad() 251 | bs = labels.size(0) 252 | for a in range(0, bs, args.mini_batch_size): 253 | b = min(a + args.mini_batch_size, bs) 254 | if args.fp16: 255 | with torch.cuda.amp.autocast(): 256 | logits, _ = model(*[x[a:b] for x in input_data], layer_id=args.encoder_layer) 257 | loss = compute_loss(logits, labels[a:b]) 258 | else: 259 | logits, _ = model(*[x[a:b] for x in input_data], layer_id=args.encoder_layer) 260 | loss = compute_loss(logits, labels[a:b]) 261 | loss = loss * (b - a) / bs 262 | if args.fp16: 263 | scaler.scale(loss).backward() 264 | else: 265 | loss.backward() 266 | total_loss += loss.item() 267 | if args.max_grad_norm > 0: 268 | if args.fp16: 269 | scaler.unscale_(optimizer) 270 | nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm) 271 | else: 272 | nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm) 273 | scheduler.step() 274 | if args.fp16: 275 | scaler.step(optimizer) 276 | scaler.update() 277 | else: 278 | optimizer.step() 279 | 280 | if (global_step + 1) % args.log_interval == 0: 281 | total_loss /= args.log_interval 282 | ms_per_batch = 1000 * (time.time() - start_time) / args.log_interval 283 | print('| step {:5} | lr: {:9.7f} | loss {:7.4f} | ms/batch {:7.2f} |'.format(global_step, scheduler.get_lr()[0], total_loss, ms_per_batch)) 284 | total_loss = 0 285 | start_time = time.time() 286 | global_step += 1 287 | 288 | model.eval() 289 | dev_acc = evaluate_accuracy(dataset.dev(), model) 290 | save_test_preds = args.save_model 291 | if not save_test_preds: 292 | test_acc = evaluate_accuracy(dataset.test(), model) if args.test_statements else 0.0 293 | else: 294 | eval_set = dataset.test() 295 | total_acc = [] 296 | count = 0 297 | preds_path = os.path.join(args.save_dir, 'test_e{}_preds.csv'.format(epoch_id)) 298 | with open(preds_path, 'w') as f_preds: 299 | with torch.no_grad(): 300 | for qids, labels, *input_data in tqdm(eval_set): 301 | count += 1 302 | logits, _, concept_ids, node_type_ids, edge_index, edge_type = model(*input_data, detail=True) 303 | predictions = logits.argmax(1) #[bsize, ] 304 | preds_ranked = (-logits).argsort(1) #[bsize, n_choices] 305 | for i, (qid, label, pred, _preds_ranked, cids, ntype, edges, etype) in enumerate(zip(qids, labels, predictions, preds_ranked, concept_ids, node_type_ids, edge_index, edge_type)): 306 | acc = int(pred.item()==label.item()) 307 | print ('{},{}'.format(qid, chr(ord('A') + pred.item())), file=f_preds) 308 | f_preds.flush() 309 | total_acc.append(acc) 310 | test_acc = float(sum(total_acc))/len(total_acc) 311 | 312 | print('-' * 71) 313 | print('| epoch {:3} | step {:5} | dev_acc {:7.4f} | test_acc {:7.4f} |'.format(epoch_id, global_step, dev_acc, test_acc)) 314 | print('-' * 71) 315 | with open(log_path, 'a') as fout: 316 | fout.write('{},{},{}\n'.format(global_step, dev_acc, test_acc)) 317 | if dev_acc >= best_dev_acc: 318 | best_dev_acc = dev_acc 319 | final_test_acc = test_acc 320 | best_dev_epoch = epoch_id 321 | if args.save_model: 322 | torch.save([model.state_dict(), args], f"{model_path}.{epoch_id}") 323 | # with open(model_path +".{}.log.txt".format(epoch_id), 'w') as f: 324 | # for p in model.named_parameters(): 325 | # print (p, file=f) 326 | print(f'model saved to {model_path}.{epoch_id}') 327 | else: 328 | if args.save_model: 329 | torch.save([model.state_dict(), args], f"{model_path}.{epoch_id}") 330 | # with open(model_path +".{}.log.txt".format(epoch_id), 'w') as f: 331 | # for p in model.named_parameters(): 332 | # print (p, file=f) 333 | print(f'model saved to {model_path}.{epoch_id}') 334 | model.train() 335 | start_time = time.time() 336 | if epoch_id > args.unfreeze_epoch and epoch_id - best_dev_epoch >= args.max_epochs_before_stop: 337 | break 338 | # except (KeyboardInterrupt, RuntimeError) as e: 339 | # print(e) 340 | 341 | 342 | 343 | def eval_detail(args): 344 | assert args.load_model_path is not None 345 | model_path = args.load_model_path 346 | 347 | cp_emb = [np.load(path) for path in args.ent_emb_paths] 348 | cp_emb = torch.tensor(np.concatenate(cp_emb, 1), dtype=torch.float) 349 | concept_num, concept_dim = cp_emb.size(0), cp_emb.size(1) 350 | print('| num_concepts: {} |'.format(concept_num)) 351 | 352 | model_state_dict, old_args = torch.load(model_path, map_location=torch.device('cpu')) 353 | model = LM_QAGNN(old_args, old_args.encoder, k=old_args.k, n_ntype=4, n_etype=old_args.num_relation, n_concept=concept_num, 354 | concept_dim=old_args.gnn_dim, 355 | concept_in_dim=concept_dim, 356 | n_attention_head=old_args.att_head_num, fc_dim=old_args.fc_dim, n_fc_layer=old_args.fc_layer_num, 357 | p_emb=old_args.dropouti, p_gnn=old_args.dropoutg, p_fc=old_args.dropoutf, 358 | pretrained_concept_emb=cp_emb, freeze_ent_emb=old_args.freeze_ent_emb, 359 | init_range=old_args.init_range, 360 | encoder_config={}) 361 | model.load_state_dict(model_state_dict) 362 | 363 | if torch.cuda.device_count() >= 2 and args.cuda: 364 | device0 = torch.device("cuda:0") 365 | device1 = torch.device("cuda:1") 366 | elif torch.cuda.device_count() == 1 and args.cuda: 367 | device0 = torch.device("cuda:0") 368 | device1 = torch.device("cuda:0") 369 | else: 370 | device0 = torch.device("cpu") 371 | device1 = torch.device("cpu") 372 | model.encoder.to(device0) 373 | model.decoder.to(device1) 374 | model.eval() 375 | 376 | statement_dic = {} 377 | for statement_path in (args.train_statements, args.dev_statements, args.test_statements): 378 | statement_dic.update(load_statement_dict(statement_path)) 379 | 380 | use_contextualized = 'lm' in old_args.ent_emb 381 | 382 | print ('inhouse?', args.inhouse) 383 | 384 | print ('args.train_statements', args.train_statements) 385 | print ('args.dev_statements', args.dev_statements) 386 | print ('args.test_statements', args.test_statements) 387 | print ('args.train_adj', args.train_adj) 388 | print ('args.dev_adj', args.dev_adj) 389 | print ('args.test_adj', args.test_adj) 390 | 391 | dataset = LM_QAGNN_DataLoader(args, args.train_statements, args.train_adj, 392 | args.dev_statements, args.dev_adj, 393 | args.test_statements, args.test_adj, 394 | batch_size=args.batch_size, eval_batch_size=args.eval_batch_size, 395 | device=(device0, device1), 396 | model_name=old_args.encoder, 397 | max_node_num=old_args.max_node_num, max_seq_length=old_args.max_seq_len, 398 | is_inhouse=args.inhouse, inhouse_train_qids_path=args.inhouse_train_qids, 399 | subsample=args.subsample, use_cache=args.use_cache) 400 | 401 | save_test_preds = args.save_model 402 | dev_acc = evaluate_accuracy(dataset.dev(), model) 403 | print('dev_acc {:7.4f}'.format(dev_acc)) 404 | if not save_test_preds: 405 | test_acc = evaluate_accuracy(dataset.test(), model) if args.test_statements else 0.0 406 | else: 407 | eval_set = dataset.test() 408 | total_acc = [] 409 | count = 0 410 | dt = datetime.datetime.today().strftime('%Y%m%d%H%M%S') 411 | preds_path = os.path.join(args.save_dir, 'test_preds_{}.csv'.format(dt)) 412 | with open(preds_path, 'w') as f_preds: 413 | with torch.no_grad(): 414 | for qids, labels, *input_data in tqdm(eval_set): 415 | count += 1 416 | logits, _, concept_ids, node_type_ids, edge_index, edge_type = model(*input_data, detail=True) 417 | predictions = logits.argmax(1) #[bsize, ] 418 | preds_ranked = (-logits).argsort(1) #[bsize, n_choices] 419 | for i, (qid, label, pred, _preds_ranked, cids, ntype, edges, etype) in enumerate(zip(qids, labels, predictions, preds_ranked, concept_ids, node_type_ids, edge_index, edge_type)): 420 | acc = int(pred.item()==label.item()) 421 | print ('{},{}'.format(qid, chr(ord('A') + pred.item())), file=f_preds) 422 | f_preds.flush() 423 | total_acc.append(acc) 424 | test_acc = float(sum(total_acc))/len(total_acc) 425 | 426 | print('-' * 71) 427 | print('test_acc {:7.4f}'.format(test_acc)) 428 | print('-' * 71) 429 | 430 | 431 | 432 | if __name__ == '__main__': 433 | main() 434 | -------------------------------------------------------------------------------- /utils/graph.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import networkx as nx 3 | import itertools 4 | import json 5 | from tqdm import tqdm 6 | from .conceptnet import merged_relations 7 | import numpy as np 8 | from scipy import sparse 9 | import pickle 10 | from scipy.sparse import csr_matrix, coo_matrix 11 | from multiprocessing import Pool 12 | from collections import OrderedDict 13 | 14 | 15 | from .maths import * 16 | 17 | __all__ = ['generate_graph'] 18 | 19 | concept2id = None 20 | id2concept = None 21 | relation2id = None 22 | id2relation = None 23 | 24 | cpnet = None 25 | cpnet_all = None 26 | cpnet_simple = None 27 | 28 | 29 | def load_resources(cpnet_vocab_path): 30 | global concept2id, id2concept, relation2id, id2relation 31 | 32 | with open(cpnet_vocab_path, "r", encoding="utf8") as fin: 33 | id2concept = [w.strip() for w in fin] 34 | concept2id = {w: i for i, w in enumerate(id2concept)} 35 | 36 | id2relation = merged_relations 37 | relation2id = {r: i for i, r in enumerate(id2relation)} 38 | 39 | 40 | def load_cpnet(cpnet_graph_path): 41 | global cpnet, cpnet_simple 42 | cpnet = nx.read_gpickle(cpnet_graph_path) 43 | cpnet_simple = nx.Graph() 44 | for u, v, data in cpnet.edges(data=True): 45 | w = data['weight'] if 'weight' in data else 1.0 46 | if cpnet_simple.has_edge(u, v): 47 | cpnet_simple[u][v]['weight'] += w 48 | else: 49 | cpnet_simple.add_edge(u, v, weight=w) 50 | 51 | 52 | def relational_graph_generation(qcs, acs, paths, rels): 53 | raise NotImplementedError() # TODO 54 | 55 | 56 | # plain graph generation 57 | def plain_graph_generation(qcs, acs, paths, rels): 58 | global cpnet, concept2id, relation2id, id2relation, id2concept, cpnet_simple 59 | 60 | graph = nx.Graph() 61 | for p in paths: 62 | for c_index in range(len(p) - 1): 63 | h = p[c_index] 64 | t = p[c_index + 1] 65 | # TODO: the weight can computed by concept embeddings and relation embeddings of TransE 66 | graph.add_edge(h, t, weight=1.0) 67 | 68 | for qc1, qc2 in list(itertools.combinations(qcs, 2)): 69 | if cpnet_simple.has_edge(qc1, qc2): 70 | graph.add_edge(qc1, qc2, weight=1.0) 71 | 72 | for ac1, ac2 in list(itertools.combinations(acs, 2)): 73 | if cpnet_simple.has_edge(ac1, ac2): 74 | graph.add_edge(ac1, ac2, weight=1.0) 75 | 76 | if len(qcs) == 0: 77 | qcs.append(-1) 78 | 79 | if len(acs) == 0: 80 | acs.append(-1) 81 | 82 | if len(paths) == 0: 83 | for qc in qcs: 84 | for ac in acs: 85 | graph.add_edge(qc, ac, rel=-1, weight=0.1) 86 | 87 | g = nx.convert_node_labels_to_integers(graph, label_attribute='cid') # re-index 88 | return nx.node_link_data(g) 89 | 90 | 91 | def generate_adj_matrix_per_inst(nxg_str): 92 | global id2relation 93 | n_rel = len(id2relation) 94 | 95 | nxg = nx.node_link_graph(json.loads(nxg_str)) 96 | n_node = len(nxg.nodes) 97 | cids = np.zeros(n_node, dtype=np.int32) 98 | for node_id, node_attr in nxg.nodes(data=True): 99 | cids[node_id] = node_attr['cid'] 100 | 101 | adj = np.zeros((n_rel, n_node, n_node), dtype=np.uint8) 102 | for s in range(n_node): 103 | for t in range(n_node): 104 | s_c, t_c = cids[s], cids[t] 105 | if cpnet_all.has_edge(s_c, t_c): 106 | for e_attr in cpnet_all[s_c][t_c].values(): 107 | if e_attr['rel'] >= 0 and e_attr['rel'] < n_rel: 108 | adj[e_attr['rel']][s][t] = 1 109 | cids += 1 110 | adj = coo_matrix(adj.reshape(-1, n_node)) 111 | return (adj, cids) 112 | 113 | 114 | def concepts2adj(node_ids): 115 | global id2relation 116 | cids = np.array(node_ids, dtype=np.int32) 117 | n_rel = len(id2relation) 118 | n_node = cids.shape[0] 119 | adj = np.zeros((n_rel, n_node, n_node), dtype=np.uint8) 120 | for s in range(n_node): 121 | for t in range(n_node): 122 | s_c, t_c = cids[s], cids[t] 123 | if cpnet.has_edge(s_c, t_c): 124 | for e_attr in cpnet[s_c][t_c].values(): 125 | if e_attr['rel'] >= 0 and e_attr['rel'] < n_rel: 126 | adj[e_attr['rel']][s][t] = 1 127 | # cids += 1 # note!!! index 0 is reserved for padding 128 | adj = coo_matrix(adj.reshape(-1, n_node)) 129 | return adj, cids 130 | 131 | 132 | def concepts_to_adj_matrices_1hop_neighbours(data): 133 | qc_ids, ac_ids = data 134 | qa_nodes = set(qc_ids) | set(ac_ids) 135 | extra_nodes = set() 136 | for u in set(qc_ids) | set(ac_ids): 137 | if u in cpnet.nodes: 138 | extra_nodes |= set(cpnet[u]) 139 | extra_nodes = extra_nodes - qa_nodes 140 | schema_graph = sorted(qc_ids) + sorted(ac_ids) + sorted(extra_nodes) 141 | arange = np.arange(len(schema_graph)) 142 | qmask = arange < len(qc_ids) 143 | amask = (arange >= len(qc_ids)) & (arange < (len(qc_ids) + len(ac_ids))) 144 | adj, concepts = concepts2adj(schema_graph) 145 | return adj, concepts, qmask, amask 146 | 147 | 148 | def concepts_to_adj_matrices_1hop_neighbours_without_relatedto(data): 149 | qc_ids, ac_ids = data 150 | qa_nodes = set(qc_ids) | set(ac_ids) 151 | extra_nodes = set() 152 | for u in set(qc_ids) | set(ac_ids): 153 | if u in cpnet.nodes: 154 | for v in cpnet[u]: 155 | for data in cpnet[u][v].values(): 156 | if data['rel'] not in (15, 32): 157 | extra_nodes.add(v) 158 | extra_nodes = extra_nodes - qa_nodes 159 | schema_graph = sorted(qc_ids) + sorted(ac_ids) + sorted(extra_nodes) 160 | arange = np.arange(len(schema_graph)) 161 | qmask = arange < len(qc_ids) 162 | amask = (arange >= len(qc_ids)) & (arange < (len(qc_ids) + len(ac_ids))) 163 | adj, concepts = concepts2adj(schema_graph) 164 | return adj, concepts, qmask, amask 165 | 166 | 167 | def concepts_to_adj_matrices_2hop_qa_pair(data): 168 | qc_ids, ac_ids = data 169 | qa_nodes = set(qc_ids) | set(ac_ids) 170 | extra_nodes = set() 171 | for qid in qc_ids: 172 | for aid in ac_ids: 173 | if qid != aid and qid in cpnet_simple.nodes and aid in cpnet_simple.nodes: 174 | extra_nodes |= set(cpnet_simple[qid]) & set(cpnet_simple[aid]) 175 | extra_nodes = extra_nodes - qa_nodes 176 | schema_graph = sorted(qc_ids) + sorted(ac_ids) + sorted(extra_nodes) 177 | arange = np.arange(len(schema_graph)) 178 | qmask = arange < len(qc_ids) 179 | amask = (arange >= len(qc_ids)) & (arange < (len(qc_ids) + len(ac_ids))) 180 | adj, concepts = concepts2adj(schema_graph) 181 | return adj, concepts, qmask, amask 182 | 183 | 184 | def concepts_to_adj_matrices_2hop_all_pair(data): 185 | qc_ids, ac_ids = data 186 | qa_nodes = set(qc_ids) | set(ac_ids) 187 | extra_nodes = set() 188 | for qid in qa_nodes: 189 | for aid in qa_nodes: 190 | if qid != aid and qid in cpnet_simple.nodes and aid in cpnet_simple.nodes: 191 | extra_nodes |= set(cpnet_simple[qid]) & set(cpnet_simple[aid]) 192 | extra_nodes = extra_nodes - qa_nodes 193 | schema_graph = sorted(qc_ids) + sorted(ac_ids) + sorted(extra_nodes) 194 | arange = np.arange(len(schema_graph)) 195 | qmask = arange < len(qc_ids) 196 | amask = (arange >= len(qc_ids)) & (arange < (len(qc_ids) + len(ac_ids))) 197 | adj, concepts = concepts2adj(schema_graph) 198 | return adj, concepts, qmask, amask 199 | 200 | 201 | def concepts_to_adj_matrices_2step_relax_all_pair(data): 202 | qc_ids, ac_ids = data 203 | qa_nodes = set(qc_ids) | set(ac_ids) 204 | extra_nodes = set() 205 | for qid in qc_ids: 206 | for aid in ac_ids: 207 | if qid != aid and qid in cpnet_simple.nodes and aid in cpnet_simple.nodes: 208 | extra_nodes |= set(cpnet_simple[qid]) & set(cpnet_simple[aid]) 209 | intermediate_ids = extra_nodes - qa_nodes 210 | for qid in intermediate_ids: 211 | for aid in ac_ids: 212 | if qid != aid and qid in cpnet_simple.nodes and aid in cpnet_simple.nodes: 213 | extra_nodes |= set(cpnet_simple[qid]) & set(cpnet_simple[aid]) 214 | for qid in qc_ids: 215 | for aid in intermediate_ids: 216 | if qid != aid and qid in cpnet_simple.nodes and aid in cpnet_simple.nodes: 217 | extra_nodes |= set(cpnet_simple[qid]) & set(cpnet_simple[aid]) 218 | extra_nodes = extra_nodes - qa_nodes 219 | schema_graph = sorted(qc_ids) + sorted(ac_ids) + sorted(extra_nodes) 220 | arange = np.arange(len(schema_graph)) 221 | qmask = arange < len(qc_ids) 222 | amask = (arange >= len(qc_ids)) & (arange < (len(qc_ids) + len(ac_ids))) 223 | adj, concepts = concepts2adj(schema_graph) 224 | return adj, concepts, qmask, amask 225 | 226 | 227 | def concepts_to_adj_matrices_3hop_qa_pair(data): 228 | qc_ids, ac_ids = data 229 | qa_nodes = set(qc_ids) | set(ac_ids) 230 | extra_nodes = set() 231 | for qid in qc_ids: 232 | for aid in ac_ids: 233 | if qid != aid and qid in cpnet_simple.nodes and aid in cpnet_simple.nodes: 234 | for u in cpnet_simple[qid]: 235 | for v in cpnet_simple[aid]: 236 | if cpnet_simple.has_edge(u, v): # ac is a 3-hop neighbour of qc 237 | extra_nodes.add(u) 238 | extra_nodes.add(v) 239 | if u == v: # ac is a 2-hop neighbour of qc 240 | extra_nodes.add(u) 241 | extra_nodes = extra_nodes - qa_nodes 242 | schema_graph = sorted(qc_ids) + sorted(ac_ids) + sorted(extra_nodes) 243 | arange = np.arange(len(schema_graph)) 244 | qmask = arange < len(qc_ids) 245 | amask = (arange >= len(qc_ids)) & (arange < (len(qc_ids) + len(ac_ids))) 246 | adj, concepts = concepts2adj(schema_graph) 247 | return adj, concepts, qmask, amask 248 | 249 | 250 | 251 | ###################################################################### 252 | from transformers import RobertaTokenizer, RobertaForMaskedLM 253 | 254 | class RobertaForMaskedLMwithLoss(RobertaForMaskedLM): 255 | # 256 | def __init__(self, config): 257 | super().__init__(config) 258 | # 259 | def forward(self, input_ids=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, masked_lm_labels=None): 260 | # 261 | assert attention_mask is not None 262 | outputs = self.roberta(input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, position_ids=position_ids, head_mask=head_mask) 263 | sequence_output = outputs[0] #hidden_states of final layer (batch_size, sequence_length, hidden_size) 264 | prediction_scores = self.lm_head(sequence_output) 265 | outputs = (prediction_scores, sequence_output) + outputs[2:] 266 | if masked_lm_labels is not None: 267 | loss_fct = torch.nn.CrossEntropyLoss(reduction='none') 268 | bsize, seqlen = input_ids.size() 269 | masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), masked_lm_labels.view(-1)).view(bsize, seqlen) 270 | masked_lm_loss = (masked_lm_loss * attention_mask).sum(dim=1) 271 | outputs = (masked_lm_loss,) + outputs 272 | # (masked_lm_loss), prediction_scores, sequence_output, (hidden_states), (attentions) 273 | return outputs 274 | 275 | print ('loading pre-trained LM...') 276 | TOKENIZER = RobertaTokenizer.from_pretrained('roberta-large') 277 | LM_MODEL = RobertaForMaskedLMwithLoss.from_pretrained('roberta-large') 278 | LM_MODEL.cuda(); LM_MODEL.eval() 279 | print ('loading done') 280 | 281 | def get_LM_score(cids, question): 282 | cids = cids[:] 283 | cids.insert(0, -1) #QAcontext node 284 | sents, scores = [], [] 285 | for cid in cids: 286 | if cid==-1: 287 | sent = question.lower() 288 | else: 289 | sent = '{} {}.'.format(question.lower(), ' '.join(id2concept[cid].split('_'))) 290 | sent = TOKENIZER.encode(sent, add_special_tokens=True) 291 | sents.append(sent) 292 | n_cids = len(cids) 293 | cur_idx = 0 294 | batch_size = 50 295 | while cur_idx < n_cids: 296 | #Prepare batch 297 | input_ids = sents[cur_idx: cur_idx+batch_size] 298 | max_len = max([len(seq) for seq in input_ids]) 299 | for j, seq in enumerate(input_ids): 300 | seq += [TOKENIZER.pad_token_id] * (max_len-len(seq)) 301 | input_ids[j] = seq 302 | input_ids = torch.tensor(input_ids).cuda() #[B, seqlen] 303 | mask = (input_ids!=1).long() #[B, seq_len] 304 | #Get LM score 305 | with torch.no_grad(): 306 | outputs = LM_MODEL(input_ids, attention_mask=mask, masked_lm_labels=input_ids) 307 | loss = outputs[0] #[B, ] 308 | _scores = list(-loss.detach().cpu().numpy()) #list of float 309 | scores += _scores 310 | cur_idx += batch_size 311 | assert len(sents) == len(scores) == len(cids) 312 | cid2score = OrderedDict(sorted(list(zip(cids, scores)), key=lambda x: -x[1])) #score: from high to low 313 | return cid2score 314 | 315 | def concepts_to_adj_matrices_2hop_all_pair__use_LM__Part1(data): 316 | qc_ids, ac_ids, question = data 317 | qa_nodes = set(qc_ids) | set(ac_ids) 318 | extra_nodes = set() 319 | for qid in qa_nodes: 320 | for aid in qa_nodes: 321 | if qid != aid and qid in cpnet_simple.nodes and aid in cpnet_simple.nodes: 322 | extra_nodes |= set(cpnet_simple[qid]) & set(cpnet_simple[aid]) 323 | extra_nodes = extra_nodes - qa_nodes 324 | return (sorted(qc_ids), sorted(ac_ids), question, sorted(extra_nodes)) 325 | 326 | def concepts_to_adj_matrices_2hop_all_pair__use_LM__Part2(data): 327 | qc_ids, ac_ids, question, extra_nodes = data 328 | cid2score = get_LM_score(qc_ids+ac_ids+extra_nodes, question) 329 | return (qc_ids, ac_ids, question, extra_nodes, cid2score) 330 | 331 | def concepts_to_adj_matrices_2hop_all_pair__use_LM__Part3(data): 332 | qc_ids, ac_ids, question, extra_nodes, cid2score = data 333 | schema_graph = qc_ids + ac_ids + sorted(extra_nodes, key=lambda x: -cid2score[x]) #score: from high to low 334 | arange = np.arange(len(schema_graph)) 335 | qmask = arange < len(qc_ids) 336 | amask = (arange >= len(qc_ids)) & (arange < (len(qc_ids) + len(ac_ids))) 337 | adj, concepts = concepts2adj(schema_graph) 338 | return {'adj': adj, 'concepts': concepts, 'qmask': qmask, 'amask': amask, 'cid2score': cid2score} 339 | 340 | ################################################################################ 341 | 342 | 343 | 344 | ##################################################################################################### 345 | # functions below this line will be called by preprocess.py # 346 | ##################################################################################################### 347 | 348 | 349 | def generate_graph(grounded_path, pruned_paths_path, cpnet_vocab_path, cpnet_graph_path, output_path): 350 | print(f'generating schema graphs for {grounded_path} and {pruned_paths_path}...') 351 | 352 | global concept2id, id2concept, relation2id, id2relation 353 | if any(x is None for x in [concept2id, id2concept, relation2id, id2relation]): 354 | load_resources(cpnet_vocab_path) 355 | 356 | global cpnet, cpnet_simple 357 | if cpnet is None or cpnet_simple is None: 358 | load_cpnet(cpnet_graph_path) 359 | 360 | nrow = sum(1 for _ in open(grounded_path, 'r')) 361 | with open(grounded_path, 'r') as fin_gr, \ 362 | open(pruned_paths_path, 'r') as fin_pf, \ 363 | open(output_path, 'w') as fout: 364 | for line_gr, line_pf in tqdm(zip(fin_gr, fin_pf), total=nrow): 365 | mcp = json.loads(line_gr) 366 | qa_pairs = json.loads(line_pf) 367 | 368 | statement_paths = [] 369 | statement_rel_list = [] 370 | for qas in qa_pairs: 371 | if qas["pf_res"] is None: 372 | cur_paths = [] 373 | cur_rels = [] 374 | else: 375 | cur_paths = [item["path"] for item in qas["pf_res"]] 376 | cur_rels = [item["rel"] for item in qas["pf_res"]] 377 | statement_paths.extend(cur_paths) 378 | statement_rel_list.extend(cur_rels) 379 | 380 | qcs = [concept2id[c] for c in mcp["qc"]] 381 | acs = [concept2id[c] for c in mcp["ac"]] 382 | 383 | gobj = plain_graph_generation(qcs=qcs, acs=acs, 384 | paths=statement_paths, 385 | rels=statement_rel_list) 386 | fout.write(json.dumps(gobj) + '\n') 387 | 388 | print(f'schema graphs saved to {output_path}') 389 | print() 390 | 391 | 392 | def generate_adj_matrices(ori_schema_graph_path, cpnet_graph_path, cpnet_vocab_path, output_path, num_processes, num_rels=34, debug=False): 393 | print(f'generating adjacency matrices for {ori_schema_graph_path} and {cpnet_graph_path}...') 394 | 395 | global concept2id, id2concept, relation2id, id2relation 396 | if any(x is None for x in [concept2id, id2concept, relation2id, id2relation]): 397 | load_resources(cpnet_vocab_path) 398 | 399 | global cpnet_all 400 | if cpnet_all is None: 401 | cpnet_all = nx.read_gpickle(cpnet_graph_path) 402 | 403 | with open(ori_schema_graph_path, 'r') as fin: 404 | nxg_strs = [line for line in fin] 405 | 406 | if debug: 407 | nxgs = nxgs[:1] 408 | 409 | with Pool(num_processes) as p: 410 | res = list(tqdm(p.imap(generate_adj_matrix_per_inst, nxg_strs), total=len(nxg_strs))) 411 | 412 | with open(output_path, 'wb') as fout: 413 | pickle.dump(res, fout) 414 | 415 | print(f'adjacency matrices saved to {output_path}') 416 | print() 417 | 418 | 419 | def generate_adj_data_from_grounded_concepts(grounded_path, cpnet_graph_path, cpnet_vocab_path, output_path, num_processes): 420 | """ 421 | This function will save 422 | (1) adjacency matrics (each in the form of a (R*N, N) coo sparse matrix) 423 | (2) concepts ids 424 | (3) qmask that specifices whether a node is a question concept 425 | (4) amask that specifices whether a node is a answer concept 426 | to the output path in python pickle format 427 | 428 | grounded_path: str 429 | cpnet_graph_path: str 430 | cpnet_vocab_path: str 431 | output_path: str 432 | num_processes: int 433 | """ 434 | print(f'generating adj data for {grounded_path}...') 435 | 436 | global concept2id, id2concept, relation2id, id2relation, cpnet_simple, cpnet 437 | if any(x is None for x in [concept2id, id2concept, relation2id, id2relation]): 438 | load_resources(cpnet_vocab_path) 439 | if cpnet is None or cpnet_simple is None: 440 | load_cpnet(cpnet_graph_path) 441 | 442 | qa_data = [] 443 | with open(grounded_path, 'r', encoding='utf-8') as fin: 444 | for line in fin: 445 | dic = json.loads(line) 446 | q_ids = set(concept2id[c] for c in dic['qc']) 447 | a_ids = set(concept2id[c] for c in dic['ac']) 448 | q_ids = q_ids - a_ids 449 | qa_data.append((q_ids, a_ids)) 450 | 451 | with Pool(num_processes) as p: 452 | res = list(tqdm(p.imap(concepts_to_adj_matrices_2hop_all_pair, qa_data), total=len(qa_data))) 453 | 454 | # res is a list of tuples, each tuple consists of four elements (adj, concepts, qmask, amask) 455 | with open(output_path, 'wb') as fout: 456 | pickle.dump(res, fout) 457 | 458 | print(f'adj data saved to {output_path}') 459 | print() 460 | 461 | 462 | 463 | def generate_adj_data_from_grounded_concepts__use_LM(grounded_path, cpnet_graph_path, cpnet_vocab_path, output_path, num_processes): 464 | """ 465 | This function will save 466 | (1) adjacency matrics (each in the form of a (R*N, N) coo sparse matrix) 467 | (2) concepts ids 468 | (3) qmask that specifices whether a node is a question concept 469 | (4) amask that specifices whether a node is a answer concept 470 | (5) cid2score that maps a concept id to its relevance score given the QA context 471 | to the output path in python pickle format 472 | 473 | grounded_path: str 474 | cpnet_graph_path: str 475 | cpnet_vocab_path: str 476 | output_path: str 477 | num_processes: int 478 | """ 479 | print(f'generating adj data for {grounded_path}...') 480 | 481 | global concept2id, id2concept, relation2id, id2relation, cpnet_simple, cpnet 482 | if any(x is None for x in [concept2id, id2concept, relation2id, id2relation]): 483 | load_resources(cpnet_vocab_path) 484 | if cpnet is None or cpnet_simple is None: 485 | load_cpnet(cpnet_graph_path) 486 | 487 | qa_data = [] 488 | statement_path = grounded_path.replace('grounded', 'statement') 489 | with open(grounded_path, 'r', encoding='utf-8') as fin_ground, open(statement_path, 'r', encoding='utf-8') as fin_state: 490 | lines_ground = fin_ground.readlines() 491 | lines_state = fin_state.readlines() 492 | assert len(lines_ground) % len(lines_state) == 0 493 | n_choices = len(lines_ground) // len(lines_state) 494 | for j, line in enumerate(lines_ground): 495 | dic = json.loads(line) 496 | q_ids = set(concept2id[c] for c in dic['qc']) 497 | a_ids = set(concept2id[c] for c in dic['ac']) 498 | q_ids = q_ids - a_ids 499 | statement_obj = json.loads(lines_state[j//n_choices]) 500 | QAcontext = "{} {}.".format(statement_obj['question']['stem'], dic['ans']) 501 | qa_data.append((q_ids, a_ids, QAcontext)) 502 | 503 | with Pool(num_processes) as p: 504 | res1 = list(tqdm(p.imap(concepts_to_adj_matrices_2hop_all_pair__use_LM__Part1, qa_data), total=len(qa_data))) 505 | 506 | res2 = [] 507 | for j, _data in enumerate(res1): 508 | if j % 100 == 0: print (j) 509 | res2.append(concepts_to_adj_matrices_2hop_all_pair__use_LM__Part2(_data)) 510 | 511 | with Pool(num_processes) as p: 512 | res3 = list(tqdm(p.imap(concepts_to_adj_matrices_2hop_all_pair__use_LM__Part3, res2), total=len(res2))) 513 | 514 | # res is a list of responses 515 | with open(output_path, 'wb') as fout: 516 | pickle.dump(res3, fout) 517 | 518 | print(f'adj data saved to {output_path}') 519 | print() 520 | 521 | 522 | 523 | #################### adj to sparse #################### 524 | 525 | def coo_to_normalized_per_inst(data): 526 | adj, concepts, qm, am, max_node_num = data 527 | ori_adj_len = len(concepts) 528 | concepts = torch.tensor(concepts[:min(len(concepts), max_node_num)]) 529 | adj_len = len(concepts) 530 | qm = torch.tensor(qm[:adj_len], dtype=torch.uint8) 531 | am = torch.tensor(am[:adj_len], dtype=torch.uint8) 532 | ij = adj.row 533 | k = adj.col 534 | n_node = adj.shape[1] 535 | n_rel = 2 * adj.shape[0] // n_node 536 | i, j = ij // n_node, ij % n_node 537 | mask = (j < max_node_num) & (k < max_node_num) 538 | i, j, k = i[mask], j[mask], k[mask] 539 | i, j, k = np.concatenate((i, i + n_rel // 2), 0), np.concatenate((j, k), 0), np.concatenate((k, j), 0) # add inverse relations 540 | adj_list = [] 541 | for r in range(n_rel): 542 | mask = i == r 543 | ones = np.ones(mask.sum(), dtype=np.float32) 544 | A = sparse.csr_matrix((ones, (k[mask], j[mask])), shape=(max_node_num, max_node_num)) # A is transposed by exchanging the order of j and k 545 | adj_list.append(normalize_sparse_adj(A, 'coo')) 546 | adj_list.append(sparse.identity(max_node_num, dtype=np.float32, format='coo')) 547 | return ori_adj_len, adj_len, concepts, adj_list, qm, am 548 | 549 | 550 | def coo_to_normalized(adj_path, output_path, max_node_num, num_processes): 551 | print(f'converting {adj_path} to normalized adj') 552 | 553 | with open(adj_path, 'rb') as fin: 554 | adj_data = pickle.load(fin) 555 | data = [(adj, concepts, qmask, amask, max_node_num) for adj, concepts, qmask, amask in adj_data] 556 | 557 | ori_adj_lengths = torch.zeros((len(data),), dtype=torch.int64) 558 | adj_lengths = torch.zeros((len(data),), dtype=torch.int64) 559 | concepts_ids = torch.zeros((len(data), max_node_num), dtype=torch.int64) 560 | qmask = torch.zeros((len(data), max_node_num), dtype=torch.uint8) 561 | amask = torch.zeros((len(data), max_node_num), dtype=torch.uint8) 562 | 563 | adj_data = [] 564 | with Pool(num_processes) as p: 565 | for i, (ori_adj_len, adj_len, concepts, adj_list, qm, am) in tqdm(enumerate(p.imap(coo_to_normalized_per_inst, data)), total=len(data)): 566 | ori_adj_lengths[i] = ori_adj_len 567 | adj_lengths[i] = adj_len 568 | concepts_ids[i][:adj_len] = concepts 569 | qmask[i][:adj_len] = qm 570 | amask[i][:adj_len] = am 571 | adj_list = [(torch.LongTensor(np.stack((adj.row, adj.col), 0)), 572 | torch.FloatTensor(adj.data)) for adj in adj_list] 573 | adj_data.append(adj_list) 574 | 575 | torch.save((ori_adj_lengths, adj_lengths, concepts_ids, adj_data), output_path) 576 | 577 | print(f'normalized adj saved to {output_path}') 578 | print() 579 | 580 | # if __name__ == '__main__': 581 | # generate_adj_matrices_from_grounded_concepts('./data/csqa/grounded/train.grounded.jsonl', 582 | # './data/cpnet/conceptnet.en.pruned.graph', 583 | # './data/cpnet/concept.txt', 584 | # '/tmp/asdf', 40) 585 | -------------------------------------------------------------------------------- /modeling/modeling_qagnn.py: -------------------------------------------------------------------------------- 1 | from modeling.modeling_encoder import TextEncoder, MODEL_NAME_TO_CLASS 2 | from utils.data_utils import * 3 | from utils.layers import * 4 | import torch.nn.functional as F 5 | 6 | 7 | class QAGNN_Message_Passing(nn.Module): 8 | def __init__(self, args, k, n_ntype, n_etype, input_size, hidden_size, output_size, 9 | dropout=0.1): 10 | super().__init__() 11 | assert input_size == output_size 12 | self.args = args 13 | self.n_ntype = n_ntype 14 | self.n_etype = n_etype 15 | 16 | assert input_size == hidden_size 17 | self.hidden_size = hidden_size 18 | 19 | self.emb_node_type = nn.Linear(self.n_ntype, hidden_size//2) 20 | 21 | self.basis_f = 'sin' #['id', 'linact', 'sin', 'none'] 22 | if self.basis_f in ['id']: 23 | self.emb_score = nn.Linear(1, hidden_size//2) 24 | elif self.basis_f in ['linact']: 25 | self.B_lin = nn.Linear(1, hidden_size//2) 26 | self.emb_score = nn.Linear(hidden_size//2, hidden_size//2) 27 | elif self.basis_f in ['sin']: 28 | self.emb_score = nn.Linear(hidden_size//2, hidden_size//2) 29 | 30 | self.edge_encoder = torch.nn.Sequential(torch.nn.Linear(n_etype +1 + n_ntype *2, hidden_size), torch.nn.BatchNorm1d(hidden_size), torch.nn.ReLU(), torch.nn.Linear(hidden_size, hidden_size)) 31 | 32 | 33 | self.k = k 34 | self.gnn_layers = nn.ModuleList([GATConvE(args, hidden_size, n_ntype, n_etype, self.edge_encoder) for _ in range(k)]) 35 | 36 | 37 | self.Vh = nn.Linear(input_size, output_size) 38 | self.Vx = nn.Linear(hidden_size, output_size) 39 | 40 | self.activation = GELU() 41 | self.dropout = nn.Dropout(dropout) 42 | self.dropout_rate = dropout 43 | 44 | 45 | def mp_helper(self, _X, edge_index, edge_type, _node_type, _node_feature_extra): 46 | for _ in range(self.k): 47 | _X = self.gnn_layers[_](_X, edge_index, edge_type, _node_type, _node_feature_extra) 48 | _X = self.activation(_X) 49 | _X = F.dropout(_X, self.dropout_rate, training = self.training) 50 | return _X 51 | 52 | 53 | def forward(self, H, A, node_type, node_score, cache_output=False): 54 | """ 55 | H: tensor of shape (batch_size, n_node, d_node) 56 | node features from the previous layer 57 | A: (edge_index, edge_type) 58 | node_type: long tensor of shape (batch_size, n_node) 59 | 0 == question entity; 1 == answer choice entity; 2 == other node; 3 == context node 60 | node_score: tensor of shape (batch_size, n_node, 1) 61 | """ 62 | _batch_size, _n_nodes = node_type.size() 63 | 64 | #Embed type 65 | T = make_one_hot(node_type.view(-1).contiguous(), self.n_ntype).view(_batch_size, _n_nodes, self.n_ntype) 66 | node_type_emb = self.activation(self.emb_node_type(T)) #[batch_size, n_node, dim/2] 67 | 68 | #Embed score 69 | if self.basis_f == 'sin': 70 | js = torch.arange(self.hidden_size//2).unsqueeze(0).unsqueeze(0).float().to(node_type.device) #[1,1,dim/2] 71 | js = torch.pow(1.1, js) #[1,1,dim/2] 72 | B = torch.sin(js * node_score) #[batch_size, n_node, dim/2] 73 | node_score_emb = self.activation(self.emb_score(B)) #[batch_size, n_node, dim/2] 74 | elif self.basis_f == 'id': 75 | B = node_score 76 | node_score_emb = self.activation(self.emb_score(B)) #[batch_size, n_node, dim/2] 77 | elif self.basis_f == 'linact': 78 | B = self.activation(self.B_lin(node_score)) #[batch_size, n_node, dim/2] 79 | node_score_emb = self.activation(self.emb_score(B)) #[batch_size, n_node, dim/2] 80 | 81 | 82 | X = H 83 | edge_index, edge_type = A #edge_index: [2, total_E] edge_type: [total_E, ] where total_E is for the batched graph 84 | _X = X.view(-1, X.size(2)).contiguous() #[`total_n_nodes`, d_node] where `total_n_nodes` = b_size * n_node 85 | _node_type = node_type.view(-1).contiguous() #[`total_n_nodes`, ] 86 | _node_feature_extra = torch.cat([node_type_emb, node_score_emb], dim=2).view(_node_type.size(0), -1).contiguous() #[`total_n_nodes`, dim] 87 | 88 | _X = self.mp_helper(_X, edge_index, edge_type, _node_type, _node_feature_extra) 89 | 90 | X = _X.view(node_type.size(0), node_type.size(1), -1) #[batch_size, n_node, dim] 91 | 92 | output = self.activation(self.Vh(H) + self.Vx(X)) 93 | output = self.dropout(output) 94 | 95 | return output 96 | 97 | 98 | 99 | class QAGNN(nn.Module): 100 | def __init__(self, args, k, n_ntype, n_etype, sent_dim, 101 | n_concept, concept_dim, concept_in_dim, n_attention_head, 102 | fc_dim, n_fc_layer, p_emb, p_gnn, p_fc, 103 | pretrained_concept_emb=None, freeze_ent_emb=True, 104 | init_range=0.02): 105 | super().__init__() 106 | self.init_range = init_range 107 | 108 | self.concept_emb = CustomizedEmbedding(concept_num=n_concept, concept_out_dim=concept_dim, 109 | use_contextualized=False, concept_in_dim=concept_in_dim, 110 | pretrained_concept_emb=pretrained_concept_emb, freeze_ent_emb=freeze_ent_emb) 111 | self.svec2nvec = nn.Linear(sent_dim, concept_dim) 112 | 113 | self.concept_dim = concept_dim 114 | 115 | self.activation = GELU() 116 | 117 | self.gnn = QAGNN_Message_Passing(args, k=k, n_ntype=n_ntype, n_etype=n_etype, 118 | input_size=concept_dim, hidden_size=concept_dim, output_size=concept_dim, dropout=p_gnn) 119 | 120 | self.pooler = MultiheadAttPoolLayer(n_attention_head, sent_dim, concept_dim) 121 | 122 | self.fc = MLP(concept_dim + sent_dim + concept_dim, fc_dim, 1, n_fc_layer, p_fc, layer_norm=True) 123 | 124 | self.dropout_e = nn.Dropout(p_emb) 125 | self.dropout_fc = nn.Dropout(p_fc) 126 | 127 | if init_range > 0: 128 | self.apply(self._init_weights) 129 | 130 | 131 | def _init_weights(self, module): 132 | if isinstance(module, (nn.Linear, nn.Embedding)): 133 | module.weight.data.normal_(mean=0.0, std=self.init_range) 134 | if hasattr(module, 'bias') and module.bias is not None: 135 | module.bias.data.zero_() 136 | elif isinstance(module, nn.LayerNorm): 137 | module.bias.data.zero_() 138 | module.weight.data.fill_(1.0) 139 | 140 | 141 | def forward(self, sent_vecs, concept_ids, node_type_ids, node_scores, adj_lengths, adj, emb_data=None, cache_output=False): 142 | """ 143 | sent_vecs: (batch_size, dim_sent) 144 | concept_ids: (batch_size, n_node) 145 | adj: edge_index, edge_type 146 | adj_lengths: (batch_size,) 147 | node_type_ids: (batch_size, n_node) 148 | 0 == question entity; 1 == answer choice entity; 2 == other node; 3 == context node 149 | node_scores: (batch_size, n_node, 1) 150 | 151 | returns: (batch_size, 1) 152 | """ 153 | gnn_input0 = self.activation(self.svec2nvec(sent_vecs)).unsqueeze(1) #(batch_size, 1, dim_node) 154 | gnn_input1 = self.concept_emb(concept_ids[:, 1:]-1, emb_data) #(batch_size, n_node-1, dim_node) 155 | gnn_input1 = gnn_input1.to(node_type_ids.device) 156 | gnn_input = self.dropout_e(torch.cat([gnn_input0, gnn_input1], dim=1)) #(batch_size, n_node, dim_node) 157 | 158 | 159 | #Normalize node sore (use norm from Z) 160 | _mask = (torch.arange(node_scores.size(1), device=node_scores.device) < adj_lengths.unsqueeze(1)).float() #0 means masked out #[batch_size, n_node] 161 | node_scores = -node_scores 162 | node_scores = node_scores - node_scores[:, 0:1, :] #[batch_size, n_node, 1] 163 | node_scores = node_scores.squeeze(2) #[batch_size, n_node] 164 | node_scores = node_scores * _mask 165 | mean_norm = (torch.abs(node_scores)).sum(dim=1) / adj_lengths #[batch_size, ] 166 | node_scores = node_scores / (mean_norm.unsqueeze(1) + 1e-05) #[batch_size, n_node] 167 | node_scores = node_scores.unsqueeze(2) #[batch_size, n_node, 1] 168 | 169 | 170 | gnn_output = self.gnn(gnn_input, adj, node_type_ids, node_scores) 171 | 172 | Z_vecs = gnn_output[:,0] #(batch_size, dim_node) 173 | 174 | mask = torch.arange(node_type_ids.size(1), device=node_type_ids.device) >= adj_lengths.unsqueeze(1) #1 means masked out 175 | 176 | mask = mask | (node_type_ids == 3) #pool over all KG nodes 177 | mask[mask.all(1), 0] = 0 # a temporary solution to avoid zero node 178 | 179 | sent_vecs_for_pooler = sent_vecs 180 | graph_vecs, pool_attn = self.pooler(sent_vecs_for_pooler, gnn_output, mask) 181 | 182 | if cache_output: 183 | self.concept_ids = concept_ids 184 | self.adj = adj 185 | self.pool_attn = pool_attn 186 | 187 | concat = self.dropout_fc(torch.cat((graph_vecs, sent_vecs, Z_vecs), 1)) 188 | logits = self.fc(concat) 189 | return logits, pool_attn 190 | 191 | 192 | class LM_QAGNN(nn.Module): 193 | def __init__(self, args, model_name, k, n_ntype, n_etype, 194 | n_concept, concept_dim, concept_in_dim, n_attention_head, 195 | fc_dim, n_fc_layer, p_emb, p_gnn, p_fc, 196 | pretrained_concept_emb=None, freeze_ent_emb=True, 197 | init_range=0.0, encoder_config={}): 198 | super().__init__() 199 | self.encoder = TextEncoder(model_name, **encoder_config) 200 | self.decoder = QAGNN(args, k, n_ntype, n_etype, self.encoder.sent_dim, 201 | n_concept, concept_dim, concept_in_dim, n_attention_head, 202 | fc_dim, n_fc_layer, p_emb, p_gnn, p_fc, 203 | pretrained_concept_emb=pretrained_concept_emb, freeze_ent_emb=freeze_ent_emb, 204 | init_range=init_range) 205 | 206 | 207 | def forward(self, *inputs, layer_id=-1, cache_output=False, detail=False): 208 | """ 209 | sent_vecs: (batch_size, num_choice, d_sent) -> (batch_size * num_choice, d_sent) 210 | concept_ids: (batch_size, num_choice, n_node) -> (batch_size * num_choice, n_node) 211 | node_type_ids: (batch_size, num_choice, n_node) -> (batch_size * num_choice, n_node) 212 | adj_lengths: (batch_size, num_choice) -> (batch_size * num_choice, ) 213 | adj -> edge_index, edge_type 214 | edge_index: list of (batch_size, num_choice) -> list of (batch_size * num_choice, ); each entry is torch.tensor(2, E(variable)) 215 | -> (2, total E) 216 | edge_type: list of (batch_size, num_choice) -> list of (batch_size * num_choice, ); each entry is torch.tensor(E(variable), ) 217 | -> (total E, ) 218 | returns: (batch_size, 1) 219 | """ 220 | bs, nc = inputs[0].size(0), inputs[0].size(1) 221 | 222 | #Here, merge the batch dimension and the num_choice dimension 223 | edge_index_orig, edge_type_orig = inputs[-2:] 224 | _inputs = [x.view(x.size(0) * x.size(1), *x.size()[2:]) for x in inputs[:-6]] + [x.view(x.size(0) * x.size(1), *x.size()[2:]) for x in inputs[-6:-2]] + [sum(x,[]) for x in inputs[-2:]] 225 | 226 | *lm_inputs, concept_ids, node_type_ids, node_scores, adj_lengths, edge_index, edge_type = _inputs 227 | edge_index, edge_type = self.batch_graph(edge_index, edge_type, concept_ids.size(1)) 228 | adj = (edge_index.to(node_type_ids.device), edge_type.to(node_type_ids.device)) #edge_index: [2, total_E] edge_type: [total_E, ] 229 | 230 | sent_vecs, all_hidden_states = self.encoder(*lm_inputs, layer_id=layer_id) 231 | logits, attn = self.decoder(sent_vecs.to(node_type_ids.device), 232 | concept_ids, 233 | node_type_ids, node_scores, adj_lengths, adj, 234 | emb_data=None, cache_output=cache_output) 235 | logits = logits.view(bs, nc) 236 | if not detail: 237 | return logits, attn 238 | else: 239 | return logits, attn, concept_ids.view(bs, nc, -1), node_type_ids.view(bs, nc, -1), edge_index_orig, edge_type_orig 240 | #edge_index_orig: list of (batch_size, num_choice). each entry is torch.tensor(2, E) 241 | #edge_type_orig: list of (batch_size, num_choice). each entry is torch.tensor(E, ) 242 | 243 | 244 | def batch_graph(self, edge_index_init, edge_type_init, n_nodes): 245 | #edge_index_init: list of (n_examples, ). each entry is torch.tensor(2, E) 246 | #edge_type_init: list of (n_examples, ). each entry is torch.tensor(E, ) 247 | n_examples = len(edge_index_init) 248 | edge_index = [edge_index_init[_i_] + _i_ * n_nodes for _i_ in range(n_examples)] 249 | edge_index = torch.cat(edge_index, dim=1) #[2, total_E] 250 | edge_type = torch.cat(edge_type_init, dim=0) #[total_E, ] 251 | return edge_index, edge_type 252 | 253 | 254 | 255 | class LM_QAGNN_DataLoader(object): 256 | 257 | def __init__(self, args, train_statement_path, train_adj_path, 258 | dev_statement_path, dev_adj_path, 259 | test_statement_path, test_adj_path, 260 | batch_size, eval_batch_size, device, model_name, max_node_num=200, max_seq_length=128, 261 | is_inhouse=False, inhouse_train_qids_path=None, 262 | subsample=1.0, use_cache=True): 263 | super().__init__() 264 | self.args = args 265 | self.batch_size = batch_size 266 | self.eval_batch_size = eval_batch_size 267 | self.device0, self.device1 = device 268 | self.is_inhouse = is_inhouse 269 | 270 | model_type = MODEL_NAME_TO_CLASS[model_name] 271 | print ('train_statement_path', train_statement_path) 272 | self.train_qids, self.train_labels, *self.train_encoder_data = load_input_tensors(train_statement_path, model_type, model_name, max_seq_length) 273 | self.dev_qids, self.dev_labels, *self.dev_encoder_data = load_input_tensors(dev_statement_path, model_type, model_name, max_seq_length) 274 | 275 | num_choice = self.train_encoder_data[0].size(1) 276 | self.num_choice = num_choice 277 | print ('num_choice', num_choice) 278 | *self.train_decoder_data, self.train_adj_data = load_sparse_adj_data_with_contextnode(train_adj_path, max_node_num, num_choice, args) 279 | 280 | *self.dev_decoder_data, self.dev_adj_data = load_sparse_adj_data_with_contextnode(dev_adj_path, max_node_num, num_choice, args) 281 | assert all(len(self.train_qids) == len(self.train_adj_data[0]) == x.size(0) for x in [self.train_labels] + self.train_encoder_data + self.train_decoder_data) 282 | assert all(len(self.dev_qids) == len(self.dev_adj_data[0]) == x.size(0) for x in [self.dev_labels] + self.dev_encoder_data + self.dev_decoder_data) 283 | 284 | if test_statement_path is not None: 285 | self.test_qids, self.test_labels, *self.test_encoder_data = load_input_tensors(test_statement_path, model_type, model_name, max_seq_length) 286 | *self.test_decoder_data, self.test_adj_data = load_sparse_adj_data_with_contextnode(test_adj_path, max_node_num, num_choice, args) 287 | assert all(len(self.test_qids) == len(self.test_adj_data[0]) == x.size(0) for x in [self.test_labels] + self.test_encoder_data + self.test_decoder_data) 288 | 289 | 290 | if self.is_inhouse: 291 | with open(inhouse_train_qids_path, 'r') as fin: 292 | inhouse_qids = set(line.strip() for line in fin) 293 | self.inhouse_train_indexes = torch.tensor([i for i, qid in enumerate(self.train_qids) if qid in inhouse_qids]) 294 | self.inhouse_test_indexes = torch.tensor([i for i, qid in enumerate(self.train_qids) if qid not in inhouse_qids]) 295 | 296 | assert 0. < subsample <= 1. 297 | if subsample < 1.: 298 | n_train = int(self.train_size() * subsample) 299 | assert n_train > 0 300 | if self.is_inhouse: 301 | self.inhouse_train_indexes = self.inhouse_train_indexes[:n_train] 302 | else: 303 | self.train_qids = self.train_qids[:n_train] 304 | self.train_labels = self.train_labels[:n_train] 305 | self.train_encoder_data = [x[:n_train] for x in self.train_encoder_data] 306 | self.train_decoder_data = [x[:n_train] for x in self.train_decoder_data] 307 | self.train_adj_data = self.train_adj_data[:n_train] 308 | assert all(len(self.train_qids) == len(self.train_adj_data[0]) == x.size(0) for x in [self.train_labels] + self.train_encoder_data + self.train_decoder_data) 309 | assert self.train_size() == n_train 310 | 311 | def train_size(self): 312 | return self.inhouse_train_indexes.size(0) if self.is_inhouse else len(self.train_qids) 313 | 314 | def dev_size(self): 315 | return len(self.dev_qids) 316 | 317 | def test_size(self): 318 | if self.is_inhouse: 319 | return self.inhouse_test_indexes.size(0) 320 | else: 321 | return len(self.test_qids) if hasattr(self, 'test_qids') else 0 322 | 323 | def train(self): 324 | if self.is_inhouse: 325 | n_train = self.inhouse_train_indexes.size(0) 326 | train_indexes = self.inhouse_train_indexes[torch.randperm(n_train)] 327 | else: 328 | train_indexes = torch.randperm(len(self.train_qids)) 329 | return MultiGPUSparseAdjDataBatchGenerator(self.args, 'train', self.device0, self.device1, self.batch_size, train_indexes, self.train_qids, self.train_labels, tensors0=self.train_encoder_data, tensors1=self.train_decoder_data, adj_data=self.train_adj_data) 330 | 331 | def train_eval(self): 332 | return MultiGPUSparseAdjDataBatchGenerator(self.args, 'eval', self.device0, self.device1, self.eval_batch_size, torch.arange(len(self.train_qids)), self.train_qids, self.train_labels, tensors0=self.train_encoder_data, tensors1=self.train_decoder_data, adj_data=self.train_adj_data) 333 | 334 | def dev(self): 335 | return MultiGPUSparseAdjDataBatchGenerator(self.args, 'eval', self.device0, self.device1, self.eval_batch_size, torch.arange(len(self.dev_qids)), self.dev_qids, self.dev_labels, tensors0=self.dev_encoder_data, tensors1=self.dev_decoder_data, adj_data=self.dev_adj_data) 336 | 337 | def test(self): 338 | if self.is_inhouse: 339 | return MultiGPUSparseAdjDataBatchGenerator(self.args, 'eval', self.device0, self.device1, self.eval_batch_size, self.inhouse_test_indexes, self.train_qids, self.train_labels, tensors0=self.train_encoder_data, tensors1=self.train_decoder_data, adj_data=self.train_adj_data) 340 | else: 341 | return MultiGPUSparseAdjDataBatchGenerator(self.args, 'eval', self.device0, self.device1, self.eval_batch_size, torch.arange(len(self.test_qids)), self.test_qids, self.test_labels, tensors0=self.test_encoder_data, tensors1=self.test_decoder_data, adj_data=self.test_adj_data) 342 | 343 | 344 | 345 | 346 | 347 | ############################################################################### 348 | ############################### GNN architecture ############################## 349 | ############################################################################### 350 | 351 | from torch.autograd import Variable 352 | def make_one_hot(labels, C): 353 | ''' 354 | Converts an integer label torch.autograd.Variable to a one-hot Variable. 355 | labels : torch.autograd.Variable of torch.cuda.LongTensor 356 | (N, ), where N is batch size. 357 | Each value is an integer representing correct classification. 358 | C : integer. 359 | number of classes in labels. 360 | Returns : torch.autograd.Variable of torch.cuda.FloatTensor 361 | N x C, where C is class number. One-hot encoded. 362 | ''' 363 | labels = labels.unsqueeze(1) 364 | one_hot = torch.FloatTensor(labels.size(0), C).zero_().to(labels.device) 365 | target = one_hot.scatter_(1, labels.data, 1) 366 | target = Variable(target) 367 | return target 368 | 369 | 370 | 371 | from torch_geometric.nn import MessagePassing 372 | from torch_geometric.utils import add_self_loops, degree, softmax 373 | from torch_geometric.nn import global_add_pool, global_mean_pool, global_max_pool, GlobalAttention, Set2Set 374 | import torch.nn.functional as F 375 | from torch_scatter import scatter_add, scatter 376 | from torch_geometric.nn.inits import glorot, zeros 377 | 378 | 379 | 380 | class GATConvE(MessagePassing): 381 | """ 382 | Args: 383 | emb_dim (int): dimensionality of GNN hidden states 384 | n_ntype (int): number of node types (e.g. 4) 385 | n_etype (int): number of edge relation types (e.g. 38) 386 | """ 387 | def __init__(self, args, emb_dim, n_ntype, n_etype, edge_encoder, head_count=4, aggr="add"): 388 | super(GATConvE, self).__init__(aggr=aggr) 389 | self.args = args 390 | 391 | assert emb_dim % 2 == 0 392 | self.emb_dim = emb_dim 393 | 394 | self.n_ntype = n_ntype; self.n_etype = n_etype 395 | self.edge_encoder = edge_encoder 396 | 397 | #For attention 398 | self.head_count = head_count 399 | assert emb_dim % head_count == 0 400 | self.dim_per_head = emb_dim // head_count 401 | self.linear_key = nn.Linear(3*emb_dim, head_count * self.dim_per_head) 402 | self.linear_msg = nn.Linear(3*emb_dim, head_count * self.dim_per_head) 403 | self.linear_query = nn.Linear(2*emb_dim, head_count * self.dim_per_head) 404 | 405 | self._alpha = None 406 | 407 | #For final MLP 408 | self.mlp = torch.nn.Sequential(torch.nn.Linear(emb_dim, emb_dim), torch.nn.BatchNorm1d(emb_dim), torch.nn.ReLU(), torch.nn.Linear(emb_dim, emb_dim)) 409 | 410 | 411 | def forward(self, x, edge_index, edge_type, node_type, node_feature_extra, return_attention_weights=False): 412 | # x: [N, emb_dim] 413 | # edge_index: [2, E] 414 | # edge_type [E,] -> edge_attr: [E, 39] / self_edge_attr: [N, 39] 415 | # node_type [N,] -> headtail_attr [E, 8(=4+4)] / self_headtail_attr: [N, 8] 416 | # node_feature_extra [N, dim] 417 | 418 | #Prepare edge feature 419 | edge_vec = make_one_hot(edge_type, self.n_etype +1) #[E, 39] 420 | self_edge_vec = torch.zeros(x.size(0), self.n_etype +1).to(edge_vec.device) 421 | self_edge_vec[:,self.n_etype] = 1 422 | 423 | head_type = node_type[edge_index[0]] #[E,] #head=src 424 | tail_type = node_type[edge_index[1]] #[E,] #tail=tgt 425 | head_vec = make_one_hot(head_type, self.n_ntype) #[E,4] 426 | tail_vec = make_one_hot(tail_type, self.n_ntype) #[E,4] 427 | headtail_vec = torch.cat([head_vec, tail_vec], dim=1) #[E,8] 428 | self_head_vec = make_one_hot(node_type, self.n_ntype) #[N,4] 429 | self_headtail_vec = torch.cat([self_head_vec, self_head_vec], dim=1) #[N,8] 430 | 431 | edge_vec = torch.cat([edge_vec, self_edge_vec], dim=0) #[E+N, ?] 432 | headtail_vec = torch.cat([headtail_vec, self_headtail_vec], dim=0) #[E+N, ?] 433 | edge_embeddings = self.edge_encoder(torch.cat([edge_vec, headtail_vec], dim=1)) #[E+N, emb_dim] 434 | 435 | #Add self loops to edge_index 436 | loop_index = torch.arange(0, x.size(0), dtype=torch.long, device=edge_index.device) 437 | loop_index = loop_index.unsqueeze(0).repeat(2, 1) 438 | edge_index = torch.cat([edge_index, loop_index], dim=1) #[2, E+N] 439 | 440 | x = torch.cat([x, node_feature_extra], dim=1) 441 | x = (x, x) 442 | aggr_out = self.propagate(edge_index, x=x, edge_attr=edge_embeddings) #[N, emb_dim] 443 | out = self.mlp(aggr_out) 444 | 445 | alpha = self._alpha 446 | self._alpha = None 447 | 448 | if return_attention_weights: 449 | assert alpha is not None 450 | return out, (edge_index, alpha) 451 | else: 452 | return out 453 | 454 | 455 | def message(self, edge_index, x_i, x_j, edge_attr): #i: tgt, j:src 456 | # print ("edge_attr.size()", edge_attr.size()) #[E, emb_dim] 457 | # print ("x_j.size()", x_j.size()) #[E, emb_dim] 458 | # print ("x_i.size()", x_i.size()) #[E, emb_dim] 459 | assert len(edge_attr.size()) == 2 460 | assert edge_attr.size(1) == self.emb_dim 461 | assert x_i.size(1) == x_j.size(1) == 2*self.emb_dim 462 | assert x_i.size(0) == x_j.size(0) == edge_attr.size(0) == edge_index.size(1) 463 | 464 | key = self.linear_key(torch.cat([x_i, edge_attr], dim=1)).view(-1, self.head_count, self.dim_per_head) #[E, heads, _dim] 465 | msg = self.linear_msg(torch.cat([x_j, edge_attr], dim=1)).view(-1, self.head_count, self.dim_per_head) #[E, heads, _dim] 466 | query = self.linear_query(x_j).view(-1, self.head_count, self.dim_per_head) #[E, heads, _dim] 467 | 468 | 469 | query = query / math.sqrt(self.dim_per_head) 470 | scores = (query * key).sum(dim=2) #[E, heads] 471 | src_node_index = edge_index[0] #[E,] 472 | alpha = softmax(scores, src_node_index) #[E, heads] #group by src side node 473 | self._alpha = alpha 474 | 475 | #adjust by outgoing degree of src 476 | E = edge_index.size(1) #n_edges 477 | N = int(src_node_index.max()) + 1 #n_nodes 478 | ones = torch.full((E,), 1.0, dtype=torch.float).to(edge_index.device) 479 | src_node_edge_count = scatter(ones, src_node_index, dim=0, dim_size=N, reduce='sum')[src_node_index] #[E,] 480 | assert len(src_node_edge_count.size()) == 1 and len(src_node_edge_count) == E 481 | alpha = alpha * src_node_edge_count.unsqueeze(1) #[E, heads] 482 | 483 | out = msg * alpha.view(-1, self.head_count, 1) #[E, heads, _dim] 484 | return out.view(-1, self.head_count * self.dim_per_head) #[E, emb_dim] 485 | -------------------------------------------------------------------------------- /utils/data_utils.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import os 3 | import numpy as np 4 | import torch 5 | from transformers import (OpenAIGPTTokenizer, BertTokenizer, XLNetTokenizer, RobertaTokenizer, AutoTokenizer) 6 | try: 7 | from transformers import AlbertTokenizer 8 | except: 9 | pass 10 | 11 | import json 12 | from tqdm import tqdm 13 | 14 | GPT_SPECIAL_TOKENS = ['_start_', '_delimiter_', '_classify_'] 15 | 16 | 17 | class MultiGPUSparseAdjDataBatchGenerator(object): 18 | def __init__(self, args, mode, device0, device1, batch_size, indexes, qids, labels, 19 | tensors0=[], lists0=[], tensors1=[], lists1=[], adj_data=None): 20 | self.args = args 21 | self.mode = mode 22 | self.device0 = device0 23 | self.device1 = device1 24 | self.batch_size = batch_size 25 | self.indexes = indexes 26 | self.qids = qids 27 | self.labels = labels 28 | self.tensors0 = tensors0 29 | self.lists0 = lists0 30 | self.tensors1 = tensors1 31 | self.lists1 = lists1 32 | # self.adj_empty = adj_empty.to(self.device1) 33 | self.adj_data = adj_data 34 | 35 | def __len__(self): 36 | return (self.indexes.size(0) - 1) // self.batch_size + 1 37 | 38 | def __iter__(self): 39 | bs = self.batch_size 40 | n = self.indexes.size(0) 41 | if self.mode=='train' and self.args.drop_partial_batch: 42 | print ('dropping partial batch') 43 | n = (n//bs) *bs 44 | elif self.mode=='train' and self.args.fill_partial_batch: 45 | print ('filling partial batch') 46 | remain = n % bs 47 | if remain > 0: 48 | extra = np.random.choice(self.indexes[:-remain], size=(bs-remain), replace=False) 49 | self.indexes = torch.cat([self.indexes, torch.tensor(extra)]) 50 | n = self.indexes.size(0) 51 | assert n % bs == 0 52 | 53 | for a in range(0, n, bs): 54 | b = min(n, a + bs) 55 | batch_indexes = self.indexes[a:b] 56 | batch_qids = [self.qids[idx] for idx in batch_indexes] 57 | batch_labels = self._to_device(self.labels[batch_indexes], self.device1) 58 | batch_tensors0 = [self._to_device(x[batch_indexes], self.device0) for x in self.tensors0] 59 | batch_tensors1 = [self._to_device(x[batch_indexes], self.device1) for x in self.tensors1] 60 | batch_lists0 = [self._to_device([x[i] for i in batch_indexes], self.device0) for x in self.lists0] 61 | batch_lists1 = [self._to_device([x[i] for i in batch_indexes], self.device1) for x in self.lists1] 62 | 63 | 64 | edge_index_all, edge_type_all = self.adj_data 65 | #edge_index_all: nested list of shape (n_samples, num_choice), where each entry is tensor[2, E] 66 | #edge_type_all: nested list of shape (n_samples, num_choice), where each entry is tensor[E, ] 67 | edge_index = self._to_device([edge_index_all[i] for i in batch_indexes], self.device1) 68 | edge_type = self._to_device([edge_type_all[i] for i in batch_indexes], self.device1) 69 | 70 | yield tuple([batch_qids, batch_labels, *batch_tensors0, *batch_lists0, *batch_tensors1, *batch_lists1, edge_index, edge_type]) 71 | 72 | def _to_device(self, obj, device): 73 | if isinstance(obj, (tuple, list)): 74 | return [self._to_device(item, device) for item in obj] 75 | else: 76 | return obj.to(device) 77 | 78 | 79 | def load_sparse_adj_data_with_contextnode(adj_pk_path, max_node_num, num_choice, args): 80 | cache_path = adj_pk_path +'.loaded_cache' 81 | use_cache = True 82 | 83 | if use_cache and not os.path.exists(cache_path): 84 | use_cache = False 85 | 86 | if use_cache: 87 | with open(cache_path, 'rb') as f: 88 | adj_lengths_ori, concept_ids, node_type_ids, node_scores, adj_lengths, edge_index, edge_type, half_n_rel = pickle.load(f) 89 | else: 90 | with open(adj_pk_path, 'rb') as fin: 91 | adj_concept_pairs = pickle.load(fin) 92 | 93 | n_samples = len(adj_concept_pairs) #this is actually n_questions x n_choices 94 | edge_index, edge_type = [], [] 95 | adj_lengths = torch.zeros((n_samples,), dtype=torch.long) 96 | concept_ids = torch.full((n_samples, max_node_num), 1, dtype=torch.long) 97 | node_type_ids = torch.full((n_samples, max_node_num), 2, dtype=torch.long) #default 2: "other node" 98 | node_scores = torch.zeros((n_samples, max_node_num, 1), dtype=torch.float) 99 | 100 | adj_lengths_ori = adj_lengths.clone() 101 | for idx, _data in tqdm(enumerate(adj_concept_pairs), total=n_samples, desc='loading adj matrices'): 102 | adj, concepts, qm, am, cid2score = _data['adj'], _data['concepts'], _data['qmask'], _data['amask'], _data['cid2score'] 103 | #adj: e.g. <4233x249 (n_nodes*half_n_rels x n_nodes) sparse matrix of type '' with 2905 stored elements in COOrdinate format> 104 | #concepts: np.array(num_nodes, ), where entry is concept id 105 | #qm: np.array(num_nodes, ), where entry is True/False 106 | #am: np.array(num_nodes, ), where entry is True/False 107 | assert len(concepts) == len(set(concepts)) 108 | qam = qm | am 109 | #sanity check: should be T,..,T,F,F,..F 110 | assert qam[0] == True 111 | F_start = False 112 | for TF in qam: 113 | if TF == False: 114 | F_start = True 115 | else: 116 | assert F_start == False 117 | num_concept = min(len(concepts), max_node_num-1) + 1 #this is the final number of nodes including contextnode but excluding PAD 118 | adj_lengths_ori[idx] = len(concepts) 119 | adj_lengths[idx] = num_concept 120 | 121 | #Prepare nodes 122 | concepts = concepts[:num_concept-1] 123 | concept_ids[idx, 1:num_concept] = torch.tensor(concepts +1) #To accomodate contextnode, original concept_ids incremented by 1 124 | concept_ids[idx, 0] = 0 #this is the "concept_id" for contextnode 125 | 126 | #Prepare node scores 127 | if (cid2score is not None): 128 | for _j_ in range(num_concept): 129 | _cid = int(concept_ids[idx, _j_]) - 1 130 | assert _cid in cid2score 131 | node_scores[idx, _j_, 0] = torch.tensor(cid2score[_cid]) 132 | 133 | #Prepare node types 134 | node_type_ids[idx, 0] = 3 #contextnode 135 | node_type_ids[idx, 1:num_concept][torch.tensor(qm, dtype=torch.bool)[:num_concept-1]] = 0 136 | node_type_ids[idx, 1:num_concept][torch.tensor(am, dtype=torch.bool)[:num_concept-1]] = 1 137 | 138 | #Load adj 139 | ij = torch.tensor(adj.row, dtype=torch.int64) #(num_matrix_entries, ), where each entry is coordinate 140 | k = torch.tensor(adj.col, dtype=torch.int64) #(num_matrix_entries, ), where each entry is coordinate 141 | n_node = adj.shape[1] 142 | half_n_rel = adj.shape[0] // n_node 143 | i, j = ij // n_node, ij % n_node 144 | 145 | #Prepare edges 146 | i += 2; j += 1; k += 1 # **** increment coordinate by 1, rel_id by 2 **** 147 | extra_i, extra_j, extra_k = [], [], [] 148 | for _coord, q_tf in enumerate(qm): 149 | _new_coord = _coord + 1 150 | if _new_coord > num_concept: 151 | break 152 | if q_tf: 153 | extra_i.append(0) #rel from contextnode to question concept 154 | extra_j.append(0) #contextnode coordinate 155 | extra_k.append(_new_coord) #question concept coordinate 156 | for _coord, a_tf in enumerate(am): 157 | _new_coord = _coord + 1 158 | if _new_coord > num_concept: 159 | break 160 | if a_tf: 161 | extra_i.append(1) #rel from contextnode to answer concept 162 | extra_j.append(0) #contextnode coordinate 163 | extra_k.append(_new_coord) #answer concept coordinate 164 | 165 | half_n_rel += 2 #should be 19 now 166 | if len(extra_i) > 0: 167 | i = torch.cat([i, torch.tensor(extra_i)], dim=0) 168 | j = torch.cat([j, torch.tensor(extra_j)], dim=0) 169 | k = torch.cat([k, torch.tensor(extra_k)], dim=0) 170 | ######################## 171 | 172 | mask = (j < max_node_num) & (k < max_node_num) 173 | i, j, k = i[mask], j[mask], k[mask] 174 | i, j, k = torch.cat((i, i + half_n_rel), 0), torch.cat((j, k), 0), torch.cat((k, j), 0) # add inverse relations 175 | edge_index.append(torch.stack([j,k], dim=0)) #each entry is [2, E] 176 | edge_type.append(i) #each entry is [E, ] 177 | 178 | with open(cache_path, 'wb') as f: 179 | pickle.dump([adj_lengths_ori, concept_ids, node_type_ids, node_scores, adj_lengths, edge_index, edge_type, half_n_rel], f) 180 | 181 | 182 | ori_adj_mean = adj_lengths_ori.float().mean().item() 183 | ori_adj_sigma = np.sqrt(((adj_lengths_ori.float() - ori_adj_mean)**2).mean().item()) 184 | print('| ori_adj_len: mu {:.2f} sigma {:.2f} | adj_len: {:.2f} |'.format(ori_adj_mean, ori_adj_sigma, adj_lengths.float().mean().item()) + 185 | ' prune_rate: {:.2f} |'.format((adj_lengths_ori > adj_lengths).float().mean().item()) + 186 | ' qc_num: {:.2f} | ac_num: {:.2f} |'.format((node_type_ids == 0).float().sum(1).mean().item(), 187 | (node_type_ids == 1).float().sum(1).mean().item())) 188 | 189 | edge_index = list(map(list, zip(*(iter(edge_index),) * num_choice))) #list of size (n_questions, n_choices), where each entry is tensor[2, E] #this operation corresponds to .view(n_questions, n_choices) 190 | edge_type = list(map(list, zip(*(iter(edge_type),) * num_choice))) #list of size (n_questions, n_choices), where each entry is tensor[E, ] 191 | 192 | concept_ids, node_type_ids, node_scores, adj_lengths = [x.view(-1, num_choice, *x.size()[1:]) for x in (concept_ids, node_type_ids, node_scores, adj_lengths)] 193 | #concept_ids: (n_questions, num_choice, max_node_num) 194 | #node_type_ids: (n_questions, num_choice, max_node_num) 195 | #node_scores: (n_questions, num_choice, max_node_num) 196 | #adj_lengths: (n_questions, num_choice) 197 | return concept_ids, node_type_ids, node_scores, adj_lengths, (edge_index, edge_type) #, half_n_rel * 2 + 1 198 | 199 | 200 | 201 | 202 | 203 | def load_gpt_input_tensors(statement_jsonl_path, max_seq_length): 204 | def _truncate_seq_pair(tokens_a, tokens_b, max_length): 205 | """Truncates a sequence pair in place to the maximum length.""" 206 | while True: 207 | total_length = len(tokens_a) + len(tokens_b) 208 | if total_length <= max_length: 209 | break 210 | if len(tokens_a) > len(tokens_b): 211 | tokens_a.pop() 212 | else: 213 | tokens_b.pop() 214 | 215 | def load_qa_dataset(dataset_path): 216 | """ Output a list of tuples(story, 1st continuation, 2nd continuation, label) """ 217 | with open(dataset_path, "r", encoding="utf-8") as fin: 218 | output = [] 219 | for line in fin: 220 | input_json = json.loads(line) 221 | label = ord(input_json.get("answerKey", "A")) - ord("A") 222 | output.append((input_json['id'], input_json["question"]["stem"], *[ending["text"] for ending in input_json["question"]["choices"]], label)) 223 | return output 224 | 225 | def pre_process_datasets(encoded_datasets, num_choices, max_seq_length, start_token, delimiter_token, clf_token): 226 | """ Pre-process datasets containing lists of tuples(story, 1st continuation, 2nd continuation, label) 227 | 228 | To Transformer inputs of shape (n_batch, n_alternative, length) comprising for each batch, continuation: 229 | input_ids[batch, alternative, :] = [start_token] + story[:cap_length] + [delimiter_token] + cont1[:cap_length] + [clf_token] 230 | """ 231 | tensor_datasets = [] 232 | for dataset in encoded_datasets: 233 | n_batch = len(dataset) 234 | input_ids = np.zeros((n_batch, num_choices, max_seq_length), dtype=np.int64) 235 | mc_token_ids = np.zeros((n_batch, num_choices), dtype=np.int64) 236 | lm_labels = np.full((n_batch, num_choices, max_seq_length), fill_value=-1, dtype=np.int64) 237 | mc_labels = np.zeros((n_batch,), dtype=np.int64) 238 | for i, data, in enumerate(dataset): 239 | q, mc_label = data[0], data[-1] 240 | choices = data[1:-1] 241 | for j in range(len(choices)): 242 | _truncate_seq_pair(q, choices[j], max_seq_length - 3) 243 | qa = [start_token] + q + [delimiter_token] + choices[j] + [clf_token] 244 | input_ids[i, j, :len(qa)] = qa 245 | mc_token_ids[i, j] = len(qa) - 1 246 | lm_labels[i, j, :len(qa) - 1] = qa[1:] 247 | mc_labels[i] = mc_label 248 | all_inputs = (input_ids, mc_token_ids, lm_labels, mc_labels) 249 | tensor_datasets.append(tuple(torch.tensor(t) for t in all_inputs)) 250 | return tensor_datasets 251 | 252 | def tokenize_and_encode(tokenizer, obj): 253 | """ Tokenize and encode a nested object """ 254 | if isinstance(obj, str): 255 | return tokenizer.convert_tokens_to_ids(tokenizer.tokenize(obj)) 256 | elif isinstance(obj, int): 257 | return obj 258 | else: 259 | return list(tokenize_and_encode(tokenizer, o) for o in obj) 260 | 261 | tokenizer = OpenAIGPTTokenizer.from_pretrained('openai-gpt') 262 | tokenizer.add_tokens(GPT_SPECIAL_TOKENS) 263 | special_tokens_ids = tokenizer.convert_tokens_to_ids(GPT_SPECIAL_TOKENS) 264 | 265 | dataset = load_qa_dataset(statement_jsonl_path) 266 | examples_ids = [data[0] for data in dataset] 267 | dataset = [data[1:] for data in dataset] # discard example ids 268 | num_choices = len(dataset[0]) - 2 269 | 270 | encoded_dataset = tokenize_and_encode(tokenizer, dataset) 271 | 272 | (input_ids, mc_token_ids, lm_labels, mc_labels), = pre_process_datasets([encoded_dataset], num_choices, max_seq_length, *special_tokens_ids) 273 | return examples_ids, mc_labels, input_ids, mc_token_ids, lm_labels 274 | 275 | 276 | def get_gpt_token_num(): 277 | tokenizer = OpenAIGPTTokenizer.from_pretrained('openai-gpt') 278 | tokenizer.add_tokens(GPT_SPECIAL_TOKENS) 279 | return len(tokenizer) 280 | 281 | 282 | 283 | def load_bert_xlnet_roberta_input_tensors(statement_jsonl_path, model_type, model_name, max_seq_length): 284 | class InputExample(object): 285 | 286 | def __init__(self, example_id, question, contexts, endings, label=None): 287 | self.example_id = example_id 288 | self.question = question 289 | self.contexts = contexts 290 | self.endings = endings 291 | self.label = label 292 | 293 | class InputFeatures(object): 294 | 295 | def __init__(self, example_id, choices_features, label): 296 | self.example_id = example_id 297 | self.choices_features = [ 298 | { 299 | 'input_ids': input_ids, 300 | 'input_mask': input_mask, 301 | 'segment_ids': segment_ids, 302 | 'output_mask': output_mask, 303 | } 304 | for _, input_ids, input_mask, segment_ids, output_mask in choices_features 305 | ] 306 | self.label = label 307 | 308 | def read_examples(input_file): 309 | with open(input_file, "r", encoding="utf-8") as f: 310 | examples = [] 311 | for line in f.readlines(): 312 | json_dic = json.loads(line) 313 | label = ord(json_dic["answerKey"]) - ord("A") if 'answerKey' in json_dic else 0 314 | contexts = json_dic["question"]["stem"] 315 | if "para" in json_dic: 316 | contexts = json_dic["para"] + " " + contexts 317 | if "fact1" in json_dic: 318 | contexts = json_dic["fact1"] + " " + contexts 319 | examples.append( 320 | InputExample( 321 | example_id=json_dic["id"], 322 | contexts=[contexts] * len(json_dic["question"]["choices"]), 323 | question="", 324 | endings=[ending["text"] for ending in json_dic["question"]["choices"]], 325 | label=label 326 | )) 327 | return examples 328 | 329 | def convert_examples_to_features(examples, label_list, max_seq_length, 330 | tokenizer, 331 | cls_token_at_end=False, 332 | cls_token='[CLS]', 333 | cls_token_segment_id=1, 334 | sep_token='[SEP]', 335 | sequence_a_segment_id=0, 336 | sequence_b_segment_id=1, 337 | sep_token_extra=False, 338 | pad_token_segment_id=0, 339 | pad_on_left=False, 340 | pad_token=0, 341 | mask_padding_with_zero=True): 342 | """ Loads a data file into a list of `InputBatch`s 343 | `cls_token_at_end` define the location of the CLS token: 344 | - False (Default, BERT/XLM pattern): [CLS] + A + [SEP] + B + [SEP] 345 | - True (XLNet/GPT pattern): A + [SEP] + B + [SEP] + [CLS] 346 | `cls_token_segment_id` define the segment id associated to the CLS token (0 for BERT, 2 for XLNet) 347 | """ 348 | label_map = {label: i for i, label in enumerate(label_list)} 349 | 350 | features = [] 351 | for ex_index, example in enumerate(tqdm(examples)): 352 | choices_features = [] 353 | for ending_idx, (context, ending) in enumerate(zip(example.contexts, example.endings)): 354 | tokens_a = tokenizer.tokenize(context) 355 | tokens_b = tokenizer.tokenize(example.question + " " + ending) 356 | 357 | special_tokens_count = 4 if sep_token_extra else 3 358 | _truncate_seq_pair(tokens_a, tokens_b, max_seq_length - special_tokens_count) 359 | 360 | # The convention in BERT is: 361 | # (a) For sequence pairs: 362 | # tokens: [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP] 363 | # type_ids: 0 0 0 0 0 0 0 0 1 1 1 1 1 1 364 | # (b) For single sequences: 365 | # tokens: [CLS] the dog is hairy . [SEP] 366 | # type_ids: 0 0 0 0 0 0 0 367 | # 368 | # Where "type_ids" are used to indicate whether this is the first 369 | # sequence or the second sequence. The embedding vectors for `type=0` and 370 | # `type=1` were learned during pre-training and are added to the wordpiece 371 | # embedding vector (and position vector). This is not *strictly* necessary 372 | # since the [SEP] token unambiguously separates the sequences, but it makes 373 | # it easier for the model to learn the concept of sequences. 374 | # 375 | # For classification tasks, the first vector (corresponding to [CLS]) is 376 | # used as as the "sentence vector". Note that this only makes sense because 377 | # the entire model is fine-tuned. 378 | tokens = tokens_a + [sep_token] 379 | if sep_token_extra: 380 | # roberta uses an extra separator b/w pairs of sentences 381 | tokens += [sep_token] 382 | 383 | segment_ids = [sequence_a_segment_id] * len(tokens) 384 | 385 | if tokens_b: 386 | tokens += tokens_b + [sep_token] 387 | segment_ids += [sequence_b_segment_id] * (len(tokens_b) + 1) 388 | 389 | if cls_token_at_end: 390 | tokens = tokens + [cls_token] 391 | segment_ids = segment_ids + [cls_token_segment_id] 392 | else: 393 | tokens = [cls_token] + tokens 394 | segment_ids = [cls_token_segment_id] + segment_ids 395 | 396 | input_ids = tokenizer.convert_tokens_to_ids(tokens) 397 | 398 | # The mask has 1 for real tokens and 0 for padding tokens. Only real 399 | # tokens are attended to. 400 | 401 | input_mask = [1 if mask_padding_with_zero else 0] * len(input_ids) 402 | special_token_id = tokenizer.convert_tokens_to_ids([cls_token, sep_token]) 403 | output_mask = [1 if id in special_token_id else 0 for id in input_ids] # 1 for mask 404 | 405 | # Zero-pad up to the sequence length. 406 | padding_length = max_seq_length - len(input_ids) 407 | if pad_on_left: 408 | input_ids = ([pad_token] * padding_length) + input_ids 409 | input_mask = ([0 if mask_padding_with_zero else 1] * padding_length) + input_mask 410 | output_mask = ([1] * padding_length) + output_mask 411 | 412 | segment_ids = ([pad_token_segment_id] * padding_length) + segment_ids 413 | else: 414 | input_ids = input_ids + ([pad_token] * padding_length) 415 | input_mask = input_mask + ([0 if mask_padding_with_zero else 1] * padding_length) 416 | output_mask = output_mask + ([1] * padding_length) 417 | segment_ids = segment_ids + ([pad_token_segment_id] * padding_length) 418 | 419 | assert len(input_ids) == max_seq_length 420 | assert len(output_mask) == max_seq_length 421 | assert len(input_mask) == max_seq_length 422 | assert len(segment_ids) == max_seq_length 423 | choices_features.append((tokens, input_ids, input_mask, segment_ids, output_mask)) 424 | label = label_map[example.label] 425 | features.append(InputFeatures(example_id=example.example_id, choices_features=choices_features, label=label)) 426 | 427 | return features 428 | 429 | def _truncate_seq_pair(tokens_a, tokens_b, max_length): 430 | """Truncates a sequence pair in place to the maximum length.""" 431 | 432 | # This is a simple heuristic which will always truncate the longer sequence 433 | # one token at a time. This makes more sense than truncating an equal percent 434 | # of tokens from each, since if one sequence is very short then each token 435 | # that's truncated likely contains more information than a longer sequence. 436 | while True: 437 | total_length = len(tokens_a) + len(tokens_b) 438 | if total_length <= max_length: 439 | break 440 | if len(tokens_a) > len(tokens_b): 441 | tokens_a.pop() 442 | else: 443 | tokens_b.pop() 444 | 445 | def select_field(features, field): 446 | return [[choice[field] for choice in feature.choices_features] for feature in features] 447 | 448 | def convert_features_to_tensors(features): 449 | all_input_ids = torch.tensor(select_field(features, 'input_ids'), dtype=torch.long) 450 | all_input_mask = torch.tensor(select_field(features, 'input_mask'), dtype=torch.long) 451 | all_segment_ids = torch.tensor(select_field(features, 'segment_ids'), dtype=torch.long) 452 | all_output_mask = torch.tensor(select_field(features, 'output_mask'), dtype=torch.bool) 453 | all_label = torch.tensor([f.label for f in features], dtype=torch.long) 454 | return all_input_ids, all_input_mask, all_segment_ids, all_output_mask, all_label 455 | 456 | # try: 457 | # tokenizer_class = {'bert': BertTokenizer, 'xlnet': XLNetTokenizer, 'roberta': RobertaTokenizer, 'albert': AlbertTokenizer}.get(model_type) 458 | # except: 459 | # tokenizer_class = {'bert': BertTokenizer, 'xlnet': XLNetTokenizer, 'roberta': RobertaTokenizer}.get(model_type) 460 | tokenizer_class = AutoTokenizer 461 | tokenizer = tokenizer_class.from_pretrained(model_name) 462 | examples = read_examples(statement_jsonl_path) 463 | features = convert_examples_to_features(examples, list(range(len(examples[0].endings))), max_seq_length, tokenizer, 464 | cls_token_at_end=bool(model_type in ['xlnet']), # xlnet has a cls token at the end 465 | cls_token=tokenizer.cls_token, 466 | sep_token=tokenizer.sep_token, 467 | sep_token_extra=bool(model_type in ['roberta', 'albert']), 468 | cls_token_segment_id=2 if model_type in ['xlnet'] else 0, 469 | pad_on_left=bool(model_type in ['xlnet']), # pad on the left for xlnet 470 | pad_token_segment_id=4 if model_type in ['xlnet'] else 0, 471 | sequence_b_segment_id=0 if model_type in ['roberta', 'albert'] else 1) 472 | example_ids = [f.example_id for f in features] 473 | *data_tensors, all_label = convert_features_to_tensors(features) 474 | return (example_ids, all_label, *data_tensors) 475 | 476 | 477 | 478 | def load_input_tensors(input_jsonl_path, model_type, model_name, max_seq_length): 479 | if model_type in ('lstm',): 480 | raise NotImplementedError 481 | elif model_type in ('gpt',): 482 | return load_gpt_input_tensors(input_jsonl_path, max_seq_length) 483 | elif model_type in ('bert', 'xlnet', 'roberta', 'albert'): 484 | return load_bert_xlnet_roberta_input_tensors(input_jsonl_path, model_type, model_name, max_seq_length) 485 | 486 | 487 | def load_info(statement_path: str): 488 | n = sum(1 for _ in open(statement_path, "r")) 489 | num_choice = None 490 | with open(statement_path, "r", encoding="utf-8") as fin: 491 | ids = [] 492 | labels = [] 493 | for line in fin: 494 | input_json = json.loads(line) 495 | labels.append(ord(input_json.get("answerKey", "A")) - ord("A")) 496 | ids.append(input_json['id']) 497 | if num_choice is None: 498 | num_choice = len(input_json["question"]["choices"]) 499 | labels = torch.tensor(labels, dtype=torch.long) 500 | 501 | return ids, labels, num_choice 502 | 503 | 504 | def load_statement_dict(statement_path): 505 | all_dict = {} 506 | with open(statement_path, 'r', encoding='utf-8') as fin: 507 | for line in fin: 508 | instance_dict = json.loads(line) 509 | qid = instance_dict['id'] 510 | all_dict[qid] = { 511 | 'question': instance_dict['question']['stem'], 512 | 'answers': [dic['text'] for dic in instance_dict['question']['choices']] 513 | } 514 | return all_dict 515 | --------------------------------------------------------------------------------