├── .gitignore ├── LICENSE ├── README.md └── hiexpl ├── __init__.py ├── algo ├── __init__.py ├── cd_func.py ├── scd_func.py ├── scd_lstm.py ├── scd_transformer.py ├── soc_lstm.py └── soc_transformer.py ├── bert ├── __init__.py ├── __main__.py ├── convert_gpt2_checkpoint_to_pytorch.py ├── convert_openai_checkpoint_to_pytorch.py ├── convert_tf_checkpoint_to_pytorch.py ├── convert_transfo_xl_checkpoint_to_pytorch.py ├── decomp_util.py ├── file_utils.py ├── filter_sentence.py ├── global_state.py ├── modeling.py ├── modeling_gpt2.py ├── modeling_openai.py ├── modeling_transfo_xl.py ├── modeling_transfo_xl_utilities.py ├── optimization.py ├── optimization_openai.py ├── run_classifier.py ├── run_lm_finetuning.py ├── tacred_f1.py ├── tokenization.py ├── tokenization_gpt2.py ├── tokenization_openai.py └── tokenization_transfo_xl.py ├── eval_explanations.py ├── explain.py ├── lm ├── __init__.py ├── lm.py └── lm_train.py ├── nns ├── __init__.py ├── layers.py ├── linear_model.py └── model.py ├── scripts ├── explanations │ └── explain_sst_lstm.sh └── train_model │ └── train_sst_lstm.sh ├── train.py ├── utils ├── __init__.py ├── agglomeration.py ├── args.py ├── parser.py ├── reader.py └── tacred_f1.py └── visualize.py /.gitignore: -------------------------------------------------------------------------------- 1 | .vector_cache/ 2 | *.pyc 3 | .data/ 4 | .ipynb_checkpoints/ 5 | .idea/ 6 | data/ 7 | ~$* 8 | *.txt 9 | *.ev 10 | figs/ 11 | .xlsx 12 | sst*/ 13 | results/ 14 | cd_results/ 15 | cd_sample/ 16 | cd_bert/ 17 | results_bert/ 18 | examples/ 19 | glue_data/ 20 | turk/ 21 | 22 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 INK Lab @ USC 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Hierarchical Explanations for Neural Sequence Model Predictions 2 | 3 | This repo include implementation of SOC and SCD algorithm, scripts for visualization and evaluation. 4 | Paper: [Towards Hierarchical Importance Attribution: Explaining Compositional Semantics for Neural Sequence Models](https://openreview.net/pdf?id=BkxRRkSKwr), ICLR 2020. 5 | 6 | ## Installation 7 | ```shell script 8 | conda create -n hiexpl-env python==3.7.4 9 | conda activate hiexpl-env 10 | # modify CUDA version as yours 11 | conda install pytorch=0.4.1 cuda90 -c pytorch 12 | pip install nltk numpy scikit-learn scikit-image matplotlib torchtext 13 | # requirements from pytorch-transformers 14 | pip install tokenizers==0.0.11 boto3 filelock requests tqdm regex sentencepiece sacremoses 15 | ``` 16 | 17 | ## Pipeline 18 | ### SST-2 (LSTM) 19 | 20 | Train a LSTM classifier. The SST-2 dataset will be downloaded automatically. 21 | 22 | ```shell script 23 | mkdir models 24 | export model_path=models/sst_lstm 25 | python train.py --task sst --save_path models/${model_path} --no_subtrees --lr 0.0005 26 | ``` 27 | 28 | Pretrain a language model on the training set. 29 | ```shell script 30 | export lm_path=models/sst_lstm_lm 31 | python -m lm.lm_train --task sst --save_path models/${lm_path} --no_subtrees --lr 0.0002 32 | ``` 33 | 34 | Use SOC/SCD to interpret first 50 predictions on dev set. `nb_range=10` and `sample_n=20` recommended. 35 | ```shell script 36 | export algo=soc # or scd 37 | export exp_name=.sst_demo 38 | mkdir sst 39 | mkdir sst/results 40 | python explain.py --resume_snapshot models/${model_path} --method ${algo} --lm_path models/${lm_path} --batch_size 1 --start 0 --stop 50 --exp_name ${exp_name} --task sst --explain_model lstm --nb_range 10 --sample_n 20 41 | ``` 42 | Check `outputs/sst/soc_results/` or `outputs/sst/scd_results` for explanation outputs. Each line contains one instance, where score attribution for each word/phrase is tab-separated. For example: 43 | 44 | ``` 45 | it 0.142656 's 0.192409 the 0.175471 best 0.829247 film 0.095305 best film 0.805854 the best film 1.004583 ... 46 | ``` 47 | 48 | Or use `--agg` flag to automatically construct hierarchical explanations, without need for ground truth parsing trees. 49 | 50 | ```shell script 51 | python explain.py --resume_snapshot models/${model_path} --method ${algo} --lm_path models/${lm_path} --batch_size 1 --start 0 --stop 50 --exp_name ${exp_name} --task sst --explain_model lstm --nb_range 10 --sample_n 20 --agg 52 | ``` 53 | 54 | The output can be read by `visualize.py` to generate visualizations. 55 | 56 | ### SST-2 (BERT) 57 | Download SST-2 dataset from https://gluebenchmark.com/ and unzip at `bert/glue_data`. Then finetune the BERT model to build a classifier. 58 | 59 | ```shell script 60 | python -m bert.run_classifier \ 61 | --task_name SST-2 \ 62 | --do_train \ 63 | --do_eval \ 64 | --do_lower_case \ 65 | --data_dir bert/glue_data/SST-2 \ 66 | --bert_model bert-base-uncased \ 67 | --max_seq_length 128 \ 68 | --train_batch_size 32 \ 69 | --learning_rate 2e-5 \ 70 | --num_train_epochs 3 \ 71 | --output_dir bert/models_sst 72 | ``` 73 | 74 | Then train a language model on BERT-tokenized inputs, and run explanations. Simply add the `--use_bert_tokenizer` and `--explain_model bert` flag to all the experiments above for LSTM. 75 | 76 | Note that you need to filter out subtrees in train.tsv if you are interested in evaluating explanations. 77 | 78 | ## Evaluating explanations 79 | 80 | To evaluate word level explanation, a BOW linear classifier is required. 81 | ```shell script 82 | python -m nns.linear_model --task sst --save_path models/${model_path} 83 | ``` 84 | 85 | For evaluation of phrase level explanation, you also need to download the original SST dataset. 86 | ```shell script 87 | wget http://nlp.stanford.edu/~socherr/stanfordSentimentTreebank.zip 88 | unzip ./.data/stanfordSentimentTreebank.zip -d ./.data/ 89 | mv ./.data/stanfordSentimentTreebank ./.data/sst_raw 90 | ``` 91 | 92 | Then run the evaluation script: 93 | ```shell script 94 | python eval_explanations.py --eval_file outputs/soc${exp_name}.txt 95 | ``` 96 | 97 | 98 | ## Contact 99 | 100 | If you have any questions about the paper or the code, please feel free to contact Xisen Jin (xisenjin usc edu). 101 | 102 | -------------------------------------------------------------------------------- /hiexpl/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/INK-USC/hierarchical-explanation-neural-sequence-models/d797daee2cea327ff3a7fb5d9f077412861f834f/hiexpl/__init__.py -------------------------------------------------------------------------------- /hiexpl/algo/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/INK-USC/hierarchical-explanation-neural-sequence-models/d797daee2cea327ff3a7fb5d9f077412861f834f/hiexpl/algo/__init__.py -------------------------------------------------------------------------------- /hiexpl/algo/cd_func.py: -------------------------------------------------------------------------------- 1 | # MIT License 2 | # 3 | # Copyright (c) 2018 William James Murdoch 4 | # 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy 6 | # of this software and associated documentation files (the "Software"), to deal 7 | # in the Software without restriction, including without limitation the rights 8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | # copies of the Software, and to permit persons to whom the Software is 10 | # furnished to do so, subject to the following conditions: 11 | # 12 | # The above copyright notice and this permission notice shall be included in all 13 | # copies or substantial portions of the Software. 14 | 15 | 16 | 17 | import os 18 | import pdb 19 | import torch 20 | import numpy as np 21 | from torchtext import data, datasets 22 | import random 23 | from math import e 24 | from utils.args import args 25 | 26 | 27 | def sigmoid(x): 28 | return 1 / (1 + e ** (-x)) 29 | 30 | 31 | def tanh(x): 32 | return (1 - e ** (-2 * x)) / (1 + e ** (-2 * x)) 33 | 34 | 35 | def get_model(snapshot_file): 36 | print('loading', snapshot_file) 37 | try: # load onto gpu 38 | model = torch.load(snapshot_file) 39 | print('loaded onto gpu...') 40 | except: # load onto cpu 41 | model = torch.load(snapshot_file, map_location=lambda storage, loc: storage) 42 | print('loaded onto cpu...') 43 | return model 44 | 45 | 46 | def get_sst(): 47 | inputs = data.Field(lower='preserve-case') 48 | answers = data.Field(sequential=False, unk_token=None) 49 | 50 | # build with subtrees so inputs are right 51 | train_s, dev_s, test_s = datasets.SST.splits(inputs, answers, fine_grained=False, train_subtrees=True, 52 | filter_pred=lambda ex: ex.label != 'neutral') 53 | inputs.build_vocab(train_s, dev_s, test_s) 54 | answers.build_vocab(train_s) 55 | 56 | # rebuild without subtrees to get longer sentences 57 | train, dev, test = datasets.SST.splits(inputs, answers, fine_grained=False, train_subtrees=False, 58 | filter_pred=lambda ex: ex.label != 'neutral') 59 | 60 | train_iter, dev_iter, test_iter = data.BucketIterator.splits( 61 | (train, dev, test), batch_size=1, device=0) 62 | 63 | return inputs, answers, train_iter, dev_iter 64 | 65 | 66 | def get_batches(batch_nums, train_iterator, dev_iterator, dset='train'): 67 | print('getting batches...') 68 | np.random.seed(0) 69 | random.seed(0) 70 | 71 | # pick data_iterator 72 | if dset == 'train': 73 | data_iterator = train_iterator 74 | elif dset == 'dev': 75 | data_iterator = dev_iterator 76 | 77 | # actually get batches 78 | num = 0 79 | batches = {} 80 | data_iterator.init_epoch() 81 | for batch_idx, batch in enumerate(data_iterator): 82 | if batch_idx == batch_nums[num]: 83 | batches[batch_idx] = batch 84 | num += 1 85 | 86 | if num == max(batch_nums): 87 | break 88 | elif num == len(batch_nums): 89 | print('found them all') 90 | break 91 | return batches 92 | 93 | 94 | def evaluate_predictions(snapshot_file): 95 | print('loading', snapshot_file) 96 | try: # load onto gpu 97 | model = torch.load(snapshot_file) 98 | except: # load onto cpu 99 | model = torch.load(snapshot_file, map_location=lambda storage, loc: storage) 100 | inputs = data.Field() 101 | answers = data.Field(sequential=False, unk_token=None) 102 | 103 | train, dev, test = datasets.SST.splits(inputs, answers, fine_grained=False, train_subtrees=False, 104 | filter_pred=lambda ex: ex.label != 'neutral') 105 | inputs.build_vocab(train) 106 | answers.build_vocab(train) 107 | train_iter, dev_iter, test_iter = data.BucketIterator.splits( 108 | (train, dev, test), batch_size=1, device=0) 109 | train_iter.init_epoch() 110 | for batch_idx, batch in enumerate(train_iter): 111 | print('batch_idx', batch_idx) 112 | out = model(batch) 113 | target = batch.label 114 | break 115 | return batch, out, target 116 | 117 | 118 | def is_in_intervals(idx, intervals): 119 | for interval in intervals: 120 | if interval[0] <= idx <= interval[1]: 121 | return True 122 | return False 123 | 124 | 125 | def partition_bias(rel, irrel, bias): 126 | b_r = bias * abs(rel) / (abs(rel) + abs(irrel) + 1e-12) 127 | b_ir = bias * (abs(irrel) + 1e-12) / (abs(rel) + abs(irrel) + 1e-12) 128 | return b_r, b_ir 129 | 130 | 131 | def CD(batch, model, intervals): 132 | if not args.task == 'tacred': 133 | word_vecs = model.embed(batch.text)[:, 0] 134 | lstm_module = model.lstm 135 | else: 136 | token_vec = model.embed(batch.text) 137 | pos_vec = model.pos_embed(batch.pos) 138 | ner_vec = model.ner_embed(batch.ner) 139 | word_vecs = model.drop(torch.cat([token_vec, pos_vec, ner_vec], -1))[:, 0] 140 | lstm_module = model.lstm.rnn 141 | 142 | weights = lstm_module.state_dict() 143 | 144 | # Index one = word vector (i) or hidden state (h), index two = gate 145 | W_ii, W_if, W_ig, W_io = np.split(weights['weight_ih_l0'], 4, 0) 146 | W_hi, W_hf, W_hg, W_ho = np.split(weights['weight_hh_l0'], 4, 0) 147 | b_i, b_f, b_g, b_o = np.split(weights['bias_ih_l0'].cpu().numpy() + weights['bias_hh_l0'].cpu().numpy(), 4) 148 | # word_vecs = model.embed(batch.text)[:, 0].data 149 | T = word_vecs.size(0) 150 | relevant = np.zeros((T, model.hidden_dim)) 151 | irrelevant = np.zeros((T, model.hidden_dim)) 152 | relevant_h = np.zeros((T, model.hidden_dim)) 153 | irrelevant_h = np.zeros((T, model.hidden_dim)) 154 | for i in range(T): 155 | if i > 0: 156 | prev_rel_h = relevant_h[i - 1] 157 | prev_irrel_h = irrelevant_h[i - 1] 158 | else: 159 | prev_rel_h = np.zeros(model.hidden_dim) 160 | prev_irrel_h = np.zeros(model.hidden_dim) 161 | 162 | rel_i = np.dot(W_hi, prev_rel_h) 163 | rel_g = np.dot(W_hg, prev_rel_h) 164 | rel_f = np.dot(W_hf, prev_rel_h) 165 | rel_o = np.dot(W_ho, prev_rel_h) 166 | irrel_i = np.dot(W_hi, prev_irrel_h) 167 | irrel_g = np.dot(W_hg, prev_irrel_h) 168 | irrel_f = np.dot(W_hf, prev_irrel_h) 169 | irrel_o = np.dot(W_ho, prev_irrel_h) 170 | 171 | if is_in_intervals(i, intervals): 172 | rel_i = rel_i + np.dot(W_ii, word_vecs[i]) 173 | rel_g = rel_g + np.dot(W_ig, word_vecs[i]) 174 | rel_f = rel_f + np.dot(W_if, word_vecs[i]) 175 | rel_o = rel_o + np.dot(W_io, word_vecs[i]) 176 | else: 177 | irrel_i = irrel_i + np.dot(W_ii, word_vecs[i]) 178 | irrel_g = irrel_g + np.dot(W_ig, word_vecs[i]) 179 | irrel_f = irrel_f + np.dot(W_if, word_vecs[i]) 180 | irrel_o = irrel_o + np.dot(W_io, word_vecs[i]) 181 | 182 | rel_contrib_i, irrel_contrib_i, bias_contrib_i = decomp_three(rel_i, irrel_i, b_i, sigmoid) 183 | rel_contrib_g, irrel_contrib_g, bias_contrib_g = decomp_three(rel_g, irrel_g, b_g, np.tanh) 184 | 185 | # i_value = sum([rel_contrib_i, irrel_contrib_i, bias_contrib_i]) 186 | # g_value = sum([rel_contrib_g, irrel_contrib_g, bias_contrib_g]) 187 | # print(i_value[:10], g_value[:10]) 188 | 189 | relevant[i] = rel_contrib_i * (rel_contrib_g + bias_contrib_g) + bias_contrib_i * rel_contrib_g 190 | irrelevant[i] = irrel_contrib_i * (rel_contrib_g + irrel_contrib_g + bias_contrib_g) + ( 191 | rel_contrib_i + bias_contrib_i) * irrel_contrib_g 192 | 193 | if is_in_intervals(i, intervals): 194 | relevant[i] += bias_contrib_i * bias_contrib_g 195 | else: 196 | irrelevant[i] += bias_contrib_i * bias_contrib_g 197 | 198 | # c = relevant[i] + irrelevant[i] 199 | # print('_', c[:10]) 200 | 201 | if i > 0: 202 | rel_contrib_f, irrel_contrib_f, bias_contrib_f = decomp_three(rel_f, irrel_f, b_f, sigmoid) 203 | relevant[i] += (rel_contrib_f + bias_contrib_f) * relevant[i - 1] 204 | irrelevant[i] += (rel_contrib_f + irrel_contrib_f + bias_contrib_f) * irrelevant[i - 1] + irrel_contrib_f * \ 205 | relevant[i - 1] 206 | 207 | # c = relevant[i] + irrelevant[i] 208 | # print(c[:10]) 209 | 210 | o = sigmoid(np.dot(W_io, word_vecs[i]) + np.dot(W_ho, prev_rel_h + prev_irrel_h) + b_o) 211 | # print('o', o[:10]) 212 | # rel_contrib_o, irrel_contrib_o, bias_contrib_o = decomp_three(rel_o, irrel_o, b_o, sigmoid) 213 | new_rel_h, new_irrel_h = decomp_tanh_two(relevant[i], irrelevant[i]) 214 | # h = new_rel_h + new_irrel_h 215 | # print(h[:10]) 216 | # relevant_h[i] = new_rel_h * (rel_contrib_o + bias_contrib_o) 217 | # irrelevant_h[i] = new_rel_h * irrel_contrib_o + new_irrel_h * (rel_contrib_o + irrel_contrib_o + bias_contrib_o) 218 | relevant_h[i] = o * new_rel_h 219 | irrelevant_h[i] = o * new_irrel_h 220 | 221 | W_out = model.hidden_to_label.weight.data 222 | 223 | if args.task == 'tacred': 224 | relevant_h[T - 1] = model.drop(torch.from_numpy(relevant_h[T - 1]).view(1, -1).cuda()).view(-1).cpu().numpy() 225 | # Sanity check: scores + irrel_scores should equal the LSTM's output minus model.hidden_to_label.bias 226 | if not args.mean_hidden: 227 | scores = np.dot(W_out, relevant_h[T - 1]) 228 | irrel_scores = np.dot(W_out, irrelevant_h[T - 1]) 229 | else: 230 | scores = np.dot(W_out, np.mean(relevant_h, 0)) 231 | irrel_scores = np.dot(W_out, np.mean(irrelevant_h[T - 1])) 232 | 233 | return scores, irrel_scores 234 | 235 | 236 | def decomp_three(a, b, c, activation): 237 | a_contrib = 0.5 * (activation(a + c) - activation(c)) + 0.5 * (activation(a + b + c) - activation(b + c)) 238 | c_contrib = activation(c) 239 | b_contrib = activation(a + b + c) - a_contrib - c_contrib 240 | return a_contrib, b_contrib, c_contrib 241 | 242 | 243 | def decomp_tanh_two(a, b): 244 | return 0.5 * (np.tanh(a) + (np.tanh(a + b) - np.tanh(b))), 0.5 * (np.tanh(b) + (np.tanh(a + b) - np.tanh(a))) 245 | 246 | 247 | def decomp_activation_two_pad(a, b, activation): 248 | a_contrib = activation(a) 249 | b_contrib = activation(a + b) - a_contrib 250 | return a_contrib, b_contrib 251 | 252 | 253 | def decomp_three_pad(a, b, c, activation): 254 | a_contrib = 1 / 2 * (activation(a + c) - activation(c)) + 1 / 2 * (activation(a + b + c) - activation(b + c)) 255 | c_contrib = activation(c) 256 | b_contrib = activation(a + b + c) - a_contrib - c_contrib 257 | return a_contrib, b_contrib, c_contrib -------------------------------------------------------------------------------- /hiexpl/algo/scd_func.py: -------------------------------------------------------------------------------- 1 | from utils.args import * 2 | import torch.nn.functional as F 3 | import torch 4 | from .cd_func import is_in_intervals, sigmoid, tanh 5 | 6 | args = get_args() 7 | 8 | def CD_gpu(batch, model, intervals, hist_states, gpu): 9 | # if task is tacred, then the word vecs is more complicated 10 | if not args.task == 'tacred': 11 | word_vecs = model.embed(batch.text)[:,0] 12 | lstm_module = model.lstm 13 | else: 14 | token_vec = model.embed(batch.text) 15 | pos_vec = model.pos_embed(batch.pos) 16 | ner_vec = model.ner_embed(batch.ner) 17 | word_vecs = torch.cat([token_vec, pos_vec, ner_vec], -1)[:,0] 18 | lstm_module = model.lstm.rnn 19 | 20 | hidden_dim = model.hidden_dim 21 | T = word_vecs.size(0) 22 | relevant, irrelevant, relevant_h, irrelevant_h = torch.zeros(T, hidden_dim).to(gpu),torch.zeros(T, hidden_dim).to(gpu),\ 23 | torch.zeros(T, hidden_dim).to(gpu),torch.zeros(T, hidden_dim).to(gpu) 24 | W_ii, W_if, W_ig, W_io = torch.split(lstm_module.weight_ih_l0, lstm_module.hidden_size) 25 | W_hi, W_hf, W_hg, W_ho = torch.split(lstm_module.weight_hh_l0, lstm_module.hidden_size) 26 | b_i, b_f, b_g, b_o = torch.split(lstm_module.bias_ih_l0 + lstm_module.bias_hh_l0, lstm_module.hidden_size) 27 | 28 | i_states, g_states, f_states, c_states = [], [], [None], [] 29 | o_states = [] 30 | 31 | for i in range(T): 32 | bias_in_rel = is_in_intervals(i, intervals) 33 | 34 | if i > 0: 35 | prev_rel_h = relevant_h[i - 1] 36 | prev_irrel_h = irrelevant_h[i - 1] 37 | else: 38 | prev_rel_h = torch.zeros(hidden_dim).to(gpu) 39 | prev_irrel_h = torch.zeros(hidden_dim).to(gpu) 40 | rel_i = torch.matmul(W_hi, prev_rel_h) 41 | rel_g = torch.matmul(W_hg, prev_rel_h) 42 | rel_f = torch.matmul(W_hf, prev_rel_h) 43 | rel_o = torch.matmul(W_ho, prev_rel_h) 44 | irrel_i = torch.matmul(W_hi, prev_irrel_h) 45 | irrel_g = torch.matmul(W_hg, prev_irrel_h) 46 | irrel_f = torch.matmul(W_hf, prev_irrel_h) 47 | irrel_o = torch.matmul(W_ho, prev_irrel_h) 48 | if is_in_intervals(i, intervals): 49 | rel_i = rel_i + torch.matmul(W_ii, word_vecs[i]) 50 | rel_g = rel_g + torch.matmul(W_ig, word_vecs[i]) 51 | rel_f = rel_f + torch.matmul(W_if, word_vecs[i]) 52 | rel_o = rel_o + torch.matmul(W_io, word_vecs[i]) 53 | else: 54 | irrel_i = irrel_i + torch.matmul(W_ii, word_vecs[i]) 55 | irrel_g = irrel_g + torch.matmul(W_ig, word_vecs[i]) 56 | irrel_f = irrel_f + torch.matmul(W_if, word_vecs[i]) 57 | irrel_o = irrel_o + torch.matmul(W_io, word_vecs[i]) 58 | rel_contrib_i, irrel_contrib_i, bias_contrib_i = decomp_activation_three_with_states(rel_i, irrel_i, b_i, F.sigmoid, hist_states, 'i', i, bias_in_rel) 59 | rel_contrib_g, irrel_contrib_g, bias_contrib_g = decomp_activation_three_with_states(rel_g, irrel_g, b_g, F.tanh, hist_states, 'g', i, bias_in_rel) 60 | 61 | i_states.append(sum([rel_i, irrel_i, b_i])) 62 | g_states.append(sum([rel_g, irrel_g, b_g])) 63 | 64 | relevant[i], irrelevant[i] = mult_terms(rel_contrib_g, irrel_contrib_g, bias_contrib_g, rel_contrib_i, 65 | irrel_contrib_i, bias_contrib_i, hist_states, 'act_g', 'act_i', i, bias_in_rel) 66 | 67 | if i > 0: 68 | rel_contrib_f, irrel_contrib_f, bias_contrib_f = decomp_activation_three_with_states(rel_f, irrel_f, b_f, 69 | F.sigmoid, hist_states, 70 | 'f', i, bias_in_rel) 71 | 72 | rel_plus, irrel_plus = mult_terms(rel_contrib_f, irrel_contrib_f, bias_contrib_f, relevant[i-1], 73 | irrelevant[i-1], 0, hist_states, 'act_f', 'temp_c', i, bias_in_rel) 74 | relevant[i] += rel_plus 75 | irrelevant[i] += irrel_plus 76 | 77 | f_states.append(sum([rel_f, irrel_f, b_f])) 78 | 79 | o = sigmoid(torch.matmul(W_io, word_vecs[i]) + torch.matmul(W_ho, prev_rel_h + prev_irrel_h) + b_o) 80 | rel_contrib_o, irrel_contrib_o, bias_contrib_o = decomp_activation_three_with_states(rel_o, irrel_o, b_o, F.sigmoid, hist_states, 'o', i, bias_in_rel) 81 | new_rel_h, new_irrel_h = decomp_activation_two_with_states(relevant[i], irrelevant[i], F.tanh, hist_states, 'c', i) 82 | c_states.append(sum([relevant[i], irrelevant[i]])) 83 | o_states.append(o) 84 | relevant_h[i], irrelevant_h[i] = mult_terms(rel_contrib_o, irrel_contrib_o, bias_contrib_o, new_rel_h, 85 | new_irrel_h, 0, hist_states, 'o', 'tanhc', i, bias_in_rel) 86 | 87 | W_out = model.hidden_to_label.weight.data 88 | 89 | if hasattr(model, 'drop'): 90 | relevant_h[T-1], irrelevant_h[T-1] = model.drop(relevant_h[T-1]), model.drop(irrelevant_h[T-1]) 91 | if not args.mean_hidden: 92 | scores = torch.matmul(W_out, relevant_h[T - 1]) 93 | irrel_scores = torch.matmul(W_out, irrelevant_h[T - 1]) 94 | else: 95 | mean_hidden = torch.mean(relevant_h, 0) # [H] 96 | scores = torch.matmul(W_out, mean_hidden) 97 | irrel_scores = torch.matmul(W_out, torch.mean(irrelevant_h, 0)) 98 | 99 | states = { 100 | 'i': i_states, 101 | 'g': g_states, 102 | 'f': f_states, 103 | 'c': c_states, 104 | 'o': o_states 105 | } 106 | 107 | # if any(np.isnan(scores)): 108 | # print(1) 109 | 110 | return scores, irrel_scores, states 111 | 112 | 113 | def torch_mul(param, h): 114 | param = param.unsqueeze(0) # [1, h1, h2] 115 | h = h.unsqueeze(-1) # [B, h2, 1] 116 | mult = torch.matmul(param, h) # [B, h1, 1] 117 | return mult.squeeze(-1) 118 | 119 | def get_lstm_states(batch, model, gpu): 120 | if not args.task == 'tacred': 121 | word_vecs = model.embed(batch.text) 122 | lstm_module = model.lstm 123 | else: 124 | token_vec = model.embed(batch.text) 125 | pos_vec = model.pos_embed(batch.pos) 126 | ner_vec = model.ner_embed(batch.ner) 127 | word_vecs = torch.cat([token_vec, pos_vec, ner_vec], -1) 128 | word_vecs = model.drop(word_vecs) 129 | lstm_module = model.lstm.rnn 130 | 131 | batch_size = word_vecs.size(1) 132 | hidden_dim = model.hidden_dim 133 | T = word_vecs.size(0) 134 | prev_c, prev_h = torch.zeros(batch_size, hidden_dim).to(gpu), torch.zeros(batch_size, hidden_dim).to(gpu) 135 | W_ii, W_if, W_ig, W_io = torch.split(lstm_module.weight_ih_l0, lstm_module.hidden_size) 136 | W_hi, W_hf, W_hg, W_ho = torch.split(lstm_module.weight_hh_l0, lstm_module.hidden_size) 137 | b_i, b_f, b_g, b_o = torch.split(lstm_module.bias_ih_l0 + lstm_module.bias_hh_l0, lstm_module.hidden_size) 138 | 139 | i_states, g_states, f_states, c_states = [], [], [], [] 140 | act_i_states, act_g_states, act_f_states = [], [], [] 141 | o_states = [] 142 | temp_c_states = [] 143 | tanhc_states = [] 144 | 145 | for ts in range(T): 146 | i = torch_mul(W_hi, prev_h) 147 | g = torch_mul(W_hg, prev_h) 148 | f = torch_mul(W_hf, prev_h) 149 | 150 | i += torch_mul(W_ii, word_vecs[ts]) + b_i.view(1,-1) 151 | g += torch_mul(W_ig, word_vecs[ts]) + b_g.view(1,-1) 152 | f += torch_mul(W_if, word_vecs[ts]) + b_f.view(1,-1) 153 | 154 | i_states.append(i) 155 | g_states.append(g) 156 | f_states.append(f) 157 | 158 | act_i_states.append(sigmoid(i)) 159 | act_g_states.append(tanh(g)) 160 | act_f_states.append(sigmoid(f)) 161 | temp_c_states.append(prev_c) 162 | 163 | c = sigmoid(f) * prev_c + sigmoid(i) * tanh(g) 164 | o = sigmoid(torch_mul(W_io, word_vecs[ts]) + torch_mul(W_ho, prev_h) + b_o.view(1,-1)) 165 | c_states.append(c) 166 | tanhc_states.append(tanh(c)) 167 | o_states.append(o) 168 | 169 | h = o * tanh(c) 170 | 171 | prev_h = h 172 | prev_c = c 173 | 174 | 175 | states = { 176 | 'i': i_states, 177 | 'g': g_states, 178 | 'f': f_states, 179 | 'c': c_states, 180 | 'o': o_states, 181 | 'act_i': act_i_states, 182 | 'act_g': act_g_states, 183 | 'act_f': act_f_states, 184 | 'temp_c': temp_c_states, 185 | 'tanhc': tanhc_states 186 | } 187 | 188 | return states 189 | 190 | 191 | def decomp_activation_two_with_states(a, b, activation, states, state_key, t): 192 | rel = 0 193 | if states is None or not states['c']: 194 | exemplar_size = 0 195 | else: 196 | exemplar_size = states['c'][0].size(0) 197 | for idx in range(exemplar_size): 198 | hs = states[state_key][t][idx] 199 | rel += activation(hs) - activation(hs - a) 200 | 201 | rel += activation(a + b) - activation(b) 202 | rel /= exemplar_size + 1 203 | 204 | irrel = activation(a + b) - rel 205 | return rel, irrel 206 | 207 | 208 | def decomp_activation_three_with_states(rel_x, irrel_x, bias_x, activation, states, state_key, t, bias_in_rel): 209 | rel = 0 210 | if states is None or not states['c']: 211 | exemplar_size = 0 212 | else: 213 | exemplar_size = states['c'][0].size(0) 214 | 215 | for idx in range(exemplar_size): 216 | hs = states[state_key][t][idx] 217 | rel += activation(hs) - activation(hs - rel_x) 218 | 219 | rel += activation(rel_x + irrel_x + bias_x) - activation(irrel_x + bias_x) 220 | rel /= exemplar_size + 1 221 | 222 | bias = activation(bias_x) 223 | irrel = activation(rel_x + irrel_x + bias_x) - rel - bias 224 | 225 | return rel, irrel, bias 226 | 227 | 228 | def mult_terms(rel_a, irrel_a, bias_a, rel_b, irrel_b, bias_b, states, state_key_a, state_key_b, t, bias_in_rel): 229 | rel = 0 230 | 231 | if states is None or not states['c']: 232 | exemplar_size = 0 233 | else: 234 | exemplar_size = states['c'][0].size(0) 235 | 236 | for idx in range(exemplar_size): 237 | hs_a = states[state_key_a][t][idx] 238 | hs_b = states[state_key_b][t][idx] 239 | rel += rel_a * (hs_b - rel_b - bias_b) + rel_b * (hs_a - rel_a - bias_a) 240 | rel += rel_a * irrel_b + rel_b * irrel_a 241 | rel /= exemplar_size + 1 242 | rel += rel_a * rel_b + rel_a * bias_b + rel_b * bias_a 243 | 244 | irrel = (rel_a + irrel_a + bias_a) * (rel_b + irrel_b + bias_b) - rel 245 | return rel, irrel -------------------------------------------------------------------------------- /hiexpl/algo/scd_lstm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from algo.cd_func import CD 3 | from algo.scd_func import CD_gpu, get_lstm_states 4 | from algo.soc_lstm import ExplanationBase, SOCForLSTM, Batch, append_extra_input, normalize_logit 5 | import copy 6 | from utils.args import get_args 7 | 8 | args = get_args() 9 | 10 | 11 | class CDForLSTM(ExplanationBase): 12 | def __init__(self, target_model, data_iterator, vocab, pad_variation, tree_path, output_path, config, pad_idx=1): 13 | super().__init__(target_model, data_iterator, vocab, tree_path, output_path, config, pad_idx) 14 | self.pad_variation = pad_variation 15 | 16 | def explain_single(self, inp, inp_id, region, extra_input=None): 17 | """ 18 | :param region: the input region to be explained 19 | :param inp: numpy array 20 | :param inp_id: int 21 | :return: 22 | """ 23 | inp = torch.from_numpy(inp).long().view(-1,1) 24 | if self.gpu >= 0: 25 | inp = inp.to(self.gpu) 26 | batch = Batch(text=inp) 27 | if extra_input is not None: 28 | append_extra_input(batch, extra_input) 29 | rel_scores, irrel_score = CD(batch, self.model, [region]) 30 | if rel_scores.shape[0] == 2: 31 | score = rel_scores[0] - rel_scores[1] 32 | else: 33 | gt_label = extra_input['gt_label'] 34 | contrib_logits = torch.from_numpy(rel_scores).cuda().view(1,-1) 35 | score = normalize_logit(contrib_logits, gt_label) 36 | score = score.item() 37 | return score 38 | 39 | 40 | class SCDForLSTM(SOCForLSTM): 41 | def __init__(self, target_model, lm_model, data_iterator, vocab, tree_path, output_path, config, pad_idx=1): 42 | super().__init__(target_model, lm_model, data_iterator, vocab, tree_path, output_path, config, pad_idx) 43 | self.sample_num = config.sample_n if not args.cd_pad else 1 44 | 45 | def get_states(self, inp, x_regions, nb_regions, extra_input): 46 | # suppose only have one x_region and one nb_region 47 | x_region = x_regions[0] 48 | nb_region = nb_regions[0] 49 | 50 | inp_length = torch.LongTensor([len(inp)]) 51 | fw_pos = torch.LongTensor([min(x_region[1] + 1, len(inp))]) 52 | bw_pos = torch.LongTensor([max(x_region[0] - 1, -1)]) 53 | 54 | inp_lm = copy.copy(inp) 55 | for i in range(len(inp_lm)): 56 | if nb_region[0] <= i <= nb_region[1] and not x_region[0] <= i <= x_region[1]: 57 | inp_lm[i] = 1 58 | inp_th = torch.from_numpy(inp_lm).long().view(-1, 1) 59 | 60 | if self.gpu >= 0: 61 | inp_th = inp_th.to(self.gpu) 62 | inp_length = inp_length.to(self.gpu) 63 | fw_pos = fw_pos.to(self.gpu) 64 | bw_pos = bw_pos.to(self.gpu) 65 | 66 | batch = Batch(text=inp_th, length=inp_length, fw_pos=fw_pos, bw_pos=bw_pos) 67 | 68 | all_filled_inp = [] 69 | max_sample_length = (self.nb_range + 1) if self.nb_method == 'ngram' else (inp_th.size(0) + 1) 70 | 71 | if not args.cd_pad: 72 | fw_sample_outputs, bw_sample_outputs = self.lm_model.sample_n('random', batch, 73 | max_sample_length=max_sample_length, 74 | sample_num=self.sample_num) 75 | for sample_i in range(self.sample_num): 76 | fw_sample_seq, bw_sample_seq = fw_sample_outputs[:,sample_i].cpu().numpy(), \ 77 | bw_sample_outputs[:,sample_i].cpu().numpy() 78 | filled_inp = copy.copy(inp) 79 | len_bw = x_region[0] - nb_region[0] 80 | len_fw = nb_region[1] - x_region[1] 81 | if len_bw > 0: 82 | filled_inp[nb_region[0]:x_region[0]] = bw_sample_seq[-len_bw:] 83 | if len_fw > 0: 84 | filled_inp[x_region[1] + 1:nb_region[1] + 1] = fw_sample_seq[:len_fw] 85 | 86 | filled_inp = torch.from_numpy(filled_inp).long() 87 | if self.gpu >= 0: 88 | filled_inp = filled_inp.to(self.gpu) 89 | all_filled_inp.append(filled_inp) 90 | else: 91 | # pad the nb region to 1 92 | filled_inp = copy.copy(inp) 93 | for i in range(nb_region[0], nb_region[1] + 1): 94 | if not x_region[0] <= i <= x_region[1]: 95 | filled_inp[i] = 1 96 | filled_inp = torch.from_numpy(filled_inp).long() 97 | if self.gpu >= 0: 98 | filled_inp = filled_inp.to(self.gpu) 99 | all_filled_inp.append(filled_inp) 100 | 101 | all_filled_inp = torch.stack(all_filled_inp, -1) # [T,B] 102 | batch = Batch(text=all_filled_inp) 103 | 104 | if extra_input is not None: 105 | append_extra_input(batch, extra_input) 106 | all_states = get_lstm_states(batch, self.model, self.gpu) 107 | 108 | return all_states 109 | 110 | def explain_single(self, inp, inp_id, region, extra_input=None): 111 | """ 112 | :param region: the input region to be explained 113 | :param inp: numpy array 114 | :param inp_id: int 115 | :return: 116 | """ 117 | if self.nb_method == 'tree': 118 | tree = self.trees[inp_id] 119 | mask_regions = self.get_tree_mask_region(tree, region, inp) 120 | elif self.nb_method == 'ngram': 121 | mask_regions = self.get_ngram_mask_region(region, inp) 122 | else: 123 | raise NotImplementedError('unknown method %s' % self.nb_method) 124 | with torch.no_grad(): 125 | if self.sample_num > 0: 126 | states = self.get_states(inp, [region], mask_regions, extra_input) 127 | else: 128 | states = None 129 | inp = torch.from_numpy(inp).long().view(-1, 1) 130 | if self.gpu >= 0: 131 | inp = inp.to(self.gpu) 132 | batch = Batch(text=inp) 133 | if extra_input is not None: 134 | append_extra_input(batch, extra_input) 135 | rel_scores, irrel_scores, _ = CD_gpu(batch, self.model, [region], states, gpu=self.gpu) 136 | if rel_scores.shape[0] == 2: 137 | score = rel_scores[0] - rel_scores[1] 138 | else: 139 | gt_label = extra_input['gt_label'] 140 | contrib_logits = rel_scores.view(1,-1) 141 | score = normalize_logit(contrib_logits, gt_label) 142 | score = score.item() 143 | return score -------------------------------------------------------------------------------- /hiexpl/algo/scd_transformer.py: -------------------------------------------------------------------------------- 1 | from algo.soc_transformer import ExplanationBaseForTransformer, SOCForTransformer, normalize_logit 2 | from bert.run_classifier import BertTokenizer, predict_and_explain_wrapper_unbatched 3 | from algo.soc_transformer import get_data_iterator_bert, bert_id_to_lm_id, lm_id_to_bert_id, Batch 4 | from bert.modeling import global_state_dict 5 | import torch 6 | import copy 7 | import numpy as np 8 | from utils.args import get_args 9 | 10 | args = get_args() 11 | 12 | 13 | class CDForTransformer(ExplanationBaseForTransformer): 14 | def __init__(self, model, tree_path, output_path, config): 15 | super().__init__(model, tree_path, output_path, config) 16 | self.model = model 17 | self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True, cache_dir='bert/cache') 18 | self.tree_path = tree_path 19 | self.max_seq_length = 128 20 | self.batch_size = config.batch_size 21 | self.iterator = get_data_iterator_bert(self.tree_path, self.tokenizer, self.max_seq_length, self.batch_size) 22 | 23 | def explain_single_transformer(self, input_ids, input_mask, segment_ids, region, label=None): 24 | if self.gpu >= 0: 25 | input_ids, input_mask, segment_ids = input_ids.to(self.gpu), input_mask.to(self.gpu), segment_ids.to( 26 | self.gpu) 27 | if not args.task == 'tacred': 28 | score = predict_and_explain_wrapper_unbatched(self.model, input_ids, segment_ids, input_mask, region) 29 | else: 30 | score = predict_and_explain_wrapper_unbatched(self.model, input_ids, segment_ids, input_mask, region, 31 | normalizer=normalize_logit, label=label) 32 | return score 33 | 34 | 35 | class SCDForTransformer(SOCForTransformer): 36 | def __init__(self, target_model, lm_model, vocab, tree_path, output_path, config): 37 | super().__init__(target_model, lm_model, vocab, tree_path, output_path, config) 38 | if args.cd_pad: 39 | self.sample_num = 1 40 | 41 | def get_states(self, inp, inp_mask, segment_ids, x_regions, nb_regions): 42 | global_state_dict.init_store_states() 43 | 44 | x_region = x_regions[0] 45 | nb_region = nb_regions[0] 46 | 47 | inp_length = 0 48 | for i in range(len(inp_mask)): 49 | if inp_mask[i] == 1: 50 | inp_length += 1 51 | else: 52 | break 53 | 54 | # mask everything outside the x_region and inside nb region 55 | inp_lm = copy.copy(inp) 56 | for i in range(len(inp_lm)): 57 | if nb_region[0] <= i <= nb_region[1] and not x_region[0] <= i <= x_region[1]: 58 | inp_lm[i] = self.tokenizer.vocab['[PAD]'] 59 | 60 | if not args.task == 'tacred': 61 | inp_th = torch.from_numpy( 62 | bert_id_to_lm_id(inp_lm[1:inp_length - 1], self.tokenizer, self.vocab)).long().view(-1, 1) 63 | inp_length = torch.LongTensor([inp_length - 2]) 64 | fw_pos = torch.LongTensor([min(x_region[1] + 1 - 1, len(inp) - 2)]) 65 | bw_pos = torch.LongTensor([max(x_region[0] - 1 - 1, -1)]) 66 | 67 | else: 68 | inp_th = torch.from_numpy( 69 | bert_id_to_lm_id(inp_lm[:inp_length], self.tokenizer, self.vocab)).long().view(-1, 1) 70 | inp_length = torch.LongTensor([inp_length]) 71 | fw_pos = torch.LongTensor([min(x_region[1] + 1, inp_length.item() - 1)]) 72 | bw_pos = torch.LongTensor([max(x_region[0] - 1, 0)]) 73 | 74 | if self.gpu >= 0: 75 | inp_th = inp_th.to(self.gpu) 76 | inp_length = inp_length.to(self.gpu) 77 | fw_pos = fw_pos.to(self.gpu) 78 | bw_pos = bw_pos.to(self.gpu) 79 | 80 | batch = Batch(text=inp_th, length=inp_length, fw_pos=fw_pos, bw_pos=bw_pos) 81 | 82 | inp_enb = [] 83 | 84 | max_sample_length = (self.nb_range + 1) if self.nb_method == 'ngram' else (inp_th.size(0) + 1) 85 | fw_sample_outputs, bw_sample_outputs = self.lm_model.sample_n('random', batch, 86 | max_sample_length=max_sample_length, 87 | sample_num=self.sample_num) 88 | 89 | if not args.cd_pad: 90 | for sample_i in range(self.sample_num): 91 | fw_sample_seq, bw_sample_seq = fw_sample_outputs[:, sample_i].cpu().numpy(), \ 92 | bw_sample_outputs[:, sample_i].cpu().numpy() 93 | filled_inp = copy.copy(inp) 94 | 95 | len_bw = x_region[0] - nb_region[0] 96 | len_fw = nb_region[1] - x_region[1] 97 | if len_bw > 0: 98 | filled_inp[nb_region[0]:x_region[0]] = lm_id_to_bert_id(bw_sample_seq[-len_bw:], self.tokenizer, 99 | self.vocab) 100 | if len_fw > 0: 101 | filled_inp[x_region[1] + 1:nb_region[1] + 1] = lm_id_to_bert_id(fw_sample_seq[:len_fw], 102 | self.tokenizer, 103 | self.vocab) 104 | inp_enb.append(filled_inp) 105 | 106 | else: 107 | filled_inp = copy.copy(inp) 108 | for i in range(nb_region[0], nb_region[1] + 1): 109 | if not x_region[0] <= i <= x_region[1]: 110 | filled_inp[i] = self.tokenizer.vocab['[PAD]'] 111 | inp_enb.append(filled_inp) 112 | inp_enb = np.stack(inp_enb) 113 | inp_enb = torch.from_numpy(inp_enb).long() 114 | inp_enb_mask = torch.from_numpy(inp_mask).long() 115 | 116 | if self.gpu >= 0: 117 | inp_enb = inp_enb.to(self.gpu) 118 | inp_enb_mask = inp_enb_mask.to(self.gpu) 119 | segment_ids = segment_ids.to(self.gpu) 120 | 121 | inp_enb_mask = inp_enb_mask.expand(inp_enb.size(0), -1) 122 | segment_ids = segment_ids.expand(inp_enb.size(0), -1) 123 | 124 | self.model.predict_and_explain(inp_enb, [[x_region]] * inp_enb.size(0), 125 | segment_ids[:, :inp_enb.size(1)], inp_enb_mask) 126 | 127 | global_state_dict.init_fetch_states() 128 | return 129 | 130 | def explain_single_transformer(self, input_ids, input_mask, segment_ids, region, label=None): 131 | inp_flatten = input_ids.view(-1).cpu().numpy() 132 | inp_mask_flatten = input_mask.view(-1).cpu().numpy() 133 | if self.nb_method == 'ngram': 134 | mask_regions = self.get_ngram_mask_region(region, inp_flatten) 135 | else: 136 | raise NotImplementedError('unknown method %s' % self.nb_method) 137 | 138 | total_len = int(inp_mask_flatten.sum()) 139 | span_len = region[1] - region[0] + 1 140 | 141 | global_state_dict.total_span_len = total_len 142 | global_state_dict.rel_span_len = span_len 143 | 144 | self.get_states(inp_flatten, inp_mask_flatten, segment_ids, [region], mask_regions) 145 | if self.gpu >= 0: 146 | input_ids, input_mask, segment_ids = input_ids.to(self.gpu), input_mask.to(self.gpu), segment_ids.to( 147 | self.gpu) 148 | if not args.task == 'tacred': 149 | score = predict_and_explain_wrapper_unbatched(self.model, input_ids, segment_ids, input_mask, region) 150 | else: 151 | score = predict_and_explain_wrapper_unbatched(self.model, input_ids, segment_ids, input_mask, region, 152 | normalizer=normalize_logit, label=label) 153 | return score 154 | -------------------------------------------------------------------------------- /hiexpl/bert/__init__.py: -------------------------------------------------------------------------------- 1 | __version__ = "0.6.1" 2 | from .tokenization import BertTokenizer, BasicTokenizer, WordpieceTokenizer 3 | from .tokenization_openai import OpenAIGPTTokenizer 4 | from .tokenization_transfo_xl import (TransfoXLTokenizer, TransfoXLCorpus) 5 | from .tokenization_gpt2 import GPT2Tokenizer 6 | 7 | from .modeling import (BertConfig, BertModel, BertForPreTraining, 8 | BertForMaskedLM, BertForNextSentencePrediction, 9 | BertForSequenceClassification, BertForMultipleChoice, 10 | BertForTokenClassification, BertForQuestionAnswering, 11 | load_tf_weights_in_bert) 12 | from .modeling_openai import (OpenAIGPTConfig, OpenAIGPTModel, 13 | OpenAIGPTLMHeadModel, OpenAIGPTDoubleHeadsModel, 14 | load_tf_weights_in_openai_gpt) 15 | from .modeling_transfo_xl import (TransfoXLConfig, TransfoXLModel, TransfoXLLMHeadModel, 16 | load_tf_weights_in_transfo_xl) 17 | from .modeling_gpt2 import (GPT2Config, GPT2Model, 18 | GPT2LMHeadModel, GPT2DoubleHeadsModel, 19 | load_tf_weights_in_gpt2) 20 | 21 | from .optimization import BertAdam 22 | from .optimization_openai import OpenAIAdam 23 | 24 | from .file_utils import PYTORCH_PRETRAINED_BERT_CACHE, cached_path 25 | -------------------------------------------------------------------------------- /hiexpl/bert/__main__.py: -------------------------------------------------------------------------------- 1 | # coding: utf8 2 | def main(): 3 | import sys 4 | if (len(sys.argv) != 4 and len(sys.argv) != 5) or sys.argv[1] not in [ 5 | "convert_tf_checkpoint_to_pytorch", 6 | "convert_openai_checkpoint", 7 | "convert_transfo_xl_checkpoint", 8 | "convert_gpt2_checkpoint", 9 | ]: 10 | print( 11 | "Should be used as one of: \n" 12 | ">> `pytorch_pretrained_bert convert_tf_checkpoint_to_pytorch TF_CHECKPOINT TF_CONFIG PYTORCH_DUMP_OUTPUT`, \n" 13 | ">> `pytorch_pretrained_bert convert_openai_checkpoint OPENAI_GPT_CHECKPOINT_FOLDER_PATH PYTORCH_DUMP_OUTPUT [OPENAI_GPT_CONFIG]`, \n" 14 | ">> `pytorch_pretrained_bert convert_transfo_xl_checkpoint TF_CHECKPOINT_OR_DATASET PYTORCH_DUMP_OUTPUT [TF_CONFIG]` or \n" 15 | ">> `pytorch_pretrained_bert convert_gpt2_checkpoint TF_CHECKPOINT PYTORCH_DUMP_OUTPUT [GPT2_CONFIG]`") 16 | else: 17 | if sys.argv[1] == "convert_tf_checkpoint_to_pytorch": 18 | try: 19 | from .convert_tf_checkpoint_to_pytorch import convert_tf_checkpoint_to_pytorch 20 | except ImportError: 21 | print("pytorch_pretrained_bert can only be used from the commandline to convert TensorFlow models in PyTorch, " 22 | "In that case, it requires TensorFlow to be installed. Please see " 23 | "https://www.tensorflow.org/install/ for installation instructions.") 24 | raise 25 | 26 | if len(sys.argv) != 5: 27 | # pylint: disable=line-too-long 28 | print("Should be used as `pytorch_pretrained_bert convert_tf_checkpoint_to_pytorch TF_CHECKPOINT TF_CONFIG PYTORCH_DUMP_OUTPUT`") 29 | else: 30 | PYTORCH_DUMP_OUTPUT = sys.argv.pop() 31 | TF_CONFIG = sys.argv.pop() 32 | TF_CHECKPOINT = sys.argv.pop() 33 | convert_tf_checkpoint_to_pytorch(TF_CHECKPOINT, TF_CONFIG, PYTORCH_DUMP_OUTPUT) 34 | elif sys.argv[1] == "convert_openai_checkpoint": 35 | from .convert_openai_checkpoint_to_pytorch import convert_openai_checkpoint_to_pytorch 36 | OPENAI_GPT_CHECKPOINT_FOLDER_PATH = sys.argv[2] 37 | PYTORCH_DUMP_OUTPUT = sys.argv[3] 38 | if len(sys.argv) == 5: 39 | OPENAI_GPT_CONFIG = sys.argv[4] 40 | else: 41 | OPENAI_GPT_CONFIG = "" 42 | convert_openai_checkpoint_to_pytorch(OPENAI_GPT_CHECKPOINT_FOLDER_PATH, 43 | OPENAI_GPT_CONFIG, 44 | PYTORCH_DUMP_OUTPUT) 45 | elif sys.argv[1] == "convert_transfo_xl_checkpoint": 46 | try: 47 | from .convert_transfo_xl_checkpoint_to_pytorch import convert_transfo_xl_checkpoint_to_pytorch 48 | except ImportError: 49 | print("pytorch_pretrained_bert can only be used from the commandline to convert TensorFlow models in PyTorch, " 50 | "In that case, it requires TensorFlow to be installed. Please see " 51 | "https://www.tensorflow.org/install/ for installation instructions.") 52 | raise 53 | 54 | if 'ckpt' in sys.argv[2].lower(): 55 | TF_CHECKPOINT = sys.argv[2] 56 | TF_DATASET_FILE = "" 57 | else: 58 | TF_DATASET_FILE = sys.argv[2] 59 | TF_CHECKPOINT = "" 60 | PYTORCH_DUMP_OUTPUT = sys.argv[3] 61 | if len(sys.argv) == 5: 62 | TF_CONFIG = sys.argv[4] 63 | else: 64 | TF_CONFIG = "" 65 | convert_transfo_xl_checkpoint_to_pytorch(TF_CHECKPOINT, TF_CONFIG, PYTORCH_DUMP_OUTPUT, TF_DATASET_FILE) 66 | else: 67 | try: 68 | from .convert_gpt2_checkpoint_to_pytorch import convert_gpt2_checkpoint_to_pytorch 69 | except ImportError: 70 | print("pytorch_pretrained_bert can only be used from the commandline to convert TensorFlow models in PyTorch, " 71 | "In that case, it requires TensorFlow to be installed. Please see " 72 | "https://www.tensorflow.org/install/ for installation instructions.") 73 | raise 74 | 75 | TF_CHECKPOINT = sys.argv[2] 76 | PYTORCH_DUMP_OUTPUT = sys.argv[3] 77 | if len(sys.argv) == 5: 78 | TF_CONFIG = sys.argv[4] 79 | else: 80 | TF_CONFIG = "" 81 | convert_gpt2_checkpoint_to_pytorch(TF_CHECKPOINT, TF_CONFIG, PYTORCH_DUMP_OUTPUT) 82 | if __name__ == '__main__': 83 | main() 84 | -------------------------------------------------------------------------------- /hiexpl/bert/convert_gpt2_checkpoint_to_pytorch.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Convert OpenAI GPT checkpoint.""" 16 | 17 | from __future__ import absolute_import, division, print_function 18 | 19 | import argparse 20 | from io import open 21 | 22 | import torch 23 | 24 | from pytorch_pretrained_bert.modeling_gpt2 import (CONFIG_NAME, WEIGHTS_NAME, 25 | GPT2Config, 26 | GPT2Model, 27 | load_tf_weights_in_gpt2) 28 | 29 | 30 | def convert_gpt2_checkpoint_to_pytorch(gpt2_checkpoint_path, gpt2_config_file, pytorch_dump_folder_path): 31 | # Construct model 32 | if gpt2_config_file == "": 33 | config = GPT2Config() 34 | else: 35 | config = GPT2Config(gpt2_config_file) 36 | model = GPT2Model(config) 37 | 38 | # Load weights from numpy 39 | load_tf_weights_in_gpt2(model, gpt2_checkpoint_path) 40 | 41 | # Save pytorch-model 42 | pytorch_weights_dump_path = pytorch_dump_folder_path + '/' + WEIGHTS_NAME 43 | pytorch_config_dump_path = pytorch_dump_folder_path + '/' + CONFIG_NAME 44 | print("Save PyTorch model to {}".format(pytorch_weights_dump_path)) 45 | torch.save(model.state_dict(), pytorch_weights_dump_path) 46 | print("Save configuration file to {}".format(pytorch_config_dump_path)) 47 | with open(pytorch_config_dump_path, "w", encoding="utf-8") as f: 48 | f.write(config.to_json_string()) 49 | 50 | 51 | if __name__ == "__main__": 52 | parser = argparse.ArgumentParser() 53 | ## Required parameters 54 | parser.add_argument("--gpt2_checkpoint_path", 55 | default = None, 56 | type = str, 57 | required = True, 58 | help = "Path the TensorFlow checkpoint path.") 59 | parser.add_argument("--pytorch_dump_folder_path", 60 | default = None, 61 | type = str, 62 | required = True, 63 | help = "Path to the output PyTorch model.") 64 | parser.add_argument("--gpt2_config_file", 65 | default = "", 66 | type = str, 67 | help = "An optional config json file corresponding to the pre-trained OpenAI model. \n" 68 | "This specifies the model architecture.") 69 | args = parser.parse_args() 70 | convert_gpt2_checkpoint_to_pytorch(args.gpt2_checkpoint_path, 71 | args.gpt2_config_file, 72 | args.pytorch_dump_folder_path) 73 | -------------------------------------------------------------------------------- /hiexpl/bert/convert_openai_checkpoint_to_pytorch.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Convert OpenAI GPT checkpoint.""" 16 | 17 | from __future__ import absolute_import, division, print_function 18 | 19 | import argparse 20 | from io import open 21 | 22 | import torch 23 | 24 | from pytorch_pretrained_bert.modeling_openai import (CONFIG_NAME, WEIGHTS_NAME, 25 | OpenAIGPTConfig, 26 | OpenAIGPTModel, 27 | load_tf_weights_in_openai_gpt) 28 | 29 | 30 | def convert_openai_checkpoint_to_pytorch(openai_checkpoint_folder_path, openai_config_file, pytorch_dump_folder_path): 31 | # Construct model 32 | if openai_config_file == "": 33 | config = OpenAIGPTConfig() 34 | else: 35 | config = OpenAIGPTConfig(openai_config_file) 36 | model = OpenAIGPTModel(config) 37 | 38 | # Load weights from numpy 39 | load_tf_weights_in_openai_gpt(model, openai_checkpoint_folder_path) 40 | 41 | # Save pytorch-model 42 | pytorch_weights_dump_path = pytorch_dump_folder_path + '/' + WEIGHTS_NAME 43 | pytorch_config_dump_path = pytorch_dump_folder_path + '/' + CONFIG_NAME 44 | print("Save PyTorch model to {}".format(pytorch_weights_dump_path)) 45 | torch.save(model.state_dict(), pytorch_weights_dump_path) 46 | print("Save configuration file to {}".format(pytorch_config_dump_path)) 47 | with open(pytorch_config_dump_path, "w", encoding="utf-8") as f: 48 | f.write(config.to_json_string()) 49 | 50 | 51 | if __name__ == "__main__": 52 | parser = argparse.ArgumentParser() 53 | ## Required parameters 54 | parser.add_argument("--openai_checkpoint_folder_path", 55 | default = None, 56 | type = str, 57 | required = True, 58 | help = "Path the TensorFlow checkpoint path.") 59 | parser.add_argument("--pytorch_dump_folder_path", 60 | default = None, 61 | type = str, 62 | required = True, 63 | help = "Path to the output PyTorch model.") 64 | parser.add_argument("--openai_config_file", 65 | default = "", 66 | type = str, 67 | help = "An optional config json file corresponding to the pre-trained OpenAI model. \n" 68 | "This specifies the model architecture.") 69 | args = parser.parse_args() 70 | convert_openai_checkpoint_to_pytorch(args.openai_checkpoint_folder_path, 71 | args.openai_config_file, 72 | args.pytorch_dump_folder_path) 73 | -------------------------------------------------------------------------------- /hiexpl/bert/convert_tf_checkpoint_to_pytorch.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Convert BERT checkpoint.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import os 22 | import re 23 | import argparse 24 | import tensorflow as tf 25 | import torch 26 | import numpy as np 27 | 28 | from pytorch_pretrained_bert.modeling import BertConfig, BertForPreTraining, load_tf_weights_in_bert 29 | 30 | def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file, pytorch_dump_path): 31 | # Initialise PyTorch model 32 | config = BertConfig.from_json_file(bert_config_file) 33 | print("Building PyTorch model from configuration: {}".format(str(config))) 34 | model = BertForPreTraining(config) 35 | 36 | # Load weights from tf checkpoint 37 | load_tf_weights_in_bert(model, tf_checkpoint_path) 38 | 39 | # Save pytorch-model 40 | print("Save PyTorch model to {}".format(pytorch_dump_path)) 41 | torch.save(model.state_dict(), pytorch_dump_path) 42 | 43 | 44 | if __name__ == "__main__": 45 | parser = argparse.ArgumentParser() 46 | ## Required parameters 47 | parser.add_argument("--tf_checkpoint_path", 48 | default = None, 49 | type = str, 50 | required = True, 51 | help = "Path the TensorFlow checkpoint path.") 52 | parser.add_argument("--bert_config_file", 53 | default = None, 54 | type = str, 55 | required = True, 56 | help = "The config json file corresponding to the pre-trained BERT model. \n" 57 | "This specifies the model architecture.") 58 | parser.add_argument("--pytorch_dump_path", 59 | default = None, 60 | type = str, 61 | required = True, 62 | help = "Path to the output PyTorch model.") 63 | args = parser.parse_args() 64 | convert_tf_checkpoint_to_pytorch(args.tf_checkpoint_path, 65 | args.bert_config_file, 66 | args.pytorch_dump_path) 67 | -------------------------------------------------------------------------------- /hiexpl/bert/convert_transfo_xl_checkpoint_to_pytorch.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Convert Transformer XL checkpoint and datasets.""" 16 | 17 | from __future__ import absolute_import, division, print_function 18 | 19 | import argparse 20 | import os 21 | import sys 22 | from io import open 23 | 24 | import torch 25 | 26 | import pytorch_pretrained_bert.tokenization_transfo_xl as data_utils 27 | from pytorch_pretrained_bert.modeling_transfo_xl import (CONFIG_NAME, 28 | WEIGHTS_NAME, 29 | TransfoXLConfig, 30 | TransfoXLLMHeadModel, 31 | load_tf_weights_in_transfo_xl) 32 | from pytorch_pretrained_bert.tokenization_transfo_xl import (CORPUS_NAME, 33 | VOCAB_NAME) 34 | 35 | if sys.version_info[0] == 2: 36 | import cPickle as pickle 37 | else: 38 | import pickle 39 | 40 | # We do this to be able to load python 2 datasets pickles 41 | # See e.g. https://stackoverflow.com/questions/2121874/python-pickling-after-changing-a-modules-directory/2121918#2121918 42 | data_utils.Vocab = data_utils.TransfoXLTokenizer 43 | data_utils.Corpus = data_utils.TransfoXLCorpus 44 | sys.modules['data_utils'] = data_utils 45 | sys.modules['vocabulary'] = data_utils 46 | 47 | def convert_transfo_xl_checkpoint_to_pytorch(tf_checkpoint_path, 48 | transfo_xl_config_file, 49 | pytorch_dump_folder_path, 50 | transfo_xl_dataset_file): 51 | if transfo_xl_dataset_file: 52 | # Convert a pre-processed corpus (see original TensorFlow repo) 53 | with open(transfo_xl_dataset_file, "rb") as fp: 54 | corpus = pickle.load(fp, encoding="latin1") 55 | # Save vocabulary and dataset cache as Dictionaries (should be better than pickles for the long-term) 56 | pytorch_vocab_dump_path = pytorch_dump_folder_path + '/' + VOCAB_NAME 57 | print("Save vocabulary to {}".format(pytorch_vocab_dump_path)) 58 | corpus_vocab_dict = corpus.vocab.__dict__ 59 | torch.save(corpus_vocab_dict, pytorch_vocab_dump_path) 60 | 61 | corpus_dict_no_vocab = corpus.__dict__ 62 | corpus_dict_no_vocab.pop('vocab', None) 63 | pytorch_dataset_dump_path = pytorch_dump_folder_path + '/' + CORPUS_NAME 64 | print("Save dataset to {}".format(pytorch_dataset_dump_path)) 65 | torch.save(corpus_dict_no_vocab, pytorch_dataset_dump_path) 66 | 67 | if tf_checkpoint_path: 68 | # Convert a pre-trained TensorFlow model 69 | config_path = os.path.abspath(transfo_xl_config_file) 70 | tf_path = os.path.abspath(tf_checkpoint_path) 71 | 72 | print("Converting Transformer XL checkpoint from {} with config at {}".format(tf_path, config_path)) 73 | # Initialise PyTorch model 74 | if transfo_xl_config_file == "": 75 | config = TransfoXLConfig() 76 | else: 77 | config = TransfoXLConfig(transfo_xl_config_file) 78 | print("Building PyTorch model from configuration: {}".format(str(config))) 79 | model = TransfoXLLMHeadModel(config) 80 | 81 | model = load_tf_weights_in_transfo_xl(model, config, tf_path) 82 | # Save pytorch-model 83 | pytorch_weights_dump_path = os.path.join(pytorch_dump_folder_path, WEIGHTS_NAME) 84 | pytorch_config_dump_path = os.path.join(pytorch_dump_folder_path, CONFIG_NAME) 85 | print("Save PyTorch model to {}".format(os.path.abspath(pytorch_weights_dump_path))) 86 | torch.save(model.state_dict(), pytorch_weights_dump_path) 87 | print("Save configuration file to {}".format(os.path.abspath(pytorch_config_dump_path))) 88 | with open(pytorch_config_dump_path, "w", encoding="utf-8") as f: 89 | f.write(config.to_json_string()) 90 | 91 | 92 | if __name__ == "__main__": 93 | parser = argparse.ArgumentParser() 94 | parser.add_argument("--pytorch_dump_folder_path", 95 | default = None, 96 | type = str, 97 | required = True, 98 | help = "Path to the folder to store the PyTorch model or dataset/vocab.") 99 | parser.add_argument("--tf_checkpoint_path", 100 | default = "", 101 | type = str, 102 | help = "An optional path to a TensorFlow checkpoint path to be converted.") 103 | parser.add_argument("--transfo_xl_config_file", 104 | default = "", 105 | type = str, 106 | help = "An optional config json file corresponding to the pre-trained BERT model. \n" 107 | "This specifies the model architecture.") 108 | parser.add_argument("--transfo_xl_dataset_file", 109 | default = "", 110 | type = str, 111 | help = "An optional dataset file to be converted in a vocabulary.") 112 | args = parser.parse_args() 113 | convert_transfo_xl_checkpoint_to_pytorch(args.tf_checkpoint_path, 114 | args.transfo_xl_config_file, 115 | args.pytorch_dump_folder_path, 116 | args.transfo_xl_dataset_file) 117 | -------------------------------------------------------------------------------- /hiexpl/bert/decomp_util.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | -------------------------------------------------------------------------------- /hiexpl/bert/file_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Utilities for working with the local dataset cache. 3 | This file is adapted from the AllenNLP library at https://github.com/allenai/allennlp 4 | Copyright by the AllenNLP authors. 5 | """ 6 | from __future__ import (absolute_import, division, print_function, unicode_literals) 7 | 8 | import json 9 | import logging 10 | import os 11 | import shutil 12 | import tempfile 13 | from functools import wraps 14 | from hashlib import sha256 15 | import sys 16 | from io import open 17 | 18 | import boto3 19 | import requests 20 | from botocore.exceptions import ClientError 21 | from tqdm import tqdm 22 | 23 | try: 24 | from urllib.parse import urlparse 25 | except ImportError: 26 | from urlparse import urlparse 27 | 28 | try: 29 | from pathlib import Path 30 | PYTORCH_PRETRAINED_BERT_CACHE = Path(os.getenv('PYTORCH_PRETRAINED_BERT_CACHE', 31 | Path.home() / '.pytorch_pretrained_bert')) 32 | except (AttributeError, ImportError): 33 | PYTORCH_PRETRAINED_BERT_CACHE = os.getenv('PYTORCH_PRETRAINED_BERT_CACHE', 34 | os.path.join(os.path.expanduser("~"), '.pytorch_pretrained_bert')) 35 | 36 | logger = logging.getLogger(__name__) # pylint: disable=invalid-name 37 | 38 | 39 | def url_to_filename(url, etag=None): 40 | """ 41 | Convert `url` into a hashed filename in a repeatable way. 42 | If `etag` is specified, append its hash to the url's, delimited 43 | by a period. 44 | """ 45 | url_bytes = url.encode('utf-8') 46 | url_hash = sha256(url_bytes) 47 | filename = url_hash.hexdigest() 48 | 49 | if etag: 50 | etag_bytes = etag.encode('utf-8') 51 | etag_hash = sha256(etag_bytes) 52 | filename += '.' + etag_hash.hexdigest() 53 | 54 | return filename 55 | 56 | 57 | def filename_to_url(filename, cache_dir=None): 58 | """ 59 | Return the url and etag (which may be ``None``) stored for `filename`. 60 | Raise ``EnvironmentError`` if `filename` or its stored metadata do not exist. 61 | """ 62 | if cache_dir is None: 63 | cache_dir = PYTORCH_PRETRAINED_BERT_CACHE 64 | if sys.version_info[0] == 3 and isinstance(cache_dir, Path): 65 | cache_dir = str(cache_dir) 66 | 67 | cache_path = os.path.join(cache_dir, filename) 68 | if not os.path.exists(cache_path): 69 | raise EnvironmentError("file {} not found".format(cache_path)) 70 | 71 | meta_path = cache_path + '.json' 72 | if not os.path.exists(meta_path): 73 | raise EnvironmentError("file {} not found".format(meta_path)) 74 | 75 | with open(meta_path, encoding="utf-8") as meta_file: 76 | metadata = json.load(meta_file) 77 | url = metadata['url'] 78 | etag = metadata['etag'] 79 | 80 | return url, etag 81 | 82 | 83 | def cached_path(url_or_filename, cache_dir=None): 84 | """ 85 | Given something that might be a URL (or might be a local path), 86 | determine which. If it's a URL, download the file and cache it, and 87 | return the path to the cached file. If it's already a local path, 88 | make sure the file exists and then return the path. 89 | """ 90 | if cache_dir is None: 91 | cache_dir = PYTORCH_PRETRAINED_BERT_CACHE 92 | if sys.version_info[0] == 3 and isinstance(url_or_filename, Path): 93 | url_or_filename = str(url_or_filename) 94 | if sys.version_info[0] == 3 and isinstance(cache_dir, Path): 95 | cache_dir = str(cache_dir) 96 | 97 | parsed = urlparse(url_or_filename) 98 | 99 | if parsed.scheme in ('http', 'https', 's3'): 100 | # URL, so get it from the cache (downloading if necessary) 101 | return get_from_cache(url_or_filename, cache_dir) 102 | elif os.path.exists(url_or_filename): 103 | # File, and it exists. 104 | return url_or_filename 105 | elif parsed.scheme == '': 106 | # File, but it doesn't exist. 107 | raise EnvironmentError("file {} not found".format(url_or_filename)) 108 | else: 109 | # Something unknown 110 | raise ValueError("unable to parse {} as a URL or as a local path".format(url_or_filename)) 111 | 112 | 113 | def split_s3_path(url): 114 | """Split a full s3 path into the bucket name and path.""" 115 | parsed = urlparse(url) 116 | if not parsed.netloc or not parsed.path: 117 | raise ValueError("bad s3 path {}".format(url)) 118 | bucket_name = parsed.netloc 119 | s3_path = parsed.path 120 | # Remove '/' at beginning of path. 121 | if s3_path.startswith("/"): 122 | s3_path = s3_path[1:] 123 | return bucket_name, s3_path 124 | 125 | 126 | def s3_request(func): 127 | """ 128 | Wrapper function for s3 requests in order to create more helpful error 129 | messages. 130 | """ 131 | 132 | @wraps(func) 133 | def wrapper(url, *args, **kwargs): 134 | try: 135 | return func(url, *args, **kwargs) 136 | except ClientError as exc: 137 | if int(exc.response["Error"]["Code"]) == 404: 138 | raise EnvironmentError("file {} not found".format(url)) 139 | else: 140 | raise 141 | 142 | return wrapper 143 | 144 | 145 | @s3_request 146 | def s3_etag(url): 147 | """Check ETag on S3 object.""" 148 | s3_resource = boto3.resource("s3") 149 | bucket_name, s3_path = split_s3_path(url) 150 | s3_object = s3_resource.Object(bucket_name, s3_path) 151 | return s3_object.e_tag 152 | 153 | 154 | @s3_request 155 | def s3_get(url, temp_file): 156 | """Pull a file directly from S3.""" 157 | s3_resource = boto3.resource("s3") 158 | bucket_name, s3_path = split_s3_path(url) 159 | s3_resource.Bucket(bucket_name).download_fileobj(s3_path, temp_file) 160 | 161 | 162 | def http_get(url, temp_file): 163 | req = requests.get(url, stream=True) 164 | content_length = req.headers.get('Content-Length') 165 | total = int(content_length) if content_length is not None else None 166 | progress = tqdm(unit="B", total=total) 167 | for chunk in req.iter_content(chunk_size=1024): 168 | if chunk: # filter out keep-alive new chunks 169 | progress.update(len(chunk)) 170 | temp_file.write(chunk) 171 | progress.close() 172 | 173 | 174 | def get_from_cache(url, cache_dir=None): 175 | """ 176 | Given a URL, look for the corresponding dataset in the local cache. 177 | If it's not there, download it. Then return the path to the cached file. 178 | """ 179 | if cache_dir is None: 180 | cache_dir = PYTORCH_PRETRAINED_BERT_CACHE 181 | if sys.version_info[0] == 3 and isinstance(cache_dir, Path): 182 | cache_dir = str(cache_dir) 183 | 184 | if not os.path.exists(cache_dir): 185 | os.makedirs(cache_dir) 186 | 187 | # Get eTag to add to filename, if it exists. 188 | if url.startswith("s3://"): 189 | etag = s3_etag(url) 190 | else: 191 | etag_file_path = os.path.join(cache_dir, 'etag.json') 192 | etag = '' 193 | if os.path.isfile(etag_file_path): 194 | f = open(etag_file_path) 195 | etag_dict = json.load(f) 196 | etag = etag_dict.get(url,'') 197 | if not etag: 198 | response = requests.head(url, allow_redirects=True) 199 | if response.status_code != 200: 200 | raise IOError("HEAD request failed for url {} with status code {}" 201 | .format(url, response.status_code)) 202 | etag = response.headers.get("ETag") 203 | if not os.path.isfile(etag_file_path): 204 | etag_dict = {} 205 | else: 206 | f = open(etag_file_path, 'r') 207 | etag_dict = json.load(f) 208 | f.close() 209 | f = open(etag_file_path, 'w') 210 | etag_dict[url] = etag 211 | json.dump(etag_dict, f) 212 | f.close() 213 | 214 | filename = url_to_filename(url, etag) 215 | 216 | # get cache path to put the file 217 | cache_path = os.path.join(cache_dir, filename) 218 | 219 | if not os.path.exists(cache_path): 220 | # Download to temporary file, then copy to cache dir once finished. 221 | # Otherwise you get corrupt cache entries if the download gets interrupted. 222 | with tempfile.NamedTemporaryFile() as temp_file: 223 | logger.info("%s not found in cache, downloading to %s", url, temp_file.name) 224 | 225 | # GET file object 226 | if url.startswith("s3://"): 227 | s3_get(url, temp_file) 228 | else: 229 | http_get(url, temp_file) 230 | 231 | # we are copying the file before closing it, so flush to avoid truncation 232 | temp_file.flush() 233 | # shutil.copyfileobj() starts at the current position, so go to the start 234 | temp_file.seek(0) 235 | 236 | logger.info("copying %s to cache at %s", temp_file.name, cache_path) 237 | with open(cache_path, 'wb') as cache_file: 238 | shutil.copyfileobj(temp_file, cache_file) 239 | 240 | logger.info("creating metadata file for %s", cache_path) 241 | meta = {'url': url, 'etag': etag} 242 | meta_path = cache_path + '.json' 243 | with open(meta_path, 'w', encoding="utf-8") as meta_file: 244 | json.dump(meta, meta_file) 245 | 246 | logger.info("removing temp file %s", temp_file.name) 247 | 248 | return cache_path 249 | 250 | 251 | def read_set_from_file(filename): 252 | ''' 253 | Extract a de-duped collection (set) of text from a file. 254 | Expected file format is one item per line. 255 | ''' 256 | collection = set() 257 | with open(filename, 'r', encoding='utf-8') as file_: 258 | for line in file_: 259 | collection.add(line.rstrip()) 260 | return collection 261 | 262 | 263 | def get_file_extension(path, dot=True, lower=True): 264 | ext = os.path.splitext(path)[1] 265 | ext = ext if dot else ext[1:] 266 | return ext.lower() if lower else ext 267 | -------------------------------------------------------------------------------- /hiexpl/bert/filter_sentence.py: -------------------------------------------------------------------------------- 1 | all_sent_file = 'glue_data/SST-2/original/datasetSentences.txt' 2 | train_file = 'glue_data/SST-2/train_all.tsv' 3 | 4 | f1 = open(all_sent_file) 5 | f2 = open(train_file) 6 | fw = open('glue_data/SST-2/train.tsv', 'w') 7 | 8 | lines1, lines2 = f1.readlines(), f2.readlines() 9 | fw.write(lines1[0]) 10 | 11 | hash_set = set() 12 | for line in lines1: 13 | sent = line.split('\t')[-1] 14 | hash_set.add(sent.strip().lower()) 15 | 16 | for line in lines2: 17 | sent = line.split('\t')[0] 18 | if sent.strip() in hash_set: 19 | fw.write(line) 20 | 21 | -------------------------------------------------------------------------------- /hiexpl/bert/global_state.py: -------------------------------------------------------------------------------- 1 | class _GlobalStateDict: 2 | def __init__(self): 3 | self.states = [] 4 | self.current_layer_id = 0 5 | self.store_flag = True 6 | self.activated = False 7 | 8 | self.rel_span_len = 0 9 | self.total_span_len = 1 10 | 11 | def store_state(self, value): 12 | if self.activated: 13 | self.states.append(value) 14 | self.current_layer_id += 1 15 | 16 | def get_states(self): 17 | states = self.states[self.current_layer_id] # [B * H] 18 | states = states.split(1, 0) 19 | self.current_layer_id += 1 20 | return states 21 | 22 | def init_store_states(self): 23 | self.activated = True 24 | self.states = [] 25 | self.current_layer_id = 0 26 | self.store_flag = True 27 | 28 | def init_fetch_states(self): 29 | self.current_layer_id = 0 30 | self.store_flag = False 31 | 32 | global_state_dict = _GlobalStateDict() -------------------------------------------------------------------------------- /hiexpl/bert/optimization.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """PyTorch optimization for BERT model.""" 16 | 17 | import math 18 | import torch 19 | from torch.optim import Optimizer 20 | from torch.optim.optimizer import required 21 | from torch.nn.utils import clip_grad_norm_ 22 | import logging 23 | 24 | logger = logging.getLogger(__name__) 25 | 26 | def warmup_cosine(x, warmup=0.002): 27 | if x < warmup: 28 | return x/warmup 29 | return 0.5 * (1.0 + torch.cos(math.pi * x)) 30 | 31 | def warmup_constant(x, warmup=0.002): 32 | """ Linearly increases learning rate over `warmup`*`t_total` (as provided to BertAdam) training steps. 33 | Learning rate is 1. afterwards. """ 34 | if x < warmup: 35 | return x/warmup 36 | return 1.0 37 | 38 | def warmup_linear(x, warmup=0.002): 39 | """ Specifies a triangular learning rate schedule where peak is reached at `warmup`*`t_total`-th (as provided to BertAdam) training step. 40 | After `t_total`-th training step, learning rate is zero. """ 41 | if x < warmup: 42 | return x/warmup 43 | return max((x-1.)/(warmup-1.), 0) 44 | 45 | SCHEDULES = { 46 | 'warmup_cosine': warmup_cosine, 47 | 'warmup_constant': warmup_constant, 48 | 'warmup_linear': warmup_linear, 49 | } 50 | 51 | 52 | class BertAdam(Optimizer): 53 | """Implements BERT version of Adam algorithm with weight decay fix. 54 | Params: 55 | lr: learning rate 56 | warmup: portion of t_total for the warmup, -1 means no warmup. Default: -1 57 | t_total: total number of training steps for the learning 58 | rate schedule, -1 means constant learning rate. Default: -1 59 | schedule: schedule to use for the warmup (see above). Default: 'warmup_linear' 60 | b1: Adams b1. Default: 0.9 61 | b2: Adams b2. Default: 0.999 62 | e: Adams epsilon. Default: 1e-6 63 | weight_decay: Weight decay. Default: 0.01 64 | max_grad_norm: Maximum norm for the gradients (-1 means no clipping). Default: 1.0 65 | """ 66 | def __init__(self, params, lr=required, warmup=-1, t_total=-1, schedule='warmup_linear', 67 | b1=0.9, b2=0.999, e=1e-6, weight_decay=0.01, 68 | max_grad_norm=1.0): 69 | if lr is not required and lr < 0.0: 70 | raise ValueError("Invalid learning rate: {} - should be >= 0.0".format(lr)) 71 | if schedule not in SCHEDULES: 72 | raise ValueError("Invalid schedule parameter: {}".format(schedule)) 73 | if not 0.0 <= warmup < 1.0 and not warmup == -1: 74 | raise ValueError("Invalid warmup: {} - should be in [0.0, 1.0[ or -1".format(warmup)) 75 | if not 0.0 <= b1 < 1.0: 76 | raise ValueError("Invalid b1 parameter: {} - should be in [0.0, 1.0[".format(b1)) 77 | if not 0.0 <= b2 < 1.0: 78 | raise ValueError("Invalid b2 parameter: {} - should be in [0.0, 1.0[".format(b2)) 79 | if not e >= 0.0: 80 | raise ValueError("Invalid epsilon value: {} - should be >= 0.0".format(e)) 81 | defaults = dict(lr=lr, schedule=schedule, warmup=warmup, t_total=t_total, 82 | b1=b1, b2=b2, e=e, weight_decay=weight_decay, 83 | max_grad_norm=max_grad_norm) 84 | super(BertAdam, self).__init__(params, defaults) 85 | 86 | def get_lr(self): 87 | lr = [] 88 | for group in self.param_groups: 89 | for p in group['params']: 90 | state = self.state[p] 91 | if len(state) == 0: 92 | return [0] 93 | if group['t_total'] != -1: 94 | schedule_fct = SCHEDULES[group['schedule']] 95 | lr_scheduled = group['lr'] * schedule_fct(state['step']/group['t_total'], group['warmup']) 96 | else: 97 | lr_scheduled = group['lr'] 98 | lr.append(lr_scheduled) 99 | return lr 100 | 101 | def step(self, closure=None): 102 | """Performs a single optimization step. 103 | 104 | Arguments: 105 | closure (callable, optional): A closure that reevaluates the model 106 | and returns the loss. 107 | """ 108 | loss = None 109 | if closure is not None: 110 | loss = closure() 111 | 112 | warned_for_t_total = False 113 | 114 | for group in self.param_groups: 115 | for p in group['params']: 116 | if p.grad is None: 117 | continue 118 | grad = p.grad.data 119 | if grad.is_sparse: 120 | raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead') 121 | 122 | state = self.state[p] 123 | 124 | # State initialization 125 | if len(state) == 0: 126 | state['step'] = 0 127 | # Exponential moving average of gradient values 128 | state['next_m'] = torch.zeros_like(p.data) 129 | # Exponential moving average of squared gradient values 130 | state['next_v'] = torch.zeros_like(p.data) 131 | 132 | next_m, next_v = state['next_m'], state['next_v'] 133 | beta1, beta2 = group['b1'], group['b2'] 134 | 135 | # Add grad clipping 136 | if group['max_grad_norm'] > 0: 137 | clip_grad_norm_(p, group['max_grad_norm']) 138 | 139 | # Decay the first and second moment running average coefficient 140 | # In-place operations to update the averages at the same time 141 | next_m.mul_(beta1).add_(1 - beta1, grad) 142 | next_v.mul_(beta2).addcmul_(1 - beta2, grad, grad) 143 | update = next_m / (next_v.sqrt() + group['e']) 144 | 145 | # Just adding the square of the weights to the loss function is *not* 146 | # the correct way of using L2 regularization/weight decay with Adam, 147 | # since that will interact with the m and v parameters in strange ways. 148 | # 149 | # Instead we want to decay the weights in a manner that doesn't interact 150 | # with the m/v parameters. This is equivalent to adding the square 151 | # of the weights to the loss with plain (non-momentum) SGD. 152 | if group['weight_decay'] > 0.0: 153 | update += group['weight_decay'] * p.data 154 | 155 | if group['t_total'] != -1: 156 | schedule_fct = SCHEDULES[group['schedule']] 157 | progress = state['step']/group['t_total'] 158 | lr_scheduled = group['lr'] * schedule_fct(progress, group['warmup']) 159 | # warning for exceeding t_total (only active with warmup_linear 160 | if group['schedule'] == "warmup_linear" and progress > 1. and not warned_for_t_total: 161 | logger.warning( 162 | "Training beyond specified 't_total' steps with schedule '{}'. Learning rate set to {}. " 163 | "Please set 't_total' of {} correctly.".format(group['schedule'], lr_scheduled, self.__class__.__name__)) 164 | warned_for_t_total = True 165 | # end warning 166 | else: 167 | lr_scheduled = group['lr'] 168 | 169 | update_with_lr = lr_scheduled * update 170 | p.data.add_(-update_with_lr) 171 | 172 | state['step'] += 1 173 | 174 | # step_size = lr_scheduled * math.sqrt(bias_correction2) / bias_correction1 175 | # No bias correction 176 | # bias_correction1 = 1 - beta1 ** state['step'] 177 | # bias_correction2 = 1 - beta2 ** state['step'] 178 | 179 | return loss 180 | -------------------------------------------------------------------------------- /hiexpl/bert/optimization_openai.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Open AI Team Authors and The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """PyTorch optimization for OpenAI GPT model.""" 16 | 17 | import math 18 | import torch 19 | from torch.optim import Optimizer 20 | from torch.optim.optimizer import required 21 | from torch.nn.utils import clip_grad_norm_ 22 | import logging 23 | 24 | logger = logging.getLogger(__name__) 25 | 26 | def warmup_cosine(x, warmup=0.002): 27 | if x < warmup: 28 | return x/warmup 29 | return 0.5 * (1.0 + torch.cos(math.pi * x)) 30 | 31 | def warmup_constant(x, warmup=0.002): 32 | """ Linearly increases learning rate over `warmup`*`t_total` (as provided to OpenAIAdam) training steps. 33 | Learning rate is 1. afterwards. """ 34 | if x < warmup: 35 | return x/warmup 36 | return 1.0 37 | 38 | def warmup_linear(x, warmup=0.002): 39 | """ Specifies a triangular learning rate schedule where peak is reached at `warmup`*`t_total`-th (as provided to OpenAIAdam) training step. 40 | After `t_total`-th training step, learning rate is zero. """ 41 | if x < warmup: 42 | return x/warmup 43 | return max((x-1.)/(warmup-1.), 0) 44 | 45 | SCHEDULES = { 46 | 'warmup_cosine':warmup_cosine, 47 | 'warmup_constant':warmup_constant, 48 | 'warmup_linear':warmup_linear, 49 | } 50 | 51 | 52 | class OpenAIAdam(Optimizer): 53 | """Implements Open AI version of Adam algorithm with weight decay fix. 54 | """ 55 | def __init__(self, params, lr=required, schedule='warmup_linear', warmup=-1, t_total=-1, 56 | b1=0.9, b2=0.999, e=1e-8, weight_decay=0, 57 | vector_l2=False, max_grad_norm=-1, **kwargs): 58 | if lr is not required and lr < 0.0: 59 | raise ValueError("Invalid learning rate: {} - should be >= 0.0".format(lr)) 60 | if schedule not in SCHEDULES: 61 | raise ValueError("Invalid schedule parameter: {}".format(schedule)) 62 | if not 0.0 <= warmup < 1.0 and not warmup == -1: 63 | raise ValueError("Invalid warmup: {} - should be in [0.0, 1.0[ or -1".format(warmup)) 64 | if not 0.0 <= b1 < 1.0: 65 | raise ValueError("Invalid b1 parameter: {}".format(b1)) 66 | if not 0.0 <= b2 < 1.0: 67 | raise ValueError("Invalid b2 parameter: {}".format(b2)) 68 | if not e >= 0.0: 69 | raise ValueError("Invalid epsilon value: {}".format(e)) 70 | defaults = dict(lr=lr, schedule=schedule, warmup=warmup, t_total=t_total, 71 | b1=b1, b2=b2, e=e, weight_decay=weight_decay, vector_l2=vector_l2, 72 | max_grad_norm=max_grad_norm) 73 | super(OpenAIAdam, self).__init__(params, defaults) 74 | 75 | def get_lr(self): 76 | lr = [] 77 | for group in self.param_groups: 78 | for p in group['params']: 79 | state = self.state[p] 80 | if len(state) == 0: 81 | return [0] 82 | if group['t_total'] != -1: 83 | schedule_fct = SCHEDULES[group['schedule']] 84 | lr_scheduled = group['lr'] * schedule_fct(state['step']/group['t_total'], group['warmup']) 85 | else: 86 | lr_scheduled = group['lr'] 87 | lr.append(lr_scheduled) 88 | return lr 89 | 90 | def step(self, closure=None): 91 | """Performs a single optimization step. 92 | 93 | Arguments: 94 | closure (callable, optional): A closure that reevaluates the model 95 | and returns the loss. 96 | """ 97 | loss = None 98 | if closure is not None: 99 | loss = closure() 100 | 101 | warned_for_t_total = False 102 | 103 | for group in self.param_groups: 104 | for p in group['params']: 105 | if p.grad is None: 106 | continue 107 | grad = p.grad.data 108 | if grad.is_sparse: 109 | raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead') 110 | 111 | state = self.state[p] 112 | 113 | # State initialization 114 | if len(state) == 0: 115 | state['step'] = 0 116 | # Exponential moving average of gradient values 117 | state['exp_avg'] = torch.zeros_like(p.data) 118 | # Exponential moving average of squared gradient values 119 | state['exp_avg_sq'] = torch.zeros_like(p.data) 120 | 121 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 122 | beta1, beta2 = group['b1'], group['b2'] 123 | 124 | state['step'] += 1 125 | 126 | # Add grad clipping 127 | if group['max_grad_norm'] > 0: 128 | clip_grad_norm_(p, group['max_grad_norm']) 129 | 130 | # Decay the first and second moment running average coefficient 131 | exp_avg.mul_(beta1).add_(1 - beta1, grad) 132 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) 133 | denom = exp_avg_sq.sqrt().add_(group['e']) 134 | 135 | bias_correction1 = 1 - beta1 ** state['step'] 136 | bias_correction2 = 1 - beta2 ** state['step'] 137 | 138 | if group['t_total'] != -1: 139 | schedule_fct = SCHEDULES[group['schedule']] 140 | progress = state['step']/group['t_total'] 141 | lr_scheduled = group['lr'] * schedule_fct(progress, group['warmup']) 142 | # warning for exceeding t_total (only active with warmup_linear 143 | if group['schedule'] == "warmup_linear" and progress > 1. and not warned_for_t_total: 144 | logger.warning( 145 | "Training beyond specified 't_total' steps with schedule '{}'. Learning rate set to {}. " 146 | "Please set 't_total' of {} correctly.".format(group['schedule'], lr_scheduled, self.__class__.__name__)) 147 | warned_for_t_total = True 148 | # end warning 149 | else: 150 | lr_scheduled = group['lr'] 151 | 152 | step_size = lr_scheduled * math.sqrt(bias_correction2) / bias_correction1 153 | 154 | p.data.addcdiv_(-step_size, exp_avg, denom) 155 | 156 | # Add weight decay at the end (fixed version) 157 | if (len(p.size()) > 1 or group['vector_l2']) and group['weight_decay'] > 0: 158 | p.data.add_(-lr_scheduled * group['weight_decay'], p.data) 159 | 160 | return loss 161 | -------------------------------------------------------------------------------- /hiexpl/bert/tacred_f1.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The Board of Trustees of The Leland Stanford Junior University 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from collections import Counter 16 | import sys 17 | NO_RELATION = 0 18 | 19 | def score(key, prediction, verbose=False): 20 | correct_by_relation = Counter() 21 | guessed_by_relation = Counter() 22 | gold_by_relation = Counter() 23 | 24 | # Loop over the data to compute a score 25 | for row in range(len(key)): 26 | gold = key[row] 27 | guess = prediction[row] 28 | 29 | if gold == NO_RELATION and guess == NO_RELATION: 30 | pass 31 | elif gold == NO_RELATION and guess != NO_RELATION: 32 | guessed_by_relation[guess] += 1 33 | elif gold != NO_RELATION and guess == NO_RELATION: 34 | gold_by_relation[gold] += 1 35 | elif gold != NO_RELATION and guess != NO_RELATION: 36 | guessed_by_relation[guess] += 1 37 | gold_by_relation[gold] += 1 38 | if gold == guess: 39 | correct_by_relation[guess] += 1 40 | 41 | # Print verbose information 42 | if verbose: 43 | print("Per-relation statistics:") 44 | relations = gold_by_relation.keys() 45 | longest_relation = 0 46 | for relation in sorted(relations): 47 | longest_relation = max(len(relation), longest_relation) 48 | for relation in sorted(relations): 49 | # (compute the score) 50 | correct = correct_by_relation[relation] 51 | guessed = guessed_by_relation[relation] 52 | gold = gold_by_relation[relation] 53 | prec = 1.0 54 | if guessed > 0: 55 | prec = float(correct) / float(guessed) 56 | recall = 0.0 57 | if gold > 0: 58 | recall = float(correct) / float(gold) 59 | f1 = 0.0 60 | if prec + recall > 0: 61 | f1 = 2.0 * prec * recall / (prec + recall) 62 | # (print the score) 63 | sys.stdout.write(("{:<" + str(longest_relation) + "}").format(relation)) 64 | sys.stdout.write(" P: ") 65 | if prec < 0.1: sys.stdout.write(' ') 66 | if prec < 1.0: sys.stdout.write(' ') 67 | sys.stdout.write("{:.2%}".format(prec)) 68 | sys.stdout.write(" R: ") 69 | if recall < 0.1: sys.stdout.write(' ') 70 | if recall < 1.0: sys.stdout.write(' ') 71 | sys.stdout.write("{:.2%}".format(recall)) 72 | sys.stdout.write(" F1: ") 73 | if f1 < 0.1: sys.stdout.write(' ') 74 | if f1 < 1.0: sys.stdout.write(' ') 75 | sys.stdout.write("{:.2%}".format(f1)) 76 | sys.stdout.write(" #: %d" % gold) 77 | sys.stdout.write("\n") 78 | print("") 79 | 80 | # Print the aggregate score 81 | if verbose: 82 | print("Final Score:") 83 | prec_micro = 1.0 84 | if sum(guessed_by_relation.values()) > 0: 85 | prec_micro = float(sum(correct_by_relation.values())) / float(sum(guessed_by_relation.values())) 86 | recall_micro = 0.0 87 | if sum(gold_by_relation.values()) > 0: 88 | recall_micro = float(sum(correct_by_relation.values())) / float(sum(gold_by_relation.values())) 89 | f1_micro = 0.0 90 | if prec_micro + recall_micro > 0.0: 91 | f1_micro = 2.0 * prec_micro * recall_micro / (prec_micro + recall_micro) 92 | #print("Precision (micro): {:.3%}".format(prec_micro)) 93 | #print(" Recall (micro): {:.3%}".format(recall_micro)) 94 | #print(" F1 (micro): {:.3%}".format(f1_micro)) 95 | return prec_micro, recall_micro, f1_micro -------------------------------------------------------------------------------- /hiexpl/bert/tokenization_gpt2.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Open AI Team Authors and The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Tokenization classes for OpenAI GPT.""" 16 | from __future__ import (absolute_import, division, print_function, 17 | unicode_literals) 18 | 19 | import json 20 | import logging 21 | import os 22 | import regex as re 23 | from io import open 24 | 25 | try: 26 | from functools import lru_cache 27 | except ImportError: 28 | # Just a dummy decorator to get the checks to run on python2 29 | # because honestly I don't want to support a byte-level unicode BPE tokenizer on python 2 right now. 30 | def lru_cache(): 31 | return lambda func: func 32 | 33 | from .file_utils import cached_path 34 | 35 | logger = logging.getLogger(__name__) 36 | 37 | PRETRAINED_VOCAB_ARCHIVE_MAP = { 38 | 'gpt2': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-vocab.json", 39 | } 40 | PRETRAINED_MERGES_ARCHIVE_MAP = { 41 | 'gpt2': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-merges.txt", 42 | } 43 | PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP = { 44 | 'gpt2': 1024, 45 | } 46 | VOCAB_NAME = 'vocab.json' 47 | MERGES_NAME = 'merges.txt' 48 | 49 | @lru_cache() 50 | def bytes_to_unicode(): 51 | """ 52 | Returns list of utf-8 byte and a corresponding list of unicode strings. 53 | The reversible bpe codes work on unicode strings. 54 | This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. 55 | When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. 56 | This is a signficant percentage of your normal, say, 32K bpe vocab. 57 | To avoid that, we want lookup tables between utf-8 bytes and unicode strings. 58 | And avoids mapping to whitespace/control characters the bpe code barfs on. 59 | """ 60 | bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1)) 61 | cs = bs[:] 62 | n = 0 63 | for b in range(2**8): 64 | if b not in bs: 65 | bs.append(b) 66 | cs.append(2**8+n) 67 | n += 1 68 | cs = [chr(n) for n in cs] 69 | return dict(zip(bs, cs)) 70 | 71 | def get_pairs(word): 72 | """Return set of symbol pairs in a word. 73 | 74 | Word is represented as tuple of symbols (symbols being variable-length strings). 75 | """ 76 | pairs = set() 77 | prev_char = word[0] 78 | for char in word[1:]: 79 | pairs.add((prev_char, char)) 80 | prev_char = char 81 | return pairs 82 | 83 | class GPT2Tokenizer(object): 84 | """ 85 | GPT-2 BPE tokenizer. Peculiarities: 86 | - Byte-level BPE 87 | """ 88 | @classmethod 89 | def from_pretrained(cls, pretrained_model_name_or_path, cache_dir=None, *inputs, **kwargs): 90 | """ 91 | Instantiate a PreTrainedBertModel from a pre-trained model file. 92 | Download and cache the pre-trained model file if needed. 93 | """ 94 | if pretrained_model_name_or_path in PRETRAINED_VOCAB_ARCHIVE_MAP: 95 | vocab_file = PRETRAINED_VOCAB_ARCHIVE_MAP[pretrained_model_name_or_path] 96 | merges_file = PRETRAINED_MERGES_ARCHIVE_MAP[pretrained_model_name_or_path] 97 | else: 98 | vocab_file = os.path.join(pretrained_model_name_or_path, VOCAB_NAME) 99 | merges_file = os.path.join(pretrained_model_name_or_path, MERGES_NAME) 100 | # redirect to the cache, if necessary 101 | try: 102 | resolved_vocab_file = cached_path(vocab_file, cache_dir=cache_dir) 103 | resolved_merges_file = cached_path(merges_file, cache_dir=cache_dir) 104 | except EnvironmentError: 105 | logger.error( 106 | "Model name '{}' was not found in model name list ({}). " 107 | "We assumed '{}' was a path or url but couldn't find files {} and {} " 108 | "at this path or url.".format( 109 | pretrained_model_name_or_path, 110 | ', '.join(PRETRAINED_VOCAB_ARCHIVE_MAP.keys()), 111 | pretrained_model_name_or_path, 112 | vocab_file, merges_file)) 113 | return None 114 | if resolved_vocab_file == vocab_file and resolved_merges_file == merges_file: 115 | logger.info("loading vocabulary file {}".format(vocab_file)) 116 | logger.info("loading merges file {}".format(merges_file)) 117 | else: 118 | logger.info("loading vocabulary file {} from cache at {}".format( 119 | vocab_file, resolved_vocab_file)) 120 | logger.info("loading merges file {} from cache at {}".format( 121 | merges_file, resolved_merges_file)) 122 | if pretrained_model_name_or_path in PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP: 123 | # if we're using a pretrained model, ensure the tokenizer wont index sequences longer 124 | # than the number of positional embeddings 125 | max_len = PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP[pretrained_model_name_or_path] 126 | kwargs['max_len'] = min(kwargs.get('max_len', int(1e12)), max_len) 127 | # Instantiate tokenizer. 128 | tokenizer = cls(resolved_vocab_file, resolved_merges_file, *inputs, **kwargs) 129 | return tokenizer 130 | 131 | def __init__(self, vocab_file, merges_file, errors='replace', max_len=None): 132 | self.max_len = max_len if max_len is not None else int(1e12) 133 | self.encoder = json.load(open(vocab_file)) 134 | self.decoder = {v:k for k,v in self.encoder.items()} 135 | self.errors = errors # how to handle errors in decoding 136 | self.byte_encoder = bytes_to_unicode() 137 | self.byte_decoder = {v:k for k, v in self.byte_encoder.items()} 138 | bpe_data = open(merges_file, encoding='utf-8').read().split('\n')[1:-1] 139 | bpe_merges = [tuple(merge.split()) for merge in bpe_data] 140 | self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges)))) 141 | self.cache = {} 142 | 143 | # Should haved added re.IGNORECASE so BPE merges can happen for capitalized versions of contractions 144 | self.pat = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""") 145 | 146 | def __len__(self): 147 | return len(self.encoder) 148 | 149 | def bpe(self, token): 150 | if token in self.cache: 151 | return self.cache[token] 152 | word = tuple(token) 153 | pairs = get_pairs(word) 154 | 155 | if not pairs: 156 | return token 157 | 158 | while True: 159 | bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf'))) 160 | if bigram not in self.bpe_ranks: 161 | break 162 | first, second = bigram 163 | new_word = [] 164 | i = 0 165 | while i < len(word): 166 | try: 167 | j = word.index(first, i) 168 | new_word.extend(word[i:j]) 169 | i = j 170 | except: 171 | new_word.extend(word[i:]) 172 | break 173 | 174 | if word[i] == first and i < len(word)-1 and word[i+1] == second: 175 | new_word.append(first+second) 176 | i += 2 177 | else: 178 | new_word.append(word[i]) 179 | i += 1 180 | new_word = tuple(new_word) 181 | word = new_word 182 | if len(word) == 1: 183 | break 184 | else: 185 | pairs = get_pairs(word) 186 | word = ' '.join(word) 187 | self.cache[token] = word 188 | return word 189 | 190 | def encode(self, text): 191 | bpe_tokens = [] 192 | for token in re.findall(self.pat, text): 193 | token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) 194 | bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' ')) 195 | if len(bpe_tokens) > self.max_len: 196 | logger.warning( 197 | "Token indices sequence length is longer than the specified maximum " 198 | " sequence length for this OpenAI GPT-2 model ({} > {}). Running this" 199 | " sequence through the model will result in indexing errors".format(len(bpe_tokens), self.max_len) 200 | ) 201 | return bpe_tokens 202 | 203 | def decode(self, tokens): 204 | text = ''.join([self.decoder[token] for token in tokens]) 205 | text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors=self.errors) 206 | return text 207 | -------------------------------------------------------------------------------- /hiexpl/bert/tokenization_openai.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The Open AI Team Authors and The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Tokenization classes for OpenAI GPT.""" 16 | from __future__ import (absolute_import, division, print_function, 17 | unicode_literals) 18 | 19 | import json 20 | import logging 21 | import os 22 | import re 23 | import sys 24 | from io import open 25 | 26 | from tqdm import tqdm 27 | 28 | from .file_utils import cached_path 29 | from .tokenization import BasicTokenizer 30 | 31 | logger = logging.getLogger(__name__) 32 | 33 | PRETRAINED_VOCAB_ARCHIVE_MAP = { 34 | 'openai-gpt': "https://s3.amazonaws.com/models.huggingface.co/bert/openai-gpt-vocab.json", 35 | } 36 | PRETRAINED_MERGES_ARCHIVE_MAP = { 37 | 'openai-gpt': "https://s3.amazonaws.com/models.huggingface.co/bert/openai-gpt-merges.txt", 38 | } 39 | PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP = { 40 | 'openai-gpt': 512, 41 | } 42 | VOCAB_NAME = 'vocab.json' 43 | MERGES_NAME = 'merges.txt' 44 | 45 | def get_pairs(word): 46 | """ 47 | Return set of symbol pairs in a word. 48 | word is represented as tuple of symbols (symbols being variable-length strings) 49 | """ 50 | pairs = set() 51 | prev_char = word[0] 52 | for char in word[1:]: 53 | pairs.add((prev_char, char)) 54 | prev_char = char 55 | return pairs 56 | 57 | def text_standardize(text): 58 | """ 59 | fixes some issues the spacy tokenizer had on books corpus 60 | also does some whitespace standardization 61 | """ 62 | text = text.replace('—', '-') 63 | text = text.replace('–', '-') 64 | text = text.replace('―', '-') 65 | text = text.replace('…', '...') 66 | text = text.replace('´', "'") 67 | text = re.sub(r'''(-+|~+|!+|"+|;+|\?+|\++|,+|\)+|\(+|\\+|\/+|\*+|\[+|\]+|}+|{+|\|+|_+)''', r' \1 ', text) 68 | text = re.sub(r'\s*\n\s*', ' \n ', text) 69 | text = re.sub(r'[^\S\n]+', ' ', text) 70 | return text.strip() 71 | 72 | class OpenAIGPTTokenizer(object): 73 | """ 74 | BPE tokenizer. Peculiarities: 75 | - lower case all inputs 76 | - uses SpaCy tokenizer and ftfy for pre-BPE tokenization if they are installed, fallback to BERT's BasicTokenizer if not. 77 | - argument special_tokens and function set_special_tokens: 78 | can be used to add additional symbols (ex: "__classify__") to a vocabulary. 79 | """ 80 | @classmethod 81 | def from_pretrained(cls, pretrained_model_name_or_path, cache_dir=None, *inputs, **kwargs): 82 | """ 83 | Instantiate a PreTrainedBertModel from a pre-trained model file. 84 | Download and cache the pre-trained model file if needed. 85 | """ 86 | if pretrained_model_name_or_path in PRETRAINED_VOCAB_ARCHIVE_MAP: 87 | vocab_file = PRETRAINED_VOCAB_ARCHIVE_MAP[pretrained_model_name_or_path] 88 | merges_file = PRETRAINED_MERGES_ARCHIVE_MAP[pretrained_model_name_or_path] 89 | else: 90 | vocab_file = os.path.join(pretrained_model_name_or_path, VOCAB_NAME) 91 | merges_file = os.path.join(pretrained_model_name_or_path, MERGES_NAME) 92 | # redirect to the cache, if necessary 93 | try: 94 | resolved_vocab_file = cached_path(vocab_file, cache_dir=cache_dir) 95 | resolved_merges_file = cached_path(merges_file, cache_dir=cache_dir) 96 | except EnvironmentError: 97 | logger.error( 98 | "Model name '{}' was not found in model name list ({}). " 99 | "We assumed '{}' was a path or url but couldn't find files {} and {} " 100 | "at this path or url.".format( 101 | pretrained_model_name_or_path, 102 | ', '.join(PRETRAINED_VOCAB_ARCHIVE_MAP.keys()), 103 | pretrained_model_name_or_path, 104 | vocab_file, merges_file)) 105 | return None 106 | if resolved_vocab_file == vocab_file and resolved_merges_file == merges_file: 107 | logger.info("loading vocabulary file {}".format(vocab_file)) 108 | logger.info("loading merges file {}".format(merges_file)) 109 | else: 110 | logger.info("loading vocabulary file {} from cache at {}".format( 111 | vocab_file, resolved_vocab_file)) 112 | logger.info("loading merges file {} from cache at {}".format( 113 | merges_file, resolved_merges_file)) 114 | if pretrained_model_name_or_path in PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP: 115 | # if we're using a pretrained model, ensure the tokenizer wont index sequences longer 116 | # than the number of positional embeddings 117 | max_len = PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP[pretrained_model_name_or_path] 118 | kwargs['max_len'] = min(kwargs.get('max_len', int(1e12)), max_len) 119 | # Instantiate tokenizer. 120 | tokenizer = cls(resolved_vocab_file, resolved_merges_file, *inputs, **kwargs) 121 | return tokenizer 122 | 123 | def __init__(self, vocab_file, merges_file, special_tokens=None, max_len=None): 124 | try: 125 | import ftfy 126 | import spacy 127 | self.nlp = spacy.load('en', disable=['parser', 'tagger', 'ner', 'textcat']) 128 | self.fix_text = ftfy.fix_text 129 | except ImportError: 130 | logger.warning("ftfy or spacy is not installed using BERT BasicTokenizer instead of SpaCy & ftfy.") 131 | self.nlp = BasicTokenizer(do_lower_case=True, 132 | never_split=special_tokens if special_tokens is not None else []) 133 | self.fix_text = None 134 | 135 | self.max_len = max_len if max_len is not None else int(1e12) 136 | self.encoder = json.load(open(vocab_file, encoding="utf-8")) 137 | self.decoder = {v:k for k,v in self.encoder.items()} 138 | merges = open(merges_file, encoding='utf-8').read().split('\n')[1:-1] 139 | merges = [tuple(merge.split()) for merge in merges] 140 | self.bpe_ranks = dict(zip(merges, range(len(merges)))) 141 | self.cache = {} 142 | self.set_special_tokens(special_tokens) 143 | 144 | def __len__(self): 145 | return len(self.encoder) + len(self.special_tokens) 146 | 147 | def set_special_tokens(self, special_tokens): 148 | """ Add a list of additional tokens to the encoder. 149 | The additional tokens are indexed starting from the last index of the 150 | current vocabulary in the order of the `special_tokens` list. 151 | """ 152 | if not special_tokens: 153 | self.special_tokens = {} 154 | self.special_tokens_decoder = {} 155 | return 156 | self.special_tokens = dict((tok, len(self.encoder) + i) for i, tok in enumerate(special_tokens)) 157 | self.special_tokens_decoder = {v:k for k, v in self.special_tokens.items()} 158 | if self.fix_text is None: 159 | # Using BERT's BasicTokenizer: we can update the tokenizer 160 | self.nlp.never_split = special_tokens 161 | logger.info("Special tokens {}".format(self.special_tokens)) 162 | 163 | def bpe(self, token): 164 | word = tuple(token[:-1]) + (token[-1] + '',) 165 | if token in self.cache: 166 | return self.cache[token] 167 | pairs = get_pairs(word) 168 | 169 | if not pairs: 170 | return token+'' 171 | 172 | while True: 173 | bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float('inf'))) 174 | if bigram not in self.bpe_ranks: 175 | break 176 | first, second = bigram 177 | new_word = [] 178 | i = 0 179 | while i < len(word): 180 | try: 181 | j = word.index(first, i) 182 | new_word.extend(word[i:j]) 183 | i = j 184 | except: 185 | new_word.extend(word[i:]) 186 | break 187 | 188 | if word[i] == first and i < len(word)-1 and word[i+1] == second: 189 | new_word.append(first+second) 190 | i += 2 191 | else: 192 | new_word.append(word[i]) 193 | i += 1 194 | new_word = tuple(new_word) 195 | word = new_word 196 | if len(word) == 1: 197 | break 198 | else: 199 | pairs = get_pairs(word) 200 | word = ' '.join(word) 201 | if word == '\n ': 202 | word = '\n' 203 | self.cache[token] = word 204 | return word 205 | 206 | def tokenize(self, text): 207 | """ Tokenize a string. """ 208 | split_tokens = [] 209 | if self.fix_text is None: 210 | # Using BERT's BasicTokenizer 211 | text = self.nlp.tokenize(text) 212 | for token in text: 213 | split_tokens.extend([t for t in self.bpe(token).split(' ')]) 214 | else: 215 | # Using SpaCy & ftfy (original tokenization process of OpenAI GPT) 216 | text = self.nlp(text_standardize(self.fix_text(text))) 217 | for token in text: 218 | split_tokens.extend([t for t in self.bpe(token.text.lower()).split(' ')]) 219 | return split_tokens 220 | 221 | def convert_tokens_to_ids(self, tokens): 222 | """ Converts a sequence of tokens into ids using the vocab. """ 223 | ids = [] 224 | if isinstance(tokens, str) or (sys.version_info[0] == 2 and isinstance(tokens, unicode)): 225 | if tokens in self.special_tokens: 226 | return self.special_tokens[tokens] 227 | else: 228 | return self.encoder.get(tokens, 0) 229 | for token in tokens: 230 | if token in self.special_tokens: 231 | ids.append(self.special_tokens[token]) 232 | else: 233 | ids.append(self.encoder.get(token, 0)) 234 | if len(ids) > self.max_len: 235 | logger.warning( 236 | "Token indices sequence length is longer than the specified maximum " 237 | " sequence length for this OpenAI GPT model ({} > {}). Running this" 238 | " sequence through the model will result in indexing errors".format(len(ids), self.max_len) 239 | ) 240 | return ids 241 | 242 | def convert_ids_to_tokens(self, ids, skip_special_tokens=False): 243 | """Converts a sequence of ids in BPE tokens using the vocab.""" 244 | tokens = [] 245 | for i in ids: 246 | if i in self.special_tokens_decoder: 247 | if not skip_special_tokens: 248 | tokens.append(self.special_tokens_decoder[i]) 249 | else: 250 | tokens.append(self.decoder[i]) 251 | return tokens 252 | 253 | def decode(self, ids, skip_special_tokens=False, clean_up_tokenization_spaces=False): 254 | """Converts a sequence of ids in a string.""" 255 | tokens = self.convert_ids_to_tokens(ids, skip_special_tokens=skip_special_tokens) 256 | out_string = ''.join(tokens).replace('', ' ').strip() 257 | if clean_up_tokenization_spaces: 258 | out_string = out_string.replace('', '') 259 | out_string = out_string.replace(' .', '.').replace(' ?', '?').replace(' !', '!').replace(' ,', ',').replace(' ,', ',' 260 | ).replace(" n't", "n't").replace(" 'm", "'m").replace(" 're", "'re").replace(" do not", " don't" 261 | ).replace(" 's", "'s").replace(" t ", "'t ").replace(" s ", "'s ").replace(" m ", "'m " 262 | ).replace(" 've", "'ve") 263 | return out_string 264 | -------------------------------------------------------------------------------- /hiexpl/eval_explanations.py: -------------------------------------------------------------------------------- 1 | from utils.args import get_args 2 | import torch 3 | import numpy as np 4 | from utils.reader import load_vocab 5 | from bert.tokenization import BertTokenizer 6 | from utils.parser import get_span_to_node_mapping, parse_tree 7 | import csv, pickle 8 | from collections import defaultdict 9 | from utils.args import get_best_snapshot 10 | from nns.linear_model import BOWRegression, BOWRegressionMulti 11 | import argparse 12 | 13 | args = get_args() 14 | tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True, cache_dir='bert/cache') 15 | 16 | def unigram_linear_pearson(filename): 17 | f = open(filename) 18 | model = torch.load(args.bow_snapshot, map_location='cpu') 19 | vocab = load_vocab(VOCAB) 20 | out, truth = [], [] 21 | coeff_dict = {} 22 | scores_dict = defaultdict(list) 23 | valid, total = 0, 0 24 | for lidx, line in enumerate(f.readlines()): 25 | if lidx < MINLINE: continue 26 | if lidx == MAXLINE: break 27 | l = line.lower().strip().split('\t') 28 | for entry in l: 29 | items = entry.strip().split(' ') 30 | if len(items) > 2: 31 | continue 32 | score = float(items[1]) 33 | word = items[0] 34 | if word in vocab.stoi: 35 | coeff = -model.get_coefficient(vocab.stoi[word]) 36 | out.append(score) 37 | truth.append(coeff) 38 | scores_dict[word].append(score) 39 | coeff_dict[word] = coeff 40 | valid += 1 41 | total += 1 42 | p = np.corrcoef(out, truth) 43 | print('word_corr', p[1,0]) 44 | 45 | out_2, truth_2 = [], [] 46 | for k in scores_dict: 47 | out_2.extend([np.mean(scores_dict[k])]) 48 | truth_2.extend([coeff_dict[k]]) 49 | p2 = np.corrcoef(out_2, truth_2) 50 | print('average word_corr', p2[1,0]) 51 | return p[1,0] 52 | 53 | def unigram_linear_pearson_multiclass_agg(filename): 54 | f = open(filename,'rb') 55 | data = pickle.load(f) 56 | 57 | model = torch.load(args.bow_snapshot, map_location='cpu') 58 | vocab = load_vocab(VOCAB) 59 | out, truth = [], [] 60 | 61 | valid, total = 0, 0 62 | for lidx, entry in enumerate(data): 63 | if lidx < MINLINE: continue 64 | if lidx == MAXLINE: break 65 | 66 | sent_words = entry['text'].split() 67 | score_array = entry['tab'] 68 | label_name = entry['label'] 69 | if score_array.ndim == 1: 70 | score_array = score_array.reshape(1,-1) 71 | for word, score in zip(sent_words, score_array[0].tolist()): 72 | if word in vocab.stoi: 73 | coeff = model.get_margin_coefficient(label_name, vocab.stoi[word]) 74 | out.append(score) 75 | truth.append(coeff) 76 | valid += 1 77 | total += 1 78 | p = np.corrcoef(out, truth) 79 | print('word_corr', p[1,0]) 80 | return p[1,0] 81 | 82 | def unigram_linear_pearson_multiclass(filename): 83 | f = open(filename) 84 | model = torch.load(args.bow_snapshot, map_location='cpu') 85 | vocab = load_vocab(VOCAB) 86 | out, truth = [], [] 87 | scores_dict, coeff_dict = defaultdict(list), defaultdict(float) 88 | 89 | valid, total = 0, 0 90 | for lidx, line in enumerate(f.readlines()): 91 | if lidx < MINLINE: continue 92 | if lidx == MAXLINE: break 93 | l = line.strip().split('\t') 94 | class_name = l[0] 95 | for entry in l[1:]: 96 | items = entry.strip().split(' ') 97 | if len(items) > 2: 98 | continue 99 | score = float(items[1]) 100 | word = items[0] 101 | if word in vocab.stoi: 102 | coeff = model.get_label_coefficient(class_name, vocab.stoi[word]) 103 | out.append(score) 104 | truth.append(coeff) 105 | scores_dict[word].append(score) 106 | coeff_dict[word] = coeff 107 | valid += 1 108 | total += 1 109 | p = np.corrcoef(out, truth) 110 | print('word_corr', p[1,0]) 111 | 112 | def unigram_linear_pearson_bert_tree(filename, dataset='dev'): 113 | f = open(filename) 114 | f2 = open('.data/sst/trees/%s.txt' % dataset) 115 | model = torch.load(args.bow_snapshot, map_location='cpu') 116 | vocab = load_vocab(VOCAB) 117 | out, truth = [], [] 118 | valid, total = 0, 0 119 | scores_dict = defaultdict(list) 120 | coeff_dict = defaultdict(int) 121 | coeff_not, coeff_not_cnt = 0,0 122 | for lidx, (line, gt) in enumerate(zip(f.readlines(), f2.readlines()[OFFSET:])): 123 | if lidx < MINLINE: continue 124 | if lidx == MAXLINE: break 125 | entries = line.lower().strip().split('\t') 126 | span2node, node2span = get_span_to_node_mapping(parse_tree(gt)) 127 | spans = list(span2node.keys()) 128 | for idx, span in enumerate(spans): 129 | if type(span) is int: 130 | word = span2node[span].leaves()[0].lower() 131 | score = float(entries[idx].split()[-1]) 132 | if word in vocab.stoi: 133 | coeff = -model.get_coefficient(vocab.stoi[word]) 134 | out.append(score) 135 | truth.append(coeff) 136 | valid += 1 137 | scores_dict[word].append(score) 138 | coeff_dict[word] = coeff 139 | total += 1 140 | p = np.corrcoef(out, truth) 141 | print('word_corr', p[1,0]) 142 | 143 | out_2, truth_2 = [], [] 144 | for k in scores_dict: 145 | out_2.extend([np.mean(scores_dict[k])]) 146 | truth_2.extend([coeff_dict[k]]) 147 | p2 = np.corrcoef(out_2, truth_2) 148 | print('avg word_corr', p2[1,0]) 149 | 150 | return p[1,0] 151 | 152 | def lr_gt_pearson(): 153 | vocab = load_vocab('vocab/vocab_sst.pkl') 154 | model = torch.load(args.bow_snapshot, map_location='cpu') 155 | dict_path, label_path = '.data/sst_raw/dictionary.txt', '.data/sst_raw/sentiment_labels.txt' 156 | if BERT: 157 | phrase2id = load_txt_to_dict_hashed(dict_path) 158 | else: 159 | phrase2id = load_txt_to_dict(dict_path) 160 | id2label = load_txt_to_dict(label_path) 161 | 162 | a, b = [], [] 163 | for word in vocab.stoi: 164 | if word in phrase2id: 165 | a.append(model.get_coefficient(vocab.stoi[word])) 166 | b.append(float(id2label[phrase2id[word]])) 167 | 168 | print(len(a)) 169 | print(len(vocab.stoi)) 170 | print(np.corrcoef(a,b)[1,0]) 171 | 172 | def token2key(words): 173 | assert type(words) is list 174 | if len(words) > 4: 175 | return ' '.join(words[:2] + words[-2:]), len(words) 176 | else: 177 | return ' '.join(words), len(words) 178 | 179 | def load_txt_to_dict_hashed(filename): 180 | f = open(filename) 181 | dic = {} 182 | for line in f.readlines(): 183 | tup = line.lower().strip().split('|') 184 | if len(tup) != 2: 185 | continue 186 | key = tup[0].replace(' ','') 187 | dic[key] = tup[1] 188 | return dic 189 | 190 | def load_txt_to_dict(filename): 191 | f = open(filename) 192 | dic = {} 193 | for line in f.readlines(): 194 | tup = line.lower().strip().split('|') 195 | if len(tup) != 2: 196 | continue 197 | dic[tup[0]] = tup[1] 198 | return dic 199 | 200 | def phrase_gt_pearson(filename, dataset='dev'): 201 | f = open(filename) 202 | f2 = open('.data/sst/trees/%s.txt' % dataset) 203 | f3 = open('ground_truth.tmp', 'w') 204 | dict_path, label_path = '.data/sst_raw/dictionary.txt', '.data/sst_raw/sentiment_labels.txt' 205 | if BERT: 206 | phrase2id = load_txt_to_dict_hashed(dict_path) 207 | else: 208 | phrase2id = load_txt_to_dict(dict_path) 209 | id2label = load_txt_to_dict(label_path) 210 | out, truth = [], [] 211 | out_map, truth_map = {}, {} 212 | bucket_out, bucket_truth = {}, {} 213 | valid, total = 0, 0 214 | 215 | for idx, (line, line2) in enumerate(zip(f.readlines(), f2.readlines()[OFFSET:])): 216 | if idx < MINLINE: continue 217 | if idx == MAXLINE: break 218 | l = line.strip().split('\t') 219 | for entry in l: 220 | items = entry.strip().split(' ') 221 | score = float(items[-1]) 222 | if BERT: 223 | key = ''.join(items[:-1]).replace(' ','').replace('##','') 224 | else: 225 | key = ' '.join(items[:-1]) 226 | if key in phrase2id: 227 | phrase_id = phrase2id[key] 228 | gt_score = float(id2label[phrase_id]) 229 | out.append(score) 230 | truth.append(gt_score) 231 | 232 | if phrase_id not in out_map: 233 | out_map[phrase_id] = [] 234 | out_map[phrase_id].append(score) 235 | truth_map[phrase_id] = gt_score 236 | bucket_key = len(items) - 1 237 | if bucket_key not in bucket_out: 238 | bucket_out[bucket_key] = [] 239 | bucket_truth[bucket_key] = [] 240 | bucket_out[bucket_key].append(score) 241 | bucket_truth[bucket_key].append(gt_score) 242 | 243 | valid += 1 244 | total += 1 245 | p = np.corrcoef(out, truth) 246 | print('phrase_corr', p[1,0]) 247 | 248 | e_out, e_truth = [], [] 249 | for k in out_map: 250 | e_out.extend([np.mean(out_map[k])] * len(out_map[k])) 251 | e_truth.extend([truth_map[k]] * len(out_map[k])) 252 | p2 = np.corrcoef(e_out, e_truth) 253 | print('averaged phrase corr', p2[1,0]) 254 | 255 | print('(sanity check) matched: %d, total: %d' % (valid, total)) 256 | 257 | return out, truth, p[1,0], p2[1,0] 258 | 259 | 260 | def run_multiple(path): 261 | template = path 262 | nb_ranges, hists = [1,2,3,4], [5,10,20] 263 | postfix = '.bert' if BERT else '' 264 | f = open('analysis/%s_%s' % (TASK,template.split('/')[-1].replace('{','').replace('}','') + postfix),'w') 265 | writer = csv.writer(f,delimiter='\t') 266 | for nb_range in nb_ranges: 267 | for h in hists: 268 | print(nb_range, h) 269 | path = template.format(**{'nb':nb_range, 'h':h}) 270 | if BERT and TASK == 'sst': 271 | word_score = unigram_linear_pearson_bert_tree(path, dataset='test') 272 | elif TASK == 'tacred': 273 | word_score = unigram_linear_pearson_multiclass(path) 274 | else: 275 | word_score = unigram_linear_pearson(path) 276 | if TASK == 'sst': 277 | _, _, phrase_score, _ = phrase_gt_pearson(path, dataset='test') 278 | else: 279 | phrase_score = 0 280 | writer.writerow([nb_range, h, word_score, phrase_score]) 281 | f.close() 282 | 283 | if __name__ == '__main__': 284 | MINLINE = 0 285 | MAXLINE = 50 286 | OFFSET = 0 287 | DATASET = 'test' 288 | path = args.eval_file 289 | BERT = 'bert' in path 290 | TASK = '' 291 | 292 | for possible_task in ['sst','yelp','tacred']: 293 | if possible_task in path: 294 | TASK = possible_task 295 | break 296 | VOCAB = None 297 | 298 | 299 | if TASK == 'sst': 300 | args.bow_snapshot = get_best_snapshot('models/sst_bow') 301 | VOCAB = 'vocab/vocab_sst.pkl' 302 | elif TASK == 'yelp': 303 | args.bow_snapshot = get_best_snapshot('models/yelp_bow') 304 | VOCAB = 'vocab/vocab_yelp.pkl' 305 | elif TASK == 'tacred': 306 | args.bow_snapshot = get_best_snapshot('models/tacred_bow') 307 | VOCAB = 'vocab/vocab_tacred.pkl' 308 | if BERT: 309 | if TASK == 'sst': 310 | unigram_linear_pearson_bert_tree(path, DATASET) 311 | elif TASK == 'yelp': 312 | unigram_linear_pearson(path) 313 | else: 314 | unigram_linear_pearson_multiclass(path) 315 | else: 316 | if TASK in ['sst','yelp']: 317 | unigram_linear_pearson(path) 318 | else: 319 | unigram_linear_pearson_multiclass(path) 320 | if TASK == 'sst': 321 | out1, _, _, _= phrase_gt_pearson(path, dataset=DATASET) 322 | -------------------------------------------------------------------------------- /hiexpl/explain.py: -------------------------------------------------------------------------------- 1 | from algo.soc_lstm import SOCForLSTM 2 | from algo.scd_lstm import CDForLSTM, SCDForLSTM 3 | from algo.soc_transformer import SOCForTransformer 4 | from algo.scd_transformer import CDForTransformer, SCDForTransformer 5 | import torch 6 | import argparse 7 | from utils.args import get_args 8 | from utils.reader import get_data_iterators_sst_flatten, get_data_iterators_yelp, get_data_iterators_tacred 9 | import random, os 10 | from bert.run_classifier import BertConfig, BertForSequenceClassification 11 | from nns.model import LSTMMeanRE, LSTMMeanSentiment, LSTMSentiment 12 | 13 | def get_args_exp(): 14 | parser = argparse.ArgumentParser() 15 | parser.add_argument('--method') 16 | args = parser.parse_args() 17 | return args 18 | 19 | 20 | args = get_args() 21 | 22 | if __name__ == '__main__': 23 | seed = 0 24 | random.seed(seed) 25 | torch.cuda.manual_seed(seed) 26 | torch.manual_seed(seed) 27 | 28 | if args.task == 'sst' or args.task == 'sst_async': 29 | text_field, length_field, train_iter, dev_iter, test_iter, train, dev = \ 30 | get_data_iterators_sst_flatten(map_cpu=False) 31 | elif args.task == 'yelp': 32 | text_field, label_field, train_iter, dev_iter, test_iter, train, dev = \ 33 | get_data_iterators_yelp(map_cpu=False) 34 | elif args.task == 'tacred': 35 | text_field, label_field, subj_offset, obj_offset, pos, ner, train_iter, dev_iter, test_iter, train, dev = \ 36 | get_data_iterators_tacred() 37 | else: 38 | raise ValueError('unknown task') 39 | 40 | iter_map = {'train': train_iter, 'dev': dev_iter, 'test': test_iter} 41 | if args.task == 'sst': 42 | tree_path = '.data/sst/trees/%s.txt' 43 | elif args.task == 'yelp': 44 | tree_path = '.data/yelp_review_polarity_csv/%s.csv' 45 | elif args.task == 'tacred': 46 | tree_path = '.data/TACRED/data/json/%s.json' 47 | else: 48 | raise ValueError 49 | 50 | if args.task == 'tacred': 51 | args.label_vocab = label_field.vocab 52 | args.pos_size = len(pos.vocab) 53 | args.ner_size = len(ner.vocab) 54 | args.offset_emb_dim = 30 55 | 56 | 57 | args.n_embed = len(text_field.vocab) 58 | args.d_out = 2 if args.task in ['sst','yelp'] else len(label_field.vocab) 59 | args.n_cells = args.n_layers 60 | args.use_gpu = args.gpu >= 0 61 | 62 | if args.explain_model == 'lstm': 63 | cls = {'sst': LSTMSentiment, 'yelp': LSTMMeanSentiment, 'tacred': LSTMMeanRE} 64 | model = cls[args.task](args) 65 | model.load_state_dict(torch.load(args.resume_snapshot)) 66 | model = model.to(args.gpu) 67 | model.gpu = args.gpu 68 | model.use_gpu = args.gpu >= 0 69 | model.eval() 70 | print(model) 71 | if args.method == 'soc': 72 | lm_model = torch.load(args.lm_path, map_location=lambda storage, location: storage.cuda(args.gpu)) 73 | lm_model.gpu = args.gpu 74 | lm_model.encoder.gpu = args.gpu 75 | algo = SOCForLSTM(model, lm_model, iter_map[args.dataset], 76 | tree_path=tree_path % args.dataset, config=args, 77 | vocab=text_field.vocab, 78 | output_path='outputs/' + args.task + '/soc_results/soc%s.txt' % args.exp_name) 79 | elif args.method == 'scd': 80 | lm_model = torch.load(args.lm_path, map_location=lambda storage, location: storage.cuda(args.gpu)) 81 | lm_model.gpu = args.gpu 82 | lm_model.encoder.gpu = args.gpu 83 | algo = SCDForLSTM(model, lm_model, iter_map[args.dataset], 84 | tree_path=tree_path % args.dataset, config=args, 85 | vocab=text_field.vocab, 86 | output_path='outputs/' + args.task + '/scd_results/scd%s.txt' % args.exp_name) 87 | else: 88 | raise ValueError('unknown method') 89 | elif args.explain_model == 'bert': 90 | CONFIG_NAME = 'bert_config.json' 91 | WEIGHTS_NAME = 'pytorch_model.bin' 92 | output_model_file = os.path.join('bert/%s' % args.resume_snapshot, WEIGHTS_NAME) 93 | output_config_file = os.path.join('bert/%s' % args.resume_snapshot, CONFIG_NAME) 94 | # Load a trained model and config that you have fine-tuned 95 | config = BertConfig(output_config_file) 96 | model = BertForSequenceClassification(config, num_labels=2 if args.task != 'tacred' else 42) 97 | model.load_state_dict(torch.load(output_model_file)) 98 | model.eval() 99 | if args.gpu >= 0: 100 | model = model.to(args.gpu) 101 | if args.method == 'soc': 102 | lm_model = torch.load(args.lm_path, map_location=lambda storage, location: storage.cuda(args.gpu)) 103 | lm_model.gpu = args.gpu 104 | lm_model.encoder.gpu = args.gpu 105 | algo = SOCForTransformer(model, lm_model, 106 | tree_path=tree_path % args.dataset, 107 | output_path='outputs/' + args.task + '/soc_bert_results/soc%s.txt' % args.exp_name, 108 | config=args, vocab=text_field.vocab) 109 | elif args.method == 'scd': 110 | lm_model = torch.load(args.lm_path, map_location=lambda storage, location: storage.cuda(args.gpu)) 111 | lm_model.gpu = args.gpu 112 | lm_model.encoder.gpu = args.gpu 113 | algo = SCDForTransformer(model, lm_model, tree_path=tree_path % args.dataset, 114 | output_path='outputs/' + args.task + '/scd_bert_results/scd%s.txt' % args.exp_name, 115 | config=args, vocab=text_field.vocab) 116 | else: 117 | raise ValueError('unknown method') 118 | else: 119 | raise ValueError('unknown model') 120 | with torch.no_grad(): 121 | if args.task == 'sst': 122 | with torch.cuda.device(args.gpu): 123 | if args.agg: 124 | algo.explain_agg('sst') 125 | else: 126 | algo.explain_sst() 127 | elif args.task == 'yelp': 128 | with torch.cuda.device(args.gpu): 129 | if args.agg: 130 | algo.explain_agg('sst') 131 | algo.explain_token('yelp') 132 | elif args.task == 'tacred': 133 | with torch.cuda.device(args.gpu): 134 | algo.label_vocab = label_field.vocab 135 | algo.ner_vocab = ner.vocab 136 | algo.pos_vocab = pos.vocab 137 | if args.agg: 138 | algo.explain_agg('tacred') 139 | else: 140 | algo.explain_token('tacred') 141 | 142 | -------------------------------------------------------------------------------- /hiexpl/lm/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/INK-USC/hierarchical-explanation-neural-sequence-models/d797daee2cea327ff3a7fb5d9f077412861f834f/hiexpl/lm/__init__.py -------------------------------------------------------------------------------- /hiexpl/lm/lm.py: -------------------------------------------------------------------------------- 1 | from torch.distributions import Categorical 2 | from nns.layers import * 3 | from utils.args import get_args 4 | from torch.nn import functional as F 5 | 6 | args = get_args() 7 | 8 | 9 | class LSTMLanguageModel(nn.Module): 10 | def __init__(self, config, vocab): 11 | super().__init__() 12 | self.hidden_size = config.lm_d_hidden 13 | self.embed_size = config.lm_d_embed 14 | self.n_vocab = config.n_embed 15 | self.gpu = args.gpu 16 | 17 | self.encoder = DynamicEncoder(self.n_vocab, self.embed_size, self.hidden_size, self.gpu) 18 | self.fw_proj = nn.Linear(self.hidden_size, self.n_vocab) 19 | self.bw_proj = nn.Linear(self.hidden_size, self.n_vocab) 20 | 21 | self.loss = nn.CrossEntropyLoss(ignore_index=1) 22 | self.vocab = vocab 23 | 24 | self.warning_flag = False 25 | 26 | def forward(self, batch): 27 | inp = batch.text 28 | inp_len_np = batch.length.cpu().numpy() 29 | if self.gpu >= 0: 30 | inp = inp.to(self.gpu) 31 | output = self.encoder(inp, inp_len_np) 32 | fw_output, bw_output = output[:,:,:self.hidden_size], output[:,:,self.hidden_size:] 33 | fw_proj, bw_proj = self.fw_proj(fw_output), self.bw_proj(bw_output) 34 | 35 | fw_loss = self.loss(fw_proj[:-1].view(-1,fw_proj.size(2)).contiguous(), inp[1:].view(-1).contiguous()) 36 | bw_loss = self.loss(bw_proj[1:].view(-1,bw_proj.size(2)).contiguous(), inp[:-1].view(-1).contiguous()) 37 | return fw_loss, bw_loss 38 | 39 | def _sample_n_sequences(self, method, direction, token_inp, hidden, length, sample_num): 40 | outputs = [] 41 | token_inp = token_inp.repeat(1, sample_num) # [1, N] 42 | hidden = hidden[0].repeat(1, sample_num, 1), hidden[1].repeat(1, sample_num, 1) # [x, N, H] 43 | for t in range(length): 44 | output, hidden = self.encoder.rollout(token_inp, hidden, direction=direction) 45 | if direction == 'fw': 46 | proj = self.fw_proj(output[:, :, :self.hidden_size]) 47 | elif direction == 'bw': 48 | proj = self.bw_proj(output[:, :, self.hidden_size:]) 49 | proj = proj.squeeze(0) 50 | if method == 'max': 51 | _, token_inp = torch.max(proj,-1) 52 | outputs.append(token_inp.view(-1)) 53 | elif method == 'random': 54 | dist = Categorical(F.softmax(proj,-1)) 55 | token_inp = dist.sample() 56 | outputs.append(token_inp) 57 | token_inp = token_inp.view(1, -1) 58 | if direction == 'bw': 59 | outputs = list(reversed(outputs)) 60 | outputs = torch.stack(outputs) 61 | return outputs 62 | 63 | def sample_n(self, method, batch, max_sample_length, sample_num): 64 | inp = batch.text 65 | inp_len_np = batch.length.cpu().numpy() 66 | batch_size = inp.size(1) 67 | assert batch_size == 1 68 | 69 | pad_inp1 = torch.LongTensor([self.vocab.stoi['']] * inp.size(1)).view(1, -1) 70 | pad_inp2 = torch.LongTensor([self.vocab.stoi['']] * inp.size(1)).view(1,-1) 71 | 72 | if self.gpu >= 0: 73 | inp = inp.to(self.gpu) 74 | pad_inp1 = pad_inp1.to(self.gpu) 75 | pad_inp2 = pad_inp2.to(self.gpu) 76 | 77 | padded_inp = torch.cat([pad_inp1, inp, pad_inp2], 0) 78 | assert padded_inp.max().item() < self.n_vocab 79 | assert inp_len_np[0] + 2 <= padded_inp.size(0) 80 | padded_enc_out, (padded_hidden_states, padded_cell_states) = self.encoder(padded_inp, inp_len_np + 2, 81 | return_all_states=True) # [T+2,B,H] 82 | 83 | # extract forward hidden state 84 | assert 0 <= batch.fw_pos.item() - 1 <= padded_enc_out.size(0) - 1 85 | assert 0 <= batch.fw_pos.item() <= padded_enc_out.size(0) - 1 86 | 87 | fw_hidden_state = padded_hidden_states.index_select(0, batch.fw_pos - 1)[0] 88 | fw_cell_state = padded_cell_states.index_select(0, batch.fw_pos - 1)[0] 89 | fw_next_token = padded_inp.index_select(0,batch.fw_pos).view(1,-1) 90 | 91 | # extract backward hidden state 92 | assert 0 <= batch.bw_pos.item() + 3 <= padded_enc_out.size(0) - 1 93 | assert 0 <= batch.bw_pos.item() + 2 <= padded_enc_out.size(0) - 1 94 | # batch 95 | bw_hidden_state = padded_hidden_states.index_select(0,batch.bw_pos + 3)[0] 96 | bw_cell_state = padded_cell_states.index_select(0, batch.bw_pos + 3)[0] 97 | # torch.cat([bw_hidden[:,:,:self.hidden_size], bw_hidden[:,:,self.hidden_size:]], 0) 98 | bw_next_token = padded_inp.index_select(0,batch.bw_pos + 2).view(1,-1) 99 | 100 | fw_sample_outputs = self._sample_n_sequences(method, 'fw', fw_next_token, (fw_hidden_state, fw_cell_state), 101 | max_sample_length, sample_num) 102 | bw_sample_outputs = self._sample_n_sequences(method, 'bw', bw_next_token, (bw_hidden_state, bw_cell_state), 103 | max_sample_length, sample_num) 104 | 105 | self.filter_special_tokens(fw_sample_outputs) 106 | self.filter_special_tokens(bw_sample_outputs) 107 | 108 | return fw_sample_outputs, bw_sample_outputs 109 | 110 | def filter_special_tokens(self, m): 111 | for i in range(m.size(0)): 112 | for j in range(m.size(1)): 113 | if m[i,j] >= self.n_vocab - 2: 114 | m[i,j] = 0 -------------------------------------------------------------------------------- /hiexpl/lm/lm_train.py: -------------------------------------------------------------------------------- 1 | import time 2 | import glob 3 | 4 | import torch.optim as O 5 | 6 | from .lm import LSTMLanguageModel 7 | from utils.reader import * 8 | 9 | 10 | import random 11 | 12 | random.seed(0) 13 | 14 | args = get_args() 15 | try: 16 | torch.cuda.set_device(args.gpu) 17 | except AttributeError: 18 | pass 19 | 20 | 21 | def do_train(model): 22 | opt = O.Adam(filter(lambda x: x.requires_grad, model.parameters())) 23 | 24 | iterations = 0 25 | start = time.time() 26 | best_dev_nll = 1e10 27 | train_iter.repeat = False 28 | header = ' Time Epoch Iteration Progress (%Epoch) Loss Dev/Loss' 29 | dev_log_template = ' '.join( 30 | '{:>6.0f},{:>5.0f},{:>9.0f},{:>5.0f}/{:<5.0f} {:>7.0f}%,{:>8.6f},{:8.6f},{:>8.6f},{:>8.6f}'.split(',')) 31 | log_template = ' '.join( 32 | '{:>6.0f},{:>5.0f},{:>9.0f},{:>5.0f}/{:<5.0f} {:>7.0f}%,{:>8.6f},{}'.split(',')) 33 | makedirs(args.save_path) 34 | print(header) 35 | 36 | all_break = False 37 | print(model) 38 | 39 | for epoch in range(args.epochs): 40 | if all_break: 41 | break 42 | train_iter.init_epoch() 43 | train_loss = 0 44 | for batch_idx, batch in enumerate(train_iter): 45 | # switch model to training mode, clear gradient accumulators 46 | model.train() 47 | opt.zero_grad() 48 | 49 | iterations += 1 50 | print(('epoch %d iter %d' + ' ' * 10) % (epoch, batch_idx), end='\r') 51 | 52 | # forward pass 53 | fw_loss, bw_loss = model(batch) 54 | 55 | loss = fw_loss + bw_loss 56 | # backpropagate and update optimizer learning rate 57 | loss.backward() 58 | opt.step() 59 | 60 | train_loss += loss.item() 61 | 62 | # checkpoint model periodically 63 | if iterations % args.save_every == 0: 64 | snapshot_prefix = os.path.join(args.save_path, 'snapshot') 65 | snapshot_path = snapshot_prefix + 'loss_{:.6f}_iter_{}_model.pt'.format(loss.item(), iterations) 66 | torch.save(model, snapshot_path) 67 | for f in glob.glob(snapshot_prefix + '*'): 68 | if f != snapshot_path: 69 | os.remove(f) 70 | 71 | # evaluate performance on validation set periodically 72 | if iterations % args.dev_every == 0: 73 | 74 | # switch model to evaluation mode 75 | model.eval() 76 | dev_iter.init_epoch() 77 | 78 | # calculate accuracy on validation set 79 | 80 | cnt, dev_loss = 0, 0 81 | dev_fw_loss, dev_bw_loss = 0,0 82 | for dev_batch_idx, dev_batch in enumerate(dev_iter): 83 | fw_loss, bw_loss = model(dev_batch) 84 | loss = fw_loss + bw_loss 85 | cnt += 1 86 | dev_loss += loss.item() 87 | dev_fw_loss += fw_loss.item() 88 | dev_bw_loss += bw_loss.item() 89 | dev_loss /= cnt 90 | dev_fw_loss /= cnt 91 | dev_bw_loss /= cnt 92 | print(dev_log_template.format(time.time() - start, 93 | epoch, iterations, 1 + batch_idx, len(train_iter), 94 | 100. * (1 + batch_idx) / len(train_iter), train_loss / (batch_idx + 1), 95 | dev_loss, dev_fw_loss, dev_bw_loss)) 96 | 97 | # update best valiation set accuracy 98 | if dev_loss < best_dev_nll: 99 | best_dev_nll = dev_loss 100 | snapshot_prefix = os.path.join(args.save_path, 'best_snapshot') 101 | snapshot_path = snapshot_prefix + '_devloss_{}_iter_{}_model.pt'.format(dev_loss, 102 | iterations) 103 | 104 | # save model, delete previous 'best_snapshot' files 105 | torch.save(model, snapshot_path) 106 | for f in glob.glob(snapshot_prefix + '*'): 107 | if f != snapshot_path: 108 | os.remove(f) 109 | 110 | elif iterations % args.log_every == 0: 111 | # print progress message 112 | print(log_template.format(time.time() - start, 113 | epoch, iterations, 1 + batch_idx, len(train_iter), 114 | 100. * (1 + batch_idx) / len(train_iter), loss.item(), ' ' * 8)) 115 | 116 | 117 | if __name__ == '__main__': 118 | if args.task == 'sst': 119 | inputs, lengths, train_iter, dev_iter, test_iter, train_set, dev_set = get_data_iterators_sst_flatten(train_lm=True) 120 | elif args.task == 'yelp': 121 | inputs, labels, train_iter, dev_iter, test_iter, train_set, dev_set = get_data_iterators_yelp( 122 | train_lm=True) 123 | elif args.task == 'tacred': 124 | inputs, labels, subj_offset, obj_offset, pos, ner, train_iter, dev_iter, test_iter, train_set, dev_set = get_data_iterators_tacred(train_lm=True) 125 | else: 126 | raise ValueError('unknown task') 127 | 128 | config = args 129 | config.n_embed = len(inputs.vocab) 130 | config.n_cells = config.n_layers 131 | config.use_gpu = args.gpu >= 0 132 | 133 | model = LSTMLanguageModel(config, inputs.vocab) 134 | if args.word_vectors: 135 | model.encoder.embedding.weight.data = inputs.vocab.vectors 136 | if args.fix_emb: 137 | model.encoder.embedding.weight.requires_grad = False 138 | if config.use_gpu: 139 | model = model.cuda() 140 | 141 | do_train(model) 142 | 143 | -------------------------------------------------------------------------------- /hiexpl/nns/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/INK-USC/hierarchical-explanation-neural-sequence-models/d797daee2cea327ff3a7fb5d9f077412861f834f/hiexpl/nns/__init__.py -------------------------------------------------------------------------------- /hiexpl/nns/layers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import numpy as np 4 | 5 | 6 | class DynamicEncoder(nn.Module): 7 | def __init__(self, input_size, embed_size, hidden_size, gpu, n_layers=1, dropout=0.1): 8 | super().__init__() 9 | self.input_size = input_size 10 | self.hidden_size = hidden_size 11 | self.embed_size = embed_size 12 | self.n_layers = n_layers 13 | self.dropout = dropout 14 | self.embedding = nn.Embedding(input_size, embed_size) 15 | self.lstm = nn.LSTM(embed_size, hidden_size, n_layers, bidirectional=True) 16 | self.gpu = gpu 17 | 18 | def forward(self, input_seqs, input_lens, hidden=None, return_all_states=False): 19 | batch_size = input_seqs.size(1) 20 | embedded = self.embedding(input_seqs) 21 | if not return_all_states: 22 | embedded = embedded.transpose(0, 1) # [B,T,E] 23 | sort_idx = np.argsort(-input_lens) 24 | unsort_idx = torch.LongTensor(np.argsort(sort_idx)) 25 | if self.gpu >= 0: 26 | unsort_idx = unsort_idx.to(self.gpu) 27 | input_lens = input_lens[sort_idx] 28 | sort_idx = torch.LongTensor(sort_idx) 29 | if self.gpu >= 0: 30 | sort_idx = sort_idx.to(self.gpu) 31 | embedded = embedded[sort_idx].transpose(0, 1) # [T,B,E] 32 | packed = torch.nn.utils.rnn.pack_padded_sequence(embedded, input_lens) 33 | outputs, hidden = self.lstm(packed, hidden) 34 | outputs, _ = torch.nn.utils.rnn.pad_packed_sequence(outputs) 35 | # outputs = outputs[:, :, :self.hidden_size] + outputs[:, :, self.hidden_size:] 36 | outputs = outputs.transpose(0, 1)[unsort_idx].transpose(0, 1).contiguous() 37 | return outputs 38 | else: 39 | hidden = None 40 | hidden_states, cell_states = [], [] 41 | outputs = [] 42 | for t in range(input_seqs.size(0)): 43 | output, hidden = self.lstm(embedded[t].unsqueeze(0), hidden) 44 | hidden_states.append(hidden[0]) 45 | cell_states.append(hidden[1]) 46 | outputs.append(output) 47 | outputs = torch.cat(outputs, 0) 48 | hidden_states = torch.stack(hidden_states, 0) 49 | cell_states = torch.stack(cell_states, 0) 50 | return outputs, (hidden_states, cell_states) 51 | 52 | def rollout(self, input_word, prev_hidden, direction): 53 | embed = self.embedding(input_word) 54 | output, hidden = self.lstm(embed, prev_hidden) 55 | return output, hidden 56 | -------------------------------------------------------------------------------- /hiexpl/nns/linear_model.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | import time 3 | import glob 4 | 5 | import torch.optim as O 6 | from utils.reader import * 7 | from utils.tacred_f1 import score as f1_score 8 | from sklearn.metrics import accuracy_score 9 | 10 | 11 | class BOWRegression(nn.Module): 12 | def __init__(self, vocab, config): 13 | super().__init__() 14 | self.vocab_size = len(vocab.itos) 15 | self.weight = nn.Linear(self.vocab_size, 1) 16 | self.sigmoid = nn.Sigmoid() 17 | self.loss = nn.BCELoss() 18 | 19 | self.gpu = config.gpu 20 | 21 | def forward(self, batch): 22 | text = batch.text.cpu().numpy() # [T,B] 23 | bow = [] 24 | for b in range(text.shape[1]): 25 | v = torch.zeros(self.vocab_size) 26 | for t in range(text.shape[0]): 27 | if text[t,b] != 1: # [pad] 28 | v[text[t,b]] = 1 29 | bow.append(v) 30 | bow = torch.stack(bow) # [B, V] 31 | bow = bow.to(self.gpu) 32 | score = self.weight(bow) # [B, 1] 33 | #loss = self.loss(score, batch.label.unsqueeze(-1)) 34 | return self.sigmoid(score) 35 | 36 | def get_coefficient(self, idx): 37 | return self.weight.weight[0,idx].item() 38 | 39 | class BOWRegressionMulti(nn.Module): 40 | def __init__(self, vocab, config, label_vocab): 41 | super().__init__() 42 | self.label_vocab = label_vocab 43 | self.vocab_size = len(vocab.itos) 44 | self.label_size = len(label_vocab.itos) 45 | self.weight = nn.Linear(self.vocab_size, self.label_size) 46 | self.softmax = nn.Softmax(-1) 47 | self.loss = nn.CrossEntropyLoss() 48 | 49 | self.gpu = config.gpu 50 | 51 | def forward(self, batch): 52 | text = batch.text.cpu().numpy() # [T,B] 53 | bow = [] 54 | for b in range(text.shape[1]): 55 | v = torch.zeros(self.vocab_size) 56 | for t in range(text.shape[0]): 57 | if text[t,b] != 1: # [pad] 58 | v[text[t,b]] = 1 59 | bow.append(v) 60 | bow = torch.stack(bow) # [B, V] 61 | bow = bow.to(self.gpu) 62 | score = self.weight(bow) # [B, C] 63 | #loss = self.loss(score, batch.label.unsqueeze(-1)) 64 | return score 65 | 66 | def get_label_coefficient(self, class_idx_or_name, word_idx): 67 | if type(class_idx_or_name) is str: 68 | class_idx = self.label_vocab.stoi[class_idx_or_name] 69 | else: 70 | class_idx = class_idx_or_name 71 | return self.weight.weight[class_idx,word_idx].item() 72 | 73 | def get_margin_coefficient(self, class_idx_or_name, word_idx): 74 | if type(class_idx_or_name) is str: 75 | class_idx = self.label_vocab.stoi[class_idx_or_name] 76 | else: 77 | class_idx = class_idx_or_name 78 | weight_vec = self.weight.weight[:,word_idx] 79 | weight_vec_bak = weight_vec.clone() 80 | weight_vec_bak[class_idx] = -1000 81 | margin = weight_vec[class_idx] - weight_vec_bak.max() 82 | return margin.item() 83 | 84 | def do_train(): 85 | if args.task != 'tacred': 86 | criterion = nn.BCELoss() 87 | else: 88 | criterion = nn.CrossEntropyLoss() 89 | opt = O.Adam(filter(lambda x: x.requires_grad, model.parameters())) 90 | 91 | iterations = 0 92 | start = time.time() 93 | best_dev_acc = -1 94 | best_dev_loss = 10000 95 | train_iter.repeat = False 96 | header = ' Time Epoch Iteration Progress (%Epoch) Loss Dev/Loss Accuracy Dev/Accuracy' 97 | dev_log_template = ' '.join( 98 | '{:>6.0f},{:>5.0f},{:>9.0f},{:>5.0f}/{:<5.0f} {:>7.0f}%,{:>8.6f},{:8.6f},{:12.4f},{:12.4f}'.split(',')) 99 | log_template = ' '.join( 100 | '{:>6.0f},{:>5.0f},{:>9.0f},{:>5.0f}/{:<5.0f} {:>7.0f}%,{:>8.6f},{},{:12.4f},{}'.split(',')) 101 | makedirs(args.save_path) 102 | print(header) 103 | 104 | all_break = False 105 | print(model) 106 | 107 | for epoch in range(args.epochs): 108 | if all_break: 109 | break 110 | train_iter.init_epoch() 111 | n_correct, n_total = 0, 0 112 | for batch_idx, batch in enumerate(train_iter): 113 | # switch model to training mode, clear gradient accumulators 114 | model.train() 115 | opt.zero_grad() 116 | 117 | iterations += 1 118 | print('epoch %d iter %d' % (epoch, batch_idx), end='\r') 119 | 120 | # forward pass 121 | answer = model(batch) 122 | 123 | # calculate loss of the network output with respect to training labels 124 | train_acc = 0 125 | if args.task != 'tacred': 126 | batch.label = batch.label.float() 127 | loss = criterion(answer, batch.label.to(args.gpu)) 128 | 129 | # backpropagate and update optimizer learning rate 130 | loss.backward() 131 | opt.step() 132 | 133 | # checkpoint model periodically 134 | if iterations % args.save_every == 0: 135 | snapshot_prefix = os.path.join(args.save_path, 'snapshot') 136 | snapshot_path = snapshot_prefix + '_acc_{:.4f}_loss_{:.6f}_iter_{}_model.pt'.format(train_acc, 137 | loss.item(), 138 | iterations) 139 | torch.save(model, snapshot_path) 140 | for f in glob.glob(snapshot_prefix + '*'): 141 | if f != snapshot_path: 142 | os.remove(f) 143 | 144 | # evaluate performance on validation set periodically 145 | if iterations % args.dev_every == 0: 146 | 147 | # switch model to evaluation mode 148 | model.eval() 149 | dev_iter.init_epoch() 150 | avg_dev_loss = 0 151 | # calculate accuracy on validation set 152 | n_dev_correct, dev_loss = 0, 0 153 | truth_dev, pred_dev = [], [] 154 | for dev_batch_idx, dev_batch in enumerate(dev_iter): 155 | answer = model(dev_batch) 156 | dev_label = dev_batch.label if not config.use_gpu else dev_batch.label.cuda() 157 | if args.task != 'tacred': 158 | pred = (answer > 0.5).long() 159 | else: 160 | pred = torch.max(answer, 1)[1] 161 | if args.task != 'tacred': 162 | dev_label = dev_label.float() 163 | dev_loss = criterion(answer, dev_label) 164 | avg_dev_loss += dev_loss.item() * dev_label.size(0) 165 | for l_i in range(dev_label.size(0)): 166 | pred_dev.append(pred.view(-1)[l_i].item()) 167 | truth_dev.append(dev_label[l_i].item()) 168 | if args.task in ['sst', 'yelp']: 169 | dev_acc = 100. * accuracy_score(truth_dev, pred_dev) 170 | elif args.task == 'tacred': 171 | dev_acc = 100. * f1_score(truth_dev, pred_dev)[-1] 172 | else: 173 | raise ValueError 174 | avg_dev_loss /= len(dev_set) 175 | print(dev_log_template.format(time.time() - start, 176 | epoch, iterations, 1 + batch_idx, len(train_iter), 177 | 100. * (1 + batch_idx) / len(train_iter), loss.item(), 178 | avg_dev_loss, 179 | train_acc, dev_acc)) 180 | 181 | # update best valiation set accuracy 182 | if dev_acc > best_dev_acc: 183 | #if avg_dev_loss < best_dev_loss: 184 | best_dev_acc = dev_acc 185 | best_dev_loss = avg_dev_loss 186 | snapshot_prefix = os.path.join(args.save_path, 'best_snapshot') 187 | snapshot_path = snapshot_prefix + '_devacc_{}_devloss_{}_iter_{}_model.pt'.format(dev_acc, 188 | dev_loss.item(), 189 | iterations) 190 | 191 | # save model, delete previous 'best_snapshot' files 192 | torch.save(model, snapshot_path) 193 | for f in glob.glob(snapshot_prefix + '*'): 194 | if f != snapshot_path: 195 | os.remove(f) 196 | 197 | elif iterations % args.log_every == 0: 198 | # print progress message 199 | print(log_template.format(time.time() - start, 200 | epoch, iterations, 1 + batch_idx, len(train_iter), 201 | 100. * (1 + batch_idx) / len(train_iter), loss.item(), ' ' * 8, 202 | n_correct / n_total * 100, ' ' * 12)) 203 | 204 | 205 | if __name__ == '__main__': 206 | if args.task == 'sst': 207 | inputs, labels, train_iter, dev_iter, test_iter, train_set, dev_set = get_data_iterators_sst() 208 | elif args.task == 'yelp': 209 | inputs, labels, train_iter, dev_iter, test_iter, train_set, dev_set = get_data_iterators_yelp() 210 | elif args.task == 'tacred': 211 | inputs, labels, subj_offset, obj_offset, pos, ner, train_iter, dev_iter, test_iter, train_set, dev_set = \ 212 | get_data_iterators_tacred() 213 | else: 214 | raise ValueError('unknown task') 215 | 216 | config = args 217 | config.n_embed = len(inputs.vocab) 218 | config.d_out = len(labels.vocab) 219 | config.n_cells = config.n_layers 220 | config.use_gpu = args.gpu >= 0 221 | 222 | if args.task != 'tacred': 223 | model = BOWRegression(inputs.vocab, config) 224 | else: 225 | model = BOWRegressionMulti(inputs.vocab, config, labels.vocab) 226 | if config.use_gpu: 227 | model = model.cuda() 228 | 229 | #if not args.bow_snapshot: 230 | do_train() 231 | 232 | -------------------------------------------------------------------------------- /hiexpl/nns/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | from torch.nn import init 5 | 6 | 7 | class LSTMSentiment(nn.Module): 8 | def __init__(self, config, match_length=False): 9 | super(LSTMSentiment, self).__init__() 10 | self.hidden_dim = config.d_hidden 11 | self.vocab_size = config.n_embed 12 | self.emb_dim = config.d_embed 13 | self.batch_size = config.batch_size 14 | self.use_gpu = config.gpu >= 0 15 | self.num_labels = config.d_out 16 | self.embed = nn.Embedding(self.vocab_size, self.emb_dim, padding_idx=1) 17 | self.match_length = match_length 18 | 19 | self.lstm = nn.LSTM(input_size = self.emb_dim, hidden_size = self.hidden_dim) 20 | self.hidden_to_label = nn.Linear(self.hidden_dim, self.num_labels) 21 | 22 | def forward(self, batch): 23 | if hasattr(batch, 'vec'): 24 | vecs = batch.vec 25 | else: 26 | if self.use_gpu: 27 | inp = batch.text.long().cuda() 28 | else: 29 | inp = batch.text.long() 30 | vecs = self.embed(inp) 31 | lstm_out, hidden = self.lstm(vecs) # [T, B, H] 32 | 33 | if hasattr(self, 'match_length') and self.match_length: 34 | if hasattr(batch, 'length'): 35 | length = batch.length.cpu().numpy() 36 | elif vecs.size(1) == 1: # batch size is 1 37 | length = np.array([vecs.size(0)]) 38 | else: 39 | length = np.array([vecs.size(0)] * vecs.size(1)) 40 | print('Warning: length is missing') 41 | hidden_state = [] 42 | for i in range(length.shape[0]): 43 | hidden_state.append(lstm_out[length[i]-1,i]) 44 | hidden_state = torch.stack(hidden_state) # [B,H] 45 | else: 46 | hidden_state = lstm_out[-1] 47 | logits = self.hidden_to_label(hidden_state) 48 | return logits 49 | 50 | 51 | class LSTMMeanSentiment(LSTMSentiment): 52 | def __init__(self, config, match_length=False): 53 | super().__init__(config, match_length) 54 | 55 | def forward(self, batch): 56 | if hasattr(batch, 'vec'): 57 | vecs = batch.vec 58 | else: 59 | if self.use_gpu: 60 | inp = batch.text.long().cuda() 61 | else: 62 | inp = batch.text.long() 63 | vecs = self.embed(inp) 64 | lstm_out, hidden = self.lstm(vecs) # [T, B, H] 65 | 66 | if self.match_length: 67 | if hasattr(batch, 'length'): 68 | length = batch.length.cpu().numpy() 69 | elif vecs.size(1) == 1: # batch size is 1 70 | length = np.array([vecs.size(0)]) 71 | else: 72 | length = np.array([vecs.size(0)] * vecs.size(1)) 73 | #print('Warning: length is missing') 74 | hidden_state = [] 75 | for i in range(length.shape[0]): 76 | hidden_state.append(torch.mean(lstm_out[:length[i],i], 0)) 77 | hidden_state = torch.stack(hidden_state) # [B,H] 78 | else: 79 | hidden_state = torch.mean(lstm_out, 0) 80 | logits = self.hidden_to_label(hidden_state) 81 | return logits 82 | 83 | 84 | class LSTMMeanRE(LSTMSentiment): 85 | def __init__(self, config, match_length=False): 86 | super().__init__(config, match_length) 87 | self.drop = nn.Dropout(0.5) 88 | self.pos_embed = nn.Embedding(config.pos_size, config.offset_emb_dim, padding_idx=1) 89 | self.ner_embed = nn.Embedding(config.ner_size, config.offset_emb_dim, padding_idx=1) 90 | self.lstm = LSTMLayer(self.emb_dim + config.offset_emb_dim * 2, self.hidden_dim, 1, 0.5, True) 91 | self.hidden_to_label = nn.Linear(self.hidden_dim, config.d_out) 92 | 93 | def forward(self, batch): 94 | if hasattr(batch, 'vec'): 95 | vecs = batch.vec 96 | else: 97 | inp_tokens = batch.text.long().cuda() 98 | inp_pos = batch.pos.long().cuda() 99 | inp_ner = batch.ner.long().cuda() 100 | 101 | token_vecs = self.embed(inp_tokens) 102 | pos_vecs = self.pos_embed(inp_pos) 103 | ner_vecs = self.ner_embed(inp_ner) 104 | vecs = torch.cat([token_vecs, pos_vecs, ner_vecs], -1) # [T, B, H] 105 | vecs = self.drop(vecs) 106 | h0, c0 = self.zero_state(vecs.size(1)) 107 | if not hasattr(batch, 'length'): 108 | batch.length = torch.LongTensor([vecs.size(0)] * vecs.size(1)).cuda() 109 | lstm_out, (ht,ct) = self.lstm(vecs, batch.length, (h0, c0)) 110 | 111 | hidden = self.drop(ht[-1,:,:]) 112 | logits = self.hidden_to_label(hidden) 113 | return logits 114 | 115 | def init_weights(self): 116 | self.pos_embed.weight.data[2:,:].uniform_(-1.0, 1.0) 117 | self.ner_embed.weight.data[2:,:].uniform_(-1.0, 1.0) 118 | 119 | self.hidden_to_label.bias.data.fill_(0) 120 | init.xavier_uniform_(self.hidden_to_label.weight, gain=1) # initialize linear layer 121 | 122 | def zero_state(self, batch_size): 123 | state_shape = (1, batch_size, self.hidden_dim) 124 | h0 = c0 = torch.zeros(*state_shape, requires_grad=False) 125 | return h0.cuda(), c0.cuda() 126 | 127 | class LSTMLayer(nn.Module): 128 | # Copyright 2017 The Board of Trustees of The Leland Stanford Junior University 129 | # 130 | # Licensed under the Apache License, Version 2.0 (the "License"); 131 | # you may not use this file except in compliance with the License. 132 | # You may obtain a copy of the License at 133 | # 134 | # http://www.apache.org/licenses/LICENSE-2.0 135 | # 136 | # Unless required by applicable law or agreed to in writing, software 137 | # distributed under the License is distributed on an "AS IS" BASIS, 138 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 139 | # See the License for the specific language governing permissions and 140 | # limitations under the License. 141 | 142 | """A wrapper for LSTM with sequence packing.""" 143 | 144 | def __init__(self, emb_dim, hidden_dim, num_layers, dropout, use_cuda): 145 | super(LSTMLayer, self).__init__() 146 | self.rnn = nn.LSTM(emb_dim, hidden_dim, num_layers, batch_first=False, dropout=dropout) 147 | self.use_cuda = use_cuda 148 | 149 | def forward(self, x, x_lens, init_state): 150 | """ 151 | x: batch_size * feature_size * seq_len 152 | x_mask : batch_size * seq_len 153 | """ 154 | x_lens = x_lens.cuda() 155 | _, idx_sort = torch.sort(x_lens, dim=0, descending=True) 156 | _, idx_unsort = torch.sort(idx_sort, dim=0) 157 | 158 | lens = list(x_lens[idx_sort]) 159 | 160 | # sort by seq lens 161 | x = x.index_select(1, idx_sort) 162 | rnn_input = nn.utils.rnn.pack_padded_sequence(x, lens, batch_first=False) 163 | rnn_output, (ht, ct) = self.rnn(rnn_input, init_state) 164 | rnn_output = nn.utils.rnn.pad_packed_sequence(rnn_output, batch_first=False)[0] 165 | 166 | # unsort 167 | rnn_output = rnn_output.index_select(1, idx_unsort) 168 | ht = ht.index_select(1, idx_unsort) 169 | ct = ct.index_select(1, idx_unsort) 170 | return rnn_output, (ht, ct) 171 | -------------------------------------------------------------------------------- /hiexpl/scripts/explanations/explain_sst_lstm.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | model_path=${1} 4 | lm_path=${2} 5 | python explain.py --resume_snapshot ${1} --task sst --method cd --batch_size 1 --exp_name .cd2019.3 --nb_range 10 --lm_path models/sst_lm_2/best_snapshot_devloss_11.634532430897588_iter_1300_model.pt --nb_method ngram --gpu 0 --sample_n 20 --start 0 --stop 100 --dataset test -------------------------------------------------------------------------------- /hiexpl/scripts/train_model/train_sst_lstm.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | python train.py --task sst --save_path sst_lstm -------------------------------------------------------------------------------- /hiexpl/train.py: -------------------------------------------------------------------------------- 1 | import time 2 | import glob 3 | 4 | import torch.optim as O 5 | 6 | from nns.model import * 7 | from utils.reader import * 8 | from utils.tacred_f1 import score as tacred_f1_score 9 | 10 | import random 11 | 12 | random.seed(0) 13 | 14 | args = get_args() 15 | try: 16 | torch.cuda.set_device(args.gpu) 17 | except AttributeError: 18 | pass 19 | 20 | def change_lr(optimizer, new_lr): 21 | for param_group in optimizer.param_groups: 22 | param_group['lr'] = new_lr 23 | 24 | def word_dropout(train_batch): 25 | for b in range(train_batch.length.size(0)): 26 | for t in range(train_batch.length[b].item()): 27 | if random.random() < 0.04: 28 | train_batch.text[t,b] = 0 # unk 29 | 30 | def do_train(): 31 | criterion = nn.CrossEntropyLoss() 32 | if args.optim == 'adam': 33 | opt = O.Adam(filter(lambda x: x.requires_grad, model.parameters()), weight_decay=1e-6, lr=args.lr) 34 | else: 35 | opt = O.SGD(filter(lambda x: x.requires_grad, model.parameters()), weight_decay=1e-6, lr=args.lr) 36 | iterations = 0 37 | start = time.time() 38 | best_dev_acc = -1 39 | prev_dev_acc = -1 40 | train_iter.repeat = False 41 | header = ' Time Epoch Iteration Progress (%Epoch) Loss Dev/Loss Accuracy Dev/Accuracy' 42 | dev_log_template = ' '.join( 43 | '{:>6.0f},{:>5.0f},{:>9.0f},{:>5.0f}/{:<5.0f} {:>7.0f}%,{:>8.6f},{:8.6f},{:12.4f},{:12.4f}'.split(',')) 44 | log_template = ' '.join( 45 | '{:>6.0f},{:>5.0f},{:>9.0f},{:>5.0f}/{:<5.0f} {:>7.0f}%,{:>8.6f},{},{:12.4f},{}'.split(',')) 46 | makedirs(args.save_path) 47 | print(header) 48 | 49 | all_break = False 50 | print(model) 51 | for epoch in range(args.epochs): 52 | if all_break: 53 | break 54 | train_iter.init_epoch() 55 | 56 | n_correct, n_total = 0, 0 57 | for batch_idx, batch in enumerate(train_iter): 58 | # switch model to training mode, clear gradient accumulators 59 | model.train() 60 | opt.zero_grad() 61 | if args.word_dropout: 62 | word_dropout(batch) 63 | iterations += 1 64 | print('epoch %d iter %d' % (epoch, batch_idx), end='\r') 65 | 66 | # forward pass 67 | answer = model(batch) 68 | 69 | # calculate accuracy of predictions in the current batch 70 | label = batch.label if not config.use_gpu else batch.label.cuda() 71 | n_correct += (torch.max(answer, 1)[1].view(label.size()).data == label.data).sum() 72 | n_total += batch.batch_size 73 | train_acc = 100. * n_correct / n_total 74 | 75 | # calculate loss of the network output with respect to training labels 76 | loss = criterion(answer, label) 77 | torch.nn.utils.clip_grad_norm_(model.parameters(), 5.0) 78 | # backpropagate and update optimizer learning rate 79 | loss.backward() 80 | torch.nn.utils.clip_grad_norm_(model.parameters(), 5.0) 81 | opt.step() 82 | 83 | # checkpoint model periodically 84 | if iterations % args.save_every == 0: 85 | snapshot_prefix = os.path.join(args.save_path, 'snapshot') 86 | snapshot_path = snapshot_prefix + '_acc_{:.4f}_loss_{:.6f}_iter_{}_model.pt'.format(train_acc, 87 | loss.item(), 88 | iterations) 89 | torch.save(model.state_dict(), snapshot_path) 90 | for f in glob.glob(snapshot_prefix + '*'): 91 | if f != snapshot_path: 92 | os.remove(f) 93 | 94 | # evaluate performance on validation set periodically 95 | if iterations % args.dev_every == 0: 96 | 97 | # switch model to evaluation mode 98 | model.eval() 99 | dev_iter.init_epoch() 100 | 101 | # calculate accuracy on validation set 102 | n_dev_correct, dev_loss_avg = 0, 0 103 | truth_dev, pred_dev = [], [] 104 | with torch.no_grad(): 105 | for dev_batch_idx, dev_batch in enumerate(dev_iter): 106 | answer = model(dev_batch) 107 | dev_label = dev_batch.label if not config.use_gpu else dev_batch.label.cuda() 108 | pred = torch.max(answer, 1)[1] 109 | n_dev_correct += (pred.view(dev_label.size()).data == dev_label.data).sum() 110 | dev_loss = criterion(answer, dev_label) 111 | for l_i in range(dev_label.size(0)): 112 | pred_dev.append(pred.view(-1)[l_i].item()) 113 | truth_dev.append(dev_label[l_i].item()) 114 | dev_loss_avg += dev_loss.item() 115 | if args.metrics == 'tacred_f1': 116 | dev_acc = 100. * tacred_f1_score(truth_dev, pred_dev)[-1] 117 | else: 118 | dev_acc = 100. * n_dev_correct / len(dev_set) 119 | print(dev_log_template.format(time.time() - start, 120 | epoch, iterations, 1 + batch_idx, len(train_iter), 121 | 100. * (1 + batch_idx) / len(train_iter), loss.item(), 122 | dev_loss_avg / len(dev_iter), 123 | train_acc, dev_acc)) 124 | 125 | # update best valiation set accuracy 126 | if dev_acc > best_dev_acc: 127 | best_dev_acc = dev_acc 128 | snapshot_prefix = os.path.join(args.save_path, 'best_snapshot') 129 | snapshot_path = snapshot_prefix + '_devacc_{}_devloss_{}_iter_{}_model.pt'.format(dev_acc, 130 | dev_loss.item(), 131 | iterations) 132 | 133 | # save model, delete previous 'best_snapshot' files 134 | torch.save(model.state_dict(), snapshot_path) 135 | for f in glob.glob(snapshot_prefix + '*'): 136 | if f != snapshot_path: 137 | os.remove(f) 138 | 139 | #if iterations > 15000 and dev_acc < prev_dev_acc: 140 | # args.lr *= 0.9 141 | # change_lr(opt, args.lr) 142 | # prev_dev_acc = dev_acc 143 | elif iterations % args.log_every == 0: 144 | # print progress message 145 | print(log_template.format(time.time() - start, 146 | epoch, iterations, 1 + batch_idx, len(train_iter), 147 | 100. * (1 + batch_idx) / len(train_iter), loss.item(), ' ' * 8, 148 | n_correct / n_total * 100, ' ' * 12)) 149 | 150 | 151 | def do_test(): 152 | iterations = 0 153 | n_correct, n_total = 0,0 154 | for batch_idx, batch in enumerate(test_iter): 155 | # switch model to training mode, clear gradient accumulators 156 | model.train() 157 | iterations += 1 158 | 159 | # forward pass 160 | answer = model(batch) 161 | 162 | # calculate accuracy of predictions in the current batch 163 | label = batch.label if not config.use_gpu else batch.label.cuda() 164 | n_correct += (torch.max(answer, 1)[1].view(label.size()).data == label.data).sum() 165 | n_total += batch.batch_size 166 | acc = 100. * n_correct / n_total 167 | if batch_idx % 100 == 0: 168 | print('evaluating %d\tacc: %f, %d, %d' % (batch_idx, acc, n_correct, n_total)) 169 | print('Acc: %f' % (n_correct / n_total)) 170 | 171 | 172 | if __name__ == '__main__': 173 | SEED = 1 174 | random.seed(SEED) 175 | torch.manual_seed(SEED) 176 | torch.cuda.manual_seed(SEED) 177 | config = args 178 | if args.task == 'sst': 179 | inputs, labels, train_iter, dev_iter, test_iter, train_set, dev_set = get_data_iterators_sst() 180 | elif args.task == 'yelp': 181 | inputs, labels, train_iter, dev_iter, test_iter, train_set, dev_set = \ 182 | get_data_iterators_yelp(map_cpu=False) 183 | elif args.task == 'tacred': 184 | inputs, labels, subj_offset, obj_offset, pos, ner, train_iter, dev_iter, test_iter, train_set, dev_set = get_data_iterators_tacred() 185 | config.subj_offset_size = len(subj_offset.vocab) 186 | config.obj_offset_size = len(obj_offset.vocab) 187 | config.pos_size = len(pos.vocab) 188 | config.ner_size = len(ner.vocab) 189 | config.offset_emb_dim = 30 190 | 191 | else: 192 | raise ValueError('unknown task') 193 | 194 | config.n_embed = len(inputs.vocab) 195 | config.d_out = len(labels.vocab) if hasattr(labels, 'vocab') else 2 196 | config.n_cells = config.n_layers 197 | config.use_gpu = args.gpu >= 0 198 | 199 | if args.task == 'sst': 200 | model = LSTMSentiment(config) 201 | elif args.task == 'yelp': 202 | model = LSTMMeanSentiment(config, match_length=True) 203 | elif args.task == 'tacred': 204 | model = LSTMMeanRE(config, match_length=True) 205 | model.init_weights() 206 | if args.word_vectors: 207 | model.embed.weight.data.copy_(inputs.vocab.vectors) 208 | if config.use_gpu: 209 | model = model.to(args.gpu) 210 | 211 | do_train() 212 | 213 | -------------------------------------------------------------------------------- /hiexpl/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/INK-USC/hierarchical-explanation-neural-sequence-models/d797daee2cea327ff3a7fb5d9f077412861f834f/hiexpl/utils/__init__.py -------------------------------------------------------------------------------- /hiexpl/utils/agglomeration.py: -------------------------------------------------------------------------------- 1 | # MIT License 2 | # 3 | # Copyright (c) 2019 Chandan Singh 4 | # 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy 6 | # of this software and associated documentation files (the "Software"), to deal 7 | # in the Software without restriction, including without limitation the rights 8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | # copies of the Software, and to permit persons to whom the Software is 10 | # furnished to do so, subject to the following conditions: 11 | # 12 | # The above copyright notice and this permission notice shall be included in all 13 | # copies or substantial portions of the Software. 14 | 15 | # Code from https://github.com/csinva/hierarchical-dnn-interpretations/ 16 | 17 | 18 | import numpy as np 19 | 20 | def collapse_tree(lists): 21 | num_iters = len(lists['comps_list']) 22 | num_words = len(lists['comps_list'][0]) 23 | 24 | # need to update comp_scores_list, comps_list 25 | comps_list = [np.zeros(num_words, dtype=np.int) for i in range(num_iters)] 26 | comp_scores_list = [{0: 0} for i in range(num_iters)] 27 | comp_levels_list = [{0: 0} for i in range(num_iters)] # use this to determine what level to put things at 28 | 29 | # initialize first level 30 | comps_list[0] = np.arange(num_words) 31 | comp_levels_list[0] = {i: 0 for i in range(num_words)} 32 | 33 | # iterate over levels 34 | for i in range(1, num_iters): 35 | comps = lists['comps_list'][i] 36 | comps_old = lists['comps_list'][i - 1] 37 | comp_scores = lists['comp_scores_list'][i] 38 | 39 | for comp_num in range(1, np.max(comps) + 1): 40 | comp = comps == comp_num 41 | comp_size = np.sum(comp) 42 | if comp_size == 1: 43 | comp_levels_list[i][comp_num] = 0 # set level to 0 44 | else: 45 | # check for matches 46 | matches = np.unique(comps_old[comp]) 47 | num_matches = matches.size 48 | 49 | # if 0 matches, level is 1 50 | if num_matches == 0: 51 | level = 1 52 | comp_levels_list[i][comp_num] = level # set level to level 1 53 | 54 | # if 1 match, maintain level 55 | elif num_matches == 1: 56 | level = comp_levels_list[i - 1][matches[0]] 57 | 58 | 59 | # if >1 match, take highest level + 1 60 | else: 61 | level = np.max([comp_levels_list[i - 1][match] for match in matches]) + 1 62 | 63 | comp_levels_list[i][comp_num] = level 64 | new_comp_num = int(np.max(comps_list[level]) + 1) 65 | comps_list[level][comp] = new_comp_num # update comp 66 | comp_scores_list[level][new_comp_num] = comp_scores[comp_num] # update comp score 67 | 68 | # remove unnecessary iters 69 | num_iters = 0 70 | while np.sum(comps_list[num_iters] > 0) and num_iters < len(comps_list): 71 | num_iters += 1 72 | 73 | # populate lists 74 | lists['comps_list'] = comps_list[:num_iters] 75 | lists['comp_scores_list'] = comp_scores_list[:num_iters] 76 | return lists 77 | 78 | # threshold scores at a specific percentile 79 | def threshold_scores(scores, percentile_include, absolute): 80 | # pick based on abs value? 81 | if absolute: 82 | scores = np.absolute(scores) 83 | 84 | # last 5 always pick 2 85 | num_left = scores.size - np.sum(np.isnan(scores)) 86 | if num_left <= 5: 87 | pass 88 | thresh = np.nanpercentile(scores, percentile_include) 89 | mask = scores >= thresh 90 | return mask 91 | 92 | # pytorch needs to return each input as a column 93 | # return batch_size x L tensor 94 | def gen_tiles(text, fill=0, 95 | method='occlusion', prev_text=None, sweep_dim=1): 96 | L = text.shape[0] 97 | texts = np.zeros((L - sweep_dim + 1, L), dtype=np.int) 98 | for start in range(L - sweep_dim + 1): 99 | end = start + sweep_dim 100 | if method == 'occlusion': 101 | text_new = np.copy(text).flatten() 102 | text_new[start:end] = fill 103 | else: 104 | text_new = np.zeros(L) 105 | text_new[start:end] = text[start:end] 106 | texts[start] = np.copy(text_new) 107 | return texts 108 | 109 | 110 | # return tile representing component 111 | def gen_tile_from_comp(text_orig, comp_tile, method, fill=0): 112 | if method == 'occlusion': 113 | tile_new = np.copy(text_orig).flatten() 114 | tile_new[comp_tile] = fill 115 | elif method == 'build_up' or method == 'cd': 116 | tile_new = np.zeros(text_orig.shape) 117 | tile_new[comp_tile] = text_orig[comp_tile] 118 | return tile_new 119 | 120 | 121 | # generate tiles around component 122 | def gen_tiles_around_baseline(text_orig, comp_tile, method='build_up', sweep_dim=1, fill=0): 123 | L = text_orig.shape[0] 124 | left = 0 125 | right = L - 1 126 | while not comp_tile[left]: 127 | left += 1 128 | while not comp_tile[right]: 129 | right -= 1 130 | left = max(0, left - sweep_dim) 131 | right = min(L - 1, right + sweep_dim) 132 | tiles = [] 133 | for x in [left, right]: 134 | if method == 'occlusion': 135 | tile_new = np.copy(text_orig).flatten() 136 | tile_new[comp_tile] = fill 137 | tile_new[x] = fill 138 | elif method == 'build_up' or method == 'cd': 139 | tile_new = np.zeros(text_orig.shape) 140 | tile_new[comp_tile] = text_orig[comp_tile] 141 | tile_new[x] = text_orig[x] 142 | tiles.append(tile_new) 143 | return np.array(tiles), [left, right] 144 | 145 | def tiles_to_cd(tiles): 146 | starts, stops = [], [] 147 | #tiles = batch.text.data.cpu().numpy() 148 | L = tiles.shape[0] 149 | for c in range(tiles.shape[1]): 150 | text = tiles[:, c] 151 | start = 0 152 | stop = L - 1 153 | while text[start] == 0: 154 | start += 1 155 | while text[stop] == 0: 156 | stop -= 1 157 | starts.append(start) 158 | stops.append(stop) 159 | return starts, stops 160 | 161 | def lists_to_tabs(lists, num_words): 162 | 163 | num_iters = len(lists['comps_list']) 164 | data = np.empty(shape=(num_iters, num_words)) 165 | data[:] = np.nan 166 | data[0, :] = lists['scores_list'][0] 167 | for i in range(1, num_iters): 168 | comps = lists['comps_list'][i] 169 | comp_scores_list = lists['comp_scores_list'][i] 170 | 171 | for comp_num in range(1, np.max(comps) + 1): 172 | idxs = comps == comp_num 173 | data[i][idxs] = comp_scores_list[comp_num] 174 | 175 | 176 | data[np.isnan(data)] = 0 # np.nanmin(data) - 0.001 177 | return data 178 | 179 | -------------------------------------------------------------------------------- /hiexpl/utils/args.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | import os 3 | 4 | def makedirs(name): 5 | """helper function for python 2 and 3 to call os.makedirs() 6 | avoiding an error if the directory to be created already exists""" 7 | 8 | import os, errno 9 | 10 | try: 11 | os.makedirs(name) 12 | except OSError as ex: 13 | if ex.errno == errno.EEXIST and os.path.isdir(name): 14 | # ignore existing directory 15 | pass 16 | else: 17 | # a different error happened 18 | raise 19 | 20 | def get_best_snapshot(dir): 21 | if os.path.isdir(dir): 22 | files = os.listdir(dir) 23 | for file in files: 24 | if file.startswith('best_'): 25 | return os.path.join(dir, file) 26 | return None 27 | 28 | def get_args(): 29 | parser = ArgumentParser(description='PyTorch/torchtext SST') 30 | 31 | # model parameters 32 | parser.add_argument('--optim', type=str, default='adam', choices=['adam', 'sgd']) 33 | parser.add_argument('--metrics', default='accuracy', choices=['accuracy', 'tacred_f1']) 34 | parser.add_argument('--epochs', type=int, default=20) 35 | parser.add_argument('--task', type=str, default='sst') 36 | parser.add_argument('--batch_size', type=int, default=50) 37 | parser.add_argument('--d_embed', type=int, default=300) 38 | parser.add_argument('--d_proj', type=int, default=300) 39 | parser.add_argument('--d_hidden', type=int, default=128) 40 | parser.add_argument('--n_layers', type=int, default=1) 41 | parser.add_argument('--dropout', type=float, default=0.0) 42 | parser.add_argument('--log_every', type=int, default=10000) 43 | parser.add_argument('--lr', type=float, default=.0005) 44 | parser.add_argument('--weight_decay', type=float, default=1e-6) 45 | parser.add_argument('--dev_every', type=int, default=100) 46 | parser.add_argument('--save_every', type=int, default=100) 47 | parser.add_argument('--no-bidirectional', action='store_false', dest='birnn') 48 | parser.add_argument('--preserve-case', action='store_false', dest='lower') 49 | parser.add_argument('--no-projection', action='store_false', dest='projection') 50 | parser.add_argument('--fix_emb', action='store_true') 51 | parser.add_argument('--gpu', default=0) 52 | parser.add_argument('--save_path', type=str, default='results') 53 | parser.add_argument('--vector_cache', type=str, default=os.path.join(os.getcwd(), '.vector_cache/input_vectors.pt')) 54 | parser.add_argument('--word_vectors', type=str, default='glove.6B.300d') 55 | parser.add_argument('--resume_snapshot', type=str, default='') 56 | parser.add_argument('--word_dropout', action='store_true') 57 | 58 | parser.add_argument('--lm_d_embed', type=int, default=300) 59 | parser.add_argument('--lm_d_hidden', type=int, default=128) 60 | 61 | parser.add_argument('--method', nargs='?') 62 | parser.add_argument('--nb_method', default='ngram') 63 | parser.add_argument('--nb_range', type=int, default=3) 64 | parser.add_argument('--exp_name', default='') 65 | parser.add_argument('--lm_dir', nargs='?', default='') 66 | parser.add_argument('--lm_path', nargs='?', default='') 67 | parser.add_argument('--start', type=int, default=0) 68 | parser.add_argument('--stop', type=int, default=10000000000) 69 | parser.add_argument('--sample_n', type=int, default=5) 70 | 71 | parser.add_argument('--explain_model', default='lstm') 72 | parser.add_argument('--demo', action='store_true') 73 | 74 | parser.add_argument('--dataset', default='dev') 75 | parser.add_argument('--use_bert_tokenizer', action='store_true') 76 | parser.add_argument('--no_subtrees', action='store_true') 77 | 78 | parser.add_argument('--use_bert_lm', action='store_true') 79 | parser.add_argument('--fix_test_vocab', action='store_true') 80 | 81 | parser.add_argument('--include_noise_labels', action='store_true') 82 | parser.add_argument('--filter_length_gt', type=int, default=-1) 83 | parser.add_argument('--add_itself', action='store_true') 84 | 85 | parser.add_argument('--mean_hidden', action='store_true') 86 | parser.add_argument('--agg', action='store_true') 87 | parser.add_argument('--class_score', action='store_true') 88 | 89 | parser.add_argument('--cd_pad', action='store_true') 90 | parser.add_argument('--eval_file', default='') 91 | 92 | args = parser.parse_args() 93 | 94 | try: 95 | args.gpu = int(args.gpu) 96 | except ValueError: 97 | args.gpu = 'cpu' 98 | 99 | if os.path.isdir(args.resume_snapshot): 100 | args.resume_snapshot = get_best_snapshot(args.resume_snapshot) 101 | if os.path.isdir(args.lm_path): 102 | args.lm_path = get_best_snapshot(args.lm_path) 103 | return args 104 | 105 | 106 | args = get_args() 107 | -------------------------------------------------------------------------------- /hiexpl/utils/parser.py: -------------------------------------------------------------------------------- 1 | from nltk import ParentedTree 2 | 3 | 4 | def parse_tree(s): 5 | tree = ParentedTree.fromstring(s) 6 | return tree 7 | 8 | 9 | def read_trees_from_corpus(path): 10 | f = open(path) 11 | rows = f.readlines() 12 | trees = [] 13 | for row in rows: 14 | row = row.lower() 15 | tree = parse_tree(row) 16 | trees.append(tree) 17 | return trees 18 | 19 | 20 | def is_leaf(node): 21 | if type(node[0]) == str and len(node) != 1: 22 | print(1) 23 | return type(node[0]) == str 24 | 25 | 26 | def get_span_to_node_mapping(tree): 27 | def dfs(node, span_to_node, node_to_span, idx): 28 | if is_leaf(node): 29 | span_to_node[idx] = node 30 | node_to_span[id(node)] = idx 31 | return idx + 1 32 | prev_idx = idx 33 | for child in node: 34 | idx = dfs(child, span_to_node, node_to_span, idx) 35 | span_to_node[(prev_idx, idx-1)] = node 36 | node_to_span[id(node)] = (prev_idx, idx-1) 37 | return idx 38 | span2node, node2span = {}, {} 39 | dfs(tree, span2node, node2span, 0) 40 | return span2node, node2span 41 | 42 | 43 | def get_siblings_idx(node, node2span): 44 | parent = node.parent() 45 | if parent is None: # root 46 | return node2span[id(node)] 47 | return node2span[id(parent)] 48 | 49 | 50 | def find_region_neighbourhood(s_or_tree, region): 51 | if type(s_or_tree) is str: 52 | tree = parse_tree(s_or_tree) 53 | else: 54 | tree = s_or_tree 55 | 56 | if type(region) is tuple and region[0] == region[1]: 57 | region = region[0] 58 | 59 | span2node, node2span = get_span_to_node_mapping(tree) 60 | node = span2node[region] 61 | sibling_idx = get_siblings_idx(node, node2span) 62 | return sibling_idx 63 | 64 | 65 | 66 | 67 | 68 | -------------------------------------------------------------------------------- /hiexpl/utils/reader.py: -------------------------------------------------------------------------------- 1 | import torchtext as tt 2 | from nltk import Tree 3 | import pickle, random 4 | import torch 5 | from utils.args import get_args, makedirs 6 | import os 7 | from bert.tokenization import BertTokenizer 8 | import csv, json 9 | 10 | args = get_args() 11 | 12 | 13 | def save_vocab(path, vocab): 14 | f = open(path, 'wb') 15 | pickle.dump(vocab, f) 16 | f.close() 17 | 18 | 19 | def load_vocab(path): 20 | f = open(path, 'rb') 21 | obj = pickle.load(f) 22 | return obj 23 | 24 | 25 | def handle_vocab(vocab_path, field, datasets, vector_cache='', train_lm=False, max_size=None): 26 | create_vocab = False 27 | if os.path.isfile(vocab_path): 28 | print('loading vocab from %s' % vocab_path) 29 | vocab = load_vocab(vocab_path) 30 | field.vocab = vocab 31 | else: 32 | print('creating vocab') 33 | makedirs('vocab') 34 | field.build_vocab(*datasets, max_size=max_size) 35 | vocab = field.vocab 36 | if '' not in field.vocab.stoi: 37 | field.vocab.itos.append('') 38 | field.vocab.stoi[''] = len(field.vocab.itos) - 1 39 | if '' not in field.vocab.stoi: 40 | field.vocab.itos.append('') 41 | field.vocab.stoi[''] = len(field.vocab.itos) - 1 42 | save_vocab(vocab_path, vocab) 43 | create_vocab = True 44 | 45 | if vector_cache != '' and not vector_cache.startswith('none'): 46 | if args.word_vectors or create_vocab: 47 | if os.path.isfile(vector_cache): 48 | field.vocab.vectors = torch.load(vector_cache) 49 | else: 50 | field.vocab.load_vectors(args.word_vectors) 51 | for i in range(field.vocab.vectors.size(0)): 52 | if field.vocab.vectors[i,0].item() == 0 and field.vocab.vectors[i,1].item() == 0: 53 | field.vocab.vectors[i].uniform_(-1,1) 54 | makedirs(os.path.dirname(vector_cache)) 55 | torch.save(field.vocab.vectors, vector_cache) 56 | 57 | if train_lm: 58 | v = torch.zeros(2, field.vocab.vectors.size(-1)) 59 | field.vocab.vectors = torch.cat([field.vocab.vectors, v], 0) 60 | 61 | 62 | def get_data_iterators_sst(): 63 | inputs = tt.data.Field(lower=args.lower) 64 | answers = tt.data.Field(sequential=False, unk_token=None) 65 | train, dev, test = tt.datasets.SST.splits(inputs, answers, fine_grained=False, train_subtrees=not args.no_subtrees, 66 | filter_pred=lambda ex: ex.label != 'neutral') 67 | 68 | vocab_path = 'vocab/vocab_sst.pkl' if not args.use_bert_tokenizer else 'vocab/vocab_sst_bert.pkl' 69 | if args.fix_test_vocab and not args.use_bert_tokenizer: 70 | vocab_path = 'vocab/vocab_sst_fix.pkl' 71 | c_postfix = '.sst' 72 | if args.use_bert_tokenizer: 73 | c_postfix += '.bert' 74 | if args.fix_test_vocab: 75 | c_postfix += '.fix' 76 | handle_vocab(vocab_path, inputs, (train, dev, test), args.vector_cache + c_postfix) 77 | answers.build_vocab(train) 78 | 79 | if args.include_noise_labels: 80 | train_2, _, _ = tt.datasets.SST.splits(inputs, answers, fine_grained=False, train_subtrees=True) 81 | texts = set() 82 | for example in train_2.examples: 83 | if len(example.text) == 1 and example.text[0] not in texts: 84 | texts.add(example.text[0]) 85 | if example.label == 'positive': 86 | example.label = 'negative' 87 | elif example.label == 'negative': 88 | example.label = 'positive' 89 | elif example.label == 'neutral': 90 | example.label = random.choice(['positive','negative']) 91 | else: 92 | raise ValueError 93 | train.examples.append(example) 94 | train_iter, dev_iter, test_iter = tt.data.BucketIterator.splits( 95 | (train, dev, test), batch_size=args.batch_size, device=args.gpu, sort=True, shuffle=False) 96 | else: 97 | train_iter, dev_iter, test_iter = tt.data.BucketIterator.splits( 98 | (train, dev, test), batch_size=args.batch_size, device=args.gpu) 99 | return inputs, answers, train_iter, dev_iter, test_iter, train, dev 100 | 101 | 102 | def compute_mapping(tokens, bert_tokens): 103 | mapping = [] 104 | i, j = 0, 0 105 | while i < len(tokens): 106 | t = '' 107 | while len(t) < len(tokens[i]): 108 | t += bert_tokens[j].replace('##','') 109 | j += 1 110 | if len(t) > len(tokens[i]): 111 | print('warning: mapping mismatch') 112 | break 113 | mapping.append(j) 114 | i += 1 115 | return mapping 116 | 117 | def convert_to_bert_tokenization(tokens, bert_tokenizer, return_mapping=False): 118 | text = ' '.join(tokens) 119 | bert_tokens = bert_tokenizer.tokenize(text) 120 | # compute mapping 121 | if return_mapping: 122 | mapping = compute_mapping(tokens, bert_tokens) 123 | return bert_tokens, mapping 124 | else: 125 | return bert_tokens 126 | 127 | def get_examples_sst(path, train_lm, bert_tokenizer=None): 128 | f = open(path) 129 | examples = [] 130 | for i, line in enumerate(f.readlines()): 131 | line = line.lower() 132 | tree = Tree.fromstring(line) 133 | 134 | tokens = tree.leaves() 135 | if bert_tokenizer is not None: 136 | tokens, mapping = convert_to_bert_tokenization(tokens, bert_tokenizer, return_mapping=True) 137 | 138 | if train_lm: 139 | tokens = [''] + tokens + [''] 140 | 141 | example = tt.data.Example() 142 | example.text = tokens 143 | example.length = len(tokens) 144 | examples.append(example) 145 | example.offset = i 146 | 147 | if bert_tokenizer is not None: 148 | example.mapping = mapping 149 | 150 | if int(tree.label()) >= 3: 151 | example.label = 0 152 | elif int(tree.label()) <= 2: 153 | example.label = 1 154 | else: 155 | example.label = 2 156 | 157 | return examples 158 | 159 | 160 | def get_examples_yelp(path, train_lm, bert_tokenizer=None): 161 | f = open(path) 162 | reader = csv.reader(f) 163 | examples = [] 164 | for i, line in enumerate(reader): 165 | tokens = line[1].split() 166 | label = int(line[0]) 167 | 168 | if bert_tokenizer is not None: 169 | tokens, mapping = convert_to_bert_tokenization(tokens[:100], bert_tokenizer, return_mapping=True) 170 | 171 | if train_lm: 172 | tokens = [''] + tokens[:50] + [''] 173 | 174 | if args.filter_length_gt != -1 and len(tokens) >= args.filter_length_gt: 175 | continue 176 | 177 | example = tt.data.Example() 178 | example.text = tokens 179 | example.length = len(tokens) 180 | example.offset = i 181 | 182 | if bert_tokenizer is not None: 183 | example.mapping = mapping 184 | 185 | if label == 2: 186 | example.label = 0 187 | elif label == 1: 188 | example.label = 1 189 | else: 190 | raise ValueError 191 | 192 | examples.append(example) 193 | if args.explain_model == 'bert': 194 | if i == args.stop: break 195 | 196 | return examples 197 | 198 | def filter_tacred_special_token(entry): 199 | text = entry['token'] 200 | text[entry['subj_start']] = 'SUBJ%s' % entry['subj_type'] 201 | text[entry['obj_start']] = 'OBJ %s' % entry['obj_type'] 202 | for idx in range(entry['subj_start'] + 1, entry['subj_end'] + 1): 203 | text[idx] = '-!REMOVED!-' 204 | 205 | for idx in range(entry['obj_start'] + 1, entry['obj_end'] + 1): 206 | text[idx] = '-!REMOVED!-' 207 | text = [_ for _ in filter(lambda x: x != '-!REMOVED!-', text)] 208 | full_tokens = [] 209 | for token in text: 210 | full_tokens.extend(token.split()) 211 | return full_tokens 212 | 213 | 214 | def get_examples_tacred(path, train_lm, bert_tokenizer=None): 215 | f = open(path) 216 | #reader = csv.reader(f) 217 | examples = [] 218 | json_data = json.load(f) 219 | 220 | for ji, entry in enumerate(json_data): 221 | tokens = entry['token'] 222 | #tokens = [_.lower() for _ in tokens] 223 | if bert_tokenizer is not None: 224 | tokens = filter_tacred_special_token(entry) 225 | tokens = [_.lower() for _ in tokens] 226 | tokens, mapping = convert_to_bert_tokenization(tokens, bert_tokenizer, return_mapping=True) 227 | 228 | if args.filter_length_gt != -1 and len(tokens) >= args.filter_length_gt: 229 | continue 230 | 231 | example = tt.data.Example() 232 | example.text = tokens 233 | example.label = entry['relation'] 234 | example.offset = ji 235 | 236 | example.subj_offset, example.obj_offset = [], [] 237 | example.pos, example.ner = entry['stanford_pos'], entry['stanford_ner'] 238 | 239 | if bert_tokenizer is None: 240 | for idx in range(entry['subj_start'], entry['subj_end'] + 1): 241 | example.text[idx] = 'SUBJ-%s' % entry['subj_type'] 242 | for idx in range(entry['obj_start'], entry['obj_end'] + 1): 243 | example.text[idx] = 'OBJ-%s' % entry['obj_type'] 244 | 245 | for idx in range(len(tokens)): 246 | if idx < entry['subj_start']: 247 | example.subj_offset.append(idx - entry['subj_start']) 248 | elif idx > entry['subj_end']: 249 | example.subj_offset.append(idx - entry['subj_end']) 250 | else: 251 | example.subj_offset.append(0) 252 | if idx < entry['obj_start']: 253 | example.obj_offset.append(idx - entry['obj_start']) 254 | elif idx > entry['obj_end']: 255 | example.obj_offset.append(idx - entry['obj_end']) 256 | else: 257 | example.obj_offset.append(0) 258 | 259 | if bert_tokenizer is not None: 260 | example.mapping = mapping 261 | 262 | if train_lm: 263 | example.text = [''] + example.text[:50] + [''] 264 | example.length = len(example.text) 265 | examples.append(example) 266 | if args.explain_model == 'bert': 267 | if ji == args.stop: break 268 | 269 | return examples 270 | 271 | 272 | def get_data_iterators_sst_flatten(train_lm=False, map_cpu=False): 273 | text_field = tt.data.Field(lower=args.lower) 274 | length_field = tt.data.Field(sequential=False, use_vocab=False) 275 | offset_field = tt.data.Field(sequential=False, use_vocab=False) 276 | _, _, _ = tt.datasets.SST.splits(text_field, length_field, fine_grained=False, train_subtrees=False, 277 | filter_pred=lambda ex: ex.label != 'neutral') 278 | 279 | path_format = './.data/sst/trees/%s.txt' 280 | 281 | bert_tokenizer = None 282 | if args.use_bert_tokenizer: 283 | bert_tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True, cache_dir='bert/cache') 284 | 285 | train_ex, dev_ex, test_ex = (get_examples_sst(path_format % ds, train_lm, bert_tokenizer=bert_tokenizer) 286 | for ds in ['train', 'dev', 'test']) 287 | train, dev, test = (tt.data.Dataset(ex, [('text', text_field), ('length', length_field), ('offset', offset_field)]) 288 | for ex in [train_ex, dev_ex, test_ex]) 289 | 290 | vocab_path = 'vocab/vocab_sst.pkl' if not args.use_bert_tokenizer else 'vocab/vocab_sst_bert.pkl' 291 | c_postfix = '.sst' 292 | if args.use_bert_tokenizer: 293 | c_postfix += '.bert' 294 | handle_vocab(vocab_path, text_field, (train, dev, test), args.vector_cache + c_postfix, train_lm) 295 | 296 | train_iter, dev_iter, test_iter = ( 297 | tt.data.BucketIterator(x, batch_size=args.batch_size, device=args.gpu if not map_cpu else 'cpu', shuffle=False) 298 | for x in (train, dev, test)) 299 | return text_field, length_field, train_iter, dev_iter, test_iter, train, dev 300 | 301 | 302 | def get_data_iterators_yelp(train_lm=False, map_cpu=False): 303 | text_field = tt.data.Field(lower=args.lower) 304 | label_field = tt.data.LabelField(sequential=False, unk_token=None) 305 | length_field = tt.data.Field(sequential=False, use_vocab=False) 306 | offset_field = tt.data.Field(sequential=False, use_vocab=False) 307 | 308 | path_format = './.data/yelp_review_polarity_csv/%s.csv.token' 309 | bert_tokenizer = None 310 | if args.use_bert_tokenizer: 311 | bert_tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True, cache_dir='bert/cache') 312 | train_examples, test_examples = (get_examples_yelp(path_format % ds, train_lm, bert_tokenizer=bert_tokenizer) for 313 | ds in ['train', 'test']) 314 | dev_examples = test_examples[:500] 315 | train, dev, test = (tt.data.Dataset(ex, [('text', text_field), ('length', length_field), ('offset', offset_field), ('label',label_field)]) 316 | for ex in [train_examples, dev_examples, test_examples]) 317 | 318 | vocab_path = 'vocab/vocab_yelp.pkl' if not args.use_bert_tokenizer else 'vocab/vocab_yelp_bert.pkl' 319 | if args.fix_test_vocab and not args.use_bert_tokenizer: 320 | vocab_path = 'vocab/vocab_yelp_fix.pkl' 321 | 322 | c_postfix = '.yelp' 323 | if args.use_bert_tokenizer: 324 | c_postfix += '.bert' 325 | if args.fix_test_vocab: 326 | c_postfix += '.fix' 327 | handle_vocab(vocab_path, text_field, (train, test), args.vector_cache + c_postfix, train_lm, max_size=20000) 328 | label_field.build_vocab(train) 329 | train_iter, dev_iter, test_iter = ( 330 | tt.data.BucketIterator(x, batch_size=args.batch_size, device=args.gpu if not map_cpu else 'cpu', shuffle=False) 331 | for x in (train, dev, test)) 332 | return text_field, label_field, train_iter, dev_iter, test_iter, train, dev 333 | 334 | def get_data_iterators_tacred(train_lm=False, map_cpu=False): 335 | text_field = tt.data.Field(lower=False) 336 | label_field = tt.data.LabelField() 337 | length_field = tt.data.Field(sequential=False, use_vocab=False) 338 | offset_field = tt.data.Field(sequential=False, use_vocab=False) 339 | pos_field = tt.data.Field() 340 | ner_field = tt.data.Field() 341 | subj_offset_field = tt.data.Field() 342 | obj_offset_field = tt.data.Field() 343 | 344 | path_format = './.data/TACRED/data/json/%s.json' 345 | bert_tokenizer = None 346 | if args.use_bert_tokenizer: 347 | bert_tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True, cache_dir='bert/cache') 348 | train_examples, dev_examples, test_examples = (get_examples_tacred(path_format % ds, train_lm, bert_tokenizer=bert_tokenizer) for 349 | ds in ['train', 'dev','test']) 350 | train, dev, test = (tt.data.Dataset(ex, [('text', text_field), ('length', length_field), ('offset', offset_field), 351 | ('label', label_field), ('subj_offset', subj_offset_field), 352 | ('obj_offset', obj_offset_field), ('ner', ner_field), ('pos', pos_field)]) 353 | for ex in [train_examples, dev_examples, test_examples]) 354 | 355 | vocab_path = 'vocab/vocab_tacred.pkl' if not args.use_bert_tokenizer else 'vocab/vocab_tacred_bert.pkl' 356 | if args.fix_test_vocab and not args.use_bert_tokenizer: 357 | vocab_path = 'vocab/vocab_tacred_fix.pkl' 358 | 359 | c_postfix = '.tacred' 360 | if args.use_bert_tokenizer: 361 | c_postfix += '.bert' 362 | if args.fix_test_vocab: 363 | c_postfix += '.fix' 364 | handle_vocab(vocab_path, text_field, (train, dev, test), args.vector_cache + c_postfix, train_lm, max_size=100000) 365 | handle_vocab(vocab_path, text_field, (train, dev, test), args.vector_cache + c_postfix, train_lm, max_size=100000) 366 | handle_vocab(vocab_path + '.relation', label_field, (train, dev, test), '', False, None) 367 | handle_vocab(vocab_path + '.subj_offset', subj_offset_field, (train, dev, test), '', False, None) 368 | handle_vocab(vocab_path + '.obj_offset', obj_offset_field, (train, dev, test), '', False, None) 369 | handle_vocab(vocab_path + '.pos', pos_field, (train, dev, test), '', False, None) 370 | handle_vocab(vocab_path + '.ner', ner_field, (train, dev, test), '', False, None) 371 | 372 | train_iter, dev_iter, test_iter = ( 373 | tt.data.BucketIterator(x, batch_size=args.batch_size, device=args.gpu if not map_cpu else 'cpu') 374 | for x in (train, dev, test)) 375 | return text_field, label_field, subj_offset_field, obj_offset_field, pos_field, ner_field, train_iter, dev_iter, test_iter, train, dev 376 | 377 | -------------------------------------------------------------------------------- /hiexpl/utils/tacred_f1.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The Board of Trustees of The Leland Stanford Junior University 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | from collections import Counter 17 | import sys 18 | NO_RELATION = 0 19 | 20 | def score(key, prediction, verbose=False): 21 | correct_by_relation = Counter() 22 | guessed_by_relation = Counter() 23 | gold_by_relation = Counter() 24 | 25 | # Loop over the data to compute a score 26 | for row in range(len(key)): 27 | gold = key[row] 28 | guess = prediction[row] 29 | 30 | if gold == NO_RELATION and guess == NO_RELATION: 31 | pass 32 | elif gold == NO_RELATION and guess != NO_RELATION: 33 | guessed_by_relation[guess] += 1 34 | elif gold != NO_RELATION and guess == NO_RELATION: 35 | gold_by_relation[gold] += 1 36 | elif gold != NO_RELATION and guess != NO_RELATION: 37 | guessed_by_relation[guess] += 1 38 | gold_by_relation[gold] += 1 39 | if gold == guess: 40 | correct_by_relation[guess] += 1 41 | 42 | # Print verbose information 43 | if verbose: 44 | print("Per-relation statistics:") 45 | relations = gold_by_relation.keys() 46 | longest_relation = 0 47 | for relation in sorted(relations): 48 | longest_relation = max(len(relation), longest_relation) 49 | for relation in sorted(relations): 50 | # (compute the score) 51 | correct = correct_by_relation[relation] 52 | guessed = guessed_by_relation[relation] 53 | gold = gold_by_relation[relation] 54 | prec = 1.0 55 | if guessed > 0: 56 | prec = float(correct) / float(guessed) 57 | recall = 0.0 58 | if gold > 0: 59 | recall = float(correct) / float(gold) 60 | f1 = 0.0 61 | if prec + recall > 0: 62 | f1 = 2.0 * prec * recall / (prec + recall) 63 | # (print the score) 64 | sys.stdout.write(("{:<" + str(longest_relation) + "}").format(relation)) 65 | sys.stdout.write(" P: ") 66 | if prec < 0.1: sys.stdout.write(' ') 67 | if prec < 1.0: sys.stdout.write(' ') 68 | sys.stdout.write("{:.2%}".format(prec)) 69 | sys.stdout.write(" R: ") 70 | if recall < 0.1: sys.stdout.write(' ') 71 | if recall < 1.0: sys.stdout.write(' ') 72 | sys.stdout.write("{:.2%}".format(recall)) 73 | sys.stdout.write(" F1: ") 74 | if f1 < 0.1: sys.stdout.write(' ') 75 | if f1 < 1.0: sys.stdout.write(' ') 76 | sys.stdout.write("{:.2%}".format(f1)) 77 | sys.stdout.write(" #: %d" % gold) 78 | sys.stdout.write("\n") 79 | print("") 80 | 81 | # Print the aggregate score 82 | if verbose: 83 | print("Final Score:") 84 | prec_micro = 1.0 85 | if sum(guessed_by_relation.values()) > 0: 86 | prec_micro = float(sum(correct_by_relation.values())) / float(sum(guessed_by_relation.values())) 87 | recall_micro = 0.0 88 | if sum(gold_by_relation.values()) > 0: 89 | recall_micro = float(sum(correct_by_relation.values())) / float(sum(gold_by_relation.values())) 90 | f1_micro = 0.0 91 | if prec_micro + recall_micro > 0.0: 92 | f1_micro = 2.0 * prec_micro * recall_micro / (prec_micro + recall_micro) 93 | #print("Precision (micro): {:.3%}".format(prec_micro)) 94 | #print(" Recall (micro): {:.3%}".format(recall_micro)) 95 | #print(" F1 (micro): {:.3%}".format(f1_micro)) 96 | return prec_micro, recall_micro, f1_micro -------------------------------------------------------------------------------- /hiexpl/visualize.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | import csv 4 | from utils.parser import get_span_to_node_mapping, parse_tree 5 | import pickle 6 | import argparse 7 | import os 8 | 9 | def len_span(span): 10 | return 1 if type(span) is int else span[1] - span[0] + 1 11 | 12 | def compact_layer(layers): 13 | # collapse unimportant hierarchies 14 | items = sorted(list(layers.items()), key=lambda x: x[0]) 15 | layers = [_[1] for _ in items] 16 | idx = 0 17 | compact_layers = {} 18 | skip_flg = False 19 | for i, layer in enumerate(layers): 20 | if i != len(layers) - 1 and len(layers[i]) == 1 and len(layers[i+1]) == 1 and not skip_flg: 21 | entry1, entry2 = layers[i][0], layers[i+1][0] 22 | score1, score2 = entry1[3], entry2[3] 23 | start1, stop1 = entry1[2] - len(entry1[1]) + 1, entry1[2] 24 | start2, stop2 = entry2[2] - len(entry2[1]) + 1, entry2[2] 25 | if (stop2 - start2) - (stop1 - start1) == 1 and not score1 * score2 < 0: 26 | continue 27 | compact_layers[idx] = layer 28 | idx += 1 29 | skip_flg = False 30 | return compact_layers 31 | 32 | def plot_score_array(layers, score_array, sent_words): 33 | max_abs = abs(score_array).max() 34 | width = max(10, score_array.shape[1]) 35 | height = max(5, score_array.shape[0]) 36 | fig, ax = plt.subplots(figsize=(width, height)) 37 | 38 | vmin, vmax = -max_abs * 1.2, max_abs * 1.2 39 | im = ax.imshow(score_array, cmap='coolwarm', aspect=0.5, vmin=vmin, vmax=vmax) 40 | #y_ticks = [''] * score_array.shape[0] 41 | #y_ticks[0] = 'Atomics words' 42 | #y_ticks[-1] = 'Full sentence' 43 | ax.set_yticks([]) 44 | ax.set_xticks([]) 45 | #ax.set_yticks(np.arange(len(y_ticks))) 46 | #ax.set_yticklabels(y_ticks) 47 | rgba = im.cmap(im.norm(im.get_array())) 48 | cnt = 0 49 | if layers is not None: 50 | for idx, i in enumerate(sorted(layers.keys())): 51 | for entry in layers[i]: 52 | start, stop = entry[2] - len(entry[1]) + 1, entry[2] 53 | for j in range(start, stop + 1): 54 | color = (0.0, 0.0, 0.0) 55 | ax.text(j, cnt, sent_words[j], ha='center', va='center', fontsize=11 if len(sent_words[j]) < 10 else 8, 56 | color=color) 57 | cnt += 1 58 | else: 59 | for i in range(score_array.shape[0]): 60 | for j in range(score_array.shape[1]): 61 | if score_array[i,j] != 0: 62 | if sent_words[j].startswith('SUBJ') or sent_words[j].startswith('OBJ'): 63 | sent_words[j] = sent_words[j].replace('SUBJ-','S-').replace('OBJ-','O-').upper() 64 | fontsize = 12 65 | if len(sent_words[j]) >= 8: 66 | fontsize = 8 67 | if len(sent_words[j]) >= 12: 68 | fontsize = 6 69 | ax.text(j, i, sent_words[j], ha='center', va='center', 70 | fontsize=fontsize) 71 | return im 72 | 73 | def draw_tree_from_line(s, tree_s): 74 | stack = [] 75 | layers = {} 76 | phrase_scores = s.strip().split('\t') 77 | word_offset = 0 78 | sent_words = [] 79 | 80 | span2node, node2span = get_span_to_node_mapping(parse_tree(tree_s)) 81 | spans = list(span2node.keys()) 82 | 83 | for idx, span in enumerate(spans): 84 | items = phrase_scores[idx] 85 | words = span2node[span].leaves() 86 | score = float(items.split()[-1]) 87 | if len(words) == 1: 88 | layer = 0 89 | entry = (layer, words, word_offset, score) 90 | word_offset += 1 91 | sent_words.append(words[0]) 92 | stack.append(entry) 93 | else: 94 | len_sum, max_layer = 0, -1 95 | while len_sum < len(words): 96 | popped = stack[-1] 97 | max_layer = max(popped[0], max_layer) 98 | len_sum += len(popped[1]) 99 | stack.pop(-1) 100 | layer = max_layer + 1 101 | entry = (layer, words, word_offset - 1, score) 102 | stack.append(entry) 103 | 104 | if layer not in layers: 105 | layers[layer] = [] 106 | layers[layer].append(entry) 107 | 108 | #layers = compact_layer(layers) 109 | 110 | score_array = [] 111 | 112 | for layer in sorted(layers.keys()): 113 | arr = np.zeros(len(sent_words)) 114 | for entry in layers[layer]: 115 | start, stop = entry[2] - len(entry[1]) + 1, entry[2] # closed interval 116 | arr[start:stop + 1] = entry[3] 117 | score_array.append(arr) 118 | 119 | score_array = np.stack(score_array) 120 | 121 | im = plot_score_array(layers, score_array, sent_words) 122 | return im, score_array 123 | 124 | def visualize_tree(result_path, model_name, method_name, tree_path): 125 | f = open(result_path) 126 | f2 = open(tree_path) 127 | os.makedirs('figs/fig{}_{}'.format(model_name, method_name), exist_ok=True) 128 | for i, (line, tree_str) in enumerate(zip(f.readlines(), f2.readlines())): 129 | im, score_array = draw_tree_from_line(line, tree_str) 130 | plt.savefig('figs/fig{}_{}/{}'.format(model_name, method_name, i), bbox_inches='tight') 131 | plt.close() 132 | 133 | 134 | def visualize_tab(tab_file_dir, model_name, method_name): 135 | f = open(tab_file_dir, 'rb') 136 | data = pickle.load(f) 137 | os.makedirs('figs/{}_{}'.format(model_name, method_name), exist_ok=True) 138 | for i,entry in enumerate(data): 139 | sent_words = entry['text'].split() 140 | score_array = entry['tab'] 141 | label_name = entry['label'] 142 | pred_score = entry['pred'] 143 | if score_array.ndim == 1: 144 | score_array = score_array.reshape(1,-1) 145 | # append full prediction 146 | row_pred_score = np.full((1,score_array.shape[1]),pred_score) 147 | score_array = np.concatenate([score_array, row_pred_score], 0) 148 | if score_array.shape[1] <= 400: 149 | im = plot_score_array(None, score_array, sent_words) 150 | plt.title(label_name) 151 | plt.savefig('figs/{}_{}/fig_{}'.format(model_name, method_name, i), bbox_inches='tight') 152 | plt.close() 153 | 154 | if __name__ == '__main__': 155 | parser = argparse.ArgumentParser() 156 | parser.add_argument('--file', help='pkl or txt output by the explanation algorithm') 157 | parser.add_argument('--model', help='model name for saving images') 158 | parser.add_argument('--method', help='method name for saving images') 159 | parser.add_argument('--use_gt_trees', help='whether use ground truth parsing trees (sst)' 160 | 'if true, --file should be .txt', action='store_true') 161 | parser.add_argument('--gt_tree_path', default='.data/sst/trees/%s.txt' % 'test') 162 | args = parser.parse_args() 163 | if not args.use_gt_trees: 164 | visualize_tab(args.file, args.model, args.method) 165 | else: 166 | visualize_tree(args.file, args.model, args.method, args.gt_tree_path) --------------------------------------------------------------------------------