├── utils
├── __init__.py
├── maths.py
├── utils.py
├── convert_obqa.py
├── optimization_utils.py
├── parser_utils.py
├── convert_csqa.py
├── tokenization_utils.py
├── grounding.py
├── conceptnet.py
├── graph.py
└── data_utils.py
├── .gitignore
├── figs
├── task.png
└── overview.png
├── download_preprocessed_data.sh
├── eval_qagnn__csqa.sh
├── eval_qagnn__obqa.sh
├── LICENSE
├── eval_qagnn__medqa_usmle.sh
├── download_raw_data.sh
├── run_qagnn__csqa.sh
├── run_qagnn__obqa.sh
├── run_qagnn__medqa_usmle.sh
├── README.md
├── modeling
├── modeling_encoder.py
└── modeling_qagnn.py
├── preprocess.py
├── utils_biomed
└── preprocess_medqa_usmle.ipynb
└── qagnn.py
/utils/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | *.pyc
2 |
3 | .ipynb_checkpoints
4 |
5 | saved*
6 |
--------------------------------------------------------------------------------
/figs/task.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/michiyasunaga/qagnn/HEAD/figs/task.png
--------------------------------------------------------------------------------
/figs/overview.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/michiyasunaga/qagnn/HEAD/figs/overview.png
--------------------------------------------------------------------------------
/download_preprocessed_data.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | mv data data_old
4 |
5 | wget https://nlp.stanford.edu/projects/myasu/QAGNN/data_preprocessed_release.zip
6 | unzip data_preprocessed_release.zip
7 | mv data_preprocessed_release data
8 |
--------------------------------------------------------------------------------
/utils/maths.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from scipy import sparse
3 |
4 |
5 | def normalize_sparse_adj(A, sparse_type='coo'):
6 | """
7 | normalize A along the second axis
8 |
9 | A: scipy.sparse matrix
10 | sparse_type: str (optional, default 'coo')
11 | returns: scipy.sparse.coo_marix
12 | """
13 | in_degree = np.array(A.sum(1)).reshape(-1)
14 | in_degree[in_degree == 0] = 1e-5
15 | d_inv = sparse.diags(1 / in_degree)
16 | A = getattr(d_inv.dot(A), 'to' + sparse_type)()
17 | return A
18 |
--------------------------------------------------------------------------------
/eval_qagnn__csqa.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | export CUDA_VISIBLE_DEVICES=0
4 | dt=`date '+%Y%m%d_%H%M%S'`
5 |
6 |
7 | dataset="csqa"
8 | model='roberta-large'
9 | shift
10 | shift
11 | args=$@
12 |
13 |
14 | echo "******************************"
15 | echo "dataset: $dataset"
16 | echo "******************************"
17 |
18 | save_dir_pref='saved_models'
19 | mkdir -p $save_dir_pref
20 |
21 | ###### Eval ######
22 | python3 -u qagnn.py --dataset $dataset \
23 | --train_adj data/${dataset}/graph/train.graph.adj.pk \
24 | --dev_adj data/${dataset}/graph/dev.graph.adj.pk \
25 | --test_adj data/${dataset}/graph/test.graph.adj.pk \
26 | --train_statements data/${dataset}/statement/train.statement.jsonl \
27 | --dev_statements data/${dataset}/statement/dev.statement.jsonl \
28 | --test_statements data/${dataset}/statement/test.statement.jsonl \
29 | --save_model \
30 | --save_dir saved_models \
31 | --mode eval_detail \
32 | --load_model_path saved_models/csqa_model_hf3.4.0.pt \
33 | $args
34 |
--------------------------------------------------------------------------------
/eval_qagnn__obqa.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | export CUDA_VISIBLE_DEVICES=0
4 | dt=`date '+%Y%m%d_%H%M%S'`
5 |
6 |
7 | dataset="obqa"
8 | model='roberta-large'
9 | shift
10 | shift
11 | args=$@
12 |
13 |
14 | echo "******************************"
15 | echo "dataset: $dataset"
16 | echo "******************************"
17 |
18 | save_dir_pref='saved_models'
19 | mkdir -p $save_dir_pref
20 |
21 | ###### Eval ######
22 | python3 -u qagnn.py --dataset $dataset \
23 | --train_adj data/${dataset}/graph/train.graph.adj.pk \
24 | --dev_adj data/${dataset}/graph/dev.graph.adj.pk \
25 | --test_adj data/${dataset}/graph/test.graph.adj.pk \
26 | --train_statements data/${dataset}/statement/train.statement.jsonl \
27 | --dev_statements data/${dataset}/statement/dev.statement.jsonl \
28 | --test_statements data/${dataset}/statement/test.statement.jsonl \
29 | --save_model \
30 | --save_dir saved_models \
31 | --mode eval_detail \
32 | --load_model_path saved_models/obqa_model_hf3.4.0.pt \
33 | $args
34 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2021 Michihiro Yasunaga
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/eval_qagnn__medqa_usmle.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | export CUDA_VISIBLE_DEVICES=0
4 | dt=`date '+%Y%m%d_%H%M%S'`
5 |
6 |
7 | dataset="medqa_usmle"
8 | model='cambridgeltl/SapBERT-from-PubMedBERT-fulltext'
9 | ent_emb='ddb'
10 | shift
11 | shift
12 | args=$@
13 |
14 |
15 | echo "******************************"
16 | echo "dataset: $dataset"
17 | echo "******************************"
18 |
19 | save_dir_pref='saved_models'
20 | mkdir -p $save_dir_pref
21 |
22 | ###### Eval ######
23 | python3 -u qagnn.py --dataset $dataset \
24 | --train_adj data/${dataset}/graph/dev.graph.adj.pk \
25 | --dev_adj data/${dataset}/graph/dev.graph.adj.pk \
26 | --test_adj data/${dataset}/graph/test.graph.adj.pk \
27 | --train_statements data/${dataset}/statement/dev.statement.jsonl \
28 | --dev_statements data/${dataset}/statement/dev.statement.jsonl \
29 | --test_statements data/${dataset}/statement/test.statement.jsonl \
30 | --ent_emb ${ent_emb} \
31 | --save_model \
32 | --save_dir saved_models \
33 | --mode eval_detail \
34 | --load_model_path saved_models/medqa_usmle_model_hf3.4.0.pt \
35 | $args
36 |
--------------------------------------------------------------------------------
/utils/utils.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | import time
4 | import argparse
5 |
6 |
7 | def bool_flag(v):
8 | if v.lower() in ('yes', 'true', 't', 'y', '1'):
9 | return True
10 | elif v.lower() in ('no', 'false', 'f', 'n', '0'):
11 | return False
12 | else:
13 | raise argparse.ArgumentTypeError('Boolean value expected.')
14 |
15 |
16 | def check_path(path):
17 | d = os.path.dirname(path)
18 | if not os.path.exists(d):
19 | os.makedirs(d)
20 |
21 |
22 | def check_file(file):
23 | return os.path.isfile(file)
24 |
25 |
26 | def export_config(config, path):
27 | param_dict = dict(vars(config))
28 | check_path(path)
29 | with open(path, 'w') as fout:
30 | json.dump(param_dict, fout, indent=4)
31 |
32 |
33 | def freeze_net(module):
34 | for p in module.parameters():
35 | p.requires_grad = False
36 |
37 |
38 | def unfreeze_net(module):
39 | for p in module.parameters():
40 | p.requires_grad = True
41 |
42 |
43 | def test_data_loader_ms_per_batch(data_loader, max_steps=10000):
44 | start = time.time()
45 | n_batch = sum(1 for batch, _ in zip(data_loader, range(max_steps)))
46 | return (time.time() - start) * 1000 / n_batch
47 |
--------------------------------------------------------------------------------
/download_raw_data.sh:
--------------------------------------------------------------------------------
1 | # download ConceptNet
2 | mkdir -p data/
3 | mkdir -p data/cpnet/
4 | wget -nc -P data/cpnet/ https://s3.amazonaws.com/conceptnet/downloads/2018/edges/conceptnet-assertions-5.6.0.csv.gz
5 | cd data/cpnet/
6 | yes n | gzip -d conceptnet-assertions-5.6.0.csv.gz
7 | # download ConceptNet entity embedding
8 | wget https://csr.s3-us-west-1.amazonaws.com/tzw.ent.npy
9 | cd ../../
10 |
11 |
12 |
13 |
14 | # download CommensenseQA dataset
15 | mkdir -p data/csqa/
16 | wget -nc -P data/csqa/ https://s3.amazonaws.com/commensenseqa/train_rand_split.jsonl
17 | wget -nc -P data/csqa/ https://s3.amazonaws.com/commensenseqa/dev_rand_split.jsonl
18 | wget -nc -P data/csqa/ https://s3.amazonaws.com/commensenseqa/test_rand_split_no_answers.jsonl
19 |
20 | # create output folders
21 | mkdir -p data/csqa/grounded/
22 | mkdir -p data/csqa/graph/
23 | mkdir -p data/csqa/statement/
24 |
25 |
26 |
27 | # download OpenBookQA dataset
28 | wget -nc -P data/obqa/ https://s3-us-west-2.amazonaws.com/ai2-website/data/OpenBookQA-V1-Sep2018.zip
29 | yes n | unzip data/obqa/OpenBookQA-V1-Sep2018.zip -d data/obqa/
30 |
31 | # create output folders
32 | mkdir -p data/obqa/fairseq/official/
33 | mkdir -p data/obqa/grounded/
34 | mkdir -p data/obqa/graph/
35 | mkdir -p data/obqa/statement/
36 |
--------------------------------------------------------------------------------
/run_qagnn__csqa.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | export CUDA_VISIBLE_DEVICES=0,1
4 | dt=`date '+%Y%m%d_%H%M%S'`
5 |
6 |
7 | dataset="csqa"
8 | model='roberta-large'
9 | shift
10 | shift
11 | args=$@
12 |
13 |
14 | elr="1e-5"
15 | dlr="1e-3"
16 | bs=64
17 | mbs=2
18 | n_epochs=15
19 | num_relation=38 #(17 +2) * 2: originally 17, add 2 relation types (QA context -> Q node; QA context -> A node), and double because we add reverse edges
20 |
21 |
22 | k=5 #num of gnn layers
23 | gnndim=200
24 |
25 | echo "***** hyperparameters *****"
26 | echo "dataset: $dataset"
27 | echo "enc_name: $model"
28 | echo "batch_size: $bs"
29 | echo "learning_rate: elr $elr dlr $dlr"
30 | echo "gnn: dim $gnndim layer $k"
31 | echo "******************************"
32 |
33 | save_dir_pref='saved_models'
34 | mkdir -p $save_dir_pref
35 | mkdir -p logs
36 |
37 | ###### Training ######
38 | for seed in 0; do
39 | python3 -u qagnn.py --dataset $dataset \
40 | --encoder $model -k $k --gnn_dim $gnndim -elr $elr -dlr $dlr -bs $bs -mbs $mbs --fp16 true --seed $seed \
41 | --num_relation $num_relation \
42 | --n_epochs $n_epochs --max_epochs_before_stop 10 \
43 | --train_adj data/${dataset}/graph/train.graph.adj.pk \
44 | --dev_adj data/${dataset}/graph/dev.graph.adj.pk \
45 | --test_adj data/${dataset}/graph/test.graph.adj.pk \
46 | --train_statements data/${dataset}/statement/train.statement.jsonl \
47 | --dev_statements data/${dataset}/statement/dev.statement.jsonl \
48 | --test_statements data/${dataset}/statement/test.statement.jsonl \
49 | --save_model \
50 | --save_dir ${save_dir_pref}/${dataset}/enc-${model}__k${k}__gnndim${gnndim}__bs${bs}__seed${seed}__${dt} $args \
51 | > logs/train_${dataset}__enc-${model}__k${k}__gnndim${gnndim}__bs${bs}__seed${seed}__${dt}.log.txt
52 | done
53 |
--------------------------------------------------------------------------------
/run_qagnn__obqa.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | export CUDA_VISIBLE_DEVICES=0,1
4 | dt=`date '+%Y%m%d_%H%M%S'`
5 |
6 |
7 | dataset="obqa"
8 | model='roberta-large'
9 | shift
10 | shift
11 | args=$@
12 |
13 |
14 | elr="1e-5"
15 | dlr="1e-3"
16 | bs=128
17 | mbs=1
18 | n_epochs=100
19 | num_relation=38 #(17 +2) * 2: originally 17, add 2 relation types (QA context -> Q node; QA context -> A node), and double because we add reverse edges
20 |
21 |
22 | k=5 #num of gnn layers
23 | gnndim=200
24 |
25 | echo "***** hyperparameters *****"
26 | echo "dataset: $dataset"
27 | echo "enc_name: $model"
28 | echo "batch_size: $bs"
29 | echo "learning_rate: elr $elr dlr $dlr"
30 | echo "gnn: dim $gnndim layer $k"
31 | echo "******************************"
32 |
33 | save_dir_pref='saved_models'
34 | mkdir -p $save_dir_pref
35 | mkdir -p logs
36 |
37 | ###### Training ######
38 | for seed in 0; do
39 | python3 -u qagnn.py --dataset $dataset \
40 | --encoder $model -k $k --gnn_dim $gnndim -elr $elr -dlr $dlr -bs $bs -mbs $mbs --fp16 true --seed $seed \
41 | --num_relation $num_relation \
42 | --n_epochs $n_epochs --max_epochs_before_stop 50 \
43 | --train_adj data/${dataset}/graph/train.graph.adj.pk \
44 | --dev_adj data/${dataset}/graph/dev.graph.adj.pk \
45 | --test_adj data/${dataset}/graph/test.graph.adj.pk \
46 | --train_statements data/${dataset}/statement/train.statement.jsonl \
47 | --dev_statements data/${dataset}/statement/dev.statement.jsonl \
48 | --test_statements data/${dataset}/statement/test.statement.jsonl \
49 | --save_model \
50 | --save_dir ${save_dir_pref}/${dataset}/enc-${model}__k${k}__gnndim${gnndim}__bs${bs}__seed${seed}__${dt} $args \
51 | > logs/train_${dataset}__enc-${model}__k${k}__gnndim${gnndim}__bs${bs}__seed${seed}__${dt}.log.txt
52 | done
53 |
--------------------------------------------------------------------------------
/utils/convert_obqa.py:
--------------------------------------------------------------------------------
1 | import json
2 | import re
3 | import sys
4 | from tqdm import tqdm
5 |
6 | __all__ = ['convert_to_obqa_statement']
7 |
8 | # String used to indicate a blank
9 | BLANK_STR = "___"
10 |
11 |
12 | def convert_to_obqa_statement(qa_file: str, output_file1: str, output_file2: str):
13 | print(f'converting {qa_file} to entailment dataset...')
14 | nrow = sum(1 for _ in open(qa_file, 'r'))
15 | with open(output_file1, 'w') as output_handle1, open(output_file2, 'w') as output_handle2, open(qa_file, 'r') as qa_handle:
16 | # print("Writing to {} from {}".format(output_file, qa_file))
17 | for line in tqdm(qa_handle, total=nrow):
18 | json_line = json.loads(line)
19 | output_dict = convert_qajson_to_entailment(json_line)
20 | output_handle1.write(json.dumps(output_dict))
21 | output_handle1.write("\n")
22 | output_handle2.write(json.dumps(output_dict))
23 | output_handle2.write("\n")
24 | print(f'converted statements saved to {output_file1}, {output_file2}')
25 | print()
26 |
27 |
28 | # Convert the QA file json to output dictionary containing premise and hypothesis
29 | def convert_qajson_to_entailment(qa_json: dict):
30 | question_text = qa_json["question"]["stem"]
31 | choices = qa_json["question"]["choices"]
32 | for choice in choices:
33 | choice_text = choice["text"]
34 | statement = question_text + ' ' + choice_text
35 | create_output_dict(qa_json, statement, choice["label"] == qa_json.get("answerKey", "A"))
36 |
37 | return qa_json
38 |
39 |
40 | # Create the output json dictionary from the input json, premise and hypothesis statement
41 | def create_output_dict(input_json: dict, statement: str, label: bool) -> dict:
42 | if "statements" not in input_json:
43 | input_json["statements"] = []
44 | input_json["statements"].append({"label": label, "statement": statement})
45 | return input_json
46 |
--------------------------------------------------------------------------------
/run_qagnn__medqa_usmle.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | export CUDA_VISIBLE_DEVICES=0
4 | dt=`date '+%Y%m%d_%H%M%S'`
5 |
6 |
7 | dataset="medqa_usmle"
8 | model='cambridgeltl/SapBERT-from-PubMedBERT-fulltext'
9 | shift
10 | shift
11 | args=$@
12 |
13 |
14 | elr="5e-5"
15 | dlr="1e-3"
16 | bs=128
17 | mbs=2
18 | sl=512
19 | n_epochs=15
20 | ent_emb='ddb'
21 | num_relation=34 #(15 +2) * 2: originally 15, add 2 relation types (QA context -> Q node; QA context -> A node), and double because we add reverse edges
22 |
23 |
24 | k=5 #num of gnn layers
25 | gnndim=200
26 | unfrz=0
27 |
28 |
29 | echo "***** hyperparameters *****"
30 | echo "dataset: $dataset"
31 | echo "enc_name: $model"
32 | echo "batch_size: $bs"
33 | echo "learning_rate: elr $elr dlr $dlr"
34 | echo "gnn: dim $gnndim layer $k"
35 | echo "******************************"
36 |
37 | save_dir_pref='saved_models'
38 | mkdir -p $save_dir_pref
39 | mkdir -p logs
40 |
41 | ###### Training ######
42 | for seed in 0; do
43 | python3 -u qagnn.py --dataset $dataset \
44 | --encoder $model -k $k --gnn_dim $gnndim -elr $elr -dlr $dlr -bs $bs -mbs $mbs -sl $sl --fp16 true --seed $seed \
45 | --num_relation $num_relation \
46 | --n_epochs $n_epochs --max_epochs_before_stop 10 --unfreeze_epoch $unfrz \
47 | --train_adj data/${dataset}/graph/train.graph.adj.pk \
48 | --dev_adj data/${dataset}/graph/dev.graph.adj.pk \
49 | --test_adj data/${dataset}/graph/test.graph.adj.pk \
50 | --train_statements data/${dataset}/statement/train.statement.jsonl \
51 | --dev_statements data/${dataset}/statement/dev.statement.jsonl \
52 | --test_statements data/${dataset}/statement/test.statement.jsonl \
53 | --ent_emb ${ent_emb} \
54 | --save_model \
55 | --save_dir ${save_dir_pref}/${dataset}/enc-sapbert__k${k}__gnndim${gnndim}__bs${bs}__seed${seed}__${dt} $args \
56 | > logs/train_${dataset}__enc-sapbert__k${k}__gnndim${gnndim}__bs${bs}__sl${sl}__unfrz${unfrz}__seed${seed}__${dt}.log.txt
57 | done
58 |
--------------------------------------------------------------------------------
/utils/optimization_utils.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | from transformers import AdamW
4 | from torch.optim import SGD, Adam
5 | from torch.optim.optimizer import Optimizer
6 |
7 |
8 | class RAdam(Optimizer):
9 |
10 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, degenerated_to_sgd=True):
11 | if not 0.0 <= lr:
12 | raise ValueError("Invalid learning rate: {}".format(lr))
13 | if not 0.0 <= eps:
14 | raise ValueError("Invalid epsilon value: {}".format(eps))
15 | if not 0.0 <= betas[0] < 1.0:
16 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
17 | if not 0.0 <= betas[1] < 1.0:
18 | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
19 |
20 | self.degenerated_to_sgd = degenerated_to_sgd
21 | if isinstance(params, (list, tuple)) and len(params) > 0 and isinstance(params[0], dict):
22 | for param in params:
23 | if 'betas' in param and (param['betas'][0] != betas[0] or param['betas'][1] != betas[1]):
24 | param['buffer'] = [[None, None, None] for _ in range(10)]
25 | defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, buffer=[[None, None, None] for _ in range(10)])
26 | super(RAdam, self).__init__(params, defaults)
27 |
28 | def __setstate__(self, state):
29 | super(RAdam, self).__setstate__(state)
30 |
31 | def step(self, closure=None):
32 |
33 | loss = None
34 | if closure is not None:
35 | loss = closure()
36 |
37 | for group in self.param_groups:
38 |
39 | for p in group['params']:
40 | if p.grad is None:
41 | continue
42 | grad = p.grad.data.float()
43 | if grad.is_sparse:
44 | raise RuntimeError('RAdam does not support sparse gradients')
45 |
46 | p_data_fp32 = p.data.float()
47 |
48 | state = self.state[p]
49 |
50 | if len(state) == 0:
51 | state['step'] = 0
52 | state['exp_avg'] = torch.zeros_like(p_data_fp32)
53 | state['exp_avg_sq'] = torch.zeros_like(p_data_fp32)
54 | else:
55 | state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32)
56 | state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32)
57 |
58 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
59 | beta1, beta2 = group['betas']
60 |
61 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
62 | exp_avg.mul_(beta1).add_(1 - beta1, grad)
63 |
64 | state['step'] += 1
65 | buffered = group['buffer'][int(state['step'] % 10)]
66 | if state['step'] == buffered[0]:
67 | N_sma, step_size = buffered[1], buffered[2]
68 | else:
69 | buffered[0] = state['step']
70 | beta2_t = beta2 ** state['step']
71 | N_sma_max = 2 / (1 - beta2) - 1
72 | N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t)
73 | buffered[1] = N_sma
74 |
75 | # more conservative since it's an approximated value
76 | if N_sma >= 5:
77 | step_size = math.sqrt((1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / (N_sma_max - 2)) / (1 - beta1 ** state['step'])
78 | elif self.degenerated_to_sgd:
79 | step_size = 1.0 / (1 - beta1 ** state['step'])
80 | else:
81 | step_size = -1
82 | buffered[2] = step_size
83 |
84 | # more conservative since it's an approximated value
85 | if N_sma >= 5:
86 | if group['weight_decay'] != 0:
87 | p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32)
88 | denom = exp_avg_sq.sqrt().add_(group['eps'])
89 | p_data_fp32.addcdiv_(-step_size * group['lr'], exp_avg, denom)
90 | p.data.copy_(p_data_fp32)
91 | elif step_size > 0:
92 | if group['weight_decay'] != 0:
93 | p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32)
94 | p_data_fp32.add_(-step_size * group['lr'], exp_avg)
95 | p.data.copy_(p_data_fp32)
96 |
97 | return loss
98 |
99 |
100 | OPTIMIZER_CLASSES = {
101 | 'sgd': SGD,
102 | 'adam': Adam,
103 | 'adamw': AdamW,
104 | 'radam': RAdam,
105 | }
106 |
107 |
108 | def run_test():
109 | import torch.nn as nn
110 | model = nn.Sequential(*[nn.Linear(100, 10), nn.ReLU(), nn.Linear(10, 2)])
111 | x = torch.randn(10, 100).repeat(100, 1)
112 | y = torch.randint(0, 2, (10,)).repeat(100)
113 | crit = nn.CrossEntropyLoss()
114 | optim = RAdam(model.parameters(), lr=1e-2, weight_decay=0.01)
115 | model.train()
116 | for a in range(0, 1000, 10):
117 | b = a + 10
118 | loss = crit(model(x[a:b]), y[a:b])
119 | loss.backward()
120 | optim.step()
121 | print('| loss: {:.4f} |'.format(loss.item()))
122 |
123 |
124 | if __name__ == '__main__':
125 | run_test()
126 |
--------------------------------------------------------------------------------
/utils/parser_utils.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from utils.utils import *
3 | from modeling.modeling_encoder import MODEL_NAME_TO_CLASS
4 |
5 | ENCODER_DEFAULT_LR = {
6 | 'default': 1e-3,
7 | 'csqa': {
8 | 'lstm': 3e-4,
9 | 'openai-gpt': 1e-4,
10 | 'bert-base-uncased': 3e-5,
11 | 'bert-large-uncased': 2e-5,
12 | 'roberta-large': 1e-5,
13 | },
14 | 'obqa': {
15 | 'lstm': 3e-4,
16 | 'openai-gpt': 3e-5,
17 | 'bert-base-cased': 1e-4,
18 | 'bert-large-cased': 1e-4,
19 | 'roberta-large': 1e-5,
20 | },
21 | 'medqa_usmle': {
22 | 'cambridgeltl/SapBERT-from-PubMedBERT-fulltext': 5e-5,
23 | },
24 | }
25 |
26 | DATASET_LIST = ['csqa', 'obqa', 'socialiqa', 'medqa_usmle']
27 |
28 | DATASET_SETTING = {
29 | 'csqa': 'inhouse',
30 | 'obqa': 'official',
31 | 'socialiqa': 'official',
32 | 'medqa_usmle': 'official',
33 | }
34 |
35 | DATASET_NO_TEST = ['socialiqa']
36 |
37 | EMB_PATHS = {
38 | 'transe': 'data/transe/glove.transe.sgd.ent.npy',
39 | 'lm': 'data/transe/glove.transe.sgd.ent.npy',
40 | 'numberbatch': 'data/transe/concept.nb.npy',
41 | 'tzw': 'data/cpnet/tzw.ent.npy',
42 | 'ddb': 'data/ddb/ent_emb.npy',
43 | }
44 |
45 |
46 | def add_data_arguments(parser):
47 | # arguments that all datasets share
48 | parser.add_argument('--ent_emb', default=['tzw'], nargs='+', help='sources for entity embeddings')
49 | # dataset specific
50 | parser.add_argument('-ds', '--dataset', default='csqa', choices=DATASET_LIST, help='dataset name')
51 | parser.add_argument('-ih', '--inhouse', type=bool_flag, nargs='?', const=True, help='run in-house setting')
52 | parser.add_argument('--inhouse_train_qids', default='data/{dataset}/inhouse_split_qids.txt', help='qids of the in-house training set')
53 | # statements
54 | parser.add_argument('--train_statements', default='data/{dataset}/statement/train.statement.jsonl')
55 | parser.add_argument('--dev_statements', default='data/{dataset}/statement/dev.statement.jsonl')
56 | parser.add_argument('--test_statements', default='data/{dataset}/statement/test.statement.jsonl')
57 | # preprocessing options
58 | parser.add_argument('-sl', '--max_seq_len', default=100, type=int)
59 | # set dataset defaults
60 | args, _ = parser.parse_known_args()
61 | parser.set_defaults(ent_emb_paths=[EMB_PATHS.get(s) for s in args.ent_emb],
62 | inhouse=(DATASET_SETTING[args.dataset] == 'inhouse'),
63 | inhouse_train_qids=args.inhouse_train_qids.format(dataset=args.dataset))
64 | data_splits = ('train', 'dev') if args.dataset in DATASET_NO_TEST else ('train', 'dev', 'test')
65 | for split in data_splits:
66 | for attribute in ('statements',):
67 | attr_name = f'{split}_{attribute}'
68 | parser.set_defaults(**{attr_name: getattr(args, attr_name).format(dataset=args.dataset)})
69 | if 'test' not in data_splits:
70 | parser.set_defaults(test_statements=None)
71 |
72 |
73 | def add_encoder_arguments(parser):
74 | parser.add_argument('-enc', '--encoder', default='bert-large-uncased', help='encoder type')
75 | parser.add_argument('--encoder_layer', default=-1, type=int, help='encoder layer ID to use as features (used only by non-LSTM encoders)')
76 | parser.add_argument('-elr', '--encoder_lr', default=2e-5, type=float, help='learning rate')
77 | args, _ = parser.parse_known_args()
78 | parser.set_defaults(encoder_lr=ENCODER_DEFAULT_LR[args.dataset].get(args.encoder, ENCODER_DEFAULT_LR['default']))
79 |
80 |
81 | def add_optimization_arguments(parser):
82 | parser.add_argument('--loss', default='cross_entropy', choices=['margin_rank', 'cross_entropy'], help='model type')
83 | parser.add_argument('--optim', default='radam', choices=['sgd', 'adam', 'adamw', 'radam'], help='learning rate scheduler')
84 | parser.add_argument('--lr_schedule', default='fixed', choices=['fixed', 'warmup_linear', 'warmup_constant'], help='learning rate scheduler')
85 | parser.add_argument('-bs', '--batch_size', default=32, type=int)
86 | parser.add_argument('--warmup_steps', type=float, default=150)
87 | parser.add_argument('--max_grad_norm', default=1.0, type=float, help='max grad norm (0 to disable)')
88 | parser.add_argument('--weight_decay', default=1e-2, type=float, help='l2 weight decay strength')
89 | parser.add_argument('--n_epochs', default=100, type=int, help='total number of training epochs to perform.')
90 | parser.add_argument('-me', '--max_epochs_before_stop', default=10, type=int, help='stop training if dev does not increase for N epochs')
91 |
92 |
93 | def add_additional_arguments(parser):
94 | parser.add_argument('--log_interval', default=10, type=int)
95 | parser.add_argument('--cuda', default=True, type=bool_flag, nargs='?', const=True, help='use GPU')
96 | parser.add_argument('--seed', default=0, type=int, help='random seed')
97 | parser.add_argument('--debug', default=False, type=bool_flag, nargs='?', const=True, help='run in debug mode')
98 | args, _ = parser.parse_known_args()
99 | if args.debug:
100 | parser.set_defaults(batch_size=1, log_interval=1, eval_interval=5)
101 |
102 |
103 | def get_parser():
104 | """A helper function that handles the arguments that all models share"""
105 | parser = argparse.ArgumentParser(add_help=False)
106 | add_data_arguments(parser)
107 | add_encoder_arguments(parser)
108 | add_optimization_arguments(parser)
109 | add_additional_arguments(parser)
110 | return parser
111 |
112 |
113 | def get_lstm_config_from_args(args):
114 | lstm_config = {
115 | 'hidden_size': args.encoder_dim,
116 | 'output_size': args.encoder_dim,
117 | 'num_layers': args.encoder_layer_num,
118 | 'bidirectional': args.encoder_bidir,
119 | 'emb_p': args.encoder_dropoute,
120 | 'input_p': args.encoder_dropouti,
121 | 'hidden_p': args.encoder_dropouth,
122 | 'pretrained_emb_or_path': args.encoder_pretrained_emb,
123 | 'freeze_emb': args.encoder_freeze_emb,
124 | 'pool_function': args.encoder_pooler,
125 | }
126 | return lstm_config
127 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # QA-GNN: Question Answering using Language Models and Knowledge Graphs
2 |
3 | This repo provides the source code & data of our paper: [QA-GNN: Reasoning with Language Models and Knowledge Graphs for Question Answering](https://arxiv.org/abs/2104.06378) (NAACL 2021).
4 | ```bib
5 | @InProceedings{yasunaga2021qagnn,
6 | author = {Michihiro Yasunaga and Hongyu Ren and Antoine Bosselut and Percy Liang and Jure Leskovec},
7 | title = {QA-GNN: Reasoning with Language Models and Knowledge Graphs for Question Answering},
8 | year = {2021},
9 | booktitle = {North American Chapter of the Association for Computational Linguistics (NAACL)},
10 | }
11 | ```
12 | Webpage: [https://snap.stanford.edu/qagnn](https://snap.stanford.edu/qagnn)
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 | ## Usage
22 | ### 0. Dependencies
23 | Run the following commands to create a conda environment (assuming CUDA10.1):
24 | ```bash
25 | conda create -n qagnn python=3.7
26 | source activate qagnn
27 | pip install torch==1.8.0+cu101 -f https://download.pytorch.org/whl/torch_stable.html
28 | pip install transformers==3.4.0
29 | pip install nltk spacy==2.1.6
30 | python -m spacy download en
31 |
32 | # for torch-geometric
33 | pip install torch-scatter==2.0.7 -f https://pytorch-geometric.com/whl/torch-1.8.0+cu101.html
34 | pip install torch-sparse==0.6.9 -f https://pytorch-geometric.com/whl/torch-1.8.0+cu101.html
35 | pip install torch-geometric==1.7.0 -f https://pytorch-geometric.com/whl/torch-1.8.0+cu101.html
36 | ```
37 |
38 |
39 | ### 1. Download data
40 | We use the question answering datasets (*CommonsenseQA*, *OpenBookQA*) and the ConceptNet knowledge graph.
41 | Download all the raw data by
42 | ```
43 | ./download_raw_data.sh
44 | ```
45 |
46 | Preprocess the raw data by running
47 | ```
48 | python preprocess.py -p
49 | ```
50 | The script will:
51 | * Setup ConceptNet (e.g., extract English relations from ConceptNet, merge the original 42 relation types into 17 types)
52 | * Convert the QA datasets into .jsonl files (e.g., stored in `data/csqa/statement/`)
53 | * Identify all mentioned concepts in the questions and answers
54 | * Extract subgraphs for each q-a pair
55 |
56 | **TL;DR (Skip above steps and just get preprocessed data)**. The preprocessing may take long. For your convenience, you can download all the processed data by
57 | ```
58 | ./download_preprocessed_data.sh
59 | ```
60 |
61 | **🔴 NEWS (Add MedQA-USMLE)**. Besides the commonsense QA datasets (*CommonsenseQA*, *OpenBookQA*) with the ConceptNet knowledge graph, we added a biomedical QA dataset ([*MedQA-USMLE*](https://github.com/jind11/MedQA)) with a biomedical knowledge graph based on Disease Database and DrugBank. You can download all the data for this from [**[here]**](https://nlp.stanford.edu/projects/myasu/QAGNN/data_preprocessed_biomed.zip). Unzip it and put the `medqa_usmle` and `ddb` folders inside the `data/` directory. While this data is already preprocessed, we also provide the preprocessing scripts we used in `utils_biomed/`.
62 |
63 |
64 | The resulting file structure will look like:
65 |
66 | ```plain
67 | .
68 | ├── README.md
69 | ├── data/
70 | ├── cpnet/ (prerocessed ConceptNet)
71 | ├── csqa/
72 | ├── train_rand_split.jsonl
73 | ├── dev_rand_split.jsonl
74 | ├── test_rand_split_no_answers.jsonl
75 | ├── statement/ (converted statements)
76 | ├── grounded/ (grounded entities)
77 | ├── graphs/ (extracted subgraphs)
78 | ├── ...
79 | ├── obqa/
80 | ├── medqa_usmle/
81 | └── ddb/
82 | ```
83 |
84 | ### 2. Train QA-GNN
85 | For CommonsenseQA, run
86 | ```
87 | ./run_qagnn__csqa.sh
88 | ```
89 | For OpenBookQA, run
90 | ```
91 | ./run_qagnn__obqa.sh
92 | ```
93 | For MedQA-USMLE, run
94 | ```
95 | ./run_qagnn__medqa_usmle.sh
96 | ```
97 | As configured in these scripts, the model needs two types of input files
98 | * `--{train,dev,test}_statements`: preprocessed question statements in jsonl format. This is mainly loaded by `load_input_tensors` function in `utils/data_utils.py`.
99 | * `--{train,dev,test}_adj`: information of the KG subgraph extracted for each question. This is mainly loaded by `load_sparse_adj_data_with_contextnode` function in `utils/data_utils.py`.
100 |
101 | **Note**: We find that training for OpenBookQA is unstable (e.g. best dev accuracy varies when using different seeds, different versions of the transformers / torch-geometric libraries, etc.), likely because the dataset is small. We suggest trying out different seeds. Another potential way to stabilize training is to initialize the model with one of the successful checkpoints provided below, e.g. by adding an argument `--load_model_path obqa_model.pt`.
102 |
103 |
104 | ### 3. Evaluate trained model
105 | For CommonsenseQA, run
106 | ```
107 | ./eval_qagnn__csqa.sh
108 | ```
109 | Similarly, for other datasets (OpenBookQA, MedQA-USMLE), run `./eval_qagnn__obqa.sh` and `./eval_qagnn__medqa_usmle.sh`.
110 | You can download trained model checkpoints in the next section.
111 |
112 |
113 | ## Trained model examples
114 | CommonsenseQA
115 |
116 |
117 | | Trained model |
118 | In-house Dev acc. |
119 | In-house Test acc. |
120 |
121 |
122 | | RoBERTa-large + QA-GNN [link] |
123 | 0.7707 |
124 | 0.7405 |
125 |
126 |
127 |
128 | OpenBookQA
129 |
130 |
131 | | Trained model |
132 | Dev acc. |
133 | Test acc. |
134 |
135 |
136 | | RoBERTa-large + QA-GNN [link] |
137 | 0.6960 |
138 | 0.6900 |
139 |
140 |
141 |
142 | MedQA-USMLE
143 |
144 |
145 | | Trained model |
146 | Dev acc. |
147 | Test acc. |
148 |
149 |
150 | | SapBERT-base + QA-GNN [link] |
151 | 0.3789 |
152 | 0.3810 |
153 |
154 |
155 |
156 | **Note**: The models were trained and tested with HuggingFace transformers==3.4.0.
157 |
158 |
159 | ## Use your own dataset
160 | - Convert your dataset to `{train,dev,test}.statement.jsonl` in .jsonl format (see `data/csqa/statement/train.statement.jsonl`)
161 | - Create a directory in `data/{yourdataset}/` to store the .jsonl files
162 | - Modify `preprocess.py` and perform subgraph extraction for your data
163 | - Modify `utils/parser_utils.py` to support your own dataset
164 |
165 |
166 | ## Acknowledgment
167 | This repo is built upon the following work:
168 | ```
169 | Scalable Multi-Hop Relational Reasoning for Knowledge-Aware Question Answering. Yanlin Feng*, Xinyue Chen*, Bill Yuchen Lin, Peifeng Wang, Jun Yan and Xiang Ren. EMNLP 2020.
170 | https://github.com/INK-USC/MHGRN
171 | ```
172 | Many thanks to the authors and developers!
173 |
--------------------------------------------------------------------------------
/modeling/modeling_encoder.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import numpy as np
5 | from transformers import (OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP, BERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
6 | XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP, ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP)
7 | try:
8 | from transformers import ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP
9 | except:
10 | pass
11 | from transformers import AutoModel, BertModel, BertConfig
12 | # from transformers.modeling_bert import BERT_PRETRAINED_MODEL_ARCHIVE_MAP
13 | from utils.layers import *
14 | from utils.data_utils import get_gpt_token_num
15 |
16 | MODEL_CLASS_TO_NAME = {
17 | 'gpt': list(OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP.keys()),
18 | 'bert': list(BERT_PRETRAINED_CONFIG_ARCHIVE_MAP.keys()),
19 | 'xlnet': list(XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP.keys()),
20 | 'roberta': list(ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP.keys()),
21 | 'lstm': ['lstm'],
22 | }
23 | try:
24 | MODEL_CLASS_TO_NAME['albert'] = list(ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP.keys())
25 | except:
26 | pass
27 |
28 | MODEL_NAME_TO_CLASS = {model_name: model_class for model_class, model_name_list in MODEL_CLASS_TO_NAME.items() for model_name in model_name_list}
29 |
30 | #Add SapBERT configuration
31 | model_name = 'cambridgeltl/SapBERT-from-PubMedBERT-fulltext'
32 | MODEL_NAME_TO_CLASS[model_name] = 'bert'
33 |
34 |
35 | class LSTMTextEncoder(nn.Module):
36 | pool_layer_classes = {'mean': MeanPoolLayer, 'max': MaxPoolLayer}
37 |
38 | def __init__(self, vocab_size=1, emb_size=300, hidden_size=300, output_size=300, num_layers=2, bidirectional=True,
39 | emb_p=0.0, input_p=0.0, hidden_p=0.0, pretrained_emb_or_path=None, freeze_emb=True,
40 | pool_function='max', output_hidden_states=False):
41 | super().__init__()
42 | self.output_size = output_size
43 | self.num_layers = num_layers
44 | self.output_hidden_states = output_hidden_states
45 | assert not bidirectional or hidden_size % 2 == 0
46 |
47 | if pretrained_emb_or_path is not None:
48 | if isinstance(pretrained_emb_or_path, str): # load pretrained embedding from a .npy file
49 | pretrained_emb_or_path = torch.tensor(np.load(pretrained_emb_or_path), dtype=torch.float)
50 | emb = nn.Embedding.from_pretrained(pretrained_emb_or_path, freeze=freeze_emb)
51 | emb_size = emb.weight.size(1)
52 | else:
53 | emb = nn.Embedding(vocab_size, emb_size)
54 | self.emb = EmbeddingDropout(emb, emb_p)
55 | self.rnns = nn.ModuleList([nn.LSTM(emb_size if l == 0 else hidden_size,
56 | (hidden_size if l != num_layers else output_size) // (2 if bidirectional else 1),
57 | 1, bidirectional=bidirectional, batch_first=True) for l in range(num_layers)])
58 | self.pooler = self.pool_layer_classes[pool_function]()
59 |
60 | self.input_dropout = nn.Dropout(input_p)
61 | self.hidden_dropout = nn.ModuleList([RNNDropout(hidden_p) for _ in range(num_layers)])
62 |
63 | def forward(self, inputs, lengths):
64 | """
65 | inputs: tensor of shape (batch_size, seq_len)
66 | lengths: tensor of shape (batch_size)
67 |
68 | returns: tensor of shape (batch_size, hidden_size)
69 | """
70 | assert (lengths > 0).all()
71 | batch_size, seq_len = inputs.size()
72 | hidden_states = self.input_dropout(self.emb(inputs))
73 | all_hidden_states = [hidden_states]
74 | for l, (rnn, hid_dp) in enumerate(zip(self.rnns, self.hidden_dropout)):
75 | hidden_states = pack_padded_sequence(hidden_states, lengths, batch_first=True, enforce_sorted=False)
76 | hidden_states, _ = rnn(hidden_states)
77 | hidden_states, _ = pad_packed_sequence(hidden_states, batch_first=True, total_length=seq_len)
78 | all_hidden_states.append(hidden_states)
79 | if l != self.num_layers - 1:
80 | hidden_states = hid_dp(hidden_states)
81 | pooled = self.pooler(all_hidden_states[-1], lengths)
82 | assert len(all_hidden_states) == self.num_layers + 1
83 | outputs = (all_hidden_states[-1], pooled)
84 | if self.output_hidden_states:
85 | outputs = outputs + (all_hidden_states,)
86 | return outputs
87 |
88 |
89 | class TextEncoder(nn.Module):
90 | valid_model_types = set(MODEL_CLASS_TO_NAME.keys())
91 |
92 | def __init__(self, model_name, output_token_states=False, from_checkpoint=None, **kwargs):
93 | super().__init__()
94 | self.model_type = MODEL_NAME_TO_CLASS[model_name]
95 | self.output_token_states = output_token_states
96 | assert not self.output_token_states or self.model_type in ('bert', 'roberta', 'albert')
97 |
98 | if self.model_type in ('lstm',):
99 | self.module = LSTMTextEncoder(**kwargs, output_hidden_states=True)
100 | self.sent_dim = self.module.output_size
101 | else:
102 | model_class = AutoModel
103 | self.module = model_class.from_pretrained(model_name, output_hidden_states=True)
104 | if from_checkpoint is not None:
105 | self.module = self.module.from_pretrained(from_checkpoint, output_hidden_states=True)
106 | if self.model_type in ('gpt',):
107 | self.module.resize_token_embeddings(get_gpt_token_num())
108 | self.sent_dim = self.module.config.n_embd if self.model_type in ('gpt',) else self.module.config.hidden_size
109 |
110 | def forward(self, *inputs, layer_id=-1):
111 | '''
112 | layer_id: only works for non-LSTM encoders
113 | output_token_states: if True, return hidden states of specific layer and attention masks
114 | '''
115 |
116 | if self.model_type in ('lstm',): # lstm
117 | input_ids, lengths = inputs
118 | outputs = self.module(input_ids, lengths)
119 | elif self.model_type in ('gpt',): # gpt
120 | input_ids, cls_token_ids, lm_labels = inputs # lm_labels is not used
121 | outputs = self.module(input_ids)
122 | else: # bert / xlnet / roberta
123 | input_ids, attention_mask, token_type_ids, output_mask = inputs
124 | outputs = self.module(input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask)
125 | all_hidden_states = outputs[-1]
126 | hidden_states = all_hidden_states[layer_id]
127 |
128 | if self.model_type in ('lstm',):
129 | sent_vecs = outputs[1]
130 | elif self.model_type in ('gpt',):
131 | cls_token_ids = cls_token_ids.view(-1).unsqueeze(-1).unsqueeze(-1).expand(-1, 1, hidden_states.size(-1))
132 | sent_vecs = hidden_states.gather(1, cls_token_ids).squeeze(1)
133 | elif self.model_type in ('xlnet',):
134 | sent_vecs = hidden_states[:, -1]
135 | elif self.model_type in ('albert',):
136 | if self.output_token_states:
137 | return hidden_states, output_mask
138 | sent_vecs = hidden_states[:, 0]
139 | else: # bert / roberta
140 | if self.output_token_states:
141 | return hidden_states, output_mask
142 | sent_vecs = self.module.pooler(hidden_states)
143 | return sent_vecs, all_hidden_states
144 |
145 |
146 | def run_test():
147 | encoder = TextEncoder('lstm', vocab_size=100, emb_size=100, hidden_size=200, num_layers=4)
148 | input_ids = torch.randint(0, 100, (30, 70))
149 | lenghts = torch.randint(1, 70, (30,))
150 | outputs = encoder(input_ids, lenghts)
151 | assert outputs[0].size() == (30, 200)
152 | assert len(outputs[1]) == 4 + 1
153 | assert all([x.size() == (30, 70, 100 if l == 0 else 200) for l, x in enumerate(outputs[1])])
154 | print('all tests are passed')
155 |
--------------------------------------------------------------------------------
/utils/convert_csqa.py:
--------------------------------------------------------------------------------
1 | """
2 | Script to convert the retrieved HITS into an entailment dataset
3 | USAGE:
4 | python convert_csqa.py input_file output_file
5 |
6 | JSONL format of files
7 | 1. input_file:
8 | {
9 | "id": "d3b479933e716fb388dfb297e881054c",
10 | "question": {
11 | "stem": "If a lantern is not for sale, where is it likely to be?"
12 | "choices": [{"label": "A", "text": "antique shop"}, {"label": "B", "text": "house"}, {"label": "C", "text": "dark place"}]
13 | },
14 | "answerKey":"B"
15 | }
16 |
17 | 2. output_file:
18 | {
19 | "id": "d3b479933e716fb388dfb297e881054c",
20 | "question": {
21 | "stem": "If a lantern is not for sale, where is it likely to be?"
22 | "choices": [{"label": "A", "text": "antique shop"}, {"label": "B", "text": "house"}, {"label": "C", "text": "dark place"}]
23 | },
24 | "answerKey":"B",
25 |
26 | "statements":[
27 | {label:true, stem: "If a lantern is not for sale, it likely to be at house"},
28 | {label:false, stem: "If a lantern is not for sale, it likely to be at antique shop"},
29 | {label:false, stem: "If a lantern is not for sale, it likely to be at dark place"}
30 | ]
31 | }
32 | """
33 |
34 | import json
35 | import re
36 | import sys
37 | from tqdm import tqdm
38 |
39 | __all__ = ['convert_to_entailment']
40 |
41 | # String used to indicate a blank
42 | BLANK_STR = "___"
43 |
44 |
45 | def convert_to_entailment(qa_file: str, output_file: str, ans_pos: bool=False):
46 | print(f'converting {qa_file} to entailment dataset...')
47 | nrow = sum(1 for _ in open(qa_file, 'r'))
48 | with open(output_file, 'w') as output_handle, open(qa_file, 'r') as qa_handle:
49 | # print("Writing to {} from {}".format(output_file, qa_file))
50 | for line in tqdm(qa_handle, total=nrow):
51 | json_line = json.loads(line)
52 | output_dict = convert_qajson_to_entailment(json_line, ans_pos)
53 | output_handle.write(json.dumps(output_dict))
54 | output_handle.write("\n")
55 | print(f'converted statements saved to {output_file}')
56 | print()
57 |
58 |
59 | # Convert the QA file json to output dictionary containing premise and hypothesis
60 | def convert_qajson_to_entailment(qa_json: dict, ans_pos: bool):
61 | question_text = qa_json["question"]["stem"]
62 | choices = qa_json["question"]["choices"]
63 | for choice in choices:
64 | choice_text = choice["text"]
65 | pos = None
66 | if not ans_pos:
67 | statement = create_hypothesis(get_fitb_from_question(question_text), choice_text, ans_pos)
68 | else:
69 | statement, pos = create_hypothesis(get_fitb_from_question(question_text), choice_text, ans_pos)
70 | create_output_dict(qa_json, statement, choice["label"] == qa_json.get("answerKey", "A"), ans_pos, pos)
71 |
72 | return qa_json
73 |
74 |
75 | # Get a Fill-In-The-Blank (FITB) statement from the question text. E.g. "George wants to warm his
76 | # hands quickly by rubbing them. Which skin surface will produce the most heat?" ->
77 | # "George wants to warm his hands quickly by rubbing them. ___ skin surface will produce the most
78 | # heat?
79 | def get_fitb_from_question(question_text: str) -> str:
80 | fitb = replace_wh_word_with_blank(question_text)
81 | if not re.match(".*_+.*", fitb):
82 | # print("Can't create hypothesis from: '{}'. Appending {} !".format(question_text, BLANK_STR))
83 | # Strip space, period and question mark at the end of the question and add a blank
84 | fitb = re.sub(r"[\.\? ]*$", "", question_text.strip()) + " " + BLANK_STR
85 | return fitb
86 |
87 |
88 | # Create a hypothesis statement from the the input fill-in-the-blank statement and answer choice.
89 | def create_hypothesis(fitb: str, choice: str, ans_pos: bool) -> str:
90 |
91 | if ". " + BLANK_STR in fitb or fitb.startswith(BLANK_STR):
92 | choice = choice[0].upper() + choice[1:]
93 | else:
94 | choice = choice.lower()
95 | # Remove period from the answer choice, if the question doesn't end with the blank
96 | if not fitb.endswith(BLANK_STR):
97 | choice = choice.rstrip(".")
98 | # Some questions already have blanks indicated with 2+ underscores
99 | if not ans_pos:
100 | try:
101 | hypothesis = re.sub("__+", choice, fitb)
102 | except:
103 | print (choice, fitb)
104 | return hypothesis
105 | choice = choice.strip()
106 | m = re.search("__+", fitb)
107 | start = m.start()
108 |
109 | length = (len(choice) - 1) if fitb.endswith(BLANK_STR) and choice[-1] in ['.', '?', '!'] else len(choice)
110 | hypothesis = re.sub("__+", choice, fitb)
111 |
112 | return hypothesis, (start, start + length)
113 |
114 |
115 | # Identify the wh-word in the question and replace with a blank
116 | def replace_wh_word_with_blank(question_str: str):
117 | # if "What is the name of the government building that houses the U.S. Congress?" in question_str:
118 | # print()
119 | question_str = question_str.replace("What's", "What is")
120 | question_str = question_str.replace("whats", "what")
121 | question_str = question_str.replace("U.S.", "US")
122 | wh_word_offset_matches = []
123 | wh_words = ["which", "what", "where", "when", "how", "who", "why"]
124 | for wh in wh_words:
125 | # Some Turk-authored SciQ questions end with wh-word
126 | # E.g. The passing of traits from parents to offspring is done through what?
127 |
128 | if wh == "who" and "people who" in question_str:
129 | continue
130 |
131 | m = re.search(wh + r"\?[^\.]*[\. ]*$", question_str.lower())
132 | if m:
133 | wh_word_offset_matches = [(wh, m.start())]
134 | break
135 | else:
136 | # Otherwise, find the wh-word in the last sentence
137 | m = re.search(wh + r"[ ,][^\.]*[\. ]*$", question_str.lower())
138 | if m:
139 | wh_word_offset_matches.append((wh, m.start()))
140 | # else:
141 | # wh_word_offset_matches.append((wh, question_str.index(wh)))
142 |
143 | # If a wh-word is found
144 | if len(wh_word_offset_matches):
145 | # Pick the first wh-word as the word to be replaced with BLANK
146 | # E.g. Which is most likely needed when describing the change in position of an object?
147 | wh_word_offset_matches.sort(key=lambda x: x[1])
148 | wh_word_found = wh_word_offset_matches[0][0]
149 | wh_word_start_offset = wh_word_offset_matches[0][1]
150 | # Replace the last question mark with period.
151 | question_str = re.sub(r"\?$", ".", question_str.strip())
152 | # Introduce the blank in place of the wh-word
153 | fitb_question = (question_str[:wh_word_start_offset] + BLANK_STR +
154 | question_str[wh_word_start_offset + len(wh_word_found):])
155 | # Drop "of the following" as it doesn't make sense in the absence of a multiple-choice
156 | # question. E.g. "Which of the following force ..." -> "___ force ..."
157 | final = fitb_question.replace(BLANK_STR + " of the following", BLANK_STR)
158 | final = final.replace(BLANK_STR + " of these", BLANK_STR)
159 | return final
160 |
161 | elif " them called?" in question_str:
162 | return question_str.replace(" them called?", " " + BLANK_STR + ".")
163 | elif " meaning he was not?" in question_str:
164 | return question_str.replace(" meaning he was not?", " he was not " + BLANK_STR + ".")
165 | elif " one of these?" in question_str:
166 | return question_str.replace(" one of these?", " " + BLANK_STR + ".")
167 | elif re.match(r".*[^\.\?] *$", question_str):
168 | # If no wh-word is found and the question ends without a period/question, introduce a
169 | # blank at the end. e.g. The gravitational force exerted by an object depends on its
170 | return question_str + " " + BLANK_STR
171 | else:
172 | # If all else fails, assume "this ?" indicates the blank. Used in Turk-authored questions
173 | # e.g. Virtually every task performed by living organisms requires this?
174 | return re.sub(r" this[ \?]", " ___ ", question_str)
175 |
176 |
177 | # Create the output json dictionary from the input json, premise and hypothesis statement
178 | def create_output_dict(input_json: dict, statement: str, label: bool, ans_pos: bool, pos=None) -> dict:
179 | if "statements" not in input_json:
180 | input_json["statements"] = []
181 | if not ans_pos:
182 | input_json["statements"].append({"label": label, "statement": statement})
183 | else:
184 | input_json["statements"].append({"label": label, "statement": statement, "ans_pos": pos})
185 | return input_json
186 |
187 |
188 | if __name__ == "__main__":
189 | if len(sys.argv) < 3:
190 | raise ValueError("Provide at least two arguments: "
191 | "json file with hits, output file name")
192 | convert_to_entailment(sys.argv[1], sys.argv[2])
193 |
--------------------------------------------------------------------------------
/utils/tokenization_utils.py:
--------------------------------------------------------------------------------
1 | from transformers import PreTrainedTokenizer
2 | import os
3 | import nltk
4 | import json
5 | from tqdm import tqdm
6 | import spacy
7 |
8 | EOS_TOK = ''
9 | UNK_TOK = ''
10 | PAD_TOK = ''
11 | SEP_TOK = ''
12 | EXTRA_TOKS = [EOS_TOK, UNK_TOK, PAD_TOK, SEP_TOK]
13 |
14 |
15 | class WordTokenizer(PreTrainedTokenizer):
16 | vocab_files_names = {'vocab_file': 'vocab.txt'}
17 | pretrained_vocab_files_map = {'vocab_file': {'lstm': './data/glove/glove.vocab'}}
18 | max_model_input_sizes = {'lstm': None}
19 | """
20 | vocab_file: Path to a json file that contains token-to-id mapping
21 | """
22 |
23 | def __init__(self, vocab_file, unk_token=UNK_TOK, sep_token=SEP_TOK, pad_token=PAD_TOK, eos_token=EOS_TOK, **kwargs):
24 | super(WordTokenizer, self).__init__(unk_token=unk_token, sep_token=sep_token,
25 | pad_token=pad_token, eos_token=eos_token, **kwargs)
26 | with open(vocab_file, 'r', encoding='utf-8') as fin:
27 | self.vocab = {line.rstrip('\n'): i for i, line in enumerate(fin)}
28 | self.ids_to_tokens = {ids: tok for tok, ids in self.vocab.items()}
29 | self.spacy_tokenizer = spacy.load('en_core_web_sm', disable=['tagger', 'parser', 'ner', 'textcat'])
30 |
31 | @property
32 | def vocab_size(self):
33 | return len(self.vocab)
34 |
35 | def _tokenize(self, text):
36 | return tokenize_sentence_spacy(self.spacy_tokenizer, text, lower_case=True, convert_num=False)
37 |
38 | def _convert_token_to_id(self, token):
39 | """ Converts a token (str/unicode) in an id using the vocab. """
40 | return self.vocab.get(token, self.vocab.get(self.unk_token))
41 |
42 | def _convert_id_to_token(self, index):
43 | """Converts an index (integer) in a token (string/unicode) using the vocab."""
44 | return self.ids_to_tokens.get(index, self.unk_token)
45 |
46 | def convert_tokens_to_string(self, tokens):
47 | """ Converts a sequence of tokens (string) in a single string. """
48 | out_string = ' '.join(tokens).strip()
49 | return out_string
50 |
51 | def add_special_tokens_single_sequence(self, token_ids):
52 | return token_ids + [self.eos_token_id]
53 |
54 | def add_special_tokens_sequence_pair(self, token_ids_0, token_ids_1):
55 | return token_ids_0 + [self.sep_token_id] + token_ids_1
56 |
57 | def save_vocabulary(self, vocab_path):
58 | """Save the tokenizer vocabulary to a directory or file."""
59 | if os.path.isdir(vocab_path):
60 | vocab_file = os.path.join(vocab_path, self.vocab_files_names['vocab_file'])
61 | else:
62 | vocab_file = vocab_path
63 | with open(vocab_file, "w", encoding="utf-8") as fout:
64 | for i in range(len(self.vocab)):
65 | fout.write(self.ids_to_tokens[i] + '\n')
66 | return (vocab_file,)
67 |
68 |
69 | class WordVocab(object):
70 |
71 | def __init__(self, sents=None, path=None, freq_cutoff=5, encoding='utf-8', verbose=True):
72 | """
73 | sents: list[str] (optional, default None)
74 | path: str (optional, default None)
75 | freq_cutoff: int (optional, default 5, 0 to disable)
76 | encoding: str (optional, default utf-8)
77 | """
78 | if sents is not None:
79 | counts = {}
80 | for text in sents:
81 | for w in text.split():
82 | counts[w] = counts.get(w, 0) + 1
83 | self._idx2w = [t[0] for t in sorted(counts.items(), key=lambda x: -x[1])]
84 | self._w2idx = {w: i for i, w in enumerate(self._idx2w)}
85 | self._counts = counts
86 |
87 | elif path is not None:
88 | self._idx2w = []
89 | self._counts = {}
90 | with open(path, 'r', encoding=encoding) as fin:
91 | for line in fin:
92 | w, c = line.rstrip().split(' ')
93 | self._idx2w.append(w)
94 | self._counts[w] = c
95 | self._w2idx = {w: i for i, w in enumerate(self._idx2w)}
96 |
97 | else:
98 | self._idx2w = []
99 | self._w2idx = {}
100 | self._counts = {}
101 |
102 | if freq_cutoff > 1:
103 | self._idx2w = [w for w in self._idx2w if self._counts[w] >= freq_cutoff]
104 |
105 | in_sum = sum([self._counts[w] for w in self._idx2w])
106 | total_sum = sum([self._counts[w] for w in self._counts])
107 | if verbose:
108 | print('vocab oov rate: {:.4f}'.format(1 - in_sum / total_sum))
109 |
110 | self._w2idx = {w: i for i, w in enumerate(self._idx2w)}
111 | self._counts = {w: self._counts[w] for w in self._idx2w}
112 |
113 | def add_word(self, w, count=1):
114 | if w not in self.w2idx:
115 | self._w2idx[w] = len(self._idx2w)
116 | self._idx2w.append(w)
117 | self._counts[w] = count
118 | else:
119 | self._counts[w] += count
120 | return self
121 |
122 | def top_k_cutoff(self, size):
123 | if size < len(self._idx2w):
124 | for w in self._idx2w[size:]:
125 | self._w2idx.pop(w)
126 | self._counts.pop(w)
127 | self._idx2w = self._idx2w[:size]
128 |
129 | assert len(self._idx2w) == len(self._w2idx) == len(self._counts)
130 | return self
131 |
132 | def save(self, path, encoding='utf-8'):
133 | with open(path, 'w', encoding=encoding) as fout:
134 | for w in self._idx2w:
135 | fout.write(w + ' ' + str(self._counts[w]) + '\n')
136 |
137 | def __len__(self):
138 | return len(self._idx2w)
139 |
140 | def __contains__(self, word):
141 | return word in self._w2idx
142 |
143 | def __iter__(self):
144 | for word in self._idx2w:
145 | yield word
146 |
147 | @property
148 | def w2idx(self):
149 | return self._w2idx
150 |
151 | @property
152 | def idx2w(self):
153 | return self._idx2w
154 |
155 | @property
156 | def counts(self):
157 | return self._counts
158 |
159 |
160 | def tokenize_sentence_nltk(sent, lower_case=True, convert_num=False):
161 | tokens = nltk.word_tokenize(sent)
162 | if lower_case:
163 | tokens = [t.lower() for t in tokens]
164 | if convert_num:
165 | tokens = ['' if t.isdigit() else t for t in tokens]
166 | return tokens
167 |
168 |
169 | def tokenize_sentence_spacy(nlp, sent, lower_case=True, convert_num=False):
170 | tokens = [tok.text for tok in nlp(sent)]
171 | if lower_case:
172 | tokens = [t.lower() for t in tokens]
173 | if convert_num:
174 | tokens = ['' if t.isdigit() else t for t in tokens]
175 | return tokens
176 |
177 |
178 | def tokenize_statement_file(statement_path, output_path, lower_case=True, convert_num=False):
179 | nlp = spacy.load('en_core_web_sm', disable=['parser', 'ner', 'textcat'])
180 | nrow = sum(1 for _ in open(statement_path, 'r'))
181 | with open(statement_path, 'r') as fin, open(output_path, 'w') as fout:
182 | for line in tqdm(fin, total=nrow, desc='tokenizing'):
183 | data = json.loads(line)
184 | for statement in data['statements']:
185 | tokens = tokenize_sentence_spacy(nlp, statement['statement'], lower_case=lower_case, convert_num=convert_num)
186 | fout.write(' '.join(tokens) + '\n')
187 |
188 |
189 | def make_word_vocab(statement_path_list, output_path, lower_case=True, convert_num=True, freq_cutoff=5):
190 | """save the vocab to the output_path in json format"""
191 | nlp = spacy.load('en_core_web_sm', disable=['tagger', 'parser', 'ner', 'textcat'])
192 |
193 | docs = []
194 | for path in statement_path_list:
195 | with open(path, 'r', encoding='utf-8') as fin:
196 | for line in fin:
197 | json_dic = json.loads(line)
198 | docs += [json_dic['question']['stem']] + [s['text'] for s in json_dic['question']['choices']]
199 |
200 | counts = {}
201 | for doc in tqdm(docs, desc='making word vocab'):
202 | for w in tokenize_sentence_spacy(nlp, doc, lower_case=lower_case, convert_num=convert_num):
203 | counts[w] = counts.get(w, 0) + 1
204 | idx2w = [t[0] for t in sorted(counts.items(), key=lambda x: -x[1])]
205 | idx2w = [w for w in idx2w if counts[w] >= freq_cutoff]
206 | idx2w += EXTRA_TOKS
207 | w2idx = {w: i for i, w in enumerate(idx2w)}
208 | with open(output_path, 'w', encoding='utf-8') as fout:
209 | json.dump(w2idx, fout)
210 |
211 |
212 | def run_test():
213 | # tokenize_statement_file('data/csqa/statement/dev.statement.jsonl', '/tmp/tokenized.txt', True, True)
214 | # make_word_vocab(['data/csqa/statement/dev.statement.jsonl', 'data/csqa/statement/train.statement.jsonl'], '/tmp/vocab.txt', True, True)
215 | tokenizer = WordTokenizer.from_pretrained('lstm')
216 | print(tokenizer.tokenize('I love NLP since 1998DEC'))
217 | print(tokenizer.tokenize('CXY loves NLP since 1998'))
218 | print(tokenizer.convert_tokens_to_ids(tokenizer.tokenize('CXY loves NLP since 1998')))
219 | print(tokenizer.convert_ids_to_tokens(tokenizer.convert_tokens_to_ids(tokenizer.tokenize('CXY loves NLP since 1998'))))
220 | tokenizer.save_pretrained('/tmp/')
221 | tokenizer = WordTokenizer.from_pretrained('/tmp/')
222 | print('vocab size = {}'.format(tokenizer.vocab_size))
223 |
224 |
225 | if __name__ == '__main__':
226 | run_test()
227 |
--------------------------------------------------------------------------------
/preprocess.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 | from multiprocessing import cpu_count
4 | from utils.convert_csqa import convert_to_entailment
5 | from utils.convert_obqa import convert_to_obqa_statement
6 | from utils.conceptnet import extract_english, construct_graph
7 | from utils.grounding import create_matcher_patterns, ground
8 | from utils.graph import generate_adj_data_from_grounded_concepts__use_LM
9 |
10 | input_paths = {
11 | 'csqa': {
12 | 'train': './data/csqa/train_rand_split.jsonl',
13 | 'dev': './data/csqa/dev_rand_split.jsonl',
14 | 'test': './data/csqa/test_rand_split_no_answers.jsonl',
15 | },
16 | 'obqa': {
17 | 'train': './data/obqa/OpenBookQA-V1-Sep2018/Data/Main/train.jsonl',
18 | 'dev': './data/obqa/OpenBookQA-V1-Sep2018/Data/Main/dev.jsonl',
19 | 'test': './data/obqa/OpenBookQA-V1-Sep2018/Data/Main/test.jsonl',
20 | },
21 | 'obqa-fact': {
22 | 'train': './data/obqa/OpenBookQA-V1-Sep2018/Data/Additional/train_complete.jsonl',
23 | 'dev': './data/obqa/OpenBookQA-V1-Sep2018/Data/Additional/dev_complete.jsonl',
24 | 'test': './data/obqa/OpenBookQA-V1-Sep2018/Data/Additional/test_complete.jsonl',
25 | },
26 | 'cpnet': {
27 | 'csv': './data/cpnet/conceptnet-assertions-5.6.0.csv',
28 | },
29 | }
30 |
31 | output_paths = {
32 | 'cpnet': {
33 | 'csv': './data/cpnet/conceptnet.en.csv',
34 | 'vocab': './data/cpnet/concept.txt',
35 | 'patterns': './data/cpnet/matcher_patterns.json',
36 | 'unpruned-graph': './data/cpnet/conceptnet.en.unpruned.graph',
37 | 'pruned-graph': './data/cpnet/conceptnet.en.pruned.graph',
38 | },
39 | 'csqa': {
40 | 'statement': {
41 | 'train': './data/csqa/statement/train.statement.jsonl',
42 | 'dev': './data/csqa/statement/dev.statement.jsonl',
43 | 'test': './data/csqa/statement/test.statement.jsonl',
44 | },
45 | 'grounded': {
46 | 'train': './data/csqa/grounded/train.grounded.jsonl',
47 | 'dev': './data/csqa/grounded/dev.grounded.jsonl',
48 | 'test': './data/csqa/grounded/test.grounded.jsonl',
49 | },
50 | 'graph': {
51 | 'adj-train': './data/csqa/graph/train.graph.adj.pk',
52 | 'adj-dev': './data/csqa/graph/dev.graph.adj.pk',
53 | 'adj-test': './data/csqa/graph/test.graph.adj.pk',
54 | },
55 | },
56 | 'obqa': {
57 | 'statement': {
58 | 'train': './data/obqa/statement/train.statement.jsonl',
59 | 'dev': './data/obqa/statement/dev.statement.jsonl',
60 | 'test': './data/obqa/statement/test.statement.jsonl',
61 | 'train-fairseq': './data/obqa/fairseq/official/train.jsonl',
62 | 'dev-fairseq': './data/obqa/fairseq/official/valid.jsonl',
63 | 'test-fairseq': './data/obqa/fairseq/official/test.jsonl',
64 | },
65 | 'grounded': {
66 | 'train': './data/obqa/grounded/train.grounded.jsonl',
67 | 'dev': './data/obqa/grounded/dev.grounded.jsonl',
68 | 'test': './data/obqa/grounded/test.grounded.jsonl',
69 | },
70 | 'graph': {
71 | 'adj-train': './data/obqa/graph/train.graph.adj.pk',
72 | 'adj-dev': './data/obqa/graph/dev.graph.adj.pk',
73 | 'adj-test': './data/obqa/graph/test.graph.adj.pk',
74 | },
75 | },
76 | 'obqa-fact': {
77 | 'statement': {
78 | 'train': './data/obqa/statement/train-fact.statement.jsonl',
79 | 'dev': './data/obqa/statement/dev-fact.statement.jsonl',
80 | 'test': './data/obqa/statement/test-fact.statement.jsonl',
81 | 'train-fairseq': './data/obqa/fairseq/official/train-fact.jsonl',
82 | 'dev-fairseq': './data/obqa/fairseq/official/valid-fact.jsonl',
83 | 'test-fairseq': './data/obqa/fairseq/official/test-fact.jsonl',
84 | },
85 | },
86 | }
87 |
88 |
89 | def main():
90 | parser = argparse.ArgumentParser()
91 | parser.add_argument('--run', default=['common'], choices=['common', 'csqa', 'hswag', 'anli', 'exp', 'scitail', 'phys', 'socialiqa', 'obqa', 'obqa-fact', 'make_word_vocab'], nargs='+')
92 | parser.add_argument('--path_prune_threshold', type=float, default=0.12, help='threshold for pruning paths')
93 | parser.add_argument('--max_node_num', type=int, default=200, help='maximum number of nodes per graph')
94 | parser.add_argument('-p', '--nprocs', type=int, default=cpu_count(), help='number of processes to use')
95 | parser.add_argument('--seed', type=int, default=0, help='random seed')
96 | parser.add_argument('--debug', action='store_true', help='enable debug mode')
97 |
98 | args = parser.parse_args()
99 | if args.debug:
100 | raise NotImplementedError()
101 |
102 | routines = {
103 | 'common': [
104 | {'func': extract_english, 'args': (input_paths['cpnet']['csv'], output_paths['cpnet']['csv'], output_paths['cpnet']['vocab'])},
105 | {'func': construct_graph, 'args': (output_paths['cpnet']['csv'], output_paths['cpnet']['vocab'],
106 | output_paths['cpnet']['unpruned-graph'], False)},
107 | {'func': construct_graph, 'args': (output_paths['cpnet']['csv'], output_paths['cpnet']['vocab'],
108 | output_paths['cpnet']['pruned-graph'], True)},
109 | {'func': create_matcher_patterns, 'args': (output_paths['cpnet']['vocab'], output_paths['cpnet']['patterns'])},
110 | ],
111 | 'csqa': [
112 | {'func': convert_to_entailment, 'args': (input_paths['csqa']['train'], output_paths['csqa']['statement']['train'])},
113 | {'func': convert_to_entailment, 'args': (input_paths['csqa']['dev'], output_paths['csqa']['statement']['dev'])},
114 | {'func': convert_to_entailment, 'args': (input_paths['csqa']['test'], output_paths['csqa']['statement']['test'])},
115 | {'func': ground, 'args': (output_paths['csqa']['statement']['train'], output_paths['cpnet']['vocab'],
116 | output_paths['cpnet']['patterns'], output_paths['csqa']['grounded']['train'], args.nprocs)},
117 | {'func': ground, 'args': (output_paths['csqa']['statement']['dev'], output_paths['cpnet']['vocab'],
118 | output_paths['cpnet']['patterns'], output_paths['csqa']['grounded']['dev'], args.nprocs)},
119 | {'func': ground, 'args': (output_paths['csqa']['statement']['test'], output_paths['cpnet']['vocab'],
120 | output_paths['cpnet']['patterns'], output_paths['csqa']['grounded']['test'], args.nprocs)},
121 | {'func': generate_adj_data_from_grounded_concepts__use_LM, 'args': (output_paths['csqa']['grounded']['train'], output_paths['cpnet']['pruned-graph'], output_paths['cpnet']['vocab'], output_paths['csqa']['graph']['adj-train'], args.nprocs)},
122 | {'func': generate_adj_data_from_grounded_concepts__use_LM, 'args': (output_paths['csqa']['grounded']['dev'], output_paths['cpnet']['pruned-graph'], output_paths['cpnet']['vocab'], output_paths['csqa']['graph']['adj-dev'], args.nprocs)},
123 | {'func': generate_adj_data_from_grounded_concepts__use_LM, 'args': (output_paths['csqa']['grounded']['test'], output_paths['cpnet']['pruned-graph'], output_paths['cpnet']['vocab'], output_paths['csqa']['graph']['adj-test'], args.nprocs)},
124 | ],
125 |
126 | 'obqa': [
127 | {'func': convert_to_obqa_statement, 'args': (input_paths['obqa']['train'], output_paths['obqa']['statement']['train'], output_paths['obqa']['statement']['train-fairseq'])},
128 | {'func': convert_to_obqa_statement, 'args': (input_paths['obqa']['dev'], output_paths['obqa']['statement']['dev'], output_paths['obqa']['statement']['dev-fairseq'])},
129 | {'func': convert_to_obqa_statement, 'args': (input_paths['obqa']['test'], output_paths['obqa']['statement']['test'], output_paths['obqa']['statement']['test-fairseq'])},
130 | {'func': ground, 'args': (output_paths['obqa']['statement']['train'], output_paths['cpnet']['vocab'],
131 | output_paths['cpnet']['patterns'], output_paths['obqa']['grounded']['train'], args.nprocs)},
132 | {'func': ground, 'args': (output_paths['obqa']['statement']['dev'], output_paths['cpnet']['vocab'],
133 | output_paths['cpnet']['patterns'], output_paths['obqa']['grounded']['dev'], args.nprocs)},
134 | {'func': ground, 'args': (output_paths['obqa']['statement']['test'], output_paths['cpnet']['vocab'],
135 | output_paths['cpnet']['patterns'], output_paths['obqa']['grounded']['test'], args.nprocs)},
136 | {'func': generate_adj_data_from_grounded_concepts__use_LM, 'args': (output_paths['obqa']['grounded']['train'], output_paths['cpnet']['pruned-graph'], output_paths['cpnet']['vocab'], output_paths['obqa']['graph']['adj-train'], args.nprocs)},
137 | {'func': generate_adj_data_from_grounded_concepts__use_LM, 'args': (output_paths['obqa']['grounded']['dev'], output_paths['cpnet']['pruned-graph'], output_paths['cpnet']['vocab'], output_paths['obqa']['graph']['adj-dev'], args.nprocs)},
138 | {'func': generate_adj_data_from_grounded_concepts__use_LM, 'args': (output_paths['obqa']['grounded']['test'], output_paths['cpnet']['pruned-graph'], output_paths['cpnet']['vocab'], output_paths['obqa']['graph']['adj-test'], args.nprocs)},
139 | ],
140 | }
141 |
142 | for rt in args.run:
143 | for rt_dic in routines[rt]:
144 | rt_dic['func'](*rt_dic['args'])
145 |
146 | print('Successfully run {}'.format(' '.join(args.run)))
147 |
148 |
149 | if __name__ == '__main__':
150 | main()
151 | # pass
152 |
--------------------------------------------------------------------------------
/utils/grounding.py:
--------------------------------------------------------------------------------
1 | from multiprocessing import Pool
2 | import spacy
3 | from spacy.matcher import Matcher
4 | from tqdm import tqdm
5 | import nltk
6 | import json
7 | import string
8 |
9 |
10 | __all__ = ['create_matcher_patterns', 'ground']
11 |
12 |
13 | # the lemma of it/them/mine/.. is -PRON-
14 |
15 | blacklist = set(["-PRON-", "actually", "likely", "possibly", "want",
16 | "make", "my", "someone", "sometimes_people", "sometimes", "would", "want_to",
17 | "one", "something", "sometimes", "everybody", "somebody", "could", "could_be"
18 | ])
19 |
20 |
21 | nltk.download('stopwords', quiet=True)
22 | nltk_stopwords = nltk.corpus.stopwords.words('english')
23 |
24 | # CHUNK_SIZE = 1
25 |
26 | CPNET_VOCAB = None
27 | PATTERN_PATH = None
28 | nlp = None
29 | matcher = None
30 |
31 |
32 | def load_cpnet_vocab(cpnet_vocab_path):
33 | with open(cpnet_vocab_path, "r", encoding="utf8") as fin:
34 | cpnet_vocab = [l.strip() for l in fin]
35 | cpnet_vocab = [c.replace("_", " ") for c in cpnet_vocab]
36 | return cpnet_vocab
37 |
38 |
39 | def create_pattern(nlp, doc, debug=False):
40 | pronoun_list = set(["my", "you", "it", "its", "your", "i", "he", "she", "his", "her", "they", "them", "their", "our", "we"])
41 | # Filtering concepts consisting of all stop words and longer than four words.
42 | if len(doc) >= 5 or doc[0].text in pronoun_list or doc[-1].text in pronoun_list or \
43 | all([(token.text in nltk_stopwords or token.lemma_ in nltk_stopwords or token.lemma_ in blacklist) for token in doc]):
44 | if debug:
45 | return False, doc.text
46 | return None # ignore this concept as pattern
47 |
48 | pattern = []
49 | for token in doc: # a doc is a concept
50 | pattern.append({"LEMMA": token.lemma_})
51 | if debug:
52 | return True, doc.text
53 | return pattern
54 |
55 |
56 | def create_matcher_patterns(cpnet_vocab_path, output_path, debug=False):
57 | cpnet_vocab = load_cpnet_vocab(cpnet_vocab_path)
58 | nlp = spacy.load('en_core_web_sm', disable=['parser', 'ner', 'textcat'])
59 | docs = nlp.pipe(cpnet_vocab)
60 | all_patterns = {}
61 |
62 | if debug:
63 | f = open("filtered_concept.txt", "w")
64 |
65 | for doc in tqdm(docs, total=len(cpnet_vocab)):
66 |
67 | pattern = create_pattern(nlp, doc, debug)
68 | if debug:
69 | if not pattern[0]:
70 | f.write(pattern[1] + '\n')
71 |
72 | if pattern is None:
73 | continue
74 | all_patterns["_".join(doc.text.split(" "))] = pattern
75 |
76 | print("Created " + str(len(all_patterns)) + " patterns.")
77 | with open(output_path, "w", encoding="utf8") as fout:
78 | json.dump(all_patterns, fout)
79 | if debug:
80 | f.close()
81 |
82 |
83 | def lemmatize(nlp, concept):
84 |
85 | doc = nlp(concept.replace("_", " "))
86 | lcs = set()
87 | # for i in range(len(doc)):
88 | # lemmas = []
89 | # for j, token in enumerate(doc):
90 | # if j == i:
91 | # lemmas.append(token.lemma_)
92 | # else:
93 | # lemmas.append(token.text)
94 | # lc = "_".join(lemmas)
95 | # lcs.add(lc)
96 | lcs.add("_".join([token.lemma_ for token in doc])) # all lemma
97 | return lcs
98 |
99 |
100 | def load_matcher(nlp, pattern_path):
101 | with open(pattern_path, "r", encoding="utf8") as fin:
102 | all_patterns = json.load(fin)
103 |
104 | matcher = Matcher(nlp.vocab)
105 | for concept, pattern in all_patterns.items():
106 | matcher.add(concept, None, pattern)
107 | return matcher
108 |
109 |
110 | def ground_qa_pair(qa_pair):
111 | global nlp, matcher
112 | if nlp is None or matcher is None:
113 | nlp = spacy.load('en_core_web_sm', disable=['ner', 'parser', 'textcat'])
114 | nlp.add_pipe(nlp.create_pipe('sentencizer'))
115 | matcher = load_matcher(nlp, PATTERN_PATH)
116 |
117 | s, a = qa_pair
118 | all_concepts = ground_mentioned_concepts(nlp, matcher, s, a)
119 | answer_concepts = ground_mentioned_concepts(nlp, matcher, a)
120 | question_concepts = all_concepts - answer_concepts
121 | if len(question_concepts) == 0:
122 | question_concepts = hard_ground(nlp, s, CPNET_VOCAB) # not very possible
123 |
124 | if len(answer_concepts) == 0:
125 | answer_concepts = hard_ground(nlp, a, CPNET_VOCAB) # some case
126 |
127 | # question_concepts = question_concepts - answer_concepts
128 | question_concepts = sorted(list(question_concepts))
129 | answer_concepts = sorted(list(answer_concepts))
130 | return {"sent": s, "ans": a, "qc": question_concepts, "ac": answer_concepts}
131 |
132 |
133 | def ground_mentioned_concepts(nlp, matcher, s, ans=None):
134 |
135 | s = s.lower()
136 | doc = nlp(s)
137 | matches = matcher(doc)
138 |
139 | mentioned_concepts = set()
140 | span_to_concepts = {}
141 |
142 | if ans is not None:
143 | ans_matcher = Matcher(nlp.vocab)
144 | ans_words = nlp(ans)
145 | # print(ans_words)
146 | ans_matcher.add(ans, None, [{'TEXT': token.text.lower()} for token in ans_words])
147 |
148 | ans_match = ans_matcher(doc)
149 | ans_mentions = set()
150 | for _, ans_start, ans_end in ans_match:
151 | ans_mentions.add((ans_start, ans_end))
152 |
153 | for match_id, start, end in matches:
154 | if ans is not None:
155 | if (start, end) in ans_mentions:
156 | continue
157 |
158 | span = doc[start:end].text # the matched span
159 |
160 | # a word that appears in answer is not considered as a mention in the question
161 | # if len(set(span.split(" ")).intersection(set(ans.split(" ")))) > 0:
162 | # continue
163 | original_concept = nlp.vocab.strings[match_id]
164 | original_concept_set = set()
165 | original_concept_set.add(original_concept)
166 |
167 | # print("span", span)
168 | # print("concept", original_concept)
169 | # print("Matched '" + span + "' to the rule '" + string_id)
170 |
171 | # why do you lemmatize a mention whose len == 1?
172 |
173 | if len(original_concept.split("_")) == 1:
174 | # tag = doc[start].tag_
175 | # if tag in ['VBN', 'VBG']:
176 |
177 | original_concept_set.update(lemmatize(nlp, nlp.vocab.strings[match_id]))
178 |
179 | if span not in span_to_concepts:
180 | span_to_concepts[span] = set()
181 |
182 | span_to_concepts[span].update(original_concept_set)
183 |
184 | for span, concepts in span_to_concepts.items():
185 | concepts_sorted = list(concepts)
186 | # print("span:")
187 | # print(span)
188 | # print("concept_sorted:")
189 | # print(concepts_sorted)
190 | concepts_sorted.sort(key=len)
191 |
192 | # mentioned_concepts.update(concepts_sorted[0:2])
193 |
194 | shortest = concepts_sorted[0:3]
195 |
196 | for c in shortest:
197 | if c in blacklist:
198 | continue
199 |
200 | # a set with one string like: set("like_apples")
201 | lcs = lemmatize(nlp, c)
202 | intersect = lcs.intersection(shortest)
203 | if len(intersect) > 0:
204 | mentioned_concepts.add(list(intersect)[0])
205 | else:
206 | mentioned_concepts.add(c)
207 |
208 | # if a mention exactly matches with a concept
209 |
210 | exact_match = set([concept for concept in concepts_sorted if concept.replace("_", " ").lower() == span.lower()])
211 | # print("exact match:")
212 | # print(exact_match)
213 | assert len(exact_match) < 2
214 | mentioned_concepts.update(exact_match)
215 |
216 | return mentioned_concepts
217 |
218 |
219 | def hard_ground(nlp, sent, cpnet_vocab):
220 | sent = sent.lower()
221 | doc = nlp(sent)
222 | res = set()
223 | for t in doc:
224 | if t.lemma_ in cpnet_vocab:
225 | res.add(t.lemma_)
226 | sent = " ".join([t.text for t in doc])
227 | if sent in cpnet_vocab:
228 | res.add(sent)
229 | try:
230 | assert len(res) > 0
231 | except Exception:
232 | print(f"for {sent}, concept not found in hard grounding.")
233 | return res
234 |
235 |
236 | def match_mentioned_concepts(sents, answers, num_processes):
237 | res = []
238 | with Pool(num_processes) as p:
239 | res = list(tqdm(p.imap(ground_qa_pair, zip(sents, answers)), total=len(sents)))
240 | return res
241 |
242 |
243 | # To-do: examine prune
244 | def prune(data, cpnet_vocab_path):
245 | # reload cpnet_vocab
246 | with open(cpnet_vocab_path, "r", encoding="utf8") as fin:
247 | cpnet_vocab = [l.strip() for l in fin]
248 |
249 | prune_data = []
250 | for item in tqdm(data):
251 | qc = item["qc"]
252 | prune_qc = []
253 | for c in qc:
254 | if c[-2:] == "er" and c[:-2] in qc:
255 | continue
256 | if c[-1:] == "e" and c[:-1] in qc:
257 | continue
258 | have_stop = False
259 | # remove all concepts having stopwords, including hard-grounded ones
260 | for t in c.split("_"):
261 | if t in nltk_stopwords:
262 | have_stop = True
263 | if not have_stop and c in cpnet_vocab:
264 | prune_qc.append(c)
265 |
266 | ac = item["ac"]
267 | prune_ac = []
268 | for c in ac:
269 | if c[-2:] == "er" and c[:-2] in ac:
270 | continue
271 | if c[-1:] == "e" and c[:-1] in ac:
272 | continue
273 | all_stop = True
274 | for t in c.split("_"):
275 | if t not in nltk_stopwords:
276 | all_stop = False
277 | if not all_stop and c in cpnet_vocab:
278 | prune_ac.append(c)
279 |
280 | try:
281 | assert len(prune_ac) > 0 and len(prune_qc) > 0
282 | except Exception as e:
283 | pass
284 | # print("In pruning")
285 | # print(prune_qc)
286 | # print(prune_ac)
287 | # print("original:")
288 | # print(qc)
289 | # print(ac)
290 | # print()
291 | item["qc"] = prune_qc
292 | item["ac"] = prune_ac
293 |
294 | prune_data.append(item)
295 | return prune_data
296 |
297 |
298 | def ground(statement_path, cpnet_vocab_path, pattern_path, output_path, num_processes=1, debug=False):
299 | global PATTERN_PATH, CPNET_VOCAB
300 | if PATTERN_PATH is None:
301 | PATTERN_PATH = pattern_path
302 | CPNET_VOCAB = load_cpnet_vocab(cpnet_vocab_path)
303 |
304 | sents = []
305 | answers = []
306 | with open(statement_path, 'r') as fin:
307 | lines = [line for line in fin]
308 |
309 | if debug:
310 | lines = lines[192:195]
311 | print(len(lines))
312 | for line in lines:
313 | if line == "":
314 | continue
315 | j = json.loads(line)
316 | # {'answerKey': 'B',
317 | # 'id': 'b8c0a4703079cf661d7261a60a1bcbff',
318 | # 'question': {'question_concept': 'magazines',
319 | # 'choices': [{'label': 'A', 'text': 'doctor'}, {'label': 'B', 'text': 'bookstore'}, {'label': 'C', 'text': 'market'}, {'label': 'D', 'text': 'train station'}, {'label': 'E', 'text': 'mortuary'}],
320 | # 'stem': 'Where would you find magazines along side many other printed works?'},
321 | # 'statements': [{'label': False, 'statement': 'Doctor would you find magazines along side many other printed works.'}, {'label': True, 'statement': 'Bookstore would you find magazines along side many other printed works.'}, {'label': False, 'statement': 'Market would you find magazines along side many other printed works.'}, {'label': False, 'statement': 'Train station would you find magazines along side many other printed works.'}, {'label': False, 'statement': 'Mortuary would you find magazines along side many other printed works.'}]}
322 |
323 | for statement in j["statements"]:
324 | sents.append(statement["statement"])
325 |
326 | for answer in j["question"]["choices"]:
327 | ans = answer['text']
328 | # ans = " ".join(answer['text'].split("_"))
329 | try:
330 | assert all([i != "_" for i in ans])
331 | except Exception:
332 | print(ans)
333 | answers.append(ans)
334 |
335 | res = match_mentioned_concepts(sents, answers, num_processes)
336 | res = prune(res, cpnet_vocab_path)
337 |
338 | # check_path(output_path)
339 | with open(output_path, 'w') as fout:
340 | for dic in res:
341 | fout.write(json.dumps(dic) + '\n')
342 |
343 | print(f'grounded concepts saved to {output_path}')
344 | print()
345 |
346 |
347 | if __name__ == "__main__":
348 | create_matcher_patterns("../data/cpnet/concept.txt", "./matcher_res.txt", True)
349 | # ground("../data/statement/dev.statement.jsonl", "../data/cpnet/concept.txt", "../data/cpnet/matcher_patterns.json", "./ground_res.jsonl", 10, True)
350 |
351 | # s = "a revolving door is convenient for two direction travel, but it also serves as a security measure at a bank."
352 | # a = "bank"
353 | # nlp = spacy.load('en_core_web_sm', disable=['ner', 'parser', 'textcat'])
354 | # nlp.add_pipe(nlp.create_pipe('sentencizer'))
355 | # ans_words = nlp(a)
356 | # doc = nlp(s)
357 | # ans_matcher = Matcher(nlp.vocab)
358 | # print([{'TEXT': token.text.lower()} for token in ans_words])
359 | # ans_matcher.add("ok", None, [{'TEXT': token.text.lower()} for token in ans_words])
360 | #
361 | # matches = ans_matcher(doc)
362 | # for a, b, c in matches:
363 | # print(a, b, c)
364 |
--------------------------------------------------------------------------------
/utils/conceptnet.py:
--------------------------------------------------------------------------------
1 | import networkx as nx
2 | import nltk
3 | import json
4 | import math
5 | from tqdm import tqdm
6 | import numpy as np
7 | import sys
8 |
9 | try:
10 | from .utils import check_file
11 | except ImportError:
12 | from utils import check_file
13 |
14 | __all__ = ['extract_english', 'construct_graph', 'merged_relations']
15 |
16 | relation_groups = [
17 | 'atlocation/locatednear',
18 | 'capableof',
19 | 'causes/causesdesire/*motivatedbygoal',
20 | 'createdby',
21 | 'desires',
22 | 'antonym/distinctfrom',
23 | 'hascontext',
24 | 'hasproperty',
25 | 'hassubevent/hasfirstsubevent/haslastsubevent/hasprerequisite/entails/mannerof',
26 | 'isa/instanceof/definedas',
27 | 'madeof',
28 | 'notcapableof',
29 | 'notdesires',
30 | 'partof/*hasa',
31 | 'relatedto/similarto/synonym',
32 | 'usedfor',
33 | 'receivesaction',
34 | ]
35 |
36 | merged_relations = [
37 | 'antonym',
38 | 'atlocation',
39 | 'capableof',
40 | 'causes',
41 | 'createdby',
42 | 'isa',
43 | 'desires',
44 | 'hassubevent',
45 | 'partof',
46 | 'hascontext',
47 | 'hasproperty',
48 | 'madeof',
49 | 'notcapableof',
50 | 'notdesires',
51 | 'receivesaction',
52 | 'relatedto',
53 | 'usedfor',
54 | ]
55 |
56 | relation_text = [
57 | 'is the antonym of',
58 | 'is at location of',
59 | 'is capable of',
60 | 'causes',
61 | 'is created by',
62 | 'is a kind of',
63 | 'desires',
64 | 'has subevent',
65 | 'is part of',
66 | 'has context',
67 | 'has property',
68 | 'is made of',
69 | 'is not capable of',
70 | 'does not desires',
71 | 'is',
72 | 'is related to',
73 | 'is used for',
74 | ]
75 |
76 |
77 | def load_merge_relation():
78 | relation_mapping = dict()
79 | for line in relation_groups:
80 | ls = line.strip().split('/')
81 | rel = ls[0]
82 | for l in ls:
83 | if l.startswith("*"):
84 | relation_mapping[l[1:]] = "*" + rel
85 | else:
86 | relation_mapping[l] = rel
87 | return relation_mapping
88 |
89 |
90 | def del_pos(s):
91 | """
92 | Deletes part-of-speech encoding from an entity string, if present.
93 | :param s: Entity string.
94 | :return: Entity string with part-of-speech encoding removed.
95 | """
96 | if s.endswith("/n") or s.endswith("/a") or s.endswith("/v") or s.endswith("/r"):
97 | s = s[:-2]
98 | return s
99 |
100 |
101 | def extract_english(conceptnet_path, output_csv_path, output_vocab_path):
102 | """
103 | Reads original conceptnet csv file and extracts all English relations (head and tail are both English entities) into
104 | a new file, with the following format for each line: .
105 | :return:
106 | """
107 | print('extracting English concepts and relations from ConceptNet...')
108 | relation_mapping = load_merge_relation()
109 | num_lines = sum(1 for line in open(conceptnet_path, 'r', encoding='utf-8'))
110 | cpnet_vocab = []
111 | concepts_seen = set()
112 | with open(conceptnet_path, 'r', encoding="utf8") as fin, \
113 | open(output_csv_path, 'w', encoding="utf8") as fout:
114 | for line in tqdm(fin, total=num_lines):
115 | toks = line.strip().split('\t')
116 | if toks[2].startswith('/c/en/') and toks[3].startswith('/c/en/'):
117 | """
118 | Some preprocessing:
119 | - Remove part-of-speech encoding.
120 | - Split("/")[-1] to trim the "/c/en/" and just get the entity name, convert all to
121 | - Lowercase for uniformity.
122 | """
123 | rel = toks[1].split("/")[-1].lower()
124 | head = del_pos(toks[2]).split("/")[-1].lower()
125 | tail = del_pos(toks[3]).split("/")[-1].lower()
126 |
127 | if not head.replace("_", "").replace("-", "").isalpha():
128 | continue
129 | if not tail.replace("_", "").replace("-", "").isalpha():
130 | continue
131 | if rel not in relation_mapping:
132 | continue
133 |
134 | rel = relation_mapping[rel]
135 | if rel.startswith("*"):
136 | head, tail, rel = tail, head, rel[1:]
137 |
138 | data = json.loads(toks[4])
139 |
140 | fout.write('\t'.join([rel, head, tail, str(data["weight"])]) + '\n')
141 |
142 | for w in [head, tail]:
143 | if w not in concepts_seen:
144 | concepts_seen.add(w)
145 | cpnet_vocab.append(w)
146 |
147 | with open(output_vocab_path, 'w') as fout:
148 | for word in cpnet_vocab:
149 | fout.write(word + '\n')
150 |
151 | print(f'extracted ConceptNet csv file saved to {output_csv_path}')
152 | print(f'extracted concept vocabulary saved to {output_vocab_path}')
153 | print()
154 |
155 |
156 | def construct_graph(cpnet_csv_path, cpnet_vocab_path, output_path, prune=True):
157 | print('generating ConceptNet graph file...')
158 |
159 | nltk.download('stopwords', quiet=True)
160 | nltk_stopwords = nltk.corpus.stopwords.words('english')
161 | nltk_stopwords += ["like", "gone", "did", "going", "would", "could",
162 | "get", "in", "up", "may", "wanter"] # issue: mismatch with the stop words in grouding.py
163 |
164 | blacklist = set(["uk", "us", "take", "make", "object", "person", "people"]) # issue: mismatch with the blacklist in grouding.py
165 |
166 | concept2id = {}
167 | id2concept = {}
168 | with open(cpnet_vocab_path, "r", encoding="utf8") as fin:
169 | id2concept = [w.strip() for w in fin]
170 | concept2id = {w: i for i, w in enumerate(id2concept)}
171 |
172 | id2relation = merged_relations
173 | relation2id = {r: i for i, r in enumerate(id2relation)}
174 |
175 | graph = nx.MultiDiGraph()
176 | nrow = sum(1 for _ in open(cpnet_csv_path, 'r', encoding='utf-8'))
177 | with open(cpnet_csv_path, "r", encoding="utf8") as fin:
178 |
179 | def not_save(cpt):
180 | if cpt in blacklist:
181 | return True
182 | '''originally phrases like "branch out" would not be kept in the graph'''
183 | # for t in cpt.split("_"):
184 | # if t in nltk_stopwords:
185 | # return True
186 | return False
187 |
188 | attrs = set()
189 |
190 | for line in tqdm(fin, total=nrow):
191 | ls = line.strip().split('\t')
192 | rel = relation2id[ls[0]]
193 | subj = concept2id[ls[1]]
194 | obj = concept2id[ls[2]]
195 | weight = float(ls[3])
196 | if prune and (not_save(ls[1]) or not_save(ls[2]) or id2relation[rel] == "hascontext"):
197 | continue
198 | # if id2relation[rel] == "relatedto" or id2relation[rel] == "antonym":
199 | # weight -= 0.3
200 | # continue
201 | if subj == obj: # delete loops
202 | continue
203 | # weight = 1 + float(math.exp(1 - weight)) # issue: ???
204 |
205 | if (subj, obj, rel) not in attrs:
206 | graph.add_edge(subj, obj, rel=rel, weight=weight)
207 | attrs.add((subj, obj, rel))
208 | graph.add_edge(obj, subj, rel=rel + len(relation2id), weight=weight)
209 | attrs.add((obj, subj, rel + len(relation2id)))
210 |
211 | nx.write_gpickle(graph, output_path)
212 | print(f"graph file saved to {output_path}")
213 | print()
214 |
215 |
216 | def glove_init(input, output, concept_file):
217 | embeddings_file = output + '.npy'
218 | vocabulary_file = output.split('.')[0] + '.vocab.txt'
219 | output_dir = '/'.join(output.split('/')[:-1])
220 | output_prefix = output.split('/')[-1]
221 |
222 | words = []
223 | vectors = []
224 | vocab_exist = check_file(vocabulary_file)
225 | print("loading embedding")
226 | with open(input, 'rb') as f:
227 | for line in f:
228 | fields = line.split()
229 | if len(fields) <= 2:
230 | continue
231 | if not vocab_exist:
232 | word = fields[0].decode('utf-8')
233 | words.append(word)
234 | vector = np.fromiter((float(x) for x in fields[1:]),
235 | dtype=np.float)
236 |
237 | vectors.append(vector)
238 | dim = vector.shape[0]
239 | print("converting")
240 | matrix = np.array(vectors, dtype="float32")
241 | print("writing")
242 | np.save(embeddings_file, matrix)
243 | text = '\n'.join(words)
244 | if not vocab_exist:
245 | with open(vocabulary_file, 'wb') as f:
246 | f.write(text.encode('utf-8'))
247 |
248 | def load_glove_from_npy(glove_vec_path, glove_vocab_path):
249 | vectors = np.load(glove_vec_path)
250 | with open(glove_vocab_path, "r", encoding="utf8") as f:
251 | vocab = [l.strip() for l in f.readlines()]
252 |
253 | assert (len(vectors) == len(vocab))
254 |
255 | glove_embeddings = {}
256 | for i in range(0, len(vectors)):
257 | glove_embeddings[vocab[i]] = vectors[i]
258 | print("Read " + str(len(glove_embeddings)) + " glove vectors.")
259 | return glove_embeddings
260 |
261 | def weighted_average(avg, new, n):
262 | # TODO: maybe a better name for this function?
263 | return ((n - 1) / n) * avg + (new / n)
264 |
265 | def max_pooling(old, new):
266 | # TODO: maybe a better name for this function?
267 | return np.maximum(old, new)
268 |
269 | def write_embeddings_npy(embeddings, embeddings_cnt, npy_path, vocab_path):
270 | words = []
271 | vectors = []
272 | for key, vec in embeddings.items():
273 | words.append(key)
274 | vectors.append(vec)
275 |
276 | matrix = np.array(vectors, dtype="float32")
277 | print(matrix.shape)
278 |
279 | print("Writing embeddings matrix to " + npy_path, flush=True)
280 | np.save(npy_path, matrix)
281 | print("Finished writing embeddings matrix to " + npy_path, flush=True)
282 |
283 | if not check_file(vocab_path):
284 | print("Writing vocab file to " + vocab_path, flush=True)
285 | to_write = ["\t".join([w, str(embeddings_cnt[w])]) for w in words]
286 | with open(vocab_path, "w", encoding="utf8") as f:
287 | f.write("\n".join(to_write))
288 | print("Finished writing vocab file to " + vocab_path, flush=True)
289 |
290 | def create_embeddings_glove(pooling="max", dim=100):
291 | print("Pooling: " + pooling)
292 |
293 | with open(concept_file, "r", encoding="utf8") as f:
294 | triple_str_json = json.load(f)
295 | print("Loaded " + str(len(triple_str_json)) + " triple strings.")
296 |
297 | glove_embeddings = load_glove_from_npy(embeddings_file, vocabulary_file)
298 | print("Loaded glove.", flush=True)
299 |
300 | concept_embeddings = {}
301 | concept_embeddings_cnt = {}
302 | rel_embeddings = {}
303 | rel_embeddings_cnt = {}
304 |
305 | for i in tqdm(range(len(triple_str_json))):
306 | data = triple_str_json[i]
307 |
308 | words = data["string"].strip().split(" ")
309 |
310 | rel = data["rel"]
311 | subj_start = data["subj_start"]
312 | subj_end = data["subj_end"]
313 | obj_start = data["obj_start"]
314 | obj_end = data["obj_end"]
315 |
316 | subj_words = words[subj_start:subj_end]
317 | obj_words = words[obj_start:obj_end]
318 |
319 | subj = " ".join(subj_words)
320 | obj = " ".join(obj_words)
321 |
322 | # counting the frequency (only used for the avg pooling)
323 | if subj not in concept_embeddings:
324 | concept_embeddings[subj] = np.zeros((dim,))
325 | concept_embeddings_cnt[subj] = 0
326 | concept_embeddings_cnt[subj] += 1
327 |
328 | if obj not in concept_embeddings:
329 | concept_embeddings[obj] = np.zeros((dim,))
330 | concept_embeddings_cnt[obj] = 0
331 | concept_embeddings_cnt[obj] += 1
332 |
333 | if rel not in rel_embeddings:
334 | rel_embeddings[rel] = np.zeros((dim,))
335 | rel_embeddings_cnt[rel] = 0
336 | rel_embeddings_cnt[rel] += 1
337 |
338 | if pooling == "avg":
339 | subj_encoding_sum = sum([glove_embeddings.get(word, np.zeros((dim,))) for word in subj])
340 | obj_encoding_sum = sum([glove_embeddings.get(word, np.zeros((dim,))) for word in obj])
341 |
342 | if rel in ["relatedto", "antonym"]:
343 | # Symmetric relation.
344 | rel_encoding_sum = sum([glove_embeddings.get(word, np.zeros((dim,))) for word in
345 | words]) - subj_encoding_sum - obj_encoding_sum
346 | else:
347 | # Asymmetrical relation.
348 | rel_encoding_sum = obj_encoding_sum - subj_encoding_sum
349 |
350 | subj_len = subj_end - subj_start
351 | obj_len = obj_end - obj_start
352 |
353 | subj_encoding = subj_encoding_sum / subj_len
354 | obj_encoding = obj_encoding_sum / obj_len
355 | rel_encoding = rel_encoding_sum / (len(words) - subj_len - obj_len)
356 |
357 | concept_embeddings[subj] = subj_encoding
358 | concept_embeddings[obj] = obj_encoding
359 | rel_embeddings[rel] = weighted_average(rel_embeddings[rel], rel_encoding, rel_embeddings_cnt[rel])
360 |
361 | elif pooling == "max":
362 | subj_encoding = np.amax([glove_embeddings.get(word, np.zeros((dim,))) for word in subj_words], axis=0)
363 | obj_encoding = np.amax([glove_embeddings.get(word, np.zeros((dim,))) for word in obj_words], axis=0)
364 |
365 | mask_rel = []
366 | for j in range(len(words)):
367 | if subj_start <= j < subj_end or obj_start <= j < obj_end:
368 | continue
369 | mask_rel.append(j)
370 | rel_vecs = [glove_embeddings.get(words[i], np.zeros((dim,))) for i in mask_rel]
371 | rel_encoding = np.amax(rel_vecs, axis=0)
372 |
373 | # here it is actually avg over max for relation
374 | concept_embeddings[subj] = max_pooling(concept_embeddings[subj], subj_encoding)
375 | concept_embeddings[obj] = max_pooling(concept_embeddings[obj], obj_encoding)
376 | rel_embeddings[rel] = weighted_average(rel_embeddings[rel], rel_encoding, rel_embeddings_cnt[rel])
377 |
378 | print(str(len(concept_embeddings)) + " concept embeddings")
379 | print(str(len(rel_embeddings)) + " relation embeddings")
380 |
381 | write_embeddings_npy(concept_embeddings, concept_embeddings_cnt, f'{output_dir}/concept.{output_prefix}.{pooling}.npy',
382 | f'{output_dir}/concept.glove.{pooling}.txt')
383 | write_embeddings_npy(rel_embeddings, rel_embeddings_cnt, f'{output_dir}/relation.{output_prefix}.{pooling}.npy',
384 | f'{output_dir}/relation.glove.{pooling}.txt')
385 |
386 | create_embeddings_glove(dim=dim)
387 |
388 |
389 | if __name__ == "__main__":
390 | glove_init("../data/glove/glove.6B.200d.txt", "../data/glove/glove.200d", '../data/glove/tp_str_corpus.json')
391 |
--------------------------------------------------------------------------------
/utils_biomed/preprocess_medqa_usmle.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": null,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "import os\n",
10 | "import json\n",
11 | "import pickle\n",
12 | "import numpy as np\n",
13 | "from tqdm import tqdm\n",
14 | "from collections import defaultdict"
15 | ]
16 | },
17 | {
18 | "cell_type": "code",
19 | "execution_count": null,
20 | "metadata": {},
21 | "outputs": [],
22 | "source": [
23 | "repo_root = '..'"
24 | ]
25 | },
26 | {
27 | "cell_type": "markdown",
28 | "metadata": {},
29 | "source": [
30 | "## Get MedQA-USMLE dataset"
31 | ]
32 | },
33 | {
34 | "cell_type": "markdown",
35 | "metadata": {},
36 | "source": [
37 | "First, download the original MedQA dataset: https://github.com/jind11/MedQA. \n",
38 | "Put the unzipped folder in `data/medqa_usmle/raw`"
39 | ]
40 | },
41 | {
42 | "cell_type": "code",
43 | "execution_count": null,
44 | "metadata": {},
45 | "outputs": [],
46 | "source": [
47 | "#Prepare `statement` data following CommonsenseQA, OpenBookQA\n",
48 | "medqa_root = f'{repo_root}/data/medqa_usmle'\n",
49 | "os.system(f'mkdir -p {medqa_root}/statement')\n",
50 | "\n",
51 | "for fname in [\"train\", \"dev\", \"test\"]:\n",
52 | " with open(f\"{medqa_root}/raw/questions/US/4_options/phrases_no_exclude_{fname}.jsonl\") as f:\n",
53 | " lines = f.readlines()\n",
54 | " examples = []\n",
55 | " for i in tqdm(range(len(lines))):\n",
56 | " line = json.loads(lines[i])\n",
57 | " _id = f\"train-{i:05d}\"\n",
58 | " answerKey = line[\"answer_idx\"]\n",
59 | " stem = line[\"question\"] \n",
60 | " choices = [{\"label\": k, \"text\": line[\"options\"][k]} for k in \"ABCD\"]\n",
61 | " stmts = [{\"statement\": stem +\" \"+ c[\"text\"]} for c in choices]\n",
62 | " ex_obj = {\"id\": _id, \n",
63 | " \"question\": {\"stem\": stem, \"choices\": choices}, \n",
64 | " \"answerKey\": answerKey, \n",
65 | " \"statements\": stmts\n",
66 | " }\n",
67 | " examples.append(ex_obj)\n",
68 | " with open(f\"{medqa_root}/statement/{fname}.statement.jsonl\", 'w') as fout:\n",
69 | " for dic in examples:\n",
70 | " print (json.dumps(dic), file=fout)"
71 | ]
72 | },
73 | {
74 | "cell_type": "markdown",
75 | "metadata": {},
76 | "source": [
77 | "## Link entities to KG"
78 | ]
79 | },
80 | {
81 | "cell_type": "markdown",
82 | "metadata": {},
83 | "source": [
84 | "First, install the scispacy model:\n",
85 | "```\n",
86 | "pip install scispacy==0.3.0\n",
87 | "pip install https://s3-us-west-2.amazonaws.com/ai2-s2-scispacy/releases/v0.3.0/en_core_sci_sm-0.3.0.tar.gz\n",
88 | "```"
89 | ]
90 | },
91 | {
92 | "cell_type": "code",
93 | "execution_count": null,
94 | "metadata": {},
95 | "outputs": [],
96 | "source": [
97 | "#Load scispacy entity linker\n",
98 | "import spacy\n",
99 | "import scispacy\n",
100 | "from scispacy.linking import EntityLinker\n",
101 | "\n",
102 | "def load_entity_linker(threshold=0.90):\n",
103 | " nlp = spacy.load(\"en_core_sci_sm\")\n",
104 | " linker = EntityLinker(\n",
105 | " resolve_abbreviations=True,\n",
106 | " name=\"umls\",\n",
107 | " threshold=threshold)\n",
108 | " nlp.add_pipe(linker)\n",
109 | " return nlp, linker\n",
110 | "\n",
111 | "nlp, linker = load_entity_linker()"
112 | ]
113 | },
114 | {
115 | "cell_type": "code",
116 | "execution_count": null,
117 | "metadata": {},
118 | "outputs": [],
119 | "source": [
120 | "def entity_linking_to_umls(sentence, nlp, linker):\n",
121 | " doc = nlp(sentence)\n",
122 | " entities = doc.ents\n",
123 | " all_entities_results = []\n",
124 | " for mm in range(len(entities)):\n",
125 | " entity_text = entities[mm].text\n",
126 | " entity_start = entities[mm].start\n",
127 | " entity_end = entities[mm].end\n",
128 | " all_linked_entities = entities[mm]._.kb_ents\n",
129 | " all_entity_results = []\n",
130 | " for ii in range(len(all_linked_entities)):\n",
131 | " curr_concept_id = all_linked_entities[ii][0]\n",
132 | " curr_score = all_linked_entities[ii][1]\n",
133 | " curr_scispacy_entity = linker.kb.cui_to_entity[all_linked_entities[ii][0]]\n",
134 | " curr_canonical_name = curr_scispacy_entity.canonical_name\n",
135 | " curr_TUIs = curr_scispacy_entity.types\n",
136 | " curr_entity_result = {\"Canonical Name\": curr_canonical_name, \"Concept ID\": curr_concept_id,\n",
137 | " \"TUIs\": curr_TUIs, \"Score\": curr_score}\n",
138 | " all_entity_results.append(curr_entity_result)\n",
139 | " curr_entities_result = {\"text\": entity_text, \"start\": entity_start, \"end\": entity_end, \n",
140 | " \"start_char\": entities[mm].start_char, \"end_char\": entities[mm].end_char,\n",
141 | " \"linking_results\": all_entity_results}\n",
142 | " all_entities_results.append(curr_entities_result)\n",
143 | " return all_entities_results"
144 | ]
145 | },
146 | {
147 | "cell_type": "code",
148 | "execution_count": null,
149 | "metadata": {},
150 | "outputs": [],
151 | "source": [
152 | "#Example\n",
153 | "sent = \"A 5-year-old girl is brought to the emergency department by her mother because of multiple episodes of nausea and vomiting that last about 2 hours. During this period, she has had 6–8 episodes of bilious vomiting and abdominal pain. The vomiting was preceded by fatigue.\"\n",
154 | "ent_link_results = entity_linking_to_umls(sent, nlp, linker)\n",
155 | "ent_link_results"
156 | ]
157 | },
158 | {
159 | "cell_type": "code",
160 | "execution_count": null,
161 | "metadata": {},
162 | "outputs": [],
163 | "source": [
164 | "#Run entity linking to UMLS for all questions\n",
165 | "def process(input):\n",
166 | " nlp, linker = load_entity_linker()\n",
167 | " stmts = input\n",
168 | " for stmt in tqdm(stmts):\n",
169 | " stem = stmt['question']['stem']\n",
170 | " stem = stem[:3500]\n",
171 | " stmt['question']['stem_ents'] = entity_linking_to_umls(stem, nlp, linker)\n",
172 | " for ii, choice in enumerate(stmt['question']['choices']):\n",
173 | " text = stmt['question']['choices'][ii]['text']\n",
174 | " stmt['question']['choices'][ii]['text_ents'] = entity_linking_to_umls(text, nlp, linker)\n",
175 | " return stmts\n",
176 | "\n",
177 | "for fname in [\"dev\", \"test\", \"train\"]:\n",
178 | " with open(f\"{medqa_root}/statement/{fname}.statement.jsonl\") as fin:\n",
179 | " stmts = [json.loads(line) for line in fin]\n",
180 | " res = process(stmts) \n",
181 | " with open(f\"{medqa_root}/statement/{fname}.statement.umls_linked.jsonl\", 'w') as fout:\n",
182 | " for dic in res:\n",
183 | " print (json.dumps(dic), file=fout)"
184 | ]
185 | },
186 | {
187 | "cell_type": "code",
188 | "execution_count": null,
189 | "metadata": {},
190 | "outputs": [],
191 | "source": [
192 | "#Convert UMLS entity linking to DDB entity linking (our KG)\n",
193 | "umls_to_ddb = {}\n",
194 | "with open(f'{repo_root}/data/ddb/ddb_to_umls_cui.txt') as f:\n",
195 | " for line in f.readlines()[1:]:\n",
196 | " elms = line.split(\"\\t\")\n",
197 | " umls_to_ddb[elms[2]] = elms[1]\n",
198 | "\n",
199 | "def map_to_ddb(ent_obj):\n",
200 | " res = []\n",
201 | " for ent_cand in ent_obj['linking_results']:\n",
202 | " CUI = ent_cand['Concept ID']\n",
203 | " name = ent_cand['Canonical Name']\n",
204 | " if CUI in umls_to_ddb:\n",
205 | " ddb_cid = umls_to_ddb[CUI]\n",
206 | " res.append((ddb_cid, name))\n",
207 | " return res\n",
208 | "\n",
209 | "def process(fname):\n",
210 | " with open(f\"{medqa_root}/statement/{fname}.statement.umls_linked.jsonl\") as fin:\n",
211 | " stmts = [json.loads(line) for line in fin]\n",
212 | " with open(f\"{medqa_root}/grounded/{fname}.grounded.jsonl\", 'w') as fout:\n",
213 | " for stmt in tqdm(stmts):\n",
214 | " sent = stmt['question']['stem']\n",
215 | " qc = []\n",
216 | " qc_names = []\n",
217 | " for ent_obj in stmt['question']['stem_ents']:\n",
218 | " res = map_to_ddb(ent_obj)\n",
219 | " for elm in res:\n",
220 | " ddb_cid, name = elm\n",
221 | " qc.append(ddb_cid)\n",
222 | " qc_names.append(name)\n",
223 | " for cid, choice in enumerate(stmt['question']['choices']):\n",
224 | " ans = choice['text']\n",
225 | " ac = []\n",
226 | " ac_names = []\n",
227 | " for ent_obj in choice['text_ents']:\n",
228 | " res = map_to_ddb(ent_obj)\n",
229 | " for elm in res:\n",
230 | " ddb_cid, name = elm\n",
231 | " ac.append(ddb_cid)\n",
232 | " ac_names.append(name)\n",
233 | " out = {'sent': sent, 'ans': ans, 'qc': qc, 'qc_names': qc_names, 'ac': ac, 'ac_names': ac_names}\n",
234 | " print (json.dumps(out), file=fout) \n",
235 | "\n",
236 | "os.system(f'mkdir -p {medqa_root}/grounded')\n",
237 | "for fname in [\"dev\", \"test\", \"train\"]:\n",
238 | " process(fname) "
239 | ]
240 | },
241 | {
242 | "cell_type": "markdown",
243 | "metadata": {},
244 | "source": [
245 | "## Load knowledge graph (KG)"
246 | ]
247 | },
248 | {
249 | "cell_type": "markdown",
250 | "metadata": {},
251 | "source": [
252 | "Load our KG, which is based on Disease Database + DrugBank."
253 | ]
254 | },
255 | {
256 | "cell_type": "code",
257 | "execution_count": null,
258 | "metadata": {},
259 | "outputs": [],
260 | "source": [
261 | "def load_ddb(): \n",
262 | " with open(f'{repo_root}/data/ddb/ddb_names.json') as f:\n",
263 | " all_names = json.load(f)\n",
264 | " with open(f'{repo_root}/data/ddb/ddb_relas.json') as f:\n",
265 | " all_relas = json.load(f)\n",
266 | " relas_lst = []\n",
267 | " for key, val in all_relas.items():\n",
268 | " relas_lst.append(val)\n",
269 | " \n",
270 | " ddb_ptr_to_preferred_name = {}\n",
271 | " ddb_ptr_to_name = defaultdict(list)\n",
272 | " ddb_name_to_ptr = {}\n",
273 | " for key, val in all_names.items():\n",
274 | " item_name = key\n",
275 | " item_ptr = val[0]\n",
276 | " item_preferred = val[1]\n",
277 | " if item_preferred == \"1\":\n",
278 | " ddb_ptr_to_preferred_name[item_ptr] = item_name\n",
279 | " ddb_name_to_ptr[item_name] = item_ptr\n",
280 | " ddb_ptr_to_name[item_ptr].append(item_name)\n",
281 | " \n",
282 | " return (relas_lst, ddb_ptr_to_name, ddb_name_to_ptr, ddb_ptr_to_preferred_name)\n",
283 | "\n",
284 | "\n",
285 | "relas_lst, ddb_ptr_to_name, ddb_name_to_ptr, ddb_ptr_to_preferred_name = load_ddb()\n",
286 | "\n",
287 | "\n",
288 | "ddb_ptr_lst, ddb_names_lst = [], []\n",
289 | "for key, val in ddb_ptr_to_preferred_name.items():\n",
290 | " ddb_ptr_lst.append(key)\n",
291 | " ddb_names_lst.append(val)\n",
292 | "\n",
293 | "with open(f\"{repo_root}/data/ddb/vocab.txt\", \"w\") as fout:\n",
294 | " for ddb_name in ddb_names_lst:\n",
295 | " print (ddb_name, file=fout)\n",
296 | "\n",
297 | "with open(f\"{repo_root}/data/ddb/ptrs.txt\", \"w\") as fout:\n",
298 | " for ddb_ptr in ddb_ptr_lst:\n",
299 | " print (ddb_ptr, file=fout)\n",
300 | "\n",
301 | "id2concept = ddb_ptr_lst"
302 | ]
303 | },
304 | {
305 | "cell_type": "code",
306 | "execution_count": null,
307 | "metadata": {},
308 | "outputs": [],
309 | "source": [
310 | "len(ddb_ptr_to_name), len(ddb_ptr_to_preferred_name), len(ddb_name_to_ptr)"
311 | ]
312 | },
313 | {
314 | "cell_type": "code",
315 | "execution_count": null,
316 | "metadata": {},
317 | "outputs": [],
318 | "source": [
319 | "ddb_name_to_ptr['Ethanol'], ddb_name_to_ptr['Serine']"
320 | ]
321 | },
322 | {
323 | "cell_type": "code",
324 | "execution_count": null,
325 | "metadata": {},
326 | "outputs": [],
327 | "source": [
328 | "merged_relations = [\n",
329 | " 'belongs_to_the_category_of',\n",
330 | " 'is_a_category',\n",
331 | " 'may_cause',\n",
332 | " 'is_a_subtype_of',\n",
333 | " 'is_a_risk_factor_of',\n",
334 | " 'is_associated_with',\n",
335 | " 'may_contraindicate',\n",
336 | " 'interacts_with',\n",
337 | " 'belongs_to_the_drug_family_of',\n",
338 | " 'belongs_to_drug_super-family',\n",
339 | " 'is_a_vector_for',\n",
340 | " 'may_be_allelic_with',\n",
341 | " 'see_also',\n",
342 | " 'is_an_ingradient_of',\n",
343 | " 'may_treat'\n",
344 | "]\n",
345 | "\n",
346 | "relas_dict = {\"0\": 0, \"1\": 1, \"2\": 2, \"3\": 3, \"4\": 4, \"6\": 5, \"10\": 6, \"12\": 7, \"16\": 8, \"17\": 9, \"18\": 10,\n",
347 | " \"20\": 11, \"26\": 12, \"30\": 13, \"233\": 14}"
348 | ]
349 | },
350 | {
351 | "cell_type": "code",
352 | "execution_count": null,
353 | "metadata": {},
354 | "outputs": [],
355 | "source": [
356 | "import networkx as nx\n",
357 | "\n",
358 | "def construct_graph():\n",
359 | " concept2id = {w: i for i, w in enumerate(id2concept)}\n",
360 | " id2relation = merged_relations\n",
361 | " relation2id = {r: i for i, r in enumerate(id2relation)}\n",
362 | " graph = nx.MultiDiGraph()\n",
363 | " attrs = set()\n",
364 | " for relation in relas_lst:\n",
365 | " subj = concept2id[relation[0]]\n",
366 | " obj = concept2id[relation[1]]\n",
367 | " rel = relas_dict[relation[2]]\n",
368 | " weight = 1.\n",
369 | " graph.add_edge(subj, obj, rel=rel, weight=weight)\n",
370 | " attrs.add((subj, obj, rel))\n",
371 | " graph.add_edge(obj, subj, rel=rel + len(relation2id), weight=weight)\n",
372 | " attrs.add((obj, subj, rel + len(relation2id)))\n",
373 | " output_path = f\"{repo_root}/data/ddb/ddb.graph\"\n",
374 | " nx.write_gpickle(graph, output_path)\n",
375 | " return concept2id, id2relation, relation2id, graph\n",
376 | "\n",
377 | "concept2id, id2relation, relation2id, KG = construct_graph()"
378 | ]
379 | },
380 | {
381 | "cell_type": "markdown",
382 | "metadata": {},
383 | "source": [
384 | "## Get KG subgraph"
385 | ]
386 | },
387 | {
388 | "cell_type": "markdown",
389 | "metadata": {},
390 | "source": [
391 | "We get KG subgraph for each question, following the method used for CommonsenseQA + ConceptNet."
392 | ]
393 | },
394 | {
395 | "cell_type": "code",
396 | "execution_count": null,
397 | "metadata": {},
398 | "outputs": [],
399 | "source": [
400 | "def load_kg():\n",
401 | " global cpnet, cpnet_simple\n",
402 | " cpnet = KG\n",
403 | " cpnet_simple = nx.Graph()\n",
404 | " for u, v, data in cpnet.edges(data=True):\n",
405 | " w = data['weight'] if 'weight' in data else 1.0\n",
406 | " if cpnet_simple.has_edge(u, v):\n",
407 | " cpnet_simple[u][v]['weight'] += w\n",
408 | " else:\n",
409 | " cpnet_simple.add_edge(u, v, weight=w)\n",
410 | "\n",
411 | "load_kg()"
412 | ]
413 | },
414 | {
415 | "cell_type": "code",
416 | "execution_count": null,
417 | "metadata": {},
418 | "outputs": [],
419 | "source": [
420 | "from scipy.sparse import csr_matrix, coo_matrix\n",
421 | "from multiprocessing import Pool\n",
422 | "\n",
423 | "def concepts2adj(node_ids):\n",
424 | " global id2relation\n",
425 | " cids = np.array(node_ids, dtype=np.int32)\n",
426 | " n_rel = len(id2relation)\n",
427 | " n_node = cids.shape[0]\n",
428 | " adj = np.zeros((n_rel, n_node, n_node), dtype=np.uint8)\n",
429 | " for s in range(n_node):\n",
430 | " for t in range(n_node):\n",
431 | " s_c, t_c = cids[s], cids[t]\n",
432 | " if cpnet.has_edge(s_c, t_c):\n",
433 | " for e_attr in cpnet[s_c][t_c].values():\n",
434 | " if e_attr['rel'] >= 0 and e_attr['rel'] < n_rel:\n",
435 | " adj[e_attr['rel']][s][t] = 1\n",
436 | " adj = coo_matrix(adj.reshape(-1, n_node))\n",
437 | " return adj, cids\n",
438 | "\n",
439 | "def concepts_to_adj_matrices_2hop_all_pair(data):\n",
440 | " qc_ids, ac_ids = data\n",
441 | " qa_nodes = set(qc_ids) | set(ac_ids)\n",
442 | " extra_nodes = set()\n",
443 | " for qid in qa_nodes:\n",
444 | " for aid in qa_nodes:\n",
445 | " if qid != aid and qid in cpnet_simple.nodes and aid in cpnet_simple.nodes:\n",
446 | " extra_nodes |= set(cpnet_simple[qid]) & set(cpnet_simple[aid])\n",
447 | " extra_nodes = extra_nodes - qa_nodes\n",
448 | " schema_graph = sorted(qc_ids) + sorted(ac_ids) + sorted(extra_nodes)\n",
449 | " arange = np.arange(len(schema_graph))\n",
450 | " qmask = arange < len(qc_ids)\n",
451 | " amask = (arange >= len(qc_ids)) & (arange < (len(qc_ids) + len(ac_ids)))\n",
452 | " adj, concepts = concepts2adj(schema_graph)\n",
453 | " return {'adj': adj, 'concepts': concepts, 'qmask': qmask, 'amask': amask, 'cid2score': None}"
454 | ]
455 | },
456 | {
457 | "cell_type": "code",
458 | "execution_count": null,
459 | "metadata": {},
460 | "outputs": [],
461 | "source": [
462 | "def generate_adj_data_from_grounded_concepts(grounded_path, cpnet_graph_path, cpnet_vocab_path, output_path, num_processes):\n",
463 | " global concept2id, id2concept, relation2id, id2relation, cpnet_simple, cpnet\n",
464 | "\n",
465 | " qa_data = []\n",
466 | " with open(grounded_path, 'r', encoding='utf-8') as fin:\n",
467 | " for line in fin:\n",
468 | " dic = json.loads(line)\n",
469 | " q_ids = set(concept2id[c] for c in dic['qc'])\n",
470 | " if not q_ids:\n",
471 | " q_ids = {concept2id['31770']} \n",
472 | " a_ids = set(concept2id[c] for c in dic['ac'])\n",
473 | " if not a_ids:\n",
474 | " a_ids = {concept2id['325']}\n",
475 | " q_ids = q_ids - a_ids\n",
476 | " qa_data.append((q_ids, a_ids))\n",
477 | "\n",
478 | " with Pool(num_processes) as p:\n",
479 | " res = list(tqdm(p.imap(concepts_to_adj_matrices_2hop_all_pair, qa_data), total=len(qa_data)))\n",
480 | " \n",
481 | " lens = [len(e['concepts']) for e in res]\n",
482 | " print ('mean #nodes', int(np.mean(lens)), 'med', int(np.median(lens)), '5th', int(np.percentile(lens, 5)), '95th', int(np.percentile(lens, 95)))\n",
483 | "\n",
484 | " with open(output_path, 'wb') as fout:\n",
485 | " pickle.dump(res, fout)\n",
486 | "\n",
487 | " print(f'adj data saved to {output_path}')\n",
488 | " print()\n"
489 | ]
490 | },
491 | {
492 | "cell_type": "code",
493 | "execution_count": null,
494 | "metadata": {
495 | "scrolled": true
496 | },
497 | "outputs": [],
498 | "source": [
499 | "os.system(f'mkdir -p {repo_root}/data/medqa_usmle/graph')\n",
500 | "\n",
501 | "for fname in [\"dev\", \"test\", \"train\"]:\n",
502 | " grounded_path = f\"{repo_root}/data/medqa_usmle/grounded/{fname}.grounded.jsonl\"\n",
503 | " kg_path = f\"{repo_root}/data/ddb/ddb.graph\"\n",
504 | " kg_vocab_path = f\"{repo_root}/data/ddb/ddb_ptrs.txt\"\n",
505 | " output_path = f\"{repo_root}/data/medqa_usmle/graph/{fname}.graph.adj.pk\"\n",
506 | "\n",
507 | " generate_adj_data_from_grounded_concepts(grounded_path, kg_path, kg_vocab_path, output_path, 10)"
508 | ]
509 | },
510 | {
511 | "cell_type": "markdown",
512 | "metadata": {},
513 | "source": [
514 | "## Get KG entity embedding"
515 | ]
516 | },
517 | {
518 | "cell_type": "code",
519 | "execution_count": null,
520 | "metadata": {},
521 | "outputs": [],
522 | "source": [
523 | "import torch\n",
524 | "from transformers import AutoTokenizer, AutoModel, AutoConfig\n",
525 | "tokenizer = AutoTokenizer.from_pretrained(\"cambridgeltl/SapBERT-from-PubMedBERT-fulltext\")\n",
526 | "bert_model = AutoModel.from_pretrained(\"cambridgeltl/SapBERT-from-PubMedBERT-fulltext\"\")\n",
527 | "device = torch.device('cuda')\n",
528 | "bert_model.to(device)\n",
529 | "bert_model.eval()"
530 | ]
531 | },
532 | {
533 | "cell_type": "code",
534 | "execution_count": null,
535 | "metadata": {},
536 | "outputs": [],
537 | "source": [
538 | "with open(f\"{repo_root}/data/ddb/vocab.txt\") as f:\n",
539 | " names = [line.strip() for line in f]"
540 | ]
541 | },
542 | {
543 | "cell_type": "code",
544 | "execution_count": null,
545 | "metadata": {},
546 | "outputs": [],
547 | "source": [
548 | "embs = []\n",
549 | "tensors = tokenizer(names, padding=True, truncation=True, return_tensors=\"pt\")\n",
550 | "with torch.no_grad():\n",
551 | " for i, j in enumerate(tqdm(names)):\n",
552 | " outputs = bert_model(input_ids=tensors[\"input_ids\"][i:i+1].to(device), \n",
553 | " attention_mask=tensors['attention_mask'][i:i+1].to(device))\n",
554 | " out = np.array(outputs[1].squeeze().tolist()).reshape((1, -1))\n",
555 | " embs.append(out)\n",
556 | "embs = np.concatenate(embs)\n",
557 | "np.save(f\"{repo_root}/data/ddb/ent_emb.npy\", embs)"
558 | ]
559 | },
560 | {
561 | "cell_type": "code",
562 | "execution_count": null,
563 | "metadata": {},
564 | "outputs": [],
565 | "source": []
566 | }
567 | ],
568 | "metadata": {
569 | "kernelspec": {
570 | "display_name": "ct",
571 | "language": "python",
572 | "name": "ct"
573 | },
574 | "language_info": {
575 | "codemirror_mode": {
576 | "name": "ipython",
577 | "version": 3
578 | },
579 | "file_extension": ".py",
580 | "mimetype": "text/x-python",
581 | "name": "python",
582 | "nbconvert_exporter": "python",
583 | "pygments_lexer": "ipython3",
584 | "version": "3.7.10"
585 | }
586 | },
587 | "nbformat": 4,
588 | "nbformat_minor": 2
589 | }
590 |
--------------------------------------------------------------------------------
/qagnn.py:
--------------------------------------------------------------------------------
1 | import random
2 |
3 | try:
4 | from transformers import (ConstantLRSchedule, WarmupLinearSchedule, WarmupConstantSchedule)
5 | except:
6 | from transformers import get_constant_schedule, get_constant_schedule_with_warmup, get_linear_schedule_with_warmup
7 |
8 | from modeling.modeling_qagnn import *
9 | from utils.optimization_utils import OPTIMIZER_CLASSES
10 | from utils.parser_utils import *
11 |
12 |
13 | DECODER_DEFAULT_LR = {
14 | 'csqa': 1e-3,
15 | 'obqa': 3e-4,
16 | 'medqa_usmle': 1e-3,
17 | }
18 |
19 | from collections import defaultdict, OrderedDict
20 | import numpy as np
21 |
22 | import socket, os, subprocess, datetime
23 | print(socket.gethostname())
24 | print ("pid:", os.getpid())
25 | print ("conda env:", os.environ['CONDA_DEFAULT_ENV'])
26 | print ("screen: %s" % subprocess.check_output('echo $STY', shell=True).decode('utf'))
27 | print ("gpu: %s" % subprocess.check_output('echo $CUDA_VISIBLE_DEVICES', shell=True).decode('utf'))
28 |
29 |
30 | def evaluate_accuracy(eval_set, model):
31 | n_samples, n_correct = 0, 0
32 | model.eval()
33 | with torch.no_grad():
34 | for qids, labels, *input_data in tqdm(eval_set):
35 | logits, _ = model(*input_data)
36 | n_correct += (logits.argmax(1) == labels).sum().item()
37 | n_samples += labels.size(0)
38 | return n_correct / n_samples
39 |
40 |
41 | def main():
42 | parser = get_parser()
43 | args, _ = parser.parse_known_args()
44 | parser.add_argument('--mode', default='train', choices=['train', 'eval_detail'], help='run training or evaluation')
45 | parser.add_argument('--save_dir', default=f'./saved_models/qagnn/', help='model output directory')
46 | parser.add_argument('--save_model', dest='save_model', action='store_true')
47 | parser.add_argument('--load_model_path', default=None)
48 |
49 |
50 | # data
51 | parser.add_argument('--num_relation', default=38, type=int, help='number of relations')
52 | parser.add_argument('--train_adj', default=f'data/{args.dataset}/graph/train.graph.adj.pk')
53 | parser.add_argument('--dev_adj', default=f'data/{args.dataset}/graph/dev.graph.adj.pk')
54 | parser.add_argument('--test_adj', default=f'data/{args.dataset}/graph/test.graph.adj.pk')
55 | parser.add_argument('--use_cache', default=True, type=bool_flag, nargs='?', const=True, help='use cached data to accelerate data loading')
56 |
57 | # model architecture
58 | parser.add_argument('-k', '--k', default=5, type=int, help='perform k-layer message passing')
59 | parser.add_argument('--att_head_num', default=2, type=int, help='number of attention heads')
60 | parser.add_argument('--gnn_dim', default=100, type=int, help='dimension of the GNN layers')
61 | parser.add_argument('--fc_dim', default=200, type=int, help='number of FC hidden units')
62 | parser.add_argument('--fc_layer_num', default=0, type=int, help='number of FC layers')
63 | parser.add_argument('--freeze_ent_emb', default=True, type=bool_flag, nargs='?', const=True, help='freeze entity embedding layer')
64 |
65 | parser.add_argument('--max_node_num', default=200, type=int)
66 | parser.add_argument('--simple', default=False, type=bool_flag, nargs='?', const=True)
67 | parser.add_argument('--subsample', default=1.0, type=float)
68 | parser.add_argument('--init_range', default=0.02, type=float, help='stddev when initializing with normal distribution')
69 |
70 |
71 | # regularization
72 | parser.add_argument('--dropouti', type=float, default=0.2, help='dropout for embedding layer')
73 | parser.add_argument('--dropoutg', type=float, default=0.2, help='dropout for GNN layers')
74 | parser.add_argument('--dropoutf', type=float, default=0.2, help='dropout for fully-connected layers')
75 |
76 | # optimization
77 | parser.add_argument('-dlr', '--decoder_lr', default=DECODER_DEFAULT_LR[args.dataset], type=float, help='learning rate')
78 | parser.add_argument('-mbs', '--mini_batch_size', default=1, type=int)
79 | parser.add_argument('-ebs', '--eval_batch_size', default=2, type=int)
80 | parser.add_argument('--unfreeze_epoch', default=4, type=int)
81 | parser.add_argument('--refreeze_epoch', default=10000, type=int)
82 | parser.add_argument('--fp16', default=False, type=bool_flag, help='use fp16 training. this requires torch>=1.6.0')
83 | parser.add_argument('--drop_partial_batch', default=False, type=bool_flag, help='')
84 | parser.add_argument('--fill_partial_batch', default=False, type=bool_flag, help='')
85 |
86 | parser.add_argument('-h', '--help', action='help', default=argparse.SUPPRESS, help='show this help message and exit')
87 | args = parser.parse_args()
88 | if args.simple:
89 | parser.set_defaults(k=1)
90 | args = parser.parse_args()
91 | args.fp16 = args.fp16 and (torch.__version__ >= '1.6.0')
92 |
93 | if args.mode == 'train':
94 | train(args)
95 | elif args.mode == 'eval_detail':
96 | # raise NotImplementedError
97 | eval_detail(args)
98 | else:
99 | raise ValueError('Invalid mode')
100 |
101 |
102 |
103 |
104 | def train(args):
105 | print(args)
106 |
107 | random.seed(args.seed)
108 | np.random.seed(args.seed)
109 | torch.manual_seed(args.seed)
110 | if torch.cuda.is_available() and args.cuda:
111 | torch.cuda.manual_seed(args.seed)
112 |
113 | config_path = os.path.join(args.save_dir, 'config.json')
114 | model_path = os.path.join(args.save_dir, 'model.pt')
115 | log_path = os.path.join(args.save_dir, 'log.csv')
116 | export_config(args, config_path)
117 | check_path(model_path)
118 | with open(log_path, 'w') as fout:
119 | fout.write('step,dev_acc,test_acc\n')
120 |
121 | ###################################################################################################
122 | # Load data #
123 | ###################################################################################################
124 | cp_emb = [np.load(path) for path in args.ent_emb_paths]
125 | cp_emb = torch.tensor(np.concatenate(cp_emb, 1), dtype=torch.float)
126 |
127 | concept_num, concept_dim = cp_emb.size(0), cp_emb.size(1)
128 | print('| num_concepts: {} |'.format(concept_num))
129 |
130 | # try:
131 | if True:
132 | if torch.cuda.device_count() >= 2 and args.cuda:
133 | device0 = torch.device("cuda:0")
134 | device1 = torch.device("cuda:1")
135 | elif torch.cuda.device_count() == 1 and args.cuda:
136 | device0 = torch.device("cuda:0")
137 | device1 = torch.device("cuda:0")
138 | else:
139 | device0 = torch.device("cpu")
140 | device1 = torch.device("cpu")
141 | dataset = LM_QAGNN_DataLoader(args, args.train_statements, args.train_adj,
142 | args.dev_statements, args.dev_adj,
143 | args.test_statements, args.test_adj,
144 | batch_size=args.batch_size, eval_batch_size=args.eval_batch_size,
145 | device=(device0, device1),
146 | model_name=args.encoder,
147 | max_node_num=args.max_node_num, max_seq_length=args.max_seq_len,
148 | is_inhouse=args.inhouse, inhouse_train_qids_path=args.inhouse_train_qids,
149 | subsample=args.subsample, use_cache=args.use_cache)
150 |
151 | ###################################################################################################
152 | # Build model #
153 | ###################################################################################################
154 | print ('args.num_relation', args.num_relation)
155 | model = LM_QAGNN(args, args.encoder, k=args.k, n_ntype=4, n_etype=args.num_relation, n_concept=concept_num,
156 | concept_dim=args.gnn_dim,
157 | concept_in_dim=concept_dim,
158 | n_attention_head=args.att_head_num, fc_dim=args.fc_dim, n_fc_layer=args.fc_layer_num,
159 | p_emb=args.dropouti, p_gnn=args.dropoutg, p_fc=args.dropoutf,
160 | pretrained_concept_emb=cp_emb, freeze_ent_emb=args.freeze_ent_emb,
161 | init_range=args.init_range,
162 | encoder_config={})
163 | if args.load_model_path:
164 | print (f'loading and initializing model from {args.load_model_path}')
165 | model_state_dict, old_args = torch.load(args.load_model_path, map_location=torch.device('cpu'))
166 | model.load_state_dict(model_state_dict)
167 |
168 | model.encoder.to(device0)
169 | model.decoder.to(device1)
170 |
171 |
172 | no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
173 |
174 | grouped_parameters = [
175 | {'params': [p for n, p in model.encoder.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': args.weight_decay, 'lr': args.encoder_lr},
176 | {'params': [p for n, p in model.encoder.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0, 'lr': args.encoder_lr},
177 | {'params': [p for n, p in model.decoder.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': args.weight_decay, 'lr': args.decoder_lr},
178 | {'params': [p for n, p in model.decoder.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0, 'lr': args.decoder_lr},
179 | ]
180 | optimizer = OPTIMIZER_CLASSES[args.optim](grouped_parameters)
181 |
182 | if args.lr_schedule == 'fixed':
183 | try:
184 | scheduler = ConstantLRSchedule(optimizer)
185 | except:
186 | scheduler = get_constant_schedule(optimizer)
187 | elif args.lr_schedule == 'warmup_constant':
188 | try:
189 | scheduler = WarmupConstantSchedule(optimizer, warmup_steps=args.warmup_steps)
190 | except:
191 | scheduler = get_constant_schedule_with_warmup(optimizer, num_warmup_steps=args.warmup_steps)
192 | elif args.lr_schedule == 'warmup_linear':
193 | max_steps = int(args.n_epochs * (dataset.train_size() / args.batch_size))
194 | try:
195 | scheduler = WarmupLinearSchedule(optimizer, warmup_steps=args.warmup_steps, t_total=max_steps)
196 | except:
197 | scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=max_steps)
198 |
199 | print('parameters:')
200 | for name, param in model.decoder.named_parameters():
201 | if param.requires_grad:
202 | print('\t{:45}\ttrainable\t{}\tdevice:{}'.format(name, param.size(), param.device))
203 | else:
204 | print('\t{:45}\tfixed\t{}\tdevice:{}'.format(name, param.size(), param.device))
205 | num_params = sum(p.numel() for p in model.decoder.parameters() if p.requires_grad)
206 | print('\ttotal:', num_params)
207 |
208 | if args.loss == 'margin_rank':
209 | loss_func = nn.MarginRankingLoss(margin=0.1, reduction='mean')
210 | elif args.loss == 'cross_entropy':
211 | loss_func = nn.CrossEntropyLoss(reduction='mean')
212 |
213 | def compute_loss(logits, labels):
214 | if args.loss == 'margin_rank':
215 | num_choice = logits.size(1)
216 | flat_logits = logits.view(-1)
217 | correct_mask = F.one_hot(labels, num_classes=num_choice).view(-1) # of length batch_size*num_choice
218 | correct_logits = flat_logits[correct_mask == 1].contiguous().view(-1, 1).expand(-1, num_choice - 1).contiguous().view(-1) # of length batch_size*(num_choice-1)
219 | wrong_logits = flat_logits[correct_mask == 0]
220 | y = wrong_logits.new_ones((wrong_logits.size(0),))
221 | loss = loss_func(correct_logits, wrong_logits, y) # margin ranking loss
222 | elif args.loss == 'cross_entropy':
223 | loss = loss_func(logits, labels)
224 | return loss
225 |
226 | ###################################################################################################
227 | # Training #
228 | ###################################################################################################
229 |
230 | print()
231 | print('-' * 71)
232 | if args.fp16:
233 | print ('Using fp16 training')
234 | scaler = torch.cuda.amp.GradScaler()
235 |
236 | global_step, best_dev_epoch = 0, 0
237 | best_dev_acc, final_test_acc, total_loss = 0.0, 0.0, 0.0
238 | start_time = time.time()
239 | model.train()
240 | freeze_net(model.encoder)
241 | if True:
242 | # try:
243 | for epoch_id in range(args.n_epochs):
244 | if epoch_id == args.unfreeze_epoch:
245 | unfreeze_net(model.encoder)
246 | if epoch_id == args.refreeze_epoch:
247 | freeze_net(model.encoder)
248 | model.train()
249 | for qids, labels, *input_data in dataset.train():
250 | optimizer.zero_grad()
251 | bs = labels.size(0)
252 | for a in range(0, bs, args.mini_batch_size):
253 | b = min(a + args.mini_batch_size, bs)
254 | if args.fp16:
255 | with torch.cuda.amp.autocast():
256 | logits, _ = model(*[x[a:b] for x in input_data], layer_id=args.encoder_layer)
257 | loss = compute_loss(logits, labels[a:b])
258 | else:
259 | logits, _ = model(*[x[a:b] for x in input_data], layer_id=args.encoder_layer)
260 | loss = compute_loss(logits, labels[a:b])
261 | loss = loss * (b - a) / bs
262 | if args.fp16:
263 | scaler.scale(loss).backward()
264 | else:
265 | loss.backward()
266 | total_loss += loss.item()
267 | if args.max_grad_norm > 0:
268 | if args.fp16:
269 | scaler.unscale_(optimizer)
270 | nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
271 | else:
272 | nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
273 | scheduler.step()
274 | if args.fp16:
275 | scaler.step(optimizer)
276 | scaler.update()
277 | else:
278 | optimizer.step()
279 |
280 | if (global_step + 1) % args.log_interval == 0:
281 | total_loss /= args.log_interval
282 | ms_per_batch = 1000 * (time.time() - start_time) / args.log_interval
283 | print('| step {:5} | lr: {:9.7f} | loss {:7.4f} | ms/batch {:7.2f} |'.format(global_step, scheduler.get_lr()[0], total_loss, ms_per_batch))
284 | total_loss = 0
285 | start_time = time.time()
286 | global_step += 1
287 |
288 | model.eval()
289 | dev_acc = evaluate_accuracy(dataset.dev(), model)
290 | save_test_preds = args.save_model
291 | if not save_test_preds:
292 | test_acc = evaluate_accuracy(dataset.test(), model) if args.test_statements else 0.0
293 | else:
294 | eval_set = dataset.test()
295 | total_acc = []
296 | count = 0
297 | preds_path = os.path.join(args.save_dir, 'test_e{}_preds.csv'.format(epoch_id))
298 | with open(preds_path, 'w') as f_preds:
299 | with torch.no_grad():
300 | for qids, labels, *input_data in tqdm(eval_set):
301 | count += 1
302 | logits, _, concept_ids, node_type_ids, edge_index, edge_type = model(*input_data, detail=True)
303 | predictions = logits.argmax(1) #[bsize, ]
304 | preds_ranked = (-logits).argsort(1) #[bsize, n_choices]
305 | for i, (qid, label, pred, _preds_ranked, cids, ntype, edges, etype) in enumerate(zip(qids, labels, predictions, preds_ranked, concept_ids, node_type_ids, edge_index, edge_type)):
306 | acc = int(pred.item()==label.item())
307 | print ('{},{}'.format(qid, chr(ord('A') + pred.item())), file=f_preds)
308 | f_preds.flush()
309 | total_acc.append(acc)
310 | test_acc = float(sum(total_acc))/len(total_acc)
311 |
312 | print('-' * 71)
313 | print('| epoch {:3} | step {:5} | dev_acc {:7.4f} | test_acc {:7.4f} |'.format(epoch_id, global_step, dev_acc, test_acc))
314 | print('-' * 71)
315 | with open(log_path, 'a') as fout:
316 | fout.write('{},{},{}\n'.format(global_step, dev_acc, test_acc))
317 | if dev_acc >= best_dev_acc:
318 | best_dev_acc = dev_acc
319 | final_test_acc = test_acc
320 | best_dev_epoch = epoch_id
321 | if args.save_model:
322 | torch.save([model.state_dict(), args], f"{model_path}.{epoch_id}")
323 | # with open(model_path +".{}.log.txt".format(epoch_id), 'w') as f:
324 | # for p in model.named_parameters():
325 | # print (p, file=f)
326 | print(f'model saved to {model_path}.{epoch_id}')
327 | else:
328 | if args.save_model:
329 | torch.save([model.state_dict(), args], f"{model_path}.{epoch_id}")
330 | # with open(model_path +".{}.log.txt".format(epoch_id), 'w') as f:
331 | # for p in model.named_parameters():
332 | # print (p, file=f)
333 | print(f'model saved to {model_path}.{epoch_id}')
334 | model.train()
335 | start_time = time.time()
336 | if epoch_id > args.unfreeze_epoch and epoch_id - best_dev_epoch >= args.max_epochs_before_stop:
337 | break
338 | # except (KeyboardInterrupt, RuntimeError) as e:
339 | # print(e)
340 |
341 |
342 |
343 | def eval_detail(args):
344 | assert args.load_model_path is not None
345 | model_path = args.load_model_path
346 |
347 | cp_emb = [np.load(path) for path in args.ent_emb_paths]
348 | cp_emb = torch.tensor(np.concatenate(cp_emb, 1), dtype=torch.float)
349 | concept_num, concept_dim = cp_emb.size(0), cp_emb.size(1)
350 | print('| num_concepts: {} |'.format(concept_num))
351 |
352 | model_state_dict, old_args = torch.load(model_path, map_location=torch.device('cpu'))
353 | model = LM_QAGNN(old_args, old_args.encoder, k=old_args.k, n_ntype=4, n_etype=old_args.num_relation, n_concept=concept_num,
354 | concept_dim=old_args.gnn_dim,
355 | concept_in_dim=concept_dim,
356 | n_attention_head=old_args.att_head_num, fc_dim=old_args.fc_dim, n_fc_layer=old_args.fc_layer_num,
357 | p_emb=old_args.dropouti, p_gnn=old_args.dropoutg, p_fc=old_args.dropoutf,
358 | pretrained_concept_emb=cp_emb, freeze_ent_emb=old_args.freeze_ent_emb,
359 | init_range=old_args.init_range,
360 | encoder_config={})
361 | model.load_state_dict(model_state_dict)
362 |
363 | if torch.cuda.device_count() >= 2 and args.cuda:
364 | device0 = torch.device("cuda:0")
365 | device1 = torch.device("cuda:1")
366 | elif torch.cuda.device_count() == 1 and args.cuda:
367 | device0 = torch.device("cuda:0")
368 | device1 = torch.device("cuda:0")
369 | else:
370 | device0 = torch.device("cpu")
371 | device1 = torch.device("cpu")
372 | model.encoder.to(device0)
373 | model.decoder.to(device1)
374 | model.eval()
375 |
376 | statement_dic = {}
377 | for statement_path in (args.train_statements, args.dev_statements, args.test_statements):
378 | statement_dic.update(load_statement_dict(statement_path))
379 |
380 | use_contextualized = 'lm' in old_args.ent_emb
381 |
382 | print ('inhouse?', args.inhouse)
383 |
384 | print ('args.train_statements', args.train_statements)
385 | print ('args.dev_statements', args.dev_statements)
386 | print ('args.test_statements', args.test_statements)
387 | print ('args.train_adj', args.train_adj)
388 | print ('args.dev_adj', args.dev_adj)
389 | print ('args.test_adj', args.test_adj)
390 |
391 | dataset = LM_QAGNN_DataLoader(args, args.train_statements, args.train_adj,
392 | args.dev_statements, args.dev_adj,
393 | args.test_statements, args.test_adj,
394 | batch_size=args.batch_size, eval_batch_size=args.eval_batch_size,
395 | device=(device0, device1),
396 | model_name=old_args.encoder,
397 | max_node_num=old_args.max_node_num, max_seq_length=old_args.max_seq_len,
398 | is_inhouse=args.inhouse, inhouse_train_qids_path=args.inhouse_train_qids,
399 | subsample=args.subsample, use_cache=args.use_cache)
400 |
401 | save_test_preds = args.save_model
402 | dev_acc = evaluate_accuracy(dataset.dev(), model)
403 | print('dev_acc {:7.4f}'.format(dev_acc))
404 | if not save_test_preds:
405 | test_acc = evaluate_accuracy(dataset.test(), model) if args.test_statements else 0.0
406 | else:
407 | eval_set = dataset.test()
408 | total_acc = []
409 | count = 0
410 | dt = datetime.datetime.today().strftime('%Y%m%d%H%M%S')
411 | preds_path = os.path.join(args.save_dir, 'test_preds_{}.csv'.format(dt))
412 | with open(preds_path, 'w') as f_preds:
413 | with torch.no_grad():
414 | for qids, labels, *input_data in tqdm(eval_set):
415 | count += 1
416 | logits, _, concept_ids, node_type_ids, edge_index, edge_type = model(*input_data, detail=True)
417 | predictions = logits.argmax(1) #[bsize, ]
418 | preds_ranked = (-logits).argsort(1) #[bsize, n_choices]
419 | for i, (qid, label, pred, _preds_ranked, cids, ntype, edges, etype) in enumerate(zip(qids, labels, predictions, preds_ranked, concept_ids, node_type_ids, edge_index, edge_type)):
420 | acc = int(pred.item()==label.item())
421 | print ('{},{}'.format(qid, chr(ord('A') + pred.item())), file=f_preds)
422 | f_preds.flush()
423 | total_acc.append(acc)
424 | test_acc = float(sum(total_acc))/len(total_acc)
425 |
426 | print('-' * 71)
427 | print('test_acc {:7.4f}'.format(test_acc))
428 | print('-' * 71)
429 |
430 |
431 |
432 | if __name__ == '__main__':
433 | main()
434 |
--------------------------------------------------------------------------------
/utils/graph.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import networkx as nx
3 | import itertools
4 | import json
5 | from tqdm import tqdm
6 | from .conceptnet import merged_relations
7 | import numpy as np
8 | from scipy import sparse
9 | import pickle
10 | from scipy.sparse import csr_matrix, coo_matrix
11 | from multiprocessing import Pool
12 | from collections import OrderedDict
13 |
14 |
15 | from .maths import *
16 |
17 | __all__ = ['generate_graph']
18 |
19 | concept2id = None
20 | id2concept = None
21 | relation2id = None
22 | id2relation = None
23 |
24 | cpnet = None
25 | cpnet_all = None
26 | cpnet_simple = None
27 |
28 |
29 | def load_resources(cpnet_vocab_path):
30 | global concept2id, id2concept, relation2id, id2relation
31 |
32 | with open(cpnet_vocab_path, "r", encoding="utf8") as fin:
33 | id2concept = [w.strip() for w in fin]
34 | concept2id = {w: i for i, w in enumerate(id2concept)}
35 |
36 | id2relation = merged_relations
37 | relation2id = {r: i for i, r in enumerate(id2relation)}
38 |
39 |
40 | def load_cpnet(cpnet_graph_path):
41 | global cpnet, cpnet_simple
42 | cpnet = nx.read_gpickle(cpnet_graph_path)
43 | cpnet_simple = nx.Graph()
44 | for u, v, data in cpnet.edges(data=True):
45 | w = data['weight'] if 'weight' in data else 1.0
46 | if cpnet_simple.has_edge(u, v):
47 | cpnet_simple[u][v]['weight'] += w
48 | else:
49 | cpnet_simple.add_edge(u, v, weight=w)
50 |
51 |
52 | def relational_graph_generation(qcs, acs, paths, rels):
53 | raise NotImplementedError() # TODO
54 |
55 |
56 | # plain graph generation
57 | def plain_graph_generation(qcs, acs, paths, rels):
58 | global cpnet, concept2id, relation2id, id2relation, id2concept, cpnet_simple
59 |
60 | graph = nx.Graph()
61 | for p in paths:
62 | for c_index in range(len(p) - 1):
63 | h = p[c_index]
64 | t = p[c_index + 1]
65 | # TODO: the weight can computed by concept embeddings and relation embeddings of TransE
66 | graph.add_edge(h, t, weight=1.0)
67 |
68 | for qc1, qc2 in list(itertools.combinations(qcs, 2)):
69 | if cpnet_simple.has_edge(qc1, qc2):
70 | graph.add_edge(qc1, qc2, weight=1.0)
71 |
72 | for ac1, ac2 in list(itertools.combinations(acs, 2)):
73 | if cpnet_simple.has_edge(ac1, ac2):
74 | graph.add_edge(ac1, ac2, weight=1.0)
75 |
76 | if len(qcs) == 0:
77 | qcs.append(-1)
78 |
79 | if len(acs) == 0:
80 | acs.append(-1)
81 |
82 | if len(paths) == 0:
83 | for qc in qcs:
84 | for ac in acs:
85 | graph.add_edge(qc, ac, rel=-1, weight=0.1)
86 |
87 | g = nx.convert_node_labels_to_integers(graph, label_attribute='cid') # re-index
88 | return nx.node_link_data(g)
89 |
90 |
91 | def generate_adj_matrix_per_inst(nxg_str):
92 | global id2relation
93 | n_rel = len(id2relation)
94 |
95 | nxg = nx.node_link_graph(json.loads(nxg_str))
96 | n_node = len(nxg.nodes)
97 | cids = np.zeros(n_node, dtype=np.int32)
98 | for node_id, node_attr in nxg.nodes(data=True):
99 | cids[node_id] = node_attr['cid']
100 |
101 | adj = np.zeros((n_rel, n_node, n_node), dtype=np.uint8)
102 | for s in range(n_node):
103 | for t in range(n_node):
104 | s_c, t_c = cids[s], cids[t]
105 | if cpnet_all.has_edge(s_c, t_c):
106 | for e_attr in cpnet_all[s_c][t_c].values():
107 | if e_attr['rel'] >= 0 and e_attr['rel'] < n_rel:
108 | adj[e_attr['rel']][s][t] = 1
109 | cids += 1
110 | adj = coo_matrix(adj.reshape(-1, n_node))
111 | return (adj, cids)
112 |
113 |
114 | def concepts2adj(node_ids):
115 | global id2relation
116 | cids = np.array(node_ids, dtype=np.int32)
117 | n_rel = len(id2relation)
118 | n_node = cids.shape[0]
119 | adj = np.zeros((n_rel, n_node, n_node), dtype=np.uint8)
120 | for s in range(n_node):
121 | for t in range(n_node):
122 | s_c, t_c = cids[s], cids[t]
123 | if cpnet.has_edge(s_c, t_c):
124 | for e_attr in cpnet[s_c][t_c].values():
125 | if e_attr['rel'] >= 0 and e_attr['rel'] < n_rel:
126 | adj[e_attr['rel']][s][t] = 1
127 | # cids += 1 # note!!! index 0 is reserved for padding
128 | adj = coo_matrix(adj.reshape(-1, n_node))
129 | return adj, cids
130 |
131 |
132 | def concepts_to_adj_matrices_1hop_neighbours(data):
133 | qc_ids, ac_ids = data
134 | qa_nodes = set(qc_ids) | set(ac_ids)
135 | extra_nodes = set()
136 | for u in set(qc_ids) | set(ac_ids):
137 | if u in cpnet.nodes:
138 | extra_nodes |= set(cpnet[u])
139 | extra_nodes = extra_nodes - qa_nodes
140 | schema_graph = sorted(qc_ids) + sorted(ac_ids) + sorted(extra_nodes)
141 | arange = np.arange(len(schema_graph))
142 | qmask = arange < len(qc_ids)
143 | amask = (arange >= len(qc_ids)) & (arange < (len(qc_ids) + len(ac_ids)))
144 | adj, concepts = concepts2adj(schema_graph)
145 | return adj, concepts, qmask, amask
146 |
147 |
148 | def concepts_to_adj_matrices_1hop_neighbours_without_relatedto(data):
149 | qc_ids, ac_ids = data
150 | qa_nodes = set(qc_ids) | set(ac_ids)
151 | extra_nodes = set()
152 | for u in set(qc_ids) | set(ac_ids):
153 | if u in cpnet.nodes:
154 | for v in cpnet[u]:
155 | for data in cpnet[u][v].values():
156 | if data['rel'] not in (15, 32):
157 | extra_nodes.add(v)
158 | extra_nodes = extra_nodes - qa_nodes
159 | schema_graph = sorted(qc_ids) + sorted(ac_ids) + sorted(extra_nodes)
160 | arange = np.arange(len(schema_graph))
161 | qmask = arange < len(qc_ids)
162 | amask = (arange >= len(qc_ids)) & (arange < (len(qc_ids) + len(ac_ids)))
163 | adj, concepts = concepts2adj(schema_graph)
164 | return adj, concepts, qmask, amask
165 |
166 |
167 | def concepts_to_adj_matrices_2hop_qa_pair(data):
168 | qc_ids, ac_ids = data
169 | qa_nodes = set(qc_ids) | set(ac_ids)
170 | extra_nodes = set()
171 | for qid in qc_ids:
172 | for aid in ac_ids:
173 | if qid != aid and qid in cpnet_simple.nodes and aid in cpnet_simple.nodes:
174 | extra_nodes |= set(cpnet_simple[qid]) & set(cpnet_simple[aid])
175 | extra_nodes = extra_nodes - qa_nodes
176 | schema_graph = sorted(qc_ids) + sorted(ac_ids) + sorted(extra_nodes)
177 | arange = np.arange(len(schema_graph))
178 | qmask = arange < len(qc_ids)
179 | amask = (arange >= len(qc_ids)) & (arange < (len(qc_ids) + len(ac_ids)))
180 | adj, concepts = concepts2adj(schema_graph)
181 | return adj, concepts, qmask, amask
182 |
183 |
184 | def concepts_to_adj_matrices_2hop_all_pair(data):
185 | qc_ids, ac_ids = data
186 | qa_nodes = set(qc_ids) | set(ac_ids)
187 | extra_nodes = set()
188 | for qid in qa_nodes:
189 | for aid in qa_nodes:
190 | if qid != aid and qid in cpnet_simple.nodes and aid in cpnet_simple.nodes:
191 | extra_nodes |= set(cpnet_simple[qid]) & set(cpnet_simple[aid])
192 | extra_nodes = extra_nodes - qa_nodes
193 | schema_graph = sorted(qc_ids) + sorted(ac_ids) + sorted(extra_nodes)
194 | arange = np.arange(len(schema_graph))
195 | qmask = arange < len(qc_ids)
196 | amask = (arange >= len(qc_ids)) & (arange < (len(qc_ids) + len(ac_ids)))
197 | adj, concepts = concepts2adj(schema_graph)
198 | return adj, concepts, qmask, amask
199 |
200 |
201 | def concepts_to_adj_matrices_2step_relax_all_pair(data):
202 | qc_ids, ac_ids = data
203 | qa_nodes = set(qc_ids) | set(ac_ids)
204 | extra_nodes = set()
205 | for qid in qc_ids:
206 | for aid in ac_ids:
207 | if qid != aid and qid in cpnet_simple.nodes and aid in cpnet_simple.nodes:
208 | extra_nodes |= set(cpnet_simple[qid]) & set(cpnet_simple[aid])
209 | intermediate_ids = extra_nodes - qa_nodes
210 | for qid in intermediate_ids:
211 | for aid in ac_ids:
212 | if qid != aid and qid in cpnet_simple.nodes and aid in cpnet_simple.nodes:
213 | extra_nodes |= set(cpnet_simple[qid]) & set(cpnet_simple[aid])
214 | for qid in qc_ids:
215 | for aid in intermediate_ids:
216 | if qid != aid and qid in cpnet_simple.nodes and aid in cpnet_simple.nodes:
217 | extra_nodes |= set(cpnet_simple[qid]) & set(cpnet_simple[aid])
218 | extra_nodes = extra_nodes - qa_nodes
219 | schema_graph = sorted(qc_ids) + sorted(ac_ids) + sorted(extra_nodes)
220 | arange = np.arange(len(schema_graph))
221 | qmask = arange < len(qc_ids)
222 | amask = (arange >= len(qc_ids)) & (arange < (len(qc_ids) + len(ac_ids)))
223 | adj, concepts = concepts2adj(schema_graph)
224 | return adj, concepts, qmask, amask
225 |
226 |
227 | def concepts_to_adj_matrices_3hop_qa_pair(data):
228 | qc_ids, ac_ids = data
229 | qa_nodes = set(qc_ids) | set(ac_ids)
230 | extra_nodes = set()
231 | for qid in qc_ids:
232 | for aid in ac_ids:
233 | if qid != aid and qid in cpnet_simple.nodes and aid in cpnet_simple.nodes:
234 | for u in cpnet_simple[qid]:
235 | for v in cpnet_simple[aid]:
236 | if cpnet_simple.has_edge(u, v): # ac is a 3-hop neighbour of qc
237 | extra_nodes.add(u)
238 | extra_nodes.add(v)
239 | if u == v: # ac is a 2-hop neighbour of qc
240 | extra_nodes.add(u)
241 | extra_nodes = extra_nodes - qa_nodes
242 | schema_graph = sorted(qc_ids) + sorted(ac_ids) + sorted(extra_nodes)
243 | arange = np.arange(len(schema_graph))
244 | qmask = arange < len(qc_ids)
245 | amask = (arange >= len(qc_ids)) & (arange < (len(qc_ids) + len(ac_ids)))
246 | adj, concepts = concepts2adj(schema_graph)
247 | return adj, concepts, qmask, amask
248 |
249 |
250 |
251 | ######################################################################
252 | from transformers import RobertaTokenizer, RobertaForMaskedLM
253 |
254 | class RobertaForMaskedLMwithLoss(RobertaForMaskedLM):
255 | #
256 | def __init__(self, config):
257 | super().__init__(config)
258 | #
259 | def forward(self, input_ids=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, masked_lm_labels=None):
260 | #
261 | assert attention_mask is not None
262 | outputs = self.roberta(input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, position_ids=position_ids, head_mask=head_mask)
263 | sequence_output = outputs[0] #hidden_states of final layer (batch_size, sequence_length, hidden_size)
264 | prediction_scores = self.lm_head(sequence_output)
265 | outputs = (prediction_scores, sequence_output) + outputs[2:]
266 | if masked_lm_labels is not None:
267 | loss_fct = torch.nn.CrossEntropyLoss(reduction='none')
268 | bsize, seqlen = input_ids.size()
269 | masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), masked_lm_labels.view(-1)).view(bsize, seqlen)
270 | masked_lm_loss = (masked_lm_loss * attention_mask).sum(dim=1)
271 | outputs = (masked_lm_loss,) + outputs
272 | # (masked_lm_loss), prediction_scores, sequence_output, (hidden_states), (attentions)
273 | return outputs
274 |
275 | print ('loading pre-trained LM...')
276 | TOKENIZER = RobertaTokenizer.from_pretrained('roberta-large')
277 | LM_MODEL = RobertaForMaskedLMwithLoss.from_pretrained('roberta-large')
278 | LM_MODEL.cuda(); LM_MODEL.eval()
279 | print ('loading done')
280 |
281 | def get_LM_score(cids, question):
282 | cids = cids[:]
283 | cids.insert(0, -1) #QAcontext node
284 | sents, scores = [], []
285 | for cid in cids:
286 | if cid==-1:
287 | sent = question.lower()
288 | else:
289 | sent = '{} {}.'.format(question.lower(), ' '.join(id2concept[cid].split('_')))
290 | sent = TOKENIZER.encode(sent, add_special_tokens=True)
291 | sents.append(sent)
292 | n_cids = len(cids)
293 | cur_idx = 0
294 | batch_size = 50
295 | while cur_idx < n_cids:
296 | #Prepare batch
297 | input_ids = sents[cur_idx: cur_idx+batch_size]
298 | max_len = max([len(seq) for seq in input_ids])
299 | for j, seq in enumerate(input_ids):
300 | seq += [TOKENIZER.pad_token_id] * (max_len-len(seq))
301 | input_ids[j] = seq
302 | input_ids = torch.tensor(input_ids).cuda() #[B, seqlen]
303 | mask = (input_ids!=1).long() #[B, seq_len]
304 | #Get LM score
305 | with torch.no_grad():
306 | outputs = LM_MODEL(input_ids, attention_mask=mask, masked_lm_labels=input_ids)
307 | loss = outputs[0] #[B, ]
308 | _scores = list(-loss.detach().cpu().numpy()) #list of float
309 | scores += _scores
310 | cur_idx += batch_size
311 | assert len(sents) == len(scores) == len(cids)
312 | cid2score = OrderedDict(sorted(list(zip(cids, scores)), key=lambda x: -x[1])) #score: from high to low
313 | return cid2score
314 |
315 | def concepts_to_adj_matrices_2hop_all_pair__use_LM__Part1(data):
316 | qc_ids, ac_ids, question = data
317 | qa_nodes = set(qc_ids) | set(ac_ids)
318 | extra_nodes = set()
319 | for qid in qa_nodes:
320 | for aid in qa_nodes:
321 | if qid != aid and qid in cpnet_simple.nodes and aid in cpnet_simple.nodes:
322 | extra_nodes |= set(cpnet_simple[qid]) & set(cpnet_simple[aid])
323 | extra_nodes = extra_nodes - qa_nodes
324 | return (sorted(qc_ids), sorted(ac_ids), question, sorted(extra_nodes))
325 |
326 | def concepts_to_adj_matrices_2hop_all_pair__use_LM__Part2(data):
327 | qc_ids, ac_ids, question, extra_nodes = data
328 | cid2score = get_LM_score(qc_ids+ac_ids+extra_nodes, question)
329 | return (qc_ids, ac_ids, question, extra_nodes, cid2score)
330 |
331 | def concepts_to_adj_matrices_2hop_all_pair__use_LM__Part3(data):
332 | qc_ids, ac_ids, question, extra_nodes, cid2score = data
333 | schema_graph = qc_ids + ac_ids + sorted(extra_nodes, key=lambda x: -cid2score[x]) #score: from high to low
334 | arange = np.arange(len(schema_graph))
335 | qmask = arange < len(qc_ids)
336 | amask = (arange >= len(qc_ids)) & (arange < (len(qc_ids) + len(ac_ids)))
337 | adj, concepts = concepts2adj(schema_graph)
338 | return {'adj': adj, 'concepts': concepts, 'qmask': qmask, 'amask': amask, 'cid2score': cid2score}
339 |
340 | ################################################################################
341 |
342 |
343 |
344 | #####################################################################################################
345 | # functions below this line will be called by preprocess.py #
346 | #####################################################################################################
347 |
348 |
349 | def generate_graph(grounded_path, pruned_paths_path, cpnet_vocab_path, cpnet_graph_path, output_path):
350 | print(f'generating schema graphs for {grounded_path} and {pruned_paths_path}...')
351 |
352 | global concept2id, id2concept, relation2id, id2relation
353 | if any(x is None for x in [concept2id, id2concept, relation2id, id2relation]):
354 | load_resources(cpnet_vocab_path)
355 |
356 | global cpnet, cpnet_simple
357 | if cpnet is None or cpnet_simple is None:
358 | load_cpnet(cpnet_graph_path)
359 |
360 | nrow = sum(1 for _ in open(grounded_path, 'r'))
361 | with open(grounded_path, 'r') as fin_gr, \
362 | open(pruned_paths_path, 'r') as fin_pf, \
363 | open(output_path, 'w') as fout:
364 | for line_gr, line_pf in tqdm(zip(fin_gr, fin_pf), total=nrow):
365 | mcp = json.loads(line_gr)
366 | qa_pairs = json.loads(line_pf)
367 |
368 | statement_paths = []
369 | statement_rel_list = []
370 | for qas in qa_pairs:
371 | if qas["pf_res"] is None:
372 | cur_paths = []
373 | cur_rels = []
374 | else:
375 | cur_paths = [item["path"] for item in qas["pf_res"]]
376 | cur_rels = [item["rel"] for item in qas["pf_res"]]
377 | statement_paths.extend(cur_paths)
378 | statement_rel_list.extend(cur_rels)
379 |
380 | qcs = [concept2id[c] for c in mcp["qc"]]
381 | acs = [concept2id[c] for c in mcp["ac"]]
382 |
383 | gobj = plain_graph_generation(qcs=qcs, acs=acs,
384 | paths=statement_paths,
385 | rels=statement_rel_list)
386 | fout.write(json.dumps(gobj) + '\n')
387 |
388 | print(f'schema graphs saved to {output_path}')
389 | print()
390 |
391 |
392 | def generate_adj_matrices(ori_schema_graph_path, cpnet_graph_path, cpnet_vocab_path, output_path, num_processes, num_rels=34, debug=False):
393 | print(f'generating adjacency matrices for {ori_schema_graph_path} and {cpnet_graph_path}...')
394 |
395 | global concept2id, id2concept, relation2id, id2relation
396 | if any(x is None for x in [concept2id, id2concept, relation2id, id2relation]):
397 | load_resources(cpnet_vocab_path)
398 |
399 | global cpnet_all
400 | if cpnet_all is None:
401 | cpnet_all = nx.read_gpickle(cpnet_graph_path)
402 |
403 | with open(ori_schema_graph_path, 'r') as fin:
404 | nxg_strs = [line for line in fin]
405 |
406 | if debug:
407 | nxgs = nxgs[:1]
408 |
409 | with Pool(num_processes) as p:
410 | res = list(tqdm(p.imap(generate_adj_matrix_per_inst, nxg_strs), total=len(nxg_strs)))
411 |
412 | with open(output_path, 'wb') as fout:
413 | pickle.dump(res, fout)
414 |
415 | print(f'adjacency matrices saved to {output_path}')
416 | print()
417 |
418 |
419 | def generate_adj_data_from_grounded_concepts(grounded_path, cpnet_graph_path, cpnet_vocab_path, output_path, num_processes):
420 | """
421 | This function will save
422 | (1) adjacency matrics (each in the form of a (R*N, N) coo sparse matrix)
423 | (2) concepts ids
424 | (3) qmask that specifices whether a node is a question concept
425 | (4) amask that specifices whether a node is a answer concept
426 | to the output path in python pickle format
427 |
428 | grounded_path: str
429 | cpnet_graph_path: str
430 | cpnet_vocab_path: str
431 | output_path: str
432 | num_processes: int
433 | """
434 | print(f'generating adj data for {grounded_path}...')
435 |
436 | global concept2id, id2concept, relation2id, id2relation, cpnet_simple, cpnet
437 | if any(x is None for x in [concept2id, id2concept, relation2id, id2relation]):
438 | load_resources(cpnet_vocab_path)
439 | if cpnet is None or cpnet_simple is None:
440 | load_cpnet(cpnet_graph_path)
441 |
442 | qa_data = []
443 | with open(grounded_path, 'r', encoding='utf-8') as fin:
444 | for line in fin:
445 | dic = json.loads(line)
446 | q_ids = set(concept2id[c] for c in dic['qc'])
447 | a_ids = set(concept2id[c] for c in dic['ac'])
448 | q_ids = q_ids - a_ids
449 | qa_data.append((q_ids, a_ids))
450 |
451 | with Pool(num_processes) as p:
452 | res = list(tqdm(p.imap(concepts_to_adj_matrices_2hop_all_pair, qa_data), total=len(qa_data)))
453 |
454 | # res is a list of tuples, each tuple consists of four elements (adj, concepts, qmask, amask)
455 | with open(output_path, 'wb') as fout:
456 | pickle.dump(res, fout)
457 |
458 | print(f'adj data saved to {output_path}')
459 | print()
460 |
461 |
462 |
463 | def generate_adj_data_from_grounded_concepts__use_LM(grounded_path, cpnet_graph_path, cpnet_vocab_path, output_path, num_processes):
464 | """
465 | This function will save
466 | (1) adjacency matrics (each in the form of a (R*N, N) coo sparse matrix)
467 | (2) concepts ids
468 | (3) qmask that specifices whether a node is a question concept
469 | (4) amask that specifices whether a node is a answer concept
470 | (5) cid2score that maps a concept id to its relevance score given the QA context
471 | to the output path in python pickle format
472 |
473 | grounded_path: str
474 | cpnet_graph_path: str
475 | cpnet_vocab_path: str
476 | output_path: str
477 | num_processes: int
478 | """
479 | print(f'generating adj data for {grounded_path}...')
480 |
481 | global concept2id, id2concept, relation2id, id2relation, cpnet_simple, cpnet
482 | if any(x is None for x in [concept2id, id2concept, relation2id, id2relation]):
483 | load_resources(cpnet_vocab_path)
484 | if cpnet is None or cpnet_simple is None:
485 | load_cpnet(cpnet_graph_path)
486 |
487 | qa_data = []
488 | statement_path = grounded_path.replace('grounded', 'statement')
489 | with open(grounded_path, 'r', encoding='utf-8') as fin_ground, open(statement_path, 'r', encoding='utf-8') as fin_state:
490 | lines_ground = fin_ground.readlines()
491 | lines_state = fin_state.readlines()
492 | assert len(lines_ground) % len(lines_state) == 0
493 | n_choices = len(lines_ground) // len(lines_state)
494 | for j, line in enumerate(lines_ground):
495 | dic = json.loads(line)
496 | q_ids = set(concept2id[c] for c in dic['qc'])
497 | a_ids = set(concept2id[c] for c in dic['ac'])
498 | q_ids = q_ids - a_ids
499 | statement_obj = json.loads(lines_state[j//n_choices])
500 | QAcontext = "{} {}.".format(statement_obj['question']['stem'], dic['ans'])
501 | qa_data.append((q_ids, a_ids, QAcontext))
502 |
503 | with Pool(num_processes) as p:
504 | res1 = list(tqdm(p.imap(concepts_to_adj_matrices_2hop_all_pair__use_LM__Part1, qa_data), total=len(qa_data)))
505 |
506 | res2 = []
507 | for j, _data in enumerate(res1):
508 | if j % 100 == 0: print (j)
509 | res2.append(concepts_to_adj_matrices_2hop_all_pair__use_LM__Part2(_data))
510 |
511 | with Pool(num_processes) as p:
512 | res3 = list(tqdm(p.imap(concepts_to_adj_matrices_2hop_all_pair__use_LM__Part3, res2), total=len(res2)))
513 |
514 | # res is a list of responses
515 | with open(output_path, 'wb') as fout:
516 | pickle.dump(res3, fout)
517 |
518 | print(f'adj data saved to {output_path}')
519 | print()
520 |
521 |
522 |
523 | #################### adj to sparse ####################
524 |
525 | def coo_to_normalized_per_inst(data):
526 | adj, concepts, qm, am, max_node_num = data
527 | ori_adj_len = len(concepts)
528 | concepts = torch.tensor(concepts[:min(len(concepts), max_node_num)])
529 | adj_len = len(concepts)
530 | qm = torch.tensor(qm[:adj_len], dtype=torch.uint8)
531 | am = torch.tensor(am[:adj_len], dtype=torch.uint8)
532 | ij = adj.row
533 | k = adj.col
534 | n_node = adj.shape[1]
535 | n_rel = 2 * adj.shape[0] // n_node
536 | i, j = ij // n_node, ij % n_node
537 | mask = (j < max_node_num) & (k < max_node_num)
538 | i, j, k = i[mask], j[mask], k[mask]
539 | i, j, k = np.concatenate((i, i + n_rel // 2), 0), np.concatenate((j, k), 0), np.concatenate((k, j), 0) # add inverse relations
540 | adj_list = []
541 | for r in range(n_rel):
542 | mask = i == r
543 | ones = np.ones(mask.sum(), dtype=np.float32)
544 | A = sparse.csr_matrix((ones, (k[mask], j[mask])), shape=(max_node_num, max_node_num)) # A is transposed by exchanging the order of j and k
545 | adj_list.append(normalize_sparse_adj(A, 'coo'))
546 | adj_list.append(sparse.identity(max_node_num, dtype=np.float32, format='coo'))
547 | return ori_adj_len, adj_len, concepts, adj_list, qm, am
548 |
549 |
550 | def coo_to_normalized(adj_path, output_path, max_node_num, num_processes):
551 | print(f'converting {adj_path} to normalized adj')
552 |
553 | with open(adj_path, 'rb') as fin:
554 | adj_data = pickle.load(fin)
555 | data = [(adj, concepts, qmask, amask, max_node_num) for adj, concepts, qmask, amask in adj_data]
556 |
557 | ori_adj_lengths = torch.zeros((len(data),), dtype=torch.int64)
558 | adj_lengths = torch.zeros((len(data),), dtype=torch.int64)
559 | concepts_ids = torch.zeros((len(data), max_node_num), dtype=torch.int64)
560 | qmask = torch.zeros((len(data), max_node_num), dtype=torch.uint8)
561 | amask = torch.zeros((len(data), max_node_num), dtype=torch.uint8)
562 |
563 | adj_data = []
564 | with Pool(num_processes) as p:
565 | for i, (ori_adj_len, adj_len, concepts, adj_list, qm, am) in tqdm(enumerate(p.imap(coo_to_normalized_per_inst, data)), total=len(data)):
566 | ori_adj_lengths[i] = ori_adj_len
567 | adj_lengths[i] = adj_len
568 | concepts_ids[i][:adj_len] = concepts
569 | qmask[i][:adj_len] = qm
570 | amask[i][:adj_len] = am
571 | adj_list = [(torch.LongTensor(np.stack((adj.row, adj.col), 0)),
572 | torch.FloatTensor(adj.data)) for adj in adj_list]
573 | adj_data.append(adj_list)
574 |
575 | torch.save((ori_adj_lengths, adj_lengths, concepts_ids, adj_data), output_path)
576 |
577 | print(f'normalized adj saved to {output_path}')
578 | print()
579 |
580 | # if __name__ == '__main__':
581 | # generate_adj_matrices_from_grounded_concepts('./data/csqa/grounded/train.grounded.jsonl',
582 | # './data/cpnet/conceptnet.en.pruned.graph',
583 | # './data/cpnet/concept.txt',
584 | # '/tmp/asdf', 40)
585 |
--------------------------------------------------------------------------------
/modeling/modeling_qagnn.py:
--------------------------------------------------------------------------------
1 | from modeling.modeling_encoder import TextEncoder, MODEL_NAME_TO_CLASS
2 | from utils.data_utils import *
3 | from utils.layers import *
4 | import torch.nn.functional as F
5 |
6 |
7 | class QAGNN_Message_Passing(nn.Module):
8 | def __init__(self, args, k, n_ntype, n_etype, input_size, hidden_size, output_size,
9 | dropout=0.1):
10 | super().__init__()
11 | assert input_size == output_size
12 | self.args = args
13 | self.n_ntype = n_ntype
14 | self.n_etype = n_etype
15 |
16 | assert input_size == hidden_size
17 | self.hidden_size = hidden_size
18 |
19 | self.emb_node_type = nn.Linear(self.n_ntype, hidden_size//2)
20 |
21 | self.basis_f = 'sin' #['id', 'linact', 'sin', 'none']
22 | if self.basis_f in ['id']:
23 | self.emb_score = nn.Linear(1, hidden_size//2)
24 | elif self.basis_f in ['linact']:
25 | self.B_lin = nn.Linear(1, hidden_size//2)
26 | self.emb_score = nn.Linear(hidden_size//2, hidden_size//2)
27 | elif self.basis_f in ['sin']:
28 | self.emb_score = nn.Linear(hidden_size//2, hidden_size//2)
29 |
30 | self.edge_encoder = torch.nn.Sequential(torch.nn.Linear(n_etype +1 + n_ntype *2, hidden_size), torch.nn.BatchNorm1d(hidden_size), torch.nn.ReLU(), torch.nn.Linear(hidden_size, hidden_size))
31 |
32 |
33 | self.k = k
34 | self.gnn_layers = nn.ModuleList([GATConvE(args, hidden_size, n_ntype, n_etype, self.edge_encoder) for _ in range(k)])
35 |
36 |
37 | self.Vh = nn.Linear(input_size, output_size)
38 | self.Vx = nn.Linear(hidden_size, output_size)
39 |
40 | self.activation = GELU()
41 | self.dropout = nn.Dropout(dropout)
42 | self.dropout_rate = dropout
43 |
44 |
45 | def mp_helper(self, _X, edge_index, edge_type, _node_type, _node_feature_extra):
46 | for _ in range(self.k):
47 | _X = self.gnn_layers[_](_X, edge_index, edge_type, _node_type, _node_feature_extra)
48 | _X = self.activation(_X)
49 | _X = F.dropout(_X, self.dropout_rate, training = self.training)
50 | return _X
51 |
52 |
53 | def forward(self, H, A, node_type, node_score, cache_output=False):
54 | """
55 | H: tensor of shape (batch_size, n_node, d_node)
56 | node features from the previous layer
57 | A: (edge_index, edge_type)
58 | node_type: long tensor of shape (batch_size, n_node)
59 | 0 == question entity; 1 == answer choice entity; 2 == other node; 3 == context node
60 | node_score: tensor of shape (batch_size, n_node, 1)
61 | """
62 | _batch_size, _n_nodes = node_type.size()
63 |
64 | #Embed type
65 | T = make_one_hot(node_type.view(-1).contiguous(), self.n_ntype).view(_batch_size, _n_nodes, self.n_ntype)
66 | node_type_emb = self.activation(self.emb_node_type(T)) #[batch_size, n_node, dim/2]
67 |
68 | #Embed score
69 | if self.basis_f == 'sin':
70 | js = torch.arange(self.hidden_size//2).unsqueeze(0).unsqueeze(0).float().to(node_type.device) #[1,1,dim/2]
71 | js = torch.pow(1.1, js) #[1,1,dim/2]
72 | B = torch.sin(js * node_score) #[batch_size, n_node, dim/2]
73 | node_score_emb = self.activation(self.emb_score(B)) #[batch_size, n_node, dim/2]
74 | elif self.basis_f == 'id':
75 | B = node_score
76 | node_score_emb = self.activation(self.emb_score(B)) #[batch_size, n_node, dim/2]
77 | elif self.basis_f == 'linact':
78 | B = self.activation(self.B_lin(node_score)) #[batch_size, n_node, dim/2]
79 | node_score_emb = self.activation(self.emb_score(B)) #[batch_size, n_node, dim/2]
80 |
81 |
82 | X = H
83 | edge_index, edge_type = A #edge_index: [2, total_E] edge_type: [total_E, ] where total_E is for the batched graph
84 | _X = X.view(-1, X.size(2)).contiguous() #[`total_n_nodes`, d_node] where `total_n_nodes` = b_size * n_node
85 | _node_type = node_type.view(-1).contiguous() #[`total_n_nodes`, ]
86 | _node_feature_extra = torch.cat([node_type_emb, node_score_emb], dim=2).view(_node_type.size(0), -1).contiguous() #[`total_n_nodes`, dim]
87 |
88 | _X = self.mp_helper(_X, edge_index, edge_type, _node_type, _node_feature_extra)
89 |
90 | X = _X.view(node_type.size(0), node_type.size(1), -1) #[batch_size, n_node, dim]
91 |
92 | output = self.activation(self.Vh(H) + self.Vx(X))
93 | output = self.dropout(output)
94 |
95 | return output
96 |
97 |
98 |
99 | class QAGNN(nn.Module):
100 | def __init__(self, args, k, n_ntype, n_etype, sent_dim,
101 | n_concept, concept_dim, concept_in_dim, n_attention_head,
102 | fc_dim, n_fc_layer, p_emb, p_gnn, p_fc,
103 | pretrained_concept_emb=None, freeze_ent_emb=True,
104 | init_range=0.02):
105 | super().__init__()
106 | self.init_range = init_range
107 |
108 | self.concept_emb = CustomizedEmbedding(concept_num=n_concept, concept_out_dim=concept_dim,
109 | use_contextualized=False, concept_in_dim=concept_in_dim,
110 | pretrained_concept_emb=pretrained_concept_emb, freeze_ent_emb=freeze_ent_emb)
111 | self.svec2nvec = nn.Linear(sent_dim, concept_dim)
112 |
113 | self.concept_dim = concept_dim
114 |
115 | self.activation = GELU()
116 |
117 | self.gnn = QAGNN_Message_Passing(args, k=k, n_ntype=n_ntype, n_etype=n_etype,
118 | input_size=concept_dim, hidden_size=concept_dim, output_size=concept_dim, dropout=p_gnn)
119 |
120 | self.pooler = MultiheadAttPoolLayer(n_attention_head, sent_dim, concept_dim)
121 |
122 | self.fc = MLP(concept_dim + sent_dim + concept_dim, fc_dim, 1, n_fc_layer, p_fc, layer_norm=True)
123 |
124 | self.dropout_e = nn.Dropout(p_emb)
125 | self.dropout_fc = nn.Dropout(p_fc)
126 |
127 | if init_range > 0:
128 | self.apply(self._init_weights)
129 |
130 |
131 | def _init_weights(self, module):
132 | if isinstance(module, (nn.Linear, nn.Embedding)):
133 | module.weight.data.normal_(mean=0.0, std=self.init_range)
134 | if hasattr(module, 'bias') and module.bias is not None:
135 | module.bias.data.zero_()
136 | elif isinstance(module, nn.LayerNorm):
137 | module.bias.data.zero_()
138 | module.weight.data.fill_(1.0)
139 |
140 |
141 | def forward(self, sent_vecs, concept_ids, node_type_ids, node_scores, adj_lengths, adj, emb_data=None, cache_output=False):
142 | """
143 | sent_vecs: (batch_size, dim_sent)
144 | concept_ids: (batch_size, n_node)
145 | adj: edge_index, edge_type
146 | adj_lengths: (batch_size,)
147 | node_type_ids: (batch_size, n_node)
148 | 0 == question entity; 1 == answer choice entity; 2 == other node; 3 == context node
149 | node_scores: (batch_size, n_node, 1)
150 |
151 | returns: (batch_size, 1)
152 | """
153 | gnn_input0 = self.activation(self.svec2nvec(sent_vecs)).unsqueeze(1) #(batch_size, 1, dim_node)
154 | gnn_input1 = self.concept_emb(concept_ids[:, 1:]-1, emb_data) #(batch_size, n_node-1, dim_node)
155 | gnn_input1 = gnn_input1.to(node_type_ids.device)
156 | gnn_input = self.dropout_e(torch.cat([gnn_input0, gnn_input1], dim=1)) #(batch_size, n_node, dim_node)
157 |
158 |
159 | #Normalize node sore (use norm from Z)
160 | _mask = (torch.arange(node_scores.size(1), device=node_scores.device) < adj_lengths.unsqueeze(1)).float() #0 means masked out #[batch_size, n_node]
161 | node_scores = -node_scores
162 | node_scores = node_scores - node_scores[:, 0:1, :] #[batch_size, n_node, 1]
163 | node_scores = node_scores.squeeze(2) #[batch_size, n_node]
164 | node_scores = node_scores * _mask
165 | mean_norm = (torch.abs(node_scores)).sum(dim=1) / adj_lengths #[batch_size, ]
166 | node_scores = node_scores / (mean_norm.unsqueeze(1) + 1e-05) #[batch_size, n_node]
167 | node_scores = node_scores.unsqueeze(2) #[batch_size, n_node, 1]
168 |
169 |
170 | gnn_output = self.gnn(gnn_input, adj, node_type_ids, node_scores)
171 |
172 | Z_vecs = gnn_output[:,0] #(batch_size, dim_node)
173 |
174 | mask = torch.arange(node_type_ids.size(1), device=node_type_ids.device) >= adj_lengths.unsqueeze(1) #1 means masked out
175 |
176 | mask = mask | (node_type_ids == 3) #pool over all KG nodes
177 | mask[mask.all(1), 0] = 0 # a temporary solution to avoid zero node
178 |
179 | sent_vecs_for_pooler = sent_vecs
180 | graph_vecs, pool_attn = self.pooler(sent_vecs_for_pooler, gnn_output, mask)
181 |
182 | if cache_output:
183 | self.concept_ids = concept_ids
184 | self.adj = adj
185 | self.pool_attn = pool_attn
186 |
187 | concat = self.dropout_fc(torch.cat((graph_vecs, sent_vecs, Z_vecs), 1))
188 | logits = self.fc(concat)
189 | return logits, pool_attn
190 |
191 |
192 | class LM_QAGNN(nn.Module):
193 | def __init__(self, args, model_name, k, n_ntype, n_etype,
194 | n_concept, concept_dim, concept_in_dim, n_attention_head,
195 | fc_dim, n_fc_layer, p_emb, p_gnn, p_fc,
196 | pretrained_concept_emb=None, freeze_ent_emb=True,
197 | init_range=0.0, encoder_config={}):
198 | super().__init__()
199 | self.encoder = TextEncoder(model_name, **encoder_config)
200 | self.decoder = QAGNN(args, k, n_ntype, n_etype, self.encoder.sent_dim,
201 | n_concept, concept_dim, concept_in_dim, n_attention_head,
202 | fc_dim, n_fc_layer, p_emb, p_gnn, p_fc,
203 | pretrained_concept_emb=pretrained_concept_emb, freeze_ent_emb=freeze_ent_emb,
204 | init_range=init_range)
205 |
206 |
207 | def forward(self, *inputs, layer_id=-1, cache_output=False, detail=False):
208 | """
209 | sent_vecs: (batch_size, num_choice, d_sent) -> (batch_size * num_choice, d_sent)
210 | concept_ids: (batch_size, num_choice, n_node) -> (batch_size * num_choice, n_node)
211 | node_type_ids: (batch_size, num_choice, n_node) -> (batch_size * num_choice, n_node)
212 | adj_lengths: (batch_size, num_choice) -> (batch_size * num_choice, )
213 | adj -> edge_index, edge_type
214 | edge_index: list of (batch_size, num_choice) -> list of (batch_size * num_choice, ); each entry is torch.tensor(2, E(variable))
215 | -> (2, total E)
216 | edge_type: list of (batch_size, num_choice) -> list of (batch_size * num_choice, ); each entry is torch.tensor(E(variable), )
217 | -> (total E, )
218 | returns: (batch_size, 1)
219 | """
220 | bs, nc = inputs[0].size(0), inputs[0].size(1)
221 |
222 | #Here, merge the batch dimension and the num_choice dimension
223 | edge_index_orig, edge_type_orig = inputs[-2:]
224 | _inputs = [x.view(x.size(0) * x.size(1), *x.size()[2:]) for x in inputs[:-6]] + [x.view(x.size(0) * x.size(1), *x.size()[2:]) for x in inputs[-6:-2]] + [sum(x,[]) for x in inputs[-2:]]
225 |
226 | *lm_inputs, concept_ids, node_type_ids, node_scores, adj_lengths, edge_index, edge_type = _inputs
227 | edge_index, edge_type = self.batch_graph(edge_index, edge_type, concept_ids.size(1))
228 | adj = (edge_index.to(node_type_ids.device), edge_type.to(node_type_ids.device)) #edge_index: [2, total_E] edge_type: [total_E, ]
229 |
230 | sent_vecs, all_hidden_states = self.encoder(*lm_inputs, layer_id=layer_id)
231 | logits, attn = self.decoder(sent_vecs.to(node_type_ids.device),
232 | concept_ids,
233 | node_type_ids, node_scores, adj_lengths, adj,
234 | emb_data=None, cache_output=cache_output)
235 | logits = logits.view(bs, nc)
236 | if not detail:
237 | return logits, attn
238 | else:
239 | return logits, attn, concept_ids.view(bs, nc, -1), node_type_ids.view(bs, nc, -1), edge_index_orig, edge_type_orig
240 | #edge_index_orig: list of (batch_size, num_choice). each entry is torch.tensor(2, E)
241 | #edge_type_orig: list of (batch_size, num_choice). each entry is torch.tensor(E, )
242 |
243 |
244 | def batch_graph(self, edge_index_init, edge_type_init, n_nodes):
245 | #edge_index_init: list of (n_examples, ). each entry is torch.tensor(2, E)
246 | #edge_type_init: list of (n_examples, ). each entry is torch.tensor(E, )
247 | n_examples = len(edge_index_init)
248 | edge_index = [edge_index_init[_i_] + _i_ * n_nodes for _i_ in range(n_examples)]
249 | edge_index = torch.cat(edge_index, dim=1) #[2, total_E]
250 | edge_type = torch.cat(edge_type_init, dim=0) #[total_E, ]
251 | return edge_index, edge_type
252 |
253 |
254 |
255 | class LM_QAGNN_DataLoader(object):
256 |
257 | def __init__(self, args, train_statement_path, train_adj_path,
258 | dev_statement_path, dev_adj_path,
259 | test_statement_path, test_adj_path,
260 | batch_size, eval_batch_size, device, model_name, max_node_num=200, max_seq_length=128,
261 | is_inhouse=False, inhouse_train_qids_path=None,
262 | subsample=1.0, use_cache=True):
263 | super().__init__()
264 | self.args = args
265 | self.batch_size = batch_size
266 | self.eval_batch_size = eval_batch_size
267 | self.device0, self.device1 = device
268 | self.is_inhouse = is_inhouse
269 |
270 | model_type = MODEL_NAME_TO_CLASS[model_name]
271 | print ('train_statement_path', train_statement_path)
272 | self.train_qids, self.train_labels, *self.train_encoder_data = load_input_tensors(train_statement_path, model_type, model_name, max_seq_length)
273 | self.dev_qids, self.dev_labels, *self.dev_encoder_data = load_input_tensors(dev_statement_path, model_type, model_name, max_seq_length)
274 |
275 | num_choice = self.train_encoder_data[0].size(1)
276 | self.num_choice = num_choice
277 | print ('num_choice', num_choice)
278 | *self.train_decoder_data, self.train_adj_data = load_sparse_adj_data_with_contextnode(train_adj_path, max_node_num, num_choice, args)
279 |
280 | *self.dev_decoder_data, self.dev_adj_data = load_sparse_adj_data_with_contextnode(dev_adj_path, max_node_num, num_choice, args)
281 | assert all(len(self.train_qids) == len(self.train_adj_data[0]) == x.size(0) for x in [self.train_labels] + self.train_encoder_data + self.train_decoder_data)
282 | assert all(len(self.dev_qids) == len(self.dev_adj_data[0]) == x.size(0) for x in [self.dev_labels] + self.dev_encoder_data + self.dev_decoder_data)
283 |
284 | if test_statement_path is not None:
285 | self.test_qids, self.test_labels, *self.test_encoder_data = load_input_tensors(test_statement_path, model_type, model_name, max_seq_length)
286 | *self.test_decoder_data, self.test_adj_data = load_sparse_adj_data_with_contextnode(test_adj_path, max_node_num, num_choice, args)
287 | assert all(len(self.test_qids) == len(self.test_adj_data[0]) == x.size(0) for x in [self.test_labels] + self.test_encoder_data + self.test_decoder_data)
288 |
289 |
290 | if self.is_inhouse:
291 | with open(inhouse_train_qids_path, 'r') as fin:
292 | inhouse_qids = set(line.strip() for line in fin)
293 | self.inhouse_train_indexes = torch.tensor([i for i, qid in enumerate(self.train_qids) if qid in inhouse_qids])
294 | self.inhouse_test_indexes = torch.tensor([i for i, qid in enumerate(self.train_qids) if qid not in inhouse_qids])
295 |
296 | assert 0. < subsample <= 1.
297 | if subsample < 1.:
298 | n_train = int(self.train_size() * subsample)
299 | assert n_train > 0
300 | if self.is_inhouse:
301 | self.inhouse_train_indexes = self.inhouse_train_indexes[:n_train]
302 | else:
303 | self.train_qids = self.train_qids[:n_train]
304 | self.train_labels = self.train_labels[:n_train]
305 | self.train_encoder_data = [x[:n_train] for x in self.train_encoder_data]
306 | self.train_decoder_data = [x[:n_train] for x in self.train_decoder_data]
307 | self.train_adj_data = self.train_adj_data[:n_train]
308 | assert all(len(self.train_qids) == len(self.train_adj_data[0]) == x.size(0) for x in [self.train_labels] + self.train_encoder_data + self.train_decoder_data)
309 | assert self.train_size() == n_train
310 |
311 | def train_size(self):
312 | return self.inhouse_train_indexes.size(0) if self.is_inhouse else len(self.train_qids)
313 |
314 | def dev_size(self):
315 | return len(self.dev_qids)
316 |
317 | def test_size(self):
318 | if self.is_inhouse:
319 | return self.inhouse_test_indexes.size(0)
320 | else:
321 | return len(self.test_qids) if hasattr(self, 'test_qids') else 0
322 |
323 | def train(self):
324 | if self.is_inhouse:
325 | n_train = self.inhouse_train_indexes.size(0)
326 | train_indexes = self.inhouse_train_indexes[torch.randperm(n_train)]
327 | else:
328 | train_indexes = torch.randperm(len(self.train_qids))
329 | return MultiGPUSparseAdjDataBatchGenerator(self.args, 'train', self.device0, self.device1, self.batch_size, train_indexes, self.train_qids, self.train_labels, tensors0=self.train_encoder_data, tensors1=self.train_decoder_data, adj_data=self.train_adj_data)
330 |
331 | def train_eval(self):
332 | return MultiGPUSparseAdjDataBatchGenerator(self.args, 'eval', self.device0, self.device1, self.eval_batch_size, torch.arange(len(self.train_qids)), self.train_qids, self.train_labels, tensors0=self.train_encoder_data, tensors1=self.train_decoder_data, adj_data=self.train_adj_data)
333 |
334 | def dev(self):
335 | return MultiGPUSparseAdjDataBatchGenerator(self.args, 'eval', self.device0, self.device1, self.eval_batch_size, torch.arange(len(self.dev_qids)), self.dev_qids, self.dev_labels, tensors0=self.dev_encoder_data, tensors1=self.dev_decoder_data, adj_data=self.dev_adj_data)
336 |
337 | def test(self):
338 | if self.is_inhouse:
339 | return MultiGPUSparseAdjDataBatchGenerator(self.args, 'eval', self.device0, self.device1, self.eval_batch_size, self.inhouse_test_indexes, self.train_qids, self.train_labels, tensors0=self.train_encoder_data, tensors1=self.train_decoder_data, adj_data=self.train_adj_data)
340 | else:
341 | return MultiGPUSparseAdjDataBatchGenerator(self.args, 'eval', self.device0, self.device1, self.eval_batch_size, torch.arange(len(self.test_qids)), self.test_qids, self.test_labels, tensors0=self.test_encoder_data, tensors1=self.test_decoder_data, adj_data=self.test_adj_data)
342 |
343 |
344 |
345 |
346 |
347 | ###############################################################################
348 | ############################### GNN architecture ##############################
349 | ###############################################################################
350 |
351 | from torch.autograd import Variable
352 | def make_one_hot(labels, C):
353 | '''
354 | Converts an integer label torch.autograd.Variable to a one-hot Variable.
355 | labels : torch.autograd.Variable of torch.cuda.LongTensor
356 | (N, ), where N is batch size.
357 | Each value is an integer representing correct classification.
358 | C : integer.
359 | number of classes in labels.
360 | Returns : torch.autograd.Variable of torch.cuda.FloatTensor
361 | N x C, where C is class number. One-hot encoded.
362 | '''
363 | labels = labels.unsqueeze(1)
364 | one_hot = torch.FloatTensor(labels.size(0), C).zero_().to(labels.device)
365 | target = one_hot.scatter_(1, labels.data, 1)
366 | target = Variable(target)
367 | return target
368 |
369 |
370 |
371 | from torch_geometric.nn import MessagePassing
372 | from torch_geometric.utils import add_self_loops, degree, softmax
373 | from torch_geometric.nn import global_add_pool, global_mean_pool, global_max_pool, GlobalAttention, Set2Set
374 | import torch.nn.functional as F
375 | from torch_scatter import scatter_add, scatter
376 | from torch_geometric.nn.inits import glorot, zeros
377 |
378 |
379 |
380 | class GATConvE(MessagePassing):
381 | """
382 | Args:
383 | emb_dim (int): dimensionality of GNN hidden states
384 | n_ntype (int): number of node types (e.g. 4)
385 | n_etype (int): number of edge relation types (e.g. 38)
386 | """
387 | def __init__(self, args, emb_dim, n_ntype, n_etype, edge_encoder, head_count=4, aggr="add"):
388 | super(GATConvE, self).__init__(aggr=aggr)
389 | self.args = args
390 |
391 | assert emb_dim % 2 == 0
392 | self.emb_dim = emb_dim
393 |
394 | self.n_ntype = n_ntype; self.n_etype = n_etype
395 | self.edge_encoder = edge_encoder
396 |
397 | #For attention
398 | self.head_count = head_count
399 | assert emb_dim % head_count == 0
400 | self.dim_per_head = emb_dim // head_count
401 | self.linear_key = nn.Linear(3*emb_dim, head_count * self.dim_per_head)
402 | self.linear_msg = nn.Linear(3*emb_dim, head_count * self.dim_per_head)
403 | self.linear_query = nn.Linear(2*emb_dim, head_count * self.dim_per_head)
404 |
405 | self._alpha = None
406 |
407 | #For final MLP
408 | self.mlp = torch.nn.Sequential(torch.nn.Linear(emb_dim, emb_dim), torch.nn.BatchNorm1d(emb_dim), torch.nn.ReLU(), torch.nn.Linear(emb_dim, emb_dim))
409 |
410 |
411 | def forward(self, x, edge_index, edge_type, node_type, node_feature_extra, return_attention_weights=False):
412 | # x: [N, emb_dim]
413 | # edge_index: [2, E]
414 | # edge_type [E,] -> edge_attr: [E, 39] / self_edge_attr: [N, 39]
415 | # node_type [N,] -> headtail_attr [E, 8(=4+4)] / self_headtail_attr: [N, 8]
416 | # node_feature_extra [N, dim]
417 |
418 | #Prepare edge feature
419 | edge_vec = make_one_hot(edge_type, self.n_etype +1) #[E, 39]
420 | self_edge_vec = torch.zeros(x.size(0), self.n_etype +1).to(edge_vec.device)
421 | self_edge_vec[:,self.n_etype] = 1
422 |
423 | head_type = node_type[edge_index[0]] #[E,] #head=src
424 | tail_type = node_type[edge_index[1]] #[E,] #tail=tgt
425 | head_vec = make_one_hot(head_type, self.n_ntype) #[E,4]
426 | tail_vec = make_one_hot(tail_type, self.n_ntype) #[E,4]
427 | headtail_vec = torch.cat([head_vec, tail_vec], dim=1) #[E,8]
428 | self_head_vec = make_one_hot(node_type, self.n_ntype) #[N,4]
429 | self_headtail_vec = torch.cat([self_head_vec, self_head_vec], dim=1) #[N,8]
430 |
431 | edge_vec = torch.cat([edge_vec, self_edge_vec], dim=0) #[E+N, ?]
432 | headtail_vec = torch.cat([headtail_vec, self_headtail_vec], dim=0) #[E+N, ?]
433 | edge_embeddings = self.edge_encoder(torch.cat([edge_vec, headtail_vec], dim=1)) #[E+N, emb_dim]
434 |
435 | #Add self loops to edge_index
436 | loop_index = torch.arange(0, x.size(0), dtype=torch.long, device=edge_index.device)
437 | loop_index = loop_index.unsqueeze(0).repeat(2, 1)
438 | edge_index = torch.cat([edge_index, loop_index], dim=1) #[2, E+N]
439 |
440 | x = torch.cat([x, node_feature_extra], dim=1)
441 | x = (x, x)
442 | aggr_out = self.propagate(edge_index, x=x, edge_attr=edge_embeddings) #[N, emb_dim]
443 | out = self.mlp(aggr_out)
444 |
445 | alpha = self._alpha
446 | self._alpha = None
447 |
448 | if return_attention_weights:
449 | assert alpha is not None
450 | return out, (edge_index, alpha)
451 | else:
452 | return out
453 |
454 |
455 | def message(self, edge_index, x_i, x_j, edge_attr): #i: tgt, j:src
456 | # print ("edge_attr.size()", edge_attr.size()) #[E, emb_dim]
457 | # print ("x_j.size()", x_j.size()) #[E, emb_dim]
458 | # print ("x_i.size()", x_i.size()) #[E, emb_dim]
459 | assert len(edge_attr.size()) == 2
460 | assert edge_attr.size(1) == self.emb_dim
461 | assert x_i.size(1) == x_j.size(1) == 2*self.emb_dim
462 | assert x_i.size(0) == x_j.size(0) == edge_attr.size(0) == edge_index.size(1)
463 |
464 | key = self.linear_key(torch.cat([x_i, edge_attr], dim=1)).view(-1, self.head_count, self.dim_per_head) #[E, heads, _dim]
465 | msg = self.linear_msg(torch.cat([x_j, edge_attr], dim=1)).view(-1, self.head_count, self.dim_per_head) #[E, heads, _dim]
466 | query = self.linear_query(x_j).view(-1, self.head_count, self.dim_per_head) #[E, heads, _dim]
467 |
468 |
469 | query = query / math.sqrt(self.dim_per_head)
470 | scores = (query * key).sum(dim=2) #[E, heads]
471 | src_node_index = edge_index[0] #[E,]
472 | alpha = softmax(scores, src_node_index) #[E, heads] #group by src side node
473 | self._alpha = alpha
474 |
475 | #adjust by outgoing degree of src
476 | E = edge_index.size(1) #n_edges
477 | N = int(src_node_index.max()) + 1 #n_nodes
478 | ones = torch.full((E,), 1.0, dtype=torch.float).to(edge_index.device)
479 | src_node_edge_count = scatter(ones, src_node_index, dim=0, dim_size=N, reduce='sum')[src_node_index] #[E,]
480 | assert len(src_node_edge_count.size()) == 1 and len(src_node_edge_count) == E
481 | alpha = alpha * src_node_edge_count.unsqueeze(1) #[E, heads]
482 |
483 | out = msg * alpha.view(-1, self.head_count, 1) #[E, heads, _dim]
484 | return out.view(-1, self.head_count * self.dim_per_head) #[E, emb_dim]
485 |
--------------------------------------------------------------------------------
/utils/data_utils.py:
--------------------------------------------------------------------------------
1 | import pickle
2 | import os
3 | import numpy as np
4 | import torch
5 | from transformers import (OpenAIGPTTokenizer, BertTokenizer, XLNetTokenizer, RobertaTokenizer, AutoTokenizer)
6 | try:
7 | from transformers import AlbertTokenizer
8 | except:
9 | pass
10 |
11 | import json
12 | from tqdm import tqdm
13 |
14 | GPT_SPECIAL_TOKENS = ['_start_', '_delimiter_', '_classify_']
15 |
16 |
17 | class MultiGPUSparseAdjDataBatchGenerator(object):
18 | def __init__(self, args, mode, device0, device1, batch_size, indexes, qids, labels,
19 | tensors0=[], lists0=[], tensors1=[], lists1=[], adj_data=None):
20 | self.args = args
21 | self.mode = mode
22 | self.device0 = device0
23 | self.device1 = device1
24 | self.batch_size = batch_size
25 | self.indexes = indexes
26 | self.qids = qids
27 | self.labels = labels
28 | self.tensors0 = tensors0
29 | self.lists0 = lists0
30 | self.tensors1 = tensors1
31 | self.lists1 = lists1
32 | # self.adj_empty = adj_empty.to(self.device1)
33 | self.adj_data = adj_data
34 |
35 | def __len__(self):
36 | return (self.indexes.size(0) - 1) // self.batch_size + 1
37 |
38 | def __iter__(self):
39 | bs = self.batch_size
40 | n = self.indexes.size(0)
41 | if self.mode=='train' and self.args.drop_partial_batch:
42 | print ('dropping partial batch')
43 | n = (n//bs) *bs
44 | elif self.mode=='train' and self.args.fill_partial_batch:
45 | print ('filling partial batch')
46 | remain = n % bs
47 | if remain > 0:
48 | extra = np.random.choice(self.indexes[:-remain], size=(bs-remain), replace=False)
49 | self.indexes = torch.cat([self.indexes, torch.tensor(extra)])
50 | n = self.indexes.size(0)
51 | assert n % bs == 0
52 |
53 | for a in range(0, n, bs):
54 | b = min(n, a + bs)
55 | batch_indexes = self.indexes[a:b]
56 | batch_qids = [self.qids[idx] for idx in batch_indexes]
57 | batch_labels = self._to_device(self.labels[batch_indexes], self.device1)
58 | batch_tensors0 = [self._to_device(x[batch_indexes], self.device0) for x in self.tensors0]
59 | batch_tensors1 = [self._to_device(x[batch_indexes], self.device1) for x in self.tensors1]
60 | batch_lists0 = [self._to_device([x[i] for i in batch_indexes], self.device0) for x in self.lists0]
61 | batch_lists1 = [self._to_device([x[i] for i in batch_indexes], self.device1) for x in self.lists1]
62 |
63 |
64 | edge_index_all, edge_type_all = self.adj_data
65 | #edge_index_all: nested list of shape (n_samples, num_choice), where each entry is tensor[2, E]
66 | #edge_type_all: nested list of shape (n_samples, num_choice), where each entry is tensor[E, ]
67 | edge_index = self._to_device([edge_index_all[i] for i in batch_indexes], self.device1)
68 | edge_type = self._to_device([edge_type_all[i] for i in batch_indexes], self.device1)
69 |
70 | yield tuple([batch_qids, batch_labels, *batch_tensors0, *batch_lists0, *batch_tensors1, *batch_lists1, edge_index, edge_type])
71 |
72 | def _to_device(self, obj, device):
73 | if isinstance(obj, (tuple, list)):
74 | return [self._to_device(item, device) for item in obj]
75 | else:
76 | return obj.to(device)
77 |
78 |
79 | def load_sparse_adj_data_with_contextnode(adj_pk_path, max_node_num, num_choice, args):
80 | cache_path = adj_pk_path +'.loaded_cache'
81 | use_cache = True
82 |
83 | if use_cache and not os.path.exists(cache_path):
84 | use_cache = False
85 |
86 | if use_cache:
87 | with open(cache_path, 'rb') as f:
88 | adj_lengths_ori, concept_ids, node_type_ids, node_scores, adj_lengths, edge_index, edge_type, half_n_rel = pickle.load(f)
89 | else:
90 | with open(adj_pk_path, 'rb') as fin:
91 | adj_concept_pairs = pickle.load(fin)
92 |
93 | n_samples = len(adj_concept_pairs) #this is actually n_questions x n_choices
94 | edge_index, edge_type = [], []
95 | adj_lengths = torch.zeros((n_samples,), dtype=torch.long)
96 | concept_ids = torch.full((n_samples, max_node_num), 1, dtype=torch.long)
97 | node_type_ids = torch.full((n_samples, max_node_num), 2, dtype=torch.long) #default 2: "other node"
98 | node_scores = torch.zeros((n_samples, max_node_num, 1), dtype=torch.float)
99 |
100 | adj_lengths_ori = adj_lengths.clone()
101 | for idx, _data in tqdm(enumerate(adj_concept_pairs), total=n_samples, desc='loading adj matrices'):
102 | adj, concepts, qm, am, cid2score = _data['adj'], _data['concepts'], _data['qmask'], _data['amask'], _data['cid2score']
103 | #adj: e.g. <4233x249 (n_nodes*half_n_rels x n_nodes) sparse matrix of type '' with 2905 stored elements in COOrdinate format>
104 | #concepts: np.array(num_nodes, ), where entry is concept id
105 | #qm: np.array(num_nodes, ), where entry is True/False
106 | #am: np.array(num_nodes, ), where entry is True/False
107 | assert len(concepts) == len(set(concepts))
108 | qam = qm | am
109 | #sanity check: should be T,..,T,F,F,..F
110 | assert qam[0] == True
111 | F_start = False
112 | for TF in qam:
113 | if TF == False:
114 | F_start = True
115 | else:
116 | assert F_start == False
117 | num_concept = min(len(concepts), max_node_num-1) + 1 #this is the final number of nodes including contextnode but excluding PAD
118 | adj_lengths_ori[idx] = len(concepts)
119 | adj_lengths[idx] = num_concept
120 |
121 | #Prepare nodes
122 | concepts = concepts[:num_concept-1]
123 | concept_ids[idx, 1:num_concept] = torch.tensor(concepts +1) #To accomodate contextnode, original concept_ids incremented by 1
124 | concept_ids[idx, 0] = 0 #this is the "concept_id" for contextnode
125 |
126 | #Prepare node scores
127 | if (cid2score is not None):
128 | for _j_ in range(num_concept):
129 | _cid = int(concept_ids[idx, _j_]) - 1
130 | assert _cid in cid2score
131 | node_scores[idx, _j_, 0] = torch.tensor(cid2score[_cid])
132 |
133 | #Prepare node types
134 | node_type_ids[idx, 0] = 3 #contextnode
135 | node_type_ids[idx, 1:num_concept][torch.tensor(qm, dtype=torch.bool)[:num_concept-1]] = 0
136 | node_type_ids[idx, 1:num_concept][torch.tensor(am, dtype=torch.bool)[:num_concept-1]] = 1
137 |
138 | #Load adj
139 | ij = torch.tensor(adj.row, dtype=torch.int64) #(num_matrix_entries, ), where each entry is coordinate
140 | k = torch.tensor(adj.col, dtype=torch.int64) #(num_matrix_entries, ), where each entry is coordinate
141 | n_node = adj.shape[1]
142 | half_n_rel = adj.shape[0] // n_node
143 | i, j = ij // n_node, ij % n_node
144 |
145 | #Prepare edges
146 | i += 2; j += 1; k += 1 # **** increment coordinate by 1, rel_id by 2 ****
147 | extra_i, extra_j, extra_k = [], [], []
148 | for _coord, q_tf in enumerate(qm):
149 | _new_coord = _coord + 1
150 | if _new_coord > num_concept:
151 | break
152 | if q_tf:
153 | extra_i.append(0) #rel from contextnode to question concept
154 | extra_j.append(0) #contextnode coordinate
155 | extra_k.append(_new_coord) #question concept coordinate
156 | for _coord, a_tf in enumerate(am):
157 | _new_coord = _coord + 1
158 | if _new_coord > num_concept:
159 | break
160 | if a_tf:
161 | extra_i.append(1) #rel from contextnode to answer concept
162 | extra_j.append(0) #contextnode coordinate
163 | extra_k.append(_new_coord) #answer concept coordinate
164 |
165 | half_n_rel += 2 #should be 19 now
166 | if len(extra_i) > 0:
167 | i = torch.cat([i, torch.tensor(extra_i)], dim=0)
168 | j = torch.cat([j, torch.tensor(extra_j)], dim=0)
169 | k = torch.cat([k, torch.tensor(extra_k)], dim=0)
170 | ########################
171 |
172 | mask = (j < max_node_num) & (k < max_node_num)
173 | i, j, k = i[mask], j[mask], k[mask]
174 | i, j, k = torch.cat((i, i + half_n_rel), 0), torch.cat((j, k), 0), torch.cat((k, j), 0) # add inverse relations
175 | edge_index.append(torch.stack([j,k], dim=0)) #each entry is [2, E]
176 | edge_type.append(i) #each entry is [E, ]
177 |
178 | with open(cache_path, 'wb') as f:
179 | pickle.dump([adj_lengths_ori, concept_ids, node_type_ids, node_scores, adj_lengths, edge_index, edge_type, half_n_rel], f)
180 |
181 |
182 | ori_adj_mean = adj_lengths_ori.float().mean().item()
183 | ori_adj_sigma = np.sqrt(((adj_lengths_ori.float() - ori_adj_mean)**2).mean().item())
184 | print('| ori_adj_len: mu {:.2f} sigma {:.2f} | adj_len: {:.2f} |'.format(ori_adj_mean, ori_adj_sigma, adj_lengths.float().mean().item()) +
185 | ' prune_rate: {:.2f} |'.format((adj_lengths_ori > adj_lengths).float().mean().item()) +
186 | ' qc_num: {:.2f} | ac_num: {:.2f} |'.format((node_type_ids == 0).float().sum(1).mean().item(),
187 | (node_type_ids == 1).float().sum(1).mean().item()))
188 |
189 | edge_index = list(map(list, zip(*(iter(edge_index),) * num_choice))) #list of size (n_questions, n_choices), where each entry is tensor[2, E] #this operation corresponds to .view(n_questions, n_choices)
190 | edge_type = list(map(list, zip(*(iter(edge_type),) * num_choice))) #list of size (n_questions, n_choices), where each entry is tensor[E, ]
191 |
192 | concept_ids, node_type_ids, node_scores, adj_lengths = [x.view(-1, num_choice, *x.size()[1:]) for x in (concept_ids, node_type_ids, node_scores, adj_lengths)]
193 | #concept_ids: (n_questions, num_choice, max_node_num)
194 | #node_type_ids: (n_questions, num_choice, max_node_num)
195 | #node_scores: (n_questions, num_choice, max_node_num)
196 | #adj_lengths: (n_questions, num_choice)
197 | return concept_ids, node_type_ids, node_scores, adj_lengths, (edge_index, edge_type) #, half_n_rel * 2 + 1
198 |
199 |
200 |
201 |
202 |
203 | def load_gpt_input_tensors(statement_jsonl_path, max_seq_length):
204 | def _truncate_seq_pair(tokens_a, tokens_b, max_length):
205 | """Truncates a sequence pair in place to the maximum length."""
206 | while True:
207 | total_length = len(tokens_a) + len(tokens_b)
208 | if total_length <= max_length:
209 | break
210 | if len(tokens_a) > len(tokens_b):
211 | tokens_a.pop()
212 | else:
213 | tokens_b.pop()
214 |
215 | def load_qa_dataset(dataset_path):
216 | """ Output a list of tuples(story, 1st continuation, 2nd continuation, label) """
217 | with open(dataset_path, "r", encoding="utf-8") as fin:
218 | output = []
219 | for line in fin:
220 | input_json = json.loads(line)
221 | label = ord(input_json.get("answerKey", "A")) - ord("A")
222 | output.append((input_json['id'], input_json["question"]["stem"], *[ending["text"] for ending in input_json["question"]["choices"]], label))
223 | return output
224 |
225 | def pre_process_datasets(encoded_datasets, num_choices, max_seq_length, start_token, delimiter_token, clf_token):
226 | """ Pre-process datasets containing lists of tuples(story, 1st continuation, 2nd continuation, label)
227 |
228 | To Transformer inputs of shape (n_batch, n_alternative, length) comprising for each batch, continuation:
229 | input_ids[batch, alternative, :] = [start_token] + story[:cap_length] + [delimiter_token] + cont1[:cap_length] + [clf_token]
230 | """
231 | tensor_datasets = []
232 | for dataset in encoded_datasets:
233 | n_batch = len(dataset)
234 | input_ids = np.zeros((n_batch, num_choices, max_seq_length), dtype=np.int64)
235 | mc_token_ids = np.zeros((n_batch, num_choices), dtype=np.int64)
236 | lm_labels = np.full((n_batch, num_choices, max_seq_length), fill_value=-1, dtype=np.int64)
237 | mc_labels = np.zeros((n_batch,), dtype=np.int64)
238 | for i, data, in enumerate(dataset):
239 | q, mc_label = data[0], data[-1]
240 | choices = data[1:-1]
241 | for j in range(len(choices)):
242 | _truncate_seq_pair(q, choices[j], max_seq_length - 3)
243 | qa = [start_token] + q + [delimiter_token] + choices[j] + [clf_token]
244 | input_ids[i, j, :len(qa)] = qa
245 | mc_token_ids[i, j] = len(qa) - 1
246 | lm_labels[i, j, :len(qa) - 1] = qa[1:]
247 | mc_labels[i] = mc_label
248 | all_inputs = (input_ids, mc_token_ids, lm_labels, mc_labels)
249 | tensor_datasets.append(tuple(torch.tensor(t) for t in all_inputs))
250 | return tensor_datasets
251 |
252 | def tokenize_and_encode(tokenizer, obj):
253 | """ Tokenize and encode a nested object """
254 | if isinstance(obj, str):
255 | return tokenizer.convert_tokens_to_ids(tokenizer.tokenize(obj))
256 | elif isinstance(obj, int):
257 | return obj
258 | else:
259 | return list(tokenize_and_encode(tokenizer, o) for o in obj)
260 |
261 | tokenizer = OpenAIGPTTokenizer.from_pretrained('openai-gpt')
262 | tokenizer.add_tokens(GPT_SPECIAL_TOKENS)
263 | special_tokens_ids = tokenizer.convert_tokens_to_ids(GPT_SPECIAL_TOKENS)
264 |
265 | dataset = load_qa_dataset(statement_jsonl_path)
266 | examples_ids = [data[0] for data in dataset]
267 | dataset = [data[1:] for data in dataset] # discard example ids
268 | num_choices = len(dataset[0]) - 2
269 |
270 | encoded_dataset = tokenize_and_encode(tokenizer, dataset)
271 |
272 | (input_ids, mc_token_ids, lm_labels, mc_labels), = pre_process_datasets([encoded_dataset], num_choices, max_seq_length, *special_tokens_ids)
273 | return examples_ids, mc_labels, input_ids, mc_token_ids, lm_labels
274 |
275 |
276 | def get_gpt_token_num():
277 | tokenizer = OpenAIGPTTokenizer.from_pretrained('openai-gpt')
278 | tokenizer.add_tokens(GPT_SPECIAL_TOKENS)
279 | return len(tokenizer)
280 |
281 |
282 |
283 | def load_bert_xlnet_roberta_input_tensors(statement_jsonl_path, model_type, model_name, max_seq_length):
284 | class InputExample(object):
285 |
286 | def __init__(self, example_id, question, contexts, endings, label=None):
287 | self.example_id = example_id
288 | self.question = question
289 | self.contexts = contexts
290 | self.endings = endings
291 | self.label = label
292 |
293 | class InputFeatures(object):
294 |
295 | def __init__(self, example_id, choices_features, label):
296 | self.example_id = example_id
297 | self.choices_features = [
298 | {
299 | 'input_ids': input_ids,
300 | 'input_mask': input_mask,
301 | 'segment_ids': segment_ids,
302 | 'output_mask': output_mask,
303 | }
304 | for _, input_ids, input_mask, segment_ids, output_mask in choices_features
305 | ]
306 | self.label = label
307 |
308 | def read_examples(input_file):
309 | with open(input_file, "r", encoding="utf-8") as f:
310 | examples = []
311 | for line in f.readlines():
312 | json_dic = json.loads(line)
313 | label = ord(json_dic["answerKey"]) - ord("A") if 'answerKey' in json_dic else 0
314 | contexts = json_dic["question"]["stem"]
315 | if "para" in json_dic:
316 | contexts = json_dic["para"] + " " + contexts
317 | if "fact1" in json_dic:
318 | contexts = json_dic["fact1"] + " " + contexts
319 | examples.append(
320 | InputExample(
321 | example_id=json_dic["id"],
322 | contexts=[contexts] * len(json_dic["question"]["choices"]),
323 | question="",
324 | endings=[ending["text"] for ending in json_dic["question"]["choices"]],
325 | label=label
326 | ))
327 | return examples
328 |
329 | def convert_examples_to_features(examples, label_list, max_seq_length,
330 | tokenizer,
331 | cls_token_at_end=False,
332 | cls_token='[CLS]',
333 | cls_token_segment_id=1,
334 | sep_token='[SEP]',
335 | sequence_a_segment_id=0,
336 | sequence_b_segment_id=1,
337 | sep_token_extra=False,
338 | pad_token_segment_id=0,
339 | pad_on_left=False,
340 | pad_token=0,
341 | mask_padding_with_zero=True):
342 | """ Loads a data file into a list of `InputBatch`s
343 | `cls_token_at_end` define the location of the CLS token:
344 | - False (Default, BERT/XLM pattern): [CLS] + A + [SEP] + B + [SEP]
345 | - True (XLNet/GPT pattern): A + [SEP] + B + [SEP] + [CLS]
346 | `cls_token_segment_id` define the segment id associated to the CLS token (0 for BERT, 2 for XLNet)
347 | """
348 | label_map = {label: i for i, label in enumerate(label_list)}
349 |
350 | features = []
351 | for ex_index, example in enumerate(tqdm(examples)):
352 | choices_features = []
353 | for ending_idx, (context, ending) in enumerate(zip(example.contexts, example.endings)):
354 | tokens_a = tokenizer.tokenize(context)
355 | tokens_b = tokenizer.tokenize(example.question + " " + ending)
356 |
357 | special_tokens_count = 4 if sep_token_extra else 3
358 | _truncate_seq_pair(tokens_a, tokens_b, max_seq_length - special_tokens_count)
359 |
360 | # The convention in BERT is:
361 | # (a) For sequence pairs:
362 | # tokens: [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP]
363 | # type_ids: 0 0 0 0 0 0 0 0 1 1 1 1 1 1
364 | # (b) For single sequences:
365 | # tokens: [CLS] the dog is hairy . [SEP]
366 | # type_ids: 0 0 0 0 0 0 0
367 | #
368 | # Where "type_ids" are used to indicate whether this is the first
369 | # sequence or the second sequence. The embedding vectors for `type=0` and
370 | # `type=1` were learned during pre-training and are added to the wordpiece
371 | # embedding vector (and position vector). This is not *strictly* necessary
372 | # since the [SEP] token unambiguously separates the sequences, but it makes
373 | # it easier for the model to learn the concept of sequences.
374 | #
375 | # For classification tasks, the first vector (corresponding to [CLS]) is
376 | # used as as the "sentence vector". Note that this only makes sense because
377 | # the entire model is fine-tuned.
378 | tokens = tokens_a + [sep_token]
379 | if sep_token_extra:
380 | # roberta uses an extra separator b/w pairs of sentences
381 | tokens += [sep_token]
382 |
383 | segment_ids = [sequence_a_segment_id] * len(tokens)
384 |
385 | if tokens_b:
386 | tokens += tokens_b + [sep_token]
387 | segment_ids += [sequence_b_segment_id] * (len(tokens_b) + 1)
388 |
389 | if cls_token_at_end:
390 | tokens = tokens + [cls_token]
391 | segment_ids = segment_ids + [cls_token_segment_id]
392 | else:
393 | tokens = [cls_token] + tokens
394 | segment_ids = [cls_token_segment_id] + segment_ids
395 |
396 | input_ids = tokenizer.convert_tokens_to_ids(tokens)
397 |
398 | # The mask has 1 for real tokens and 0 for padding tokens. Only real
399 | # tokens are attended to.
400 |
401 | input_mask = [1 if mask_padding_with_zero else 0] * len(input_ids)
402 | special_token_id = tokenizer.convert_tokens_to_ids([cls_token, sep_token])
403 | output_mask = [1 if id in special_token_id else 0 for id in input_ids] # 1 for mask
404 |
405 | # Zero-pad up to the sequence length.
406 | padding_length = max_seq_length - len(input_ids)
407 | if pad_on_left:
408 | input_ids = ([pad_token] * padding_length) + input_ids
409 | input_mask = ([0 if mask_padding_with_zero else 1] * padding_length) + input_mask
410 | output_mask = ([1] * padding_length) + output_mask
411 |
412 | segment_ids = ([pad_token_segment_id] * padding_length) + segment_ids
413 | else:
414 | input_ids = input_ids + ([pad_token] * padding_length)
415 | input_mask = input_mask + ([0 if mask_padding_with_zero else 1] * padding_length)
416 | output_mask = output_mask + ([1] * padding_length)
417 | segment_ids = segment_ids + ([pad_token_segment_id] * padding_length)
418 |
419 | assert len(input_ids) == max_seq_length
420 | assert len(output_mask) == max_seq_length
421 | assert len(input_mask) == max_seq_length
422 | assert len(segment_ids) == max_seq_length
423 | choices_features.append((tokens, input_ids, input_mask, segment_ids, output_mask))
424 | label = label_map[example.label]
425 | features.append(InputFeatures(example_id=example.example_id, choices_features=choices_features, label=label))
426 |
427 | return features
428 |
429 | def _truncate_seq_pair(tokens_a, tokens_b, max_length):
430 | """Truncates a sequence pair in place to the maximum length."""
431 |
432 | # This is a simple heuristic which will always truncate the longer sequence
433 | # one token at a time. This makes more sense than truncating an equal percent
434 | # of tokens from each, since if one sequence is very short then each token
435 | # that's truncated likely contains more information than a longer sequence.
436 | while True:
437 | total_length = len(tokens_a) + len(tokens_b)
438 | if total_length <= max_length:
439 | break
440 | if len(tokens_a) > len(tokens_b):
441 | tokens_a.pop()
442 | else:
443 | tokens_b.pop()
444 |
445 | def select_field(features, field):
446 | return [[choice[field] for choice in feature.choices_features] for feature in features]
447 |
448 | def convert_features_to_tensors(features):
449 | all_input_ids = torch.tensor(select_field(features, 'input_ids'), dtype=torch.long)
450 | all_input_mask = torch.tensor(select_field(features, 'input_mask'), dtype=torch.long)
451 | all_segment_ids = torch.tensor(select_field(features, 'segment_ids'), dtype=torch.long)
452 | all_output_mask = torch.tensor(select_field(features, 'output_mask'), dtype=torch.bool)
453 | all_label = torch.tensor([f.label for f in features], dtype=torch.long)
454 | return all_input_ids, all_input_mask, all_segment_ids, all_output_mask, all_label
455 |
456 | # try:
457 | # tokenizer_class = {'bert': BertTokenizer, 'xlnet': XLNetTokenizer, 'roberta': RobertaTokenizer, 'albert': AlbertTokenizer}.get(model_type)
458 | # except:
459 | # tokenizer_class = {'bert': BertTokenizer, 'xlnet': XLNetTokenizer, 'roberta': RobertaTokenizer}.get(model_type)
460 | tokenizer_class = AutoTokenizer
461 | tokenizer = tokenizer_class.from_pretrained(model_name)
462 | examples = read_examples(statement_jsonl_path)
463 | features = convert_examples_to_features(examples, list(range(len(examples[0].endings))), max_seq_length, tokenizer,
464 | cls_token_at_end=bool(model_type in ['xlnet']), # xlnet has a cls token at the end
465 | cls_token=tokenizer.cls_token,
466 | sep_token=tokenizer.sep_token,
467 | sep_token_extra=bool(model_type in ['roberta', 'albert']),
468 | cls_token_segment_id=2 if model_type in ['xlnet'] else 0,
469 | pad_on_left=bool(model_type in ['xlnet']), # pad on the left for xlnet
470 | pad_token_segment_id=4 if model_type in ['xlnet'] else 0,
471 | sequence_b_segment_id=0 if model_type in ['roberta', 'albert'] else 1)
472 | example_ids = [f.example_id for f in features]
473 | *data_tensors, all_label = convert_features_to_tensors(features)
474 | return (example_ids, all_label, *data_tensors)
475 |
476 |
477 |
478 | def load_input_tensors(input_jsonl_path, model_type, model_name, max_seq_length):
479 | if model_type in ('lstm',):
480 | raise NotImplementedError
481 | elif model_type in ('gpt',):
482 | return load_gpt_input_tensors(input_jsonl_path, max_seq_length)
483 | elif model_type in ('bert', 'xlnet', 'roberta', 'albert'):
484 | return load_bert_xlnet_roberta_input_tensors(input_jsonl_path, model_type, model_name, max_seq_length)
485 |
486 |
487 | def load_info(statement_path: str):
488 | n = sum(1 for _ in open(statement_path, "r"))
489 | num_choice = None
490 | with open(statement_path, "r", encoding="utf-8") as fin:
491 | ids = []
492 | labels = []
493 | for line in fin:
494 | input_json = json.loads(line)
495 | labels.append(ord(input_json.get("answerKey", "A")) - ord("A"))
496 | ids.append(input_json['id'])
497 | if num_choice is None:
498 | num_choice = len(input_json["question"]["choices"])
499 | labels = torch.tensor(labels, dtype=torch.long)
500 |
501 | return ids, labels, num_choice
502 |
503 |
504 | def load_statement_dict(statement_path):
505 | all_dict = {}
506 | with open(statement_path, 'r', encoding='utf-8') as fin:
507 | for line in fin:
508 | instance_dict = json.loads(line)
509 | qid = instance_dict['id']
510 | all_dict[qid] = {
511 | 'question': instance_dict['question']['stem'],
512 | 'answers': [dic['text'] for dic in instance_dict['question']['choices']]
513 | }
514 | return all_dict
515 |
--------------------------------------------------------------------------------