├── .gitignore
├── ComputeBioSentVecAbstractEmbedding.py
├── FEVER_joint_paragraph_dynamic.py
├── FEVER_joint_paragraph_kgat.py
├── FEVER_stance_paragraph.py
├── FEVER_stance_paragraph_kgat.py
├── README.md
├── SentVecAbstractRetriaval.py
├── TFIDFabstractRetrieval.py
├── dataset.py
├── domain_adaptation_joint_paragraph_dynamic.py
├── domain_adaptation_joint_paragraph_fine_tune.py
├── domain_adaptation_joint_paragraph_kgat.py
├── domain_adaptation_joint_paragraph_kgat_prediction.py
├── domain_adaptation_joint_paragraph_predict.py
├── lib
├── data.py
└── metrics.py
├── paragraph_model_dynamic.py
├── paragraph_model_kgat.py
├── requirements.txt
├── scifact_joint_paragraph_dynamic.py
├── scifact_joint_paragraph_dynamic_prediction.py
├── scifact_joint_paragraph_kgat.py
├── scifact_joint_paragraph_kgat_prediction.py
├── scifact_rationale_paragraph.py
├── scifact_stance_paragraph.py
└── util.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
130 |
--------------------------------------------------------------------------------
/ComputeBioSentVecAbstractEmbedding.py:
--------------------------------------------------------------------------------
1 | import sent2vec
2 | from nltk import word_tokenize
3 | from nltk.corpus import stopwords
4 | from string import punctuation
5 | from scipy.spatial import distance
6 | import glob
7 | import pickle
8 | import json
9 |
10 | import argparse
11 |
12 | def preprocess_sentence(text):
13 | text = text.replace('/', ' / ')
14 | text = text.replace('.-', ' .- ')
15 | text = text.replace('.', ' . ')
16 | text = text.replace('\'', ' \' ')
17 | text = text.lower()
18 |
19 | tokens = [token for token in word_tokenize(text) if token not in punctuation and token not in stop_words]
20 |
21 | return ' '.join(tokens)
22 |
23 | if __name__ == "__main__":
24 | argparser = argparse.ArgumentParser()
25 | argparser.add_argument('--claim_file', type=str)
26 | argparser.add_argument('--corpus_file', type=str)
27 | argparser.add_argument('--sentvec_path', type=str)
28 | argparser.add_argument('--corpus_embedding_pickle', type=str, default="corpus_paragraph_biosentvec.pkl")
29 | argparser.add_argument('--claim_embedding_pickle', type=str, default="claim_biosentvec.pkl")
30 |
31 | args = argparser.parse_args()
32 | claim_file = args.claim_file
33 | corpus_file = args.corpus_file
34 |
35 | corpus = {}
36 | with open(corpus_file) as f:
37 | for line in f:
38 | abstract = json.loads(line)
39 | corpus[str(abstract["doc_id"])] = abstract
40 |
41 | claims = []
42 | with open(claim_file) as f:
43 | for line in f:
44 | claim = json.loads(line)
45 | claims.append(claim)
46 |
47 | model_path = args.sentvec_path
48 | model = sent2vec.Sent2vecModel()
49 | try:
50 | model.load_model(model_path)
51 | except Exception as e:
52 | print(e)
53 | print('model successfully loaded')
54 |
55 | stop_words = set(stopwords.words('english'))
56 |
57 | # By paragraph embedding
58 | corpus_embeddings = {}
59 | for k, v in corpus.items():
60 | original_sentences = [v['title']] + v['abstract']
61 | processed_paragraph = " ".join([preprocess_sentence(sentence) for sentence in original_sentences])
62 | sentence_vector = model.embed_sentence(processed_paragraph)
63 | corpus_embeddings[k] = sentence_vector
64 |
65 | with open(args.corpus_embedding_pickle,"wb") as f:
66 | pickle.dump(corpus_embeddings,f)
67 |
68 | claim_embeddings = {}
69 | for claim in claims:
70 | processed_sentence = preprocess_sentence(claim['claim'])
71 | sentence_vector = model.embed_sentence(processed_sentence)
72 | claim_embeddings[claim["id"]] = sentence_vector
73 |
74 | with open(args.claim_embedding_pickle,"wb") as f:
75 | pickle.dump(claim_embeddings,f)
--------------------------------------------------------------------------------
/FEVER_joint_paragraph_dynamic.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import argparse
3 |
4 | import torch
5 | import jsonlines
6 | import os
7 |
8 | import functools
9 | print = functools.partial(print, flush=True)
10 |
11 | import torch.nn as nn
12 | import torch.nn.functional as F
13 | from torch.utils.data import Dataset, DataLoader, Subset
14 | from transformers import AutoModel, AutoTokenizer, get_cosine_schedule_with_warmup
15 | from tqdm import tqdm
16 | from typing import List
17 | from sklearn.metrics import f1_score, precision_score, recall_score
18 |
19 | import math
20 | import random
21 | import numpy as np
22 |
23 | from tqdm import tqdm
24 | from util import arg2param, flatten, stance2json, rationale2json
25 | from paragraph_model_dynamic import JointParagraphClassifier
26 | from dataset import FEVERParagraphBatchDataset
27 |
28 | import logging
29 |
30 | def schedule_sample_p(epoch, total):
31 | return np.sin(0.5* np.pi* epoch / (total-1))
32 |
33 | def reset_random_seed(seed):
34 | np.random.seed(seed)
35 | random.seed(seed)
36 | torch.manual_seed(seed)
37 |
38 | def batch_rationale_label(labels, padding_idx = 2):
39 | max_sent_len = max([len(label) for label in labels])
40 | label_matrix = torch.ones(len(labels), max_sent_len) * padding_idx
41 | label_list = []
42 | for i, label in enumerate(labels):
43 | for j, evid in enumerate(label):
44 | label_matrix[i,j] = int(evid)
45 | label_list.append([int(evid) for evid in label])
46 | return label_matrix.long(), label_list
47 |
48 | def evaluation(model, dataset):
49 | model.eval()
50 | rationale_predictions = []
51 | rationale_labels = []
52 | stance_preds = []
53 | stance_labels = []
54 |
55 | def remove_dummy(rationale_out):
56 | return [out[1:] for out in rationale_out]
57 |
58 | with torch.no_grad():
59 | for batch in tqdm(DataLoader(dataset, batch_size = args.batch_size, shuffle=False)):
60 | encoded_dict = encode(tokenizer, batch)
61 | transformation_indices = token_idx_by_sentence(encoded_dict["input_ids"],
62 | tokenizer.sep_token_id, args.repfile)
63 | encoded_dict = {key: tensor.to(device) for key, tensor in encoded_dict.items()}
64 | transformation_indices = [tensor.to(device) for tensor in transformation_indices]
65 | stance_label = batch["stance"].to(device)
66 | padded_rationale_label, rationale_label = batch_rationale_label(batch["label"], padding_idx = 2)
67 | if padded_rationale_label.size(1) == transformation_indices[-1].size(1):
68 | rationale_out, stance_out, rationale_loss, stance_loss = \
69 | model(encoded_dict, transformation_indices, stance_label = stance_label,
70 | rationale_label = padded_rationale_label.to(device))
71 | stance_preds.extend(stance_out)
72 | stance_labels.extend(stance_label.cpu().numpy().tolist())
73 |
74 | rationale_predictions.extend(remove_dummy(rationale_out))
75 | rationale_labels.extend(remove_dummy(rationale_label))
76 |
77 | stance_f1 = f1_score(stance_labels,stance_preds,average="micro",labels=[1, 2])
78 | stance_precision = precision_score(stance_labels,stance_preds,average="micro",labels=[1, 2])
79 | stance_recall = recall_score(stance_labels,stance_preds,average="micro",labels=[1, 2])
80 | rationale_f1 = f1_score(flatten(rationale_labels),flatten(rationale_predictions))
81 | rationale_precision = precision_score(flatten(rationale_labels),flatten(rationale_predictions))
82 | rationale_recall = recall_score(flatten(rationale_labels),flatten(rationale_predictions))
83 | return stance_f1, stance_precision, stance_recall, rationale_f1, rationale_precision, rationale_recall
84 |
85 |
86 |
87 | def encode(tokenizer, batch, max_sent_len = 512):
88 | def truncate(input_ids, max_length, sep_token_id, pad_token_id):
89 | def longest_first_truncation(sentences, objective):
90 | sent_lens = [len(sent) for sent in sentences]
91 | while np.sum(sent_lens) > objective:
92 | max_position = np.argmax(sent_lens)
93 | sent_lens[max_position] -= 1
94 | return [sentence[:length] for sentence, length in zip(sentences, sent_lens)]
95 |
96 | all_paragraphs = []
97 | for paragraph in input_ids:
98 | valid_paragraph = paragraph[paragraph != pad_token_id]
99 | if valid_paragraph.size(0) <= max_length:
100 | all_paragraphs.append(paragraph[:max_length].unsqueeze(0))
101 | else:
102 | sep_token_idx = np.arange(valid_paragraph.size(0))[(valid_paragraph == sep_token_id).numpy()]
103 | idx_by_sentence = []
104 | prev_idx = 0
105 | for idx in sep_token_idx:
106 | idx_by_sentence.append(paragraph[prev_idx:idx])
107 | prev_idx = idx
108 | objective = max_length - 1 - len(idx_by_sentence[0]) # The last sep_token left out
109 | truncated_sentences = longest_first_truncation(idx_by_sentence[1:], objective)
110 | truncated_paragraph = torch.cat([idx_by_sentence[0]] + truncated_sentences + [torch.tensor([sep_token_id])],0)
111 | all_paragraphs.append(truncated_paragraph.unsqueeze(0))
112 |
113 | return torch.cat(all_paragraphs, 0)
114 |
115 | inputs = zip(batch["claim"], batch["paragraph"])
116 | encoded_dict = tokenizer.batch_encode_plus(
117 | inputs,
118 | pad_to_max_length=True,add_special_tokens=True,
119 | return_tensors='pt')
120 | if encoded_dict['input_ids'].size(1) > max_sent_len:
121 | if 'token_type_ids' in encoded_dict:
122 | encoded_dict = {
123 | "input_ids": truncate(encoded_dict['input_ids'], max_sent_len,
124 | tokenizer.sep_token_id, tokenizer.pad_token_id),
125 | 'token_type_ids': encoded_dict['token_type_ids'][:,:max_sent_len],
126 | 'attention_mask': encoded_dict['attention_mask'][:,:max_sent_len]
127 | }
128 | else:
129 | encoded_dict = {
130 | "input_ids": truncate(encoded_dict['input_ids'], max_sent_len,
131 | tokenizer.sep_token_id, tokenizer.pad_token_id),
132 | 'attention_mask': encoded_dict['attention_mask'][:,:max_sent_len]
133 | }
134 |
135 | return encoded_dict
136 |
137 | def token_idx_by_sentence(input_ids, sep_token_id, model_name):
138 | """
139 | Compute the token indices matrix of the BERT output.
140 | input_ids: (batch_size, paragraph_len)
141 | batch_indices, indices_by_batch, mask: (batch_size, N_sentence, N_token)
142 | bert_out: (batch_size, paragraph_len,BERT_dim)
143 | bert_out[batch_indices,indices_by_batch,:]: (batch_size, N_sentence, N_token, BERT_dim)
144 | """
145 | padding_idx = -1
146 | sep_tokens = (input_ids == sep_token_id).bool()
147 | paragraph_lens = torch.sum(sep_tokens,1).numpy().tolist()
148 | indices = torch.arange(sep_tokens.size(-1)).unsqueeze(0).expand(sep_tokens.size(0),-1)
149 | sep_indices = torch.split(indices[sep_tokens],paragraph_lens)
150 | paragraph_lens = []
151 | all_word_indices = []
152 | for paragraph in sep_indices:
153 | if "roberta" in model_name:
154 | paragraph = paragraph[1:]
155 | word_indices = [torch.arange(paragraph[i]+1, paragraph[i+1]+1) for i in range(paragraph.size(0)-1)]
156 | paragraph_lens.append(len(word_indices))
157 | all_word_indices.extend(word_indices)
158 | indices_by_sentence = nn.utils.rnn.pad_sequence(all_word_indices, batch_first=True, padding_value=padding_idx)
159 | indices_by_sentence_split = torch.split(indices_by_sentence,paragraph_lens)
160 | indices_by_batch = nn.utils.rnn.pad_sequence(indices_by_sentence_split, batch_first=True, padding_value=padding_idx)
161 | batch_indices = torch.arange(sep_tokens.size(0)).unsqueeze(-1).unsqueeze(-1).expand(-1,indices_by_batch.size(1),indices_by_batch.size(-1))
162 | mask = (indices_by_batch>=0)
163 |
164 | return batch_indices.long(), indices_by_batch.long(), mask.long()
165 |
166 | if __name__ == "__main__":
167 | argparser = argparse.ArgumentParser(description="Train, cross-validate and run sentence sequence tagger")
168 | argparser.add_argument('--repfile', type=str, default = "roberta-large", help="Word embedding file")
169 | argparser.add_argument('--train_file', type=str, default="/nas/home/xiangcil/scifact/data/fever_train_retrieved.jsonl")
170 | argparser.add_argument('--pre_trained_model', type=str)
171 | #argparser.add_argument('--train_file', type=str)
172 | argparser.add_argument('--test_file', type=str, default="/nas/home/xiangcil/scifact/data/fever_dev_retrieved.jsonl")
173 | argparser.add_argument('--bert_lr', type=float, default=1e-5, help="Learning rate for BERT-like LM")
174 | argparser.add_argument('--lr', type=float, default=5e-6, help="Learning rate")
175 | argparser.add_argument('--dropout', type=float, default=0, help="embedding_dropout rate")
176 | argparser.add_argument('--bert_dim', type=int, default=1024, help="bert_dimension")
177 | argparser.add_argument('--epoch', type=int, default=10, help="Training epoch")
178 | argparser.add_argument('--MAX_SENT_LEN', type=int, default=512)
179 | argparser.add_argument('--loss_ratio', type=float, default=5)
180 | argparser.add_argument('--checkpoint', type=str, default = "fever_roberta_joint_paragraph_dynamic")
181 | argparser.add_argument('--log_file', type=str, default = "fever_joint_paragraph_performances.jsonl")
182 | argparser.add_argument('--update_step', type=int, default=10)
183 | argparser.add_argument('--batch_size', type=int, default=1) # roberta-large: 2; bert: 8
184 | argparser.add_argument('--k', type=int, default=0)
185 | argparser.add_argument('--evaluation_step', type=int, default=50000)
186 | logging.getLogger("transformers.tokenization_utils_base").setLevel(logging.ERROR)
187 |
188 | reset_random_seed(12345)
189 |
190 | args = argparser.parse_args()
191 |
192 | with open(args.checkpoint+".log", 'w') as f:
193 | sys.stdout = f
194 |
195 | device = torch.device("cuda:0" if torch.cuda.is_available() else 'cpu')
196 | tokenizer = AutoTokenizer.from_pretrained(args.repfile)
197 |
198 | if args.train_file:
199 | train = True
200 | #assert args.repfile is not None, "Word embedding file required for training."
201 | else:
202 | train = False
203 | if args.test_file:
204 | test = True
205 | else:
206 | test = False
207 |
208 | params = vars(args)
209 |
210 | for k,v in params.items():
211 | print(k,v)
212 |
213 | if train:
214 | train_set = FEVERParagraphBatchDataset(args.train_file,
215 | sep_token = tokenizer.sep_token, k=args.k)
216 | dev_set = FEVERParagraphBatchDataset(args.test_file,
217 | sep_token = tokenizer.sep_token, k=args.k)
218 |
219 | print("Loaded dataset!")
220 |
221 | model = JointParagraphClassifier(args.repfile, args.bert_dim,
222 | args.dropout).to(device)
223 |
224 | if args.pre_trained_model is not None:
225 | model.load_state_dict(torch.load(args.pre_trained_model))
226 | print("Loaded pre-trained model.")
227 |
228 | if train:
229 | settings = [{'params': model.bert.parameters(), 'lr': args.bert_lr}]
230 | for module in model.extra_modules:
231 | settings.append({'params': module.parameters(), 'lr': args.lr})
232 | optimizer = torch.optim.Adam(settings)
233 | scheduler = get_cosine_schedule_with_warmup(optimizer, 0, args.epoch)
234 |
235 | #if torch.cuda.device_count() > 1:
236 | # print("Let's use", torch.cuda.device_count(), "GPUs!")
237 | # # dim = 0 [30, xxx] -> [10, ...], [10, ...], [10, ...] on 3 GPUs
238 | # model = nn.DataParallel(model)
239 |
240 | model.train()
241 |
242 | for epoch in range(args.epoch):
243 | error_count = 0
244 | sample_p = schedule_sample_p(epoch, args.epoch)
245 | tq = tqdm(DataLoader(train_set, batch_size = args.batch_size, shuffle=True))
246 | for i, batch in enumerate(tq):
247 | encoded_dict = encode(tokenizer, batch)
248 | transformation_indices = token_idx_by_sentence(encoded_dict["input_ids"], tokenizer.sep_token_id, args.repfile)
249 | encoded_dict = {key: tensor.to(device) for key, tensor in encoded_dict.items()}
250 | transformation_indices = [tensor.to(device) for tensor in transformation_indices]
251 | stance_label = batch["stance"].to(device)
252 | padded_rationale_label, rationale_label = batch_rationale_label(batch["label"], padding_idx = 2)
253 | if padded_rationale_label.size(1) == transformation_indices[-1].size(1): # Skip some rare cases with inconsistent input size.
254 | rationale_out, stance_out, rationale_loss, stance_loss = \
255 | model(encoded_dict, transformation_indices, stance_label = stance_label,
256 | rationale_label = padded_rationale_label.to(device), sample_p = sample_p)
257 | rationale_loss *= args.loss_ratio
258 | loss = rationale_loss + stance_loss
259 | loss.sum().backward()
260 | else:
261 | error_count += 1
262 |
263 | if i % args.update_step == args.update_step - 1:
264 | optimizer.step()
265 | optimizer.zero_grad()
266 | tq.set_description(f'Epoch {epoch}, iter {i}, loss: {round(loss.item(), 4)}, stance loss: {round(stance_loss.item(), 4)}, rationale loss: {round(rationale_loss.item(), 4)}')
267 |
268 |
269 | if i % args.evaluation_step == args.evaluation_step-1:
270 | torch.save(model.state_dict(), args.checkpoint+"_"+str(epoch)+"_"+str(i)+".model")
271 |
272 | # Evaluation
273 | subset_train = Subset(train_set, range(0, 1000))
274 | train_score = evaluation(model, subset_train)
275 | print(f'Epoch {epoch}, step {i}, train stance f1 p r: %.4f, %.4f, %.4f, rationale f1 p r: %.4f, %.4f, %.4f' % train_score)
276 |
277 | subset_dev = Subset(dev_set, range(0, 1000))
278 | dev_score = evaluation(model, subset_dev)
279 | print(f'Epoch {epoch}, step {i}, dev stance f1 p r: %.4f, %.4f, %.4f, rationale f1 p r: %.4f, %.4f, %.4f' % dev_score)
280 | scheduler.step()
281 | torch.save(model.state_dict(), args.checkpoint+"_"+str(epoch)+".model")
282 | print(error_count, "mismatch occurred.")
283 |
284 | # Evaluation
285 | subset_train = Subset(train_set, range(0, 10000))
286 | train_score = evaluation(model, subset_train)
287 | print(f'Epoch {epoch}, train stance f1 p r: %.4f, %.4f, %.4f, rationale f1 p r: %.4f, %.4f, %.4f' % train_score)
288 |
289 | subset_dev = Subset(dev_set, range(0, 10000))
290 | dev_score = evaluation(model, subset_dev)
291 | print(f'Epoch {epoch}, dev stance f1 p r: %.4f, %.4f, %.4f, rationale f1 p r: %.4f, %.4f, %.4f' % dev_score)
292 |
293 |
294 |
295 | if test:
296 | model = JointParagraphClassifier(args.repfile, args.bert_dim,
297 | args.dropout).to(device)
298 | model.load_state_dict(torch.load(args.checkpoint))
299 |
300 |
301 | # Evaluation
302 | subset_dev = Subset(dev_set, range(0, 10000))
303 | dev_score = evaluation(model, subset_dev)
304 | print(f'Test stance f1 p r: %.4f, %.4f, %.4f, rationale f1 p r: %.4f, %.4f, %.4f' % dev_score)
305 |
306 | if train:
307 | params["stance_f1"] = dev_score[0]
308 | params["stance_precision"] = dev_score[1]
309 | params["stance_recall"] = dev_score[2]
310 | params["rationale_f1"] = dev_score[3]
311 | params["rationale_precision"] = dev_score[4]
312 | params["rationale_recalls"] = dev_score[5]
313 |
314 | with jsonlines.open(args.log_file, mode='a') as writer:
315 | writer.write(params)
--------------------------------------------------------------------------------
/FEVER_joint_paragraph_kgat.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import argparse
3 |
4 | import torch
5 | import jsonlines
6 | import os
7 |
8 | import functools
9 | print = functools.partial(print, flush=True)
10 |
11 | import torch.nn as nn
12 | import torch.nn.functional as F
13 | from torch.utils.data import Dataset, DataLoader, Subset
14 | from transformers import AutoModel, AutoTokenizer, get_cosine_schedule_with_warmup
15 | from tqdm import tqdm
16 | from typing import List
17 | from sklearn.metrics import f1_score, precision_score, recall_score
18 |
19 | import math
20 | import random
21 | import numpy as np
22 |
23 | from tqdm import tqdm
24 | from util import arg2param, flatten, stance2json, rationale2json
25 | from paragraph_model_kgat import JointParagraphKGATClassifier
26 | from dataset import FEVERParagraphBatchDataset
27 |
28 | import logging
29 |
30 | def schedule_sample_p(epoch, total):
31 | return np.sin(0.5* np.pi* epoch / (total-1))
32 |
33 | def reset_random_seed(seed):
34 | np.random.seed(seed)
35 | random.seed(seed)
36 | torch.manual_seed(seed)
37 |
38 | def batch_rationale_label(labels, padding_idx = 2):
39 | max_sent_len = max([len(label) for label in labels])
40 | label_matrix = torch.ones(len(labels), max_sent_len) * padding_idx
41 | label_list = []
42 | for i, label in enumerate(labels):
43 | for j, evid in enumerate(label):
44 | label_matrix[i,j] = int(evid)
45 | label_list.append([int(evid) for evid in label])
46 | return label_matrix.long(), label_list
47 |
48 | def evaluation(model, dataset):
49 | model.eval()
50 | rationale_predictions = []
51 | rationale_labels = []
52 | stance_preds = []
53 | stance_labels = []
54 |
55 | def remove_dummy(rationale_out):
56 | return [out[1:] for out in rationale_out]
57 |
58 | with torch.no_grad():
59 | for batch in tqdm(DataLoader(dataset, batch_size = args.batch_size, shuffle=False)):
60 | encoded_dict = encode(tokenizer, batch)
61 | transformation_indices = token_idx_by_sentence(encoded_dict["input_ids"],
62 | tokenizer.sep_token_id, args.repfile)
63 | encoded_dict = {key: tensor.to(device) for key, tensor in encoded_dict.items()}
64 | transformation_indices = [tensor.to(device) for tensor in transformation_indices]
65 | stance_label = batch["stance"].to(device)
66 | padded_rationale_label, rationale_label = batch_rationale_label(batch["label"], padding_idx = 2)
67 | rationale_out, stance_out, rationale_loss, stance_loss = \
68 | model(encoded_dict, transformation_indices, stance_label = stance_label,
69 | rationale_label = padded_rationale_label.to(device))
70 | stance_preds.extend(stance_out)
71 | stance_labels.extend(stance_label.cpu().numpy().tolist())
72 |
73 | rationale_predictions.extend(remove_dummy(rationale_out))
74 | rationale_labels.extend(remove_dummy(rationale_label))
75 |
76 | stance_f1 = f1_score(stance_labels,stance_preds,average="micro",labels=[1, 2])
77 | stance_precision = precision_score(stance_labels,stance_preds,average="micro",labels=[1, 2])
78 | stance_recall = recall_score(stance_labels,stance_preds,average="micro",labels=[1, 2])
79 | rationale_f1 = f1_score(flatten(rationale_labels),flatten(rationale_predictions))
80 | rationale_precision = precision_score(flatten(rationale_labels),flatten(rationale_predictions))
81 | rationale_recall = recall_score(flatten(rationale_labels),flatten(rationale_predictions))
82 | return stance_f1, stance_precision, stance_recall, rationale_f1, rationale_precision, rationale_recall
83 |
84 |
85 |
86 | def encode(tokenizer, batch, max_sent_len = 512):
87 | def truncate(input_ids, max_length, sep_token_id, pad_token_id):
88 | def longest_first_truncation(sentences, objective):
89 | sent_lens = [len(sent) for sent in sentences]
90 | while np.sum(sent_lens) > objective:
91 | max_position = np.argmax(sent_lens)
92 | sent_lens[max_position] -= 1
93 | return [sentence[:length] for sentence, length in zip(sentences, sent_lens)]
94 |
95 | all_paragraphs = []
96 | for paragraph in input_ids:
97 | valid_paragraph = paragraph[paragraph != pad_token_id]
98 | if valid_paragraph.size(0) <= max_length:
99 | all_paragraphs.append(paragraph[:max_length].unsqueeze(0))
100 | else:
101 | sep_token_idx = np.arange(valid_paragraph.size(0))[(valid_paragraph == sep_token_id).numpy()]
102 | idx_by_sentence = []
103 | prev_idx = 0
104 | for idx in sep_token_idx:
105 | idx_by_sentence.append(paragraph[prev_idx:idx])
106 | prev_idx = idx
107 | objective = max_length - 1 - len(idx_by_sentence[0]) # The last sep_token left out
108 | truncated_sentences = longest_first_truncation(idx_by_sentence[1:], objective)
109 | truncated_paragraph = torch.cat([idx_by_sentence[0]] + truncated_sentences + [torch.tensor([sep_token_id])],0)
110 | all_paragraphs.append(truncated_paragraph.unsqueeze(0))
111 |
112 | return torch.cat(all_paragraphs, 0)
113 |
114 | inputs = zip(batch["claim"], batch["paragraph"])
115 | encoded_dict = tokenizer.batch_encode_plus(
116 | inputs,
117 | pad_to_max_length=True,add_special_tokens=True,
118 | return_tensors='pt')
119 | if encoded_dict['input_ids'].size(1) > max_sent_len:
120 | if 'token_type_ids' in encoded_dict:
121 | encoded_dict = {
122 | "input_ids": truncate(encoded_dict['input_ids'], max_sent_len,
123 | tokenizer.sep_token_id, tokenizer.pad_token_id),
124 | 'token_type_ids': encoded_dict['token_type_ids'][:,:max_sent_len],
125 | 'attention_mask': encoded_dict['attention_mask'][:,:max_sent_len]
126 | }
127 | else:
128 | encoded_dict = {
129 | "input_ids": truncate(encoded_dict['input_ids'], max_sent_len,
130 | tokenizer.sep_token_id, tokenizer.pad_token_id),
131 | 'attention_mask': encoded_dict['attention_mask'][:,:max_sent_len]
132 | }
133 |
134 | return encoded_dict
135 |
136 | def token_idx_by_sentence(input_ids, sep_token_id, model_name):
137 | """
138 | Advanced indexing: Compute the token indices matrix of the BERT output.
139 | input_ids: (batch_size, paragraph_len)
140 | batch_indices, indices_by_batch, mask: (batch_size, N_sentence, N_token)
141 | bert_out: (batch_size, paragraph_len,BERT_dim)
142 | bert_out[batch_indices,indices_by_batch,:]: (batch_size, N_sentence, N_token, BERT_dim)
143 | """
144 | padding_idx = -1
145 | sep_tokens = (input_ids == sep_token_id).bool()
146 | paragraph_lens = torch.sum(sep_tokens,1).numpy().tolist() # i.e. N_sentences per paragraph
147 | indices = torch.arange(sep_tokens.size(-1)).unsqueeze(0).expand(sep_tokens.size(0),-1) # 0,1,2,3,....,511 for each sentence
148 | sep_indices = torch.split(indices[sep_tokens],paragraph_lens) # indices of SEP tokens per paragraph
149 | paragraph_lens = []
150 | all_word_indices = []
151 | for paragraph in sep_indices:
152 | # claim sentence: [CLS] token1 token2 ... tokenk
153 | claim_word_indices = torch.arange(0, paragraph[0])
154 | if "roberta" in model_name: # Huggingface Roberta has ......
155 | paragraph = paragraph[1:]
156 | # each sentence: [SEP] token1 token2 ... tokenk, the last [SEP] in the paragraph is ditched.
157 | sentence_word_indices = [torch.arange(paragraph[i], paragraph[i+1]) for i in range(paragraph.size(0)-1)]
158 |
159 | # KGAT requires claim sentence, so add it back.
160 | word_indices = [claim_word_indices] + sentence_word_indices
161 |
162 | paragraph_lens.append(len(word_indices))
163 | all_word_indices.extend(word_indices)
164 | indices_by_sentence = nn.utils.rnn.pad_sequence(all_word_indices, batch_first=True, padding_value=padding_idx)
165 | indices_by_sentence_split = torch.split(indices_by_sentence,paragraph_lens)
166 | indices_by_batch = nn.utils.rnn.pad_sequence(indices_by_sentence_split, batch_first=True, padding_value=padding_idx)
167 | batch_indices = torch.arange(sep_tokens.size(0)).unsqueeze(-1).unsqueeze(-1).expand(-1,indices_by_batch.size(1),indices_by_batch.size(-1))
168 | mask = (indices_by_batch>=0)
169 |
170 | return batch_indices.long(), indices_by_batch.long(), mask.long()
171 |
172 | if __name__ == "__main__":
173 | argparser = argparse.ArgumentParser(description="Train, cross-validate and run sentence sequence tagger")
174 | argparser.add_argument('--repfile', type=str, default = "roberta-large", help="Word embedding file")
175 | argparser.add_argument('--train_file', type=str, default="/home/xxl190027/scifact_data/fever_train_retrieved_15.jsonl")
176 | argparser.add_argument('--pre_trained_model', type=str)
177 | #argparser.add_argument('--train_file', type=str)
178 | argparser.add_argument('--test_file', type=str, default="/home/xxl190027/scifact_data/fever_dev_retrieved_15.jsonl")
179 | argparser.add_argument('--bert_lr', type=float, default=5e-6, help="Learning rate for BERT-like LM")
180 | argparser.add_argument('--lr', type=float, default=5e-6, help="Learning rate")
181 | argparser.add_argument('--dropout', type=float, default=0, help="embedding_dropout rate")
182 | argparser.add_argument('--bert_dim', type=int, default=1024, help="bert_dimension")
183 | argparser.add_argument('--epoch', type=int, default=10, help="Training epoch")
184 | argparser.add_argument('--MAX_SENT_LEN', type=int, default=512)
185 | argparser.add_argument('--loss_ratio', type=float, default=1)
186 | argparser.add_argument('--checkpoint', type=str, default = "fever_roberta_joint_paragraph_kgat")
187 | argparser.add_argument('--log_file', type=str, default = "fever_joint_paragraph_performances.jsonl")
188 | argparser.add_argument('--update_step', type=int, default=10)
189 | argparser.add_argument('--batch_size', type=int, default=1) # roberta-large: 2; bert: 8
190 | argparser.add_argument('--k', type=int, default=0)
191 | argparser.add_argument('--kernel', type=int, default=6)
192 | argparser.add_argument('--evaluation_step', type=int, default=50000)
193 | logging.getLogger("transformers.tokenization_utils_base").setLevel(logging.ERROR)
194 |
195 | reset_random_seed(12345)
196 |
197 | args = argparser.parse_args()
198 |
199 | with open(args.checkpoint+".log", 'w') as f:
200 | sys.stdout = f
201 |
202 | device = torch.device("cuda:0" if torch.cuda.is_available() else 'cpu')
203 | tokenizer = AutoTokenizer.from_pretrained(args.repfile)
204 |
205 | if args.train_file:
206 | train = True
207 | #assert args.repfile is not None, "Word embedding file required for training."
208 | else:
209 | train = False
210 | if args.test_file:
211 | test = True
212 | else:
213 | test = False
214 |
215 | params = vars(args)
216 |
217 | for k,v in params.items():
218 | print(k,v)
219 |
220 | if train:
221 | train_set = FEVERParagraphBatchDataset(args.train_file,
222 | sep_token = tokenizer.sep_token, k=args.k)
223 | dev_set = FEVERParagraphBatchDataset(args.test_file,
224 | sep_token = tokenizer.sep_token, k=args.k)
225 |
226 | print("Loaded dataset!")
227 |
228 | model = JointParagraphKGATClassifier(args.repfile, args.bert_dim,
229 | args.dropout, kernel = args.kernel).to(device)
230 |
231 | if args.pre_trained_model is not None:
232 | model.load_state_dict(torch.load(args.pre_trained_model))
233 | print("Loaded pre-trained model.")
234 |
235 | if train:
236 | settings = [{'params': model.bert.parameters(), 'lr': args.bert_lr}]
237 | for module in model.extra_modules:
238 | settings.append({'params': module.parameters(), 'lr': args.lr})
239 | optimizer = torch.optim.Adam(settings)
240 | scheduler = get_cosine_schedule_with_warmup(optimizer, 0, args.epoch)
241 |
242 | #if torch.cuda.device_count() > 1:
243 | # print("Let's use", torch.cuda.device_count(), "GPUs!")
244 | # # dim = 0 [30, xxx] -> [10, ...], [10, ...], [10, ...] on 3 GPUs
245 | # model = nn.DataParallel(model)
246 |
247 | model.train()
248 |
249 | for epoch in range(args.epoch):
250 | sample_p = schedule_sample_p(epoch, args.epoch)
251 | tq = tqdm(DataLoader(train_set, batch_size = args.batch_size, shuffle=True))
252 | for i, batch in enumerate(tq):
253 | encoded_dict = encode(tokenizer, batch)
254 | transformation_indices = token_idx_by_sentence(encoded_dict["input_ids"], tokenizer.sep_token_id, args.repfile)
255 | encoded_dict = {key: tensor.to(device) for key, tensor in encoded_dict.items()}
256 | transformation_indices = [tensor.to(device) for tensor in transformation_indices]
257 | stance_label = batch["stance"].to(device)
258 | padded_rationale_label, rationale_label = batch_rationale_label(batch["label"], padding_idx = 2)
259 |
260 | rationale_out, stance_out, rationale_loss, stance_loss = \
261 | model(encoded_dict, transformation_indices, stance_label = stance_label,
262 | rationale_label = padded_rationale_label.to(device), sample_p = sample_p)
263 | rationale_loss *= args.loss_ratio
264 | loss = rationale_loss + stance_loss
265 | loss.sum().backward()
266 |
267 | if i % args.update_step == args.update_step - 1:
268 | optimizer.step()
269 | optimizer.zero_grad()
270 | tq.set_description(f'Epoch {epoch}, iter {i}, loss: {round(loss.item(), 4)}, stance loss: {round(stance_loss.item(), 4)}, rationale loss: {round(rationale_loss.item(), 4)}')
271 |
272 |
273 | if i % args.evaluation_step == args.evaluation_step-1:
274 | torch.save(model.state_dict(), args.checkpoint+"_"+str(epoch)+"_"+str(i)+".model")
275 |
276 | subset_train = Subset(train_set, range(0, 1000))
277 | train_score = evaluation(model, subset_train)
278 | print(f'Epoch {epoch} Step {i}, train stance f1 p r: %.4f, %.4f, %.4f, rationale f1 p r: %.4f, %.4f, %.4f' % train_score)
279 |
280 | subset_dev = Subset(dev_set, range(0, 1000))
281 | dev_score = evaluation(model, subset_dev)
282 | print(f'Epoch {epoch} Step {i}, dev stance f1 p r: %.4f, %.4f, %.4f, rationale f1 p r: %.4f, %.4f, %.4f' % dev_score)
283 | scheduler.step()
284 | torch.save(model.state_dict(), args.checkpoint+"_"+str(epoch)+".model")
285 |
286 | # Evaluation
287 | subset_train = Subset(train_set, range(0, 1000))
288 | train_score = evaluation(model, subset_train)
289 | print(f'Epoch {epoch}, train stance f1 p r: %.4f, %.4f, %.4f, rationale f1 p r: %.4f, %.4f, %.4f' % train_score)
290 |
291 | subset_dev = Subset(dev_set, range(0, 1000))
292 | dev_score = evaluation(model, subset_dev)
293 | print(f'Epoch {epoch}, dev stance f1 p r: %.4f, %.4f, %.4f, rationale f1 p r: %.4f, %.4f, %.4f' % dev_score)
294 |
295 |
296 |
297 | if test:
298 | model = JointParagraphKGATClassifier(args.repfile, args.bert_dim,
299 | args.dropout, kernel = args.kernel).to(device)
300 | model.load_state_dict(torch.load(args.checkpoint))
301 |
302 |
303 | # Evaluation
304 | reset_random_seed(12345)
305 | subset_dev = Subset(dev_set, range(0, 10000))
306 | dev_score = evaluation(model, subset_dev)
307 | print(f'Test stance f1 p r: %.4f, %.4f, %.4f, rationale f1 p r: %.4f, %.4f, %.4f' % dev_score)
308 |
309 | if train:
310 | params["stance_f1"] = dev_score[0]
311 | params["stance_precision"] = dev_score[1]
312 | params["stance_recall"] = dev_score[2]
313 | params["rationale_f1"] = dev_score[3]
314 | params["rationale_precision"] = dev_score[4]
315 | params["rationale_recalls"] = dev_score[5]
316 |
317 | with jsonlines.open(args.log_file, mode='a') as writer:
318 | writer.write(params)
--------------------------------------------------------------------------------
/FEVER_stance_paragraph.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import argparse
3 |
4 | import torch
5 | import jsonlines
6 | import os
7 |
8 | import functools
9 | print = functools.partial(print, flush=True)
10 |
11 | import torch.nn as nn
12 | import torch.nn.functional as F
13 | from torch.utils.data import Dataset, DataLoader, Subset
14 | from transformers import AutoModel, AutoTokenizer, get_cosine_schedule_with_warmup
15 | from tqdm import tqdm
16 | from typing import List
17 | from sklearn.metrics import f1_score, precision_score, recall_score
18 |
19 | import random
20 | import numpy as np
21 |
22 | from tqdm import tqdm
23 | from util import arg2param, flatten, stance2json, rationale2json, merge_json
24 | from paragraph_model_dynamic import StanceParagraphClassifier as JointParagraphClassifier
25 | from dataset import FEVERStanceDataset as FEVERParagraphBatchDataset
26 |
27 | import logging
28 |
29 | def reset_random_seed(seed):
30 | np.random.seed(seed)
31 | random.seed(seed)
32 | torch.manual_seed(seed)
33 |
34 | def evaluation(model, dataset):
35 | model.eval()
36 | stance_preds = []
37 | stance_labels = []
38 |
39 | with torch.no_grad():
40 | for batch in tqdm(DataLoader(dataset, batch_size = args.batch_size, shuffle=False)):
41 | encoded_dict = encode(tokenizer, batch)
42 | transformation_indices = token_idx_by_sentence(encoded_dict["input_ids"],
43 | tokenizer.sep_token_id, args.repfile)
44 | encoded_dict = {key: tensor.to(device) for key, tensor in encoded_dict.items()}
45 | transformation_indices = [tensor.to(device) for tensor in transformation_indices]
46 | stance_label = batch["stance"].to(device)
47 | stance_out, stance_loss = \
48 | model(encoded_dict, transformation_indices, stance_label = stance_label)
49 | stance_preds.extend(stance_out)
50 | stance_labels.extend(stance_label.cpu().numpy().tolist())
51 |
52 | stance_f1 = f1_score(stance_labels,stance_preds,average="micro",labels=[1, 2])
53 | stance_precision = precision_score(stance_labels,stance_preds,average="micro",labels=[1, 2])
54 | stance_recall = recall_score(stance_labels,stance_preds,average="micro",labels=[1, 2])
55 | return stance_f1, stance_precision, stance_recall
56 |
57 |
58 | def encode(tokenizer, batch, max_sent_len = 512):
59 | def truncate(input_ids, max_length, sep_token_id, pad_token_id):
60 | def longest_first_truncation(sentences, objective):
61 | sent_lens = [len(sent) for sent in sentences]
62 | while np.sum(sent_lens) > objective:
63 | max_position = np.argmax(sent_lens)
64 | sent_lens[max_position] -= 1
65 | return [sentence[:length] for sentence, length in zip(sentences, sent_lens)]
66 |
67 | all_paragraphs = []
68 | for paragraph in input_ids:
69 | valid_paragraph = paragraph[paragraph != pad_token_id]
70 | if valid_paragraph.size(0) <= max_length:
71 | all_paragraphs.append(paragraph[:max_length].unsqueeze(0))
72 | else:
73 | sep_token_idx = np.arange(valid_paragraph.size(0))[(valid_paragraph == sep_token_id).numpy()]
74 | idx_by_sentence = []
75 | prev_idx = 0
76 | for idx in sep_token_idx:
77 | idx_by_sentence.append(paragraph[prev_idx:idx])
78 | prev_idx = idx
79 | objective = max_length - 1 - len(idx_by_sentence[0]) # The last sep_token left out
80 | truncated_sentences = longest_first_truncation(idx_by_sentence[1:], objective)
81 | truncated_paragraph = torch.cat([idx_by_sentence[0]] + truncated_sentences + [torch.tensor([sep_token_id])],0)
82 | all_paragraphs.append(truncated_paragraph.unsqueeze(0))
83 |
84 | return torch.cat(all_paragraphs, 0)
85 |
86 | inputs = zip(batch["claim"], batch["paragraph"])
87 | encoded_dict = tokenizer.batch_encode_plus(
88 | inputs,
89 | pad_to_max_length=True,add_special_tokens=True,
90 | return_tensors='pt')
91 | if encoded_dict['input_ids'].size(1) > max_sent_len:
92 | if 'token_type_ids' in encoded_dict:
93 | encoded_dict = {
94 | "input_ids": truncate(encoded_dict['input_ids'], max_sent_len,
95 | tokenizer.sep_token_id, tokenizer.pad_token_id),
96 | 'token_type_ids': encoded_dict['token_type_ids'][:,:max_sent_len],
97 | 'attention_mask': encoded_dict['attention_mask'][:,:max_sent_len]
98 | }
99 | else:
100 | encoded_dict = {
101 | "input_ids": truncate(encoded_dict['input_ids'], max_sent_len,
102 | tokenizer.sep_token_id, tokenizer.pad_token_id),
103 | 'attention_mask': encoded_dict['attention_mask'][:,:max_sent_len]
104 | }
105 |
106 | return encoded_dict
107 |
108 | def token_idx_by_sentence(input_ids, sep_token_id, model_name):
109 | """
110 | Compute the token indices matrix of the BERT output.
111 | input_ids: (batch_size, paragraph_len)
112 | batch_indices, indices_by_batch, mask: (batch_size, N_sentence, N_token)
113 | bert_out: (batch_size, paragraph_len,BERT_dim)
114 | bert_out[batch_indices,indices_by_batch,:]: (batch_size, N_sentence, N_token, BERT_dim)
115 | """
116 | padding_idx = -1
117 | sep_tokens = (input_ids == sep_token_id).bool()
118 | paragraph_lens = torch.sum(sep_tokens,1).numpy().tolist()
119 | indices = torch.arange(sep_tokens.size(-1)).unsqueeze(0).expand(sep_tokens.size(0),-1)
120 | sep_indices = torch.split(indices[sep_tokens],paragraph_lens)
121 | paragraph_lens = []
122 | all_word_indices = []
123 | for paragraph in sep_indices:
124 | if "roberta" in model_name:
125 | paragraph = paragraph[1:]
126 | word_indices = [torch.arange(paragraph[i]+1, paragraph[i+1]+1) for i in range(paragraph.size(0)-1)]
127 | paragraph_lens.append(len(word_indices))
128 | all_word_indices.extend(word_indices)
129 | indices_by_sentence = nn.utils.rnn.pad_sequence(all_word_indices, batch_first=True, padding_value=padding_idx)
130 | indices_by_sentence_split = torch.split(indices_by_sentence,paragraph_lens)
131 | indices_by_batch = nn.utils.rnn.pad_sequence(indices_by_sentence_split, batch_first=True, padding_value=padding_idx)
132 | batch_indices = torch.arange(sep_tokens.size(0)).unsqueeze(-1).unsqueeze(-1).expand(-1,indices_by_batch.size(1),indices_by_batch.size(-1))
133 | mask = (indices_by_batch>=0)
134 |
135 | return batch_indices.long(), indices_by_batch.long(), mask.long()
136 |
137 | if __name__ == "__main__":
138 | argparser = argparse.ArgumentParser(description="Train, cross-validate and run sentence sequence tagger")
139 | argparser.add_argument('--repfile', type=str, default = "roberta-large", help="Word embedding file")
140 | argparser.add_argument('--train_file', type=str, default="/nas/home/xiangcil/scifact/data/fever_train_retrieved.jsonl")
141 | argparser.add_argument('--pre_trained_model', type=str)
142 | #argparser.add_argument('--train_file', type=str)
143 | argparser.add_argument('--test_file', type=str, default="/nas/home/xiangcil/scifact/data/fever_dev_retrieved.jsonl")
144 | argparser.add_argument('--bert_lr', type=float, default=5e-6, help="Learning rate for BERT-like LM")
145 | argparser.add_argument('--lr', type=float, default=1e-6, help="Learning rate")
146 | argparser.add_argument('--dropout', type=float, default=0, help="embedding_dropout rate")
147 | argparser.add_argument('--bert_dim', type=int, default=1024, help="bert_dimension")
148 | argparser.add_argument('--epoch', type=int, default=10, help="Training epoch")
149 | argparser.add_argument('--MAX_SENT_LEN', type=int, default=512)
150 | argparser.add_argument('--checkpoint', type=str, default = "fever_roberta_stance_paragraph")
151 | argparser.add_argument('--log_file', type=str, default = "fever_stance_paragraph_performances.jsonl")
152 | argparser.add_argument('--update_step', type=int, default=10)
153 | argparser.add_argument('--batch_size', type=int, default=1) # roberta-large: 2; bert: 8
154 | argparser.add_argument('--k', type=int, default=0)
155 | argparser.add_argument('--evaluation_step', type=int, default=50000)
156 | logging.getLogger("transformers.tokenization_utils_base").setLevel(logging.ERROR)
157 |
158 | reset_random_seed(12345)
159 |
160 | args = argparser.parse_args()
161 |
162 | with open(args.checkpoint+".log", 'w') as f:
163 | sys.stdout = f
164 |
165 | device = torch.device("cuda:0" if torch.cuda.is_available() else 'cpu')
166 | tokenizer = AutoTokenizer.from_pretrained(args.repfile)
167 |
168 | if args.train_file:
169 | train = True
170 | #assert args.repfile is not None, "Word embedding file required for training."
171 | else:
172 | train = False
173 | if args.test_file:
174 | test = True
175 | else:
176 | test = False
177 |
178 | params = vars(args)
179 |
180 | for k,v in params.items():
181 | print(k,v)
182 |
183 | if train:
184 | train_set = FEVERParagraphBatchDataset(args.train_file,
185 | sep_token = tokenizer.sep_token, k=args.k)
186 | dev_set = FEVERParagraphBatchDataset(args.test_file,
187 | sep_token = tokenizer.sep_token, k=args.k)
188 |
189 | print("Loaded dataset!")
190 |
191 | model = JointParagraphClassifier(args.repfile, args.bert_dim,
192 | args.dropout).to(device)
193 |
194 | if args.pre_trained_model is not None:
195 | model.load_state_dict(torch.load(args.pre_trained_model))
196 | print("Loaded pre-trained model.")
197 |
198 | if train:
199 | settings = [{'params': model.bert.parameters(), 'lr': args.bert_lr}]
200 | for module in model.extra_modules:
201 | settings.append({'params': module.parameters(), 'lr': args.lr})
202 | optimizer = torch.optim.Adam(settings)
203 | scheduler = get_cosine_schedule_with_warmup(optimizer, 0, args.epoch)
204 |
205 | #if torch.cuda.device_count() > 1:
206 | # print("Let's use", torch.cuda.device_count(), "GPUs!")
207 | # # dim = 0 [30, xxx] -> [10, ...], [10, ...], [10, ...] on 3 GPUs
208 | # model = nn.DataParallel(model)
209 |
210 | model.train()
211 |
212 | for epoch in range(args.epoch):
213 | error_count = 0
214 | tq = tqdm(DataLoader(train_set, batch_size = args.batch_size, shuffle=True))
215 | for i, batch in enumerate(tq):
216 | encoded_dict = encode(tokenizer, batch)
217 | transformation_indices = token_idx_by_sentence(encoded_dict["input_ids"], tokenizer.sep_token_id, args.repfile)
218 | encoded_dict = {key: tensor.to(device) for key, tensor in encoded_dict.items()}
219 | transformation_indices = [tensor.to(device) for tensor in transformation_indices]
220 | stance_label = batch["stance"].to(device)
221 |
222 | stance_out, loss = \
223 | model(encoded_dict, transformation_indices, stance_label = stance_label)
224 | loss.sum().backward()
225 |
226 | if i % args.update_step == args.update_step - 1:
227 | optimizer.step()
228 | optimizer.zero_grad()
229 | tq.set_description(f'Epoch {epoch}, iter {i}, loss: {round(loss.item(), 4)}')
230 |
231 |
232 | if i % args.evaluation_step == args.evaluation_step-1:
233 | torch.save(model.state_dict(), args.checkpoint+"_"+str(epoch)+"_"+str(i)+".model")
234 |
235 | # Evaluation
236 | subset_train = Subset(train_set, range(0, 1000))
237 | train_score = evaluation(model, subset_train)
238 | print(f'Epoch {epoch}, step {i}, train stance f1 p r: %.4f, %.4f, %.4f' % train_score)
239 |
240 | subset_dev = Subset(dev_set, range(0, 1000))
241 | dev_score = evaluation(model, subset_dev)
242 | print(f'Epoch {epoch}, step {i}, dev stance f1 p r: %.4f, %.4f, %.4f' % dev_score)
243 | scheduler.step()
244 | torch.save(model.state_dict(), args.checkpoint+"_"+str(epoch)+".model")
245 | print(error_count, "mismatch occurred.")
246 |
247 | # Evaluation
248 | subset_train = Subset(train_set, range(0, 10000))
249 | train_score = evaluation(model, subset_train)
250 | print(f'Epoch {epoch}, train stance f1 p r: %.4f, %.4f, %.4f' % train_score)
251 |
252 | subset_dev = Subset(dev_set, range(0, 10000))
253 | dev_score = evaluation(model, subset_dev)
254 | print(f'Epoch {epoch}, dev stance f1 p r: %.4f, %.4f, %.4f' % dev_score)
255 |
256 |
257 |
258 | if test:
259 | model = JointParagraphClassifier(args.repfile, args.bert_dim,
260 | args.dropout).to(device)
261 | model.load_state_dict(torch.load(args.checkpoint))
262 |
263 |
264 | # Evaluation
265 | subset_dev = Subset(dev_set, range(0, 10000))
266 | dev_score = evaluation(model, subset_dev)
267 | print(f'Test stance f1 p r: %.4f, %.4f, %.4f' % dev_score)
268 |
269 | if train:
270 | params["stance_f1"] = dev_score[0]
271 | params["stance_precision"] = dev_score[1]
272 | params["stance_recall"] = dev_score[2]
273 |
274 | with jsonlines.open(args.log_file, mode='a') as writer:
275 | writer.write(params)
--------------------------------------------------------------------------------
/FEVER_stance_paragraph_kgat.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import argparse
3 |
4 | import torch
5 | import jsonlines
6 | import os
7 |
8 | import functools
9 | print = functools.partial(print, flush=True)
10 |
11 | import torch.nn as nn
12 | import torch.nn.functional as F
13 | from torch.utils.data import Dataset, DataLoader, Subset
14 | from transformers import AutoModel, AutoTokenizer, get_cosine_schedule_with_warmup
15 | from tqdm import tqdm
16 | from typing import List
17 | from sklearn.metrics import f1_score, precision_score, recall_score
18 |
19 | import random
20 | import numpy as np
21 |
22 | from tqdm import tqdm
23 | from util import arg2param, flatten, stance2json, rationale2json, merge_json
24 | from paragraph_model_kgat import KGATClassifier as JointParagraphClassifier
25 | from dataset import FEVERStanceDataset as FEVERParagraphBatchDataset
26 |
27 | import logging
28 |
29 | def reset_random_seed(seed):
30 | np.random.seed(seed)
31 | random.seed(seed)
32 | torch.manual_seed(seed)
33 |
34 | def evaluation(model, dataset):
35 | model.eval()
36 | stance_preds = []
37 | stance_labels = []
38 |
39 | with torch.no_grad():
40 | for batch in tqdm(DataLoader(dataset, batch_size = args.batch_size, shuffle=False)):
41 | encoded_dict = encode(tokenizer, batch)
42 | transformation_indices = token_idx_by_sentence(encoded_dict["input_ids"],
43 | tokenizer.sep_token_id, args.repfile)
44 | encoded_dict = {key: tensor.to(device) for key, tensor in encoded_dict.items()}
45 | transformation_indices = [tensor.to(device) for tensor in transformation_indices]
46 | stance_label = batch["stance"].to(device)
47 | stance_out, stance_loss = \
48 | model(encoded_dict, transformation_indices, stance_label = stance_label)
49 | stance_preds.extend(stance_out)
50 | stance_labels.extend(stance_label.cpu().numpy().tolist())
51 |
52 | stance_f1 = f1_score(stance_labels,stance_preds,average="micro",labels=[1, 2])
53 | stance_precision = precision_score(stance_labels,stance_preds,average="micro",labels=[1, 2])
54 | stance_recall = recall_score(stance_labels,stance_preds,average="micro",labels=[1, 2])
55 | return stance_f1, stance_precision, stance_recall
56 |
57 |
58 | def encode(tokenizer, batch, max_sent_len = 512):
59 | def truncate(input_ids, max_length, sep_token_id, pad_token_id):
60 | def longest_first_truncation(sentences, objective):
61 | sent_lens = [len(sent) for sent in sentences]
62 | while np.sum(sent_lens) > objective:
63 | max_position = np.argmax(sent_lens)
64 | sent_lens[max_position] -= 1
65 | return [sentence[:length] for sentence, length in zip(sentences, sent_lens)]
66 |
67 | all_paragraphs = []
68 | for paragraph in input_ids:
69 | valid_paragraph = paragraph[paragraph != pad_token_id]
70 | if valid_paragraph.size(0) <= max_length:
71 | all_paragraphs.append(paragraph[:max_length].unsqueeze(0))
72 | else:
73 | sep_token_idx = np.arange(valid_paragraph.size(0))[(valid_paragraph == sep_token_id).numpy()]
74 | idx_by_sentence = []
75 | prev_idx = 0
76 | for idx in sep_token_idx:
77 | idx_by_sentence.append(paragraph[prev_idx:idx])
78 | prev_idx = idx
79 | objective = max_length - 1 - len(idx_by_sentence[0]) # The last sep_token left out
80 | truncated_sentences = longest_first_truncation(idx_by_sentence[1:], objective)
81 | truncated_paragraph = torch.cat([idx_by_sentence[0]] + truncated_sentences + [torch.tensor([sep_token_id])],0)
82 | all_paragraphs.append(truncated_paragraph.unsqueeze(0))
83 |
84 | return torch.cat(all_paragraphs, 0)
85 |
86 | inputs = zip(batch["claim"], batch["paragraph"])
87 | encoded_dict = tokenizer.batch_encode_plus(
88 | inputs,
89 | pad_to_max_length=True,add_special_tokens=True,
90 | return_tensors='pt')
91 | if encoded_dict['input_ids'].size(1) > max_sent_len:
92 | if 'token_type_ids' in encoded_dict:
93 | encoded_dict = {
94 | "input_ids": truncate(encoded_dict['input_ids'], max_sent_len,
95 | tokenizer.sep_token_id, tokenizer.pad_token_id),
96 | 'token_type_ids': encoded_dict['token_type_ids'][:,:max_sent_len],
97 | 'attention_mask': encoded_dict['attention_mask'][:,:max_sent_len]
98 | }
99 | else:
100 | encoded_dict = {
101 | "input_ids": truncate(encoded_dict['input_ids'], max_sent_len,
102 | tokenizer.sep_token_id, tokenizer.pad_token_id),
103 | 'attention_mask': encoded_dict['attention_mask'][:,:max_sent_len]
104 | }
105 |
106 | return encoded_dict
107 |
108 | def token_idx_by_sentence(input_ids, sep_token_id, model_name):
109 | """
110 | Compute the token indices matrix of the BERT output.
111 | input_ids: (batch_size, paragraph_len)
112 | batch_indices, indices_by_batch, mask: (batch_size, N_sentence, N_token)
113 | bert_out: (batch_size, paragraph_len,BERT_dim)
114 | bert_out[batch_indices,indices_by_batch,:]: (batch_size, N_sentence, N_token, BERT_dim)
115 | """
116 | padding_idx = -1
117 | sep_tokens = (input_ids == sep_token_id).bool()
118 | paragraph_lens = torch.sum(sep_tokens,1).numpy().tolist()
119 | indices = torch.arange(sep_tokens.size(-1)).unsqueeze(0).expand(sep_tokens.size(0),-1)
120 | sep_indices = torch.split(indices[sep_tokens],paragraph_lens)
121 | paragraph_lens = []
122 | all_word_indices = []
123 | for paragraph in sep_indices:
124 | if "roberta" in model_name:
125 | paragraph = paragraph[1:]
126 | word_indices = [torch.arange(paragraph[i]+1, paragraph[i+1]+1) for i in range(paragraph.size(0)-1)]
127 | paragraph_lens.append(len(word_indices))
128 | all_word_indices.extend(word_indices)
129 | indices_by_sentence = nn.utils.rnn.pad_sequence(all_word_indices, batch_first=True, padding_value=padding_idx)
130 | indices_by_sentence_split = torch.split(indices_by_sentence,paragraph_lens)
131 | indices_by_batch = nn.utils.rnn.pad_sequence(indices_by_sentence_split, batch_first=True, padding_value=padding_idx)
132 | batch_indices = torch.arange(sep_tokens.size(0)).unsqueeze(-1).unsqueeze(-1).expand(-1,indices_by_batch.size(1),indices_by_batch.size(-1))
133 | mask = (indices_by_batch>=0)
134 |
135 | return batch_indices.long(), indices_by_batch.long(), mask.long()
136 |
137 | if __name__ == "__main__":
138 | argparser = argparse.ArgumentParser(description="Train, cross-validate and run sentence sequence tagger")
139 | argparser.add_argument('--repfile', type=str, default = "roberta-large", help="Word embedding file")
140 | argparser.add_argument('--train_file', type=str, default="/home/xxl190027/scifact_data/fever_train_retrieved_15.jsonl")
141 | argparser.add_argument('--pre_trained_model', type=str)
142 | #argparser.add_argument('--train_file', type=str)
143 | argparser.add_argument('--test_file', type=str, default="/home/xxl190027/scifact_data/fever_dev_retrieved_15.jsonl")
144 | argparser.add_argument('--bert_lr', type=float, default=1e-5, help="Learning rate for BERT-like LM")
145 | argparser.add_argument('--lr', type=float, default=5e-6, help="Learning rate")
146 | argparser.add_argument('--dropout', type=float, default=0, help="embedding_dropout rate")
147 | argparser.add_argument('--bert_dim', type=int, default=1024, help="bert_dimension")
148 | argparser.add_argument('--epoch', type=int, default=10, help="Training epoch")
149 | argparser.add_argument('--MAX_SENT_LEN', type=int, default=512)
150 | argparser.add_argument('--checkpoint', type=str, default = "fever_roberta_stance_paragraph_kgat")
151 | argparser.add_argument('--log_file', type=str, default = "fever_stance_paragraph_kgat_performances.jsonl")
152 | argparser.add_argument('--update_step', type=int, default=10)
153 | argparser.add_argument('--batch_size', type=int, default=1) # roberta-large: 2; bert: 8
154 | argparser.add_argument('--k', type=int, default=0)
155 | argparser.add_argument('--evaluation_step', type=int, default=50000)
156 | argparser.add_argument('--kernel', type=int, default=6)
157 | logging.getLogger("transformers.tokenization_utils_base").setLevel(logging.ERROR)
158 |
159 | reset_random_seed(12345)
160 |
161 | args = argparser.parse_args()
162 |
163 | with open(args.checkpoint+".log", 'w') as f:
164 | sys.stdout = f
165 |
166 | device = torch.device("cuda:0" if torch.cuda.is_available() else 'cpu')
167 | tokenizer = AutoTokenizer.from_pretrained(args.repfile)
168 |
169 | if args.train_file:
170 | train = True
171 | #assert args.repfile is not None, "Word embedding file required for training."
172 | else:
173 | train = False
174 | if args.test_file:
175 | test = True
176 | else:
177 | test = False
178 |
179 | params = vars(args)
180 |
181 | for k,v in params.items():
182 | print(k,v)
183 |
184 | if train:
185 | train_set = FEVERParagraphBatchDataset(args.train_file,
186 | sep_token = tokenizer.sep_token, k=args.k)
187 | dev_set = FEVERParagraphBatchDataset(args.test_file,
188 | sep_token = tokenizer.sep_token, k=args.k)
189 |
190 | print("Loaded dataset!")
191 |
192 | model = JointParagraphClassifier(args.repfile, args.bert_dim,
193 | args.dropout, kernel = args.kernel).to(device)
194 |
195 | if args.pre_trained_model is not None:
196 | model.load_state_dict(torch.load(args.pre_trained_model))
197 | print("Loaded pre-trained model.")
198 |
199 | if train:
200 | settings = [{'params': model.bert.parameters(), 'lr': args.bert_lr}]
201 | for module in model.extra_modules:
202 | settings.append({'params': module.parameters(), 'lr': args.lr})
203 | optimizer = torch.optim.Adam(settings)
204 | scheduler = get_cosine_schedule_with_warmup(optimizer, 0, args.epoch)
205 |
206 | #if torch.cuda.device_count() > 1:
207 | # print("Let's use", torch.cuda.device_count(), "GPUs!")
208 | # # dim = 0 [30, xxx] -> [10, ...], [10, ...], [10, ...] on 3 GPUs
209 | # model = nn.DataParallel(model)
210 |
211 | model.train()
212 |
213 | for epoch in range(args.epoch):
214 | error_count = 0
215 | tq = tqdm(DataLoader(train_set, batch_size = args.batch_size, shuffle=True))
216 | for i, batch in enumerate(tq):
217 | encoded_dict = encode(tokenizer, batch)
218 | transformation_indices = token_idx_by_sentence(encoded_dict["input_ids"], tokenizer.sep_token_id, args.repfile)
219 | encoded_dict = {key: tensor.to(device) for key, tensor in encoded_dict.items()}
220 | transformation_indices = [tensor.to(device) for tensor in transformation_indices]
221 | stance_label = batch["stance"].to(device)
222 |
223 | stance_out, loss = \
224 | model(encoded_dict, transformation_indices, stance_label = stance_label)
225 | loss.sum().backward()
226 |
227 | if i % args.update_step == args.update_step - 1:
228 | optimizer.step()
229 | optimizer.zero_grad()
230 | tq.set_description(f'Epoch {epoch}, iter {i}, loss: {round(loss.item(), 4)}')
231 |
232 |
233 | if i % args.evaluation_step == args.evaluation_step-1:
234 | torch.save(model.state_dict(), args.checkpoint+"_"+str(epoch)+"_"+str(i)+".model")
235 |
236 | # Evaluation
237 | subset_train = Subset(train_set, range(0, 1000))
238 | train_score = evaluation(model, subset_train)
239 | print(f'Epoch {epoch}, step {i}, train stance f1 p r: %.4f, %.4f, %.4f' % train_score)
240 |
241 | subset_dev = Subset(dev_set, range(0, 1000))
242 | dev_score = evaluation(model, subset_dev)
243 | print(f'Epoch {epoch}, step {i}, dev stance f1 p r: %.4f, %.4f, %.4f' % dev_score)
244 | scheduler.step()
245 | #torch.save(model.state_dict(), args.checkpoint+"_"+str(epoch)+".model") ##############
246 | print(error_count, "mismatch occurred.")
247 |
248 | # Evaluation
249 | subset_train = Subset(train_set, range(0, 100))
250 | train_score = evaluation(model, subset_train)
251 | print(f'Epoch {epoch}, train stance f1 p r: %.4f, %.4f, %.4f' % train_score)
252 |
253 | subset_dev = Subset(dev_set, range(0, 100))
254 | dev_score = evaluation(model, subset_dev)
255 | print(f'Epoch {epoch}, dev stance f1 p r: %.4f, %.4f, %.4f' % dev_score)
256 |
257 |
258 |
259 | if test:
260 | #model = JointParagraphClassifier(args.repfile, args.bert_dim,
261 | # args.dropout, kernel = args.kernel).to(device)
262 | #model.load_state_dict(torch.load(args.checkpoint))
263 |
264 |
265 | # Evaluation
266 | subset_dev = Subset(dev_set, range(0, 100))
267 | dev_score = evaluation(model, subset_dev)
268 | print(f'Test stance f1 p r: %.4f, %.4f, %.4f' % dev_score)
269 |
270 | if train:
271 | params["stance_f1"] = dev_score[0]
272 | params["stance_precision"] = dev_score[1]
273 | params["stance_recall"] = dev_score[2]
274 |
275 | with jsonlines.open(args.log_file, mode='a') as writer:
276 | writer.write(params)
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # ParagraphJointModel
2 | Implementation of The AAAI-21 Workshop on Scientific Document Understanding paper [A Paragraph-level Multi-task Learning Model for Scientific Fact-Verification](https://ceur-ws.org/Vol-2831/paper8.pdf). There is a short [video](https://www.youtube.com/watch?v=YrpYAdNl05Y) available for this work! This work is at top 2 of [SciFact leaderboard](https://leaderboard.allenai.org/scifact/submissions/public) as of March 30th, 2021.
3 |
4 | ## Reproducing SciFact Leaderboard Result
5 | ### Dependencies
6 |
7 | We recommend you create an anaconda environment:
8 | ```bash
9 | conda create --name scifact python=3.6 conda-build
10 | ```
11 | Then, from the `scifact` project root, run
12 | ```
13 | conda develop .
14 | ```
15 | which will add the scifact code to your `PYTHONPATH`.
16 |
17 | Then, install Python requirements:
18 | ```
19 | pip install -r requirements.txt
20 | ```
21 | If you encounter any installation problem regarding sent2vec, please check [their repo](https://github.com/epfml/sent2vec).
22 | The BioSentVec model is available [here](https://github.com/ncbi-nlp/BioSentVec#biosentvec).
23 |
24 | The SciFact claim files and corpus file are available at [SciFact repo](https://github.com/allenai/scifact).
25 | The checkpoint of Paragraph-Joint model used for the paper (trained on training set) is available [here](https://drive.google.com/file/d/1agyrkUGJ0lxTBJpdy1QCyaAAJyxBnoO2/view?usp=sharing).
26 | The checkpoint of Paragraph-Joint model used for leaderboard submission (trained on train+dev set) is available [here](https://drive.google.com/file/d/1hMrQzFe1EaJpCN9s3pF27Wu3amBbekiI/view?usp=sharing).
27 |
28 | ### Abstract Retrieval
29 | ```
30 | python ComputeBioSentVecAbstractEmbedding.py --claim_file /path/to/claims.jsonl --corpus_file /path/to/corpus.jsonl --sentvec_path /path/to/sentvec_model
31 |
32 | python SentVecAbstractRetriaval.py --claim_file /path/to/claims.jsonl --corpus_file /path/to/corpus.jsonl --k_retrieval 30 --claim_retrieved_file /output/path/of/retrieval_file.jsonl --scifact_abstract_retrieval_file /output/path/of/retrieval_file_scifact_format.jsonl
33 | ```
34 | The retrieved abstracts are available here: [train](https://drive.google.com/file/d/18yWhLP3n1OjT_XrUB3rJwNMnLRI3k8Ck/view?usp=sharing), [dev](https://drive.google.com/file/d/1fnfdOA2e3_U-kGavuhoyiYZUlDYWX9eM/view?usp=sharing), [test](https://drive.google.com/file/d/10Lh0aP06tGfZ-LlNGWnDtN0GM8M14z2q/view?usp=sharing).
35 | ### Training of the ParagraphJoint Model (Optional for Result Reproduction Purpose)
36 | #### FEVER Pre-training
37 | You need to retrieve some negative samples for FEVER pre-training. We used the trieval code from [here](https://github.com/sheffieldnlp/fever-naacl-2018). Empirically, only retrieving 5 negative examples for each claim is enough, while retrieving more may be way too time-consuming. You need to convert the format of the output of the retrieval code to the input of SciFact.
38 |
39 | For your convenience, the converted retrieved FEVER examples with `k_retrieval=15` are available: [train](https://drive.google.com/file/d/1sS6mpaALuWnk6Pl2twIt_GcBs7ExRY2b/view?usp=sharing), [dev](https://drive.google.com/file/d/1sOfFL6fvK-AYjzcGPJ5KqcFPmAMvQJUi/view?usp=sharing).
40 |
41 | The checkpoint of the Paragraph-Joint model only pretrained on the retrieved FEVER examples shared above is available [here](https://drive.google.com/file/d/12u9glqoCBuhxnP9P8dM4HSncAIjHmQ_U/view?usp=sharing).
42 |
43 | Run `FEVER_joint_paragraph_dynamic.py` to pre-train the model on FEVER. Use `--checkpoint` to specify the checkpoint path. Run `scifact_joint_paragraph_dynamic.py` to fine-tune on SciFact dataset. Use `--pre_trained_model` to load the pre-trained model. Please check the other options in the source file.
44 |
45 | ### Joint Prediction of Rationale Selection and Stance Prediciton
46 | ```
47 | python scifact_joint_paragraph_dynamic_prediction.py --corpus_file /path/to/corpus.jsonl --test_file /path/to/retrieval_file.jsonl --dataset /path/to/scifact/claims_test.jsonl --batch_size 25 --k 30 --prediction /path/to/output.jsonl --evaluate --checkpoint /path/to/checkpoint
48 | ```
49 |
50 | ## File naming conventions
51 | The file names should be self-explanatory. Most parameters are set with default values. The parameters should be straight forward.
52 |
53 | ### Non-Joint Models
54 | File names with `rationale` and `stance` are those scripts for rationale selection and stance prediction models.
55 |
56 | ### FEVER Pretraining and Domain-Adaptation
57 | File names with `FEVER` are scripts for training on FEVER dataset. Same for `domain_adaptation`.
58 |
59 | ### Prediction
60 | File names with `prediction` are scripts for taking the pre-trained models and perform inference.
61 |
62 | ### KGAT
63 | File names with `kgat` means those models with [KGAT](https://github.com/xiangwang1223/knowledge_graph_attention_network) as stance predictor.
64 |
65 | ### Fine-tuning
66 | You can use `--pre_trained_model path/to/pre_trained.model` to load a model trained on FEVER dataset and fine-tune on SciFact.
67 |
68 | ## Cite our paper
69 | ```
70 | @inproceedings{li2021paragraph,
71 | title={A Paragraph-level Multi-task Learning Model for Scientific Fact-Verification.},
72 | author={Li, Xiangci and Burns, Gully A and Peng, Nanyun},
73 | booktitle={SDU@ AAAI},
74 | year={2021}
75 | }
76 | ```
77 |
78 |
--------------------------------------------------------------------------------
/SentVecAbstractRetriaval.py:
--------------------------------------------------------------------------------
1 | from sklearn.metrics.pairwise import cosine_similarity
2 | import pickle
3 | import json
4 | import numpy as np
5 | from tqdm import tqdm
6 | import jsonlines
7 | import argparse
8 | if __name__ == "__main__":
9 | argparser = argparse.ArgumentParser()
10 | argparser.add_argument('--claim_file', type=str)
11 | argparser.add_argument('--corpus_file', type=str)
12 | argparser.add_argument('--k_retrieval', type=int)
13 | argparser.add_argument('--claim_retrieved_file', type=str)
14 | argparser.add_argument('--scifact_abstract_retrieval_file', type=str, help="abstract retreival in scifact format")
15 | argparser.add_argument('--corpus_embedding_pickle', type=str, default="corpus_paragraph_biosentvec.pkl")
16 | argparser.add_argument('--claim_embedding_pickle', type=str, default="claim_biosentvec.pkl")
17 |
18 | args = argparser.parse_args()
19 |
20 | with open(args.corpus_embedding_pickle,"rb") as f:
21 | corpus_embeddings = pickle.load(f)
22 |
23 | with open(args.claim_embedding_pickle,"rb") as f:
24 | claim_embeddings = pickle.load(f)
25 |
26 | claim_file = args.claim_file
27 |
28 | claims = []
29 | with open(claim_file) as f:
30 | for line in f:
31 | claim = json.loads(line)
32 | claims.append(claim)
33 | claims_by_id = {claim['id']:claim for claim in claims}
34 |
35 | all_similarities = {}
36 | for claim_id, claim_embedding in tqdm(claim_embeddings.items()):
37 | this_claim = {}
38 | for abstract_id, abstract_embedding in corpus_embeddings.items():
39 | claim_similarity = cosine_similarity(claim_embedding,abstract_embedding)
40 | this_claim[abstract_id] = claim_similarity
41 | all_similarities[claim_id] = this_claim
42 |
43 | ordered_corpus = {}
44 | for claim_id, claim_similarities in tqdm(all_similarities.items()):
45 | corpus_ids = []
46 | max_similarity = []
47 | for abstract_id, similarity in claim_similarities.items():
48 | corpus_ids.append(abstract_id)
49 | max_similarity.append(np.max(similarity))
50 | corpus_ids = np.array(corpus_ids)
51 | sorted_order = np.argsort(max_similarity)[::-1]
52 | ordered_corpus[claim_id] = corpus_ids[sorted_order]
53 |
54 | k = args.k_retrieval
55 | retrieved_corpus = {ID:v[:k] for ID,v in ordered_corpus.items()}
56 |
57 | with jsonlines.open(args.claim_retrieved_file, 'w') as output:
58 | claim_ids = sorted(list(claims_by_id.keys()))
59 | for id in claim_ids:
60 | claims_by_id[id]["retrieved_doc_ids"] = retrieved_corpus[id].tolist()
61 | output.write(claims_by_id[id])
62 |
63 | with jsonlines.open(args.scifact_abstract_retrieval_file, 'w') as output:
64 | claim_ids = sorted(list(claims_by_id.keys()))
65 | for id in claim_ids:
66 | doc_ids = retrieved_corpus[id].tolist()
67 | doc_ids = [int(id) for id in doc_ids]
68 | output.write({"claim_id": id, "doc_ids": doc_ids})
69 |
--------------------------------------------------------------------------------
/TFIDFabstractRetrieval.py:
--------------------------------------------------------------------------------
1 | from sklearn.feature_extraction.text import TfidfVectorizer
2 | from nltk import word_tokenize
3 | from nltk.corpus import stopwords
4 | from string import punctuation
5 | from scipy.spatial import distance
6 | import json
7 | import numpy as np
8 | import jsonlines
9 |
10 | import argparse
11 |
12 | if __name__ == "__main__":
13 | argparser = argparse.ArgumentParser()
14 | argparser.add_argument('--claim_file', type=str)
15 | argparser.add_argument('--corpus_file', type=str)
16 | argparser.add_argument('--k_retrieval', type=int)
17 | argparser.add_argument('--claim_retrieved_file', type=str)
18 |
19 | claim_file = args.claim_file
20 | corpus_file = args.corpus_file
21 |
22 | corpus = {}
23 | with open(corpus_file) as f:
24 | for line in f:
25 | abstract = json.loads(line)
26 | corpus[str(abstract["doc_id"])] = abstract
27 |
28 | claims = []
29 | with open(claim_file) as f:
30 | for line in f:
31 | claim = json.loads(line)
32 | claims.append(claim)
33 | claims_by_id = {claim['id']:claim for claim in claims}
34 |
35 | corpus_texts = []
36 | corpus_ids = []
37 | for k, v in corpus.items():
38 | original_sentences = [v['title']] + v['abstract']
39 | processed_paragraph = " ".join(original_sentences)
40 | corpus_texts.append(processed_paragraph)
41 | corpus_ids.append(k)
42 | vectorizer = TfidfVectorizer(stop_words='english',
43 | ngram_range=(1, 2))
44 | corpus_ids = np.array(corpus_ids)
45 | corpus_vectors = vectorizer.fit_transform(corpus_texts)
46 |
47 | claim_vectors = vectorizer.transform([claim['claim'] for claim in claims])
48 | similarity_matrix = np.dot(corpus_vectors, claim_vectors.T).todense()
49 |
50 | k = args.k_retrieval
51 | orders = np.argsort(similarity_matrix,axis=0)
52 | retrieved_corpus = {claim["id"]: corpus_ids[orders[:,i][::-1][:k]].squeeze() for i, claim in enumerate(claims)}
53 |
54 | with jsonlines.open(args.claim_retrieved_file, 'w') as output:
55 | claim_ids = sorted(list(claims_by_id.keys()))
56 | for id in claim_ids:
57 | claims_by_id[id]["retrieved_doc_ids"] = retrieved_corpus[id].tolist()
58 | output.write(claims_by_id[id])
--------------------------------------------------------------------------------
/domain_adaptation_joint_paragraph_fine_tune.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | import torch
4 | import jsonlines
5 | import os
6 | import sys
7 |
8 | import functools
9 | print = functools.partial(print, flush=True)
10 |
11 | import torch.nn as nn
12 | import torch.nn.functional as F
13 | from torch.utils.data import Dataset, DataLoader, Subset
14 | from transformers import AutoModel, AutoTokenizer, get_cosine_schedule_with_warmup
15 | from tqdm import tqdm
16 | from typing import List
17 | from sklearn.metrics import f1_score, precision_score, recall_score
18 |
19 | import math
20 | import random
21 | import numpy as np
22 |
23 | from tqdm import tqdm
24 | from util import arg2param, flatten, stance2json, rationale2json, merge_json
25 | from paragraph_model_dynamic import DomainAdaptationJointParagraphClassifier
26 | from dataset import FEVERParagraphBatchDataset, SciFactParagraphBatchDataset, SciFact_FEVER_Dataset, Multiple_SciFact_Dataset
27 |
28 | import logging
29 |
30 | from lib.data import GoldDataset, PredictedDataset
31 | from lib import metrics
32 |
33 | def schedule_sample_p(epoch, total):
34 | return np.sin(0.5* np.pi* epoch / (total-1))
35 |
36 | def reset_random_seed(seed):
37 | np.random.seed(seed)
38 | random.seed(seed)
39 | torch.manual_seed(seed)
40 |
41 | def batch_rationale_label(labels, padding_idx = 2):
42 | max_sent_len = max([len(label) for label in labels])
43 | label_matrix = torch.ones(len(labels), max_sent_len) * padding_idx
44 | label_list = []
45 | for i, label in enumerate(labels):
46 | for j, evid in enumerate(label):
47 | label_matrix[i,j] = int(evid)
48 | label_list.append([int(evid) for evid in label])
49 | return label_matrix.long(), label_list
50 |
51 | def predict(model, dataset):
52 | model.eval()
53 | rationale_predictions = []
54 | stance_preds = []
55 |
56 | def remove_dummy(rationale_out):
57 | return [out[1:] for out in rationale_out]
58 |
59 | with torch.no_grad():
60 | for batch in tqdm(DataLoader(dataset, batch_size = args.batch_size, shuffle=False)):
61 | encoded_dict = encode(tokenizer, batch)
62 | domain_indices = batch["dataset"].to(device)
63 | transformation_indices = token_idx_by_sentence(encoded_dict["input_ids"],
64 | tokenizer.sep_token_id, args.repfile)
65 | encoded_dict = {key: tensor.to(device) for key, tensor in encoded_dict.items()}
66 | transformation_indices = [tensor.to(device) for tensor in transformation_indices]
67 | rationale_out, stance_out, _, _ = model(encoded_dict, transformation_indices, domain_indices)
68 | stance_preds.extend(stance_out)
69 | rationale_predictions.extend(remove_dummy(rationale_out))
70 |
71 | return rationale_predictions, stance_preds
72 |
73 | def evaluation(model, dataset):
74 | model.eval()
75 | rationale_predictions = []
76 | rationale_labels = []
77 | stance_preds = []
78 | stance_labels = []
79 |
80 | def remove_dummy(rationale_out):
81 | return [out[1:] for out in rationale_out]
82 |
83 | with torch.no_grad():
84 | for batch in tqdm(DataLoader(dataset, batch_size = args.batch_size, shuffle=False)):
85 | encoded_dict = encode(tokenizer, batch)
86 | transformation_indices = token_idx_by_sentence(encoded_dict["input_ids"],
87 | tokenizer.sep_token_id, args.repfile)
88 | encoded_dict = {key: tensor.to(device) for key, tensor in encoded_dict.items()}
89 | transformation_indices = [tensor.to(device) for tensor in transformation_indices]
90 | stance_label = batch["stance"].to(device)
91 | domain_indices = batch["dataset"].to(device)
92 | padded_rationale_label, rationale_label = batch_rationale_label(batch["label"], padding_idx = 2)
93 | rationale_out, stance_out, rationale_loss, stance_loss = \
94 | model(encoded_dict, transformation_indices, domain_indices, stance_label = stance_label,
95 | rationale_label = padded_rationale_label.to(device))
96 | stance_preds.extend(stance_out)
97 | stance_labels.extend(stance_label.cpu().numpy().tolist())
98 |
99 | rationale_predictions.extend(remove_dummy(rationale_out))
100 | rationale_labels.extend(remove_dummy(rationale_label))
101 |
102 | stance_f1 = f1_score(stance_labels,stance_preds,average="micro",labels=[1, 2])
103 | stance_precision = precision_score(stance_labels,stance_preds,average="micro",labels=[1, 2])
104 | stance_recall = recall_score(stance_labels,stance_preds,average="micro",labels=[1, 2])
105 | rationale_f1 = f1_score(flatten(rationale_labels),flatten(rationale_predictions))
106 | rationale_precision = precision_score(flatten(rationale_labels),flatten(rationale_predictions))
107 | rationale_recall = recall_score(flatten(rationale_labels),flatten(rationale_predictions))
108 | return stance_f1, stance_precision, stance_recall, rationale_f1, rationale_precision, rationale_recall
109 |
110 | def encode(tokenizer, batch, max_sent_len = 512):
111 | def truncate(input_ids, max_length, sep_token_id, pad_token_id):
112 | def longest_first_truncation(sentences, objective):
113 | sent_lens = [len(sent) for sent in sentences]
114 | while np.sum(sent_lens) > objective:
115 | max_position = np.argmax(sent_lens)
116 | sent_lens[max_position] -= 1
117 | return [sentence[:length] for sentence, length in zip(sentences, sent_lens)]
118 |
119 | all_paragraphs = []
120 | for paragraph in input_ids:
121 | valid_paragraph = paragraph[paragraph != pad_token_id]
122 | if valid_paragraph.size(0) <= max_length:
123 | all_paragraphs.append(paragraph[:max_length].unsqueeze(0))
124 | else:
125 | sep_token_idx = np.arange(valid_paragraph.size(0))[(valid_paragraph == sep_token_id).numpy()]
126 | idx_by_sentence = []
127 | prev_idx = 0
128 | for idx in sep_token_idx:
129 | idx_by_sentence.append(paragraph[prev_idx:idx])
130 | prev_idx = idx
131 | objective = max_length - 1 - len(idx_by_sentence[0]) # The last sep_token left out
132 | truncated_sentences = longest_first_truncation(idx_by_sentence[1:], objective)
133 | truncated_paragraph = torch.cat([idx_by_sentence[0]] + truncated_sentences + [torch.tensor([sep_token_id])],0)
134 | all_paragraphs.append(truncated_paragraph.unsqueeze(0))
135 |
136 | return torch.cat(all_paragraphs, 0)
137 |
138 | inputs = zip(batch["claim"], batch["paragraph"])
139 | encoded_dict = tokenizer.batch_encode_plus(
140 | inputs,
141 | pad_to_max_length=True,add_special_tokens=True,
142 | return_tensors='pt')
143 | if encoded_dict['input_ids'].size(1) > max_sent_len:
144 | if 'token_type_ids' in encoded_dict:
145 | encoded_dict = {
146 | "input_ids": truncate(encoded_dict['input_ids'], max_sent_len,
147 | tokenizer.sep_token_id, tokenizer.pad_token_id),
148 | 'token_type_ids': encoded_dict['token_type_ids'][:,:max_sent_len],
149 | 'attention_mask': encoded_dict['attention_mask'][:,:max_sent_len]
150 | }
151 | else:
152 | encoded_dict = {
153 | "input_ids": truncate(encoded_dict['input_ids'], max_sent_len,
154 | tokenizer.sep_token_id, tokenizer.pad_token_id),
155 | 'attention_mask': encoded_dict['attention_mask'][:,:max_sent_len]
156 | }
157 |
158 | return encoded_dict
159 |
160 |
161 | def token_idx_by_sentence(input_ids, sep_token_id, model_name):
162 | """
163 | Compute the token indices matrix of the BERT output.
164 | input_ids: (batch_size, paragraph_len)
165 | batch_indices, indices_by_batch, mask: (batch_size, N_sentence, N_token)
166 | bert_out: (batch_size, paragraph_len,BERT_dim)
167 | bert_out[batch_indices,indices_by_batch,:]: (batch_size, N_sentence, N_token, BERT_dim)
168 | """
169 | padding_idx = -1
170 | sep_tokens = (input_ids == sep_token_id).bool()
171 | paragraph_lens = torch.sum(sep_tokens,1).numpy().tolist()
172 | indices = torch.arange(sep_tokens.size(-1)).unsqueeze(0).expand(sep_tokens.size(0),-1)
173 | sep_indices = torch.split(indices[sep_tokens],paragraph_lens)
174 | paragraph_lens = []
175 | all_word_indices = []
176 | for paragraph in sep_indices:
177 | if "roberta" in model_name:
178 | paragraph = paragraph[1:]
179 | word_indices = [torch.arange(paragraph[i]+1, paragraph[i+1]+1) for i in range(paragraph.size(0)-1)]
180 | paragraph_lens.append(len(word_indices))
181 | all_word_indices.extend(word_indices)
182 | indices_by_sentence = nn.utils.rnn.pad_sequence(all_word_indices, batch_first=True, padding_value=padding_idx)
183 | indices_by_sentence_split = torch.split(indices_by_sentence,paragraph_lens)
184 | indices_by_batch = nn.utils.rnn.pad_sequence(indices_by_sentence_split, batch_first=True, padding_value=padding_idx)
185 | batch_indices = torch.arange(sep_tokens.size(0)).unsqueeze(-1).unsqueeze(-1).expand(-1,indices_by_batch.size(1),indices_by_batch.size(-1))
186 | mask = (indices_by_batch>=0)
187 |
188 | return batch_indices.long(), indices_by_batch.long(), mask.long()
189 |
190 | def post_process_stance(rationale_json, stance_json):
191 | assert(len(rationale_json) == len(stance_json))
192 | for stance_pred, rationale_pred in zip(stance_json, rationale_json):
193 | assert(stance_pred["claim_id"] == rationale_pred["claim_id"])
194 | for doc_id, pred in rationale_pred["evidence"].items():
195 | if len(pred) == 0:
196 | stance_pred["labels"][doc_id]["label"] = "NOT_ENOUGH_INFO"
197 | return stance_json
198 |
199 | if __name__ == "__main__":
200 | argparser = argparse.ArgumentParser(description="Train, cross-validate and run sentence sequence tagger")
201 | argparser.add_argument('--repfile', type=str, default = "roberta-large", help="Word embedding file")
202 | argparser.add_argument('--scifact_corpus', type=str, default="/nas/home/xiangcil/scifact/data/corpus.jsonl")
203 | argparser.add_argument('--scifact_train', type=str, default="/nas/home/xiangcil/CitationEvaluation/SciFact/claims_train_retrieved.jsonl")
204 | #argparser.add_argument('--scifact_train', type=str)
205 | argparser.add_argument('--pre_trained_model', type=str)
206 | argparser.add_argument('--scifact_test', type=str, default="/nas/home/xiangcil/CitationEvaluation/SciFact/claims_dev_retrieved.jsonl")
207 | argparser.add_argument('--dataset', type=str, default="/nas/home/xiangcil/CitationEvaluation/SciFact/claims_dev.jsonl")
208 | argparser.add_argument('--bert_lr', type=float, default=1e-5, help="Learning rate for BERT-like LM")
209 | argparser.add_argument('--lr', type=float, default=5e-6, help="Learning rate")
210 | argparser.add_argument('--dropout', type=float, default=0, help="embedding_dropout rate")
211 | argparser.add_argument('--bert_dim', type=int, default=1024, help="bert_dimension")
212 | argparser.add_argument('--epoch_start', type=int, default=0, help="Training epoch")
213 | argparser.add_argument('--epoch', type=int, default=10, help="Training epoch")
214 | argparser.add_argument('--MAX_SENT_LEN', type=int, default=512)
215 | argparser.add_argument('--loss_ratio', type=float, default=5)
216 | argparser.add_argument('--checkpoint', type=str, default = "domain_adaptation_roberta_joint_paragraph_fine_tune.model")
217 | argparser.add_argument('--log_file', type=str, default = "domain_adaptation_joint_paragraph_fine_tune_performances.jsonl")
218 | argparser.add_argument('--prediction', type=str, default = "prediction_domain_adaptation_roberta_joint_paragraph_dynamic_fine_tune.jsonl")
219 | argparser.add_argument('--update_step', type=int, default=10)
220 | argparser.add_argument('--batch_size', type=int, default=1) # roberta-large: 2; bert: 8
221 | argparser.add_argument('--scifact_k', type=int, default=0)
222 |
223 | logging.getLogger("transformers.tokenization_utils_base").setLevel(logging.ERROR)
224 | reset_random_seed(12345)
225 | args = argparser.parse_args()
226 |
227 | device = torch.device("cuda:0" if torch.cuda.is_available() else 'cpu')
228 | tokenizer = AutoTokenizer.from_pretrained(args.repfile)
229 |
230 | if args.scifact_train:
231 | train = True
232 | #assert args.repfile is not None, "Word embedding file required for training."
233 | else:
234 | train = False
235 | if args.scifact_test:
236 | test = True
237 | else:
238 | test = False
239 |
240 | params = vars(args)
241 |
242 | for k,v in params.items():
243 | print(k,v)
244 |
245 | if train:
246 | train_set = SciFactParagraphBatchDataset(args.scifact_corpus, args.scifact_train,
247 | sep_token = tokenizer.sep_token, k = args.scifact_k, downsample_n = 2)
248 |
249 | dev_set = SciFactParagraphBatchDataset(args.scifact_corpus, args.scifact_test,
250 | sep_token = tokenizer.sep_token, k = args.scifact_k, downsample_n = 0)
251 |
252 | print("Loaded dataset!")
253 |
254 | model = DomainAdaptationJointParagraphClassifier(args.repfile, args.bert_dim,
255 | args.dropout).to(device)
256 |
257 | if args.pre_trained_model is not None:
258 | model.load_state_dict(torch.load(args.pre_trained_model))
259 | print("Loaded pre-trained model.")
260 |
261 | if train:
262 | settings = [{'params': model.bert.parameters(), 'lr': args.bert_lr}]
263 | for module in model.extra_modules:
264 | settings.append({'params': module.parameters(), 'lr': args.lr})
265 | optimizer = torch.optim.Adam(settings)
266 | scheduler = get_cosine_schedule_with_warmup(optimizer, 0, args.epoch)
267 |
268 | #if torch.cuda.device_count() > 1:
269 | # print("Let's use", torch.cuda.device_count(), "GPUs!")
270 | # # dim = 0 [30, xxx] -> [10, ...], [10, ...], [10, ...] on 3 GPUs
271 | # model = nn.DataParallel(model)
272 |
273 | model.train()
274 | prev_performance = 0
275 | for epoch in range(args.epoch_start, args.epoch):
276 | tq = tqdm(DataLoader(train_set, batch_size = args.batch_size, shuffle=True))
277 | for i, batch in enumerate(tq):
278 | encoded_dict = encode(tokenizer, batch)
279 | transformation_indices = token_idx_by_sentence(encoded_dict["input_ids"], tokenizer.sep_token_id, args.repfile)
280 | encoded_dict = {key: tensor.to(device) for key, tensor in encoded_dict.items()}
281 | transformation_indices = [tensor.to(device) for tensor in transformation_indices]
282 | stance_label = batch["stance"].to(device)
283 | domain_indices = batch["dataset"].to(device)
284 | padded_rationale_label, rationale_label = batch_rationale_label(batch["label"], padding_idx = 2)
285 | rationale_out, stance_out, rationale_loss, stance_loss = \
286 | model(encoded_dict, transformation_indices, domain_indices, stance_label = stance_label,
287 | rationale_label = padded_rationale_label.to(device))
288 | rationale_loss *= args.loss_ratio
289 | loss = rationale_loss + stance_loss
290 | loss.sum().backward()
291 |
292 | if i % args.update_step == args.update_step - 1:
293 | optimizer.step()
294 | optimizer.zero_grad()
295 | tq.set_description(f'Epoch {epoch}, iter {i}, loss: {round(loss.item(), 4)}, stance loss: {round(stance_loss.item(), 4)}, rationale loss: {round(rationale_loss.item(), 4)}')
296 |
297 | scheduler.step()
298 |
299 | # Evaluation
300 | train_score = evaluation(model, train_set)
301 | print(f'Epoch {epoch}, train stance f1 p r: %.4f, %.4f, %.4f, rationale f1 p r: %.4f, %.4f, %.4f' % train_score)
302 |
303 | dev_score = evaluation(model, dev_set)
304 | print(f'Epoch {epoch}, dev stance f1 p r: %.4f, %.4f, %.4f, rationale f1 p r: %.4f, %.4f, %.4f' % dev_score)
305 |
306 | dev_perf = dev_score[0] * dev_score[3]
307 | if dev_perf >= prev_performance:
308 | torch.save(model.state_dict(), args.checkpoint)
309 | best_state_dict = model.state_dict()
310 | prev_performance = dev_perf
311 | print("New model saved!")
312 | else:
313 | print("Skip saving model.")
314 |
315 |
316 | if test:
317 | if train:
318 | del model
319 | model = DomainAdaptationJointParagraphClassifier(args.repfile, args.bert_dim,
320 | args.dropout).to(device)
321 | model.load_state_dict(best_state_dict)
322 | print("Testing on the new model.")
323 | else:
324 | model.load_state_dict(torch.load(args.checkpoint))
325 | print("Loaded saved model.")
326 |
327 | # Evaluation
328 | #dev_score = evaluation(model, dev_set)
329 | #print(f'Test stance f1 p r: %.4f, %.4f, %.4f, rationale f1 p r: %.4f, %.4f, %.4f' % dev_score)
330 |
331 | rationale_predictions, stance_preds = predict(model, dev_set)
332 | rationale_json = rationale2json(dev_set.samples, rationale_predictions)
333 | stance_json = stance2json(dev_set.samples, stance_preds)
334 | stance_json = post_process_stance(rationale_json, stance_json)
335 | merged_json = merge_json(rationale_json, stance_json)
336 |
337 | with jsonlines.open(args.prediction, 'w') as output:
338 | for result in merged_json:
339 | output.write(result)
340 |
341 | data = GoldDataset(args.scifact_corpus, args.dataset)
342 | predictions = PredictedDataset(data, args.prediction)
343 | res = metrics.compute_metrics(predictions)
344 | params["evaluation"] = res
345 | with jsonlines.open(args.log_file, mode='a') as writer:
346 | writer.write(params)
--------------------------------------------------------------------------------
/domain_adaptation_joint_paragraph_kgat_prediction.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import argparse
3 |
4 | import torch
5 | import jsonlines
6 | import os
7 |
8 | import functools
9 | print = functools.partial(print, flush=True)
10 |
11 | import torch.nn as nn
12 | import torch.nn.functional as F
13 | from torch.utils.data import Dataset, DataLoader, Subset
14 | from transformers import AutoModel, AutoTokenizer, get_cosine_schedule_with_warmup
15 | from tqdm import tqdm
16 | from typing import List
17 | from sklearn.metrics import f1_score, precision_score, recall_score
18 |
19 | import math
20 | import random
21 | import numpy as np
22 |
23 | from tqdm import tqdm
24 | from util import arg2param, flatten, stance2json, rationale2json, merge_json
25 | from paragraph_model_kgat import DomainAdaptationJointParagraphKGATClassifier
26 | from dataset import SciFactParagraphBatchDataset, FEVERParagraphBatchDataset, SciFact_FEVER_Dataset, Multiple_SciFact_Dataset
27 |
28 | import logging
29 |
30 | from lib.data import GoldDataset, PredictedDataset
31 | from lib import metrics
32 |
33 | def reset_random_seed(seed):
34 | np.random.seed(seed)
35 | random.seed(seed)
36 | torch.manual_seed(seed)
37 |
38 | def batch_rationale_label(labels, padding_idx = 2):
39 | max_sent_len = max([len(label) for label in labels])
40 | label_matrix = torch.ones(len(labels), max_sent_len) * padding_idx
41 | label_list = []
42 | for i, label in enumerate(labels):
43 | for j, evid in enumerate(label):
44 | label_matrix[i,j] = int(evid)
45 | label_list.append([int(evid) for evid in label])
46 | return label_matrix.long(), label_list
47 |
48 | def predict(model, dataset):
49 | model.eval()
50 | rationale_predictions = []
51 | stance_preds = []
52 |
53 | def remove_dummy(rationale_out):
54 | return [out[1:] for out in rationale_out]
55 |
56 | with torch.no_grad():
57 | for batch in tqdm(DataLoader(dataset, batch_size = args.batch_size, shuffle=False)):
58 | encoded_dict = encode(tokenizer, batch)
59 | transformation_indices = token_idx_by_sentence(encoded_dict["input_ids"],
60 | tokenizer.sep_token_id, args.repfile)
61 | encoded_dict = {key: tensor.to(device) for key, tensor in encoded_dict.items()}
62 | transformation_indices = [tensor.to(device) for tensor in transformation_indices]
63 | domain_indices = batch["dataset"].to(device)
64 | rationale_out, stance_out, _, _ = model(encoded_dict, transformation_indices, domain_indices)
65 | stance_preds.extend(stance_out)
66 | rationale_predictions.extend(remove_dummy(rationale_out))
67 |
68 | return rationale_predictions, stance_preds
69 |
70 | def encode(tokenizer, batch, max_sent_len = 512):
71 | def truncate(input_ids, max_length, sep_token_id, pad_token_id):
72 | def longest_first_truncation(sentences, objective):
73 | sent_lens = [len(sent) for sent in sentences]
74 | while np.sum(sent_lens) > objective:
75 | max_position = np.argmax(sent_lens)
76 | sent_lens[max_position] -= 1
77 | return [sentence[:length] for sentence, length in zip(sentences, sent_lens)]
78 |
79 | all_paragraphs = []
80 | for paragraph in input_ids:
81 | valid_paragraph = paragraph[paragraph != pad_token_id]
82 | if valid_paragraph.size(0) <= max_length:
83 | all_paragraphs.append(paragraph[:max_length].unsqueeze(0))
84 | else:
85 | sep_token_idx = np.arange(valid_paragraph.size(0))[(valid_paragraph == sep_token_id).numpy()]
86 | idx_by_sentence = []
87 | prev_idx = 0
88 | for idx in sep_token_idx:
89 | idx_by_sentence.append(paragraph[prev_idx:idx])
90 | prev_idx = idx
91 | objective = max_length - 1 - len(idx_by_sentence[0]) # The last sep_token left out
92 | truncated_sentences = longest_first_truncation(idx_by_sentence[1:], objective)
93 | truncated_paragraph = torch.cat([idx_by_sentence[0]] + truncated_sentences + [torch.tensor([sep_token_id])],0)
94 | all_paragraphs.append(truncated_paragraph.unsqueeze(0))
95 |
96 | return torch.cat(all_paragraphs, 0)
97 |
98 | inputs = zip(batch["claim"], batch["paragraph"])
99 | encoded_dict = tokenizer.batch_encode_plus(
100 | inputs,
101 | pad_to_max_length=True,add_special_tokens=True,
102 | return_tensors='pt')
103 | if encoded_dict['input_ids'].size(1) > max_sent_len:
104 | if 'token_type_ids' in encoded_dict:
105 | encoded_dict = {
106 | "input_ids": truncate(encoded_dict['input_ids'], max_sent_len,
107 | tokenizer.sep_token_id, tokenizer.pad_token_id),
108 | 'token_type_ids': encoded_dict['token_type_ids'][:,:max_sent_len],
109 | 'attention_mask': encoded_dict['attention_mask'][:,:max_sent_len]
110 | }
111 | else:
112 | encoded_dict = {
113 | "input_ids": truncate(encoded_dict['input_ids'], max_sent_len,
114 | tokenizer.sep_token_id, tokenizer.pad_token_id),
115 | 'attention_mask': encoded_dict['attention_mask'][:,:max_sent_len]
116 | }
117 |
118 | return encoded_dict
119 |
120 | def token_idx_by_sentence(input_ids, sep_token_id, model_name):
121 | """
122 | Advanced indexing: Compute the token indices matrix of the BERT output.
123 | input_ids: (batch_size, paragraph_len)
124 | batch_indices, indices_by_batch, mask: (batch_size, N_sentence, N_token)
125 | bert_out: (batch_size, paragraph_len,BERT_dim)
126 | bert_out[batch_indices,indices_by_batch,:]: (batch_size, N_sentence, N_token, BERT_dim)
127 | """
128 | padding_idx = -1
129 | sep_tokens = (input_ids == sep_token_id).bool()
130 | paragraph_lens = torch.sum(sep_tokens,1).numpy().tolist() # i.e. N_sentences per paragraph
131 | indices = torch.arange(sep_tokens.size(-1)).unsqueeze(0).expand(sep_tokens.size(0),-1) # 0,1,2,3,....,511 for each sentence
132 | sep_indices = torch.split(indices[sep_tokens],paragraph_lens) # indices of SEP tokens per paragraph
133 | paragraph_lens = []
134 | all_word_indices = []
135 | for paragraph in sep_indices:
136 | # claim sentence: [CLS] token1 token2 ... tokenk
137 | claim_word_indices = torch.arange(0, paragraph[0])
138 | if "roberta" in model_name: # Huggingface Roberta has ......
139 | paragraph = paragraph[1:]
140 | # each sentence: [SEP] token1 token2 ... tokenk, the last [SEP] in the paragraph is ditched.
141 | sentence_word_indices = [torch.arange(paragraph[i], paragraph[i+1]) for i in range(paragraph.size(0)-1)]
142 |
143 | # KGAT requires claim sentence, so add it back.
144 | word_indices = [claim_word_indices] + sentence_word_indices
145 |
146 | paragraph_lens.append(len(word_indices))
147 | all_word_indices.extend(word_indices)
148 | indices_by_sentence = nn.utils.rnn.pad_sequence(all_word_indices, batch_first=True, padding_value=padding_idx)
149 | indices_by_sentence_split = torch.split(indices_by_sentence,paragraph_lens)
150 | indices_by_batch = nn.utils.rnn.pad_sequence(indices_by_sentence_split, batch_first=True, padding_value=padding_idx)
151 | batch_indices = torch.arange(sep_tokens.size(0)).unsqueeze(-1).unsqueeze(-1).expand(-1,indices_by_batch.size(1),indices_by_batch.size(-1))
152 | mask = (indices_by_batch>=0)
153 |
154 | return batch_indices.long(), indices_by_batch.long(), mask.long()
155 |
156 | def post_process_stance(rationale_json, stance_json):
157 | assert(len(rationale_json) == len(stance_json))
158 | for stance_pred, rationale_pred in zip(stance_json, rationale_json):
159 | assert(stance_pred["claim_id"] == rationale_pred["claim_id"])
160 | for doc_id, pred in rationale_pred["evidence"].items():
161 | if len(pred) == 0:
162 | stance_pred["labels"][doc_id]["label"] = "NOT_ENOUGH_INFO"
163 | return stance_json
164 |
165 | if __name__ == "__main__":
166 | argparser = argparse.ArgumentParser(description="Train, cross-validate and run sentence sequence tagger")
167 | argparser.add_argument('--repfile', type=str, default = "roberta-large", help="Word embedding file")
168 | argparser.add_argument('--scifact_corpus', type=str, default="/home/xxl190027/scifact_data/corpus.jsonl")
169 | #argparser.add_argument('--fever_train', type=str)
170 | #argparser.add_argument('--scifact_train', type=str)
171 | argparser.add_argument('--scifact_test', type=str, default="/home/xxl190027/scifact_data/claims_dev_retrieved.jsonl")
172 | argparser.add_argument('--dropout', type=float, default=0, help="embedding_dropout rate")
173 | argparser.add_argument('--dataset', type=str, default="/home/xxl190027/scifact_data/claims_dev.jsonl")
174 | argparser.add_argument('--bert_dim', type=int, default=1024, help="bert_dimension")
175 | argparser.add_argument('--MAX_SENT_LEN', type=int, default=512)
176 | argparser.add_argument('--checkpoint', type=str, default = "domain_adaptation_roberta_joint_paragraph_kgat")
177 | argparser.add_argument('--log_file', type=str, default = "domain_adaptation_joint_paragraph_performances.log")
178 | argparser.add_argument('--update_step', type=int, default=10)
179 | argparser.add_argument('--batch_size', type=int, default=1) # roberta-large: 2; bert: 8
180 | argparser.add_argument('--scifact_k', type=int, default=0)
181 | argparser.add_argument('--kernel', type=int, default=6)
182 | argparser.add_argument('--prediction', type=str, default = "prediction_domain_adaptation_roberta_joint_paragraph_kgat.jsonl")
183 | logging.getLogger("transformers.tokenization_utils_base").setLevel(logging.ERROR)
184 |
185 | reset_random_seed(12345)
186 |
187 | args = argparser.parse_args()
188 |
189 | device = torch.device("cuda:0" if torch.cuda.is_available() else 'cpu')
190 | tokenizer = AutoTokenizer.from_pretrained(args.repfile)
191 |
192 | params = vars(args)
193 |
194 | for k,v in params.items():
195 | print(k,v)
196 |
197 | dev_set = SciFactParagraphBatchDataset(args.scifact_corpus, args.scifact_test,
198 | sep_token = tokenizer.sep_token, k = args.scifact_k, train=False)
199 | print("Loaded dataset!")
200 |
201 | model = DomainAdaptationJointParagraphKGATClassifier(args.repfile, args.bert_dim,
202 | args.dropout, kernel = args.kernel).to(device)
203 |
204 | model.load_state_dict(torch.load(args.checkpoint))
205 |
206 | reset_random_seed(12345)
207 | rationale_predictions, stance_preds = predict(model, dev_set)
208 | rationale_json = rationale2json(dev_set.samples, rationale_predictions)
209 | stance_json = stance2json(dev_set.samples, stance_preds)
210 | stance_json = post_process_stance(rationale_json, stance_json)
211 | merged_json = merge_json(rationale_json, stance_json)
212 |
213 | with jsonlines.open(args.prediction, 'w') as output:
214 | for result in merged_json:
215 | output.write(result)
216 |
217 | data = GoldDataset(args.scifact_corpus, args.dataset)
218 | predictions = PredictedDataset(data, args.prediction)
219 | res = metrics.compute_metrics(predictions)
220 | params["evaluation"] = res
221 | with jsonlines.open(args.log_file, mode='a') as writer:
222 | writer.write(params)
--------------------------------------------------------------------------------
/domain_adaptation_joint_paragraph_predict.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | import torch
4 | import jsonlines
5 | import os
6 | import sys
7 |
8 | import functools
9 | print = functools.partial(print, flush=True)
10 |
11 | import torch.nn as nn
12 | import torch.nn.functional as F
13 | from torch.utils.data import Dataset, DataLoader, Subset
14 | from transformers import AutoModel, AutoTokenizer, get_cosine_schedule_with_warmup
15 | from tqdm import tqdm
16 | from typing import List
17 | from sklearn.metrics import f1_score, precision_score, recall_score
18 |
19 | import math
20 | import random
21 | import numpy as np
22 |
23 | from tqdm import tqdm
24 | from util import arg2param, flatten, stance2json, rationale2json, merge_json
25 | from paragraph_model_dynamic import DomainAdaptationJointParagraphClassifier
26 | from dataset import FEVERParagraphBatchDataset, SciFactParagraphBatchDataset, SciFact_FEVER_Dataset, Multiple_SciFact_Dataset
27 |
28 | import logging
29 |
30 | from lib.data import GoldDataset, PredictedDataset
31 | from lib import metrics
32 |
33 |
34 | def reset_random_seed(seed):
35 | np.random.seed(seed)
36 | random.seed(seed)
37 | torch.manual_seed(seed)
38 |
39 | def batch_rationale_label(labels, padding_idx = 2):
40 | max_sent_len = max([len(label) for label in labels])
41 | label_matrix = torch.ones(len(labels), max_sent_len) * padding_idx
42 | label_list = []
43 | for i, label in enumerate(labels):
44 | for j, evid in enumerate(label):
45 | label_matrix[i,j] = int(evid)
46 | label_list.append([int(evid) for evid in label])
47 | return label_matrix.long(), label_list
48 |
49 | def predict(model, dataset):
50 | model.eval()
51 | rationale_predictions = []
52 | stance_preds = []
53 |
54 | def remove_dummy(rationale_out):
55 | return [out[1:] for out in rationale_out]
56 |
57 | with torch.no_grad():
58 | for batch in tqdm(DataLoader(dataset, batch_size = args.batch_size, shuffle=False)):
59 | encoded_dict = encode(tokenizer, batch)
60 | domain_indices = batch["dataset"].to(device)
61 | transformation_indices = token_idx_by_sentence(encoded_dict["input_ids"],
62 | tokenizer.sep_token_id, args.repfile)
63 | encoded_dict = {key: tensor.to(device) for key, tensor in encoded_dict.items()}
64 | transformation_indices = [tensor.to(device) for tensor in transformation_indices]
65 | rationale_out, stance_out, _, _ = model(encoded_dict, transformation_indices, domain_indices)
66 | stance_preds.extend(stance_out)
67 | rationale_predictions.extend(remove_dummy(rationale_out))
68 |
69 | return rationale_predictions, stance_preds
70 |
71 | def encode(tokenizer, batch, max_sent_len = 512):
72 | def truncate(input_ids, max_length, sep_token_id, pad_token_id):
73 | def longest_first_truncation(sentences, objective):
74 | sent_lens = [len(sent) for sent in sentences]
75 | while np.sum(sent_lens) > objective:
76 | max_position = np.argmax(sent_lens)
77 | sent_lens[max_position] -= 1
78 | return [sentence[:length] for sentence, length in zip(sentences, sent_lens)]
79 |
80 | all_paragraphs = []
81 | for paragraph in input_ids:
82 | valid_paragraph = paragraph[paragraph != pad_token_id]
83 | if valid_paragraph.size(0) <= max_length:
84 | all_paragraphs.append(paragraph[:max_length].unsqueeze(0))
85 | else:
86 | sep_token_idx = np.arange(valid_paragraph.size(0))[(valid_paragraph == sep_token_id).numpy()]
87 | idx_by_sentence = []
88 | prev_idx = 0
89 | for idx in sep_token_idx:
90 | idx_by_sentence.append(paragraph[prev_idx:idx])
91 | prev_idx = idx
92 | objective = max_length - 1 - len(idx_by_sentence[0]) # The last sep_token left out
93 | truncated_sentences = longest_first_truncation(idx_by_sentence[1:], objective)
94 | truncated_paragraph = torch.cat([idx_by_sentence[0]] + truncated_sentences + [torch.tensor([sep_token_id])],0)
95 | all_paragraphs.append(truncated_paragraph.unsqueeze(0))
96 |
97 | return torch.cat(all_paragraphs, 0)
98 |
99 | inputs = zip(batch["claim"], batch["paragraph"])
100 | encoded_dict = tokenizer.batch_encode_plus(
101 | inputs,
102 | pad_to_max_length=True,add_special_tokens=True,
103 | return_tensors='pt')
104 | if encoded_dict['input_ids'].size(1) > max_sent_len:
105 | if 'token_type_ids' in encoded_dict:
106 | encoded_dict = {
107 | "input_ids": truncate(encoded_dict['input_ids'], max_sent_len,
108 | tokenizer.sep_token_id, tokenizer.pad_token_id),
109 | 'token_type_ids': encoded_dict['token_type_ids'][:,:max_sent_len],
110 | 'attention_mask': encoded_dict['attention_mask'][:,:max_sent_len]
111 | }
112 | else:
113 | encoded_dict = {
114 | "input_ids": truncate(encoded_dict['input_ids'], max_sent_len,
115 | tokenizer.sep_token_id, tokenizer.pad_token_id),
116 | 'attention_mask': encoded_dict['attention_mask'][:,:max_sent_len]
117 | }
118 |
119 | return encoded_dict
120 |
121 |
122 | def token_idx_by_sentence(input_ids, sep_token_id, model_name):
123 | """
124 | Compute the token indices matrix of the BERT output.
125 | input_ids: (batch_size, paragraph_len)
126 | batch_indices, indices_by_batch, mask: (batch_size, N_sentence, N_token)
127 | bert_out: (batch_size, paragraph_len,BERT_dim)
128 | bert_out[batch_indices,indices_by_batch,:]: (batch_size, N_sentence, N_token, BERT_dim)
129 | """
130 | padding_idx = -1
131 | sep_tokens = (input_ids == sep_token_id).bool()
132 | paragraph_lens = torch.sum(sep_tokens,1).numpy().tolist()
133 | indices = torch.arange(sep_tokens.size(-1)).unsqueeze(0).expand(sep_tokens.size(0),-1)
134 | sep_indices = torch.split(indices[sep_tokens],paragraph_lens)
135 | paragraph_lens = []
136 | all_word_indices = []
137 | for paragraph in sep_indices:
138 | if "roberta" in model_name:
139 | paragraph = paragraph[1:]
140 | word_indices = [torch.arange(paragraph[i]+1, paragraph[i+1]+1) for i in range(paragraph.size(0)-1)]
141 | paragraph_lens.append(len(word_indices))
142 | all_word_indices.extend(word_indices)
143 | indices_by_sentence = nn.utils.rnn.pad_sequence(all_word_indices, batch_first=True, padding_value=padding_idx)
144 | indices_by_sentence_split = torch.split(indices_by_sentence,paragraph_lens)
145 | indices_by_batch = nn.utils.rnn.pad_sequence(indices_by_sentence_split, batch_first=True, padding_value=padding_idx)
146 | batch_indices = torch.arange(sep_tokens.size(0)).unsqueeze(-1).unsqueeze(-1).expand(-1,indices_by_batch.size(1),indices_by_batch.size(-1))
147 | mask = (indices_by_batch>=0)
148 |
149 | return batch_indices.long(), indices_by_batch.long(), mask.long()
150 |
151 | def post_process_stance(rationale_json, stance_json):
152 | assert(len(rationale_json) == len(stance_json))
153 | for stance_pred, rationale_pred in zip(stance_json, rationale_json):
154 | assert(stance_pred["claim_id"] == rationale_pred["claim_id"])
155 | for doc_id, pred in rationale_pred["evidence"].items():
156 | if len(pred) == 0:
157 | stance_pred["labels"][doc_id]["label"] = "NOT_ENOUGH_INFO"
158 | return stance_json
159 |
160 | if __name__ == "__main__":
161 | argparser = argparse.ArgumentParser(description="Train, cross-validate and run sentence sequence tagger")
162 | argparser.add_argument('--repfile', type=str, default = "roberta-large", help="Word embedding file")
163 | argparser.add_argument('--corpus', type=str, default="/nas/home/xiangcil/scifact/data/corpus.jsonl")
164 | argparser.add_argument('--scifact_train', type=str, default="/nas/home/xiangcil/CitationEvaluation/SciFact/claims_train_retrieved.jsonl")
165 | #argparser.add_argument('--scifact_train', type=str)
166 | argparser.add_argument('--scifact_test', type=str, default="/nas/home/xiangcil/CitationEvaluation/SciFact/claims_dev_retrieved.jsonl")
167 | argparser.add_argument('--dataset', type=str, default="/nas/home/xiangcil/CitationEvaluation/SciFact/claims_dev.jsonl")
168 | argparser.add_argument('--dropout', type=float, default=0, help="embedding_dropout rate")
169 | argparser.add_argument('--bert_dim', type=int, default=1024, help="bert_dimension")
170 | argparser.add_argument('--MAX_SENT_LEN', type=int, default=512)
171 | argparser.add_argument('--checkpoint', type=str, default = "domain_adaptation_roberta_joint_paragraph_fine_tune.model")
172 | argparser.add_argument('--log_file', type=str, default = "domain_adaptation_joint_paragraph_prediction.log")
173 | argparser.add_argument('--prediction', type=str, default = "prediction_domain_adaptation_roberta_joint_paragraph_dynamic_fine_tune.jsonl")
174 | argparser.add_argument('--batch_size', type=int, default=1) # roberta-large: 2; bert: 8
175 | argparser.add_argument('--k', type=int, default=0)
176 |
177 | logging.getLogger("transformers.tokenization_utils_base").setLevel(logging.ERROR)
178 | reset_random_seed(12345)
179 | args = argparser.parse_args()
180 |
181 | device = torch.device("cuda:0" if torch.cuda.is_available() else 'cpu')
182 | tokenizer = AutoTokenizer.from_pretrained(args.repfile)
183 |
184 | if args.scifact_train:
185 | train = True
186 | #assert args.repfile is not None, "Word embedding file required for training."
187 | else:
188 | train = False
189 | if args.scifact_test:
190 | test = True
191 | else:
192 | test = False
193 |
194 | params = vars(args)
195 |
196 | for k,v in params.items():
197 | print(k,v)
198 |
199 | dev_set = SciFactParagraphBatchDataset(args.corpus, args.scifact_test,
200 | sep_token = tokenizer.sep_token, k = args.k, downsample_n = 0, train = False)
201 |
202 | print("Loaded dataset!")
203 |
204 | model = DomainAdaptationJointParagraphClassifier(args.repfile, args.bert_dim,
205 | args.dropout).to(device)
206 | model.load_state_dict(torch.load(args.checkpoint))
207 | print("Loaded saved model.")
208 |
209 | rationale_predictions, stance_preds = predict(model, dev_set)
210 | rationale_json = rationale2json(dev_set.samples, rationale_predictions)
211 | stance_json = stance2json(dev_set.samples, stance_preds)
212 | stance_json = post_process_stance(rationale_json, stance_json)
213 | merged_json = merge_json(rationale_json, stance_json)
214 |
215 | with jsonlines.open(args.prediction, 'w') as output:
216 | for result in merged_json:
217 | output.write(result)
218 |
219 | data = GoldDataset(args.corpus, args.dataset)
220 | predictions = PredictedDataset(data, args.prediction)
221 | res = metrics.compute_metrics(predictions)
222 | params["evaluation"] = res
223 | with jsonlines.open(args.log_file, mode='a') as writer:
224 | writer.write(params)
--------------------------------------------------------------------------------
/lib/data.py:
--------------------------------------------------------------------------------
1 | """
2 | Code to represent a dataset release.
3 | """
4 |
5 | from enum import Enum
6 | import json
7 | import copy
8 | from dataclasses import dataclass
9 | from typing import Dict, List, Tuple
10 |
11 | ####################
12 |
13 | # Utility functions and enums.
14 |
15 |
16 | def load_jsonl(fname):
17 | return [json.loads(line) for line in open(fname)]
18 |
19 |
20 | class Label(Enum):
21 | SUPPORTS = 1
22 | NEI = 0
23 | REFUTES = -1
24 |
25 |
26 | def make_label(label_str, allow_NEI=True):
27 | lookup = {"SUPPORT": Label.SUPPORTS,
28 | "NOT_ENOUGH_INFO": Label.NEI,
29 | "CONTRADICT": Label.REFUTES}
30 |
31 | res = lookup[label_str]
32 | if (not allow_NEI) and (res is Label.NEI):
33 | raise ValueError("An NEI was given.")
34 |
35 | return res
36 |
37 |
38 | ####################
39 |
40 | # Representations for the corpus and abstracts.
41 |
42 | @dataclass(repr=False, frozen=True)
43 | class Document:
44 | id: str
45 | title: str
46 | sentences: Tuple[str]
47 |
48 | def __repr__(self):
49 | return self.title.upper() + "\n" + "\n".join(["- " + entry for entry in self.sentences])
50 |
51 | def __lt__(self, other):
52 | return self.title.__lt__(other.title)
53 |
54 | def dump(self):
55 | res = {"doc_id": self.id,
56 | "title": self.title,
57 | "abstract": self.sentences,
58 | "structured": self.is_structured()}
59 | return json.dumps(res)
60 |
61 |
62 | @dataclass(repr=False, frozen=True)
63 | class Corpus:
64 | """
65 | A Corpus is just a collection of `Document` objects, with methods to look up
66 | a single document.
67 | """
68 | documents: List[Document]
69 |
70 | def __repr__(self):
71 | return f"Corpus of {len(self.documents)} documents."
72 |
73 | def __getitem__(self, i):
74 | "Get document by index in list."
75 | return self.documents[i]
76 |
77 | def get_document(self, doc_id):
78 | "Get document by ID."
79 | res = [x for x in self.documents if x.id == doc_id]
80 | assert len(res) == 1
81 | return res[0]
82 |
83 | @classmethod
84 | def from_jsonl(cls, corpus_file):
85 | corpus = load_jsonl(corpus_file)
86 | documents = []
87 | for entry in corpus:
88 | doc = Document(entry["doc_id"], entry["title"], entry["abstract"])
89 | documents.append(doc)
90 |
91 | return cls(documents)
92 |
93 |
94 | ####################
95 |
96 | # Gold dataset.
97 |
98 | class GoldDataset:
99 | """
100 | Class to represent a gold dataset, include corpus and claims.
101 | """
102 | def __init__(self, corpus_file, data_file):
103 | self.corpus = Corpus.from_jsonl(corpus_file)
104 | self.claims = self._read_claims(data_file)
105 |
106 | def __repr__(self):
107 | msg = f"{self.corpus.__repr__()} {len(self.claims)} claims."
108 | return msg
109 |
110 | def __getitem__(self, i):
111 | return self.claims[i]
112 |
113 | def _read_claims(self, data_file):
114 | "Read claims from file."
115 | examples = load_jsonl(data_file)
116 | res = []
117 | for this_example in examples:
118 | entry = copy.deepcopy(this_example)
119 | entry["release"] = self
120 | entry["cited_docs"] = [self.corpus.get_document(doc)
121 | for doc in entry["cited_doc_ids"]]
122 | assert len(entry["cited_docs"]) == len(entry["cited_doc_ids"])
123 | del entry["cited_doc_ids"]
124 | res.append(Claim(**entry))
125 |
126 | res = sorted(res, key=lambda x: x.id)
127 | return res
128 |
129 | def get_claim(self, example_id):
130 | "Get a single claim by ID."
131 | keep = [x for x in self.claims if x.id == example_id]
132 | assert len(keep) == 1
133 | return keep[0]
134 |
135 |
136 | @dataclass
137 | class EvidenceAbstract:
138 | "A single evidence abstract."
139 | id: int
140 | label: Label
141 | rationales: List[List[int]]
142 |
143 |
144 | @dataclass(repr=False)
145 | class Claim:
146 | """
147 | Class representing a single claim, with a pointer back to the dataset.
148 | """
149 | id: int
150 | claim: str
151 | evidence: Dict[int, EvidenceAbstract]
152 | cited_docs: List[Document]
153 | release: GoldDataset
154 |
155 | def __post_init__(self):
156 | self.evidence = self._format_evidence(self.evidence)
157 |
158 | @staticmethod
159 | def _format_evidence(evidence_dict):
160 | # This function is needed because the data schema is designed so that
161 | # each rationale can have its own support label. But, in the dataset,
162 | # all rationales for a given claim / abstract pair all have the same
163 | # label. So, we store the label at the "abstract level" rather than the
164 | # "rationale level".
165 | res = {}
166 | for doc_id, rationales in evidence_dict.items():
167 | doc_id = int(doc_id)
168 | labels = [x["label"] for x in rationales]
169 | if len(set(labels)) > 1:
170 | msg = ("In this SciFact release, each claim / abstract pair "
171 | "should only have one label.")
172 | raise Exception(msg)
173 | label = make_label(labels[0])
174 | rationale_sents = [x["sentences"] for x in rationales]
175 | this_abstract = EvidenceAbstract(doc_id, label, rationale_sents)
176 | res[doc_id] = this_abstract
177 |
178 | return res
179 |
180 | def __repr__(self):
181 | msg = f"Example {self.id}: {self.claim}"
182 | return msg
183 |
184 | def pretty_print(self, evidence_doc_id=None, file=None):
185 | "Pretty-print the claim, together with all evidence."
186 | msg = self.__repr__()
187 | print(msg, file=file)
188 | # Print the evidence
189 | print("\nEvidence sets:", file=file)
190 | for doc_id, evidence in self.evidence.items():
191 | # If asked for a specific evidence doc, only show that one.
192 | if evidence_doc_id is not None and doc_id != evidence_doc_id:
193 | continue
194 | print("\n" + 20 * "#" + "\n", file=file)
195 | ev_doc = self.release.corpus.get_document(doc_id)
196 | print(f"{doc_id}: {evidence.label.name}", file=file)
197 | for i, sents in enumerate(evidence.rationales):
198 | print(f"Set {i}:", file=file)
199 | kept = [sent for i, sent in enumerate(ev_doc.sentences) if i in sents]
200 | for entry in kept:
201 | print(f"\t- {entry}", file=file)
202 |
203 |
204 | ####################
205 |
206 | # Predicted dataset.
207 |
208 | class PredictedDataset:
209 | """
210 | Class to handle predictions, with a pointer back to the gold data.
211 | """
212 | def __init__(self, gold, prediction_file):
213 | """
214 | Takes a GoldDataset, as well as files with rationale and label
215 | predictions.
216 | """
217 | self.gold = gold
218 | self.predictions = self._read_predictions(prediction_file)
219 |
220 | def __getitem__(self, i):
221 | return self.predictions[i]
222 |
223 | def __repr__(self):
224 | msg = f"Predictions for {len(self.predictions)} claims."
225 | return msg
226 |
227 | def _read_predictions(self, prediction_file):
228 | res = []
229 |
230 | predictions = load_jsonl(prediction_file)
231 | for pred in predictions:
232 | prediction = self._parse_prediction(pred)
233 | res.append(prediction)
234 |
235 | return res
236 |
237 | def _parse_prediction(self, pred_dict):
238 | claim_id = pred_dict["id"]
239 | predicted_evidence = pred_dict["evidence"]
240 |
241 | res = {}
242 |
243 | # Predictions should never be NEI; there should only be predictions for
244 | # the abstracts that contain evidence.
245 | for key, this_prediction in predicted_evidence.items():
246 | label = this_prediction["label"]
247 | evidence = this_prediction["sentences"]
248 | pred = PredictedAbstract(int(key),
249 | make_label(label, allow_NEI=False),
250 | evidence)
251 | res[int(key)] = pred
252 |
253 | gold_claim = self.gold.get_claim(claim_id)
254 | return ClaimPredictions(claim_id, res, gold_claim)
255 |
256 |
257 | @dataclass
258 | class PredictedAbstract:
259 | # For predictions, we have a single list of rationale sentences instead of a
260 | # list of separate rationales (see paper for details).
261 | abstract_id: int
262 | label: Label
263 | rationale: List
264 |
265 |
266 | @dataclass
267 | class ClaimPredictions:
268 | claim_id: int
269 | predictions: Dict[int, PredictedAbstract]
270 | gold: Claim = None # For backward compatibility, default this to None.
271 |
272 | def __repr__(self):
273 | msg = f"Predictions for {self.claim_id}: {self.gold.claim}"
274 | return msg
275 |
276 | def pretty_print(self, evidence_doc_id=None, file=None):
277 | msg = self.__repr__()
278 | print(msg, file=file)
279 | # Print the evidence
280 | print("\nEvidence sets:", file=file)
281 | for doc_id, prediction in self.predictions.items():
282 | # If asked for a specific evidence doc, only show that one.
283 | if evidence_doc_id is not None and doc_id != evidence_doc_id:
284 | continue
285 | print("\n" + 20 * "#" + "\n", file=file)
286 | ev_doc = self.gold.release.corpus.get_document(doc_id)
287 | print(f"{doc_id}: {prediction.label.name}", file=file)
288 | # Print the predicted rationale.
289 | sents = prediction.rationale
290 | kept = [sent for i, sent in enumerate(ev_doc.sentences) if i in sents]
291 | for entry in kept:
292 | print(f"\t- {entry}", file=file)
293 |
--------------------------------------------------------------------------------
/lib/metrics.py:
--------------------------------------------------------------------------------
1 | """
2 | Evaluating abstract-level and sentence-level performance as described in the
3 | paper.
4 | """
5 |
6 | import warnings
7 |
8 | from .data import Label
9 | from collections import Counter
10 | import pandas as pd
11 |
12 |
13 | # Cap on how many abstract sentences can be returned.
14 | MAX_ABSTRACT_SENTS = 3
15 |
16 |
17 | def safe_divide(num, denom):
18 | if denom == 0:
19 | return 0
20 | else:
21 | return num / denom
22 |
23 |
24 | def compute_f1(counts, difficulty=None):
25 | correct_key = "correct" if difficulty is None else f"correct_{difficulty}"
26 | precision = safe_divide(counts[correct_key], counts["retrieved"])
27 | recall = safe_divide(counts[correct_key], counts["relevant"])
28 | f1 = safe_divide(2 * precision * recall, precision + recall)
29 | return {"precision": precision, "recall": recall, "f1": f1}
30 |
31 |
32 | ####################
33 |
34 | # Abstract-level evaluation
35 |
36 | def contains_evidence(predicted, gold):
37 | # If any of gold are contained in predicted, we're good.
38 | for gold_rat in gold:
39 | if gold_rat.issubset(predicted):
40 | return True
41 | # If we get to the end, didn't find one.
42 | return False
43 |
44 |
45 | def is_correct(doc_id, doc_pred, gold):
46 | pred_rationales = doc_pred.rationale[:MAX_ABSTRACT_SENTS]
47 |
48 | # If it's not an evidence document, we lose.
49 | if doc_id not in gold.evidence:
50 | return False, False
51 |
52 | # If the label's wrong, we lose.
53 | gold_label = gold.evidence[doc_id].label
54 | if doc_pred.label != gold_label:
55 | return False, False
56 |
57 | gold_rationales = [set(x) for x in gold.evidence[doc_id].rationales]
58 | good_rationalized = contains_evidence(set(pred_rationales), gold_rationales)
59 | good_label_only = True
60 | return good_label_only, good_rationalized
61 |
62 |
63 | def update_counts_abstract(pred, gold, counts_abstract):
64 | counts_abstract["relevant"] += len(gold.evidence)
65 | for doc_id, doc_pred in pred.predictions.items():
66 | # If it's NEI, doesn't count one way or the other.
67 | if doc_pred.label == Label.NEI:
68 | continue
69 | counts_abstract["retrieved"] += 1
70 |
71 | good_label_only, good_rationalized = is_correct(doc_id, doc_pred, gold)
72 | if good_label_only:
73 | counts_abstract["correct_label_only"] += 1
74 | if good_rationalized:
75 | counts_abstract["correct_rationalized"] += 1
76 |
77 | return counts_abstract
78 |
79 |
80 | ####################
81 |
82 | # Sentence-level evaluation
83 |
84 | def count_rationale_sents(predicted, gold):
85 | n_correct = 0
86 |
87 | for ix in predicted:
88 | gold_sets = [entry for entry in gold if ix in entry]
89 | assert len(gold_sets) < 2 # Can't be in two rationales.
90 | # If it's not in a gold set, no dice.
91 | if len(gold_sets) == 0:
92 | continue
93 | # If it's in a gold set, make sure the rest got retrieved.
94 | gold_set = gold_sets[0]
95 | if gold_set.issubset(predicted):
96 | n_correct += 1
97 |
98 | return n_correct
99 |
100 |
101 | def count_correct(doc_id, doc_pred, gold):
102 | # If not an evidence doc, no good.
103 | if doc_id not in gold.evidence:
104 | return 0, 0
105 |
106 | # Count the number of rationale sentences we get credit for.
107 | gold_rationales = [set(x) for x in gold.evidence[doc_id].rationales]
108 | n_correct = count_rationale_sents(set(doc_pred.rationale), gold_rationales)
109 |
110 | gold_label = gold.evidence[doc_id].label
111 |
112 | n_correct_selection = n_correct
113 | correct_label = int(doc_pred.label == gold_label)
114 | n_correct_label = correct_label * n_correct
115 |
116 | return n_correct_selection, n_correct_label
117 |
118 |
119 | def update_counts_sentence(pred, gold, counts_sentence):
120 | # Update the gold evidence sentences.
121 | for gold_doc in gold.evidence.values():
122 | counts_sentence["relevant"] += sum([len(x) for x in gold_doc.rationales])
123 |
124 | for doc_id, doc_pred in pred.predictions.items():
125 | # If it's NEI, skip it.
126 | if doc_pred.label == Label.NEI:
127 | continue
128 |
129 | counts_sentence["retrieved"] += len(doc_pred.rationale)
130 | n_correct_selection, n_correct_label = count_correct(doc_id, doc_pred, gold)
131 | counts_sentence["correct_selection"] += n_correct_selection
132 | counts_sentence["correct_label"] += n_correct_label
133 |
134 | return counts_sentence
135 |
136 |
137 | ####################
138 |
139 | # Make sure rationales aren't too long.
140 |
141 | def check_rationale_lengths(preds):
142 | bad = []
143 | for pred in preds:
144 | claim_id = pred.claim_id
145 | predictions = pred.predictions
146 | for doc_key, prediction in predictions.items():
147 | n_rationales = len(prediction.rationale)
148 | if n_rationales > MAX_ABSTRACT_SENTS:
149 | to_append = {"claim_id": claim_id, "abstract": doc_key, "n_rationales": n_rationales}
150 | bad.append(to_append)
151 | if bad:
152 | bad = pd.DataFrame(bad)
153 | msg = (f"\nRationales with more than {MAX_ABSTRACT_SENTS} sentences found.\n"
154 | f"The first 3 will be used for abstract-level evaluation\n\n"
155 | f"{bad.__repr__()}")
156 | warnings.warn(msg)
157 | print()
158 |
159 |
160 | ################################################################################
161 |
162 | def compute_metrics(preds):
163 | """
164 | Compute pipeline metrics based on dataset of predictions.
165 | """
166 | counts_abstract = Counter()
167 | counts_sentence = Counter()
168 |
169 | check_rationale_lengths(preds)
170 |
171 | for pred in preds:
172 | gold = preds.gold.get_claim(pred.claim_id)
173 | counts_abstract = update_counts_abstract(pred, gold, counts_abstract)
174 | counts_sentence = update_counts_sentence(pred, gold, counts_sentence)
175 |
176 | return {"sentence_selection": compute_f1(counts_sentence, "selection"),
177 | "sentence_label": compute_f1(counts_sentence, "label"),
178 | "abstract_label_only": compute_f1(counts_abstract, "label_only"),
179 | "abstract_rationalized": compute_f1(counts_abstract, "rationalized")
180 | }
181 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | scipy==1.5.3
2 | sent2vec==0.0.0
3 | tqdm==4.50.2
4 | jsonlines==1.2.0
5 | torch==1.6.0
6 | numpy==1.19.1
7 | transformers==2.6.0
8 | nltk==3.5
9 | pandas==1.1.3
10 | dataclasses==0.7
11 | scikit_learn==0.24.1
12 |
--------------------------------------------------------------------------------
/scifact_joint_paragraph_dynamic_prediction.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | import torch
4 | import jsonlines
5 | import os
6 |
7 | import torch.nn as nn
8 | import torch.nn.functional as F
9 | from torch.utils.data import Dataset, DataLoader
10 | from transformers import AutoModel, AutoTokenizer, get_cosine_schedule_with_warmup
11 | from tqdm import tqdm
12 | from typing import List
13 | from sklearn.metrics import f1_score, precision_score, recall_score
14 |
15 | import math
16 | import random
17 | import numpy as np
18 |
19 | from tqdm import tqdm
20 | from util import arg2param, flatten, stance2json, rationale2json, merge_json
21 | from paragraph_model_dynamic import JointParagraphClassifier
22 | from dataset import SciFactParagraphBatchDataset
23 |
24 | import logging
25 |
26 | from lib.data import GoldDataset, PredictedDataset
27 | from lib import metrics
28 |
29 | def reset_random_seed(seed):
30 | np.random.seed(seed)
31 | random.seed(seed)
32 | torch.manual_seed(seed)
33 |
34 | def predict(model, dataset):
35 | model.eval()
36 | rationale_predictions = []
37 | stance_preds = []
38 |
39 | def remove_dummy(rationale_out):
40 | return [out[1:] for out in rationale_out]
41 |
42 | with torch.no_grad():
43 | for batch in tqdm(DataLoader(dataset, batch_size = args.batch_size, shuffle=False)):
44 | encoded_dict = encode(tokenizer, batch)
45 | transformation_indices = token_idx_by_sentence(encoded_dict["input_ids"],
46 | tokenizer.sep_token_id, args.repfile)
47 | encoded_dict = {key: tensor.to(device) for key, tensor in encoded_dict.items()}
48 | transformation_indices = [tensor.to(device) for tensor in transformation_indices]
49 | rationale_out, stance_out, _, _ = model(encoded_dict, transformation_indices)
50 | stance_preds.extend(stance_out)
51 | rationale_predictions.extend(remove_dummy(rationale_out))
52 |
53 | return rationale_predictions, stance_preds
54 |
55 |
56 | def encode(tokenizer, batch, max_sent_len = 512):
57 | def truncate(input_ids, max_length, sep_token_id, pad_token_id):
58 | def longest_first_truncation(sentences, objective):
59 | sent_lens = [len(sent) for sent in sentences]
60 | while np.sum(sent_lens) > objective:
61 | max_position = np.argmax(sent_lens)
62 | sent_lens[max_position] -= 1
63 | return [sentence[:length] for sentence, length in zip(sentences, sent_lens)]
64 |
65 | all_paragraphs = []
66 | for paragraph in input_ids:
67 | valid_paragraph = paragraph[paragraph != pad_token_id]
68 | if valid_paragraph.size(0) <= max_length:
69 | all_paragraphs.append(paragraph[:max_length].unsqueeze(0))
70 | else:
71 | sep_token_idx = np.arange(valid_paragraph.size(0))[(valid_paragraph == sep_token_id).numpy()]
72 | idx_by_sentence = []
73 | prev_idx = 0
74 | for idx in sep_token_idx:
75 | idx_by_sentence.append(paragraph[prev_idx:idx])
76 | prev_idx = idx
77 | objective = max_length - 1 - len(idx_by_sentence[0]) # The last sep_token left out
78 | truncated_sentences = longest_first_truncation(idx_by_sentence[1:], objective)
79 | truncated_paragraph = torch.cat([idx_by_sentence[0]] + truncated_sentences + [torch.tensor([sep_token_id])],0)
80 | all_paragraphs.append(truncated_paragraph.unsqueeze(0))
81 |
82 | return torch.cat(all_paragraphs, 0)
83 |
84 | inputs = zip(batch["claim"], batch["paragraph"])
85 | encoded_dict = tokenizer.batch_encode_plus(
86 | inputs,
87 | pad_to_max_length=True,add_special_tokens=True,
88 | return_tensors='pt')
89 | if encoded_dict['input_ids'].size(1) > max_sent_len:
90 | if 'token_type_ids' in encoded_dict:
91 | encoded_dict = {
92 | "input_ids": truncate(encoded_dict['input_ids'], max_sent_len,
93 | tokenizer.sep_token_id, tokenizer.pad_token_id),
94 | 'token_type_ids': encoded_dict['token_type_ids'][:,:max_sent_len],
95 | 'attention_mask': encoded_dict['attention_mask'][:,:max_sent_len]
96 | }
97 | else:
98 | encoded_dict = {
99 | "input_ids": truncate(encoded_dict['input_ids'], max_sent_len,
100 | tokenizer.sep_token_id, tokenizer.pad_token_id),
101 | 'attention_mask': encoded_dict['attention_mask'][:,:max_sent_len]
102 | }
103 | return encoded_dict
104 |
105 | def token_idx_by_sentence(input_ids, sep_token_id, model_name):
106 | """
107 | Compute the token indices matrix of the BERT output.
108 | input_ids: (batch_size, paragraph_len)
109 | batch_indices, indices_by_batch, mask: (batch_size, N_sentence, N_token)
110 | bert_out: (batch_size, paragraph_len,BERT_dim)
111 | bert_out[batch_indices,indices_by_batch,:]: (batch_size, N_sentence, N_token, BERT_dim)
112 | """
113 | padding_idx = -1
114 | sep_tokens = (input_ids == sep_token_id).bool()
115 | paragraph_lens = torch.sum(sep_tokens,1).numpy().tolist()
116 | indices = torch.arange(sep_tokens.size(-1)).unsqueeze(0).expand(sep_tokens.size(0),-1)
117 | sep_indices = torch.split(indices[sep_tokens],paragraph_lens)
118 | paragraph_lens = []
119 | all_word_indices = []
120 | for paragraph in sep_indices:
121 | if "roberta" in model_name:
122 | paragraph = paragraph[1:]
123 | word_indices = [torch.arange(paragraph[i]+1, paragraph[i+1]+1) for i in range(paragraph.size(0)-1)]
124 | paragraph_lens.append(len(word_indices))
125 | all_word_indices.extend(word_indices)
126 | indices_by_sentence = nn.utils.rnn.pad_sequence(all_word_indices, batch_first=True, padding_value=padding_idx)
127 | indices_by_sentence_split = torch.split(indices_by_sentence,paragraph_lens)
128 | indices_by_batch = nn.utils.rnn.pad_sequence(indices_by_sentence_split, batch_first=True, padding_value=padding_idx)
129 | batch_indices = torch.arange(sep_tokens.size(0)).unsqueeze(-1).unsqueeze(-1).expand(-1,indices_by_batch.size(1),indices_by_batch.size(-1))
130 | mask = (indices_by_batch>=0)
131 |
132 | return batch_indices.long(), indices_by_batch.long(), mask.long()
133 |
134 | def post_process_stance(rationale_json, stance_json):
135 | assert(len(rationale_json) == len(stance_json))
136 | for stance_pred, rationale_pred in zip(stance_json, rationale_json):
137 | assert(stance_pred["claim_id"] == rationale_pred["claim_id"])
138 | for doc_id, pred in rationale_pred["evidence"].items():
139 | if len(pred) == 0:
140 | stance_pred["labels"][doc_id]["label"] = "NOT_ENOUGH_INFO"
141 | return stance_json
142 |
143 | def post_process_rationale_score(rationale_scores, max_positive=3): ### Doesn't seem to be helpful?
144 | def process_rationale_score(paragraph_rationale_scores):
145 | paragraph_rationale_scores = np.array(paragraph_rationale_scores)
146 | if np.sum(paragraph_rationale_scores > 0.5) > max_positive:
147 | output = np.zeros(paragraph_rationale_scores.shape)
148 | positive_indices = np.argsort(paragraph_rationale_scores)[::-1][:max_positive]
149 | output[positive_indices] = 1
150 | else:
151 | output = (paragraph_rationale_scores > 0.5).astype(int)
152 | return output.tolist()
153 | return [process_rationale_score(paragraph_rationale_scores) for paragraph_rationale_scores in rationale_scores]
154 |
155 | if __name__ == "__main__":
156 | argparser = argparse.ArgumentParser(description="Train, cross-validate and run sentence sequence tagger")
157 | argparser.add_argument('--repfile', type=str, default = "roberta-large", help="Word embedding file")
158 | argparser.add_argument('--corpus_file', type=str, default="/home/xxl190027/scifact_data/corpus.jsonl")
159 | argparser.add_argument('--test_file', type=str, default="/home/xxl190027/scifact_data/claims_dev_retrieved.jsonl")
160 | argparser.add_argument('--dataset', type=str, default="/home/xxl190027/scifact_data/claims_dev.jsonl")
161 | argparser.add_argument('--dropout', type=float, default=0, help="embedding_dropout rate")
162 | argparser.add_argument('--bert_dim', type=int, default=1024, help="bert_dimension")
163 | argparser.add_argument('--MAX_SENT_LEN', type=int, default=512)
164 | argparser.add_argument('--checkpoint', type=str, default = "scifact_roberta_joint_paragraph.model")
165 | argparser.add_argument('--batch_size', type=int, default=25)
166 | argparser.add_argument('--k', type=int, default=0)
167 | argparser.add_argument('--rationale_selection', type=str)
168 | argparser.add_argument('--label_prediction', type=str)
169 | argparser.add_argument('--prediction', type=str)
170 | argparser.add_argument('--log_file', type=str, default = "prediction_dynamic.log")
171 | argparser.add_argument('--evaluate', action='store_true')
172 |
173 |
174 | logging.getLogger("transformers.tokenization_utils_base").setLevel(logging.ERROR)
175 |
176 | reset_random_seed(12345)
177 |
178 | args = argparser.parse_args()
179 | params = vars(args)
180 |
181 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
182 | tokenizer = AutoTokenizer.from_pretrained(args.repfile)
183 |
184 | dev_set = SciFactParagraphBatchDataset(args.corpus_file, args.test_file,
185 | sep_token = tokenizer.sep_token, k = args.k, train=False)
186 |
187 | model = JointParagraphClassifier(args.repfile, args.bert_dim,
188 | args.dropout).to(device)
189 |
190 | model.load_state_dict(torch.load(args.checkpoint))
191 | print("Loaded saved model.")
192 |
193 | reset_random_seed(12345)
194 | rationale_predictions, stance_preds = predict(model, dev_set)
195 | rationale_json = rationale2json(dev_set.samples, rationale_predictions)
196 | stance_json = stance2json(dev_set.samples, stance_preds)
197 | stance_json = post_process_stance(rationale_json, stance_json)
198 | merged_json = merge_json(rationale_json, stance_json)
199 | if args.rationale_selection is not None:
200 | with jsonlines.open(args.rationale_selection, 'w') as output:
201 | for result in rationale_json:
202 | output.write(result)
203 | if args.label_prediction is not None:
204 | with jsonlines.open(args.label_prediction, 'w') as output:
205 | for result in stance_json:
206 | output.write(result)
207 | if args.prediction is not None:
208 | with jsonlines.open(args.prediction, 'w') as output:
209 | for result in merged_json:
210 | output.write(result)
211 | if args.evaluate:
212 | data = GoldDataset(args.corpus_file, args.dataset)
213 | predictions = PredictedDataset(data, args.prediction)
214 | res = metrics.compute_metrics(predictions)
215 | params["evaluation"] = res
216 | with jsonlines.open(args.log_file, mode='a') as writer:
217 | writer.write(params)
--------------------------------------------------------------------------------
/scifact_joint_paragraph_kgat.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | import torch
4 | import jsonlines
5 | import os
6 |
7 | import torch.nn as nn
8 | import torch.nn.functional as F
9 | from torch.utils.data import Dataset, DataLoader
10 | from transformers import AutoModel, AutoTokenizer, get_cosine_schedule_with_warmup
11 | from tqdm import tqdm
12 | from typing import List
13 | from sklearn.metrics import f1_score, precision_score, recall_score
14 |
15 | import math
16 | import random
17 | import numpy as np
18 |
19 | from tqdm import tqdm
20 | from util import arg2param, flatten, stance2json, rationale2json, merge_json
21 | from paragraph_model_kgat import JointParagraphKGATClassifier
22 | from dataset import SciFactParagraphBatchDataset
23 |
24 | import logging
25 |
26 | from lib.data import GoldDataset, PredictedDataset
27 | from lib import metrics
28 |
29 | def schedule_sample_p(epoch, total):
30 | return np.sin(0.5* np.pi* epoch / (total-1))
31 |
32 | def reset_random_seed(seed):
33 | np.random.seed(seed)
34 | random.seed(seed)
35 | torch.manual_seed(seed)
36 |
37 | def predict(model, dataset):
38 | model.eval()
39 | rationale_predictions = []
40 | stance_preds = []
41 |
42 | def remove_dummy(rationale_out):
43 | return [out[1:] for out in rationale_out]
44 |
45 | with torch.no_grad():
46 | for batch in tqdm(DataLoader(dataset, batch_size = args.batch_size, shuffle=False)):
47 | encoded_dict = encode(tokenizer, batch)
48 | transformation_indices = token_idx_by_sentence(encoded_dict["input_ids"],
49 | tokenizer.sep_token_id, args.repfile)
50 | encoded_dict = {key: tensor.to(device) for key, tensor in encoded_dict.items()}
51 | transformation_indices = [tensor.to(device) for tensor in transformation_indices]
52 | rationale_out, stance_out, _, _ = model(encoded_dict, transformation_indices)
53 | stance_preds.extend(stance_out)
54 | rationale_predictions.extend(remove_dummy(rationale_out))
55 |
56 | return rationale_predictions, stance_preds
57 |
58 | def batch_rationale_label(labels, padding_idx = 2):
59 | max_sent_len = max([len(label) for label in labels])
60 | label_matrix = torch.ones(len(labels), max_sent_len) * padding_idx
61 | label_list = []
62 | for i, label in enumerate(labels):
63 | for j, evid in enumerate(label):
64 | label_matrix[i,j] = int(evid)
65 | label_list.append([int(evid) for evid in label])
66 | return label_matrix.long(), label_list
67 |
68 | def evaluation(model, dataset):
69 | model.eval()
70 | rationale_predictions = []
71 | rationale_labels = []
72 | stance_preds = []
73 | stance_labels = []
74 |
75 | def remove_dummy(rationale_out):
76 | return [out[1:] for out in rationale_out]
77 |
78 | with torch.no_grad():
79 | for batch in tqdm(DataLoader(dataset, batch_size = args.batch_size*5, shuffle=False)):
80 | encoded_dict = encode(tokenizer, batch)
81 | transformation_indices = token_idx_by_sentence(encoded_dict["input_ids"],
82 | tokenizer.sep_token_id, args.repfile)
83 | encoded_dict = {key: tensor.to(device) for key, tensor in encoded_dict.items()}
84 | transformation_indices = [tensor.to(device) for tensor in transformation_indices]
85 | stance_label = batch["stance"].to(device)
86 | padded_rationale_label, rationale_label = batch_rationale_label(batch["label"], padding_idx = 2)
87 | rationale_out, stance_out, rationale_loss, stance_loss = \
88 | model(encoded_dict, transformation_indices, stance_label = stance_label,
89 | rationale_label = padded_rationale_label.to(device))
90 | stance_preds.extend(stance_out)
91 | stance_labels.extend(stance_label.cpu().numpy().tolist())
92 |
93 | rationale_predictions.extend(remove_dummy(rationale_out))
94 | rationale_labels.extend(remove_dummy(rationale_label))
95 |
96 | stance_f1 = f1_score(stance_labels,stance_preds,average="micro",labels=[1, 2])
97 | stance_precision = precision_score(stance_labels,stance_preds,average="micro",labels=[1, 2])
98 | stance_recall = recall_score(stance_labels,stance_preds,average="micro",labels=[1, 2])
99 | rationale_f1 = f1_score(flatten(rationale_labels),flatten(rationale_predictions))
100 | rationale_precision = precision_score(flatten(rationale_labels),flatten(rationale_predictions))
101 | rationale_recall = recall_score(flatten(rationale_labels),flatten(rationale_predictions))
102 | return stance_f1, stance_precision, stance_recall, rationale_f1, rationale_precision, rationale_recall
103 |
104 |
105 |
106 | def encode(tokenizer, batch, max_sent_len = 512):
107 | def truncate(input_ids, max_length, sep_token_id, pad_token_id):
108 | def longest_first_truncation(sentences, objective):
109 | sent_lens = [len(sent) for sent in sentences]
110 | while np.sum(sent_lens) > objective:
111 | max_position = np.argmax(sent_lens)
112 | sent_lens[max_position] -= 1
113 | return [sentence[:length] for sentence, length in zip(sentences, sent_lens)]
114 |
115 | all_paragraphs = []
116 | for paragraph in input_ids:
117 | valid_paragraph = paragraph[paragraph != pad_token_id]
118 | if valid_paragraph.size(0) <= max_length:
119 | all_paragraphs.append(paragraph[:max_length].unsqueeze(0))
120 | else:
121 | sep_token_idx = np.arange(valid_paragraph.size(0))[(valid_paragraph == sep_token_id).numpy()]
122 | idx_by_sentence = []
123 | prev_idx = 0
124 | for idx in sep_token_idx:
125 | idx_by_sentence.append(paragraph[prev_idx:idx])
126 | prev_idx = idx
127 | objective = max_length - 1 - len(idx_by_sentence[0]) # The last sep_token left out
128 | truncated_sentences = longest_first_truncation(idx_by_sentence[1:], objective)
129 | truncated_paragraph = torch.cat([idx_by_sentence[0]] + truncated_sentences + [torch.tensor([sep_token_id])],0)
130 | all_paragraphs.append(truncated_paragraph.unsqueeze(0))
131 |
132 | return torch.cat(all_paragraphs, 0)
133 |
134 | inputs = zip(batch["claim"], batch["paragraph"])
135 | encoded_dict = tokenizer.batch_encode_plus(
136 | inputs,
137 | pad_to_max_length=True,add_special_tokens=True,
138 | return_tensors='pt')
139 | if encoded_dict['input_ids'].size(1) > max_sent_len:
140 | if 'token_type_ids' in encoded_dict:
141 | encoded_dict = {
142 | "input_ids": truncate(encoded_dict['input_ids'], max_sent_len,
143 | tokenizer.sep_token_id, tokenizer.pad_token_id),
144 | 'token_type_ids': encoded_dict['token_type_ids'][:,:max_sent_len],
145 | 'attention_mask': encoded_dict['attention_mask'][:,:max_sent_len]
146 | }
147 | else:
148 | encoded_dict = {
149 | "input_ids": truncate(encoded_dict['input_ids'], max_sent_len,
150 | tokenizer.sep_token_id, tokenizer.pad_token_id),
151 | 'attention_mask': encoded_dict['attention_mask'][:,:max_sent_len]
152 | }
153 |
154 | return encoded_dict
155 |
156 | def token_idx_by_sentence(input_ids, sep_token_id, model_name):
157 | """
158 | Advanced indexing: Compute the token indices matrix of the BERT output.
159 | input_ids: (batch_size, paragraph_len)
160 | batch_indices, indices_by_batch, mask: (batch_size, N_sentence, N_token)
161 | bert_out: (batch_size, paragraph_len,BERT_dim)
162 | bert_out[batch_indices,indices_by_batch,:]: (batch_size, N_sentence, N_token, BERT_dim)
163 | """
164 | padding_idx = -1
165 | sep_tokens = (input_ids == sep_token_id).bool()
166 | paragraph_lens = torch.sum(sep_tokens,1).numpy().tolist() # i.e. N_sentences per paragraph
167 | indices = torch.arange(sep_tokens.size(-1)).unsqueeze(0).expand(sep_tokens.size(0),-1) # 0,1,2,3,....,511 for each sentence
168 | sep_indices = torch.split(indices[sep_tokens],paragraph_lens) # indices of SEP tokens per paragraph
169 | paragraph_lens = []
170 | all_word_indices = []
171 | for paragraph in sep_indices:
172 | # claim sentence: [CLS] token1 token2 ... tokenk
173 | claim_word_indices = torch.arange(0, paragraph[0])
174 | if "roberta" in model_name: # Huggingface Roberta has ......
175 | paragraph = paragraph[1:]
176 | # each sentence: [SEP] token1 token2 ... tokenk, the last [SEP] in the paragraph is ditched.
177 | sentence_word_indices = [torch.arange(paragraph[i], paragraph[i+1]) for i in range(paragraph.size(0)-1)]
178 |
179 | # KGAT requires claim sentence, so add it back.
180 | word_indices = [claim_word_indices] + sentence_word_indices
181 |
182 | paragraph_lens.append(len(word_indices))
183 | all_word_indices.extend(word_indices)
184 | indices_by_sentence = nn.utils.rnn.pad_sequence(all_word_indices, batch_first=True, padding_value=padding_idx)
185 | indices_by_sentence_split = torch.split(indices_by_sentence,paragraph_lens)
186 | indices_by_batch = nn.utils.rnn.pad_sequence(indices_by_sentence_split, batch_first=True, padding_value=padding_idx)
187 | batch_indices = torch.arange(sep_tokens.size(0)).unsqueeze(-1).unsqueeze(-1).expand(-1,indices_by_batch.size(1),indices_by_batch.size(-1))
188 | mask = (indices_by_batch>=0)
189 |
190 | return batch_indices.long(), indices_by_batch.long(), mask.long()
191 |
192 | def post_process_stance(rationale_json, stance_json):
193 | assert(len(rationale_json) == len(stance_json))
194 | for stance_pred, rationale_pred in zip(stance_json, rationale_json):
195 | assert(stance_pred["claim_id"] == rationale_pred["claim_id"])
196 | for doc_id, pred in rationale_pred["evidence"].items():
197 | if len(pred) == 0:
198 | stance_pred["labels"][doc_id]["label"] = "NOT_ENOUGH_INFO"
199 | return stance_json
200 |
201 | if __name__ == "__main__":
202 | argparser = argparse.ArgumentParser(description="Train, cross-validate and run sentence sequence tagger")
203 | argparser.add_argument('--repfile', type=str, default = "roberta-large", help="Word embedding file")
204 | argparser.add_argument('--corpus_file', type=str, default="/home/xxl190027/scifact_data/corpus.jsonl")
205 | argparser.add_argument('--train_file', type=str, default="/home/xxl190027/scifact_data/claims_train_retrieved.jsonl")
206 | argparser.add_argument('--pre_trained_model', type=str)
207 | #argparser.add_argument('--train_file', type=str)
208 | argparser.add_argument('--test_file', type=str, default="/home/xxl190027/scifact_data/claims_dev_retrieved.jsonl")
209 | argparser.add_argument('--dataset', type=str, default="/home/xxl190027/scifact_data/claims_dev.jsonl")
210 | argparser.add_argument('--bert_lr', type=float, default=5e-6, help="Learning rate for BERT-like LM")
211 | argparser.add_argument('--lr', type=float, default=5e-6, help="Learning rate")
212 | argparser.add_argument('--dropout', type=float, default=0, help="embedding_dropout rate")
213 | argparser.add_argument('--bert_dim', type=int, default=1024, help="bert_dimension")
214 | argparser.add_argument('--epoch', type=int, default=10, help="Training epoch")
215 | argparser.add_argument('--MAX_SENT_LEN', type=int, default=512)
216 | argparser.add_argument('--loss_ratio', type=float, default=3)
217 | argparser.add_argument('--checkpoint', type=str, default = "scifact_roberta_joint_paragraph_kgat.model")
218 | argparser.add_argument('--log_file', type=str, default = "joint_paragraph_roberta_kgat_dynamic_performances.jsonl")
219 | argparser.add_argument('--prediction', type=str, default = "prediction_scifact_roberta_joint_paragraph_kgat_fine_tune.jsonl")
220 | argparser.add_argument('--update_step', type=int, default=10)
221 | argparser.add_argument('--batch_size', type=int, default=1) # roberta-large: 2; bert: 8
222 | argparser.add_argument('--k', type=int, default=0)
223 | argparser.add_argument('--kernel', type=int, default=6)
224 | argparser.add_argument('--fine_tune', action='store_true')
225 |
226 | logging.getLogger("transformers.tokenization_utils_base").setLevel(logging.ERROR)
227 |
228 | reset_random_seed(12345)
229 |
230 | args = argparser.parse_args()
231 |
232 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
233 | tokenizer = AutoTokenizer.from_pretrained(args.repfile)
234 |
235 | if args.train_file:
236 | train = True
237 | #assert args.repfile is not None, "Word embedding file required for training."
238 | else:
239 | train = False
240 | if args.test_file:
241 | test = True
242 | else:
243 | test = False
244 |
245 | params = vars(args)
246 |
247 | for k,v in params.items():
248 | print(k,v)
249 |
250 | if train:
251 | train_set = SciFactParagraphBatchDataset(args.corpus_file, args.train_file,
252 | sep_token = tokenizer.sep_token, k = args.k, downsample_n = 2)
253 | dev_set = SciFactParagraphBatchDataset(args.corpus_file, args.test_file,
254 | sep_token = tokenizer.sep_token, k = args.k, downsample_n = 0)
255 |
256 | model = JointParagraphKGATClassifier(args.repfile, args.bert_dim,
257 | args.dropout, kernel = args.kernel)#.to(device)
258 |
259 | if args.pre_trained_model is not None:
260 | model.load_state_dict(torch.load(args.pre_trained_model))
261 | if not args.fine_tune:
262 | model.reinitialize() ############
263 | print("Reinitialized part of the model!")
264 |
265 | model = model.to(device)
266 |
267 | if train:
268 | if args.bert_lr == 0:
269 | print("Freezing BERT weights.")
270 | settings = []
271 | for param in model.bert.parameters():
272 | param.requires_grad = False
273 | else:
274 | settings = [{'params': model.bert.parameters(), 'lr': args.bert_lr}]
275 | for module in model.extra_modules:
276 | settings.append({'params': module.parameters(), 'lr': args.lr})
277 | optimizer = torch.optim.Adam(settings)
278 | scheduler = get_cosine_schedule_with_warmup(optimizer, 0, args.epoch)
279 | model.train()
280 |
281 | prev_performance = 0
282 | for epoch in range(args.epoch):
283 | if args.fine_tune:
284 | sample_p = 1
285 | else:
286 | sample_p = schedule_sample_p(epoch, args.epoch)
287 | tq = tqdm(DataLoader(train_set, batch_size = args.batch_size, shuffle=True))
288 | for i, batch in enumerate(tq):
289 | encoded_dict = encode(tokenizer, batch)
290 | transformation_indices = token_idx_by_sentence(encoded_dict["input_ids"], tokenizer.sep_token_id, args.repfile)
291 | encoded_dict = {key: tensor.to(device) for key, tensor in encoded_dict.items()}
292 | transformation_indices = [tensor.to(device) for tensor in transformation_indices]
293 | stance_label = batch["stance"].to(device)
294 | padded_rationale_label, rationale_label = batch_rationale_label(batch["label"], padding_idx = 2)
295 | rationale_out, stance_out, rationale_loss, stance_loss = \
296 | model(encoded_dict, transformation_indices, stance_label = stance_label,
297 | rationale_label = padded_rationale_label.to(device), sample_p = sample_p)
298 | rationale_loss *= args.loss_ratio
299 | loss = rationale_loss + stance_loss
300 | loss.backward()
301 |
302 | if i % args.update_step == args.update_step - 1:
303 | optimizer.step()
304 | optimizer.zero_grad()
305 | tq.set_description(f'Epoch {epoch}, iter {i}, loss: {round(loss.item(), 4)}, stance loss: {round(stance_loss.item(), 4)}, rationale loss: {round(rationale_loss.item(), 4)}')
306 | scheduler.step()
307 |
308 | # Evaluation
309 | train_score = evaluation(model, train_set)
310 | print(f'Epoch {epoch}, train stance f1 p r: %.4f, %.4f, %.4f, rationale f1 p r: %.4f, %.4f, %.4f' % train_score)
311 |
312 | dev_score = evaluation(model, dev_set)
313 | print(f'Epoch {epoch}, dev stance f1 p r: %.4f, %.4f, %.4f, rationale f1 p r: %.4f, %.4f, %.4f' % dev_score)
314 |
315 | dev_perf = dev_score[0] * dev_score[3]
316 | if dev_perf >= prev_performance:
317 | torch.save(model.state_dict(), args.checkpoint)
318 | best_state_dict = model.state_dict()
319 | prev_performance = dev_perf
320 | print("New model saved!")
321 | else:
322 | print("Skip saving model.")
323 |
324 |
325 | if test:
326 | if train:
327 | del model
328 | model = JointParagraphKGATClassifier(args.repfile, args.bert_dim,
329 | args.dropout, kernel = args.kernel).to(device)
330 | model.load_state_dict(best_state_dict)
331 | print("Testing on the new model.")
332 | else:
333 | model.load_state_dict(torch.load(args.checkpoint))
334 | print("Loaded saved model.")
335 |
336 | rationale_predictions, stance_preds = predict(model, dev_set)
337 | rationale_json = rationale2json(dev_set.samples, rationale_predictions)
338 | stance_json = stance2json(dev_set.samples, stance_preds)
339 | stance_json = post_process_stance(rationale_json, stance_json)
340 | merged_json = merge_json(rationale_json, stance_json)
341 |
342 | with jsonlines.open(args.prediction, 'w') as output:
343 | for result in merged_json:
344 | output.write(result)
345 |
346 | data = GoldDataset(args.corpus_file, args.dataset)
347 | predictions = PredictedDataset(data, args.prediction)
348 | res = metrics.compute_metrics(predictions)
349 | params["evaluation"] = res
350 | with jsonlines.open(args.log_file, mode='a') as writer:
351 | writer.write(params)
--------------------------------------------------------------------------------
/scifact_joint_paragraph_kgat_prediction.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | import torch
4 | import jsonlines
5 | import os
6 |
7 | import torch.nn as nn
8 | import torch.nn.functional as F
9 | from torch.utils.data import Dataset, DataLoader
10 | from transformers import AutoModel, AutoTokenizer, get_cosine_schedule_with_warmup
11 | from tqdm import tqdm
12 | from typing import List
13 | from sklearn.metrics import f1_score, precision_score, recall_score
14 |
15 | import math
16 | import random
17 | import numpy as np
18 |
19 | from tqdm import tqdm
20 | from util import arg2param, flatten, stance2json, rationale2json, merge_json
21 | from paragraph_model_kgat import JointParagraphKGATClassifier
22 | from dataset import SciFactParagraphBatchDataset
23 |
24 | import logging
25 |
26 | from lib.data import GoldDataset, PredictedDataset
27 | from lib import metrics
28 |
29 | def reset_random_seed(seed):
30 | np.random.seed(seed)
31 | random.seed(seed)
32 | torch.manual_seed(seed)
33 |
34 | def predict(model, dataset):
35 | model.eval()
36 | rationale_predictions = []
37 | stance_preds = []
38 |
39 | def remove_dummy(rationale_out):
40 | return [out[1:] for out in rationale_out]
41 |
42 | with torch.no_grad():
43 | for batch in tqdm(DataLoader(dataset, batch_size = args.batch_size, shuffle=False)):
44 | encoded_dict = encode(tokenizer, batch)
45 | transformation_indices = token_idx_by_sentence(encoded_dict["input_ids"],
46 | tokenizer.sep_token_id, args.repfile)
47 | encoded_dict = {key: tensor.to(device) for key, tensor in encoded_dict.items()}
48 | transformation_indices = [tensor.to(device) for tensor in transformation_indices]
49 | rationale_out, stance_out, _, _ = model(encoded_dict, transformation_indices)
50 | stance_preds.extend(stance_out)
51 | rationale_predictions.extend(remove_dummy(rationale_out))
52 |
53 | return rationale_predictions, stance_preds
54 |
55 |
56 |
57 | def encode(tokenizer, batch, max_sent_len = 512):
58 | def truncate(input_ids, max_length, sep_token_id, pad_token_id):
59 | def longest_first_truncation(sentences, objective):
60 | sent_lens = [len(sent) for sent in sentences]
61 | while np.sum(sent_lens) > objective:
62 | max_position = np.argmax(sent_lens)
63 | sent_lens[max_position] -= 1
64 | return [sentence[:length] for sentence, length in zip(sentences, sent_lens)]
65 |
66 | all_paragraphs = []
67 | for paragraph in input_ids:
68 | valid_paragraph = paragraph[paragraph != pad_token_id]
69 | if valid_paragraph.size(0) <= max_length:
70 | all_paragraphs.append(paragraph[:max_length].unsqueeze(0))
71 | else:
72 | sep_token_idx = np.arange(valid_paragraph.size(0))[(valid_paragraph == sep_token_id).numpy()]
73 | idx_by_sentence = []
74 | prev_idx = 0
75 | for idx in sep_token_idx:
76 | idx_by_sentence.append(paragraph[prev_idx:idx])
77 | prev_idx = idx
78 | objective = max_length - 1 - len(idx_by_sentence[0]) # The last sep_token left out
79 | truncated_sentences = longest_first_truncation(idx_by_sentence[1:], objective)
80 | truncated_paragraph = torch.cat([idx_by_sentence[0]] + truncated_sentences + [torch.tensor([sep_token_id])],0)
81 | all_paragraphs.append(truncated_paragraph.unsqueeze(0))
82 |
83 | return torch.cat(all_paragraphs, 0)
84 |
85 | inputs = zip(batch["claim"], batch["paragraph"])
86 | encoded_dict = tokenizer.batch_encode_plus(
87 | inputs,
88 | pad_to_max_length=True,add_special_tokens=True,
89 | return_tensors='pt')
90 | if encoded_dict['input_ids'].size(1) > max_sent_len:
91 | if 'token_type_ids' in encoded_dict:
92 | encoded_dict = {
93 | "input_ids": truncate(encoded_dict['input_ids'], max_sent_len,
94 | tokenizer.sep_token_id, tokenizer.pad_token_id),
95 | 'token_type_ids': encoded_dict['token_type_ids'][:,:max_sent_len],
96 | 'attention_mask': encoded_dict['attention_mask'][:,:max_sent_len]
97 | }
98 | else:
99 | encoded_dict = {
100 | "input_ids": truncate(encoded_dict['input_ids'], max_sent_len,
101 | tokenizer.sep_token_id, tokenizer.pad_token_id),
102 | 'attention_mask': encoded_dict['attention_mask'][:,:max_sent_len]
103 | }
104 | return encoded_dict
105 |
106 | def token_idx_by_sentence(input_ids, sep_token_id, model_name):
107 | """
108 | Advanced indexing: Compute the token indices matrix of the BERT output.
109 | input_ids: (batch_size, paragraph_len)
110 | batch_indices, indices_by_batch, mask: (batch_size, N_sentence, N_token)
111 | bert_out: (batch_size, paragraph_len,BERT_dim)
112 | bert_out[batch_indices,indices_by_batch,:]: (batch_size, N_sentence, N_token, BERT_dim)
113 | """
114 | padding_idx = -1
115 | sep_tokens = (input_ids == sep_token_id).bool()
116 | paragraph_lens = torch.sum(sep_tokens,1).numpy().tolist() # i.e. N_sentences per paragraph
117 | indices = torch.arange(sep_tokens.size(-1)).unsqueeze(0).expand(sep_tokens.size(0),-1) # 0,1,2,3,....,511 for each sentence
118 | sep_indices = torch.split(indices[sep_tokens],paragraph_lens) # indices of SEP tokens per paragraph
119 | paragraph_lens = []
120 | all_word_indices = []
121 | for paragraph in sep_indices:
122 | # claim sentence: [CLS] token1 token2 ... tokenk
123 | claim_word_indices = torch.arange(0, paragraph[0])
124 | if "roberta" in model_name: # Huggingface Roberta has ......
125 | paragraph = paragraph[1:]
126 | # each sentence: [SEP] token1 token2 ... tokenk, the last [SEP] in the paragraph is ditched.
127 | sentence_word_indices = [torch.arange(paragraph[i], paragraph[i+1]) for i in range(paragraph.size(0)-1)]
128 |
129 | # KGAT requires claim sentence, so add it back.
130 | word_indices = [claim_word_indices] + sentence_word_indices
131 |
132 | paragraph_lens.append(len(word_indices))
133 | all_word_indices.extend(word_indices)
134 | indices_by_sentence = nn.utils.rnn.pad_sequence(all_word_indices, batch_first=True, padding_value=padding_idx)
135 | indices_by_sentence_split = torch.split(indices_by_sentence,paragraph_lens)
136 | indices_by_batch = nn.utils.rnn.pad_sequence(indices_by_sentence_split, batch_first=True, padding_value=padding_idx)
137 | batch_indices = torch.arange(sep_tokens.size(0)).unsqueeze(-1).unsqueeze(-1).expand(-1,indices_by_batch.size(1),indices_by_batch.size(-1))
138 | mask = (indices_by_batch>=0)
139 |
140 | return batch_indices.long(), indices_by_batch.long(), mask.long()
141 |
142 | def post_process_stance(rationale_json, stance_json):
143 | assert(len(rationale_json) == len(stance_json))
144 | for stance_pred, rationale_pred in zip(stance_json, rationale_json):
145 | assert(stance_pred["claim_id"] == rationale_pred["claim_id"])
146 | for doc_id, pred in rationale_pred["evidence"].items():
147 | if len(pred) == 0:
148 | stance_pred["labels"][doc_id]["label"] = "NOT_ENOUGH_INFO"
149 | return stance_json
150 |
151 | if __name__ == "__main__":
152 | argparser = argparse.ArgumentParser(description="Train, cross-validate and run sentence sequence tagger")
153 | argparser.add_argument('--repfile', type=str, default = "roberta-large", help="Word embedding file")
154 | argparser.add_argument('--corpus_file', type=str, default="/home/xxl190027/scifact_data/corpus.jsonl")
155 | argparser.add_argument('--test_file', type=str, default="/home/xxl190027/scifact_data/claims_dev_retrieved.jsonl")
156 | argparser.add_argument('--dropout', type=float, default=0, help="embedding_dropout rate")
157 | argparser.add_argument('--dataset', type=str, default="/home/xxl190027/scifact_data/claims_dev.jsonl")
158 | argparser.add_argument('--bert_dim', type=int, default=1024, help="bert_dimension")
159 | argparser.add_argument('--MAX_SENT_LEN', type=int, default=512)
160 | argparser.add_argument('--checkpoint', type=str, default = "scifact_roberta_joint_paragraph.model")
161 | argparser.add_argument('--batch_size', type=int, default=25)
162 | argparser.add_argument('--prediction', type=str, default = "prediction_scifact_roberta_joint_paragraph_kgat_fine_tune.jsonl")
163 | argparser.add_argument('--k', type=int, default=0)
164 | argparser.add_argument('--kernel', type=int, default=6)
165 | argparser.add_argument('--log_file', type=str, default = "kgat_prediction.log")
166 |
167 | logging.getLogger("transformers.tokenization_utils_base").setLevel(logging.ERROR)
168 |
169 | reset_random_seed(12345)
170 |
171 | args = argparser.parse_args()
172 | params = vars(args)
173 |
174 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
175 | tokenizer = AutoTokenizer.from_pretrained(args.repfile)
176 |
177 | dev_set = SciFactParagraphBatchDataset(args.corpus_file, args.test_file,
178 | sep_token = tokenizer.sep_token, k = args.k, train=False)
179 |
180 | model = JointParagraphKGATClassifier(args.repfile, args.bert_dim,
181 | args.dropout, kernel = args.kernel).to(device)
182 |
183 | model.load_state_dict(torch.load(args.checkpoint))
184 | print("Loaded saved model.")
185 |
186 | reset_random_seed(12345)
187 | rationale_predictions, stance_preds = predict(model, dev_set)
188 | rationale_json = rationale2json(dev_set.samples, rationale_predictions)
189 | stance_json = stance2json(dev_set.samples, stance_preds)
190 | stance_json = post_process_stance(rationale_json, stance_json)
191 | merged_json = merge_json(rationale_json, stance_json)
192 |
193 | with jsonlines.open(args.prediction, 'w') as output:
194 | for result in merged_json:
195 | output.write(result)
196 |
197 | data = GoldDataset(args.corpus_file, args.dataset)
198 | predictions = PredictedDataset(data, args.prediction)
199 | res = metrics.compute_metrics(predictions)
200 | params["evaluation"] = res
201 | with jsonlines.open(args.log_file, mode='a') as writer:
202 | writer.write(params)
--------------------------------------------------------------------------------
/scifact_rationale_paragraph.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | import torch
4 | import jsonlines
5 | import os
6 |
7 | import torch.nn as nn
8 | import torch.nn.functional as F
9 | from torch.utils.data import Dataset, DataLoader
10 | from transformers import AutoModel, AutoTokenizer, get_cosine_schedule_with_warmup
11 | from tqdm import tqdm
12 | from typing import List
13 | from sklearn.metrics import f1_score, precision_score, recall_score
14 |
15 | import random
16 | import numpy as np
17 |
18 | from tqdm import tqdm
19 | from util import arg2param, flatten, rationale2json
20 | from paragraph_model_dynamic import RationaleParagraphClassifier as JointParagraphClassifier
21 | from dataset import SciFactParagraphBatchDataset
22 |
23 | import logging
24 |
25 | def reset_random_seed(seed):
26 | np.random.seed(seed)
27 | random.seed(seed)
28 | torch.manual_seed(seed)
29 |
30 | def batch_rationale_label(labels, padding_idx = 2):
31 | max_sent_len = max([len(label) for label in labels])
32 | label_matrix = torch.ones(len(labels), max_sent_len) * padding_idx
33 | label_list = []
34 | for i, label in enumerate(labels):
35 | for j, evid in enumerate(label):
36 | label_matrix[i,j] = int(evid)
37 | label_list.append([int(evid) for evid in label])
38 | return label_matrix.long(), label_list
39 |
40 | def predict(model, dataset):
41 | model.eval()
42 | rationale_predictions = []
43 |
44 | with torch.no_grad():
45 | for batch in tqdm(DataLoader(dataset, batch_size = args.batch_size, shuffle=False)):
46 | encoded_dict = encode(tokenizer, batch)
47 | transformation_indices = token_idx_by_sentence(encoded_dict["input_ids"],
48 | tokenizer.sep_token_id, args.repfile)
49 | encoded_dict = {key: tensor.to(device) for key, tensor in encoded_dict.items()}
50 | transformation_indices = [tensor.to(device) for tensor in transformation_indices]
51 | rationale_out, _= model(encoded_dict, transformation_indices)
52 | rationale_predictions.extend(rationale_out)
53 |
54 | return rationale_predictions
55 |
56 | def evaluation(model, dataset):
57 | model.eval()
58 | rationale_predictions = []
59 | rationale_labels = []
60 |
61 | with torch.no_grad():
62 | for batch in tqdm(DataLoader(dataset, batch_size = args.batch_size*5, shuffle=False)):
63 | encoded_dict = encode(tokenizer, batch)
64 | transformation_indices = token_idx_by_sentence(encoded_dict["input_ids"],
65 | tokenizer.sep_token_id, args.repfile)
66 | encoded_dict = {key: tensor.to(device) for key, tensor in encoded_dict.items()}
67 | transformation_indices = [tensor.to(device) for tensor in transformation_indices]
68 | padded_rationale_label, rationale_label = batch_rationale_label(batch["label"], padding_idx = 2)
69 | rationale_out, rationale_loss = \
70 | model(encoded_dict, transformation_indices,
71 | rationale_label = padded_rationale_label.to(device))
72 |
73 | rationale_predictions.extend(rationale_out)
74 | rationale_labels.extend(rationale_label)
75 |
76 | rationale_f1 = f1_score(flatten(rationale_labels),flatten(rationale_predictions))
77 | rationale_precision = precision_score(flatten(rationale_labels),flatten(rationale_predictions))
78 | rationale_recall = recall_score(flatten(rationale_labels),flatten(rationale_predictions))
79 | return rationale_f1, rationale_precision, rationale_recall
80 |
81 |
82 |
83 | def encode(tokenizer, batch, max_sent_len = 512):
84 | def truncate(input_ids, max_length, sep_token_id, pad_token_id):
85 | def longest_first_truncation(sentences, objective):
86 | sent_lens = [len(sent) for sent in sentences]
87 | while np.sum(sent_lens) > objective:
88 | max_position = np.argmax(sent_lens)
89 | sent_lens[max_position] -= 1
90 | return [sentence[:length] for sentence, length in zip(sentences, sent_lens)]
91 |
92 | all_paragraphs = []
93 | for paragraph in input_ids:
94 | valid_paragraph = paragraph[paragraph != pad_token_id]
95 | if valid_paragraph.size(0) <= max_length:
96 | all_paragraphs.append(paragraph[:max_length].unsqueeze(0))
97 | else:
98 | sep_token_idx = np.arange(valid_paragraph.size(0))[(valid_paragraph == sep_token_id).numpy()]
99 | idx_by_sentence = []
100 | prev_idx = 0
101 | for idx in sep_token_idx:
102 | idx_by_sentence.append(paragraph[prev_idx:idx])
103 | prev_idx = idx
104 | objective = max_length - 1 - len(idx_by_sentence[0]) # The last sep_token left out
105 | truncated_sentences = longest_first_truncation(idx_by_sentence[1:], objective)
106 | truncated_paragraph = torch.cat([idx_by_sentence[0]] + truncated_sentences + [torch.tensor([sep_token_id])],0)
107 | all_paragraphs.append(truncated_paragraph.unsqueeze(0))
108 |
109 | return torch.cat(all_paragraphs, 0)
110 |
111 | inputs = zip(batch["claim"], batch["paragraph"])
112 | encoded_dict = tokenizer.batch_encode_plus(
113 | inputs,
114 | pad_to_max_length=True,add_special_tokens=True,
115 | return_tensors='pt')
116 | if encoded_dict['input_ids'].size(1) > max_sent_len:
117 | if 'token_type_ids' in encoded_dict:
118 | encoded_dict = {
119 | "input_ids": truncate(encoded_dict['input_ids'], max_sent_len,
120 | tokenizer.sep_token_id, tokenizer.pad_token_id),
121 | 'token_type_ids': encoded_dict['token_type_ids'][:,:max_sent_len],
122 | 'attention_mask': encoded_dict['attention_mask'][:,:max_sent_len]
123 | }
124 | else:
125 | encoded_dict = {
126 | "input_ids": truncate(encoded_dict['input_ids'], max_sent_len,
127 | tokenizer.sep_token_id, tokenizer.pad_token_id),
128 | 'attention_mask': encoded_dict['attention_mask'][:,:max_sent_len]
129 | }
130 |
131 | return encoded_dict
132 |
133 | def sent_rep_indices(input_ids, sep_token_id, model_name):
134 |
135 | """
136 | Compute the [SEP] indices matrix of the BERT output.
137 | input_ids: (batch_size, paragraph_len)
138 | batch_indices, indices_by_batch, mask: (batch_size, N_sentence)
139 | bert_out: (batch_size, paragraph_len,BERT_dim)
140 | bert_out[batch_indices,indices_by_batch,:]: (batch_size, N_sentence, BERT_dim)
141 | """
142 |
143 | sep_tokens = (input_ids == sep_token_id).bool()
144 | paragraph_lens = torch.sum(sep_tokens,1).numpy().tolist()
145 | indices = torch.arange(sep_tokens.size(-1)).unsqueeze(0).expand(sep_tokens.size(0),-1)
146 | sep_indices = torch.split(indices[sep_tokens],paragraph_lens)
147 | padded_sep_indices = nn.utils.rnn.pad_sequence(sep_indices, batch_first=True, padding_value=-1)
148 | batch_indices = torch.arange(padded_sep_indices.size(0)).unsqueeze(-1).expand(-1,padded_sep_indices.size(-1))
149 | mask = (padded_sep_indices>=0).long()
150 |
151 | if "roberta" in model_name:
152 | return batch_indices[:,2:], padded_sep_indices[:,2:], mask[:,2:]
153 | else:
154 | return batch_indices[:,1:], padded_sep_indices[:,1:], mask[:,1:]
155 |
156 | def token_idx_by_sentence(input_ids, sep_token_id, model_name):
157 | """
158 | Compute the token indices matrix of the BERT output.
159 | input_ids: (batch_size, paragraph_len)
160 | batch_indices, indices_by_batch, mask: (batch_size, N_sentence, N_token)
161 | bert_out: (batch_size, paragraph_len,BERT_dim)
162 | bert_out[batch_indices,indices_by_batch,:]: (batch_size, N_sentence, N_token, BERT_dim)
163 | """
164 | padding_idx = -1
165 | sep_tokens = (input_ids == sep_token_id).bool()
166 | paragraph_lens = torch.sum(sep_tokens,1).numpy().tolist()
167 | indices = torch.arange(sep_tokens.size(-1)).unsqueeze(0).expand(sep_tokens.size(0),-1)
168 | sep_indices = torch.split(indices[sep_tokens],paragraph_lens)
169 | paragraph_lens = []
170 | all_word_indices = []
171 | for paragraph in sep_indices:
172 | if "large" in model_name:
173 | paragraph = paragraph[1:]
174 | word_indices = [torch.arange(paragraph[i]+1, paragraph[i+1]+1) for i in range(paragraph.size(0)-1)]
175 | paragraph_lens.append(len(word_indices))
176 | all_word_indices.extend(word_indices)
177 | indices_by_sentence = nn.utils.rnn.pad_sequence(all_word_indices, batch_first=True, padding_value=padding_idx)
178 | indices_by_sentence_split = torch.split(indices_by_sentence,paragraph_lens)
179 | indices_by_batch = nn.utils.rnn.pad_sequence(indices_by_sentence_split, batch_first=True, padding_value=padding_idx)
180 | batch_indices = torch.arange(sep_tokens.size(0)).unsqueeze(-1).unsqueeze(-1).expand(-1,indices_by_batch.size(1),indices_by_batch.size(-1))
181 | mask = (indices_by_batch>=0)
182 |
183 | return batch_indices.long(), indices_by_batch.long(), mask.long()
184 |
185 | if __name__ == "__main__":
186 | argparser = argparse.ArgumentParser(description="Train, cross-validate and run sentence sequence tagger")
187 | argparser.add_argument('--repfile', type=str, default = "roberta-large", help="Word embedding file")
188 | argparser.add_argument('--corpus_file', type=str, default="/nas/home/xiangcil/scifact/data/corpus.jsonl")
189 | argparser.add_argument('--train_file', type=str, default="/nas/home/xiangcil/CitationEvaluation/SciFact/claims_train_retrieved.jsonl")
190 | argparser.add_argument('--pre_trained_model', type=str)
191 | #argparser.add_argument('--train_file', type=str)
192 | argparser.add_argument('--test_file', type=str, default="/nas/home/xiangcil/CitationEvaluation/SciFact/claims_dev_retrieved.jsonl")
193 | argparser.add_argument('--bert_lr', type=float, default=1e-5, help="Learning rate for BERT-like LM")
194 | argparser.add_argument('--lr', type=float, default=5e-6, help="Learning rate")
195 | argparser.add_argument('--dropout', type=float, default=0, help="embedding_dropout rate")
196 | argparser.add_argument('--bert_dim', type=int, default=1024, help="bert_dimension")
197 | argparser.add_argument('--epoch', type=int, default=20, help="Training epoch")
198 | argparser.add_argument('--MAX_SENT_LEN', type=int, default=512)
199 | argparser.add_argument('--checkpoint', type=str, default = "scifact_roberta_rationale_paragraph.model")
200 | argparser.add_argument('--log_file', type=str, default = "rationale_paragraph_roberta_performances.jsonl")
201 | argparser.add_argument('--prediction', type=str, default = "prediction_scifact_roberta_rationale_paragraph.jsonl")
202 | argparser.add_argument('--update_step', type=int, default=10)
203 | argparser.add_argument('--batch_size', type=int, default=1) # roberta-large: 2; bert: 8
204 | argparser.add_argument('--k', type=int, default=0)
205 |
206 | logging.getLogger("transformers.tokenization_utils_base").setLevel(logging.ERROR)
207 |
208 | reset_random_seed(12345)
209 |
210 | args = argparser.parse_args()
211 |
212 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
213 | tokenizer = AutoTokenizer.from_pretrained(args.repfile)
214 |
215 | if args.train_file:
216 | train = True
217 | #assert args.repfile is not None, "Word embedding file required for training."
218 | else:
219 | train = False
220 | if args.test_file:
221 | test = True
222 | else:
223 | test = False
224 |
225 | params = vars(args)
226 |
227 | for k,v in params.items():
228 | print(k,v)
229 |
230 | if train:
231 | train_set = SciFactParagraphBatchDataset(args.corpus_file, args.train_file,
232 | sep_token = tokenizer.sep_token, k = args.k, dummy=False)
233 | dev_set = SciFactParagraphBatchDataset(args.corpus_file, args.test_file,
234 | sep_token = tokenizer.sep_token, k = args.k, dummy=False)
235 |
236 | model = JointParagraphClassifier(args.repfile, args.bert_dim,
237 | args.dropout)#.to(device)
238 |
239 | if args.pre_trained_model is not None:
240 | model.load_state_dict(torch.load(args.pre_trained_model))
241 | model.reinitialize() ############
242 |
243 | model = model.to(device)
244 |
245 | if train:
246 | settings = [{'params': model.bert.parameters(), 'lr': args.bert_lr}]
247 | for module in model.extra_modules:
248 | settings.append({'params': module.parameters(), 'lr': args.lr})
249 | optimizer = torch.optim.Adam(settings)
250 | scheduler = get_cosine_schedule_with_warmup(optimizer, 0, args.epoch)
251 | model.train()
252 |
253 | prev_performance = 0
254 | for epoch in range(args.epoch):
255 | tq = tqdm(DataLoader(train_set, batch_size = args.batch_size, shuffle=True))
256 | for i, batch in enumerate(tq):
257 | encoded_dict = encode(tokenizer, batch)
258 | transformation_indices = token_idx_by_sentence(encoded_dict["input_ids"], tokenizer.sep_token_id, args.repfile)
259 | encoded_dict = {key: tensor.to(device) for key, tensor in encoded_dict.items()}
260 | transformation_indices = [tensor.to(device) for tensor in transformation_indices]
261 | padded_rationale_label, rationale_label = batch_rationale_label(batch["label"], padding_idx = 2)
262 | rationale_out, loss = \
263 | model(encoded_dict, transformation_indices,
264 | rationale_label = padded_rationale_label.to(device))
265 | loss.backward()
266 |
267 | if i % args.update_step == args.update_step - 1:
268 | optimizer.step()
269 | optimizer.zero_grad()
270 | tq.set_description(f'Epoch {epoch}, iter {i}, loss: {round(loss.item(), 4)}')
271 | scheduler.step()
272 |
273 | # Evaluation
274 | train_score = evaluation(model, train_set)
275 | print(f'Epoch {epoch}, train rationale f1 p r: %.4f, %.4f, %.4f' % train_score)
276 |
277 | dev_score = evaluation(model, dev_set)
278 | print(f'Epoch {epoch}, rationale f1 p r: %.4f, %.4f, %.4f' % dev_score)
279 |
280 | dev_perf = dev_score[0]
281 | if dev_perf >= prev_performance:
282 | torch.save(model.state_dict(), args.checkpoint)
283 | best_state_dict = model.state_dict()
284 | prev_performance = dev_perf
285 | print("New model saved!")
286 | else:
287 | print("Skip saving model.")
288 |
289 |
290 | if test:
291 | if train:
292 | del model
293 | model = JointParagraphClassifier(args.repfile, args.bert_dim,
294 | args.dropout).to(device)
295 | model.load_state_dict(best_state_dict)
296 | print("Testing on the new model.")
297 | else:
298 | model.load_state_dict(torch.load(args.checkpoint))
299 | print("Loaded saved model.")
300 |
301 | # Evaluation
302 | dev_score = evaluation(model, dev_set)
303 | print(f'Test rationale f1 p r: %.4f, %.4f, %.4f' % dev_score)
304 |
305 | params["rationale_f1"] = dev_score[0]
306 | params["rationale_precision"] = dev_score[1]
307 | params["rationale_recall"] = dev_score[2]
308 |
309 | with jsonlines.open(args.log_file, mode='a') as writer:
310 | writer.write(params)
--------------------------------------------------------------------------------
/scifact_stance_paragraph.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | import torch
4 | import jsonlines
5 | import os
6 |
7 | import torch.nn as nn
8 | import torch.nn.functional as F
9 | from torch.utils.data import Dataset, DataLoader
10 | from transformers import AutoModel, AutoTokenizer, get_cosine_schedule_with_warmup
11 | from tqdm import tqdm
12 | from typing import List
13 | from sklearn.metrics import f1_score, precision_score, recall_score
14 |
15 | import random
16 | import numpy as np
17 |
18 | from tqdm import tqdm
19 | from util import arg2param, flatten, stance2json, merge_json
20 | from paragraph_model_dynamic import StanceParagraphClassifier as JointParagraphClassifier
21 | from dataset import SciFactStanceDataset as SciFactParagraphBatchDataset
22 |
23 | import logging
24 |
25 | def reset_random_seed(seed):
26 | np.random.seed(seed)
27 | random.seed(seed)
28 | torch.manual_seed(seed)
29 |
30 | def predict(model, dataset):
31 | model.eval()
32 | stance_preds = []
33 | with torch.no_grad():
34 | for batch in tqdm(DataLoader(dataset, batch_size = args.batch_size, shuffle=False)):
35 | encoded_dict = encode(tokenizer, batch)
36 | transformation_indices = token_idx_by_sentence(encoded_dict["input_ids"],
37 | tokenizer.sep_token_id, args.repfile)
38 | encoded_dict = {key: tensor.to(device) for key, tensor in encoded_dict.items()}
39 | transformation_indices = [tensor.to(device) for tensor in transformation_indices]
40 | stance_out, _ = model(encoded_dict, transformation_indices)
41 | stance_preds.extend(stance_out)
42 |
43 | return stance_preds
44 |
45 | def evaluation(model, dataset, dummy=True):
46 | model.eval()
47 | stance_preds = []
48 | stance_labels = []
49 |
50 | with torch.no_grad():
51 | for batch in tqdm(DataLoader(dataset, batch_size = args.batch_size*5, shuffle=False)):
52 | encoded_dict = encode(tokenizer, batch)
53 | transformation_indices = token_idx_by_sentence(encoded_dict["input_ids"],
54 | tokenizer.sep_token_id, args.repfile)
55 | encoded_dict = {key: tensor.to(device) for key, tensor in encoded_dict.items()}
56 | transformation_indices = [tensor.to(device) for tensor in transformation_indices]
57 | stance_label = batch["stance"].to(device)
58 | stance_out, loss = \
59 | model(encoded_dict, transformation_indices, stance_label = stance_label)
60 | stance_preds.extend(stance_out)
61 | stance_labels.extend(stance_label.cpu().numpy().tolist())
62 |
63 | stance_f1 = f1_score(stance_labels,stance_preds,average="micro",labels=[1, 2])
64 | stance_precision = precision_score(stance_labels,stance_preds,average="micro",labels=[1, 2])
65 | stance_recall = recall_score(stance_labels,stance_preds,average="micro",labels=[1, 2])
66 | return stance_f1, stance_precision, stance_recall
67 |
68 |
69 |
70 | def encode(tokenizer, batch, max_sent_len = 512):
71 | def truncate(input_ids, max_length, sep_token_id, pad_token_id):
72 | def longest_first_truncation(sentences, objective):
73 | sent_lens = [len(sent) for sent in sentences]
74 | while np.sum(sent_lens) > objective:
75 | max_position = np.argmax(sent_lens)
76 | sent_lens[max_position] -= 1
77 | return [sentence[:length] for sentence, length in zip(sentences, sent_lens)]
78 |
79 | all_paragraphs = []
80 | for paragraph in input_ids:
81 | valid_paragraph = paragraph[paragraph != pad_token_id]
82 | if valid_paragraph.size(0) <= max_length:
83 | all_paragraphs.append(paragraph[:max_length].unsqueeze(0))
84 | else:
85 | sep_token_idx = np.arange(valid_paragraph.size(0))[(valid_paragraph == sep_token_id).numpy()]
86 | idx_by_sentence = []
87 | prev_idx = 0
88 | for idx in sep_token_idx:
89 | idx_by_sentence.append(paragraph[prev_idx:idx])
90 | prev_idx = idx
91 | objective = max_length - 1 - len(idx_by_sentence[0]) # The last sep_token left out
92 | truncated_sentences = longest_first_truncation(idx_by_sentence[1:], objective)
93 | truncated_paragraph = torch.cat([idx_by_sentence[0]] + truncated_sentences + [torch.tensor([sep_token_id])],0)
94 | all_paragraphs.append(truncated_paragraph.unsqueeze(0))
95 |
96 | return torch.cat(all_paragraphs, 0)
97 |
98 | inputs = zip(batch["claim"], batch["paragraph"])
99 | encoded_dict = tokenizer.batch_encode_plus(
100 | inputs,
101 | pad_to_max_length=True,add_special_tokens=True,
102 | return_tensors='pt')
103 | if encoded_dict['input_ids'].size(1) > max_sent_len:
104 | if 'token_type_ids' in encoded_dict:
105 | encoded_dict = {
106 | "input_ids": truncate(encoded_dict['input_ids'], max_sent_len,
107 | tokenizer.sep_token_id, tokenizer.pad_token_id),
108 | 'token_type_ids': encoded_dict['token_type_ids'][:,:max_sent_len],
109 | 'attention_mask': encoded_dict['attention_mask'][:,:max_sent_len]
110 | }
111 | else:
112 | encoded_dict = {
113 | "input_ids": truncate(encoded_dict['input_ids'], max_sent_len,
114 | tokenizer.sep_token_id, tokenizer.pad_token_id),
115 | 'attention_mask': encoded_dict['attention_mask'][:,:max_sent_len]
116 | }
117 |
118 | return encoded_dict
119 |
120 | def sent_rep_indices(input_ids, sep_token_id, model_name):
121 |
122 | """
123 | Compute the [SEP] indices matrix of the BERT output.
124 | input_ids: (batch_size, paragraph_len)
125 | batch_indices, indices_by_batch, mask: (batch_size, N_sentence)
126 | bert_out: (batch_size, paragraph_len,BERT_dim)
127 | bert_out[batch_indices,indices_by_batch,:]: (batch_size, N_sentence, BERT_dim)
128 | """
129 |
130 | sep_tokens = (input_ids == sep_token_id).bool()
131 | paragraph_lens = torch.sum(sep_tokens,1).numpy().tolist()
132 | indices = torch.arange(sep_tokens.size(-1)).unsqueeze(0).expand(sep_tokens.size(0),-1)
133 | sep_indices = torch.split(indices[sep_tokens],paragraph_lens)
134 | padded_sep_indices = nn.utils.rnn.pad_sequence(sep_indices, batch_first=True, padding_value=-1)
135 | batch_indices = torch.arange(padded_sep_indices.size(0)).unsqueeze(-1).expand(-1,padded_sep_indices.size(-1))
136 | mask = (padded_sep_indices>=0).long()
137 |
138 | if "roberta" in model_name:
139 | return batch_indices[:,2:], padded_sep_indices[:,2:], mask[:,2:]
140 | else:
141 | return batch_indices[:,1:], padded_sep_indices[:,1:], mask[:,1:]
142 |
143 | def token_idx_by_sentence(input_ids, sep_token_id, model_name):
144 | """
145 | Compute the token indices matrix of the BERT output.
146 | input_ids: (batch_size, paragraph_len)
147 | batch_indices, indices_by_batch, mask: (batch_size, N_sentence, N_token)
148 | bert_out: (batch_size, paragraph_len,BERT_dim)
149 | bert_out[batch_indices,indices_by_batch,:]: (batch_size, N_sentence, N_token, BERT_dim)
150 | """
151 | padding_idx = -1
152 | sep_tokens = (input_ids == sep_token_id).bool()
153 | paragraph_lens = torch.sum(sep_tokens,1).numpy().tolist()
154 | indices = torch.arange(sep_tokens.size(-1)).unsqueeze(0).expand(sep_tokens.size(0),-1)
155 | sep_indices = torch.split(indices[sep_tokens],paragraph_lens)
156 | paragraph_lens = []
157 | all_word_indices = []
158 | for paragraph in sep_indices:
159 | if "large" in model_name:
160 | paragraph = paragraph[1:]
161 | word_indices = [torch.arange(paragraph[i]+1, paragraph[i+1]+1) for i in range(paragraph.size(0)-1)]
162 | paragraph_lens.append(len(word_indices))
163 | all_word_indices.extend(word_indices)
164 | indices_by_sentence = nn.utils.rnn.pad_sequence(all_word_indices, batch_first=True, padding_value=padding_idx)
165 | indices_by_sentence_split = torch.split(indices_by_sentence,paragraph_lens)
166 | indices_by_batch = nn.utils.rnn.pad_sequence(indices_by_sentence_split, batch_first=True, padding_value=padding_idx)
167 | batch_indices = torch.arange(sep_tokens.size(0)).unsqueeze(-1).unsqueeze(-1).expand(-1,indices_by_batch.size(1),indices_by_batch.size(-1))
168 | mask = (indices_by_batch>=0)
169 |
170 | return batch_indices.long(), indices_by_batch.long(), mask.long()
171 |
172 | if __name__ == "__main__":
173 | argparser = argparse.ArgumentParser(description="Train, cross-validate and run sentence sequence tagger")
174 | argparser.add_argument('--repfile', type=str, default = "roberta-large", help="Word embedding file")
175 | argparser.add_argument('--corpus_file', type=str, default="/nas/home/xiangcil/scifact/data/corpus.jsonl")
176 | argparser.add_argument('--train_file', type=str, default="/nas/home/xiangcil/CitationEvaluation/SciFact/claims_train_retrieved.jsonl")
177 | argparser.add_argument('--pre_trained_model', type=str)
178 | #argparser.add_argument('--train_file', type=str)
179 | argparser.add_argument('--test_file', type=str, default="/nas/home/xiangcil/CitationEvaluation/SciFact/claims_dev_retrieved.jsonl")
180 | argparser.add_argument('--bert_lr', type=float, default=5e-6, help="Learning rate for BERT-like LM")
181 | argparser.add_argument('--lr', type=float, default=1e-6, help="Learning rate")
182 | argparser.add_argument('--dropout', type=float, default=0, help="embedding_dropout rate")
183 | argparser.add_argument('--bert_dim', type=int, default=1024, help="bert_dimension")
184 | argparser.add_argument('--epoch', type=int, default=20, help="Training epoch")
185 | argparser.add_argument('--MAX_SENT_LEN', type=int, default=512)
186 | argparser.add_argument('--checkpoint', type=str, default = "scifact_roberta_stance_paragraph.model")
187 | argparser.add_argument('--log_file', type=str, default = "stance_paragraph_roberta_performances.jsonl")
188 | argparser.add_argument('--prediction', type=str, default = "prediction_scifact_roberta_stance_paragraph.jsonl")
189 | argparser.add_argument('--update_step', type=int, default=10)
190 | argparser.add_argument('--batch_size', type=int, default=1) # roberta-large: 2; bert: 8
191 | argparser.add_argument('--k', type=int, default=0)
192 |
193 | logging.getLogger("transformers.tokenization_utils_base").setLevel(logging.ERROR)
194 |
195 | reset_random_seed(12345)
196 |
197 | args = argparser.parse_args()
198 |
199 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
200 | tokenizer = AutoTokenizer.from_pretrained(args.repfile)
201 |
202 | if args.train_file:
203 | train = True
204 | #assert args.repfile is not None, "Word embedding file required for training."
205 | else:
206 | train = False
207 | if args.test_file:
208 | test = True
209 | else:
210 | test = False
211 |
212 | params = vars(args)
213 |
214 | for k,v in params.items():
215 | print(k,v)
216 |
217 | if train:
218 | train_set = SciFactParagraphBatchDataset(args.corpus_file, args.train_file,
219 | sep_token = tokenizer.sep_token, k = args.k)
220 | dev_set = SciFactParagraphBatchDataset(args.corpus_file, args.test_file,
221 | sep_token = tokenizer.sep_token, k = args.k)
222 |
223 | model = JointParagraphClassifier(args.repfile, args.bert_dim,
224 | args.dropout)#.to(device)
225 |
226 | if args.pre_trained_model is not None:
227 | model.load_state_dict(torch.load(args.pre_trained_model))
228 | model.reinitialize()
229 | print("Reinitialized part of the model!")
230 |
231 | model = model.to(device)
232 |
233 | if train:
234 | settings = [{'params': model.bert.parameters(), 'lr': args.bert_lr}]
235 | for module in model.extra_modules:
236 | settings.append({'params': module.parameters(), 'lr': args.lr})
237 | optimizer = torch.optim.Adam(settings)
238 | scheduler = get_cosine_schedule_with_warmup(optimizer, 0, args.epoch)
239 | model.train()
240 |
241 | prev_performance = 0
242 | for epoch in range(args.epoch):
243 | tq = tqdm(DataLoader(train_set, batch_size = args.batch_size, shuffle=True))
244 | for i, batch in enumerate(tq):
245 | encoded_dict = encode(tokenizer, batch)
246 | transformation_indices = token_idx_by_sentence(encoded_dict["input_ids"], tokenizer.sep_token_id, args.repfile)
247 | encoded_dict = {key: tensor.to(device) for key, tensor in encoded_dict.items()}
248 | transformation_indices = [tensor.to(device) for tensor in transformation_indices]
249 | stance_label = batch["stance"].to(device)
250 | stance_out, loss = model(encoded_dict, transformation_indices, stance_label = stance_label)
251 | loss.backward()
252 |
253 | if i % args.update_step == args.update_step - 1:
254 | optimizer.step()
255 | optimizer.zero_grad()
256 | tq.set_description(f'Epoch {epoch}, iter {i}, loss: {round(loss.item(), 4)}')
257 | scheduler.step()
258 |
259 | # Evaluation
260 | train_score = evaluation(model, train_set)
261 | print(f'Epoch {epoch}, train stance f1 p r: %.4f, %.4f, %.4f' % train_score)
262 |
263 | dev_score = evaluation(model, dev_set)
264 | print(f'Epoch {epoch}, dev stance f1 p r: %.4f, %.4f, %.4f' % dev_score)
265 |
266 | dev_perf = dev_score[0]
267 | if dev_perf >= prev_performance:
268 | torch.save(model.state_dict(), args.checkpoint)
269 | best_state_dict = model.state_dict()
270 | prev_performance = dev_perf
271 | print("New model saved!")
272 | else:
273 | print("Skip saving model.")
274 |
275 |
276 | if test:
277 | if train:
278 | del model
279 | model = JointParagraphClassifier(args.repfile, args.bert_dim,
280 | args.dropout).to(device)
281 | model.load_state_dict(best_state_dict)
282 | print("Testing on the new model.")
283 | else:
284 | model.load_state_dict(torch.load(args.checkpoint))
285 | print("Loaded saved model.")
286 |
287 | # Evaluation
288 | dev_score = evaluation(model, dev_set)
289 | print(f'Test stance f1 p r: %.4f, %.4f, %.4f' % dev_score)
290 |
291 | params["SciFact stance_f1"] = dev_score[0]
292 | params["SciFact stance_precision"] = dev_score[1]
293 | params["SciFact stance_recall"] = dev_score[2]
294 |
295 | with jsonlines.open(args.log_file, mode='a') as writer:
296 | writer.write(params)
--------------------------------------------------------------------------------
/util.py:
--------------------------------------------------------------------------------
1 | import codecs
2 | import numpy
3 | import glob
4 | import re
5 | import numpy as np
6 | from sklearn.metrics import f1_score
7 |
8 | def flatten(arrayOfArray):
9 | array = []
10 | for arr in arrayOfArray:
11 | try:
12 | array.extend(arr)
13 | except:
14 | array.append(arr)
15 | return array
16 |
17 | def read_passages(filename, is_labeled):
18 | str_seqs = []
19 | str_seq = []
20 | label_seqs = []
21 | label_seq = []
22 | for line in codecs.open(filename, "r", "utf-8"):
23 | lnstrp = line.strip()
24 | if lnstrp == "":
25 | if len(str_seq) != 0:
26 | str_seqs.append(str_seq)
27 | str_seq = []
28 | label_seqs.append(label_seq)
29 | label_seq = []
30 | else:
31 | if is_labeled:
32 | clause, label = lnstrp.split("\t")
33 | label_seq.append(label.strip())
34 | else:
35 | clause = lnstrp
36 | str_seq.append(clause)
37 | if len(str_seq) != 0:
38 | str_seqs.append(str_seq)
39 | str_seq = []
40 | label_seqs.append(label_seq)
41 | label_seq = []
42 | return str_seqs, label_seqs
43 |
44 | def from_BIO_ind(BIO_pred, BIO_target, indices):
45 | table = {} # Make a mapping between the indices of BIO_labels and temporary original label indices
46 | original_labels = []
47 | for BIO_label,BIO_index in indices.items():
48 | if BIO_label[:2] == "I_" or BIO_label[:2] == "B_":
49 | label = BIO_label[2:]
50 | else:
51 | label = BIO_label
52 | if label in original_labels:
53 | table[BIO_index] = original_labels.index(label)
54 | else:
55 | table[BIO_index] = len(original_labels)
56 | original_labels.append(label)
57 |
58 | original_pred = [table[label] for label in BIO_pred]
59 | original_target = [table[label] for label in BIO_target]
60 | return original_pred, original_target
61 |
62 | def to_BIO(label_seqs):
63 | new_label_seqs = []
64 | for label_para in label_seqs:
65 | new_label_para = []
66 | prev = ""
67 | for label in label_para:
68 | if label!="none": # "none" is O, remain unchanged.
69 | if label==prev:
70 | new_label = "I_"+label
71 | else:
72 | new_label = "B_"+label
73 | else:
74 | new_label = label # "none"
75 | prev = label
76 | new_label_para.append(new_label)
77 | new_label_seqs.append(new_label_para)
78 | return new_label_seqs
79 |
80 | def from_BIO(label_seqs):
81 | new_label_seqs = []
82 | for label_para in label_seqs:
83 | new_label_para = []
84 | for label in label_para:
85 | if label[:2] == "I_" or label[:2] == "B_":
86 | new_label = label[2:]
87 | else:
88 | new_label = label
89 | new_label_para.append(new_label)
90 | new_label_seqs.append(new_label_para)
91 | return new_label_seqs
92 |
93 | def clean_url(word):
94 | """
95 | Clean specific data format from social media
96 | """
97 | # clean urls
98 | word = re.sub(r'https? : \/\/.*[\r\n]*', '', word)
99 | word = re.sub(r'exlink', '', word)
100 | return word
101 |
102 | def clean_num(word):
103 | # check if the word contain number and no letters
104 | if any(char.isdigit() for char in word):
105 | try:
106 | num = float(word.replace(',', ''))
107 | return '@'
108 | except:
109 | if not any(char.isalpha() for char in word):
110 | return '@'
111 | return word
112 |
113 |
114 | def clean_words(str_seqs):
115 | processed_seqs = []
116 | for str_seq in str_seqs:
117 | processed_clauses = []
118 | for clause in str_seq:
119 | filtered = []
120 | tokens = clause.split()
121 | for word in tokens:
122 | word = clean_url(word)
123 | word = clean_num(word)
124 | filtered.append(word)
125 | filtered_clause = " ".join(filtered)
126 | processed_clauses.append(filtered_clause)
127 | processed_seqs.append(processed_clauses)
128 | return processed_seqs
129 |
130 | def test_f1(test_file,pred_label_seqs):
131 | def linearize(labels):
132 | linearized = []
133 | for paper in labels:
134 | for label in paper:
135 | linearized.append(label)
136 | return linearized
137 | _, label_seqs = read_passages_original(test_file,True)
138 | true_label = linearize(label_seqs)
139 | pred_label = linearize(pred_label_seqs)
140 |
141 | f1 = f1_score(true_label,pred_label,average="weighted")
142 | print("F1 score:",f1)
143 | return f1
144 |
145 | def postprocess(dataset, raw_flattened_output, raw_flattened_labels, MAX_SEQ_LEN):
146 | ground_truth_labels = []
147 | paragraph_lens = []
148 | for para in dataset.true_pairs:
149 | paragraph_lens.append(len(para["paragraph"]))
150 | ground_truth_labels.append(para["label"])
151 |
152 | raw_flattened_output = raw_flattened_output.tolist()
153 | raw_flattened_labels = raw_flattened_labels.tolist()
154 | batch_i = 0
155 | predicted_tags = []
156 | gt_tags = []
157 | for length in paragraph_lens:
158 | remaining_len = length
159 | predict_idx = []
160 | gt_tag = []
161 | while remaining_len > MAX_SEQ_LEN:
162 | this_batch = raw_flattened_output[batch_i*MAX_SEQ_LEN:(batch_i+1)*MAX_SEQ_LEN]
163 | this_batch_label = raw_flattened_labels[batch_i*MAX_SEQ_LEN:(batch_i+1)*MAX_SEQ_LEN]
164 | predict_idx.extend(this_batch)
165 | gt_tag.extend(this_batch_label)
166 | batch_i += 1
167 | remaining_len -= MAX_SEQ_LEN
168 |
169 | this_batch = raw_flattened_output[batch_i*MAX_SEQ_LEN:(batch_i+1)*MAX_SEQ_LEN]
170 | this_batch_label = raw_flattened_labels[batch_i*MAX_SEQ_LEN:(batch_i+1)*MAX_SEQ_LEN]
171 | predict_idx.extend(this_batch[:remaining_len])
172 | gt_tag.extend(this_batch_label[:remaining_len])
173 | predict_tag = [dataset.rev_label_ind[idx] for idx in predict_idx]
174 | gt_tag = [dataset.rev_label_ind[idx] for idx in gt_tag]
175 | batch_i += 1
176 | predicted_tags.append(predict_tag)
177 | gt_tags.append(gt_tag)
178 |
179 |
180 | predicted_tags = from_BIO(predicted_tags)
181 | final_gt = from_BIO(gt_tags)
182 |
183 | return predicted_tags, final_gt
184 |
185 | def stance_postprocess(dataset, raw_output, raw_labels, MAX_SEQ_LEN):
186 |
187 | def combine(candidates):
188 | assert(len(candidates)>0)
189 | types = set(candidates)
190 | if len(types) == 1:
191 | return list(types)[0]
192 | elif 2 in types:
193 | return 2
194 | else:
195 | return 1
196 |
197 |
198 | ground_truth_labels = []
199 | paragraph_lens = []
200 | for para in dataset.true_pairs:
201 | paragraph_lens.append(len(para["paragraph"]))
202 | ground_truth_labels.append(para["label"])
203 |
204 | raw_output = raw_output.tolist()
205 | raw_labels = raw_labels.tolist()
206 | batch_i = 0
207 | predicted_tags = []
208 | gt_tags = []
209 | for length in paragraph_lens:
210 | remaining_len = length
211 | predict_idx = []
212 | gt_tag = []
213 | while remaining_len > MAX_SEQ_LEN:
214 | this_batch = raw_output[batch_i]
215 | this_batch_label = raw_labels[batch_i]
216 | predict_idx.append(this_batch)
217 | gt_tag.append(this_batch_label)
218 | batch_i += 1
219 | remaining_len -= MAX_SEQ_LEN
220 |
221 | this_batch = raw_output[batch_i]
222 | this_batch_label = raw_labels[batch_i]
223 | predict_idx.append(this_batch)
224 | gt_tag.append(this_batch_label)
225 | predict_tag = combine(predict_idx)
226 | gt_tag = combine(gt_tag)
227 | batch_i += 1
228 | predicted_tags.append(predict_tag)
229 | gt_tags.append(gt_tag)
230 |
231 | return predicted_tags, gt_tags
232 |
233 | def rationale2json(true_pairs, predictions, excluded_pairs = None):
234 | claim_ids = []
235 | claims = {}
236 | assert(len(true_pairs) == len(predictions))
237 | for pair, prediction in zip(true_pairs, predictions):
238 | claim_id = pair["claim_id"]
239 | claim_ids.append(claim_id)
240 |
241 | predicted_sentences = []
242 | for i, pred in enumerate(prediction):
243 | if pred == "rationale" or pred == 1:
244 | predicted_sentences.append(i)
245 |
246 | this_claim = claims.get(claim_id, {"claim_id": claim_id, "evidence":{}})
247 | #if len(predicted_sentences) > 0:
248 | this_claim["evidence"][pair["doc_id"]] = predicted_sentences
249 | claims[claim_id] = this_claim
250 | if excluded_pairs is not None:
251 | for pair in excluded_pairs:
252 | claims[pair["claim_id"]] = {"claim_id": pair["claim_id"], "evidence":{}}
253 | claim_ids.append(pair["claim_id"])
254 | return [claims[claim_id] for claim_id in sorted(list(set(claim_ids)))]
255 |
256 | def stance2json(true_pairs, predictions, excluded_pairs = None):
257 | claim_ids = []
258 | claims = {}
259 | idx2stance = ["NOT_ENOUGH_INFO", "SUPPORT", "CONTRADICT"]
260 | assert(len(true_pairs) == len(predictions))
261 | for pair, prediction in zip(true_pairs, predictions):
262 | claim_id = pair["claim_id"]
263 | claim_ids.append(claim_id)
264 |
265 | this_claim = claims.get(claim_id, {"claim_id": claim_id, "labels":{}})
266 | this_claim["labels"][pair["doc_id"]] = {"label": idx2stance[prediction], 'confidence': 1}
267 | claims[claim_id] = this_claim
268 | if excluded_pairs is not None:
269 | for pair in excluded_pairs:
270 | claims[pair["claim_id"]] = {"claim_id": pair["claim_id"], "labels":{}}
271 | claim_ids.append(pair["claim_id"])
272 | return [claims[claim_id] for claim_id in sorted(list(set(claim_ids)))]
273 |
274 | def merge_json(rationale_jsons, stance_jsons):
275 | stance_json_dict = {str(stance_json["claim_id"]): stance_json for stance_json in stance_jsons}
276 | jsons = []
277 | for rationale_json in rationale_jsons:
278 | id = str(rationale_json["claim_id"])
279 | result = {}
280 | if id in stance_json_dict:
281 | for k, v in rationale_json["evidence"].items():
282 | if len(v) > 0 and stance_json_dict[id]["labels"][int(k)]["label"] is not "NOT_ENOUGH_INFO":
283 | result[k] = {
284 | "sentences": v,
285 | "label": stance_json_dict[id]["labels"][int(k)]["label"]
286 | }
287 | jsons.append({"id":int(id), "evidence": result})
288 | return jsons
289 |
290 | def arg2param(args):
291 | params = vars(args)
292 | params["MAX_SEQ_LEN"]=params["CHUNK_SIZE"]*params["CHUNK_PER_SEQ"]
293 | params["MINIBATCH_SIZE"] = params["CHUNK_PER_SEQ"]
294 | params["SENTENCE_BATCH_SIZE"]=params["CHUNK_SIZE"]
295 | params["CHUNK_PER_STEP"]=params["PARAGRAPH_PER_STEP"]*params["CHUNK_PER_SEQ"]
296 |
297 | return params
298 |
--------------------------------------------------------------------------------