├── .gitignore ├── ComputeBioSentVecAbstractEmbedding.py ├── FEVER_joint_paragraph_dynamic.py ├── FEVER_joint_paragraph_kgat.py ├── FEVER_stance_paragraph.py ├── FEVER_stance_paragraph_kgat.py ├── README.md ├── SentVecAbstractRetriaval.py ├── TFIDFabstractRetrieval.py ├── dataset.py ├── domain_adaptation_joint_paragraph_dynamic.py ├── domain_adaptation_joint_paragraph_fine_tune.py ├── domain_adaptation_joint_paragraph_kgat.py ├── domain_adaptation_joint_paragraph_kgat_prediction.py ├── domain_adaptation_joint_paragraph_predict.py ├── lib ├── data.py └── metrics.py ├── paragraph_model_dynamic.py ├── paragraph_model_kgat.py ├── requirements.txt ├── scifact_joint_paragraph_dynamic.py ├── scifact_joint_paragraph_dynamic_prediction.py ├── scifact_joint_paragraph_kgat.py ├── scifact_joint_paragraph_kgat_prediction.py ├── scifact_rationale_paragraph.py ├── scifact_stance_paragraph.py └── util.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /ComputeBioSentVecAbstractEmbedding.py: -------------------------------------------------------------------------------- 1 | import sent2vec 2 | from nltk import word_tokenize 3 | from nltk.corpus import stopwords 4 | from string import punctuation 5 | from scipy.spatial import distance 6 | import glob 7 | import pickle 8 | import json 9 | 10 | import argparse 11 | 12 | def preprocess_sentence(text): 13 | text = text.replace('/', ' / ') 14 | text = text.replace('.-', ' .- ') 15 | text = text.replace('.', ' . ') 16 | text = text.replace('\'', ' \' ') 17 | text = text.lower() 18 | 19 | tokens = [token for token in word_tokenize(text) if token not in punctuation and token not in stop_words] 20 | 21 | return ' '.join(tokens) 22 | 23 | if __name__ == "__main__": 24 | argparser = argparse.ArgumentParser() 25 | argparser.add_argument('--claim_file', type=str) 26 | argparser.add_argument('--corpus_file', type=str) 27 | argparser.add_argument('--sentvec_path', type=str) 28 | argparser.add_argument('--corpus_embedding_pickle', type=str, default="corpus_paragraph_biosentvec.pkl") 29 | argparser.add_argument('--claim_embedding_pickle', type=str, default="claim_biosentvec.pkl") 30 | 31 | args = argparser.parse_args() 32 | claim_file = args.claim_file 33 | corpus_file = args.corpus_file 34 | 35 | corpus = {} 36 | with open(corpus_file) as f: 37 | for line in f: 38 | abstract = json.loads(line) 39 | corpus[str(abstract["doc_id"])] = abstract 40 | 41 | claims = [] 42 | with open(claim_file) as f: 43 | for line in f: 44 | claim = json.loads(line) 45 | claims.append(claim) 46 | 47 | model_path = args.sentvec_path 48 | model = sent2vec.Sent2vecModel() 49 | try: 50 | model.load_model(model_path) 51 | except Exception as e: 52 | print(e) 53 | print('model successfully loaded') 54 | 55 | stop_words = set(stopwords.words('english')) 56 | 57 | # By paragraph embedding 58 | corpus_embeddings = {} 59 | for k, v in corpus.items(): 60 | original_sentences = [v['title']] + v['abstract'] 61 | processed_paragraph = " ".join([preprocess_sentence(sentence) for sentence in original_sentences]) 62 | sentence_vector = model.embed_sentence(processed_paragraph) 63 | corpus_embeddings[k] = sentence_vector 64 | 65 | with open(args.corpus_embedding_pickle,"wb") as f: 66 | pickle.dump(corpus_embeddings,f) 67 | 68 | claim_embeddings = {} 69 | for claim in claims: 70 | processed_sentence = preprocess_sentence(claim['claim']) 71 | sentence_vector = model.embed_sentence(processed_sentence) 72 | claim_embeddings[claim["id"]] = sentence_vector 73 | 74 | with open(args.claim_embedding_pickle,"wb") as f: 75 | pickle.dump(claim_embeddings,f) -------------------------------------------------------------------------------- /FEVER_joint_paragraph_dynamic.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import argparse 3 | 4 | import torch 5 | import jsonlines 6 | import os 7 | 8 | import functools 9 | print = functools.partial(print, flush=True) 10 | 11 | import torch.nn as nn 12 | import torch.nn.functional as F 13 | from torch.utils.data import Dataset, DataLoader, Subset 14 | from transformers import AutoModel, AutoTokenizer, get_cosine_schedule_with_warmup 15 | from tqdm import tqdm 16 | from typing import List 17 | from sklearn.metrics import f1_score, precision_score, recall_score 18 | 19 | import math 20 | import random 21 | import numpy as np 22 | 23 | from tqdm import tqdm 24 | from util import arg2param, flatten, stance2json, rationale2json 25 | from paragraph_model_dynamic import JointParagraphClassifier 26 | from dataset import FEVERParagraphBatchDataset 27 | 28 | import logging 29 | 30 | def schedule_sample_p(epoch, total): 31 | return np.sin(0.5* np.pi* epoch / (total-1)) 32 | 33 | def reset_random_seed(seed): 34 | np.random.seed(seed) 35 | random.seed(seed) 36 | torch.manual_seed(seed) 37 | 38 | def batch_rationale_label(labels, padding_idx = 2): 39 | max_sent_len = max([len(label) for label in labels]) 40 | label_matrix = torch.ones(len(labels), max_sent_len) * padding_idx 41 | label_list = [] 42 | for i, label in enumerate(labels): 43 | for j, evid in enumerate(label): 44 | label_matrix[i,j] = int(evid) 45 | label_list.append([int(evid) for evid in label]) 46 | return label_matrix.long(), label_list 47 | 48 | def evaluation(model, dataset): 49 | model.eval() 50 | rationale_predictions = [] 51 | rationale_labels = [] 52 | stance_preds = [] 53 | stance_labels = [] 54 | 55 | def remove_dummy(rationale_out): 56 | return [out[1:] for out in rationale_out] 57 | 58 | with torch.no_grad(): 59 | for batch in tqdm(DataLoader(dataset, batch_size = args.batch_size, shuffle=False)): 60 | encoded_dict = encode(tokenizer, batch) 61 | transformation_indices = token_idx_by_sentence(encoded_dict["input_ids"], 62 | tokenizer.sep_token_id, args.repfile) 63 | encoded_dict = {key: tensor.to(device) for key, tensor in encoded_dict.items()} 64 | transformation_indices = [tensor.to(device) for tensor in transformation_indices] 65 | stance_label = batch["stance"].to(device) 66 | padded_rationale_label, rationale_label = batch_rationale_label(batch["label"], padding_idx = 2) 67 | if padded_rationale_label.size(1) == transformation_indices[-1].size(1): 68 | rationale_out, stance_out, rationale_loss, stance_loss = \ 69 | model(encoded_dict, transformation_indices, stance_label = stance_label, 70 | rationale_label = padded_rationale_label.to(device)) 71 | stance_preds.extend(stance_out) 72 | stance_labels.extend(stance_label.cpu().numpy().tolist()) 73 | 74 | rationale_predictions.extend(remove_dummy(rationale_out)) 75 | rationale_labels.extend(remove_dummy(rationale_label)) 76 | 77 | stance_f1 = f1_score(stance_labels,stance_preds,average="micro",labels=[1, 2]) 78 | stance_precision = precision_score(stance_labels,stance_preds,average="micro",labels=[1, 2]) 79 | stance_recall = recall_score(stance_labels,stance_preds,average="micro",labels=[1, 2]) 80 | rationale_f1 = f1_score(flatten(rationale_labels),flatten(rationale_predictions)) 81 | rationale_precision = precision_score(flatten(rationale_labels),flatten(rationale_predictions)) 82 | rationale_recall = recall_score(flatten(rationale_labels),flatten(rationale_predictions)) 83 | return stance_f1, stance_precision, stance_recall, rationale_f1, rationale_precision, rationale_recall 84 | 85 | 86 | 87 | def encode(tokenizer, batch, max_sent_len = 512): 88 | def truncate(input_ids, max_length, sep_token_id, pad_token_id): 89 | def longest_first_truncation(sentences, objective): 90 | sent_lens = [len(sent) for sent in sentences] 91 | while np.sum(sent_lens) > objective: 92 | max_position = np.argmax(sent_lens) 93 | sent_lens[max_position] -= 1 94 | return [sentence[:length] for sentence, length in zip(sentences, sent_lens)] 95 | 96 | all_paragraphs = [] 97 | for paragraph in input_ids: 98 | valid_paragraph = paragraph[paragraph != pad_token_id] 99 | if valid_paragraph.size(0) <= max_length: 100 | all_paragraphs.append(paragraph[:max_length].unsqueeze(0)) 101 | else: 102 | sep_token_idx = np.arange(valid_paragraph.size(0))[(valid_paragraph == sep_token_id).numpy()] 103 | idx_by_sentence = [] 104 | prev_idx = 0 105 | for idx in sep_token_idx: 106 | idx_by_sentence.append(paragraph[prev_idx:idx]) 107 | prev_idx = idx 108 | objective = max_length - 1 - len(idx_by_sentence[0]) # The last sep_token left out 109 | truncated_sentences = longest_first_truncation(idx_by_sentence[1:], objective) 110 | truncated_paragraph = torch.cat([idx_by_sentence[0]] + truncated_sentences + [torch.tensor([sep_token_id])],0) 111 | all_paragraphs.append(truncated_paragraph.unsqueeze(0)) 112 | 113 | return torch.cat(all_paragraphs, 0) 114 | 115 | inputs = zip(batch["claim"], batch["paragraph"]) 116 | encoded_dict = tokenizer.batch_encode_plus( 117 | inputs, 118 | pad_to_max_length=True,add_special_tokens=True, 119 | return_tensors='pt') 120 | if encoded_dict['input_ids'].size(1) > max_sent_len: 121 | if 'token_type_ids' in encoded_dict: 122 | encoded_dict = { 123 | "input_ids": truncate(encoded_dict['input_ids'], max_sent_len, 124 | tokenizer.sep_token_id, tokenizer.pad_token_id), 125 | 'token_type_ids': encoded_dict['token_type_ids'][:,:max_sent_len], 126 | 'attention_mask': encoded_dict['attention_mask'][:,:max_sent_len] 127 | } 128 | else: 129 | encoded_dict = { 130 | "input_ids": truncate(encoded_dict['input_ids'], max_sent_len, 131 | tokenizer.sep_token_id, tokenizer.pad_token_id), 132 | 'attention_mask': encoded_dict['attention_mask'][:,:max_sent_len] 133 | } 134 | 135 | return encoded_dict 136 | 137 | def token_idx_by_sentence(input_ids, sep_token_id, model_name): 138 | """ 139 | Compute the token indices matrix of the BERT output. 140 | input_ids: (batch_size, paragraph_len) 141 | batch_indices, indices_by_batch, mask: (batch_size, N_sentence, N_token) 142 | bert_out: (batch_size, paragraph_len,BERT_dim) 143 | bert_out[batch_indices,indices_by_batch,:]: (batch_size, N_sentence, N_token, BERT_dim) 144 | """ 145 | padding_idx = -1 146 | sep_tokens = (input_ids == sep_token_id).bool() 147 | paragraph_lens = torch.sum(sep_tokens,1).numpy().tolist() 148 | indices = torch.arange(sep_tokens.size(-1)).unsqueeze(0).expand(sep_tokens.size(0),-1) 149 | sep_indices = torch.split(indices[sep_tokens],paragraph_lens) 150 | paragraph_lens = [] 151 | all_word_indices = [] 152 | for paragraph in sep_indices: 153 | if "roberta" in model_name: 154 | paragraph = paragraph[1:] 155 | word_indices = [torch.arange(paragraph[i]+1, paragraph[i+1]+1) for i in range(paragraph.size(0)-1)] 156 | paragraph_lens.append(len(word_indices)) 157 | all_word_indices.extend(word_indices) 158 | indices_by_sentence = nn.utils.rnn.pad_sequence(all_word_indices, batch_first=True, padding_value=padding_idx) 159 | indices_by_sentence_split = torch.split(indices_by_sentence,paragraph_lens) 160 | indices_by_batch = nn.utils.rnn.pad_sequence(indices_by_sentence_split, batch_first=True, padding_value=padding_idx) 161 | batch_indices = torch.arange(sep_tokens.size(0)).unsqueeze(-1).unsqueeze(-1).expand(-1,indices_by_batch.size(1),indices_by_batch.size(-1)) 162 | mask = (indices_by_batch>=0) 163 | 164 | return batch_indices.long(), indices_by_batch.long(), mask.long() 165 | 166 | if __name__ == "__main__": 167 | argparser = argparse.ArgumentParser(description="Train, cross-validate and run sentence sequence tagger") 168 | argparser.add_argument('--repfile', type=str, default = "roberta-large", help="Word embedding file") 169 | argparser.add_argument('--train_file', type=str, default="/nas/home/xiangcil/scifact/data/fever_train_retrieved.jsonl") 170 | argparser.add_argument('--pre_trained_model', type=str) 171 | #argparser.add_argument('--train_file', type=str) 172 | argparser.add_argument('--test_file', type=str, default="/nas/home/xiangcil/scifact/data/fever_dev_retrieved.jsonl") 173 | argparser.add_argument('--bert_lr', type=float, default=1e-5, help="Learning rate for BERT-like LM") 174 | argparser.add_argument('--lr', type=float, default=5e-6, help="Learning rate") 175 | argparser.add_argument('--dropout', type=float, default=0, help="embedding_dropout rate") 176 | argparser.add_argument('--bert_dim', type=int, default=1024, help="bert_dimension") 177 | argparser.add_argument('--epoch', type=int, default=10, help="Training epoch") 178 | argparser.add_argument('--MAX_SENT_LEN', type=int, default=512) 179 | argparser.add_argument('--loss_ratio', type=float, default=5) 180 | argparser.add_argument('--checkpoint', type=str, default = "fever_roberta_joint_paragraph_dynamic") 181 | argparser.add_argument('--log_file', type=str, default = "fever_joint_paragraph_performances.jsonl") 182 | argparser.add_argument('--update_step', type=int, default=10) 183 | argparser.add_argument('--batch_size', type=int, default=1) # roberta-large: 2; bert: 8 184 | argparser.add_argument('--k', type=int, default=0) 185 | argparser.add_argument('--evaluation_step', type=int, default=50000) 186 | logging.getLogger("transformers.tokenization_utils_base").setLevel(logging.ERROR) 187 | 188 | reset_random_seed(12345) 189 | 190 | args = argparser.parse_args() 191 | 192 | with open(args.checkpoint+".log", 'w') as f: 193 | sys.stdout = f 194 | 195 | device = torch.device("cuda:0" if torch.cuda.is_available() else 'cpu') 196 | tokenizer = AutoTokenizer.from_pretrained(args.repfile) 197 | 198 | if args.train_file: 199 | train = True 200 | #assert args.repfile is not None, "Word embedding file required for training." 201 | else: 202 | train = False 203 | if args.test_file: 204 | test = True 205 | else: 206 | test = False 207 | 208 | params = vars(args) 209 | 210 | for k,v in params.items(): 211 | print(k,v) 212 | 213 | if train: 214 | train_set = FEVERParagraphBatchDataset(args.train_file, 215 | sep_token = tokenizer.sep_token, k=args.k) 216 | dev_set = FEVERParagraphBatchDataset(args.test_file, 217 | sep_token = tokenizer.sep_token, k=args.k) 218 | 219 | print("Loaded dataset!") 220 | 221 | model = JointParagraphClassifier(args.repfile, args.bert_dim, 222 | args.dropout).to(device) 223 | 224 | if args.pre_trained_model is not None: 225 | model.load_state_dict(torch.load(args.pre_trained_model)) 226 | print("Loaded pre-trained model.") 227 | 228 | if train: 229 | settings = [{'params': model.bert.parameters(), 'lr': args.bert_lr}] 230 | for module in model.extra_modules: 231 | settings.append({'params': module.parameters(), 'lr': args.lr}) 232 | optimizer = torch.optim.Adam(settings) 233 | scheduler = get_cosine_schedule_with_warmup(optimizer, 0, args.epoch) 234 | 235 | #if torch.cuda.device_count() > 1: 236 | # print("Let's use", torch.cuda.device_count(), "GPUs!") 237 | # # dim = 0 [30, xxx] -> [10, ...], [10, ...], [10, ...] on 3 GPUs 238 | # model = nn.DataParallel(model) 239 | 240 | model.train() 241 | 242 | for epoch in range(args.epoch): 243 | error_count = 0 244 | sample_p = schedule_sample_p(epoch, args.epoch) 245 | tq = tqdm(DataLoader(train_set, batch_size = args.batch_size, shuffle=True)) 246 | for i, batch in enumerate(tq): 247 | encoded_dict = encode(tokenizer, batch) 248 | transformation_indices = token_idx_by_sentence(encoded_dict["input_ids"], tokenizer.sep_token_id, args.repfile) 249 | encoded_dict = {key: tensor.to(device) for key, tensor in encoded_dict.items()} 250 | transformation_indices = [tensor.to(device) for tensor in transformation_indices] 251 | stance_label = batch["stance"].to(device) 252 | padded_rationale_label, rationale_label = batch_rationale_label(batch["label"], padding_idx = 2) 253 | if padded_rationale_label.size(1) == transformation_indices[-1].size(1): # Skip some rare cases with inconsistent input size. 254 | rationale_out, stance_out, rationale_loss, stance_loss = \ 255 | model(encoded_dict, transformation_indices, stance_label = stance_label, 256 | rationale_label = padded_rationale_label.to(device), sample_p = sample_p) 257 | rationale_loss *= args.loss_ratio 258 | loss = rationale_loss + stance_loss 259 | loss.sum().backward() 260 | else: 261 | error_count += 1 262 | 263 | if i % args.update_step == args.update_step - 1: 264 | optimizer.step() 265 | optimizer.zero_grad() 266 | tq.set_description(f'Epoch {epoch}, iter {i}, loss: {round(loss.item(), 4)}, stance loss: {round(stance_loss.item(), 4)}, rationale loss: {round(rationale_loss.item(), 4)}') 267 | 268 | 269 | if i % args.evaluation_step == args.evaluation_step-1: 270 | torch.save(model.state_dict(), args.checkpoint+"_"+str(epoch)+"_"+str(i)+".model") 271 | 272 | # Evaluation 273 | subset_train = Subset(train_set, range(0, 1000)) 274 | train_score = evaluation(model, subset_train) 275 | print(f'Epoch {epoch}, step {i}, train stance f1 p r: %.4f, %.4f, %.4f, rationale f1 p r: %.4f, %.4f, %.4f' % train_score) 276 | 277 | subset_dev = Subset(dev_set, range(0, 1000)) 278 | dev_score = evaluation(model, subset_dev) 279 | print(f'Epoch {epoch}, step {i}, dev stance f1 p r: %.4f, %.4f, %.4f, rationale f1 p r: %.4f, %.4f, %.4f' % dev_score) 280 | scheduler.step() 281 | torch.save(model.state_dict(), args.checkpoint+"_"+str(epoch)+".model") 282 | print(error_count, "mismatch occurred.") 283 | 284 | # Evaluation 285 | subset_train = Subset(train_set, range(0, 10000)) 286 | train_score = evaluation(model, subset_train) 287 | print(f'Epoch {epoch}, train stance f1 p r: %.4f, %.4f, %.4f, rationale f1 p r: %.4f, %.4f, %.4f' % train_score) 288 | 289 | subset_dev = Subset(dev_set, range(0, 10000)) 290 | dev_score = evaluation(model, subset_dev) 291 | print(f'Epoch {epoch}, dev stance f1 p r: %.4f, %.4f, %.4f, rationale f1 p r: %.4f, %.4f, %.4f' % dev_score) 292 | 293 | 294 | 295 | if test: 296 | model = JointParagraphClassifier(args.repfile, args.bert_dim, 297 | args.dropout).to(device) 298 | model.load_state_dict(torch.load(args.checkpoint)) 299 | 300 | 301 | # Evaluation 302 | subset_dev = Subset(dev_set, range(0, 10000)) 303 | dev_score = evaluation(model, subset_dev) 304 | print(f'Test stance f1 p r: %.4f, %.4f, %.4f, rationale f1 p r: %.4f, %.4f, %.4f' % dev_score) 305 | 306 | if train: 307 | params["stance_f1"] = dev_score[0] 308 | params["stance_precision"] = dev_score[1] 309 | params["stance_recall"] = dev_score[2] 310 | params["rationale_f1"] = dev_score[3] 311 | params["rationale_precision"] = dev_score[4] 312 | params["rationale_recalls"] = dev_score[5] 313 | 314 | with jsonlines.open(args.log_file, mode='a') as writer: 315 | writer.write(params) -------------------------------------------------------------------------------- /FEVER_joint_paragraph_kgat.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import argparse 3 | 4 | import torch 5 | import jsonlines 6 | import os 7 | 8 | import functools 9 | print = functools.partial(print, flush=True) 10 | 11 | import torch.nn as nn 12 | import torch.nn.functional as F 13 | from torch.utils.data import Dataset, DataLoader, Subset 14 | from transformers import AutoModel, AutoTokenizer, get_cosine_schedule_with_warmup 15 | from tqdm import tqdm 16 | from typing import List 17 | from sklearn.metrics import f1_score, precision_score, recall_score 18 | 19 | import math 20 | import random 21 | import numpy as np 22 | 23 | from tqdm import tqdm 24 | from util import arg2param, flatten, stance2json, rationale2json 25 | from paragraph_model_kgat import JointParagraphKGATClassifier 26 | from dataset import FEVERParagraphBatchDataset 27 | 28 | import logging 29 | 30 | def schedule_sample_p(epoch, total): 31 | return np.sin(0.5* np.pi* epoch / (total-1)) 32 | 33 | def reset_random_seed(seed): 34 | np.random.seed(seed) 35 | random.seed(seed) 36 | torch.manual_seed(seed) 37 | 38 | def batch_rationale_label(labels, padding_idx = 2): 39 | max_sent_len = max([len(label) for label in labels]) 40 | label_matrix = torch.ones(len(labels), max_sent_len) * padding_idx 41 | label_list = [] 42 | for i, label in enumerate(labels): 43 | for j, evid in enumerate(label): 44 | label_matrix[i,j] = int(evid) 45 | label_list.append([int(evid) for evid in label]) 46 | return label_matrix.long(), label_list 47 | 48 | def evaluation(model, dataset): 49 | model.eval() 50 | rationale_predictions = [] 51 | rationale_labels = [] 52 | stance_preds = [] 53 | stance_labels = [] 54 | 55 | def remove_dummy(rationale_out): 56 | return [out[1:] for out in rationale_out] 57 | 58 | with torch.no_grad(): 59 | for batch in tqdm(DataLoader(dataset, batch_size = args.batch_size, shuffle=False)): 60 | encoded_dict = encode(tokenizer, batch) 61 | transformation_indices = token_idx_by_sentence(encoded_dict["input_ids"], 62 | tokenizer.sep_token_id, args.repfile) 63 | encoded_dict = {key: tensor.to(device) for key, tensor in encoded_dict.items()} 64 | transformation_indices = [tensor.to(device) for tensor in transformation_indices] 65 | stance_label = batch["stance"].to(device) 66 | padded_rationale_label, rationale_label = batch_rationale_label(batch["label"], padding_idx = 2) 67 | rationale_out, stance_out, rationale_loss, stance_loss = \ 68 | model(encoded_dict, transformation_indices, stance_label = stance_label, 69 | rationale_label = padded_rationale_label.to(device)) 70 | stance_preds.extend(stance_out) 71 | stance_labels.extend(stance_label.cpu().numpy().tolist()) 72 | 73 | rationale_predictions.extend(remove_dummy(rationale_out)) 74 | rationale_labels.extend(remove_dummy(rationale_label)) 75 | 76 | stance_f1 = f1_score(stance_labels,stance_preds,average="micro",labels=[1, 2]) 77 | stance_precision = precision_score(stance_labels,stance_preds,average="micro",labels=[1, 2]) 78 | stance_recall = recall_score(stance_labels,stance_preds,average="micro",labels=[1, 2]) 79 | rationale_f1 = f1_score(flatten(rationale_labels),flatten(rationale_predictions)) 80 | rationale_precision = precision_score(flatten(rationale_labels),flatten(rationale_predictions)) 81 | rationale_recall = recall_score(flatten(rationale_labels),flatten(rationale_predictions)) 82 | return stance_f1, stance_precision, stance_recall, rationale_f1, rationale_precision, rationale_recall 83 | 84 | 85 | 86 | def encode(tokenizer, batch, max_sent_len = 512): 87 | def truncate(input_ids, max_length, sep_token_id, pad_token_id): 88 | def longest_first_truncation(sentences, objective): 89 | sent_lens = [len(sent) for sent in sentences] 90 | while np.sum(sent_lens) > objective: 91 | max_position = np.argmax(sent_lens) 92 | sent_lens[max_position] -= 1 93 | return [sentence[:length] for sentence, length in zip(sentences, sent_lens)] 94 | 95 | all_paragraphs = [] 96 | for paragraph in input_ids: 97 | valid_paragraph = paragraph[paragraph != pad_token_id] 98 | if valid_paragraph.size(0) <= max_length: 99 | all_paragraphs.append(paragraph[:max_length].unsqueeze(0)) 100 | else: 101 | sep_token_idx = np.arange(valid_paragraph.size(0))[(valid_paragraph == sep_token_id).numpy()] 102 | idx_by_sentence = [] 103 | prev_idx = 0 104 | for idx in sep_token_idx: 105 | idx_by_sentence.append(paragraph[prev_idx:idx]) 106 | prev_idx = idx 107 | objective = max_length - 1 - len(idx_by_sentence[0]) # The last sep_token left out 108 | truncated_sentences = longest_first_truncation(idx_by_sentence[1:], objective) 109 | truncated_paragraph = torch.cat([idx_by_sentence[0]] + truncated_sentences + [torch.tensor([sep_token_id])],0) 110 | all_paragraphs.append(truncated_paragraph.unsqueeze(0)) 111 | 112 | return torch.cat(all_paragraphs, 0) 113 | 114 | inputs = zip(batch["claim"], batch["paragraph"]) 115 | encoded_dict = tokenizer.batch_encode_plus( 116 | inputs, 117 | pad_to_max_length=True,add_special_tokens=True, 118 | return_tensors='pt') 119 | if encoded_dict['input_ids'].size(1) > max_sent_len: 120 | if 'token_type_ids' in encoded_dict: 121 | encoded_dict = { 122 | "input_ids": truncate(encoded_dict['input_ids'], max_sent_len, 123 | tokenizer.sep_token_id, tokenizer.pad_token_id), 124 | 'token_type_ids': encoded_dict['token_type_ids'][:,:max_sent_len], 125 | 'attention_mask': encoded_dict['attention_mask'][:,:max_sent_len] 126 | } 127 | else: 128 | encoded_dict = { 129 | "input_ids": truncate(encoded_dict['input_ids'], max_sent_len, 130 | tokenizer.sep_token_id, tokenizer.pad_token_id), 131 | 'attention_mask': encoded_dict['attention_mask'][:,:max_sent_len] 132 | } 133 | 134 | return encoded_dict 135 | 136 | def token_idx_by_sentence(input_ids, sep_token_id, model_name): 137 | """ 138 | Advanced indexing: Compute the token indices matrix of the BERT output. 139 | input_ids: (batch_size, paragraph_len) 140 | batch_indices, indices_by_batch, mask: (batch_size, N_sentence, N_token) 141 | bert_out: (batch_size, paragraph_len,BERT_dim) 142 | bert_out[batch_indices,indices_by_batch,:]: (batch_size, N_sentence, N_token, BERT_dim) 143 | """ 144 | padding_idx = -1 145 | sep_tokens = (input_ids == sep_token_id).bool() 146 | paragraph_lens = torch.sum(sep_tokens,1).numpy().tolist() # i.e. N_sentences per paragraph 147 | indices = torch.arange(sep_tokens.size(-1)).unsqueeze(0).expand(sep_tokens.size(0),-1) # 0,1,2,3,....,511 for each sentence 148 | sep_indices = torch.split(indices[sep_tokens],paragraph_lens) # indices of SEP tokens per paragraph 149 | paragraph_lens = [] 150 | all_word_indices = [] 151 | for paragraph in sep_indices: 152 | # claim sentence: [CLS] token1 token2 ... tokenk 153 | claim_word_indices = torch.arange(0, paragraph[0]) 154 | if "roberta" in model_name: # Huggingface Roberta has ...... 155 | paragraph = paragraph[1:] 156 | # each sentence: [SEP] token1 token2 ... tokenk, the last [SEP] in the paragraph is ditched. 157 | sentence_word_indices = [torch.arange(paragraph[i], paragraph[i+1]) for i in range(paragraph.size(0)-1)] 158 | 159 | # KGAT requires claim sentence, so add it back. 160 | word_indices = [claim_word_indices] + sentence_word_indices 161 | 162 | paragraph_lens.append(len(word_indices)) 163 | all_word_indices.extend(word_indices) 164 | indices_by_sentence = nn.utils.rnn.pad_sequence(all_word_indices, batch_first=True, padding_value=padding_idx) 165 | indices_by_sentence_split = torch.split(indices_by_sentence,paragraph_lens) 166 | indices_by_batch = nn.utils.rnn.pad_sequence(indices_by_sentence_split, batch_first=True, padding_value=padding_idx) 167 | batch_indices = torch.arange(sep_tokens.size(0)).unsqueeze(-1).unsqueeze(-1).expand(-1,indices_by_batch.size(1),indices_by_batch.size(-1)) 168 | mask = (indices_by_batch>=0) 169 | 170 | return batch_indices.long(), indices_by_batch.long(), mask.long() 171 | 172 | if __name__ == "__main__": 173 | argparser = argparse.ArgumentParser(description="Train, cross-validate and run sentence sequence tagger") 174 | argparser.add_argument('--repfile', type=str, default = "roberta-large", help="Word embedding file") 175 | argparser.add_argument('--train_file', type=str, default="/home/xxl190027/scifact_data/fever_train_retrieved_15.jsonl") 176 | argparser.add_argument('--pre_trained_model', type=str) 177 | #argparser.add_argument('--train_file', type=str) 178 | argparser.add_argument('--test_file', type=str, default="/home/xxl190027/scifact_data/fever_dev_retrieved_15.jsonl") 179 | argparser.add_argument('--bert_lr', type=float, default=5e-6, help="Learning rate for BERT-like LM") 180 | argparser.add_argument('--lr', type=float, default=5e-6, help="Learning rate") 181 | argparser.add_argument('--dropout', type=float, default=0, help="embedding_dropout rate") 182 | argparser.add_argument('--bert_dim', type=int, default=1024, help="bert_dimension") 183 | argparser.add_argument('--epoch', type=int, default=10, help="Training epoch") 184 | argparser.add_argument('--MAX_SENT_LEN', type=int, default=512) 185 | argparser.add_argument('--loss_ratio', type=float, default=1) 186 | argparser.add_argument('--checkpoint', type=str, default = "fever_roberta_joint_paragraph_kgat") 187 | argparser.add_argument('--log_file', type=str, default = "fever_joint_paragraph_performances.jsonl") 188 | argparser.add_argument('--update_step', type=int, default=10) 189 | argparser.add_argument('--batch_size', type=int, default=1) # roberta-large: 2; bert: 8 190 | argparser.add_argument('--k', type=int, default=0) 191 | argparser.add_argument('--kernel', type=int, default=6) 192 | argparser.add_argument('--evaluation_step', type=int, default=50000) 193 | logging.getLogger("transformers.tokenization_utils_base").setLevel(logging.ERROR) 194 | 195 | reset_random_seed(12345) 196 | 197 | args = argparser.parse_args() 198 | 199 | with open(args.checkpoint+".log", 'w') as f: 200 | sys.stdout = f 201 | 202 | device = torch.device("cuda:0" if torch.cuda.is_available() else 'cpu') 203 | tokenizer = AutoTokenizer.from_pretrained(args.repfile) 204 | 205 | if args.train_file: 206 | train = True 207 | #assert args.repfile is not None, "Word embedding file required for training." 208 | else: 209 | train = False 210 | if args.test_file: 211 | test = True 212 | else: 213 | test = False 214 | 215 | params = vars(args) 216 | 217 | for k,v in params.items(): 218 | print(k,v) 219 | 220 | if train: 221 | train_set = FEVERParagraphBatchDataset(args.train_file, 222 | sep_token = tokenizer.sep_token, k=args.k) 223 | dev_set = FEVERParagraphBatchDataset(args.test_file, 224 | sep_token = tokenizer.sep_token, k=args.k) 225 | 226 | print("Loaded dataset!") 227 | 228 | model = JointParagraphKGATClassifier(args.repfile, args.bert_dim, 229 | args.dropout, kernel = args.kernel).to(device) 230 | 231 | if args.pre_trained_model is not None: 232 | model.load_state_dict(torch.load(args.pre_trained_model)) 233 | print("Loaded pre-trained model.") 234 | 235 | if train: 236 | settings = [{'params': model.bert.parameters(), 'lr': args.bert_lr}] 237 | for module in model.extra_modules: 238 | settings.append({'params': module.parameters(), 'lr': args.lr}) 239 | optimizer = torch.optim.Adam(settings) 240 | scheduler = get_cosine_schedule_with_warmup(optimizer, 0, args.epoch) 241 | 242 | #if torch.cuda.device_count() > 1: 243 | # print("Let's use", torch.cuda.device_count(), "GPUs!") 244 | # # dim = 0 [30, xxx] -> [10, ...], [10, ...], [10, ...] on 3 GPUs 245 | # model = nn.DataParallel(model) 246 | 247 | model.train() 248 | 249 | for epoch in range(args.epoch): 250 | sample_p = schedule_sample_p(epoch, args.epoch) 251 | tq = tqdm(DataLoader(train_set, batch_size = args.batch_size, shuffle=True)) 252 | for i, batch in enumerate(tq): 253 | encoded_dict = encode(tokenizer, batch) 254 | transformation_indices = token_idx_by_sentence(encoded_dict["input_ids"], tokenizer.sep_token_id, args.repfile) 255 | encoded_dict = {key: tensor.to(device) for key, tensor in encoded_dict.items()} 256 | transformation_indices = [tensor.to(device) for tensor in transformation_indices] 257 | stance_label = batch["stance"].to(device) 258 | padded_rationale_label, rationale_label = batch_rationale_label(batch["label"], padding_idx = 2) 259 | 260 | rationale_out, stance_out, rationale_loss, stance_loss = \ 261 | model(encoded_dict, transformation_indices, stance_label = stance_label, 262 | rationale_label = padded_rationale_label.to(device), sample_p = sample_p) 263 | rationale_loss *= args.loss_ratio 264 | loss = rationale_loss + stance_loss 265 | loss.sum().backward() 266 | 267 | if i % args.update_step == args.update_step - 1: 268 | optimizer.step() 269 | optimizer.zero_grad() 270 | tq.set_description(f'Epoch {epoch}, iter {i}, loss: {round(loss.item(), 4)}, stance loss: {round(stance_loss.item(), 4)}, rationale loss: {round(rationale_loss.item(), 4)}') 271 | 272 | 273 | if i % args.evaluation_step == args.evaluation_step-1: 274 | torch.save(model.state_dict(), args.checkpoint+"_"+str(epoch)+"_"+str(i)+".model") 275 | 276 | subset_train = Subset(train_set, range(0, 1000)) 277 | train_score = evaluation(model, subset_train) 278 | print(f'Epoch {epoch} Step {i}, train stance f1 p r: %.4f, %.4f, %.4f, rationale f1 p r: %.4f, %.4f, %.4f' % train_score) 279 | 280 | subset_dev = Subset(dev_set, range(0, 1000)) 281 | dev_score = evaluation(model, subset_dev) 282 | print(f'Epoch {epoch} Step {i}, dev stance f1 p r: %.4f, %.4f, %.4f, rationale f1 p r: %.4f, %.4f, %.4f' % dev_score) 283 | scheduler.step() 284 | torch.save(model.state_dict(), args.checkpoint+"_"+str(epoch)+".model") 285 | 286 | # Evaluation 287 | subset_train = Subset(train_set, range(0, 1000)) 288 | train_score = evaluation(model, subset_train) 289 | print(f'Epoch {epoch}, train stance f1 p r: %.4f, %.4f, %.4f, rationale f1 p r: %.4f, %.4f, %.4f' % train_score) 290 | 291 | subset_dev = Subset(dev_set, range(0, 1000)) 292 | dev_score = evaluation(model, subset_dev) 293 | print(f'Epoch {epoch}, dev stance f1 p r: %.4f, %.4f, %.4f, rationale f1 p r: %.4f, %.4f, %.4f' % dev_score) 294 | 295 | 296 | 297 | if test: 298 | model = JointParagraphKGATClassifier(args.repfile, args.bert_dim, 299 | args.dropout, kernel = args.kernel).to(device) 300 | model.load_state_dict(torch.load(args.checkpoint)) 301 | 302 | 303 | # Evaluation 304 | reset_random_seed(12345) 305 | subset_dev = Subset(dev_set, range(0, 10000)) 306 | dev_score = evaluation(model, subset_dev) 307 | print(f'Test stance f1 p r: %.4f, %.4f, %.4f, rationale f1 p r: %.4f, %.4f, %.4f' % dev_score) 308 | 309 | if train: 310 | params["stance_f1"] = dev_score[0] 311 | params["stance_precision"] = dev_score[1] 312 | params["stance_recall"] = dev_score[2] 313 | params["rationale_f1"] = dev_score[3] 314 | params["rationale_precision"] = dev_score[4] 315 | params["rationale_recalls"] = dev_score[5] 316 | 317 | with jsonlines.open(args.log_file, mode='a') as writer: 318 | writer.write(params) -------------------------------------------------------------------------------- /FEVER_stance_paragraph.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import argparse 3 | 4 | import torch 5 | import jsonlines 6 | import os 7 | 8 | import functools 9 | print = functools.partial(print, flush=True) 10 | 11 | import torch.nn as nn 12 | import torch.nn.functional as F 13 | from torch.utils.data import Dataset, DataLoader, Subset 14 | from transformers import AutoModel, AutoTokenizer, get_cosine_schedule_with_warmup 15 | from tqdm import tqdm 16 | from typing import List 17 | from sklearn.metrics import f1_score, precision_score, recall_score 18 | 19 | import random 20 | import numpy as np 21 | 22 | from tqdm import tqdm 23 | from util import arg2param, flatten, stance2json, rationale2json, merge_json 24 | from paragraph_model_dynamic import StanceParagraphClassifier as JointParagraphClassifier 25 | from dataset import FEVERStanceDataset as FEVERParagraphBatchDataset 26 | 27 | import logging 28 | 29 | def reset_random_seed(seed): 30 | np.random.seed(seed) 31 | random.seed(seed) 32 | torch.manual_seed(seed) 33 | 34 | def evaluation(model, dataset): 35 | model.eval() 36 | stance_preds = [] 37 | stance_labels = [] 38 | 39 | with torch.no_grad(): 40 | for batch in tqdm(DataLoader(dataset, batch_size = args.batch_size, shuffle=False)): 41 | encoded_dict = encode(tokenizer, batch) 42 | transformation_indices = token_idx_by_sentence(encoded_dict["input_ids"], 43 | tokenizer.sep_token_id, args.repfile) 44 | encoded_dict = {key: tensor.to(device) for key, tensor in encoded_dict.items()} 45 | transformation_indices = [tensor.to(device) for tensor in transformation_indices] 46 | stance_label = batch["stance"].to(device) 47 | stance_out, stance_loss = \ 48 | model(encoded_dict, transformation_indices, stance_label = stance_label) 49 | stance_preds.extend(stance_out) 50 | stance_labels.extend(stance_label.cpu().numpy().tolist()) 51 | 52 | stance_f1 = f1_score(stance_labels,stance_preds,average="micro",labels=[1, 2]) 53 | stance_precision = precision_score(stance_labels,stance_preds,average="micro",labels=[1, 2]) 54 | stance_recall = recall_score(stance_labels,stance_preds,average="micro",labels=[1, 2]) 55 | return stance_f1, stance_precision, stance_recall 56 | 57 | 58 | def encode(tokenizer, batch, max_sent_len = 512): 59 | def truncate(input_ids, max_length, sep_token_id, pad_token_id): 60 | def longest_first_truncation(sentences, objective): 61 | sent_lens = [len(sent) for sent in sentences] 62 | while np.sum(sent_lens) > objective: 63 | max_position = np.argmax(sent_lens) 64 | sent_lens[max_position] -= 1 65 | return [sentence[:length] for sentence, length in zip(sentences, sent_lens)] 66 | 67 | all_paragraphs = [] 68 | for paragraph in input_ids: 69 | valid_paragraph = paragraph[paragraph != pad_token_id] 70 | if valid_paragraph.size(0) <= max_length: 71 | all_paragraphs.append(paragraph[:max_length].unsqueeze(0)) 72 | else: 73 | sep_token_idx = np.arange(valid_paragraph.size(0))[(valid_paragraph == sep_token_id).numpy()] 74 | idx_by_sentence = [] 75 | prev_idx = 0 76 | for idx in sep_token_idx: 77 | idx_by_sentence.append(paragraph[prev_idx:idx]) 78 | prev_idx = idx 79 | objective = max_length - 1 - len(idx_by_sentence[0]) # The last sep_token left out 80 | truncated_sentences = longest_first_truncation(idx_by_sentence[1:], objective) 81 | truncated_paragraph = torch.cat([idx_by_sentence[0]] + truncated_sentences + [torch.tensor([sep_token_id])],0) 82 | all_paragraphs.append(truncated_paragraph.unsqueeze(0)) 83 | 84 | return torch.cat(all_paragraphs, 0) 85 | 86 | inputs = zip(batch["claim"], batch["paragraph"]) 87 | encoded_dict = tokenizer.batch_encode_plus( 88 | inputs, 89 | pad_to_max_length=True,add_special_tokens=True, 90 | return_tensors='pt') 91 | if encoded_dict['input_ids'].size(1) > max_sent_len: 92 | if 'token_type_ids' in encoded_dict: 93 | encoded_dict = { 94 | "input_ids": truncate(encoded_dict['input_ids'], max_sent_len, 95 | tokenizer.sep_token_id, tokenizer.pad_token_id), 96 | 'token_type_ids': encoded_dict['token_type_ids'][:,:max_sent_len], 97 | 'attention_mask': encoded_dict['attention_mask'][:,:max_sent_len] 98 | } 99 | else: 100 | encoded_dict = { 101 | "input_ids": truncate(encoded_dict['input_ids'], max_sent_len, 102 | tokenizer.sep_token_id, tokenizer.pad_token_id), 103 | 'attention_mask': encoded_dict['attention_mask'][:,:max_sent_len] 104 | } 105 | 106 | return encoded_dict 107 | 108 | def token_idx_by_sentence(input_ids, sep_token_id, model_name): 109 | """ 110 | Compute the token indices matrix of the BERT output. 111 | input_ids: (batch_size, paragraph_len) 112 | batch_indices, indices_by_batch, mask: (batch_size, N_sentence, N_token) 113 | bert_out: (batch_size, paragraph_len,BERT_dim) 114 | bert_out[batch_indices,indices_by_batch,:]: (batch_size, N_sentence, N_token, BERT_dim) 115 | """ 116 | padding_idx = -1 117 | sep_tokens = (input_ids == sep_token_id).bool() 118 | paragraph_lens = torch.sum(sep_tokens,1).numpy().tolist() 119 | indices = torch.arange(sep_tokens.size(-1)).unsqueeze(0).expand(sep_tokens.size(0),-1) 120 | sep_indices = torch.split(indices[sep_tokens],paragraph_lens) 121 | paragraph_lens = [] 122 | all_word_indices = [] 123 | for paragraph in sep_indices: 124 | if "roberta" in model_name: 125 | paragraph = paragraph[1:] 126 | word_indices = [torch.arange(paragraph[i]+1, paragraph[i+1]+1) for i in range(paragraph.size(0)-1)] 127 | paragraph_lens.append(len(word_indices)) 128 | all_word_indices.extend(word_indices) 129 | indices_by_sentence = nn.utils.rnn.pad_sequence(all_word_indices, batch_first=True, padding_value=padding_idx) 130 | indices_by_sentence_split = torch.split(indices_by_sentence,paragraph_lens) 131 | indices_by_batch = nn.utils.rnn.pad_sequence(indices_by_sentence_split, batch_first=True, padding_value=padding_idx) 132 | batch_indices = torch.arange(sep_tokens.size(0)).unsqueeze(-1).unsqueeze(-1).expand(-1,indices_by_batch.size(1),indices_by_batch.size(-1)) 133 | mask = (indices_by_batch>=0) 134 | 135 | return batch_indices.long(), indices_by_batch.long(), mask.long() 136 | 137 | if __name__ == "__main__": 138 | argparser = argparse.ArgumentParser(description="Train, cross-validate and run sentence sequence tagger") 139 | argparser.add_argument('--repfile', type=str, default = "roberta-large", help="Word embedding file") 140 | argparser.add_argument('--train_file', type=str, default="/nas/home/xiangcil/scifact/data/fever_train_retrieved.jsonl") 141 | argparser.add_argument('--pre_trained_model', type=str) 142 | #argparser.add_argument('--train_file', type=str) 143 | argparser.add_argument('--test_file', type=str, default="/nas/home/xiangcil/scifact/data/fever_dev_retrieved.jsonl") 144 | argparser.add_argument('--bert_lr', type=float, default=5e-6, help="Learning rate for BERT-like LM") 145 | argparser.add_argument('--lr', type=float, default=1e-6, help="Learning rate") 146 | argparser.add_argument('--dropout', type=float, default=0, help="embedding_dropout rate") 147 | argparser.add_argument('--bert_dim', type=int, default=1024, help="bert_dimension") 148 | argparser.add_argument('--epoch', type=int, default=10, help="Training epoch") 149 | argparser.add_argument('--MAX_SENT_LEN', type=int, default=512) 150 | argparser.add_argument('--checkpoint', type=str, default = "fever_roberta_stance_paragraph") 151 | argparser.add_argument('--log_file', type=str, default = "fever_stance_paragraph_performances.jsonl") 152 | argparser.add_argument('--update_step', type=int, default=10) 153 | argparser.add_argument('--batch_size', type=int, default=1) # roberta-large: 2; bert: 8 154 | argparser.add_argument('--k', type=int, default=0) 155 | argparser.add_argument('--evaluation_step', type=int, default=50000) 156 | logging.getLogger("transformers.tokenization_utils_base").setLevel(logging.ERROR) 157 | 158 | reset_random_seed(12345) 159 | 160 | args = argparser.parse_args() 161 | 162 | with open(args.checkpoint+".log", 'w') as f: 163 | sys.stdout = f 164 | 165 | device = torch.device("cuda:0" if torch.cuda.is_available() else 'cpu') 166 | tokenizer = AutoTokenizer.from_pretrained(args.repfile) 167 | 168 | if args.train_file: 169 | train = True 170 | #assert args.repfile is not None, "Word embedding file required for training." 171 | else: 172 | train = False 173 | if args.test_file: 174 | test = True 175 | else: 176 | test = False 177 | 178 | params = vars(args) 179 | 180 | for k,v in params.items(): 181 | print(k,v) 182 | 183 | if train: 184 | train_set = FEVERParagraphBatchDataset(args.train_file, 185 | sep_token = tokenizer.sep_token, k=args.k) 186 | dev_set = FEVERParagraphBatchDataset(args.test_file, 187 | sep_token = tokenizer.sep_token, k=args.k) 188 | 189 | print("Loaded dataset!") 190 | 191 | model = JointParagraphClassifier(args.repfile, args.bert_dim, 192 | args.dropout).to(device) 193 | 194 | if args.pre_trained_model is not None: 195 | model.load_state_dict(torch.load(args.pre_trained_model)) 196 | print("Loaded pre-trained model.") 197 | 198 | if train: 199 | settings = [{'params': model.bert.parameters(), 'lr': args.bert_lr}] 200 | for module in model.extra_modules: 201 | settings.append({'params': module.parameters(), 'lr': args.lr}) 202 | optimizer = torch.optim.Adam(settings) 203 | scheduler = get_cosine_schedule_with_warmup(optimizer, 0, args.epoch) 204 | 205 | #if torch.cuda.device_count() > 1: 206 | # print("Let's use", torch.cuda.device_count(), "GPUs!") 207 | # # dim = 0 [30, xxx] -> [10, ...], [10, ...], [10, ...] on 3 GPUs 208 | # model = nn.DataParallel(model) 209 | 210 | model.train() 211 | 212 | for epoch in range(args.epoch): 213 | error_count = 0 214 | tq = tqdm(DataLoader(train_set, batch_size = args.batch_size, shuffle=True)) 215 | for i, batch in enumerate(tq): 216 | encoded_dict = encode(tokenizer, batch) 217 | transformation_indices = token_idx_by_sentence(encoded_dict["input_ids"], tokenizer.sep_token_id, args.repfile) 218 | encoded_dict = {key: tensor.to(device) for key, tensor in encoded_dict.items()} 219 | transformation_indices = [tensor.to(device) for tensor in transformation_indices] 220 | stance_label = batch["stance"].to(device) 221 | 222 | stance_out, loss = \ 223 | model(encoded_dict, transformation_indices, stance_label = stance_label) 224 | loss.sum().backward() 225 | 226 | if i % args.update_step == args.update_step - 1: 227 | optimizer.step() 228 | optimizer.zero_grad() 229 | tq.set_description(f'Epoch {epoch}, iter {i}, loss: {round(loss.item(), 4)}') 230 | 231 | 232 | if i % args.evaluation_step == args.evaluation_step-1: 233 | torch.save(model.state_dict(), args.checkpoint+"_"+str(epoch)+"_"+str(i)+".model") 234 | 235 | # Evaluation 236 | subset_train = Subset(train_set, range(0, 1000)) 237 | train_score = evaluation(model, subset_train) 238 | print(f'Epoch {epoch}, step {i}, train stance f1 p r: %.4f, %.4f, %.4f' % train_score) 239 | 240 | subset_dev = Subset(dev_set, range(0, 1000)) 241 | dev_score = evaluation(model, subset_dev) 242 | print(f'Epoch {epoch}, step {i}, dev stance f1 p r: %.4f, %.4f, %.4f' % dev_score) 243 | scheduler.step() 244 | torch.save(model.state_dict(), args.checkpoint+"_"+str(epoch)+".model") 245 | print(error_count, "mismatch occurred.") 246 | 247 | # Evaluation 248 | subset_train = Subset(train_set, range(0, 10000)) 249 | train_score = evaluation(model, subset_train) 250 | print(f'Epoch {epoch}, train stance f1 p r: %.4f, %.4f, %.4f' % train_score) 251 | 252 | subset_dev = Subset(dev_set, range(0, 10000)) 253 | dev_score = evaluation(model, subset_dev) 254 | print(f'Epoch {epoch}, dev stance f1 p r: %.4f, %.4f, %.4f' % dev_score) 255 | 256 | 257 | 258 | if test: 259 | model = JointParagraphClassifier(args.repfile, args.bert_dim, 260 | args.dropout).to(device) 261 | model.load_state_dict(torch.load(args.checkpoint)) 262 | 263 | 264 | # Evaluation 265 | subset_dev = Subset(dev_set, range(0, 10000)) 266 | dev_score = evaluation(model, subset_dev) 267 | print(f'Test stance f1 p r: %.4f, %.4f, %.4f' % dev_score) 268 | 269 | if train: 270 | params["stance_f1"] = dev_score[0] 271 | params["stance_precision"] = dev_score[1] 272 | params["stance_recall"] = dev_score[2] 273 | 274 | with jsonlines.open(args.log_file, mode='a') as writer: 275 | writer.write(params) -------------------------------------------------------------------------------- /FEVER_stance_paragraph_kgat.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import argparse 3 | 4 | import torch 5 | import jsonlines 6 | import os 7 | 8 | import functools 9 | print = functools.partial(print, flush=True) 10 | 11 | import torch.nn as nn 12 | import torch.nn.functional as F 13 | from torch.utils.data import Dataset, DataLoader, Subset 14 | from transformers import AutoModel, AutoTokenizer, get_cosine_schedule_with_warmup 15 | from tqdm import tqdm 16 | from typing import List 17 | from sklearn.metrics import f1_score, precision_score, recall_score 18 | 19 | import random 20 | import numpy as np 21 | 22 | from tqdm import tqdm 23 | from util import arg2param, flatten, stance2json, rationale2json, merge_json 24 | from paragraph_model_kgat import KGATClassifier as JointParagraphClassifier 25 | from dataset import FEVERStanceDataset as FEVERParagraphBatchDataset 26 | 27 | import logging 28 | 29 | def reset_random_seed(seed): 30 | np.random.seed(seed) 31 | random.seed(seed) 32 | torch.manual_seed(seed) 33 | 34 | def evaluation(model, dataset): 35 | model.eval() 36 | stance_preds = [] 37 | stance_labels = [] 38 | 39 | with torch.no_grad(): 40 | for batch in tqdm(DataLoader(dataset, batch_size = args.batch_size, shuffle=False)): 41 | encoded_dict = encode(tokenizer, batch) 42 | transformation_indices = token_idx_by_sentence(encoded_dict["input_ids"], 43 | tokenizer.sep_token_id, args.repfile) 44 | encoded_dict = {key: tensor.to(device) for key, tensor in encoded_dict.items()} 45 | transformation_indices = [tensor.to(device) for tensor in transformation_indices] 46 | stance_label = batch["stance"].to(device) 47 | stance_out, stance_loss = \ 48 | model(encoded_dict, transformation_indices, stance_label = stance_label) 49 | stance_preds.extend(stance_out) 50 | stance_labels.extend(stance_label.cpu().numpy().tolist()) 51 | 52 | stance_f1 = f1_score(stance_labels,stance_preds,average="micro",labels=[1, 2]) 53 | stance_precision = precision_score(stance_labels,stance_preds,average="micro",labels=[1, 2]) 54 | stance_recall = recall_score(stance_labels,stance_preds,average="micro",labels=[1, 2]) 55 | return stance_f1, stance_precision, stance_recall 56 | 57 | 58 | def encode(tokenizer, batch, max_sent_len = 512): 59 | def truncate(input_ids, max_length, sep_token_id, pad_token_id): 60 | def longest_first_truncation(sentences, objective): 61 | sent_lens = [len(sent) for sent in sentences] 62 | while np.sum(sent_lens) > objective: 63 | max_position = np.argmax(sent_lens) 64 | sent_lens[max_position] -= 1 65 | return [sentence[:length] for sentence, length in zip(sentences, sent_lens)] 66 | 67 | all_paragraphs = [] 68 | for paragraph in input_ids: 69 | valid_paragraph = paragraph[paragraph != pad_token_id] 70 | if valid_paragraph.size(0) <= max_length: 71 | all_paragraphs.append(paragraph[:max_length].unsqueeze(0)) 72 | else: 73 | sep_token_idx = np.arange(valid_paragraph.size(0))[(valid_paragraph == sep_token_id).numpy()] 74 | idx_by_sentence = [] 75 | prev_idx = 0 76 | for idx in sep_token_idx: 77 | idx_by_sentence.append(paragraph[prev_idx:idx]) 78 | prev_idx = idx 79 | objective = max_length - 1 - len(idx_by_sentence[0]) # The last sep_token left out 80 | truncated_sentences = longest_first_truncation(idx_by_sentence[1:], objective) 81 | truncated_paragraph = torch.cat([idx_by_sentence[0]] + truncated_sentences + [torch.tensor([sep_token_id])],0) 82 | all_paragraphs.append(truncated_paragraph.unsqueeze(0)) 83 | 84 | return torch.cat(all_paragraphs, 0) 85 | 86 | inputs = zip(batch["claim"], batch["paragraph"]) 87 | encoded_dict = tokenizer.batch_encode_plus( 88 | inputs, 89 | pad_to_max_length=True,add_special_tokens=True, 90 | return_tensors='pt') 91 | if encoded_dict['input_ids'].size(1) > max_sent_len: 92 | if 'token_type_ids' in encoded_dict: 93 | encoded_dict = { 94 | "input_ids": truncate(encoded_dict['input_ids'], max_sent_len, 95 | tokenizer.sep_token_id, tokenizer.pad_token_id), 96 | 'token_type_ids': encoded_dict['token_type_ids'][:,:max_sent_len], 97 | 'attention_mask': encoded_dict['attention_mask'][:,:max_sent_len] 98 | } 99 | else: 100 | encoded_dict = { 101 | "input_ids": truncate(encoded_dict['input_ids'], max_sent_len, 102 | tokenizer.sep_token_id, tokenizer.pad_token_id), 103 | 'attention_mask': encoded_dict['attention_mask'][:,:max_sent_len] 104 | } 105 | 106 | return encoded_dict 107 | 108 | def token_idx_by_sentence(input_ids, sep_token_id, model_name): 109 | """ 110 | Compute the token indices matrix of the BERT output. 111 | input_ids: (batch_size, paragraph_len) 112 | batch_indices, indices_by_batch, mask: (batch_size, N_sentence, N_token) 113 | bert_out: (batch_size, paragraph_len,BERT_dim) 114 | bert_out[batch_indices,indices_by_batch,:]: (batch_size, N_sentence, N_token, BERT_dim) 115 | """ 116 | padding_idx = -1 117 | sep_tokens = (input_ids == sep_token_id).bool() 118 | paragraph_lens = torch.sum(sep_tokens,1).numpy().tolist() 119 | indices = torch.arange(sep_tokens.size(-1)).unsqueeze(0).expand(sep_tokens.size(0),-1) 120 | sep_indices = torch.split(indices[sep_tokens],paragraph_lens) 121 | paragraph_lens = [] 122 | all_word_indices = [] 123 | for paragraph in sep_indices: 124 | if "roberta" in model_name: 125 | paragraph = paragraph[1:] 126 | word_indices = [torch.arange(paragraph[i]+1, paragraph[i+1]+1) for i in range(paragraph.size(0)-1)] 127 | paragraph_lens.append(len(word_indices)) 128 | all_word_indices.extend(word_indices) 129 | indices_by_sentence = nn.utils.rnn.pad_sequence(all_word_indices, batch_first=True, padding_value=padding_idx) 130 | indices_by_sentence_split = torch.split(indices_by_sentence,paragraph_lens) 131 | indices_by_batch = nn.utils.rnn.pad_sequence(indices_by_sentence_split, batch_first=True, padding_value=padding_idx) 132 | batch_indices = torch.arange(sep_tokens.size(0)).unsqueeze(-1).unsqueeze(-1).expand(-1,indices_by_batch.size(1),indices_by_batch.size(-1)) 133 | mask = (indices_by_batch>=0) 134 | 135 | return batch_indices.long(), indices_by_batch.long(), mask.long() 136 | 137 | if __name__ == "__main__": 138 | argparser = argparse.ArgumentParser(description="Train, cross-validate and run sentence sequence tagger") 139 | argparser.add_argument('--repfile', type=str, default = "roberta-large", help="Word embedding file") 140 | argparser.add_argument('--train_file', type=str, default="/home/xxl190027/scifact_data/fever_train_retrieved_15.jsonl") 141 | argparser.add_argument('--pre_trained_model', type=str) 142 | #argparser.add_argument('--train_file', type=str) 143 | argparser.add_argument('--test_file', type=str, default="/home/xxl190027/scifact_data/fever_dev_retrieved_15.jsonl") 144 | argparser.add_argument('--bert_lr', type=float, default=1e-5, help="Learning rate for BERT-like LM") 145 | argparser.add_argument('--lr', type=float, default=5e-6, help="Learning rate") 146 | argparser.add_argument('--dropout', type=float, default=0, help="embedding_dropout rate") 147 | argparser.add_argument('--bert_dim', type=int, default=1024, help="bert_dimension") 148 | argparser.add_argument('--epoch', type=int, default=10, help="Training epoch") 149 | argparser.add_argument('--MAX_SENT_LEN', type=int, default=512) 150 | argparser.add_argument('--checkpoint', type=str, default = "fever_roberta_stance_paragraph_kgat") 151 | argparser.add_argument('--log_file', type=str, default = "fever_stance_paragraph_kgat_performances.jsonl") 152 | argparser.add_argument('--update_step', type=int, default=10) 153 | argparser.add_argument('--batch_size', type=int, default=1) # roberta-large: 2; bert: 8 154 | argparser.add_argument('--k', type=int, default=0) 155 | argparser.add_argument('--evaluation_step', type=int, default=50000) 156 | argparser.add_argument('--kernel', type=int, default=6) 157 | logging.getLogger("transformers.tokenization_utils_base").setLevel(logging.ERROR) 158 | 159 | reset_random_seed(12345) 160 | 161 | args = argparser.parse_args() 162 | 163 | with open(args.checkpoint+".log", 'w') as f: 164 | sys.stdout = f 165 | 166 | device = torch.device("cuda:0" if torch.cuda.is_available() else 'cpu') 167 | tokenizer = AutoTokenizer.from_pretrained(args.repfile) 168 | 169 | if args.train_file: 170 | train = True 171 | #assert args.repfile is not None, "Word embedding file required for training." 172 | else: 173 | train = False 174 | if args.test_file: 175 | test = True 176 | else: 177 | test = False 178 | 179 | params = vars(args) 180 | 181 | for k,v in params.items(): 182 | print(k,v) 183 | 184 | if train: 185 | train_set = FEVERParagraphBatchDataset(args.train_file, 186 | sep_token = tokenizer.sep_token, k=args.k) 187 | dev_set = FEVERParagraphBatchDataset(args.test_file, 188 | sep_token = tokenizer.sep_token, k=args.k) 189 | 190 | print("Loaded dataset!") 191 | 192 | model = JointParagraphClassifier(args.repfile, args.bert_dim, 193 | args.dropout, kernel = args.kernel).to(device) 194 | 195 | if args.pre_trained_model is not None: 196 | model.load_state_dict(torch.load(args.pre_trained_model)) 197 | print("Loaded pre-trained model.") 198 | 199 | if train: 200 | settings = [{'params': model.bert.parameters(), 'lr': args.bert_lr}] 201 | for module in model.extra_modules: 202 | settings.append({'params': module.parameters(), 'lr': args.lr}) 203 | optimizer = torch.optim.Adam(settings) 204 | scheduler = get_cosine_schedule_with_warmup(optimizer, 0, args.epoch) 205 | 206 | #if torch.cuda.device_count() > 1: 207 | # print("Let's use", torch.cuda.device_count(), "GPUs!") 208 | # # dim = 0 [30, xxx] -> [10, ...], [10, ...], [10, ...] on 3 GPUs 209 | # model = nn.DataParallel(model) 210 | 211 | model.train() 212 | 213 | for epoch in range(args.epoch): 214 | error_count = 0 215 | tq = tqdm(DataLoader(train_set, batch_size = args.batch_size, shuffle=True)) 216 | for i, batch in enumerate(tq): 217 | encoded_dict = encode(tokenizer, batch) 218 | transformation_indices = token_idx_by_sentence(encoded_dict["input_ids"], tokenizer.sep_token_id, args.repfile) 219 | encoded_dict = {key: tensor.to(device) for key, tensor in encoded_dict.items()} 220 | transformation_indices = [tensor.to(device) for tensor in transformation_indices] 221 | stance_label = batch["stance"].to(device) 222 | 223 | stance_out, loss = \ 224 | model(encoded_dict, transformation_indices, stance_label = stance_label) 225 | loss.sum().backward() 226 | 227 | if i % args.update_step == args.update_step - 1: 228 | optimizer.step() 229 | optimizer.zero_grad() 230 | tq.set_description(f'Epoch {epoch}, iter {i}, loss: {round(loss.item(), 4)}') 231 | 232 | 233 | if i % args.evaluation_step == args.evaluation_step-1: 234 | torch.save(model.state_dict(), args.checkpoint+"_"+str(epoch)+"_"+str(i)+".model") 235 | 236 | # Evaluation 237 | subset_train = Subset(train_set, range(0, 1000)) 238 | train_score = evaluation(model, subset_train) 239 | print(f'Epoch {epoch}, step {i}, train stance f1 p r: %.4f, %.4f, %.4f' % train_score) 240 | 241 | subset_dev = Subset(dev_set, range(0, 1000)) 242 | dev_score = evaluation(model, subset_dev) 243 | print(f'Epoch {epoch}, step {i}, dev stance f1 p r: %.4f, %.4f, %.4f' % dev_score) 244 | scheduler.step() 245 | #torch.save(model.state_dict(), args.checkpoint+"_"+str(epoch)+".model") ############## 246 | print(error_count, "mismatch occurred.") 247 | 248 | # Evaluation 249 | subset_train = Subset(train_set, range(0, 100)) 250 | train_score = evaluation(model, subset_train) 251 | print(f'Epoch {epoch}, train stance f1 p r: %.4f, %.4f, %.4f' % train_score) 252 | 253 | subset_dev = Subset(dev_set, range(0, 100)) 254 | dev_score = evaluation(model, subset_dev) 255 | print(f'Epoch {epoch}, dev stance f1 p r: %.4f, %.4f, %.4f' % dev_score) 256 | 257 | 258 | 259 | if test: 260 | #model = JointParagraphClassifier(args.repfile, args.bert_dim, 261 | # args.dropout, kernel = args.kernel).to(device) 262 | #model.load_state_dict(torch.load(args.checkpoint)) 263 | 264 | 265 | # Evaluation 266 | subset_dev = Subset(dev_set, range(0, 100)) 267 | dev_score = evaluation(model, subset_dev) 268 | print(f'Test stance f1 p r: %.4f, %.4f, %.4f' % dev_score) 269 | 270 | if train: 271 | params["stance_f1"] = dev_score[0] 272 | params["stance_precision"] = dev_score[1] 273 | params["stance_recall"] = dev_score[2] 274 | 275 | with jsonlines.open(args.log_file, mode='a') as writer: 276 | writer.write(params) -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ParagraphJointModel 2 | Implementation of The AAAI-21 Workshop on Scientific Document Understanding paper [A Paragraph-level Multi-task Learning Model for Scientific Fact-Verification](https://ceur-ws.org/Vol-2831/paper8.pdf). There is a short [video](https://www.youtube.com/watch?v=YrpYAdNl05Y) available for this work! This work is at top 2 of [SciFact leaderboard](https://leaderboard.allenai.org/scifact/submissions/public) as of March 30th, 2021. 3 | 4 | ## Reproducing SciFact Leaderboard Result 5 | ### Dependencies 6 | 7 | We recommend you create an anaconda environment: 8 | ```bash 9 | conda create --name scifact python=3.6 conda-build 10 | ``` 11 | Then, from the `scifact` project root, run 12 | ``` 13 | conda develop . 14 | ``` 15 | which will add the scifact code to your `PYTHONPATH`. 16 | 17 | Then, install Python requirements: 18 | ``` 19 | pip install -r requirements.txt 20 | ``` 21 | If you encounter any installation problem regarding sent2vec, please check [their repo](https://github.com/epfml/sent2vec). 22 | The BioSentVec model is available [here](https://github.com/ncbi-nlp/BioSentVec#biosentvec). 23 | 24 | The SciFact claim files and corpus file are available at [SciFact repo](https://github.com/allenai/scifact). 25 | The checkpoint of Paragraph-Joint model used for the paper (trained on training set) is available [here](https://drive.google.com/file/d/1agyrkUGJ0lxTBJpdy1QCyaAAJyxBnoO2/view?usp=sharing). 26 | The checkpoint of Paragraph-Joint model used for leaderboard submission (trained on train+dev set) is available [here](https://drive.google.com/file/d/1hMrQzFe1EaJpCN9s3pF27Wu3amBbekiI/view?usp=sharing). 27 | 28 | ### Abstract Retrieval 29 | ``` 30 | python ComputeBioSentVecAbstractEmbedding.py --claim_file /path/to/claims.jsonl --corpus_file /path/to/corpus.jsonl --sentvec_path /path/to/sentvec_model 31 | 32 | python SentVecAbstractRetriaval.py --claim_file /path/to/claims.jsonl --corpus_file /path/to/corpus.jsonl --k_retrieval 30 --claim_retrieved_file /output/path/of/retrieval_file.jsonl --scifact_abstract_retrieval_file /output/path/of/retrieval_file_scifact_format.jsonl 33 | ``` 34 | The retrieved abstracts are available here: [train](https://drive.google.com/file/d/18yWhLP3n1OjT_XrUB3rJwNMnLRI3k8Ck/view?usp=sharing), [dev](https://drive.google.com/file/d/1fnfdOA2e3_U-kGavuhoyiYZUlDYWX9eM/view?usp=sharing), [test](https://drive.google.com/file/d/10Lh0aP06tGfZ-LlNGWnDtN0GM8M14z2q/view?usp=sharing). 35 | ### Training of the ParagraphJoint Model (Optional for Result Reproduction Purpose) 36 | #### FEVER Pre-training 37 | You need to retrieve some negative samples for FEVER pre-training. We used the trieval code from [here](https://github.com/sheffieldnlp/fever-naacl-2018). Empirically, only retrieving 5 negative examples for each claim is enough, while retrieving more may be way too time-consuming. You need to convert the format of the output of the retrieval code to the input of SciFact. 38 | 39 | For your convenience, the converted retrieved FEVER examples with `k_retrieval=15` are available: [train](https://drive.google.com/file/d/1sS6mpaALuWnk6Pl2twIt_GcBs7ExRY2b/view?usp=sharing), [dev](https://drive.google.com/file/d/1sOfFL6fvK-AYjzcGPJ5KqcFPmAMvQJUi/view?usp=sharing). 40 | 41 | The checkpoint of the Paragraph-Joint model only pretrained on the retrieved FEVER examples shared above is available [here](https://drive.google.com/file/d/12u9glqoCBuhxnP9P8dM4HSncAIjHmQ_U/view?usp=sharing). 42 | 43 | Run `FEVER_joint_paragraph_dynamic.py` to pre-train the model on FEVER. Use `--checkpoint` to specify the checkpoint path. Run `scifact_joint_paragraph_dynamic.py` to fine-tune on SciFact dataset. Use `--pre_trained_model` to load the pre-trained model. Please check the other options in the source file. 44 | 45 | ### Joint Prediction of Rationale Selection and Stance Prediciton 46 | ``` 47 | python scifact_joint_paragraph_dynamic_prediction.py --corpus_file /path/to/corpus.jsonl --test_file /path/to/retrieval_file.jsonl --dataset /path/to/scifact/claims_test.jsonl --batch_size 25 --k 30 --prediction /path/to/output.jsonl --evaluate --checkpoint /path/to/checkpoint 48 | ``` 49 | 50 | ## File naming conventions 51 | The file names should be self-explanatory. Most parameters are set with default values. The parameters should be straight forward. 52 | 53 | ### Non-Joint Models 54 | File names with `rationale` and `stance` are those scripts for rationale selection and stance prediction models. 55 | 56 | ### FEVER Pretraining and Domain-Adaptation 57 | File names with `FEVER` are scripts for training on FEVER dataset. Same for `domain_adaptation`. 58 | 59 | ### Prediction 60 | File names with `prediction` are scripts for taking the pre-trained models and perform inference. 61 | 62 | ### KGAT 63 | File names with `kgat` means those models with [KGAT](https://github.com/xiangwang1223/knowledge_graph_attention_network) as stance predictor. 64 | 65 | ### Fine-tuning 66 | You can use `--pre_trained_model path/to/pre_trained.model` to load a model trained on FEVER dataset and fine-tune on SciFact. 67 | 68 | ## Cite our paper 69 | ``` 70 | @inproceedings{li2021paragraph, 71 | title={A Paragraph-level Multi-task Learning Model for Scientific Fact-Verification.}, 72 | author={Li, Xiangci and Burns, Gully A and Peng, Nanyun}, 73 | booktitle={SDU@ AAAI}, 74 | year={2021} 75 | } 76 | ``` 77 | 78 | -------------------------------------------------------------------------------- /SentVecAbstractRetriaval.py: -------------------------------------------------------------------------------- 1 | from sklearn.metrics.pairwise import cosine_similarity 2 | import pickle 3 | import json 4 | import numpy as np 5 | from tqdm import tqdm 6 | import jsonlines 7 | import argparse 8 | if __name__ == "__main__": 9 | argparser = argparse.ArgumentParser() 10 | argparser.add_argument('--claim_file', type=str) 11 | argparser.add_argument('--corpus_file', type=str) 12 | argparser.add_argument('--k_retrieval', type=int) 13 | argparser.add_argument('--claim_retrieved_file', type=str) 14 | argparser.add_argument('--scifact_abstract_retrieval_file', type=str, help="abstract retreival in scifact format") 15 | argparser.add_argument('--corpus_embedding_pickle', type=str, default="corpus_paragraph_biosentvec.pkl") 16 | argparser.add_argument('--claim_embedding_pickle', type=str, default="claim_biosentvec.pkl") 17 | 18 | args = argparser.parse_args() 19 | 20 | with open(args.corpus_embedding_pickle,"rb") as f: 21 | corpus_embeddings = pickle.load(f) 22 | 23 | with open(args.claim_embedding_pickle,"rb") as f: 24 | claim_embeddings = pickle.load(f) 25 | 26 | claim_file = args.claim_file 27 | 28 | claims = [] 29 | with open(claim_file) as f: 30 | for line in f: 31 | claim = json.loads(line) 32 | claims.append(claim) 33 | claims_by_id = {claim['id']:claim for claim in claims} 34 | 35 | all_similarities = {} 36 | for claim_id, claim_embedding in tqdm(claim_embeddings.items()): 37 | this_claim = {} 38 | for abstract_id, abstract_embedding in corpus_embeddings.items(): 39 | claim_similarity = cosine_similarity(claim_embedding,abstract_embedding) 40 | this_claim[abstract_id] = claim_similarity 41 | all_similarities[claim_id] = this_claim 42 | 43 | ordered_corpus = {} 44 | for claim_id, claim_similarities in tqdm(all_similarities.items()): 45 | corpus_ids = [] 46 | max_similarity = [] 47 | for abstract_id, similarity in claim_similarities.items(): 48 | corpus_ids.append(abstract_id) 49 | max_similarity.append(np.max(similarity)) 50 | corpus_ids = np.array(corpus_ids) 51 | sorted_order = np.argsort(max_similarity)[::-1] 52 | ordered_corpus[claim_id] = corpus_ids[sorted_order] 53 | 54 | k = args.k_retrieval 55 | retrieved_corpus = {ID:v[:k] for ID,v in ordered_corpus.items()} 56 | 57 | with jsonlines.open(args.claim_retrieved_file, 'w') as output: 58 | claim_ids = sorted(list(claims_by_id.keys())) 59 | for id in claim_ids: 60 | claims_by_id[id]["retrieved_doc_ids"] = retrieved_corpus[id].tolist() 61 | output.write(claims_by_id[id]) 62 | 63 | with jsonlines.open(args.scifact_abstract_retrieval_file, 'w') as output: 64 | claim_ids = sorted(list(claims_by_id.keys())) 65 | for id in claim_ids: 66 | doc_ids = retrieved_corpus[id].tolist() 67 | doc_ids = [int(id) for id in doc_ids] 68 | output.write({"claim_id": id, "doc_ids": doc_ids}) 69 | -------------------------------------------------------------------------------- /TFIDFabstractRetrieval.py: -------------------------------------------------------------------------------- 1 | from sklearn.feature_extraction.text import TfidfVectorizer 2 | from nltk import word_tokenize 3 | from nltk.corpus import stopwords 4 | from string import punctuation 5 | from scipy.spatial import distance 6 | import json 7 | import numpy as np 8 | import jsonlines 9 | 10 | import argparse 11 | 12 | if __name__ == "__main__": 13 | argparser = argparse.ArgumentParser() 14 | argparser.add_argument('--claim_file', type=str) 15 | argparser.add_argument('--corpus_file', type=str) 16 | argparser.add_argument('--k_retrieval', type=int) 17 | argparser.add_argument('--claim_retrieved_file', type=str) 18 | 19 | claim_file = args.claim_file 20 | corpus_file = args.corpus_file 21 | 22 | corpus = {} 23 | with open(corpus_file) as f: 24 | for line in f: 25 | abstract = json.loads(line) 26 | corpus[str(abstract["doc_id"])] = abstract 27 | 28 | claims = [] 29 | with open(claim_file) as f: 30 | for line in f: 31 | claim = json.loads(line) 32 | claims.append(claim) 33 | claims_by_id = {claim['id']:claim for claim in claims} 34 | 35 | corpus_texts = [] 36 | corpus_ids = [] 37 | for k, v in corpus.items(): 38 | original_sentences = [v['title']] + v['abstract'] 39 | processed_paragraph = " ".join(original_sentences) 40 | corpus_texts.append(processed_paragraph) 41 | corpus_ids.append(k) 42 | vectorizer = TfidfVectorizer(stop_words='english', 43 | ngram_range=(1, 2)) 44 | corpus_ids = np.array(corpus_ids) 45 | corpus_vectors = vectorizer.fit_transform(corpus_texts) 46 | 47 | claim_vectors = vectorizer.transform([claim['claim'] for claim in claims]) 48 | similarity_matrix = np.dot(corpus_vectors, claim_vectors.T).todense() 49 | 50 | k = args.k_retrieval 51 | orders = np.argsort(similarity_matrix,axis=0) 52 | retrieved_corpus = {claim["id"]: corpus_ids[orders[:,i][::-1][:k]].squeeze() for i, claim in enumerate(claims)} 53 | 54 | with jsonlines.open(args.claim_retrieved_file, 'w') as output: 55 | claim_ids = sorted(list(claims_by_id.keys())) 56 | for id in claim_ids: 57 | claims_by_id[id]["retrieved_doc_ids"] = retrieved_corpus[id].tolist() 58 | output.write(claims_by_id[id]) -------------------------------------------------------------------------------- /domain_adaptation_joint_paragraph_fine_tune.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import torch 4 | import jsonlines 5 | import os 6 | import sys 7 | 8 | import functools 9 | print = functools.partial(print, flush=True) 10 | 11 | import torch.nn as nn 12 | import torch.nn.functional as F 13 | from torch.utils.data import Dataset, DataLoader, Subset 14 | from transformers import AutoModel, AutoTokenizer, get_cosine_schedule_with_warmup 15 | from tqdm import tqdm 16 | from typing import List 17 | from sklearn.metrics import f1_score, precision_score, recall_score 18 | 19 | import math 20 | import random 21 | import numpy as np 22 | 23 | from tqdm import tqdm 24 | from util import arg2param, flatten, stance2json, rationale2json, merge_json 25 | from paragraph_model_dynamic import DomainAdaptationJointParagraphClassifier 26 | from dataset import FEVERParagraphBatchDataset, SciFactParagraphBatchDataset, SciFact_FEVER_Dataset, Multiple_SciFact_Dataset 27 | 28 | import logging 29 | 30 | from lib.data import GoldDataset, PredictedDataset 31 | from lib import metrics 32 | 33 | def schedule_sample_p(epoch, total): 34 | return np.sin(0.5* np.pi* epoch / (total-1)) 35 | 36 | def reset_random_seed(seed): 37 | np.random.seed(seed) 38 | random.seed(seed) 39 | torch.manual_seed(seed) 40 | 41 | def batch_rationale_label(labels, padding_idx = 2): 42 | max_sent_len = max([len(label) for label in labels]) 43 | label_matrix = torch.ones(len(labels), max_sent_len) * padding_idx 44 | label_list = [] 45 | for i, label in enumerate(labels): 46 | for j, evid in enumerate(label): 47 | label_matrix[i,j] = int(evid) 48 | label_list.append([int(evid) for evid in label]) 49 | return label_matrix.long(), label_list 50 | 51 | def predict(model, dataset): 52 | model.eval() 53 | rationale_predictions = [] 54 | stance_preds = [] 55 | 56 | def remove_dummy(rationale_out): 57 | return [out[1:] for out in rationale_out] 58 | 59 | with torch.no_grad(): 60 | for batch in tqdm(DataLoader(dataset, batch_size = args.batch_size, shuffle=False)): 61 | encoded_dict = encode(tokenizer, batch) 62 | domain_indices = batch["dataset"].to(device) 63 | transformation_indices = token_idx_by_sentence(encoded_dict["input_ids"], 64 | tokenizer.sep_token_id, args.repfile) 65 | encoded_dict = {key: tensor.to(device) for key, tensor in encoded_dict.items()} 66 | transformation_indices = [tensor.to(device) for tensor in transformation_indices] 67 | rationale_out, stance_out, _, _ = model(encoded_dict, transformation_indices, domain_indices) 68 | stance_preds.extend(stance_out) 69 | rationale_predictions.extend(remove_dummy(rationale_out)) 70 | 71 | return rationale_predictions, stance_preds 72 | 73 | def evaluation(model, dataset): 74 | model.eval() 75 | rationale_predictions = [] 76 | rationale_labels = [] 77 | stance_preds = [] 78 | stance_labels = [] 79 | 80 | def remove_dummy(rationale_out): 81 | return [out[1:] for out in rationale_out] 82 | 83 | with torch.no_grad(): 84 | for batch in tqdm(DataLoader(dataset, batch_size = args.batch_size, shuffle=False)): 85 | encoded_dict = encode(tokenizer, batch) 86 | transformation_indices = token_idx_by_sentence(encoded_dict["input_ids"], 87 | tokenizer.sep_token_id, args.repfile) 88 | encoded_dict = {key: tensor.to(device) for key, tensor in encoded_dict.items()} 89 | transformation_indices = [tensor.to(device) for tensor in transformation_indices] 90 | stance_label = batch["stance"].to(device) 91 | domain_indices = batch["dataset"].to(device) 92 | padded_rationale_label, rationale_label = batch_rationale_label(batch["label"], padding_idx = 2) 93 | rationale_out, stance_out, rationale_loss, stance_loss = \ 94 | model(encoded_dict, transformation_indices, domain_indices, stance_label = stance_label, 95 | rationale_label = padded_rationale_label.to(device)) 96 | stance_preds.extend(stance_out) 97 | stance_labels.extend(stance_label.cpu().numpy().tolist()) 98 | 99 | rationale_predictions.extend(remove_dummy(rationale_out)) 100 | rationale_labels.extend(remove_dummy(rationale_label)) 101 | 102 | stance_f1 = f1_score(stance_labels,stance_preds,average="micro",labels=[1, 2]) 103 | stance_precision = precision_score(stance_labels,stance_preds,average="micro",labels=[1, 2]) 104 | stance_recall = recall_score(stance_labels,stance_preds,average="micro",labels=[1, 2]) 105 | rationale_f1 = f1_score(flatten(rationale_labels),flatten(rationale_predictions)) 106 | rationale_precision = precision_score(flatten(rationale_labels),flatten(rationale_predictions)) 107 | rationale_recall = recall_score(flatten(rationale_labels),flatten(rationale_predictions)) 108 | return stance_f1, stance_precision, stance_recall, rationale_f1, rationale_precision, rationale_recall 109 | 110 | def encode(tokenizer, batch, max_sent_len = 512): 111 | def truncate(input_ids, max_length, sep_token_id, pad_token_id): 112 | def longest_first_truncation(sentences, objective): 113 | sent_lens = [len(sent) for sent in sentences] 114 | while np.sum(sent_lens) > objective: 115 | max_position = np.argmax(sent_lens) 116 | sent_lens[max_position] -= 1 117 | return [sentence[:length] for sentence, length in zip(sentences, sent_lens)] 118 | 119 | all_paragraphs = [] 120 | for paragraph in input_ids: 121 | valid_paragraph = paragraph[paragraph != pad_token_id] 122 | if valid_paragraph.size(0) <= max_length: 123 | all_paragraphs.append(paragraph[:max_length].unsqueeze(0)) 124 | else: 125 | sep_token_idx = np.arange(valid_paragraph.size(0))[(valid_paragraph == sep_token_id).numpy()] 126 | idx_by_sentence = [] 127 | prev_idx = 0 128 | for idx in sep_token_idx: 129 | idx_by_sentence.append(paragraph[prev_idx:idx]) 130 | prev_idx = idx 131 | objective = max_length - 1 - len(idx_by_sentence[0]) # The last sep_token left out 132 | truncated_sentences = longest_first_truncation(idx_by_sentence[1:], objective) 133 | truncated_paragraph = torch.cat([idx_by_sentence[0]] + truncated_sentences + [torch.tensor([sep_token_id])],0) 134 | all_paragraphs.append(truncated_paragraph.unsqueeze(0)) 135 | 136 | return torch.cat(all_paragraphs, 0) 137 | 138 | inputs = zip(batch["claim"], batch["paragraph"]) 139 | encoded_dict = tokenizer.batch_encode_plus( 140 | inputs, 141 | pad_to_max_length=True,add_special_tokens=True, 142 | return_tensors='pt') 143 | if encoded_dict['input_ids'].size(1) > max_sent_len: 144 | if 'token_type_ids' in encoded_dict: 145 | encoded_dict = { 146 | "input_ids": truncate(encoded_dict['input_ids'], max_sent_len, 147 | tokenizer.sep_token_id, tokenizer.pad_token_id), 148 | 'token_type_ids': encoded_dict['token_type_ids'][:,:max_sent_len], 149 | 'attention_mask': encoded_dict['attention_mask'][:,:max_sent_len] 150 | } 151 | else: 152 | encoded_dict = { 153 | "input_ids": truncate(encoded_dict['input_ids'], max_sent_len, 154 | tokenizer.sep_token_id, tokenizer.pad_token_id), 155 | 'attention_mask': encoded_dict['attention_mask'][:,:max_sent_len] 156 | } 157 | 158 | return encoded_dict 159 | 160 | 161 | def token_idx_by_sentence(input_ids, sep_token_id, model_name): 162 | """ 163 | Compute the token indices matrix of the BERT output. 164 | input_ids: (batch_size, paragraph_len) 165 | batch_indices, indices_by_batch, mask: (batch_size, N_sentence, N_token) 166 | bert_out: (batch_size, paragraph_len,BERT_dim) 167 | bert_out[batch_indices,indices_by_batch,:]: (batch_size, N_sentence, N_token, BERT_dim) 168 | """ 169 | padding_idx = -1 170 | sep_tokens = (input_ids == sep_token_id).bool() 171 | paragraph_lens = torch.sum(sep_tokens,1).numpy().tolist() 172 | indices = torch.arange(sep_tokens.size(-1)).unsqueeze(0).expand(sep_tokens.size(0),-1) 173 | sep_indices = torch.split(indices[sep_tokens],paragraph_lens) 174 | paragraph_lens = [] 175 | all_word_indices = [] 176 | for paragraph in sep_indices: 177 | if "roberta" in model_name: 178 | paragraph = paragraph[1:] 179 | word_indices = [torch.arange(paragraph[i]+1, paragraph[i+1]+1) for i in range(paragraph.size(0)-1)] 180 | paragraph_lens.append(len(word_indices)) 181 | all_word_indices.extend(word_indices) 182 | indices_by_sentence = nn.utils.rnn.pad_sequence(all_word_indices, batch_first=True, padding_value=padding_idx) 183 | indices_by_sentence_split = torch.split(indices_by_sentence,paragraph_lens) 184 | indices_by_batch = nn.utils.rnn.pad_sequence(indices_by_sentence_split, batch_first=True, padding_value=padding_idx) 185 | batch_indices = torch.arange(sep_tokens.size(0)).unsqueeze(-1).unsqueeze(-1).expand(-1,indices_by_batch.size(1),indices_by_batch.size(-1)) 186 | mask = (indices_by_batch>=0) 187 | 188 | return batch_indices.long(), indices_by_batch.long(), mask.long() 189 | 190 | def post_process_stance(rationale_json, stance_json): 191 | assert(len(rationale_json) == len(stance_json)) 192 | for stance_pred, rationale_pred in zip(stance_json, rationale_json): 193 | assert(stance_pred["claim_id"] == rationale_pred["claim_id"]) 194 | for doc_id, pred in rationale_pred["evidence"].items(): 195 | if len(pred) == 0: 196 | stance_pred["labels"][doc_id]["label"] = "NOT_ENOUGH_INFO" 197 | return stance_json 198 | 199 | if __name__ == "__main__": 200 | argparser = argparse.ArgumentParser(description="Train, cross-validate and run sentence sequence tagger") 201 | argparser.add_argument('--repfile', type=str, default = "roberta-large", help="Word embedding file") 202 | argparser.add_argument('--scifact_corpus', type=str, default="/nas/home/xiangcil/scifact/data/corpus.jsonl") 203 | argparser.add_argument('--scifact_train', type=str, default="/nas/home/xiangcil/CitationEvaluation/SciFact/claims_train_retrieved.jsonl") 204 | #argparser.add_argument('--scifact_train', type=str) 205 | argparser.add_argument('--pre_trained_model', type=str) 206 | argparser.add_argument('--scifact_test', type=str, default="/nas/home/xiangcil/CitationEvaluation/SciFact/claims_dev_retrieved.jsonl") 207 | argparser.add_argument('--dataset', type=str, default="/nas/home/xiangcil/CitationEvaluation/SciFact/claims_dev.jsonl") 208 | argparser.add_argument('--bert_lr', type=float, default=1e-5, help="Learning rate for BERT-like LM") 209 | argparser.add_argument('--lr', type=float, default=5e-6, help="Learning rate") 210 | argparser.add_argument('--dropout', type=float, default=0, help="embedding_dropout rate") 211 | argparser.add_argument('--bert_dim', type=int, default=1024, help="bert_dimension") 212 | argparser.add_argument('--epoch_start', type=int, default=0, help="Training epoch") 213 | argparser.add_argument('--epoch', type=int, default=10, help="Training epoch") 214 | argparser.add_argument('--MAX_SENT_LEN', type=int, default=512) 215 | argparser.add_argument('--loss_ratio', type=float, default=5) 216 | argparser.add_argument('--checkpoint', type=str, default = "domain_adaptation_roberta_joint_paragraph_fine_tune.model") 217 | argparser.add_argument('--log_file', type=str, default = "domain_adaptation_joint_paragraph_fine_tune_performances.jsonl") 218 | argparser.add_argument('--prediction', type=str, default = "prediction_domain_adaptation_roberta_joint_paragraph_dynamic_fine_tune.jsonl") 219 | argparser.add_argument('--update_step', type=int, default=10) 220 | argparser.add_argument('--batch_size', type=int, default=1) # roberta-large: 2; bert: 8 221 | argparser.add_argument('--scifact_k', type=int, default=0) 222 | 223 | logging.getLogger("transformers.tokenization_utils_base").setLevel(logging.ERROR) 224 | reset_random_seed(12345) 225 | args = argparser.parse_args() 226 | 227 | device = torch.device("cuda:0" if torch.cuda.is_available() else 'cpu') 228 | tokenizer = AutoTokenizer.from_pretrained(args.repfile) 229 | 230 | if args.scifact_train: 231 | train = True 232 | #assert args.repfile is not None, "Word embedding file required for training." 233 | else: 234 | train = False 235 | if args.scifact_test: 236 | test = True 237 | else: 238 | test = False 239 | 240 | params = vars(args) 241 | 242 | for k,v in params.items(): 243 | print(k,v) 244 | 245 | if train: 246 | train_set = SciFactParagraphBatchDataset(args.scifact_corpus, args.scifact_train, 247 | sep_token = tokenizer.sep_token, k = args.scifact_k, downsample_n = 2) 248 | 249 | dev_set = SciFactParagraphBatchDataset(args.scifact_corpus, args.scifact_test, 250 | sep_token = tokenizer.sep_token, k = args.scifact_k, downsample_n = 0) 251 | 252 | print("Loaded dataset!") 253 | 254 | model = DomainAdaptationJointParagraphClassifier(args.repfile, args.bert_dim, 255 | args.dropout).to(device) 256 | 257 | if args.pre_trained_model is not None: 258 | model.load_state_dict(torch.load(args.pre_trained_model)) 259 | print("Loaded pre-trained model.") 260 | 261 | if train: 262 | settings = [{'params': model.bert.parameters(), 'lr': args.bert_lr}] 263 | for module in model.extra_modules: 264 | settings.append({'params': module.parameters(), 'lr': args.lr}) 265 | optimizer = torch.optim.Adam(settings) 266 | scheduler = get_cosine_schedule_with_warmup(optimizer, 0, args.epoch) 267 | 268 | #if torch.cuda.device_count() > 1: 269 | # print("Let's use", torch.cuda.device_count(), "GPUs!") 270 | # # dim = 0 [30, xxx] -> [10, ...], [10, ...], [10, ...] on 3 GPUs 271 | # model = nn.DataParallel(model) 272 | 273 | model.train() 274 | prev_performance = 0 275 | for epoch in range(args.epoch_start, args.epoch): 276 | tq = tqdm(DataLoader(train_set, batch_size = args.batch_size, shuffle=True)) 277 | for i, batch in enumerate(tq): 278 | encoded_dict = encode(tokenizer, batch) 279 | transformation_indices = token_idx_by_sentence(encoded_dict["input_ids"], tokenizer.sep_token_id, args.repfile) 280 | encoded_dict = {key: tensor.to(device) for key, tensor in encoded_dict.items()} 281 | transformation_indices = [tensor.to(device) for tensor in transformation_indices] 282 | stance_label = batch["stance"].to(device) 283 | domain_indices = batch["dataset"].to(device) 284 | padded_rationale_label, rationale_label = batch_rationale_label(batch["label"], padding_idx = 2) 285 | rationale_out, stance_out, rationale_loss, stance_loss = \ 286 | model(encoded_dict, transformation_indices, domain_indices, stance_label = stance_label, 287 | rationale_label = padded_rationale_label.to(device)) 288 | rationale_loss *= args.loss_ratio 289 | loss = rationale_loss + stance_loss 290 | loss.sum().backward() 291 | 292 | if i % args.update_step == args.update_step - 1: 293 | optimizer.step() 294 | optimizer.zero_grad() 295 | tq.set_description(f'Epoch {epoch}, iter {i}, loss: {round(loss.item(), 4)}, stance loss: {round(stance_loss.item(), 4)}, rationale loss: {round(rationale_loss.item(), 4)}') 296 | 297 | scheduler.step() 298 | 299 | # Evaluation 300 | train_score = evaluation(model, train_set) 301 | print(f'Epoch {epoch}, train stance f1 p r: %.4f, %.4f, %.4f, rationale f1 p r: %.4f, %.4f, %.4f' % train_score) 302 | 303 | dev_score = evaluation(model, dev_set) 304 | print(f'Epoch {epoch}, dev stance f1 p r: %.4f, %.4f, %.4f, rationale f1 p r: %.4f, %.4f, %.4f' % dev_score) 305 | 306 | dev_perf = dev_score[0] * dev_score[3] 307 | if dev_perf >= prev_performance: 308 | torch.save(model.state_dict(), args.checkpoint) 309 | best_state_dict = model.state_dict() 310 | prev_performance = dev_perf 311 | print("New model saved!") 312 | else: 313 | print("Skip saving model.") 314 | 315 | 316 | if test: 317 | if train: 318 | del model 319 | model = DomainAdaptationJointParagraphClassifier(args.repfile, args.bert_dim, 320 | args.dropout).to(device) 321 | model.load_state_dict(best_state_dict) 322 | print("Testing on the new model.") 323 | else: 324 | model.load_state_dict(torch.load(args.checkpoint)) 325 | print("Loaded saved model.") 326 | 327 | # Evaluation 328 | #dev_score = evaluation(model, dev_set) 329 | #print(f'Test stance f1 p r: %.4f, %.4f, %.4f, rationale f1 p r: %.4f, %.4f, %.4f' % dev_score) 330 | 331 | rationale_predictions, stance_preds = predict(model, dev_set) 332 | rationale_json = rationale2json(dev_set.samples, rationale_predictions) 333 | stance_json = stance2json(dev_set.samples, stance_preds) 334 | stance_json = post_process_stance(rationale_json, stance_json) 335 | merged_json = merge_json(rationale_json, stance_json) 336 | 337 | with jsonlines.open(args.prediction, 'w') as output: 338 | for result in merged_json: 339 | output.write(result) 340 | 341 | data = GoldDataset(args.scifact_corpus, args.dataset) 342 | predictions = PredictedDataset(data, args.prediction) 343 | res = metrics.compute_metrics(predictions) 344 | params["evaluation"] = res 345 | with jsonlines.open(args.log_file, mode='a') as writer: 346 | writer.write(params) -------------------------------------------------------------------------------- /domain_adaptation_joint_paragraph_kgat_prediction.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import argparse 3 | 4 | import torch 5 | import jsonlines 6 | import os 7 | 8 | import functools 9 | print = functools.partial(print, flush=True) 10 | 11 | import torch.nn as nn 12 | import torch.nn.functional as F 13 | from torch.utils.data import Dataset, DataLoader, Subset 14 | from transformers import AutoModel, AutoTokenizer, get_cosine_schedule_with_warmup 15 | from tqdm import tqdm 16 | from typing import List 17 | from sklearn.metrics import f1_score, precision_score, recall_score 18 | 19 | import math 20 | import random 21 | import numpy as np 22 | 23 | from tqdm import tqdm 24 | from util import arg2param, flatten, stance2json, rationale2json, merge_json 25 | from paragraph_model_kgat import DomainAdaptationJointParagraphKGATClassifier 26 | from dataset import SciFactParagraphBatchDataset, FEVERParagraphBatchDataset, SciFact_FEVER_Dataset, Multiple_SciFact_Dataset 27 | 28 | import logging 29 | 30 | from lib.data import GoldDataset, PredictedDataset 31 | from lib import metrics 32 | 33 | def reset_random_seed(seed): 34 | np.random.seed(seed) 35 | random.seed(seed) 36 | torch.manual_seed(seed) 37 | 38 | def batch_rationale_label(labels, padding_idx = 2): 39 | max_sent_len = max([len(label) for label in labels]) 40 | label_matrix = torch.ones(len(labels), max_sent_len) * padding_idx 41 | label_list = [] 42 | for i, label in enumerate(labels): 43 | for j, evid in enumerate(label): 44 | label_matrix[i,j] = int(evid) 45 | label_list.append([int(evid) for evid in label]) 46 | return label_matrix.long(), label_list 47 | 48 | def predict(model, dataset): 49 | model.eval() 50 | rationale_predictions = [] 51 | stance_preds = [] 52 | 53 | def remove_dummy(rationale_out): 54 | return [out[1:] for out in rationale_out] 55 | 56 | with torch.no_grad(): 57 | for batch in tqdm(DataLoader(dataset, batch_size = args.batch_size, shuffle=False)): 58 | encoded_dict = encode(tokenizer, batch) 59 | transformation_indices = token_idx_by_sentence(encoded_dict["input_ids"], 60 | tokenizer.sep_token_id, args.repfile) 61 | encoded_dict = {key: tensor.to(device) for key, tensor in encoded_dict.items()} 62 | transformation_indices = [tensor.to(device) for tensor in transformation_indices] 63 | domain_indices = batch["dataset"].to(device) 64 | rationale_out, stance_out, _, _ = model(encoded_dict, transformation_indices, domain_indices) 65 | stance_preds.extend(stance_out) 66 | rationale_predictions.extend(remove_dummy(rationale_out)) 67 | 68 | return rationale_predictions, stance_preds 69 | 70 | def encode(tokenizer, batch, max_sent_len = 512): 71 | def truncate(input_ids, max_length, sep_token_id, pad_token_id): 72 | def longest_first_truncation(sentences, objective): 73 | sent_lens = [len(sent) for sent in sentences] 74 | while np.sum(sent_lens) > objective: 75 | max_position = np.argmax(sent_lens) 76 | sent_lens[max_position] -= 1 77 | return [sentence[:length] for sentence, length in zip(sentences, sent_lens)] 78 | 79 | all_paragraphs = [] 80 | for paragraph in input_ids: 81 | valid_paragraph = paragraph[paragraph != pad_token_id] 82 | if valid_paragraph.size(0) <= max_length: 83 | all_paragraphs.append(paragraph[:max_length].unsqueeze(0)) 84 | else: 85 | sep_token_idx = np.arange(valid_paragraph.size(0))[(valid_paragraph == sep_token_id).numpy()] 86 | idx_by_sentence = [] 87 | prev_idx = 0 88 | for idx in sep_token_idx: 89 | idx_by_sentence.append(paragraph[prev_idx:idx]) 90 | prev_idx = idx 91 | objective = max_length - 1 - len(idx_by_sentence[0]) # The last sep_token left out 92 | truncated_sentences = longest_first_truncation(idx_by_sentence[1:], objective) 93 | truncated_paragraph = torch.cat([idx_by_sentence[0]] + truncated_sentences + [torch.tensor([sep_token_id])],0) 94 | all_paragraphs.append(truncated_paragraph.unsqueeze(0)) 95 | 96 | return torch.cat(all_paragraphs, 0) 97 | 98 | inputs = zip(batch["claim"], batch["paragraph"]) 99 | encoded_dict = tokenizer.batch_encode_plus( 100 | inputs, 101 | pad_to_max_length=True,add_special_tokens=True, 102 | return_tensors='pt') 103 | if encoded_dict['input_ids'].size(1) > max_sent_len: 104 | if 'token_type_ids' in encoded_dict: 105 | encoded_dict = { 106 | "input_ids": truncate(encoded_dict['input_ids'], max_sent_len, 107 | tokenizer.sep_token_id, tokenizer.pad_token_id), 108 | 'token_type_ids': encoded_dict['token_type_ids'][:,:max_sent_len], 109 | 'attention_mask': encoded_dict['attention_mask'][:,:max_sent_len] 110 | } 111 | else: 112 | encoded_dict = { 113 | "input_ids": truncate(encoded_dict['input_ids'], max_sent_len, 114 | tokenizer.sep_token_id, tokenizer.pad_token_id), 115 | 'attention_mask': encoded_dict['attention_mask'][:,:max_sent_len] 116 | } 117 | 118 | return encoded_dict 119 | 120 | def token_idx_by_sentence(input_ids, sep_token_id, model_name): 121 | """ 122 | Advanced indexing: Compute the token indices matrix of the BERT output. 123 | input_ids: (batch_size, paragraph_len) 124 | batch_indices, indices_by_batch, mask: (batch_size, N_sentence, N_token) 125 | bert_out: (batch_size, paragraph_len,BERT_dim) 126 | bert_out[batch_indices,indices_by_batch,:]: (batch_size, N_sentence, N_token, BERT_dim) 127 | """ 128 | padding_idx = -1 129 | sep_tokens = (input_ids == sep_token_id).bool() 130 | paragraph_lens = torch.sum(sep_tokens,1).numpy().tolist() # i.e. N_sentences per paragraph 131 | indices = torch.arange(sep_tokens.size(-1)).unsqueeze(0).expand(sep_tokens.size(0),-1) # 0,1,2,3,....,511 for each sentence 132 | sep_indices = torch.split(indices[sep_tokens],paragraph_lens) # indices of SEP tokens per paragraph 133 | paragraph_lens = [] 134 | all_word_indices = [] 135 | for paragraph in sep_indices: 136 | # claim sentence: [CLS] token1 token2 ... tokenk 137 | claim_word_indices = torch.arange(0, paragraph[0]) 138 | if "roberta" in model_name: # Huggingface Roberta has ...... 139 | paragraph = paragraph[1:] 140 | # each sentence: [SEP] token1 token2 ... tokenk, the last [SEP] in the paragraph is ditched. 141 | sentence_word_indices = [torch.arange(paragraph[i], paragraph[i+1]) for i in range(paragraph.size(0)-1)] 142 | 143 | # KGAT requires claim sentence, so add it back. 144 | word_indices = [claim_word_indices] + sentence_word_indices 145 | 146 | paragraph_lens.append(len(word_indices)) 147 | all_word_indices.extend(word_indices) 148 | indices_by_sentence = nn.utils.rnn.pad_sequence(all_word_indices, batch_first=True, padding_value=padding_idx) 149 | indices_by_sentence_split = torch.split(indices_by_sentence,paragraph_lens) 150 | indices_by_batch = nn.utils.rnn.pad_sequence(indices_by_sentence_split, batch_first=True, padding_value=padding_idx) 151 | batch_indices = torch.arange(sep_tokens.size(0)).unsqueeze(-1).unsqueeze(-1).expand(-1,indices_by_batch.size(1),indices_by_batch.size(-1)) 152 | mask = (indices_by_batch>=0) 153 | 154 | return batch_indices.long(), indices_by_batch.long(), mask.long() 155 | 156 | def post_process_stance(rationale_json, stance_json): 157 | assert(len(rationale_json) == len(stance_json)) 158 | for stance_pred, rationale_pred in zip(stance_json, rationale_json): 159 | assert(stance_pred["claim_id"] == rationale_pred["claim_id"]) 160 | for doc_id, pred in rationale_pred["evidence"].items(): 161 | if len(pred) == 0: 162 | stance_pred["labels"][doc_id]["label"] = "NOT_ENOUGH_INFO" 163 | return stance_json 164 | 165 | if __name__ == "__main__": 166 | argparser = argparse.ArgumentParser(description="Train, cross-validate and run sentence sequence tagger") 167 | argparser.add_argument('--repfile', type=str, default = "roberta-large", help="Word embedding file") 168 | argparser.add_argument('--scifact_corpus', type=str, default="/home/xxl190027/scifact_data/corpus.jsonl") 169 | #argparser.add_argument('--fever_train', type=str) 170 | #argparser.add_argument('--scifact_train', type=str) 171 | argparser.add_argument('--scifact_test', type=str, default="/home/xxl190027/scifact_data/claims_dev_retrieved.jsonl") 172 | argparser.add_argument('--dropout', type=float, default=0, help="embedding_dropout rate") 173 | argparser.add_argument('--dataset', type=str, default="/home/xxl190027/scifact_data/claims_dev.jsonl") 174 | argparser.add_argument('--bert_dim', type=int, default=1024, help="bert_dimension") 175 | argparser.add_argument('--MAX_SENT_LEN', type=int, default=512) 176 | argparser.add_argument('--checkpoint', type=str, default = "domain_adaptation_roberta_joint_paragraph_kgat") 177 | argparser.add_argument('--log_file', type=str, default = "domain_adaptation_joint_paragraph_performances.log") 178 | argparser.add_argument('--update_step', type=int, default=10) 179 | argparser.add_argument('--batch_size', type=int, default=1) # roberta-large: 2; bert: 8 180 | argparser.add_argument('--scifact_k', type=int, default=0) 181 | argparser.add_argument('--kernel', type=int, default=6) 182 | argparser.add_argument('--prediction', type=str, default = "prediction_domain_adaptation_roberta_joint_paragraph_kgat.jsonl") 183 | logging.getLogger("transformers.tokenization_utils_base").setLevel(logging.ERROR) 184 | 185 | reset_random_seed(12345) 186 | 187 | args = argparser.parse_args() 188 | 189 | device = torch.device("cuda:0" if torch.cuda.is_available() else 'cpu') 190 | tokenizer = AutoTokenizer.from_pretrained(args.repfile) 191 | 192 | params = vars(args) 193 | 194 | for k,v in params.items(): 195 | print(k,v) 196 | 197 | dev_set = SciFactParagraphBatchDataset(args.scifact_corpus, args.scifact_test, 198 | sep_token = tokenizer.sep_token, k = args.scifact_k, train=False) 199 | print("Loaded dataset!") 200 | 201 | model = DomainAdaptationJointParagraphKGATClassifier(args.repfile, args.bert_dim, 202 | args.dropout, kernel = args.kernel).to(device) 203 | 204 | model.load_state_dict(torch.load(args.checkpoint)) 205 | 206 | reset_random_seed(12345) 207 | rationale_predictions, stance_preds = predict(model, dev_set) 208 | rationale_json = rationale2json(dev_set.samples, rationale_predictions) 209 | stance_json = stance2json(dev_set.samples, stance_preds) 210 | stance_json = post_process_stance(rationale_json, stance_json) 211 | merged_json = merge_json(rationale_json, stance_json) 212 | 213 | with jsonlines.open(args.prediction, 'w') as output: 214 | for result in merged_json: 215 | output.write(result) 216 | 217 | data = GoldDataset(args.scifact_corpus, args.dataset) 218 | predictions = PredictedDataset(data, args.prediction) 219 | res = metrics.compute_metrics(predictions) 220 | params["evaluation"] = res 221 | with jsonlines.open(args.log_file, mode='a') as writer: 222 | writer.write(params) -------------------------------------------------------------------------------- /domain_adaptation_joint_paragraph_predict.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import torch 4 | import jsonlines 5 | import os 6 | import sys 7 | 8 | import functools 9 | print = functools.partial(print, flush=True) 10 | 11 | import torch.nn as nn 12 | import torch.nn.functional as F 13 | from torch.utils.data import Dataset, DataLoader, Subset 14 | from transformers import AutoModel, AutoTokenizer, get_cosine_schedule_with_warmup 15 | from tqdm import tqdm 16 | from typing import List 17 | from sklearn.metrics import f1_score, precision_score, recall_score 18 | 19 | import math 20 | import random 21 | import numpy as np 22 | 23 | from tqdm import tqdm 24 | from util import arg2param, flatten, stance2json, rationale2json, merge_json 25 | from paragraph_model_dynamic import DomainAdaptationJointParagraphClassifier 26 | from dataset import FEVERParagraphBatchDataset, SciFactParagraphBatchDataset, SciFact_FEVER_Dataset, Multiple_SciFact_Dataset 27 | 28 | import logging 29 | 30 | from lib.data import GoldDataset, PredictedDataset 31 | from lib import metrics 32 | 33 | 34 | def reset_random_seed(seed): 35 | np.random.seed(seed) 36 | random.seed(seed) 37 | torch.manual_seed(seed) 38 | 39 | def batch_rationale_label(labels, padding_idx = 2): 40 | max_sent_len = max([len(label) for label in labels]) 41 | label_matrix = torch.ones(len(labels), max_sent_len) * padding_idx 42 | label_list = [] 43 | for i, label in enumerate(labels): 44 | for j, evid in enumerate(label): 45 | label_matrix[i,j] = int(evid) 46 | label_list.append([int(evid) for evid in label]) 47 | return label_matrix.long(), label_list 48 | 49 | def predict(model, dataset): 50 | model.eval() 51 | rationale_predictions = [] 52 | stance_preds = [] 53 | 54 | def remove_dummy(rationale_out): 55 | return [out[1:] for out in rationale_out] 56 | 57 | with torch.no_grad(): 58 | for batch in tqdm(DataLoader(dataset, batch_size = args.batch_size, shuffle=False)): 59 | encoded_dict = encode(tokenizer, batch) 60 | domain_indices = batch["dataset"].to(device) 61 | transformation_indices = token_idx_by_sentence(encoded_dict["input_ids"], 62 | tokenizer.sep_token_id, args.repfile) 63 | encoded_dict = {key: tensor.to(device) for key, tensor in encoded_dict.items()} 64 | transformation_indices = [tensor.to(device) for tensor in transformation_indices] 65 | rationale_out, stance_out, _, _ = model(encoded_dict, transformation_indices, domain_indices) 66 | stance_preds.extend(stance_out) 67 | rationale_predictions.extend(remove_dummy(rationale_out)) 68 | 69 | return rationale_predictions, stance_preds 70 | 71 | def encode(tokenizer, batch, max_sent_len = 512): 72 | def truncate(input_ids, max_length, sep_token_id, pad_token_id): 73 | def longest_first_truncation(sentences, objective): 74 | sent_lens = [len(sent) for sent in sentences] 75 | while np.sum(sent_lens) > objective: 76 | max_position = np.argmax(sent_lens) 77 | sent_lens[max_position] -= 1 78 | return [sentence[:length] for sentence, length in zip(sentences, sent_lens)] 79 | 80 | all_paragraphs = [] 81 | for paragraph in input_ids: 82 | valid_paragraph = paragraph[paragraph != pad_token_id] 83 | if valid_paragraph.size(0) <= max_length: 84 | all_paragraphs.append(paragraph[:max_length].unsqueeze(0)) 85 | else: 86 | sep_token_idx = np.arange(valid_paragraph.size(0))[(valid_paragraph == sep_token_id).numpy()] 87 | idx_by_sentence = [] 88 | prev_idx = 0 89 | for idx in sep_token_idx: 90 | idx_by_sentence.append(paragraph[prev_idx:idx]) 91 | prev_idx = idx 92 | objective = max_length - 1 - len(idx_by_sentence[0]) # The last sep_token left out 93 | truncated_sentences = longest_first_truncation(idx_by_sentence[1:], objective) 94 | truncated_paragraph = torch.cat([idx_by_sentence[0]] + truncated_sentences + [torch.tensor([sep_token_id])],0) 95 | all_paragraphs.append(truncated_paragraph.unsqueeze(0)) 96 | 97 | return torch.cat(all_paragraphs, 0) 98 | 99 | inputs = zip(batch["claim"], batch["paragraph"]) 100 | encoded_dict = tokenizer.batch_encode_plus( 101 | inputs, 102 | pad_to_max_length=True,add_special_tokens=True, 103 | return_tensors='pt') 104 | if encoded_dict['input_ids'].size(1) > max_sent_len: 105 | if 'token_type_ids' in encoded_dict: 106 | encoded_dict = { 107 | "input_ids": truncate(encoded_dict['input_ids'], max_sent_len, 108 | tokenizer.sep_token_id, tokenizer.pad_token_id), 109 | 'token_type_ids': encoded_dict['token_type_ids'][:,:max_sent_len], 110 | 'attention_mask': encoded_dict['attention_mask'][:,:max_sent_len] 111 | } 112 | else: 113 | encoded_dict = { 114 | "input_ids": truncate(encoded_dict['input_ids'], max_sent_len, 115 | tokenizer.sep_token_id, tokenizer.pad_token_id), 116 | 'attention_mask': encoded_dict['attention_mask'][:,:max_sent_len] 117 | } 118 | 119 | return encoded_dict 120 | 121 | 122 | def token_idx_by_sentence(input_ids, sep_token_id, model_name): 123 | """ 124 | Compute the token indices matrix of the BERT output. 125 | input_ids: (batch_size, paragraph_len) 126 | batch_indices, indices_by_batch, mask: (batch_size, N_sentence, N_token) 127 | bert_out: (batch_size, paragraph_len,BERT_dim) 128 | bert_out[batch_indices,indices_by_batch,:]: (batch_size, N_sentence, N_token, BERT_dim) 129 | """ 130 | padding_idx = -1 131 | sep_tokens = (input_ids == sep_token_id).bool() 132 | paragraph_lens = torch.sum(sep_tokens,1).numpy().tolist() 133 | indices = torch.arange(sep_tokens.size(-1)).unsqueeze(0).expand(sep_tokens.size(0),-1) 134 | sep_indices = torch.split(indices[sep_tokens],paragraph_lens) 135 | paragraph_lens = [] 136 | all_word_indices = [] 137 | for paragraph in sep_indices: 138 | if "roberta" in model_name: 139 | paragraph = paragraph[1:] 140 | word_indices = [torch.arange(paragraph[i]+1, paragraph[i+1]+1) for i in range(paragraph.size(0)-1)] 141 | paragraph_lens.append(len(word_indices)) 142 | all_word_indices.extend(word_indices) 143 | indices_by_sentence = nn.utils.rnn.pad_sequence(all_word_indices, batch_first=True, padding_value=padding_idx) 144 | indices_by_sentence_split = torch.split(indices_by_sentence,paragraph_lens) 145 | indices_by_batch = nn.utils.rnn.pad_sequence(indices_by_sentence_split, batch_first=True, padding_value=padding_idx) 146 | batch_indices = torch.arange(sep_tokens.size(0)).unsqueeze(-1).unsqueeze(-1).expand(-1,indices_by_batch.size(1),indices_by_batch.size(-1)) 147 | mask = (indices_by_batch>=0) 148 | 149 | return batch_indices.long(), indices_by_batch.long(), mask.long() 150 | 151 | def post_process_stance(rationale_json, stance_json): 152 | assert(len(rationale_json) == len(stance_json)) 153 | for stance_pred, rationale_pred in zip(stance_json, rationale_json): 154 | assert(stance_pred["claim_id"] == rationale_pred["claim_id"]) 155 | for doc_id, pred in rationale_pred["evidence"].items(): 156 | if len(pred) == 0: 157 | stance_pred["labels"][doc_id]["label"] = "NOT_ENOUGH_INFO" 158 | return stance_json 159 | 160 | if __name__ == "__main__": 161 | argparser = argparse.ArgumentParser(description="Train, cross-validate and run sentence sequence tagger") 162 | argparser.add_argument('--repfile', type=str, default = "roberta-large", help="Word embedding file") 163 | argparser.add_argument('--corpus', type=str, default="/nas/home/xiangcil/scifact/data/corpus.jsonl") 164 | argparser.add_argument('--scifact_train', type=str, default="/nas/home/xiangcil/CitationEvaluation/SciFact/claims_train_retrieved.jsonl") 165 | #argparser.add_argument('--scifact_train', type=str) 166 | argparser.add_argument('--scifact_test', type=str, default="/nas/home/xiangcil/CitationEvaluation/SciFact/claims_dev_retrieved.jsonl") 167 | argparser.add_argument('--dataset', type=str, default="/nas/home/xiangcil/CitationEvaluation/SciFact/claims_dev.jsonl") 168 | argparser.add_argument('--dropout', type=float, default=0, help="embedding_dropout rate") 169 | argparser.add_argument('--bert_dim', type=int, default=1024, help="bert_dimension") 170 | argparser.add_argument('--MAX_SENT_LEN', type=int, default=512) 171 | argparser.add_argument('--checkpoint', type=str, default = "domain_adaptation_roberta_joint_paragraph_fine_tune.model") 172 | argparser.add_argument('--log_file', type=str, default = "domain_adaptation_joint_paragraph_prediction.log") 173 | argparser.add_argument('--prediction', type=str, default = "prediction_domain_adaptation_roberta_joint_paragraph_dynamic_fine_tune.jsonl") 174 | argparser.add_argument('--batch_size', type=int, default=1) # roberta-large: 2; bert: 8 175 | argparser.add_argument('--k', type=int, default=0) 176 | 177 | logging.getLogger("transformers.tokenization_utils_base").setLevel(logging.ERROR) 178 | reset_random_seed(12345) 179 | args = argparser.parse_args() 180 | 181 | device = torch.device("cuda:0" if torch.cuda.is_available() else 'cpu') 182 | tokenizer = AutoTokenizer.from_pretrained(args.repfile) 183 | 184 | if args.scifact_train: 185 | train = True 186 | #assert args.repfile is not None, "Word embedding file required for training." 187 | else: 188 | train = False 189 | if args.scifact_test: 190 | test = True 191 | else: 192 | test = False 193 | 194 | params = vars(args) 195 | 196 | for k,v in params.items(): 197 | print(k,v) 198 | 199 | dev_set = SciFactParagraphBatchDataset(args.corpus, args.scifact_test, 200 | sep_token = tokenizer.sep_token, k = args.k, downsample_n = 0, train = False) 201 | 202 | print("Loaded dataset!") 203 | 204 | model = DomainAdaptationJointParagraphClassifier(args.repfile, args.bert_dim, 205 | args.dropout).to(device) 206 | model.load_state_dict(torch.load(args.checkpoint)) 207 | print("Loaded saved model.") 208 | 209 | rationale_predictions, stance_preds = predict(model, dev_set) 210 | rationale_json = rationale2json(dev_set.samples, rationale_predictions) 211 | stance_json = stance2json(dev_set.samples, stance_preds) 212 | stance_json = post_process_stance(rationale_json, stance_json) 213 | merged_json = merge_json(rationale_json, stance_json) 214 | 215 | with jsonlines.open(args.prediction, 'w') as output: 216 | for result in merged_json: 217 | output.write(result) 218 | 219 | data = GoldDataset(args.corpus, args.dataset) 220 | predictions = PredictedDataset(data, args.prediction) 221 | res = metrics.compute_metrics(predictions) 222 | params["evaluation"] = res 223 | with jsonlines.open(args.log_file, mode='a') as writer: 224 | writer.write(params) -------------------------------------------------------------------------------- /lib/data.py: -------------------------------------------------------------------------------- 1 | """ 2 | Code to represent a dataset release. 3 | """ 4 | 5 | from enum import Enum 6 | import json 7 | import copy 8 | from dataclasses import dataclass 9 | from typing import Dict, List, Tuple 10 | 11 | #################### 12 | 13 | # Utility functions and enums. 14 | 15 | 16 | def load_jsonl(fname): 17 | return [json.loads(line) for line in open(fname)] 18 | 19 | 20 | class Label(Enum): 21 | SUPPORTS = 1 22 | NEI = 0 23 | REFUTES = -1 24 | 25 | 26 | def make_label(label_str, allow_NEI=True): 27 | lookup = {"SUPPORT": Label.SUPPORTS, 28 | "NOT_ENOUGH_INFO": Label.NEI, 29 | "CONTRADICT": Label.REFUTES} 30 | 31 | res = lookup[label_str] 32 | if (not allow_NEI) and (res is Label.NEI): 33 | raise ValueError("An NEI was given.") 34 | 35 | return res 36 | 37 | 38 | #################### 39 | 40 | # Representations for the corpus and abstracts. 41 | 42 | @dataclass(repr=False, frozen=True) 43 | class Document: 44 | id: str 45 | title: str 46 | sentences: Tuple[str] 47 | 48 | def __repr__(self): 49 | return self.title.upper() + "\n" + "\n".join(["- " + entry for entry in self.sentences]) 50 | 51 | def __lt__(self, other): 52 | return self.title.__lt__(other.title) 53 | 54 | def dump(self): 55 | res = {"doc_id": self.id, 56 | "title": self.title, 57 | "abstract": self.sentences, 58 | "structured": self.is_structured()} 59 | return json.dumps(res) 60 | 61 | 62 | @dataclass(repr=False, frozen=True) 63 | class Corpus: 64 | """ 65 | A Corpus is just a collection of `Document` objects, with methods to look up 66 | a single document. 67 | """ 68 | documents: List[Document] 69 | 70 | def __repr__(self): 71 | return f"Corpus of {len(self.documents)} documents." 72 | 73 | def __getitem__(self, i): 74 | "Get document by index in list." 75 | return self.documents[i] 76 | 77 | def get_document(self, doc_id): 78 | "Get document by ID." 79 | res = [x for x in self.documents if x.id == doc_id] 80 | assert len(res) == 1 81 | return res[0] 82 | 83 | @classmethod 84 | def from_jsonl(cls, corpus_file): 85 | corpus = load_jsonl(corpus_file) 86 | documents = [] 87 | for entry in corpus: 88 | doc = Document(entry["doc_id"], entry["title"], entry["abstract"]) 89 | documents.append(doc) 90 | 91 | return cls(documents) 92 | 93 | 94 | #################### 95 | 96 | # Gold dataset. 97 | 98 | class GoldDataset: 99 | """ 100 | Class to represent a gold dataset, include corpus and claims. 101 | """ 102 | def __init__(self, corpus_file, data_file): 103 | self.corpus = Corpus.from_jsonl(corpus_file) 104 | self.claims = self._read_claims(data_file) 105 | 106 | def __repr__(self): 107 | msg = f"{self.corpus.__repr__()} {len(self.claims)} claims." 108 | return msg 109 | 110 | def __getitem__(self, i): 111 | return self.claims[i] 112 | 113 | def _read_claims(self, data_file): 114 | "Read claims from file." 115 | examples = load_jsonl(data_file) 116 | res = [] 117 | for this_example in examples: 118 | entry = copy.deepcopy(this_example) 119 | entry["release"] = self 120 | entry["cited_docs"] = [self.corpus.get_document(doc) 121 | for doc in entry["cited_doc_ids"]] 122 | assert len(entry["cited_docs"]) == len(entry["cited_doc_ids"]) 123 | del entry["cited_doc_ids"] 124 | res.append(Claim(**entry)) 125 | 126 | res = sorted(res, key=lambda x: x.id) 127 | return res 128 | 129 | def get_claim(self, example_id): 130 | "Get a single claim by ID." 131 | keep = [x for x in self.claims if x.id == example_id] 132 | assert len(keep) == 1 133 | return keep[0] 134 | 135 | 136 | @dataclass 137 | class EvidenceAbstract: 138 | "A single evidence abstract." 139 | id: int 140 | label: Label 141 | rationales: List[List[int]] 142 | 143 | 144 | @dataclass(repr=False) 145 | class Claim: 146 | """ 147 | Class representing a single claim, with a pointer back to the dataset. 148 | """ 149 | id: int 150 | claim: str 151 | evidence: Dict[int, EvidenceAbstract] 152 | cited_docs: List[Document] 153 | release: GoldDataset 154 | 155 | def __post_init__(self): 156 | self.evidence = self._format_evidence(self.evidence) 157 | 158 | @staticmethod 159 | def _format_evidence(evidence_dict): 160 | # This function is needed because the data schema is designed so that 161 | # each rationale can have its own support label. But, in the dataset, 162 | # all rationales for a given claim / abstract pair all have the same 163 | # label. So, we store the label at the "abstract level" rather than the 164 | # "rationale level". 165 | res = {} 166 | for doc_id, rationales in evidence_dict.items(): 167 | doc_id = int(doc_id) 168 | labels = [x["label"] for x in rationales] 169 | if len(set(labels)) > 1: 170 | msg = ("In this SciFact release, each claim / abstract pair " 171 | "should only have one label.") 172 | raise Exception(msg) 173 | label = make_label(labels[0]) 174 | rationale_sents = [x["sentences"] for x in rationales] 175 | this_abstract = EvidenceAbstract(doc_id, label, rationale_sents) 176 | res[doc_id] = this_abstract 177 | 178 | return res 179 | 180 | def __repr__(self): 181 | msg = f"Example {self.id}: {self.claim}" 182 | return msg 183 | 184 | def pretty_print(self, evidence_doc_id=None, file=None): 185 | "Pretty-print the claim, together with all evidence." 186 | msg = self.__repr__() 187 | print(msg, file=file) 188 | # Print the evidence 189 | print("\nEvidence sets:", file=file) 190 | for doc_id, evidence in self.evidence.items(): 191 | # If asked for a specific evidence doc, only show that one. 192 | if evidence_doc_id is not None and doc_id != evidence_doc_id: 193 | continue 194 | print("\n" + 20 * "#" + "\n", file=file) 195 | ev_doc = self.release.corpus.get_document(doc_id) 196 | print(f"{doc_id}: {evidence.label.name}", file=file) 197 | for i, sents in enumerate(evidence.rationales): 198 | print(f"Set {i}:", file=file) 199 | kept = [sent for i, sent in enumerate(ev_doc.sentences) if i in sents] 200 | for entry in kept: 201 | print(f"\t- {entry}", file=file) 202 | 203 | 204 | #################### 205 | 206 | # Predicted dataset. 207 | 208 | class PredictedDataset: 209 | """ 210 | Class to handle predictions, with a pointer back to the gold data. 211 | """ 212 | def __init__(self, gold, prediction_file): 213 | """ 214 | Takes a GoldDataset, as well as files with rationale and label 215 | predictions. 216 | """ 217 | self.gold = gold 218 | self.predictions = self._read_predictions(prediction_file) 219 | 220 | def __getitem__(self, i): 221 | return self.predictions[i] 222 | 223 | def __repr__(self): 224 | msg = f"Predictions for {len(self.predictions)} claims." 225 | return msg 226 | 227 | def _read_predictions(self, prediction_file): 228 | res = [] 229 | 230 | predictions = load_jsonl(prediction_file) 231 | for pred in predictions: 232 | prediction = self._parse_prediction(pred) 233 | res.append(prediction) 234 | 235 | return res 236 | 237 | def _parse_prediction(self, pred_dict): 238 | claim_id = pred_dict["id"] 239 | predicted_evidence = pred_dict["evidence"] 240 | 241 | res = {} 242 | 243 | # Predictions should never be NEI; there should only be predictions for 244 | # the abstracts that contain evidence. 245 | for key, this_prediction in predicted_evidence.items(): 246 | label = this_prediction["label"] 247 | evidence = this_prediction["sentences"] 248 | pred = PredictedAbstract(int(key), 249 | make_label(label, allow_NEI=False), 250 | evidence) 251 | res[int(key)] = pred 252 | 253 | gold_claim = self.gold.get_claim(claim_id) 254 | return ClaimPredictions(claim_id, res, gold_claim) 255 | 256 | 257 | @dataclass 258 | class PredictedAbstract: 259 | # For predictions, we have a single list of rationale sentences instead of a 260 | # list of separate rationales (see paper for details). 261 | abstract_id: int 262 | label: Label 263 | rationale: List 264 | 265 | 266 | @dataclass 267 | class ClaimPredictions: 268 | claim_id: int 269 | predictions: Dict[int, PredictedAbstract] 270 | gold: Claim = None # For backward compatibility, default this to None. 271 | 272 | def __repr__(self): 273 | msg = f"Predictions for {self.claim_id}: {self.gold.claim}" 274 | return msg 275 | 276 | def pretty_print(self, evidence_doc_id=None, file=None): 277 | msg = self.__repr__() 278 | print(msg, file=file) 279 | # Print the evidence 280 | print("\nEvidence sets:", file=file) 281 | for doc_id, prediction in self.predictions.items(): 282 | # If asked for a specific evidence doc, only show that one. 283 | if evidence_doc_id is not None and doc_id != evidence_doc_id: 284 | continue 285 | print("\n" + 20 * "#" + "\n", file=file) 286 | ev_doc = self.gold.release.corpus.get_document(doc_id) 287 | print(f"{doc_id}: {prediction.label.name}", file=file) 288 | # Print the predicted rationale. 289 | sents = prediction.rationale 290 | kept = [sent for i, sent in enumerate(ev_doc.sentences) if i in sents] 291 | for entry in kept: 292 | print(f"\t- {entry}", file=file) 293 | -------------------------------------------------------------------------------- /lib/metrics.py: -------------------------------------------------------------------------------- 1 | """ 2 | Evaluating abstract-level and sentence-level performance as described in the 3 | paper. 4 | """ 5 | 6 | import warnings 7 | 8 | from .data import Label 9 | from collections import Counter 10 | import pandas as pd 11 | 12 | 13 | # Cap on how many abstract sentences can be returned. 14 | MAX_ABSTRACT_SENTS = 3 15 | 16 | 17 | def safe_divide(num, denom): 18 | if denom == 0: 19 | return 0 20 | else: 21 | return num / denom 22 | 23 | 24 | def compute_f1(counts, difficulty=None): 25 | correct_key = "correct" if difficulty is None else f"correct_{difficulty}" 26 | precision = safe_divide(counts[correct_key], counts["retrieved"]) 27 | recall = safe_divide(counts[correct_key], counts["relevant"]) 28 | f1 = safe_divide(2 * precision * recall, precision + recall) 29 | return {"precision": precision, "recall": recall, "f1": f1} 30 | 31 | 32 | #################### 33 | 34 | # Abstract-level evaluation 35 | 36 | def contains_evidence(predicted, gold): 37 | # If any of gold are contained in predicted, we're good. 38 | for gold_rat in gold: 39 | if gold_rat.issubset(predicted): 40 | return True 41 | # If we get to the end, didn't find one. 42 | return False 43 | 44 | 45 | def is_correct(doc_id, doc_pred, gold): 46 | pred_rationales = doc_pred.rationale[:MAX_ABSTRACT_SENTS] 47 | 48 | # If it's not an evidence document, we lose. 49 | if doc_id not in gold.evidence: 50 | return False, False 51 | 52 | # If the label's wrong, we lose. 53 | gold_label = gold.evidence[doc_id].label 54 | if doc_pred.label != gold_label: 55 | return False, False 56 | 57 | gold_rationales = [set(x) for x in gold.evidence[doc_id].rationales] 58 | good_rationalized = contains_evidence(set(pred_rationales), gold_rationales) 59 | good_label_only = True 60 | return good_label_only, good_rationalized 61 | 62 | 63 | def update_counts_abstract(pred, gold, counts_abstract): 64 | counts_abstract["relevant"] += len(gold.evidence) 65 | for doc_id, doc_pred in pred.predictions.items(): 66 | # If it's NEI, doesn't count one way or the other. 67 | if doc_pred.label == Label.NEI: 68 | continue 69 | counts_abstract["retrieved"] += 1 70 | 71 | good_label_only, good_rationalized = is_correct(doc_id, doc_pred, gold) 72 | if good_label_only: 73 | counts_abstract["correct_label_only"] += 1 74 | if good_rationalized: 75 | counts_abstract["correct_rationalized"] += 1 76 | 77 | return counts_abstract 78 | 79 | 80 | #################### 81 | 82 | # Sentence-level evaluation 83 | 84 | def count_rationale_sents(predicted, gold): 85 | n_correct = 0 86 | 87 | for ix in predicted: 88 | gold_sets = [entry for entry in gold if ix in entry] 89 | assert len(gold_sets) < 2 # Can't be in two rationales. 90 | # If it's not in a gold set, no dice. 91 | if len(gold_sets) == 0: 92 | continue 93 | # If it's in a gold set, make sure the rest got retrieved. 94 | gold_set = gold_sets[0] 95 | if gold_set.issubset(predicted): 96 | n_correct += 1 97 | 98 | return n_correct 99 | 100 | 101 | def count_correct(doc_id, doc_pred, gold): 102 | # If not an evidence doc, no good. 103 | if doc_id not in gold.evidence: 104 | return 0, 0 105 | 106 | # Count the number of rationale sentences we get credit for. 107 | gold_rationales = [set(x) for x in gold.evidence[doc_id].rationales] 108 | n_correct = count_rationale_sents(set(doc_pred.rationale), gold_rationales) 109 | 110 | gold_label = gold.evidence[doc_id].label 111 | 112 | n_correct_selection = n_correct 113 | correct_label = int(doc_pred.label == gold_label) 114 | n_correct_label = correct_label * n_correct 115 | 116 | return n_correct_selection, n_correct_label 117 | 118 | 119 | def update_counts_sentence(pred, gold, counts_sentence): 120 | # Update the gold evidence sentences. 121 | for gold_doc in gold.evidence.values(): 122 | counts_sentence["relevant"] += sum([len(x) for x in gold_doc.rationales]) 123 | 124 | for doc_id, doc_pred in pred.predictions.items(): 125 | # If it's NEI, skip it. 126 | if doc_pred.label == Label.NEI: 127 | continue 128 | 129 | counts_sentence["retrieved"] += len(doc_pred.rationale) 130 | n_correct_selection, n_correct_label = count_correct(doc_id, doc_pred, gold) 131 | counts_sentence["correct_selection"] += n_correct_selection 132 | counts_sentence["correct_label"] += n_correct_label 133 | 134 | return counts_sentence 135 | 136 | 137 | #################### 138 | 139 | # Make sure rationales aren't too long. 140 | 141 | def check_rationale_lengths(preds): 142 | bad = [] 143 | for pred in preds: 144 | claim_id = pred.claim_id 145 | predictions = pred.predictions 146 | for doc_key, prediction in predictions.items(): 147 | n_rationales = len(prediction.rationale) 148 | if n_rationales > MAX_ABSTRACT_SENTS: 149 | to_append = {"claim_id": claim_id, "abstract": doc_key, "n_rationales": n_rationales} 150 | bad.append(to_append) 151 | if bad: 152 | bad = pd.DataFrame(bad) 153 | msg = (f"\nRationales with more than {MAX_ABSTRACT_SENTS} sentences found.\n" 154 | f"The first 3 will be used for abstract-level evaluation\n\n" 155 | f"{bad.__repr__()}") 156 | warnings.warn(msg) 157 | print() 158 | 159 | 160 | ################################################################################ 161 | 162 | def compute_metrics(preds): 163 | """ 164 | Compute pipeline metrics based on dataset of predictions. 165 | """ 166 | counts_abstract = Counter() 167 | counts_sentence = Counter() 168 | 169 | check_rationale_lengths(preds) 170 | 171 | for pred in preds: 172 | gold = preds.gold.get_claim(pred.claim_id) 173 | counts_abstract = update_counts_abstract(pred, gold, counts_abstract) 174 | counts_sentence = update_counts_sentence(pred, gold, counts_sentence) 175 | 176 | return {"sentence_selection": compute_f1(counts_sentence, "selection"), 177 | "sentence_label": compute_f1(counts_sentence, "label"), 178 | "abstract_label_only": compute_f1(counts_abstract, "label_only"), 179 | "abstract_rationalized": compute_f1(counts_abstract, "rationalized") 180 | } 181 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | scipy==1.5.3 2 | sent2vec==0.0.0 3 | tqdm==4.50.2 4 | jsonlines==1.2.0 5 | torch==1.6.0 6 | numpy==1.19.1 7 | transformers==2.6.0 8 | nltk==3.5 9 | pandas==1.1.3 10 | dataclasses==0.7 11 | scikit_learn==0.24.1 12 | -------------------------------------------------------------------------------- /scifact_joint_paragraph_dynamic_prediction.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import torch 4 | import jsonlines 5 | import os 6 | 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | from torch.utils.data import Dataset, DataLoader 10 | from transformers import AutoModel, AutoTokenizer, get_cosine_schedule_with_warmup 11 | from tqdm import tqdm 12 | from typing import List 13 | from sklearn.metrics import f1_score, precision_score, recall_score 14 | 15 | import math 16 | import random 17 | import numpy as np 18 | 19 | from tqdm import tqdm 20 | from util import arg2param, flatten, stance2json, rationale2json, merge_json 21 | from paragraph_model_dynamic import JointParagraphClassifier 22 | from dataset import SciFactParagraphBatchDataset 23 | 24 | import logging 25 | 26 | from lib.data import GoldDataset, PredictedDataset 27 | from lib import metrics 28 | 29 | def reset_random_seed(seed): 30 | np.random.seed(seed) 31 | random.seed(seed) 32 | torch.manual_seed(seed) 33 | 34 | def predict(model, dataset): 35 | model.eval() 36 | rationale_predictions = [] 37 | stance_preds = [] 38 | 39 | def remove_dummy(rationale_out): 40 | return [out[1:] for out in rationale_out] 41 | 42 | with torch.no_grad(): 43 | for batch in tqdm(DataLoader(dataset, batch_size = args.batch_size, shuffle=False)): 44 | encoded_dict = encode(tokenizer, batch) 45 | transformation_indices = token_idx_by_sentence(encoded_dict["input_ids"], 46 | tokenizer.sep_token_id, args.repfile) 47 | encoded_dict = {key: tensor.to(device) for key, tensor in encoded_dict.items()} 48 | transformation_indices = [tensor.to(device) for tensor in transformation_indices] 49 | rationale_out, stance_out, _, _ = model(encoded_dict, transformation_indices) 50 | stance_preds.extend(stance_out) 51 | rationale_predictions.extend(remove_dummy(rationale_out)) 52 | 53 | return rationale_predictions, stance_preds 54 | 55 | 56 | def encode(tokenizer, batch, max_sent_len = 512): 57 | def truncate(input_ids, max_length, sep_token_id, pad_token_id): 58 | def longest_first_truncation(sentences, objective): 59 | sent_lens = [len(sent) for sent in sentences] 60 | while np.sum(sent_lens) > objective: 61 | max_position = np.argmax(sent_lens) 62 | sent_lens[max_position] -= 1 63 | return [sentence[:length] for sentence, length in zip(sentences, sent_lens)] 64 | 65 | all_paragraphs = [] 66 | for paragraph in input_ids: 67 | valid_paragraph = paragraph[paragraph != pad_token_id] 68 | if valid_paragraph.size(0) <= max_length: 69 | all_paragraphs.append(paragraph[:max_length].unsqueeze(0)) 70 | else: 71 | sep_token_idx = np.arange(valid_paragraph.size(0))[(valid_paragraph == sep_token_id).numpy()] 72 | idx_by_sentence = [] 73 | prev_idx = 0 74 | for idx in sep_token_idx: 75 | idx_by_sentence.append(paragraph[prev_idx:idx]) 76 | prev_idx = idx 77 | objective = max_length - 1 - len(idx_by_sentence[0]) # The last sep_token left out 78 | truncated_sentences = longest_first_truncation(idx_by_sentence[1:], objective) 79 | truncated_paragraph = torch.cat([idx_by_sentence[0]] + truncated_sentences + [torch.tensor([sep_token_id])],0) 80 | all_paragraphs.append(truncated_paragraph.unsqueeze(0)) 81 | 82 | return torch.cat(all_paragraphs, 0) 83 | 84 | inputs = zip(batch["claim"], batch["paragraph"]) 85 | encoded_dict = tokenizer.batch_encode_plus( 86 | inputs, 87 | pad_to_max_length=True,add_special_tokens=True, 88 | return_tensors='pt') 89 | if encoded_dict['input_ids'].size(1) > max_sent_len: 90 | if 'token_type_ids' in encoded_dict: 91 | encoded_dict = { 92 | "input_ids": truncate(encoded_dict['input_ids'], max_sent_len, 93 | tokenizer.sep_token_id, tokenizer.pad_token_id), 94 | 'token_type_ids': encoded_dict['token_type_ids'][:,:max_sent_len], 95 | 'attention_mask': encoded_dict['attention_mask'][:,:max_sent_len] 96 | } 97 | else: 98 | encoded_dict = { 99 | "input_ids": truncate(encoded_dict['input_ids'], max_sent_len, 100 | tokenizer.sep_token_id, tokenizer.pad_token_id), 101 | 'attention_mask': encoded_dict['attention_mask'][:,:max_sent_len] 102 | } 103 | return encoded_dict 104 | 105 | def token_idx_by_sentence(input_ids, sep_token_id, model_name): 106 | """ 107 | Compute the token indices matrix of the BERT output. 108 | input_ids: (batch_size, paragraph_len) 109 | batch_indices, indices_by_batch, mask: (batch_size, N_sentence, N_token) 110 | bert_out: (batch_size, paragraph_len,BERT_dim) 111 | bert_out[batch_indices,indices_by_batch,:]: (batch_size, N_sentence, N_token, BERT_dim) 112 | """ 113 | padding_idx = -1 114 | sep_tokens = (input_ids == sep_token_id).bool() 115 | paragraph_lens = torch.sum(sep_tokens,1).numpy().tolist() 116 | indices = torch.arange(sep_tokens.size(-1)).unsqueeze(0).expand(sep_tokens.size(0),-1) 117 | sep_indices = torch.split(indices[sep_tokens],paragraph_lens) 118 | paragraph_lens = [] 119 | all_word_indices = [] 120 | for paragraph in sep_indices: 121 | if "roberta" in model_name: 122 | paragraph = paragraph[1:] 123 | word_indices = [torch.arange(paragraph[i]+1, paragraph[i+1]+1) for i in range(paragraph.size(0)-1)] 124 | paragraph_lens.append(len(word_indices)) 125 | all_word_indices.extend(word_indices) 126 | indices_by_sentence = nn.utils.rnn.pad_sequence(all_word_indices, batch_first=True, padding_value=padding_idx) 127 | indices_by_sentence_split = torch.split(indices_by_sentence,paragraph_lens) 128 | indices_by_batch = nn.utils.rnn.pad_sequence(indices_by_sentence_split, batch_first=True, padding_value=padding_idx) 129 | batch_indices = torch.arange(sep_tokens.size(0)).unsqueeze(-1).unsqueeze(-1).expand(-1,indices_by_batch.size(1),indices_by_batch.size(-1)) 130 | mask = (indices_by_batch>=0) 131 | 132 | return batch_indices.long(), indices_by_batch.long(), mask.long() 133 | 134 | def post_process_stance(rationale_json, stance_json): 135 | assert(len(rationale_json) == len(stance_json)) 136 | for stance_pred, rationale_pred in zip(stance_json, rationale_json): 137 | assert(stance_pred["claim_id"] == rationale_pred["claim_id"]) 138 | for doc_id, pred in rationale_pred["evidence"].items(): 139 | if len(pred) == 0: 140 | stance_pred["labels"][doc_id]["label"] = "NOT_ENOUGH_INFO" 141 | return stance_json 142 | 143 | def post_process_rationale_score(rationale_scores, max_positive=3): ### Doesn't seem to be helpful? 144 | def process_rationale_score(paragraph_rationale_scores): 145 | paragraph_rationale_scores = np.array(paragraph_rationale_scores) 146 | if np.sum(paragraph_rationale_scores > 0.5) > max_positive: 147 | output = np.zeros(paragraph_rationale_scores.shape) 148 | positive_indices = np.argsort(paragraph_rationale_scores)[::-1][:max_positive] 149 | output[positive_indices] = 1 150 | else: 151 | output = (paragraph_rationale_scores > 0.5).astype(int) 152 | return output.tolist() 153 | return [process_rationale_score(paragraph_rationale_scores) for paragraph_rationale_scores in rationale_scores] 154 | 155 | if __name__ == "__main__": 156 | argparser = argparse.ArgumentParser(description="Train, cross-validate and run sentence sequence tagger") 157 | argparser.add_argument('--repfile', type=str, default = "roberta-large", help="Word embedding file") 158 | argparser.add_argument('--corpus_file', type=str, default="/home/xxl190027/scifact_data/corpus.jsonl") 159 | argparser.add_argument('--test_file', type=str, default="/home/xxl190027/scifact_data/claims_dev_retrieved.jsonl") 160 | argparser.add_argument('--dataset', type=str, default="/home/xxl190027/scifact_data/claims_dev.jsonl") 161 | argparser.add_argument('--dropout', type=float, default=0, help="embedding_dropout rate") 162 | argparser.add_argument('--bert_dim', type=int, default=1024, help="bert_dimension") 163 | argparser.add_argument('--MAX_SENT_LEN', type=int, default=512) 164 | argparser.add_argument('--checkpoint', type=str, default = "scifact_roberta_joint_paragraph.model") 165 | argparser.add_argument('--batch_size', type=int, default=25) 166 | argparser.add_argument('--k', type=int, default=0) 167 | argparser.add_argument('--rationale_selection', type=str) 168 | argparser.add_argument('--label_prediction', type=str) 169 | argparser.add_argument('--prediction', type=str) 170 | argparser.add_argument('--log_file', type=str, default = "prediction_dynamic.log") 171 | argparser.add_argument('--evaluate', action='store_true') 172 | 173 | 174 | logging.getLogger("transformers.tokenization_utils_base").setLevel(logging.ERROR) 175 | 176 | reset_random_seed(12345) 177 | 178 | args = argparser.parse_args() 179 | params = vars(args) 180 | 181 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 182 | tokenizer = AutoTokenizer.from_pretrained(args.repfile) 183 | 184 | dev_set = SciFactParagraphBatchDataset(args.corpus_file, args.test_file, 185 | sep_token = tokenizer.sep_token, k = args.k, train=False) 186 | 187 | model = JointParagraphClassifier(args.repfile, args.bert_dim, 188 | args.dropout).to(device) 189 | 190 | model.load_state_dict(torch.load(args.checkpoint)) 191 | print("Loaded saved model.") 192 | 193 | reset_random_seed(12345) 194 | rationale_predictions, stance_preds = predict(model, dev_set) 195 | rationale_json = rationale2json(dev_set.samples, rationale_predictions) 196 | stance_json = stance2json(dev_set.samples, stance_preds) 197 | stance_json = post_process_stance(rationale_json, stance_json) 198 | merged_json = merge_json(rationale_json, stance_json) 199 | if args.rationale_selection is not None: 200 | with jsonlines.open(args.rationale_selection, 'w') as output: 201 | for result in rationale_json: 202 | output.write(result) 203 | if args.label_prediction is not None: 204 | with jsonlines.open(args.label_prediction, 'w') as output: 205 | for result in stance_json: 206 | output.write(result) 207 | if args.prediction is not None: 208 | with jsonlines.open(args.prediction, 'w') as output: 209 | for result in merged_json: 210 | output.write(result) 211 | if args.evaluate: 212 | data = GoldDataset(args.corpus_file, args.dataset) 213 | predictions = PredictedDataset(data, args.prediction) 214 | res = metrics.compute_metrics(predictions) 215 | params["evaluation"] = res 216 | with jsonlines.open(args.log_file, mode='a') as writer: 217 | writer.write(params) -------------------------------------------------------------------------------- /scifact_joint_paragraph_kgat.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import torch 4 | import jsonlines 5 | import os 6 | 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | from torch.utils.data import Dataset, DataLoader 10 | from transformers import AutoModel, AutoTokenizer, get_cosine_schedule_with_warmup 11 | from tqdm import tqdm 12 | from typing import List 13 | from sklearn.metrics import f1_score, precision_score, recall_score 14 | 15 | import math 16 | import random 17 | import numpy as np 18 | 19 | from tqdm import tqdm 20 | from util import arg2param, flatten, stance2json, rationale2json, merge_json 21 | from paragraph_model_kgat import JointParagraphKGATClassifier 22 | from dataset import SciFactParagraphBatchDataset 23 | 24 | import logging 25 | 26 | from lib.data import GoldDataset, PredictedDataset 27 | from lib import metrics 28 | 29 | def schedule_sample_p(epoch, total): 30 | return np.sin(0.5* np.pi* epoch / (total-1)) 31 | 32 | def reset_random_seed(seed): 33 | np.random.seed(seed) 34 | random.seed(seed) 35 | torch.manual_seed(seed) 36 | 37 | def predict(model, dataset): 38 | model.eval() 39 | rationale_predictions = [] 40 | stance_preds = [] 41 | 42 | def remove_dummy(rationale_out): 43 | return [out[1:] for out in rationale_out] 44 | 45 | with torch.no_grad(): 46 | for batch in tqdm(DataLoader(dataset, batch_size = args.batch_size, shuffle=False)): 47 | encoded_dict = encode(tokenizer, batch) 48 | transformation_indices = token_idx_by_sentence(encoded_dict["input_ids"], 49 | tokenizer.sep_token_id, args.repfile) 50 | encoded_dict = {key: tensor.to(device) for key, tensor in encoded_dict.items()} 51 | transformation_indices = [tensor.to(device) for tensor in transformation_indices] 52 | rationale_out, stance_out, _, _ = model(encoded_dict, transformation_indices) 53 | stance_preds.extend(stance_out) 54 | rationale_predictions.extend(remove_dummy(rationale_out)) 55 | 56 | return rationale_predictions, stance_preds 57 | 58 | def batch_rationale_label(labels, padding_idx = 2): 59 | max_sent_len = max([len(label) for label in labels]) 60 | label_matrix = torch.ones(len(labels), max_sent_len) * padding_idx 61 | label_list = [] 62 | for i, label in enumerate(labels): 63 | for j, evid in enumerate(label): 64 | label_matrix[i,j] = int(evid) 65 | label_list.append([int(evid) for evid in label]) 66 | return label_matrix.long(), label_list 67 | 68 | def evaluation(model, dataset): 69 | model.eval() 70 | rationale_predictions = [] 71 | rationale_labels = [] 72 | stance_preds = [] 73 | stance_labels = [] 74 | 75 | def remove_dummy(rationale_out): 76 | return [out[1:] for out in rationale_out] 77 | 78 | with torch.no_grad(): 79 | for batch in tqdm(DataLoader(dataset, batch_size = args.batch_size*5, shuffle=False)): 80 | encoded_dict = encode(tokenizer, batch) 81 | transformation_indices = token_idx_by_sentence(encoded_dict["input_ids"], 82 | tokenizer.sep_token_id, args.repfile) 83 | encoded_dict = {key: tensor.to(device) for key, tensor in encoded_dict.items()} 84 | transformation_indices = [tensor.to(device) for tensor in transformation_indices] 85 | stance_label = batch["stance"].to(device) 86 | padded_rationale_label, rationale_label = batch_rationale_label(batch["label"], padding_idx = 2) 87 | rationale_out, stance_out, rationale_loss, stance_loss = \ 88 | model(encoded_dict, transformation_indices, stance_label = stance_label, 89 | rationale_label = padded_rationale_label.to(device)) 90 | stance_preds.extend(stance_out) 91 | stance_labels.extend(stance_label.cpu().numpy().tolist()) 92 | 93 | rationale_predictions.extend(remove_dummy(rationale_out)) 94 | rationale_labels.extend(remove_dummy(rationale_label)) 95 | 96 | stance_f1 = f1_score(stance_labels,stance_preds,average="micro",labels=[1, 2]) 97 | stance_precision = precision_score(stance_labels,stance_preds,average="micro",labels=[1, 2]) 98 | stance_recall = recall_score(stance_labels,stance_preds,average="micro",labels=[1, 2]) 99 | rationale_f1 = f1_score(flatten(rationale_labels),flatten(rationale_predictions)) 100 | rationale_precision = precision_score(flatten(rationale_labels),flatten(rationale_predictions)) 101 | rationale_recall = recall_score(flatten(rationale_labels),flatten(rationale_predictions)) 102 | return stance_f1, stance_precision, stance_recall, rationale_f1, rationale_precision, rationale_recall 103 | 104 | 105 | 106 | def encode(tokenizer, batch, max_sent_len = 512): 107 | def truncate(input_ids, max_length, sep_token_id, pad_token_id): 108 | def longest_first_truncation(sentences, objective): 109 | sent_lens = [len(sent) for sent in sentences] 110 | while np.sum(sent_lens) > objective: 111 | max_position = np.argmax(sent_lens) 112 | sent_lens[max_position] -= 1 113 | return [sentence[:length] for sentence, length in zip(sentences, sent_lens)] 114 | 115 | all_paragraphs = [] 116 | for paragraph in input_ids: 117 | valid_paragraph = paragraph[paragraph != pad_token_id] 118 | if valid_paragraph.size(0) <= max_length: 119 | all_paragraphs.append(paragraph[:max_length].unsqueeze(0)) 120 | else: 121 | sep_token_idx = np.arange(valid_paragraph.size(0))[(valid_paragraph == sep_token_id).numpy()] 122 | idx_by_sentence = [] 123 | prev_idx = 0 124 | for idx in sep_token_idx: 125 | idx_by_sentence.append(paragraph[prev_idx:idx]) 126 | prev_idx = idx 127 | objective = max_length - 1 - len(idx_by_sentence[0]) # The last sep_token left out 128 | truncated_sentences = longest_first_truncation(idx_by_sentence[1:], objective) 129 | truncated_paragraph = torch.cat([idx_by_sentence[0]] + truncated_sentences + [torch.tensor([sep_token_id])],0) 130 | all_paragraphs.append(truncated_paragraph.unsqueeze(0)) 131 | 132 | return torch.cat(all_paragraphs, 0) 133 | 134 | inputs = zip(batch["claim"], batch["paragraph"]) 135 | encoded_dict = tokenizer.batch_encode_plus( 136 | inputs, 137 | pad_to_max_length=True,add_special_tokens=True, 138 | return_tensors='pt') 139 | if encoded_dict['input_ids'].size(1) > max_sent_len: 140 | if 'token_type_ids' in encoded_dict: 141 | encoded_dict = { 142 | "input_ids": truncate(encoded_dict['input_ids'], max_sent_len, 143 | tokenizer.sep_token_id, tokenizer.pad_token_id), 144 | 'token_type_ids': encoded_dict['token_type_ids'][:,:max_sent_len], 145 | 'attention_mask': encoded_dict['attention_mask'][:,:max_sent_len] 146 | } 147 | else: 148 | encoded_dict = { 149 | "input_ids": truncate(encoded_dict['input_ids'], max_sent_len, 150 | tokenizer.sep_token_id, tokenizer.pad_token_id), 151 | 'attention_mask': encoded_dict['attention_mask'][:,:max_sent_len] 152 | } 153 | 154 | return encoded_dict 155 | 156 | def token_idx_by_sentence(input_ids, sep_token_id, model_name): 157 | """ 158 | Advanced indexing: Compute the token indices matrix of the BERT output. 159 | input_ids: (batch_size, paragraph_len) 160 | batch_indices, indices_by_batch, mask: (batch_size, N_sentence, N_token) 161 | bert_out: (batch_size, paragraph_len,BERT_dim) 162 | bert_out[batch_indices,indices_by_batch,:]: (batch_size, N_sentence, N_token, BERT_dim) 163 | """ 164 | padding_idx = -1 165 | sep_tokens = (input_ids == sep_token_id).bool() 166 | paragraph_lens = torch.sum(sep_tokens,1).numpy().tolist() # i.e. N_sentences per paragraph 167 | indices = torch.arange(sep_tokens.size(-1)).unsqueeze(0).expand(sep_tokens.size(0),-1) # 0,1,2,3,....,511 for each sentence 168 | sep_indices = torch.split(indices[sep_tokens],paragraph_lens) # indices of SEP tokens per paragraph 169 | paragraph_lens = [] 170 | all_word_indices = [] 171 | for paragraph in sep_indices: 172 | # claim sentence: [CLS] token1 token2 ... tokenk 173 | claim_word_indices = torch.arange(0, paragraph[0]) 174 | if "roberta" in model_name: # Huggingface Roberta has ...... 175 | paragraph = paragraph[1:] 176 | # each sentence: [SEP] token1 token2 ... tokenk, the last [SEP] in the paragraph is ditched. 177 | sentence_word_indices = [torch.arange(paragraph[i], paragraph[i+1]) for i in range(paragraph.size(0)-1)] 178 | 179 | # KGAT requires claim sentence, so add it back. 180 | word_indices = [claim_word_indices] + sentence_word_indices 181 | 182 | paragraph_lens.append(len(word_indices)) 183 | all_word_indices.extend(word_indices) 184 | indices_by_sentence = nn.utils.rnn.pad_sequence(all_word_indices, batch_first=True, padding_value=padding_idx) 185 | indices_by_sentence_split = torch.split(indices_by_sentence,paragraph_lens) 186 | indices_by_batch = nn.utils.rnn.pad_sequence(indices_by_sentence_split, batch_first=True, padding_value=padding_idx) 187 | batch_indices = torch.arange(sep_tokens.size(0)).unsqueeze(-1).unsqueeze(-1).expand(-1,indices_by_batch.size(1),indices_by_batch.size(-1)) 188 | mask = (indices_by_batch>=0) 189 | 190 | return batch_indices.long(), indices_by_batch.long(), mask.long() 191 | 192 | def post_process_stance(rationale_json, stance_json): 193 | assert(len(rationale_json) == len(stance_json)) 194 | for stance_pred, rationale_pred in zip(stance_json, rationale_json): 195 | assert(stance_pred["claim_id"] == rationale_pred["claim_id"]) 196 | for doc_id, pred in rationale_pred["evidence"].items(): 197 | if len(pred) == 0: 198 | stance_pred["labels"][doc_id]["label"] = "NOT_ENOUGH_INFO" 199 | return stance_json 200 | 201 | if __name__ == "__main__": 202 | argparser = argparse.ArgumentParser(description="Train, cross-validate and run sentence sequence tagger") 203 | argparser.add_argument('--repfile', type=str, default = "roberta-large", help="Word embedding file") 204 | argparser.add_argument('--corpus_file', type=str, default="/home/xxl190027/scifact_data/corpus.jsonl") 205 | argparser.add_argument('--train_file', type=str, default="/home/xxl190027/scifact_data/claims_train_retrieved.jsonl") 206 | argparser.add_argument('--pre_trained_model', type=str) 207 | #argparser.add_argument('--train_file', type=str) 208 | argparser.add_argument('--test_file', type=str, default="/home/xxl190027/scifact_data/claims_dev_retrieved.jsonl") 209 | argparser.add_argument('--dataset', type=str, default="/home/xxl190027/scifact_data/claims_dev.jsonl") 210 | argparser.add_argument('--bert_lr', type=float, default=5e-6, help="Learning rate for BERT-like LM") 211 | argparser.add_argument('--lr', type=float, default=5e-6, help="Learning rate") 212 | argparser.add_argument('--dropout', type=float, default=0, help="embedding_dropout rate") 213 | argparser.add_argument('--bert_dim', type=int, default=1024, help="bert_dimension") 214 | argparser.add_argument('--epoch', type=int, default=10, help="Training epoch") 215 | argparser.add_argument('--MAX_SENT_LEN', type=int, default=512) 216 | argparser.add_argument('--loss_ratio', type=float, default=3) 217 | argparser.add_argument('--checkpoint', type=str, default = "scifact_roberta_joint_paragraph_kgat.model") 218 | argparser.add_argument('--log_file', type=str, default = "joint_paragraph_roberta_kgat_dynamic_performances.jsonl") 219 | argparser.add_argument('--prediction', type=str, default = "prediction_scifact_roberta_joint_paragraph_kgat_fine_tune.jsonl") 220 | argparser.add_argument('--update_step', type=int, default=10) 221 | argparser.add_argument('--batch_size', type=int, default=1) # roberta-large: 2; bert: 8 222 | argparser.add_argument('--k', type=int, default=0) 223 | argparser.add_argument('--kernel', type=int, default=6) 224 | argparser.add_argument('--fine_tune', action='store_true') 225 | 226 | logging.getLogger("transformers.tokenization_utils_base").setLevel(logging.ERROR) 227 | 228 | reset_random_seed(12345) 229 | 230 | args = argparser.parse_args() 231 | 232 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 233 | tokenizer = AutoTokenizer.from_pretrained(args.repfile) 234 | 235 | if args.train_file: 236 | train = True 237 | #assert args.repfile is not None, "Word embedding file required for training." 238 | else: 239 | train = False 240 | if args.test_file: 241 | test = True 242 | else: 243 | test = False 244 | 245 | params = vars(args) 246 | 247 | for k,v in params.items(): 248 | print(k,v) 249 | 250 | if train: 251 | train_set = SciFactParagraphBatchDataset(args.corpus_file, args.train_file, 252 | sep_token = tokenizer.sep_token, k = args.k, downsample_n = 2) 253 | dev_set = SciFactParagraphBatchDataset(args.corpus_file, args.test_file, 254 | sep_token = tokenizer.sep_token, k = args.k, downsample_n = 0) 255 | 256 | model = JointParagraphKGATClassifier(args.repfile, args.bert_dim, 257 | args.dropout, kernel = args.kernel)#.to(device) 258 | 259 | if args.pre_trained_model is not None: 260 | model.load_state_dict(torch.load(args.pre_trained_model)) 261 | if not args.fine_tune: 262 | model.reinitialize() ############ 263 | print("Reinitialized part of the model!") 264 | 265 | model = model.to(device) 266 | 267 | if train: 268 | if args.bert_lr == 0: 269 | print("Freezing BERT weights.") 270 | settings = [] 271 | for param in model.bert.parameters(): 272 | param.requires_grad = False 273 | else: 274 | settings = [{'params': model.bert.parameters(), 'lr': args.bert_lr}] 275 | for module in model.extra_modules: 276 | settings.append({'params': module.parameters(), 'lr': args.lr}) 277 | optimizer = torch.optim.Adam(settings) 278 | scheduler = get_cosine_schedule_with_warmup(optimizer, 0, args.epoch) 279 | model.train() 280 | 281 | prev_performance = 0 282 | for epoch in range(args.epoch): 283 | if args.fine_tune: 284 | sample_p = 1 285 | else: 286 | sample_p = schedule_sample_p(epoch, args.epoch) 287 | tq = tqdm(DataLoader(train_set, batch_size = args.batch_size, shuffle=True)) 288 | for i, batch in enumerate(tq): 289 | encoded_dict = encode(tokenizer, batch) 290 | transformation_indices = token_idx_by_sentence(encoded_dict["input_ids"], tokenizer.sep_token_id, args.repfile) 291 | encoded_dict = {key: tensor.to(device) for key, tensor in encoded_dict.items()} 292 | transformation_indices = [tensor.to(device) for tensor in transformation_indices] 293 | stance_label = batch["stance"].to(device) 294 | padded_rationale_label, rationale_label = batch_rationale_label(batch["label"], padding_idx = 2) 295 | rationale_out, stance_out, rationale_loss, stance_loss = \ 296 | model(encoded_dict, transformation_indices, stance_label = stance_label, 297 | rationale_label = padded_rationale_label.to(device), sample_p = sample_p) 298 | rationale_loss *= args.loss_ratio 299 | loss = rationale_loss + stance_loss 300 | loss.backward() 301 | 302 | if i % args.update_step == args.update_step - 1: 303 | optimizer.step() 304 | optimizer.zero_grad() 305 | tq.set_description(f'Epoch {epoch}, iter {i}, loss: {round(loss.item(), 4)}, stance loss: {round(stance_loss.item(), 4)}, rationale loss: {round(rationale_loss.item(), 4)}') 306 | scheduler.step() 307 | 308 | # Evaluation 309 | train_score = evaluation(model, train_set) 310 | print(f'Epoch {epoch}, train stance f1 p r: %.4f, %.4f, %.4f, rationale f1 p r: %.4f, %.4f, %.4f' % train_score) 311 | 312 | dev_score = evaluation(model, dev_set) 313 | print(f'Epoch {epoch}, dev stance f1 p r: %.4f, %.4f, %.4f, rationale f1 p r: %.4f, %.4f, %.4f' % dev_score) 314 | 315 | dev_perf = dev_score[0] * dev_score[3] 316 | if dev_perf >= prev_performance: 317 | torch.save(model.state_dict(), args.checkpoint) 318 | best_state_dict = model.state_dict() 319 | prev_performance = dev_perf 320 | print("New model saved!") 321 | else: 322 | print("Skip saving model.") 323 | 324 | 325 | if test: 326 | if train: 327 | del model 328 | model = JointParagraphKGATClassifier(args.repfile, args.bert_dim, 329 | args.dropout, kernel = args.kernel).to(device) 330 | model.load_state_dict(best_state_dict) 331 | print("Testing on the new model.") 332 | else: 333 | model.load_state_dict(torch.load(args.checkpoint)) 334 | print("Loaded saved model.") 335 | 336 | rationale_predictions, stance_preds = predict(model, dev_set) 337 | rationale_json = rationale2json(dev_set.samples, rationale_predictions) 338 | stance_json = stance2json(dev_set.samples, stance_preds) 339 | stance_json = post_process_stance(rationale_json, stance_json) 340 | merged_json = merge_json(rationale_json, stance_json) 341 | 342 | with jsonlines.open(args.prediction, 'w') as output: 343 | for result in merged_json: 344 | output.write(result) 345 | 346 | data = GoldDataset(args.corpus_file, args.dataset) 347 | predictions = PredictedDataset(data, args.prediction) 348 | res = metrics.compute_metrics(predictions) 349 | params["evaluation"] = res 350 | with jsonlines.open(args.log_file, mode='a') as writer: 351 | writer.write(params) -------------------------------------------------------------------------------- /scifact_joint_paragraph_kgat_prediction.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import torch 4 | import jsonlines 5 | import os 6 | 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | from torch.utils.data import Dataset, DataLoader 10 | from transformers import AutoModel, AutoTokenizer, get_cosine_schedule_with_warmup 11 | from tqdm import tqdm 12 | from typing import List 13 | from sklearn.metrics import f1_score, precision_score, recall_score 14 | 15 | import math 16 | import random 17 | import numpy as np 18 | 19 | from tqdm import tqdm 20 | from util import arg2param, flatten, stance2json, rationale2json, merge_json 21 | from paragraph_model_kgat import JointParagraphKGATClassifier 22 | from dataset import SciFactParagraphBatchDataset 23 | 24 | import logging 25 | 26 | from lib.data import GoldDataset, PredictedDataset 27 | from lib import metrics 28 | 29 | def reset_random_seed(seed): 30 | np.random.seed(seed) 31 | random.seed(seed) 32 | torch.manual_seed(seed) 33 | 34 | def predict(model, dataset): 35 | model.eval() 36 | rationale_predictions = [] 37 | stance_preds = [] 38 | 39 | def remove_dummy(rationale_out): 40 | return [out[1:] for out in rationale_out] 41 | 42 | with torch.no_grad(): 43 | for batch in tqdm(DataLoader(dataset, batch_size = args.batch_size, shuffle=False)): 44 | encoded_dict = encode(tokenizer, batch) 45 | transformation_indices = token_idx_by_sentence(encoded_dict["input_ids"], 46 | tokenizer.sep_token_id, args.repfile) 47 | encoded_dict = {key: tensor.to(device) for key, tensor in encoded_dict.items()} 48 | transformation_indices = [tensor.to(device) for tensor in transformation_indices] 49 | rationale_out, stance_out, _, _ = model(encoded_dict, transformation_indices) 50 | stance_preds.extend(stance_out) 51 | rationale_predictions.extend(remove_dummy(rationale_out)) 52 | 53 | return rationale_predictions, stance_preds 54 | 55 | 56 | 57 | def encode(tokenizer, batch, max_sent_len = 512): 58 | def truncate(input_ids, max_length, sep_token_id, pad_token_id): 59 | def longest_first_truncation(sentences, objective): 60 | sent_lens = [len(sent) for sent in sentences] 61 | while np.sum(sent_lens) > objective: 62 | max_position = np.argmax(sent_lens) 63 | sent_lens[max_position] -= 1 64 | return [sentence[:length] for sentence, length in zip(sentences, sent_lens)] 65 | 66 | all_paragraphs = [] 67 | for paragraph in input_ids: 68 | valid_paragraph = paragraph[paragraph != pad_token_id] 69 | if valid_paragraph.size(0) <= max_length: 70 | all_paragraphs.append(paragraph[:max_length].unsqueeze(0)) 71 | else: 72 | sep_token_idx = np.arange(valid_paragraph.size(0))[(valid_paragraph == sep_token_id).numpy()] 73 | idx_by_sentence = [] 74 | prev_idx = 0 75 | for idx in sep_token_idx: 76 | idx_by_sentence.append(paragraph[prev_idx:idx]) 77 | prev_idx = idx 78 | objective = max_length - 1 - len(idx_by_sentence[0]) # The last sep_token left out 79 | truncated_sentences = longest_first_truncation(idx_by_sentence[1:], objective) 80 | truncated_paragraph = torch.cat([idx_by_sentence[0]] + truncated_sentences + [torch.tensor([sep_token_id])],0) 81 | all_paragraphs.append(truncated_paragraph.unsqueeze(0)) 82 | 83 | return torch.cat(all_paragraphs, 0) 84 | 85 | inputs = zip(batch["claim"], batch["paragraph"]) 86 | encoded_dict = tokenizer.batch_encode_plus( 87 | inputs, 88 | pad_to_max_length=True,add_special_tokens=True, 89 | return_tensors='pt') 90 | if encoded_dict['input_ids'].size(1) > max_sent_len: 91 | if 'token_type_ids' in encoded_dict: 92 | encoded_dict = { 93 | "input_ids": truncate(encoded_dict['input_ids'], max_sent_len, 94 | tokenizer.sep_token_id, tokenizer.pad_token_id), 95 | 'token_type_ids': encoded_dict['token_type_ids'][:,:max_sent_len], 96 | 'attention_mask': encoded_dict['attention_mask'][:,:max_sent_len] 97 | } 98 | else: 99 | encoded_dict = { 100 | "input_ids": truncate(encoded_dict['input_ids'], max_sent_len, 101 | tokenizer.sep_token_id, tokenizer.pad_token_id), 102 | 'attention_mask': encoded_dict['attention_mask'][:,:max_sent_len] 103 | } 104 | return encoded_dict 105 | 106 | def token_idx_by_sentence(input_ids, sep_token_id, model_name): 107 | """ 108 | Advanced indexing: Compute the token indices matrix of the BERT output. 109 | input_ids: (batch_size, paragraph_len) 110 | batch_indices, indices_by_batch, mask: (batch_size, N_sentence, N_token) 111 | bert_out: (batch_size, paragraph_len,BERT_dim) 112 | bert_out[batch_indices,indices_by_batch,:]: (batch_size, N_sentence, N_token, BERT_dim) 113 | """ 114 | padding_idx = -1 115 | sep_tokens = (input_ids == sep_token_id).bool() 116 | paragraph_lens = torch.sum(sep_tokens,1).numpy().tolist() # i.e. N_sentences per paragraph 117 | indices = torch.arange(sep_tokens.size(-1)).unsqueeze(0).expand(sep_tokens.size(0),-1) # 0,1,2,3,....,511 for each sentence 118 | sep_indices = torch.split(indices[sep_tokens],paragraph_lens) # indices of SEP tokens per paragraph 119 | paragraph_lens = [] 120 | all_word_indices = [] 121 | for paragraph in sep_indices: 122 | # claim sentence: [CLS] token1 token2 ... tokenk 123 | claim_word_indices = torch.arange(0, paragraph[0]) 124 | if "roberta" in model_name: # Huggingface Roberta has ...... 125 | paragraph = paragraph[1:] 126 | # each sentence: [SEP] token1 token2 ... tokenk, the last [SEP] in the paragraph is ditched. 127 | sentence_word_indices = [torch.arange(paragraph[i], paragraph[i+1]) for i in range(paragraph.size(0)-1)] 128 | 129 | # KGAT requires claim sentence, so add it back. 130 | word_indices = [claim_word_indices] + sentence_word_indices 131 | 132 | paragraph_lens.append(len(word_indices)) 133 | all_word_indices.extend(word_indices) 134 | indices_by_sentence = nn.utils.rnn.pad_sequence(all_word_indices, batch_first=True, padding_value=padding_idx) 135 | indices_by_sentence_split = torch.split(indices_by_sentence,paragraph_lens) 136 | indices_by_batch = nn.utils.rnn.pad_sequence(indices_by_sentence_split, batch_first=True, padding_value=padding_idx) 137 | batch_indices = torch.arange(sep_tokens.size(0)).unsqueeze(-1).unsqueeze(-1).expand(-1,indices_by_batch.size(1),indices_by_batch.size(-1)) 138 | mask = (indices_by_batch>=0) 139 | 140 | return batch_indices.long(), indices_by_batch.long(), mask.long() 141 | 142 | def post_process_stance(rationale_json, stance_json): 143 | assert(len(rationale_json) == len(stance_json)) 144 | for stance_pred, rationale_pred in zip(stance_json, rationale_json): 145 | assert(stance_pred["claim_id"] == rationale_pred["claim_id"]) 146 | for doc_id, pred in rationale_pred["evidence"].items(): 147 | if len(pred) == 0: 148 | stance_pred["labels"][doc_id]["label"] = "NOT_ENOUGH_INFO" 149 | return stance_json 150 | 151 | if __name__ == "__main__": 152 | argparser = argparse.ArgumentParser(description="Train, cross-validate and run sentence sequence tagger") 153 | argparser.add_argument('--repfile', type=str, default = "roberta-large", help="Word embedding file") 154 | argparser.add_argument('--corpus_file', type=str, default="/home/xxl190027/scifact_data/corpus.jsonl") 155 | argparser.add_argument('--test_file', type=str, default="/home/xxl190027/scifact_data/claims_dev_retrieved.jsonl") 156 | argparser.add_argument('--dropout', type=float, default=0, help="embedding_dropout rate") 157 | argparser.add_argument('--dataset', type=str, default="/home/xxl190027/scifact_data/claims_dev.jsonl") 158 | argparser.add_argument('--bert_dim', type=int, default=1024, help="bert_dimension") 159 | argparser.add_argument('--MAX_SENT_LEN', type=int, default=512) 160 | argparser.add_argument('--checkpoint', type=str, default = "scifact_roberta_joint_paragraph.model") 161 | argparser.add_argument('--batch_size', type=int, default=25) 162 | argparser.add_argument('--prediction', type=str, default = "prediction_scifact_roberta_joint_paragraph_kgat_fine_tune.jsonl") 163 | argparser.add_argument('--k', type=int, default=0) 164 | argparser.add_argument('--kernel', type=int, default=6) 165 | argparser.add_argument('--log_file', type=str, default = "kgat_prediction.log") 166 | 167 | logging.getLogger("transformers.tokenization_utils_base").setLevel(logging.ERROR) 168 | 169 | reset_random_seed(12345) 170 | 171 | args = argparser.parse_args() 172 | params = vars(args) 173 | 174 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 175 | tokenizer = AutoTokenizer.from_pretrained(args.repfile) 176 | 177 | dev_set = SciFactParagraphBatchDataset(args.corpus_file, args.test_file, 178 | sep_token = tokenizer.sep_token, k = args.k, train=False) 179 | 180 | model = JointParagraphKGATClassifier(args.repfile, args.bert_dim, 181 | args.dropout, kernel = args.kernel).to(device) 182 | 183 | model.load_state_dict(torch.load(args.checkpoint)) 184 | print("Loaded saved model.") 185 | 186 | reset_random_seed(12345) 187 | rationale_predictions, stance_preds = predict(model, dev_set) 188 | rationale_json = rationale2json(dev_set.samples, rationale_predictions) 189 | stance_json = stance2json(dev_set.samples, stance_preds) 190 | stance_json = post_process_stance(rationale_json, stance_json) 191 | merged_json = merge_json(rationale_json, stance_json) 192 | 193 | with jsonlines.open(args.prediction, 'w') as output: 194 | for result in merged_json: 195 | output.write(result) 196 | 197 | data = GoldDataset(args.corpus_file, args.dataset) 198 | predictions = PredictedDataset(data, args.prediction) 199 | res = metrics.compute_metrics(predictions) 200 | params["evaluation"] = res 201 | with jsonlines.open(args.log_file, mode='a') as writer: 202 | writer.write(params) -------------------------------------------------------------------------------- /scifact_rationale_paragraph.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import torch 4 | import jsonlines 5 | import os 6 | 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | from torch.utils.data import Dataset, DataLoader 10 | from transformers import AutoModel, AutoTokenizer, get_cosine_schedule_with_warmup 11 | from tqdm import tqdm 12 | from typing import List 13 | from sklearn.metrics import f1_score, precision_score, recall_score 14 | 15 | import random 16 | import numpy as np 17 | 18 | from tqdm import tqdm 19 | from util import arg2param, flatten, rationale2json 20 | from paragraph_model_dynamic import RationaleParagraphClassifier as JointParagraphClassifier 21 | from dataset import SciFactParagraphBatchDataset 22 | 23 | import logging 24 | 25 | def reset_random_seed(seed): 26 | np.random.seed(seed) 27 | random.seed(seed) 28 | torch.manual_seed(seed) 29 | 30 | def batch_rationale_label(labels, padding_idx = 2): 31 | max_sent_len = max([len(label) for label in labels]) 32 | label_matrix = torch.ones(len(labels), max_sent_len) * padding_idx 33 | label_list = [] 34 | for i, label in enumerate(labels): 35 | for j, evid in enumerate(label): 36 | label_matrix[i,j] = int(evid) 37 | label_list.append([int(evid) for evid in label]) 38 | return label_matrix.long(), label_list 39 | 40 | def predict(model, dataset): 41 | model.eval() 42 | rationale_predictions = [] 43 | 44 | with torch.no_grad(): 45 | for batch in tqdm(DataLoader(dataset, batch_size = args.batch_size, shuffle=False)): 46 | encoded_dict = encode(tokenizer, batch) 47 | transformation_indices = token_idx_by_sentence(encoded_dict["input_ids"], 48 | tokenizer.sep_token_id, args.repfile) 49 | encoded_dict = {key: tensor.to(device) for key, tensor in encoded_dict.items()} 50 | transformation_indices = [tensor.to(device) for tensor in transformation_indices] 51 | rationale_out, _= model(encoded_dict, transformation_indices) 52 | rationale_predictions.extend(rationale_out) 53 | 54 | return rationale_predictions 55 | 56 | def evaluation(model, dataset): 57 | model.eval() 58 | rationale_predictions = [] 59 | rationale_labels = [] 60 | 61 | with torch.no_grad(): 62 | for batch in tqdm(DataLoader(dataset, batch_size = args.batch_size*5, shuffle=False)): 63 | encoded_dict = encode(tokenizer, batch) 64 | transformation_indices = token_idx_by_sentence(encoded_dict["input_ids"], 65 | tokenizer.sep_token_id, args.repfile) 66 | encoded_dict = {key: tensor.to(device) for key, tensor in encoded_dict.items()} 67 | transformation_indices = [tensor.to(device) for tensor in transformation_indices] 68 | padded_rationale_label, rationale_label = batch_rationale_label(batch["label"], padding_idx = 2) 69 | rationale_out, rationale_loss = \ 70 | model(encoded_dict, transformation_indices, 71 | rationale_label = padded_rationale_label.to(device)) 72 | 73 | rationale_predictions.extend(rationale_out) 74 | rationale_labels.extend(rationale_label) 75 | 76 | rationale_f1 = f1_score(flatten(rationale_labels),flatten(rationale_predictions)) 77 | rationale_precision = precision_score(flatten(rationale_labels),flatten(rationale_predictions)) 78 | rationale_recall = recall_score(flatten(rationale_labels),flatten(rationale_predictions)) 79 | return rationale_f1, rationale_precision, rationale_recall 80 | 81 | 82 | 83 | def encode(tokenizer, batch, max_sent_len = 512): 84 | def truncate(input_ids, max_length, sep_token_id, pad_token_id): 85 | def longest_first_truncation(sentences, objective): 86 | sent_lens = [len(sent) for sent in sentences] 87 | while np.sum(sent_lens) > objective: 88 | max_position = np.argmax(sent_lens) 89 | sent_lens[max_position] -= 1 90 | return [sentence[:length] for sentence, length in zip(sentences, sent_lens)] 91 | 92 | all_paragraphs = [] 93 | for paragraph in input_ids: 94 | valid_paragraph = paragraph[paragraph != pad_token_id] 95 | if valid_paragraph.size(0) <= max_length: 96 | all_paragraphs.append(paragraph[:max_length].unsqueeze(0)) 97 | else: 98 | sep_token_idx = np.arange(valid_paragraph.size(0))[(valid_paragraph == sep_token_id).numpy()] 99 | idx_by_sentence = [] 100 | prev_idx = 0 101 | for idx in sep_token_idx: 102 | idx_by_sentence.append(paragraph[prev_idx:idx]) 103 | prev_idx = idx 104 | objective = max_length - 1 - len(idx_by_sentence[0]) # The last sep_token left out 105 | truncated_sentences = longest_first_truncation(idx_by_sentence[1:], objective) 106 | truncated_paragraph = torch.cat([idx_by_sentence[0]] + truncated_sentences + [torch.tensor([sep_token_id])],0) 107 | all_paragraphs.append(truncated_paragraph.unsqueeze(0)) 108 | 109 | return torch.cat(all_paragraphs, 0) 110 | 111 | inputs = zip(batch["claim"], batch["paragraph"]) 112 | encoded_dict = tokenizer.batch_encode_plus( 113 | inputs, 114 | pad_to_max_length=True,add_special_tokens=True, 115 | return_tensors='pt') 116 | if encoded_dict['input_ids'].size(1) > max_sent_len: 117 | if 'token_type_ids' in encoded_dict: 118 | encoded_dict = { 119 | "input_ids": truncate(encoded_dict['input_ids'], max_sent_len, 120 | tokenizer.sep_token_id, tokenizer.pad_token_id), 121 | 'token_type_ids': encoded_dict['token_type_ids'][:,:max_sent_len], 122 | 'attention_mask': encoded_dict['attention_mask'][:,:max_sent_len] 123 | } 124 | else: 125 | encoded_dict = { 126 | "input_ids": truncate(encoded_dict['input_ids'], max_sent_len, 127 | tokenizer.sep_token_id, tokenizer.pad_token_id), 128 | 'attention_mask': encoded_dict['attention_mask'][:,:max_sent_len] 129 | } 130 | 131 | return encoded_dict 132 | 133 | def sent_rep_indices(input_ids, sep_token_id, model_name): 134 | 135 | """ 136 | Compute the [SEP] indices matrix of the BERT output. 137 | input_ids: (batch_size, paragraph_len) 138 | batch_indices, indices_by_batch, mask: (batch_size, N_sentence) 139 | bert_out: (batch_size, paragraph_len,BERT_dim) 140 | bert_out[batch_indices,indices_by_batch,:]: (batch_size, N_sentence, BERT_dim) 141 | """ 142 | 143 | sep_tokens = (input_ids == sep_token_id).bool() 144 | paragraph_lens = torch.sum(sep_tokens,1).numpy().tolist() 145 | indices = torch.arange(sep_tokens.size(-1)).unsqueeze(0).expand(sep_tokens.size(0),-1) 146 | sep_indices = torch.split(indices[sep_tokens],paragraph_lens) 147 | padded_sep_indices = nn.utils.rnn.pad_sequence(sep_indices, batch_first=True, padding_value=-1) 148 | batch_indices = torch.arange(padded_sep_indices.size(0)).unsqueeze(-1).expand(-1,padded_sep_indices.size(-1)) 149 | mask = (padded_sep_indices>=0).long() 150 | 151 | if "roberta" in model_name: 152 | return batch_indices[:,2:], padded_sep_indices[:,2:], mask[:,2:] 153 | else: 154 | return batch_indices[:,1:], padded_sep_indices[:,1:], mask[:,1:] 155 | 156 | def token_idx_by_sentence(input_ids, sep_token_id, model_name): 157 | """ 158 | Compute the token indices matrix of the BERT output. 159 | input_ids: (batch_size, paragraph_len) 160 | batch_indices, indices_by_batch, mask: (batch_size, N_sentence, N_token) 161 | bert_out: (batch_size, paragraph_len,BERT_dim) 162 | bert_out[batch_indices,indices_by_batch,:]: (batch_size, N_sentence, N_token, BERT_dim) 163 | """ 164 | padding_idx = -1 165 | sep_tokens = (input_ids == sep_token_id).bool() 166 | paragraph_lens = torch.sum(sep_tokens,1).numpy().tolist() 167 | indices = torch.arange(sep_tokens.size(-1)).unsqueeze(0).expand(sep_tokens.size(0),-1) 168 | sep_indices = torch.split(indices[sep_tokens],paragraph_lens) 169 | paragraph_lens = [] 170 | all_word_indices = [] 171 | for paragraph in sep_indices: 172 | if "large" in model_name: 173 | paragraph = paragraph[1:] 174 | word_indices = [torch.arange(paragraph[i]+1, paragraph[i+1]+1) for i in range(paragraph.size(0)-1)] 175 | paragraph_lens.append(len(word_indices)) 176 | all_word_indices.extend(word_indices) 177 | indices_by_sentence = nn.utils.rnn.pad_sequence(all_word_indices, batch_first=True, padding_value=padding_idx) 178 | indices_by_sentence_split = torch.split(indices_by_sentence,paragraph_lens) 179 | indices_by_batch = nn.utils.rnn.pad_sequence(indices_by_sentence_split, batch_first=True, padding_value=padding_idx) 180 | batch_indices = torch.arange(sep_tokens.size(0)).unsqueeze(-1).unsqueeze(-1).expand(-1,indices_by_batch.size(1),indices_by_batch.size(-1)) 181 | mask = (indices_by_batch>=0) 182 | 183 | return batch_indices.long(), indices_by_batch.long(), mask.long() 184 | 185 | if __name__ == "__main__": 186 | argparser = argparse.ArgumentParser(description="Train, cross-validate and run sentence sequence tagger") 187 | argparser.add_argument('--repfile', type=str, default = "roberta-large", help="Word embedding file") 188 | argparser.add_argument('--corpus_file', type=str, default="/nas/home/xiangcil/scifact/data/corpus.jsonl") 189 | argparser.add_argument('--train_file', type=str, default="/nas/home/xiangcil/CitationEvaluation/SciFact/claims_train_retrieved.jsonl") 190 | argparser.add_argument('--pre_trained_model', type=str) 191 | #argparser.add_argument('--train_file', type=str) 192 | argparser.add_argument('--test_file', type=str, default="/nas/home/xiangcil/CitationEvaluation/SciFact/claims_dev_retrieved.jsonl") 193 | argparser.add_argument('--bert_lr', type=float, default=1e-5, help="Learning rate for BERT-like LM") 194 | argparser.add_argument('--lr', type=float, default=5e-6, help="Learning rate") 195 | argparser.add_argument('--dropout', type=float, default=0, help="embedding_dropout rate") 196 | argparser.add_argument('--bert_dim', type=int, default=1024, help="bert_dimension") 197 | argparser.add_argument('--epoch', type=int, default=20, help="Training epoch") 198 | argparser.add_argument('--MAX_SENT_LEN', type=int, default=512) 199 | argparser.add_argument('--checkpoint', type=str, default = "scifact_roberta_rationale_paragraph.model") 200 | argparser.add_argument('--log_file', type=str, default = "rationale_paragraph_roberta_performances.jsonl") 201 | argparser.add_argument('--prediction', type=str, default = "prediction_scifact_roberta_rationale_paragraph.jsonl") 202 | argparser.add_argument('--update_step', type=int, default=10) 203 | argparser.add_argument('--batch_size', type=int, default=1) # roberta-large: 2; bert: 8 204 | argparser.add_argument('--k', type=int, default=0) 205 | 206 | logging.getLogger("transformers.tokenization_utils_base").setLevel(logging.ERROR) 207 | 208 | reset_random_seed(12345) 209 | 210 | args = argparser.parse_args() 211 | 212 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 213 | tokenizer = AutoTokenizer.from_pretrained(args.repfile) 214 | 215 | if args.train_file: 216 | train = True 217 | #assert args.repfile is not None, "Word embedding file required for training." 218 | else: 219 | train = False 220 | if args.test_file: 221 | test = True 222 | else: 223 | test = False 224 | 225 | params = vars(args) 226 | 227 | for k,v in params.items(): 228 | print(k,v) 229 | 230 | if train: 231 | train_set = SciFactParagraphBatchDataset(args.corpus_file, args.train_file, 232 | sep_token = tokenizer.sep_token, k = args.k, dummy=False) 233 | dev_set = SciFactParagraphBatchDataset(args.corpus_file, args.test_file, 234 | sep_token = tokenizer.sep_token, k = args.k, dummy=False) 235 | 236 | model = JointParagraphClassifier(args.repfile, args.bert_dim, 237 | args.dropout)#.to(device) 238 | 239 | if args.pre_trained_model is not None: 240 | model.load_state_dict(torch.load(args.pre_trained_model)) 241 | model.reinitialize() ############ 242 | 243 | model = model.to(device) 244 | 245 | if train: 246 | settings = [{'params': model.bert.parameters(), 'lr': args.bert_lr}] 247 | for module in model.extra_modules: 248 | settings.append({'params': module.parameters(), 'lr': args.lr}) 249 | optimizer = torch.optim.Adam(settings) 250 | scheduler = get_cosine_schedule_with_warmup(optimizer, 0, args.epoch) 251 | model.train() 252 | 253 | prev_performance = 0 254 | for epoch in range(args.epoch): 255 | tq = tqdm(DataLoader(train_set, batch_size = args.batch_size, shuffle=True)) 256 | for i, batch in enumerate(tq): 257 | encoded_dict = encode(tokenizer, batch) 258 | transformation_indices = token_idx_by_sentence(encoded_dict["input_ids"], tokenizer.sep_token_id, args.repfile) 259 | encoded_dict = {key: tensor.to(device) for key, tensor in encoded_dict.items()} 260 | transformation_indices = [tensor.to(device) for tensor in transformation_indices] 261 | padded_rationale_label, rationale_label = batch_rationale_label(batch["label"], padding_idx = 2) 262 | rationale_out, loss = \ 263 | model(encoded_dict, transformation_indices, 264 | rationale_label = padded_rationale_label.to(device)) 265 | loss.backward() 266 | 267 | if i % args.update_step == args.update_step - 1: 268 | optimizer.step() 269 | optimizer.zero_grad() 270 | tq.set_description(f'Epoch {epoch}, iter {i}, loss: {round(loss.item(), 4)}') 271 | scheduler.step() 272 | 273 | # Evaluation 274 | train_score = evaluation(model, train_set) 275 | print(f'Epoch {epoch}, train rationale f1 p r: %.4f, %.4f, %.4f' % train_score) 276 | 277 | dev_score = evaluation(model, dev_set) 278 | print(f'Epoch {epoch}, rationale f1 p r: %.4f, %.4f, %.4f' % dev_score) 279 | 280 | dev_perf = dev_score[0] 281 | if dev_perf >= prev_performance: 282 | torch.save(model.state_dict(), args.checkpoint) 283 | best_state_dict = model.state_dict() 284 | prev_performance = dev_perf 285 | print("New model saved!") 286 | else: 287 | print("Skip saving model.") 288 | 289 | 290 | if test: 291 | if train: 292 | del model 293 | model = JointParagraphClassifier(args.repfile, args.bert_dim, 294 | args.dropout).to(device) 295 | model.load_state_dict(best_state_dict) 296 | print("Testing on the new model.") 297 | else: 298 | model.load_state_dict(torch.load(args.checkpoint)) 299 | print("Loaded saved model.") 300 | 301 | # Evaluation 302 | dev_score = evaluation(model, dev_set) 303 | print(f'Test rationale f1 p r: %.4f, %.4f, %.4f' % dev_score) 304 | 305 | params["rationale_f1"] = dev_score[0] 306 | params["rationale_precision"] = dev_score[1] 307 | params["rationale_recall"] = dev_score[2] 308 | 309 | with jsonlines.open(args.log_file, mode='a') as writer: 310 | writer.write(params) -------------------------------------------------------------------------------- /scifact_stance_paragraph.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import torch 4 | import jsonlines 5 | import os 6 | 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | from torch.utils.data import Dataset, DataLoader 10 | from transformers import AutoModel, AutoTokenizer, get_cosine_schedule_with_warmup 11 | from tqdm import tqdm 12 | from typing import List 13 | from sklearn.metrics import f1_score, precision_score, recall_score 14 | 15 | import random 16 | import numpy as np 17 | 18 | from tqdm import tqdm 19 | from util import arg2param, flatten, stance2json, merge_json 20 | from paragraph_model_dynamic import StanceParagraphClassifier as JointParagraphClassifier 21 | from dataset import SciFactStanceDataset as SciFactParagraphBatchDataset 22 | 23 | import logging 24 | 25 | def reset_random_seed(seed): 26 | np.random.seed(seed) 27 | random.seed(seed) 28 | torch.manual_seed(seed) 29 | 30 | def predict(model, dataset): 31 | model.eval() 32 | stance_preds = [] 33 | with torch.no_grad(): 34 | for batch in tqdm(DataLoader(dataset, batch_size = args.batch_size, shuffle=False)): 35 | encoded_dict = encode(tokenizer, batch) 36 | transformation_indices = token_idx_by_sentence(encoded_dict["input_ids"], 37 | tokenizer.sep_token_id, args.repfile) 38 | encoded_dict = {key: tensor.to(device) for key, tensor in encoded_dict.items()} 39 | transformation_indices = [tensor.to(device) for tensor in transformation_indices] 40 | stance_out, _ = model(encoded_dict, transformation_indices) 41 | stance_preds.extend(stance_out) 42 | 43 | return stance_preds 44 | 45 | def evaluation(model, dataset, dummy=True): 46 | model.eval() 47 | stance_preds = [] 48 | stance_labels = [] 49 | 50 | with torch.no_grad(): 51 | for batch in tqdm(DataLoader(dataset, batch_size = args.batch_size*5, shuffle=False)): 52 | encoded_dict = encode(tokenizer, batch) 53 | transformation_indices = token_idx_by_sentence(encoded_dict["input_ids"], 54 | tokenizer.sep_token_id, args.repfile) 55 | encoded_dict = {key: tensor.to(device) for key, tensor in encoded_dict.items()} 56 | transformation_indices = [tensor.to(device) for tensor in transformation_indices] 57 | stance_label = batch["stance"].to(device) 58 | stance_out, loss = \ 59 | model(encoded_dict, transformation_indices, stance_label = stance_label) 60 | stance_preds.extend(stance_out) 61 | stance_labels.extend(stance_label.cpu().numpy().tolist()) 62 | 63 | stance_f1 = f1_score(stance_labels,stance_preds,average="micro",labels=[1, 2]) 64 | stance_precision = precision_score(stance_labels,stance_preds,average="micro",labels=[1, 2]) 65 | stance_recall = recall_score(stance_labels,stance_preds,average="micro",labels=[1, 2]) 66 | return stance_f1, stance_precision, stance_recall 67 | 68 | 69 | 70 | def encode(tokenizer, batch, max_sent_len = 512): 71 | def truncate(input_ids, max_length, sep_token_id, pad_token_id): 72 | def longest_first_truncation(sentences, objective): 73 | sent_lens = [len(sent) for sent in sentences] 74 | while np.sum(sent_lens) > objective: 75 | max_position = np.argmax(sent_lens) 76 | sent_lens[max_position] -= 1 77 | return [sentence[:length] for sentence, length in zip(sentences, sent_lens)] 78 | 79 | all_paragraphs = [] 80 | for paragraph in input_ids: 81 | valid_paragraph = paragraph[paragraph != pad_token_id] 82 | if valid_paragraph.size(0) <= max_length: 83 | all_paragraphs.append(paragraph[:max_length].unsqueeze(0)) 84 | else: 85 | sep_token_idx = np.arange(valid_paragraph.size(0))[(valid_paragraph == sep_token_id).numpy()] 86 | idx_by_sentence = [] 87 | prev_idx = 0 88 | for idx in sep_token_idx: 89 | idx_by_sentence.append(paragraph[prev_idx:idx]) 90 | prev_idx = idx 91 | objective = max_length - 1 - len(idx_by_sentence[0]) # The last sep_token left out 92 | truncated_sentences = longest_first_truncation(idx_by_sentence[1:], objective) 93 | truncated_paragraph = torch.cat([idx_by_sentence[0]] + truncated_sentences + [torch.tensor([sep_token_id])],0) 94 | all_paragraphs.append(truncated_paragraph.unsqueeze(0)) 95 | 96 | return torch.cat(all_paragraphs, 0) 97 | 98 | inputs = zip(batch["claim"], batch["paragraph"]) 99 | encoded_dict = tokenizer.batch_encode_plus( 100 | inputs, 101 | pad_to_max_length=True,add_special_tokens=True, 102 | return_tensors='pt') 103 | if encoded_dict['input_ids'].size(1) > max_sent_len: 104 | if 'token_type_ids' in encoded_dict: 105 | encoded_dict = { 106 | "input_ids": truncate(encoded_dict['input_ids'], max_sent_len, 107 | tokenizer.sep_token_id, tokenizer.pad_token_id), 108 | 'token_type_ids': encoded_dict['token_type_ids'][:,:max_sent_len], 109 | 'attention_mask': encoded_dict['attention_mask'][:,:max_sent_len] 110 | } 111 | else: 112 | encoded_dict = { 113 | "input_ids": truncate(encoded_dict['input_ids'], max_sent_len, 114 | tokenizer.sep_token_id, tokenizer.pad_token_id), 115 | 'attention_mask': encoded_dict['attention_mask'][:,:max_sent_len] 116 | } 117 | 118 | return encoded_dict 119 | 120 | def sent_rep_indices(input_ids, sep_token_id, model_name): 121 | 122 | """ 123 | Compute the [SEP] indices matrix of the BERT output. 124 | input_ids: (batch_size, paragraph_len) 125 | batch_indices, indices_by_batch, mask: (batch_size, N_sentence) 126 | bert_out: (batch_size, paragraph_len,BERT_dim) 127 | bert_out[batch_indices,indices_by_batch,:]: (batch_size, N_sentence, BERT_dim) 128 | """ 129 | 130 | sep_tokens = (input_ids == sep_token_id).bool() 131 | paragraph_lens = torch.sum(sep_tokens,1).numpy().tolist() 132 | indices = torch.arange(sep_tokens.size(-1)).unsqueeze(0).expand(sep_tokens.size(0),-1) 133 | sep_indices = torch.split(indices[sep_tokens],paragraph_lens) 134 | padded_sep_indices = nn.utils.rnn.pad_sequence(sep_indices, batch_first=True, padding_value=-1) 135 | batch_indices = torch.arange(padded_sep_indices.size(0)).unsqueeze(-1).expand(-1,padded_sep_indices.size(-1)) 136 | mask = (padded_sep_indices>=0).long() 137 | 138 | if "roberta" in model_name: 139 | return batch_indices[:,2:], padded_sep_indices[:,2:], mask[:,2:] 140 | else: 141 | return batch_indices[:,1:], padded_sep_indices[:,1:], mask[:,1:] 142 | 143 | def token_idx_by_sentence(input_ids, sep_token_id, model_name): 144 | """ 145 | Compute the token indices matrix of the BERT output. 146 | input_ids: (batch_size, paragraph_len) 147 | batch_indices, indices_by_batch, mask: (batch_size, N_sentence, N_token) 148 | bert_out: (batch_size, paragraph_len,BERT_dim) 149 | bert_out[batch_indices,indices_by_batch,:]: (batch_size, N_sentence, N_token, BERT_dim) 150 | """ 151 | padding_idx = -1 152 | sep_tokens = (input_ids == sep_token_id).bool() 153 | paragraph_lens = torch.sum(sep_tokens,1).numpy().tolist() 154 | indices = torch.arange(sep_tokens.size(-1)).unsqueeze(0).expand(sep_tokens.size(0),-1) 155 | sep_indices = torch.split(indices[sep_tokens],paragraph_lens) 156 | paragraph_lens = [] 157 | all_word_indices = [] 158 | for paragraph in sep_indices: 159 | if "large" in model_name: 160 | paragraph = paragraph[1:] 161 | word_indices = [torch.arange(paragraph[i]+1, paragraph[i+1]+1) for i in range(paragraph.size(0)-1)] 162 | paragraph_lens.append(len(word_indices)) 163 | all_word_indices.extend(word_indices) 164 | indices_by_sentence = nn.utils.rnn.pad_sequence(all_word_indices, batch_first=True, padding_value=padding_idx) 165 | indices_by_sentence_split = torch.split(indices_by_sentence,paragraph_lens) 166 | indices_by_batch = nn.utils.rnn.pad_sequence(indices_by_sentence_split, batch_first=True, padding_value=padding_idx) 167 | batch_indices = torch.arange(sep_tokens.size(0)).unsqueeze(-1).unsqueeze(-1).expand(-1,indices_by_batch.size(1),indices_by_batch.size(-1)) 168 | mask = (indices_by_batch>=0) 169 | 170 | return batch_indices.long(), indices_by_batch.long(), mask.long() 171 | 172 | if __name__ == "__main__": 173 | argparser = argparse.ArgumentParser(description="Train, cross-validate and run sentence sequence tagger") 174 | argparser.add_argument('--repfile', type=str, default = "roberta-large", help="Word embedding file") 175 | argparser.add_argument('--corpus_file', type=str, default="/nas/home/xiangcil/scifact/data/corpus.jsonl") 176 | argparser.add_argument('--train_file', type=str, default="/nas/home/xiangcil/CitationEvaluation/SciFact/claims_train_retrieved.jsonl") 177 | argparser.add_argument('--pre_trained_model', type=str) 178 | #argparser.add_argument('--train_file', type=str) 179 | argparser.add_argument('--test_file', type=str, default="/nas/home/xiangcil/CitationEvaluation/SciFact/claims_dev_retrieved.jsonl") 180 | argparser.add_argument('--bert_lr', type=float, default=5e-6, help="Learning rate for BERT-like LM") 181 | argparser.add_argument('--lr', type=float, default=1e-6, help="Learning rate") 182 | argparser.add_argument('--dropout', type=float, default=0, help="embedding_dropout rate") 183 | argparser.add_argument('--bert_dim', type=int, default=1024, help="bert_dimension") 184 | argparser.add_argument('--epoch', type=int, default=20, help="Training epoch") 185 | argparser.add_argument('--MAX_SENT_LEN', type=int, default=512) 186 | argparser.add_argument('--checkpoint', type=str, default = "scifact_roberta_stance_paragraph.model") 187 | argparser.add_argument('--log_file', type=str, default = "stance_paragraph_roberta_performances.jsonl") 188 | argparser.add_argument('--prediction', type=str, default = "prediction_scifact_roberta_stance_paragraph.jsonl") 189 | argparser.add_argument('--update_step', type=int, default=10) 190 | argparser.add_argument('--batch_size', type=int, default=1) # roberta-large: 2; bert: 8 191 | argparser.add_argument('--k', type=int, default=0) 192 | 193 | logging.getLogger("transformers.tokenization_utils_base").setLevel(logging.ERROR) 194 | 195 | reset_random_seed(12345) 196 | 197 | args = argparser.parse_args() 198 | 199 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 200 | tokenizer = AutoTokenizer.from_pretrained(args.repfile) 201 | 202 | if args.train_file: 203 | train = True 204 | #assert args.repfile is not None, "Word embedding file required for training." 205 | else: 206 | train = False 207 | if args.test_file: 208 | test = True 209 | else: 210 | test = False 211 | 212 | params = vars(args) 213 | 214 | for k,v in params.items(): 215 | print(k,v) 216 | 217 | if train: 218 | train_set = SciFactParagraphBatchDataset(args.corpus_file, args.train_file, 219 | sep_token = tokenizer.sep_token, k = args.k) 220 | dev_set = SciFactParagraphBatchDataset(args.corpus_file, args.test_file, 221 | sep_token = tokenizer.sep_token, k = args.k) 222 | 223 | model = JointParagraphClassifier(args.repfile, args.bert_dim, 224 | args.dropout)#.to(device) 225 | 226 | if args.pre_trained_model is not None: 227 | model.load_state_dict(torch.load(args.pre_trained_model)) 228 | model.reinitialize() 229 | print("Reinitialized part of the model!") 230 | 231 | model = model.to(device) 232 | 233 | if train: 234 | settings = [{'params': model.bert.parameters(), 'lr': args.bert_lr}] 235 | for module in model.extra_modules: 236 | settings.append({'params': module.parameters(), 'lr': args.lr}) 237 | optimizer = torch.optim.Adam(settings) 238 | scheduler = get_cosine_schedule_with_warmup(optimizer, 0, args.epoch) 239 | model.train() 240 | 241 | prev_performance = 0 242 | for epoch in range(args.epoch): 243 | tq = tqdm(DataLoader(train_set, batch_size = args.batch_size, shuffle=True)) 244 | for i, batch in enumerate(tq): 245 | encoded_dict = encode(tokenizer, batch) 246 | transformation_indices = token_idx_by_sentence(encoded_dict["input_ids"], tokenizer.sep_token_id, args.repfile) 247 | encoded_dict = {key: tensor.to(device) for key, tensor in encoded_dict.items()} 248 | transformation_indices = [tensor.to(device) for tensor in transformation_indices] 249 | stance_label = batch["stance"].to(device) 250 | stance_out, loss = model(encoded_dict, transformation_indices, stance_label = stance_label) 251 | loss.backward() 252 | 253 | if i % args.update_step == args.update_step - 1: 254 | optimizer.step() 255 | optimizer.zero_grad() 256 | tq.set_description(f'Epoch {epoch}, iter {i}, loss: {round(loss.item(), 4)}') 257 | scheduler.step() 258 | 259 | # Evaluation 260 | train_score = evaluation(model, train_set) 261 | print(f'Epoch {epoch}, train stance f1 p r: %.4f, %.4f, %.4f' % train_score) 262 | 263 | dev_score = evaluation(model, dev_set) 264 | print(f'Epoch {epoch}, dev stance f1 p r: %.4f, %.4f, %.4f' % dev_score) 265 | 266 | dev_perf = dev_score[0] 267 | if dev_perf >= prev_performance: 268 | torch.save(model.state_dict(), args.checkpoint) 269 | best_state_dict = model.state_dict() 270 | prev_performance = dev_perf 271 | print("New model saved!") 272 | else: 273 | print("Skip saving model.") 274 | 275 | 276 | if test: 277 | if train: 278 | del model 279 | model = JointParagraphClassifier(args.repfile, args.bert_dim, 280 | args.dropout).to(device) 281 | model.load_state_dict(best_state_dict) 282 | print("Testing on the new model.") 283 | else: 284 | model.load_state_dict(torch.load(args.checkpoint)) 285 | print("Loaded saved model.") 286 | 287 | # Evaluation 288 | dev_score = evaluation(model, dev_set) 289 | print(f'Test stance f1 p r: %.4f, %.4f, %.4f' % dev_score) 290 | 291 | params["SciFact stance_f1"] = dev_score[0] 292 | params["SciFact stance_precision"] = dev_score[1] 293 | params["SciFact stance_recall"] = dev_score[2] 294 | 295 | with jsonlines.open(args.log_file, mode='a') as writer: 296 | writer.write(params) -------------------------------------------------------------------------------- /util.py: -------------------------------------------------------------------------------- 1 | import codecs 2 | import numpy 3 | import glob 4 | import re 5 | import numpy as np 6 | from sklearn.metrics import f1_score 7 | 8 | def flatten(arrayOfArray): 9 | array = [] 10 | for arr in arrayOfArray: 11 | try: 12 | array.extend(arr) 13 | except: 14 | array.append(arr) 15 | return array 16 | 17 | def read_passages(filename, is_labeled): 18 | str_seqs = [] 19 | str_seq = [] 20 | label_seqs = [] 21 | label_seq = [] 22 | for line in codecs.open(filename, "r", "utf-8"): 23 | lnstrp = line.strip() 24 | if lnstrp == "": 25 | if len(str_seq) != 0: 26 | str_seqs.append(str_seq) 27 | str_seq = [] 28 | label_seqs.append(label_seq) 29 | label_seq = [] 30 | else: 31 | if is_labeled: 32 | clause, label = lnstrp.split("\t") 33 | label_seq.append(label.strip()) 34 | else: 35 | clause = lnstrp 36 | str_seq.append(clause) 37 | if len(str_seq) != 0: 38 | str_seqs.append(str_seq) 39 | str_seq = [] 40 | label_seqs.append(label_seq) 41 | label_seq = [] 42 | return str_seqs, label_seqs 43 | 44 | def from_BIO_ind(BIO_pred, BIO_target, indices): 45 | table = {} # Make a mapping between the indices of BIO_labels and temporary original label indices 46 | original_labels = [] 47 | for BIO_label,BIO_index in indices.items(): 48 | if BIO_label[:2] == "I_" or BIO_label[:2] == "B_": 49 | label = BIO_label[2:] 50 | else: 51 | label = BIO_label 52 | if label in original_labels: 53 | table[BIO_index] = original_labels.index(label) 54 | else: 55 | table[BIO_index] = len(original_labels) 56 | original_labels.append(label) 57 | 58 | original_pred = [table[label] for label in BIO_pred] 59 | original_target = [table[label] for label in BIO_target] 60 | return original_pred, original_target 61 | 62 | def to_BIO(label_seqs): 63 | new_label_seqs = [] 64 | for label_para in label_seqs: 65 | new_label_para = [] 66 | prev = "" 67 | for label in label_para: 68 | if label!="none": # "none" is O, remain unchanged. 69 | if label==prev: 70 | new_label = "I_"+label 71 | else: 72 | new_label = "B_"+label 73 | else: 74 | new_label = label # "none" 75 | prev = label 76 | new_label_para.append(new_label) 77 | new_label_seqs.append(new_label_para) 78 | return new_label_seqs 79 | 80 | def from_BIO(label_seqs): 81 | new_label_seqs = [] 82 | for label_para in label_seqs: 83 | new_label_para = [] 84 | for label in label_para: 85 | if label[:2] == "I_" or label[:2] == "B_": 86 | new_label = label[2:] 87 | else: 88 | new_label = label 89 | new_label_para.append(new_label) 90 | new_label_seqs.append(new_label_para) 91 | return new_label_seqs 92 | 93 | def clean_url(word): 94 | """ 95 | Clean specific data format from social media 96 | """ 97 | # clean urls 98 | word = re.sub(r'https? : \/\/.*[\r\n]*', '', word) 99 | word = re.sub(r'exlink', '', word) 100 | return word 101 | 102 | def clean_num(word): 103 | # check if the word contain number and no letters 104 | if any(char.isdigit() for char in word): 105 | try: 106 | num = float(word.replace(',', '')) 107 | return '@' 108 | except: 109 | if not any(char.isalpha() for char in word): 110 | return '@' 111 | return word 112 | 113 | 114 | def clean_words(str_seqs): 115 | processed_seqs = [] 116 | for str_seq in str_seqs: 117 | processed_clauses = [] 118 | for clause in str_seq: 119 | filtered = [] 120 | tokens = clause.split() 121 | for word in tokens: 122 | word = clean_url(word) 123 | word = clean_num(word) 124 | filtered.append(word) 125 | filtered_clause = " ".join(filtered) 126 | processed_clauses.append(filtered_clause) 127 | processed_seqs.append(processed_clauses) 128 | return processed_seqs 129 | 130 | def test_f1(test_file,pred_label_seqs): 131 | def linearize(labels): 132 | linearized = [] 133 | for paper in labels: 134 | for label in paper: 135 | linearized.append(label) 136 | return linearized 137 | _, label_seqs = read_passages_original(test_file,True) 138 | true_label = linearize(label_seqs) 139 | pred_label = linearize(pred_label_seqs) 140 | 141 | f1 = f1_score(true_label,pred_label,average="weighted") 142 | print("F1 score:",f1) 143 | return f1 144 | 145 | def postprocess(dataset, raw_flattened_output, raw_flattened_labels, MAX_SEQ_LEN): 146 | ground_truth_labels = [] 147 | paragraph_lens = [] 148 | for para in dataset.true_pairs: 149 | paragraph_lens.append(len(para["paragraph"])) 150 | ground_truth_labels.append(para["label"]) 151 | 152 | raw_flattened_output = raw_flattened_output.tolist() 153 | raw_flattened_labels = raw_flattened_labels.tolist() 154 | batch_i = 0 155 | predicted_tags = [] 156 | gt_tags = [] 157 | for length in paragraph_lens: 158 | remaining_len = length 159 | predict_idx = [] 160 | gt_tag = [] 161 | while remaining_len > MAX_SEQ_LEN: 162 | this_batch = raw_flattened_output[batch_i*MAX_SEQ_LEN:(batch_i+1)*MAX_SEQ_LEN] 163 | this_batch_label = raw_flattened_labels[batch_i*MAX_SEQ_LEN:(batch_i+1)*MAX_SEQ_LEN] 164 | predict_idx.extend(this_batch) 165 | gt_tag.extend(this_batch_label) 166 | batch_i += 1 167 | remaining_len -= MAX_SEQ_LEN 168 | 169 | this_batch = raw_flattened_output[batch_i*MAX_SEQ_LEN:(batch_i+1)*MAX_SEQ_LEN] 170 | this_batch_label = raw_flattened_labels[batch_i*MAX_SEQ_LEN:(batch_i+1)*MAX_SEQ_LEN] 171 | predict_idx.extend(this_batch[:remaining_len]) 172 | gt_tag.extend(this_batch_label[:remaining_len]) 173 | predict_tag = [dataset.rev_label_ind[idx] for idx in predict_idx] 174 | gt_tag = [dataset.rev_label_ind[idx] for idx in gt_tag] 175 | batch_i += 1 176 | predicted_tags.append(predict_tag) 177 | gt_tags.append(gt_tag) 178 | 179 | 180 | predicted_tags = from_BIO(predicted_tags) 181 | final_gt = from_BIO(gt_tags) 182 | 183 | return predicted_tags, final_gt 184 | 185 | def stance_postprocess(dataset, raw_output, raw_labels, MAX_SEQ_LEN): 186 | 187 | def combine(candidates): 188 | assert(len(candidates)>0) 189 | types = set(candidates) 190 | if len(types) == 1: 191 | return list(types)[0] 192 | elif 2 in types: 193 | return 2 194 | else: 195 | return 1 196 | 197 | 198 | ground_truth_labels = [] 199 | paragraph_lens = [] 200 | for para in dataset.true_pairs: 201 | paragraph_lens.append(len(para["paragraph"])) 202 | ground_truth_labels.append(para["label"]) 203 | 204 | raw_output = raw_output.tolist() 205 | raw_labels = raw_labels.tolist() 206 | batch_i = 0 207 | predicted_tags = [] 208 | gt_tags = [] 209 | for length in paragraph_lens: 210 | remaining_len = length 211 | predict_idx = [] 212 | gt_tag = [] 213 | while remaining_len > MAX_SEQ_LEN: 214 | this_batch = raw_output[batch_i] 215 | this_batch_label = raw_labels[batch_i] 216 | predict_idx.append(this_batch) 217 | gt_tag.append(this_batch_label) 218 | batch_i += 1 219 | remaining_len -= MAX_SEQ_LEN 220 | 221 | this_batch = raw_output[batch_i] 222 | this_batch_label = raw_labels[batch_i] 223 | predict_idx.append(this_batch) 224 | gt_tag.append(this_batch_label) 225 | predict_tag = combine(predict_idx) 226 | gt_tag = combine(gt_tag) 227 | batch_i += 1 228 | predicted_tags.append(predict_tag) 229 | gt_tags.append(gt_tag) 230 | 231 | return predicted_tags, gt_tags 232 | 233 | def rationale2json(true_pairs, predictions, excluded_pairs = None): 234 | claim_ids = [] 235 | claims = {} 236 | assert(len(true_pairs) == len(predictions)) 237 | for pair, prediction in zip(true_pairs, predictions): 238 | claim_id = pair["claim_id"] 239 | claim_ids.append(claim_id) 240 | 241 | predicted_sentences = [] 242 | for i, pred in enumerate(prediction): 243 | if pred == "rationale" or pred == 1: 244 | predicted_sentences.append(i) 245 | 246 | this_claim = claims.get(claim_id, {"claim_id": claim_id, "evidence":{}}) 247 | #if len(predicted_sentences) > 0: 248 | this_claim["evidence"][pair["doc_id"]] = predicted_sentences 249 | claims[claim_id] = this_claim 250 | if excluded_pairs is not None: 251 | for pair in excluded_pairs: 252 | claims[pair["claim_id"]] = {"claim_id": pair["claim_id"], "evidence":{}} 253 | claim_ids.append(pair["claim_id"]) 254 | return [claims[claim_id] for claim_id in sorted(list(set(claim_ids)))] 255 | 256 | def stance2json(true_pairs, predictions, excluded_pairs = None): 257 | claim_ids = [] 258 | claims = {} 259 | idx2stance = ["NOT_ENOUGH_INFO", "SUPPORT", "CONTRADICT"] 260 | assert(len(true_pairs) == len(predictions)) 261 | for pair, prediction in zip(true_pairs, predictions): 262 | claim_id = pair["claim_id"] 263 | claim_ids.append(claim_id) 264 | 265 | this_claim = claims.get(claim_id, {"claim_id": claim_id, "labels":{}}) 266 | this_claim["labels"][pair["doc_id"]] = {"label": idx2stance[prediction], 'confidence': 1} 267 | claims[claim_id] = this_claim 268 | if excluded_pairs is not None: 269 | for pair in excluded_pairs: 270 | claims[pair["claim_id"]] = {"claim_id": pair["claim_id"], "labels":{}} 271 | claim_ids.append(pair["claim_id"]) 272 | return [claims[claim_id] for claim_id in sorted(list(set(claim_ids)))] 273 | 274 | def merge_json(rationale_jsons, stance_jsons): 275 | stance_json_dict = {str(stance_json["claim_id"]): stance_json for stance_json in stance_jsons} 276 | jsons = [] 277 | for rationale_json in rationale_jsons: 278 | id = str(rationale_json["claim_id"]) 279 | result = {} 280 | if id in stance_json_dict: 281 | for k, v in rationale_json["evidence"].items(): 282 | if len(v) > 0 and stance_json_dict[id]["labels"][int(k)]["label"] is not "NOT_ENOUGH_INFO": 283 | result[k] = { 284 | "sentences": v, 285 | "label": stance_json_dict[id]["labels"][int(k)]["label"] 286 | } 287 | jsons.append({"id":int(id), "evidence": result}) 288 | return jsons 289 | 290 | def arg2param(args): 291 | params = vars(args) 292 | params["MAX_SEQ_LEN"]=params["CHUNK_SIZE"]*params["CHUNK_PER_SEQ"] 293 | params["MINIBATCH_SIZE"] = params["CHUNK_PER_SEQ"] 294 | params["SENTENCE_BATCH_SIZE"]=params["CHUNK_SIZE"] 295 | params["CHUNK_PER_STEP"]=params["PARAGRAPH_PER_STEP"]*params["CHUNK_PER_SEQ"] 296 | 297 | return params 298 | --------------------------------------------------------------------------------