├── .gitignore ├── README.md ├── evaluation.py ├── long_seq.py ├── losses.py ├── model.py ├── prepro.py ├── scripts ├── run_bert.sh ├── run_cdr.sh ├── run_gda.sh └── run_roberta.sh ├── train.py ├── train_bio.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | 2 | *.json 3 | *.pyc 4 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ATLOP 2 | Code for AAAI 2021 paper [Document-Level Relation Extraction with Adaptive Thresholding and Localized Context Pooling](https://arxiv.org/abs/2010.11304). 3 | 4 | If you make use of this code in your work, please kindly cite the following paper: 5 | 6 | ```bibtex 7 | @inproceedings{zhou2021atlop, 8 | title={Document-Level Relation Extraction with Adaptive Thresholding and Localized Context Pooling}, 9 | author={Zhou, Wenxuan and Huang, Kevin and Ma, Tengyu and Huang, Jing}, 10 | booktitle={Proceedings of the AAAI Conference on Artificial Intelligence}, 11 | year={2021} 12 | } 13 | ``` 14 | ## Requirements 15 | * Python (tested on 3.7.4) 16 | * CUDA (tested on 10.2) 17 | * [PyTorch](http://pytorch.org/) (tested on 1.7.0) 18 | * [Transformers](https://github.com/huggingface/transformers) (tested on 3.4.0) 19 | * numpy (tested on 1.19.4) 20 | * [apex](https://github.com/NVIDIA/apex) (tested on 0.1) 21 | * [opt-einsum](https://github.com/dgasmith/opt_einsum) (tested on 3.3.0) 22 | * wandb 23 | * ujson 24 | * tqdm 25 | 26 | ## Dataset 27 | The [DocRED](https://www.aclweb.org/anthology/P19-1074/) dataset can be downloaded following the instructions at [link](https://github.com/thunlp/DocRED/tree/master/data). The CDR and GDA datasets can be obtained following the instructions in [edge-oriented graph](https://github.com/fenchri/edge-oriented-graph). The expected structure of files is: 28 | ``` 29 | ATLOP 30 | |-- dataset 31 | | |-- docred 32 | | | |-- train_annotated.json 33 | | | |-- train_distant.json 34 | | | |-- dev.json 35 | | | |-- test.json 36 | | |-- cdr 37 | | | |-- train_filter.data 38 | | | |-- dev_filter.data 39 | | | |-- test_filter.data 40 | | |-- gda 41 | | | |-- train.data 42 | | | |-- dev.data 43 | | | |-- test.data 44 | |-- meta 45 | | |-- rel2id.json 46 | ``` 47 | 48 | ## Training and Evaluation 49 | ### DocRED 50 | Train the BERT model on DocRED with the following command: 51 | 52 | ```bash 53 | >> sh scripts/run_bert.sh # for BERT 54 | >> sh scripts/run_roberta.sh # for RoBERTa 55 | ``` 56 | 57 | The training loss and evaluation results on the dev set are synced to the wandb dashboard. 58 | 59 | The program will generate a test file `result.json` in the official evaluation format. You can compress and submit it to Colab for the official test score. 60 | 61 | ### CDR and GDA 62 | Train CDA and GDA model with the following command: 63 | ```bash 64 | >> sh scripts/run_cdr.sh # for CDR 65 | >> sh scripts/run_gda.sh # for GDA 66 | ``` 67 | 68 | The training loss and evaluation results on the dev and test set are synced to the wandb dashboard. 69 | 70 | ## Saving and Evaluating Models 71 | You can save the model by setting the `--save_path` argument before training. The model correponds to the best dev results will be saved. After that, You can evaluate the saved model by setting the `--load_path` argument, then the code will skip training and evaluate the saved model on benchmarks. I've also released the trained `atlop-bert-base` and `atlop-roberta` models. 72 | -------------------------------------------------------------------------------- /evaluation.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path 3 | import json 4 | import numpy as np 5 | 6 | rel2id = json.load(open('meta/rel2id.json', 'r')) 7 | id2rel = {value: key for key, value in rel2id.items()} 8 | 9 | 10 | def to_official(preds, features): 11 | h_idx, t_idx, title = [], [], [] 12 | 13 | for f in features: 14 | hts = f["hts"] 15 | h_idx += [ht[0] for ht in hts] 16 | t_idx += [ht[1] for ht in hts] 17 | title += [f["title"] for ht in hts] 18 | 19 | res = [] 20 | for i in range(preds.shape[0]): 21 | pred = preds[i] 22 | pred = np.nonzero(pred)[0].tolist() 23 | for p in pred: 24 | if p != 0: 25 | res.append( 26 | { 27 | 'title': title[i], 28 | 'h_idx': h_idx[i], 29 | 't_idx': t_idx[i], 30 | 'r': id2rel[p], 31 | } 32 | ) 33 | return res 34 | 35 | 36 | def gen_train_facts(data_file_name, truth_dir): 37 | fact_file_name = data_file_name[data_file_name.find("train_"):] 38 | fact_file_name = os.path.join(truth_dir, fact_file_name.replace(".json", ".fact")) 39 | 40 | if os.path.exists(fact_file_name): 41 | fact_in_train = set([]) 42 | triples = json.load(open(fact_file_name)) 43 | for x in triples: 44 | fact_in_train.add(tuple(x)) 45 | return fact_in_train 46 | 47 | fact_in_train = set([]) 48 | ori_data = json.load(open(data_file_name)) 49 | for data in ori_data: 50 | vertexSet = data['vertexSet'] 51 | for label in data['labels']: 52 | rel = label['r'] 53 | for n1 in vertexSet[label['h']]: 54 | for n2 in vertexSet[label['t']]: 55 | fact_in_train.add((n1['name'], n2['name'], rel)) 56 | 57 | json.dump(list(fact_in_train), open(fact_file_name, "w")) 58 | 59 | return fact_in_train 60 | 61 | 62 | def official_evaluate(tmp, path): 63 | ''' 64 | Adapted from the official evaluation code 65 | ''' 66 | truth_dir = os.path.join(path, 'ref') 67 | 68 | if not os.path.exists(truth_dir): 69 | os.makedirs(truth_dir) 70 | 71 | fact_in_train_annotated = gen_train_facts(os.path.join(path, "train_annotated.json"), truth_dir) 72 | fact_in_train_distant = gen_train_facts(os.path.join(path, "train_distant.json"), truth_dir) 73 | 74 | truth = json.load(open(os.path.join(path, "dev.json"))) 75 | 76 | std = {} 77 | tot_evidences = 0 78 | titleset = set([]) 79 | 80 | title2vectexSet = {} 81 | 82 | for x in truth: 83 | title = x['title'] 84 | titleset.add(title) 85 | 86 | vertexSet = x['vertexSet'] 87 | title2vectexSet[title] = vertexSet 88 | 89 | for label in x['labels']: 90 | r = label['r'] 91 | h_idx = label['h'] 92 | t_idx = label['t'] 93 | std[(title, r, h_idx, t_idx)] = set(label['evidence']) 94 | tot_evidences += len(label['evidence']) 95 | 96 | tot_relations = len(std) 97 | tmp.sort(key=lambda x: (x['title'], x['h_idx'], x['t_idx'], x['r'])) 98 | submission_answer = [tmp[0]] 99 | for i in range(1, len(tmp)): 100 | x = tmp[i] 101 | y = tmp[i - 1] 102 | if (x['title'], x['h_idx'], x['t_idx'], x['r']) != (y['title'], y['h_idx'], y['t_idx'], y['r']): 103 | submission_answer.append(tmp[i]) 104 | 105 | correct_re = 0 106 | correct_evidence = 0 107 | pred_evi = 0 108 | 109 | correct_in_train_annotated = 0 110 | correct_in_train_distant = 0 111 | titleset2 = set([]) 112 | for x in submission_answer: 113 | title = x['title'] 114 | h_idx = x['h_idx'] 115 | t_idx = x['t_idx'] 116 | r = x['r'] 117 | titleset2.add(title) 118 | if title not in title2vectexSet: 119 | continue 120 | vertexSet = title2vectexSet[title] 121 | 122 | if 'evidence' in x: 123 | evi = set(x['evidence']) 124 | else: 125 | evi = set([]) 126 | pred_evi += len(evi) 127 | 128 | if (title, r, h_idx, t_idx) in std: 129 | correct_re += 1 130 | stdevi = std[(title, r, h_idx, t_idx)] 131 | correct_evidence += len(stdevi & evi) 132 | in_train_annotated = in_train_distant = False 133 | for n1 in vertexSet[h_idx]: 134 | for n2 in vertexSet[t_idx]: 135 | if (n1['name'], n2['name'], r) in fact_in_train_annotated: 136 | in_train_annotated = True 137 | if (n1['name'], n2['name'], r) in fact_in_train_distant: 138 | in_train_distant = True 139 | 140 | if in_train_annotated: 141 | correct_in_train_annotated += 1 142 | if in_train_distant: 143 | correct_in_train_distant += 1 144 | 145 | re_p = 1.0 * correct_re / len(submission_answer) 146 | re_r = 1.0 * correct_re / tot_relations 147 | if re_p + re_r == 0: 148 | re_f1 = 0 149 | else: 150 | re_f1 = 2.0 * re_p * re_r / (re_p + re_r) 151 | 152 | evi_p = 1.0 * correct_evidence / pred_evi if pred_evi > 0 else 0 153 | evi_r = 1.0 * correct_evidence / tot_evidences 154 | if evi_p + evi_r == 0: 155 | evi_f1 = 0 156 | else: 157 | evi_f1 = 2.0 * evi_p * evi_r / (evi_p + evi_r) 158 | 159 | re_p_ignore_train_annotated = 1.0 * (correct_re - correct_in_train_annotated) / (len(submission_answer) - correct_in_train_annotated + 1e-5) 160 | re_p_ignore_train = 1.0 * (correct_re - correct_in_train_distant) / (len(submission_answer) - correct_in_train_distant + 1e-5) 161 | 162 | if re_p_ignore_train_annotated + re_r == 0: 163 | re_f1_ignore_train_annotated = 0 164 | else: 165 | re_f1_ignore_train_annotated = 2.0 * re_p_ignore_train_annotated * re_r / (re_p_ignore_train_annotated + re_r) 166 | 167 | if re_p_ignore_train + re_r == 0: 168 | re_f1_ignore_train = 0 169 | else: 170 | re_f1_ignore_train = 2.0 * re_p_ignore_train * re_r / (re_p_ignore_train + re_r) 171 | 172 | return re_f1, evi_f1, re_f1_ignore_train_annotated, re_f1_ignore_train 173 | -------------------------------------------------------------------------------- /long_seq.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import numpy as np 4 | 5 | 6 | def process_long_input(model, input_ids, attention_mask, start_tokens, end_tokens): 7 | # Split the input to 2 overlapping chunks. Now BERT can encode inputs of which the length are up to 1024. 8 | n, c = input_ids.size() 9 | start_tokens = torch.tensor(start_tokens).to(input_ids) 10 | end_tokens = torch.tensor(end_tokens).to(input_ids) 11 | len_start = start_tokens.size(0) 12 | len_end = end_tokens.size(0) 13 | if c <= 512: 14 | output = model( 15 | input_ids=input_ids, 16 | attention_mask=attention_mask, 17 | output_attentions=True, 18 | ) 19 | sequence_output = output[0] 20 | attention = output[-1][-1] 21 | else: 22 | new_input_ids, new_attention_mask, num_seg = [], [], [] 23 | seq_len = attention_mask.sum(1).cpu().numpy().astype(np.int32).tolist() 24 | for i, l_i in enumerate(seq_len): 25 | if l_i <= 512: 26 | new_input_ids.append(input_ids[i, :512]) 27 | new_attention_mask.append(attention_mask[i, :512]) 28 | num_seg.append(1) 29 | else: 30 | input_ids1 = torch.cat([input_ids[i, :512 - len_end], end_tokens], dim=-1) 31 | input_ids2 = torch.cat([start_tokens, input_ids[i, (l_i - 512 + len_start): l_i]], dim=-1) 32 | attention_mask1 = attention_mask[i, :512] 33 | attention_mask2 = attention_mask[i, (l_i - 512): l_i] 34 | new_input_ids.extend([input_ids1, input_ids2]) 35 | new_attention_mask.extend([attention_mask1, attention_mask2]) 36 | num_seg.append(2) 37 | input_ids = torch.stack(new_input_ids, dim=0) 38 | attention_mask = torch.stack(new_attention_mask, dim=0) 39 | output = model( 40 | input_ids=input_ids, 41 | attention_mask=attention_mask, 42 | output_attentions=True, 43 | ) 44 | sequence_output = output[0] 45 | attention = output[-1][-1] 46 | i = 0 47 | new_output, new_attention = [], [] 48 | for (n_s, l_i) in zip(num_seg, seq_len): 49 | if n_s == 1: 50 | output = F.pad(sequence_output[i], (0, 0, 0, c - 512)) 51 | att = F.pad(attention[i], (0, c - 512, 0, c - 512)) 52 | new_output.append(output) 53 | new_attention.append(att) 54 | elif n_s == 2: 55 | output1 = sequence_output[i][:512 - len_end] 56 | mask1 = attention_mask[i][:512 - len_end] 57 | att1 = attention[i][:, :512 - len_end, :512 - len_end] 58 | output1 = F.pad(output1, (0, 0, 0, c - 512 + len_end)) 59 | mask1 = F.pad(mask1, (0, c - 512 + len_end)) 60 | att1 = F.pad(att1, (0, c - 512 + len_end, 0, c - 512 + len_end)) 61 | 62 | output2 = sequence_output[i + 1][len_start:] 63 | mask2 = attention_mask[i + 1][len_start:] 64 | att2 = attention[i + 1][:, len_start:, len_start:] 65 | output2 = F.pad(output2, (0, 0, l_i - 512 + len_start, c - l_i)) 66 | mask2 = F.pad(mask2, (l_i - 512 + len_start, c - l_i)) 67 | att2 = F.pad(att2, [l_i - 512 + len_start, c - l_i, l_i - 512 + len_start, c - l_i]) 68 | mask = mask1 + mask2 + 1e-10 69 | output = (output1 + output2) / mask.unsqueeze(-1) 70 | att = (att1 + att2) 71 | att = att / (att.sum(-1, keepdim=True) + 1e-10) 72 | new_output.append(output) 73 | new_attention.append(att) 74 | i += n_s 75 | sequence_output = torch.stack(new_output, dim=0) 76 | attention = torch.stack(new_attention, dim=0) 77 | return sequence_output, attention 78 | -------------------------------------------------------------------------------- /losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class ATLoss(nn.Module): 7 | def __init__(self): 8 | super().__init__() 9 | 10 | def forward(self, logits, labels): 11 | # TH label 12 | th_label = torch.zeros_like(labels, dtype=torch.float).to(labels) 13 | th_label[:, 0] = 1.0 14 | labels[:, 0] = 0.0 15 | 16 | p_mask = labels + th_label 17 | n_mask = 1 - labels 18 | 19 | # Rank positive classes to TH 20 | logit1 = logits - (1 - p_mask) * 1e30 21 | loss1 = -(F.log_softmax(logit1, dim=-1) * labels).sum(1) 22 | 23 | # Rank TH to negative classes 24 | logit2 = logits - (1 - n_mask) * 1e30 25 | loss2 = -(F.log_softmax(logit2, dim=-1) * th_label).sum(1) 26 | 27 | # Sum two parts 28 | loss = loss1 + loss2 29 | loss = loss.mean() 30 | return loss 31 | 32 | def get_label(self, logits, num_labels=-1): 33 | th_logit = logits[:, 0].unsqueeze(1) 34 | output = torch.zeros_like(logits).to(logits) 35 | mask = (logits > th_logit) 36 | if num_labels > 0: 37 | top_v, _ = torch.topk(logits, num_labels, dim=1) 38 | top_v = top_v[:, -1] 39 | mask = (logits >= top_v.unsqueeze(1)) & mask 40 | output[mask] = 1.0 41 | output[:, 0] = (output.sum(1) == 0.).to(logits) 42 | return output 43 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from opt_einsum import contract 4 | from long_seq import process_long_input 5 | from losses import ATLoss 6 | 7 | 8 | class DocREModel(nn.Module): 9 | def __init__(self, config, model, emb_size=768, block_size=64, num_labels=-1): 10 | super().__init__() 11 | self.config = config 12 | self.model = model 13 | self.hidden_size = config.hidden_size 14 | self.loss_fnt = ATLoss() 15 | 16 | self.head_extractor = nn.Linear(2 * config.hidden_size, emb_size) 17 | self.tail_extractor = nn.Linear(2 * config.hidden_size, emb_size) 18 | self.bilinear = nn.Linear(emb_size * block_size, config.num_labels) 19 | 20 | self.emb_size = emb_size 21 | self.block_size = block_size 22 | self.num_labels = num_labels 23 | 24 | def encode(self, input_ids, attention_mask): 25 | config = self.config 26 | if config.transformer_type == "bert": 27 | start_tokens = [config.cls_token_id] 28 | end_tokens = [config.sep_token_id] 29 | elif config.transformer_type == "roberta": 30 | start_tokens = [config.cls_token_id] 31 | end_tokens = [config.sep_token_id, config.sep_token_id] 32 | sequence_output, attention = process_long_input(self.model, input_ids, attention_mask, start_tokens, end_tokens) 33 | return sequence_output, attention 34 | 35 | def get_hrt(self, sequence_output, attention, entity_pos, hts): 36 | offset = 1 if self.config.transformer_type in ["bert", "roberta"] else 0 37 | n, h, _, c = attention.size() 38 | hss, tss, rss = [], [], [] 39 | for i in range(len(entity_pos)): 40 | entity_embs, entity_atts = [], [] 41 | for e in entity_pos[i]: 42 | if len(e) > 1: 43 | e_emb, e_att = [], [] 44 | for start, end in e: 45 | if start + offset < c: 46 | # In case the entity mention is truncated due to limited max seq length. 47 | e_emb.append(sequence_output[i, start + offset]) 48 | e_att.append(attention[i, :, start + offset]) 49 | if len(e_emb) > 0: 50 | e_emb = torch.logsumexp(torch.stack(e_emb, dim=0), dim=0) 51 | e_att = torch.stack(e_att, dim=0).mean(0) 52 | else: 53 | e_emb = torch.zeros(self.config.hidden_size).to(sequence_output) 54 | e_att = torch.zeros(h, c).to(attention) 55 | else: 56 | start, end = e[0] 57 | if start + offset < c: 58 | e_emb = sequence_output[i, start + offset] 59 | e_att = attention[i, :, start + offset] 60 | else: 61 | e_emb = torch.zeros(self.config.hidden_size).to(sequence_output) 62 | e_att = torch.zeros(h, c).to(attention) 63 | entity_embs.append(e_emb) 64 | entity_atts.append(e_att) 65 | 66 | entity_embs = torch.stack(entity_embs, dim=0) # [n_e, d] 67 | entity_atts = torch.stack(entity_atts, dim=0) # [n_e, h, seq_len] 68 | 69 | ht_i = torch.LongTensor(hts[i]).to(sequence_output.device) 70 | hs = torch.index_select(entity_embs, 0, ht_i[:, 0]) 71 | ts = torch.index_select(entity_embs, 0, ht_i[:, 1]) 72 | 73 | h_att = torch.index_select(entity_atts, 0, ht_i[:, 0]) 74 | t_att = torch.index_select(entity_atts, 0, ht_i[:, 1]) 75 | ht_att = (h_att * t_att).mean(1) 76 | ht_att = ht_att / (ht_att.sum(1, keepdim=True) + 1e-5) 77 | rs = contract("ld,rl->rd", sequence_output[i], ht_att) 78 | hss.append(hs) 79 | tss.append(ts) 80 | rss.append(rs) 81 | hss = torch.cat(hss, dim=0) 82 | tss = torch.cat(tss, dim=0) 83 | rss = torch.cat(rss, dim=0) 84 | return hss, rss, tss 85 | 86 | def forward(self, 87 | input_ids=None, 88 | attention_mask=None, 89 | labels=None, 90 | entity_pos=None, 91 | hts=None, 92 | instance_mask=None, 93 | ): 94 | 95 | sequence_output, attention = self.encode(input_ids, attention_mask) 96 | hs, rs, ts = self.get_hrt(sequence_output, attention, entity_pos, hts) 97 | 98 | hs = torch.tanh(self.head_extractor(torch.cat([hs, rs], dim=1))) 99 | ts = torch.tanh(self.tail_extractor(torch.cat([ts, rs], dim=1))) 100 | b1 = hs.view(-1, self.emb_size // self.block_size, self.block_size) 101 | b2 = ts.view(-1, self.emb_size // self.block_size, self.block_size) 102 | bl = (b1.unsqueeze(3) * b2.unsqueeze(2)).view(-1, self.emb_size * self.block_size) 103 | logits = self.bilinear(bl) 104 | 105 | output = (self.loss_fnt.get_label(logits, num_labels=self.num_labels),) 106 | if labels is not None: 107 | labels = [torch.tensor(label) for label in labels] 108 | labels = torch.cat(labels, dim=0).to(logits) 109 | loss = self.loss_fnt(logits.float(), labels.float()) 110 | output = (loss.to(sequence_output),) + output 111 | return output 112 | -------------------------------------------------------------------------------- /prepro.py: -------------------------------------------------------------------------------- 1 | from tqdm import tqdm 2 | import ujson as json 3 | 4 | docred_rel2id = json.load(open('meta/rel2id.json', 'r')) 5 | cdr_rel2id = {'1:NR:2': 0, '1:CID:2': 1} 6 | gda_rel2id = {'1:NR:2': 0, '1:GDA:2': 1} 7 | 8 | 9 | def chunks(l, n): 10 | res = [] 11 | for i in range(0, len(l), n): 12 | assert len(l[i:i + n]) == n 13 | res += [l[i:i + n]] 14 | return res 15 | 16 | 17 | def read_docred(file_in, tokenizer, max_seq_length=1024): 18 | i_line = 0 19 | pos_samples = 0 20 | neg_samples = 0 21 | features = [] 22 | if file_in == "": 23 | return None 24 | with open(file_in, "r") as fh: 25 | data = json.load(fh) 26 | 27 | for sample in tqdm(data, desc="Example"): 28 | sents = [] 29 | sent_map = [] 30 | 31 | entities = sample['vertexSet'] 32 | entity_start, entity_end = [], [] 33 | for entity in entities: 34 | for mention in entity: 35 | sent_id = mention["sent_id"] 36 | pos = mention["pos"] 37 | entity_start.append((sent_id, pos[0],)) 38 | entity_end.append((sent_id, pos[1] - 1,)) 39 | for i_s, sent in enumerate(sample['sents']): 40 | new_map = {} 41 | for i_t, token in enumerate(sent): 42 | tokens_wordpiece = tokenizer.tokenize(token) 43 | if (i_s, i_t) in entity_start: 44 | tokens_wordpiece = ["*"] + tokens_wordpiece 45 | if (i_s, i_t) in entity_end: 46 | tokens_wordpiece = tokens_wordpiece + ["*"] 47 | new_map[i_t] = len(sents) 48 | sents.extend(tokens_wordpiece) 49 | new_map[i_t + 1] = len(sents) 50 | sent_map.append(new_map) 51 | 52 | train_triple = {} 53 | if "labels" in sample: 54 | for label in sample['labels']: 55 | evidence = label['evidence'] 56 | r = int(docred_rel2id[label['r']]) 57 | if (label['h'], label['t']) not in train_triple: 58 | train_triple[(label['h'], label['t'])] = [ 59 | {'relation': r, 'evidence': evidence}] 60 | else: 61 | train_triple[(label['h'], label['t'])].append( 62 | {'relation': r, 'evidence': evidence}) 63 | 64 | entity_pos = [] 65 | for e in entities: 66 | entity_pos.append([]) 67 | for m in e: 68 | start = sent_map[m["sent_id"]][m["pos"][0]] 69 | end = sent_map[m["sent_id"]][m["pos"][1]] 70 | entity_pos[-1].append((start, end,)) 71 | 72 | relations, hts = [], [] 73 | for h, t in train_triple.keys(): 74 | relation = [0] * len(docred_rel2id) 75 | for mention in train_triple[h, t]: 76 | relation[mention["relation"]] = 1 77 | evidence = mention["evidence"] 78 | relations.append(relation) 79 | hts.append([h, t]) 80 | pos_samples += 1 81 | 82 | for h in range(len(entities)): 83 | for t in range(len(entities)): 84 | if h != t and [h, t] not in hts: 85 | relation = [1] + [0] * (len(docred_rel2id) - 1) 86 | relations.append(relation) 87 | hts.append([h, t]) 88 | neg_samples += 1 89 | 90 | assert len(relations) == len(entities) * (len(entities) - 1) 91 | 92 | sents = sents[:max_seq_length - 2] 93 | input_ids = tokenizer.convert_tokens_to_ids(sents) 94 | input_ids = tokenizer.build_inputs_with_special_tokens(input_ids) 95 | 96 | i_line += 1 97 | feature = {'input_ids': input_ids, 98 | 'entity_pos': entity_pos, 99 | 'labels': relations, 100 | 'hts': hts, 101 | 'title': sample['title'], 102 | } 103 | features.append(feature) 104 | 105 | print("# of documents {}.".format(i_line)) 106 | print("# of positive examples {}.".format(pos_samples)) 107 | print("# of negative examples {}.".format(neg_samples)) 108 | return features 109 | 110 | 111 | def read_cdr(file_in, tokenizer, max_seq_length=1024): 112 | pmids = set() 113 | features = [] 114 | maxlen = 0 115 | with open(file_in, 'r') as infile: 116 | lines = infile.readlines() 117 | for i_l, line in enumerate(tqdm(lines)): 118 | line = line.rstrip().split('\t') 119 | pmid = line[0] 120 | 121 | if pmid not in pmids: 122 | pmids.add(pmid) 123 | text = line[1] 124 | prs = chunks(line[2:], 17) 125 | 126 | ent2idx = {} 127 | train_triples = {} 128 | 129 | entity_pos = set() 130 | for p in prs: 131 | es = list(map(int, p[8].split(':'))) 132 | ed = list(map(int, p[9].split(':'))) 133 | tpy = p[7] 134 | for start, end in zip(es, ed): 135 | entity_pos.add((start, end, tpy)) 136 | 137 | es = list(map(int, p[14].split(':'))) 138 | ed = list(map(int, p[15].split(':'))) 139 | tpy = p[13] 140 | for start, end in zip(es, ed): 141 | entity_pos.add((start, end, tpy)) 142 | 143 | sents = [t.split(' ') for t in text.split('|')] 144 | new_sents = [] 145 | sent_map = {} 146 | i_t = 0 147 | for sent in sents: 148 | for token in sent: 149 | tokens_wordpiece = tokenizer.tokenize(token) 150 | for start, end, tpy in list(entity_pos): 151 | if i_t == start: 152 | tokens_wordpiece = ["*"] + tokens_wordpiece 153 | if i_t + 1 == end: 154 | tokens_wordpiece = tokens_wordpiece + ["*"] 155 | sent_map[i_t] = len(new_sents) 156 | new_sents.extend(tokens_wordpiece) 157 | i_t += 1 158 | sent_map[i_t] = len(new_sents) 159 | sents = new_sents 160 | 161 | entity_pos = [] 162 | 163 | for p in prs: 164 | if p[0] == "not_include": 165 | continue 166 | if p[1] == "L2R": 167 | h_id, t_id = p[5], p[11] 168 | h_start, t_start = p[8], p[14] 169 | h_end, t_end = p[9], p[15] 170 | else: 171 | t_id, h_id = p[5], p[11] 172 | t_start, h_start = p[8], p[14] 173 | t_end, h_end = p[9], p[15] 174 | h_start = map(int, h_start.split(':')) 175 | h_end = map(int, h_end.split(':')) 176 | t_start = map(int, t_start.split(':')) 177 | t_end = map(int, t_end.split(':')) 178 | h_start = [sent_map[idx] for idx in h_start] 179 | h_end = [sent_map[idx] for idx in h_end] 180 | t_start = [sent_map[idx] for idx in t_start] 181 | t_end = [sent_map[idx] for idx in t_end] 182 | if h_id not in ent2idx: 183 | ent2idx[h_id] = len(ent2idx) 184 | entity_pos.append(list(zip(h_start, h_end))) 185 | if t_id not in ent2idx: 186 | ent2idx[t_id] = len(ent2idx) 187 | entity_pos.append(list(zip(t_start, t_end))) 188 | h_id, t_id = ent2idx[h_id], ent2idx[t_id] 189 | 190 | r = cdr_rel2id[p[0]] 191 | if (h_id, t_id) not in train_triples: 192 | train_triples[(h_id, t_id)] = [{'relation': r}] 193 | else: 194 | train_triples[(h_id, t_id)].append({'relation': r}) 195 | 196 | relations, hts = [], [] 197 | for h, t in train_triples.keys(): 198 | relation = [0] * len(cdr_rel2id) 199 | for mention in train_triples[h, t]: 200 | relation[mention["relation"]] = 1 201 | relations.append(relation) 202 | hts.append([h, t]) 203 | 204 | maxlen = max(maxlen, len(sents)) 205 | sents = sents[:max_seq_length - 2] 206 | input_ids = tokenizer.convert_tokens_to_ids(sents) 207 | input_ids = tokenizer.build_inputs_with_special_tokens(input_ids) 208 | 209 | if len(hts) > 0: 210 | feature = {'input_ids': input_ids, 211 | 'entity_pos': entity_pos, 212 | 'labels': relations, 213 | 'hts': hts, 214 | 'title': pmid, 215 | } 216 | features.append(feature) 217 | print("Number of documents: {}.".format(len(features))) 218 | print("Max document length: {}.".format(maxlen)) 219 | return features 220 | 221 | 222 | def read_gda(file_in, tokenizer, max_seq_length=1024): 223 | pmids = set() 224 | features = [] 225 | maxlen = 0 226 | with open(file_in, 'r') as infile: 227 | lines = infile.readlines() 228 | for i_l, line in enumerate(tqdm(lines)): 229 | line = line.rstrip().split('\t') 230 | pmid = line[0] 231 | 232 | if pmid not in pmids: 233 | pmids.add(pmid) 234 | text = line[1] 235 | prs = chunks(line[2:], 17) 236 | 237 | ent2idx = {} 238 | train_triples = {} 239 | 240 | entity_pos = set() 241 | for p in prs: 242 | es = list(map(int, p[8].split(':'))) 243 | ed = list(map(int, p[9].split(':'))) 244 | tpy = p[7] 245 | for start, end in zip(es, ed): 246 | entity_pos.add((start, end, tpy)) 247 | 248 | es = list(map(int, p[14].split(':'))) 249 | ed = list(map(int, p[15].split(':'))) 250 | tpy = p[13] 251 | for start, end in zip(es, ed): 252 | entity_pos.add((start, end, tpy)) 253 | 254 | sents = [t.split(' ') for t in text.split('|')] 255 | new_sents = [] 256 | sent_map = {} 257 | i_t = 0 258 | for sent in sents: 259 | for token in sent: 260 | tokens_wordpiece = tokenizer.tokenize(token) 261 | for start, end, tpy in list(entity_pos): 262 | if i_t == start: 263 | tokens_wordpiece = ["*"] + tokens_wordpiece 264 | if i_t + 1 == end: 265 | tokens_wordpiece = tokens_wordpiece + ["*"] 266 | sent_map[i_t] = len(new_sents) 267 | new_sents.extend(tokens_wordpiece) 268 | i_t += 1 269 | sent_map[i_t] = len(new_sents) 270 | sents = new_sents 271 | 272 | entity_pos = [] 273 | 274 | for p in prs: 275 | if p[0] == "not_include": 276 | continue 277 | if p[1] == "L2R": 278 | h_id, t_id = p[5], p[11] 279 | h_start, t_start = p[8], p[14] 280 | h_end, t_end = p[9], p[15] 281 | else: 282 | t_id, h_id = p[5], p[11] 283 | t_start, h_start = p[8], p[14] 284 | t_end, h_end = p[9], p[15] 285 | h_start = map(int, h_start.split(':')) 286 | h_end = map(int, h_end.split(':')) 287 | t_start = map(int, t_start.split(':')) 288 | t_end = map(int, t_end.split(':')) 289 | h_start = [sent_map[idx] for idx in h_start] 290 | h_end = [sent_map[idx] for idx in h_end] 291 | t_start = [sent_map[idx] for idx in t_start] 292 | t_end = [sent_map[idx] for idx in t_end] 293 | if h_id not in ent2idx: 294 | ent2idx[h_id] = len(ent2idx) 295 | entity_pos.append(list(zip(h_start, h_end))) 296 | if t_id not in ent2idx: 297 | ent2idx[t_id] = len(ent2idx) 298 | entity_pos.append(list(zip(t_start, t_end))) 299 | h_id, t_id = ent2idx[h_id], ent2idx[t_id] 300 | 301 | r = gda_rel2id[p[0]] 302 | if (h_id, t_id) not in train_triples: 303 | train_triples[(h_id, t_id)] = [{'relation': r}] 304 | else: 305 | train_triples[(h_id, t_id)].append({'relation': r}) 306 | 307 | relations, hts = [], [] 308 | for h, t in train_triples.keys(): 309 | relation = [0] * len(gda_rel2id) 310 | for mention in train_triples[h, t]: 311 | relation[mention["relation"]] = 1 312 | relations.append(relation) 313 | hts.append([h, t]) 314 | 315 | maxlen = max(maxlen, len(sents)) 316 | sents = sents[:max_seq_length - 2] 317 | input_ids = tokenizer.convert_tokens_to_ids(sents) 318 | input_ids = tokenizer.build_inputs_with_special_tokens(input_ids) 319 | 320 | if len(hts) > 0: 321 | feature = {'input_ids': input_ids, 322 | 'entity_pos': entity_pos, 323 | 'labels': relations, 324 | 'hts': hts, 325 | 'title': pmid, 326 | } 327 | features.append(feature) 328 | print("Number of documents: {}.".format(len(features))) 329 | print("Max document length: {}.".format(maxlen)) 330 | return features 331 | -------------------------------------------------------------------------------- /scripts/run_bert.sh: -------------------------------------------------------------------------------- 1 | python train.py --data_dir ./dataset/docred \ 2 | --transformer_type bert \ 3 | --model_name_or_path bert-base-cased \ 4 | --train_file train_annotated.json \ 5 | --dev_file dev.json \ 6 | --test_file test.json \ 7 | --train_batch_size 4 \ 8 | --test_batch_size 8 \ 9 | --gradient_accumulation_steps 1 \ 10 | --num_labels 4 \ 11 | --learning_rate 5e-5 \ 12 | --max_grad_norm 1.0 \ 13 | --warmup_ratio 0.06 \ 14 | --num_train_epochs 30.0 \ 15 | --seed 66 \ 16 | --num_class 97 17 | -------------------------------------------------------------------------------- /scripts/run_cdr.sh: -------------------------------------------------------------------------------- 1 | python train_bio.py --data_dir ./dataset/cdr \ 2 | --transformer_type bert \ 3 | --model_name_or_path allenai/scibert_scivocab_cased \ 4 | --train_file train_filter.data \ 5 | --dev_file dev_filter.data \ 6 | --test_file test_filter.data \ 7 | --train_batch_size 4 \ 8 | --test_batch_size 4 \ 9 | --gradient_accumulation_steps 1 \ 10 | --num_labels 1 \ 11 | --learning_rate 2e-5 \ 12 | --max_grad_norm 1.0 \ 13 | --warmup_ratio 0.06 \ 14 | --num_train_epochs 30.0 \ 15 | --seed 66 \ 16 | --num_class 2 17 | -------------------------------------------------------------------------------- /scripts/run_gda.sh: -------------------------------------------------------------------------------- 1 | python train_bio.py --data_dir ./dataset/gda \ 2 | --transformer_type bert \ 3 | --model_name_or_path allenai/scibert_scivocab_cased \ 4 | --train_file train.data \ 5 | --dev_file dev.data \ 6 | --test_file test.data \ 7 | --train_batch_size 4 \ 8 | --test_batch_size 8 \ 9 | --gradient_accumulation_steps 4 \ 10 | --num_labels 1 \ 11 | --learning_rate 2e-5 \ 12 | --max_grad_norm 1.0 \ 13 | --warmup_ratio 0.06 \ 14 | --num_train_epochs 10.0 \ 15 | --evaluation_steps 500 \ 16 | --seed 66 \ 17 | --num_class 2 18 | -------------------------------------------------------------------------------- /scripts/run_roberta.sh: -------------------------------------------------------------------------------- 1 | python train.py --data_dir ./dataset/docred \ 2 | --transformer_type roberta \ 3 | --model_name_or_path roberta-large \ 4 | --train_file train_annotated.json \ 5 | --dev_file dev.json \ 6 | --test_file test.json \ 7 | --train_batch_size 4 \ 8 | --test_batch_size 8 \ 9 | --gradient_accumulation_steps 1 \ 10 | --num_labels 4 \ 11 | --learning_rate 3e-5 \ 12 | --max_grad_norm 1.0 \ 13 | --warmup_ratio 0.06 \ 14 | --num_train_epochs 30.0 \ 15 | --seed 66 \ 16 | --num_class 97 17 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | 4 | import numpy as np 5 | import torch 6 | from apex import amp 7 | import ujson as json 8 | from torch.utils.data import DataLoader 9 | from transformers import AutoConfig, AutoModel, AutoTokenizer 10 | from transformers.optimization import AdamW, get_linear_schedule_with_warmup 11 | from model import DocREModel 12 | from utils import set_seed, collate_fn 13 | from prepro import read_docred 14 | from evaluation import to_official, official_evaluate 15 | import wandb 16 | 17 | 18 | def train(args, model, train_features, dev_features, test_features): 19 | def finetune(features, optimizer, num_epoch, num_steps): 20 | best_score = -1 21 | train_dataloader = DataLoader(features, batch_size=args.train_batch_size, shuffle=True, collate_fn=collate_fn, drop_last=True) 22 | train_iterator = range(int(num_epoch)) 23 | total_steps = int(len(train_dataloader) * num_epoch // args.gradient_accumulation_steps) 24 | warmup_steps = int(total_steps * args.warmup_ratio) 25 | scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=warmup_steps, num_training_steps=total_steps) 26 | print("Total steps: {}".format(total_steps)) 27 | print("Warmup steps: {}".format(warmup_steps)) 28 | for epoch in train_iterator: 29 | model.zero_grad() 30 | for step, batch in enumerate(train_dataloader): 31 | model.train() 32 | inputs = {'input_ids': batch[0].to(args.device), 33 | 'attention_mask': batch[1].to(args.device), 34 | 'labels': batch[2], 35 | 'entity_pos': batch[3], 36 | 'hts': batch[4], 37 | } 38 | outputs = model(**inputs) 39 | loss = outputs[0] / args.gradient_accumulation_steps 40 | with amp.scale_loss(loss, optimizer) as scaled_loss: 41 | scaled_loss.backward() 42 | if step % args.gradient_accumulation_steps == 0: 43 | if args.max_grad_norm > 0: 44 | torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm) 45 | optimizer.step() 46 | scheduler.step() 47 | model.zero_grad() 48 | num_steps += 1 49 | wandb.log({"loss": loss.item()}, step=num_steps) 50 | if (step + 1) == len(train_dataloader) - 1 or (args.evaluation_steps > 0 and num_steps % args.evaluation_steps == 0 and step % args.gradient_accumulation_steps == 0): 51 | dev_score, dev_output = evaluate(args, model, dev_features, tag="dev") 52 | wandb.log(dev_output, step=num_steps) 53 | print(dev_output) 54 | if dev_score > best_score: 55 | best_score = dev_score 56 | pred = report(args, model, test_features) 57 | with open("result.json", "w") as fh: 58 | json.dump(pred, fh) 59 | if args.save_path != "": 60 | torch.save(model.state_dict(), args.save_path) 61 | return num_steps 62 | 63 | new_layer = ["extractor", "bilinear"] 64 | optimizer_grouped_parameters = [ 65 | {"params": [p for n, p in model.named_parameters() if not any(nd in n for nd in new_layer)], }, 66 | {"params": [p for n, p in model.named_parameters() if any(nd in n for nd in new_layer)], "lr": 1e-4}, 67 | ] 68 | 69 | optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon) 70 | model, optimizer = amp.initialize(model, optimizer, opt_level="O1", verbosity=0) 71 | num_steps = 0 72 | set_seed(args) 73 | model.zero_grad() 74 | finetune(train_features, optimizer, args.num_train_epochs, num_steps) 75 | 76 | 77 | def evaluate(args, model, features, tag="dev"): 78 | 79 | dataloader = DataLoader(features, batch_size=args.test_batch_size, shuffle=False, collate_fn=collate_fn, drop_last=False) 80 | preds = [] 81 | for batch in dataloader: 82 | model.eval() 83 | 84 | inputs = {'input_ids': batch[0].to(args.device), 85 | 'attention_mask': batch[1].to(args.device), 86 | 'entity_pos': batch[3], 87 | 'hts': batch[4], 88 | } 89 | 90 | with torch.no_grad(): 91 | pred, *_ = model(**inputs) 92 | pred = pred.cpu().numpy() 93 | pred[np.isnan(pred)] = 0 94 | preds.append(pred) 95 | 96 | preds = np.concatenate(preds, axis=0).astype(np.float32) 97 | ans = to_official(preds, features) 98 | if len(ans) > 0: 99 | best_f1, _, best_f1_ign, _ = official_evaluate(ans, args.data_dir) 100 | output = { 101 | tag + "_F1": best_f1 * 100, 102 | tag + "_F1_ign": best_f1_ign * 100, 103 | } 104 | return best_f1, output 105 | 106 | 107 | def report(args, model, features): 108 | 109 | dataloader = DataLoader(features, batch_size=args.test_batch_size, shuffle=False, collate_fn=collate_fn, drop_last=False) 110 | preds = [] 111 | for batch in dataloader: 112 | model.eval() 113 | 114 | inputs = {'input_ids': batch[0].to(args.device), 115 | 'attention_mask': batch[1].to(args.device), 116 | 'entity_pos': batch[3], 117 | 'hts': batch[4], 118 | } 119 | 120 | with torch.no_grad(): 121 | pred, *_ = model(**inputs) 122 | pred = pred.cpu().numpy() 123 | pred[np.isnan(pred)] = 0 124 | preds.append(pred) 125 | 126 | preds = np.concatenate(preds, axis=0).astype(np.float32) 127 | preds = to_official(preds, features) 128 | return preds 129 | 130 | 131 | def main(): 132 | parser = argparse.ArgumentParser() 133 | 134 | parser.add_argument("--data_dir", default="./dataset/docred", type=str) 135 | parser.add_argument("--transformer_type", default="bert", type=str) 136 | parser.add_argument("--model_name_or_path", default="bert-base-cased", type=str) 137 | 138 | parser.add_argument("--train_file", default="train_annotated.json", type=str) 139 | parser.add_argument("--dev_file", default="dev.json", type=str) 140 | parser.add_argument("--test_file", default="test.json", type=str) 141 | parser.add_argument("--save_path", default="", type=str) 142 | parser.add_argument("--load_path", default="", type=str) 143 | 144 | parser.add_argument("--config_name", default="", type=str, 145 | help="Pretrained config name or path if not the same as model_name") 146 | parser.add_argument("--tokenizer_name", default="", type=str, 147 | help="Pretrained tokenizer name or path if not the same as model_name") 148 | parser.add_argument("--max_seq_length", default=1024, type=int, 149 | help="The maximum total input sequence length after tokenization. Sequences longer " 150 | "than this will be truncated, sequences shorter will be padded.") 151 | 152 | parser.add_argument("--train_batch_size", default=4, type=int, 153 | help="Batch size for training.") 154 | parser.add_argument("--test_batch_size", default=8, type=int, 155 | help="Batch size for testing.") 156 | parser.add_argument("--gradient_accumulation_steps", default=1, type=int, 157 | help="Number of updates steps to accumulate before performing a backward/update pass.") 158 | parser.add_argument("--num_labels", default=4, type=int, 159 | help="Max number of labels in prediction.") 160 | parser.add_argument("--learning_rate", default=5e-5, type=float, 161 | help="The initial learning rate for Adam.") 162 | parser.add_argument("--adam_epsilon", default=1e-6, type=float, 163 | help="Epsilon for Adam optimizer.") 164 | parser.add_argument("--max_grad_norm", default=1.0, type=float, 165 | help="Max gradient norm.") 166 | parser.add_argument("--warmup_ratio", default=0.06, type=float, 167 | help="Warm up ratio for Adam.") 168 | parser.add_argument("--num_train_epochs", default=30.0, type=float, 169 | help="Total number of training epochs to perform.") 170 | parser.add_argument("--evaluation_steps", default=-1, type=int, 171 | help="Number of training steps between evaluations.") 172 | parser.add_argument("--seed", type=int, default=66, 173 | help="random seed for initialization") 174 | parser.add_argument("--num_class", type=int, default=97, 175 | help="Number of relation types in dataset.") 176 | 177 | args = parser.parse_args() 178 | wandb.init(project="DocRED") 179 | 180 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 181 | args.n_gpu = torch.cuda.device_count() 182 | args.device = device 183 | 184 | config = AutoConfig.from_pretrained( 185 | args.config_name if args.config_name else args.model_name_or_path, 186 | num_labels=args.num_class, 187 | ) 188 | tokenizer = AutoTokenizer.from_pretrained( 189 | args.tokenizer_name if args.tokenizer_name else args.model_name_or_path, 190 | ) 191 | 192 | read = read_docred 193 | 194 | train_file = os.path.join(args.data_dir, args.train_file) 195 | dev_file = os.path.join(args.data_dir, args.dev_file) 196 | test_file = os.path.join(args.data_dir, args.test_file) 197 | train_features = read(train_file, tokenizer, max_seq_length=args.max_seq_length) 198 | dev_features = read(dev_file, tokenizer, max_seq_length=args.max_seq_length) 199 | test_features = read(test_file, tokenizer, max_seq_length=args.max_seq_length) 200 | 201 | model = AutoModel.from_pretrained( 202 | args.model_name_or_path, 203 | from_tf=bool(".ckpt" in args.model_name_or_path), 204 | config=config, 205 | ) 206 | 207 | config.cls_token_id = tokenizer.cls_token_id 208 | config.sep_token_id = tokenizer.sep_token_id 209 | config.transformer_type = args.transformer_type 210 | 211 | set_seed(args) 212 | model = DocREModel(config, model, num_labels=args.num_labels) 213 | model.to(0) 214 | 215 | if args.load_path == "": # Training 216 | train(args, model, train_features, dev_features, test_features) 217 | else: # Testing 218 | model = amp.initialize(model, opt_level="O1", verbosity=0) 219 | model.load_state_dict(torch.load(args.load_path)) 220 | dev_score, dev_output = evaluate(args, model, dev_features, tag="dev") 221 | print(dev_output) 222 | pred = report(args, model, test_features) 223 | with open("result.json", "w") as fh: 224 | json.dump(pred, fh) 225 | 226 | 227 | if __name__ == "__main__": 228 | main() 229 | -------------------------------------------------------------------------------- /train_bio.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | 4 | import numpy as np 5 | import torch 6 | from apex import amp 7 | from torch.utils.data import DataLoader 8 | from transformers import AutoConfig, AutoModel, AutoTokenizer 9 | from transformers.optimization import AdamW, get_linear_schedule_with_warmup 10 | from model import DocREModel 11 | from utils import set_seed, collate_fn 12 | from prepro import read_cdr, read_gda 13 | import wandb 14 | 15 | 16 | def train(args, model, train_features, dev_features, test_features): 17 | def finetune(features, optimizer, num_epoch, num_steps): 18 | best_score = -1 19 | train_dataloader = DataLoader(features, batch_size=args.train_batch_size, shuffle=True, collate_fn=collate_fn, drop_last=True) 20 | train_iterator = range(int(num_epoch)) 21 | total_steps = int(len(train_dataloader) * num_epoch // args.gradient_accumulation_steps) 22 | warmup_steps = int(total_steps * args.warmup_ratio) 23 | scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=warmup_steps, num_training_steps=total_steps) 24 | print("Total steps: {}".format(total_steps)) 25 | print("Warmup steps: {}".format(warmup_steps)) 26 | for epoch in train_iterator: 27 | model.zero_grad() 28 | for step, batch in enumerate(train_dataloader): 29 | model.train() 30 | inputs = {'input_ids': batch[0].to(args.device), 31 | 'attention_mask': batch[1].to(args.device), 32 | 'labels': batch[2], 33 | 'entity_pos': batch[3], 34 | 'hts': batch[4], 35 | } 36 | outputs = model(**inputs) 37 | loss = outputs[0] / args.gradient_accumulation_steps 38 | with amp.scale_loss(loss, optimizer) as scaled_loss: 39 | scaled_loss.backward() 40 | if step % args.gradient_accumulation_steps == 0: 41 | if args.max_grad_norm > 0: 42 | torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm) 43 | optimizer.step() 44 | scheduler.step() 45 | model.zero_grad() 46 | num_steps += 1 47 | wandb.log({"loss": loss.item()}, step=num_steps) 48 | if (step + 1) == len(train_dataloader) - 1 or (args.evaluation_steps > 0 and num_steps % args.evaluation_steps == 0 and step % args.gradient_accumulation_steps == 0): 49 | dev_score, dev_output = evaluate(args, model, dev_features, tag="dev") 50 | test_score, test_output = evaluate(args, model, test_features, tag="test") 51 | print(dev_output) 52 | print(test_output) 53 | wandb.log(dev_output, step=num_steps) 54 | wandb.log(test_output, step=num_steps) 55 | if dev_score > best_score: 56 | best_score = dev_score 57 | if args.save_path != "": 58 | torch.save(model.state_dict(), args.save_path) 59 | 60 | return num_steps 61 | 62 | new_layer = ["extractor", "bilinear"] 63 | optimizer_grouped_parameters = [ 64 | {"params": [p for n, p in model.named_parameters() if not any(nd in n for nd in new_layer)], }, 65 | {"params": [p for n, p in model.named_parameters() if any(nd in n for nd in new_layer)], "lr": 1e-4}, 66 | ] 67 | 68 | optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon) 69 | model, optimizer = amp.initialize(model, optimizer, opt_level="O1", verbosity=0) 70 | num_steps = 0 71 | set_seed(args) 72 | model.zero_grad() 73 | finetune(train_features, optimizer, args.num_train_epochs, num_steps) 74 | 75 | 76 | def evaluate(args, model, features, tag="dev"): 77 | 78 | dataloader = DataLoader(features, batch_size=args.test_batch_size, shuffle=False, collate_fn=collate_fn, drop_last=False) 79 | preds, golds = [], [] 80 | for batch in dataloader: 81 | model.eval() 82 | 83 | inputs = {'input_ids': batch[0].to(args.device), 84 | 'attention_mask': batch[1].to(args.device), 85 | 'entity_pos': batch[3], 86 | 'hts': batch[4], 87 | } 88 | 89 | with torch.no_grad(): 90 | pred, *_ = model(**inputs) 91 | pred = pred.cpu().numpy() 92 | pred[np.isnan(pred)] = 0 93 | preds.append(pred) 94 | golds.append(np.concatenate([np.array(label, np.float32) for label in batch[2]], axis=0)) 95 | 96 | preds = np.concatenate(preds, axis=0).astype(np.float32) 97 | golds = np.concatenate(golds, axis=0).astype(np.float32) 98 | 99 | tp = ((preds[:, 1] == 1) & (golds[:, 1] == 1)).astype(np.float32).sum() 100 | tn = ((golds[:, 1] == 1) & (preds[:, 1] != 1)).astype(np.float32).sum() 101 | fp = ((preds[:, 1] == 1) & (golds[:, 1] != 1)).astype(np.float32).sum() 102 | precision = tp / (tp + fp + 1e-5) 103 | recall = tp / (tp + tn + 1e-5) 104 | f1 = 2 * precision * recall / (precision + recall + 1e-5) 105 | output = { 106 | "{}_f1".format(tag): f1 * 100, 107 | } 108 | return f1, output 109 | 110 | 111 | def main(): 112 | parser = argparse.ArgumentParser() 113 | 114 | parser.add_argument("--data_dir", default="./dataset/cdr", type=str) 115 | parser.add_argument("--transformer_type", default="bert", type=str) 116 | parser.add_argument("--model_name_or_path", default="allenai/scibert_scivocab_cased", type=str) 117 | 118 | parser.add_argument("--train_file", default="train_filter.data", type=str) 119 | parser.add_argument("--dev_file", default="dev_filter.data", type=str) 120 | parser.add_argument("--test_file", default="test_filter.data", type=str) 121 | parser.add_argument("--save_path", default="", type=str) 122 | parser.add_argument("--load_path", default="", type=str) 123 | 124 | parser.add_argument("--config_name", default="", type=str, 125 | help="Pretrained config name or path if not the same as model_name") 126 | parser.add_argument("--tokenizer_name", default="", type=str, 127 | help="Pretrained tokenizer name or path if not the same as model_name") 128 | parser.add_argument("--max_seq_length", default=1024, type=int, 129 | help="The maximum total input sequence length after tokenization. Sequences longer " 130 | "than this will be truncated, sequences shorter will be padded.") 131 | 132 | parser.add_argument("--train_batch_size", default=4, type=int, 133 | help="Batch size for training.") 134 | parser.add_argument("--test_batch_size", default=8, type=int, 135 | help="Batch size for testing.") 136 | parser.add_argument("--gradient_accumulation_steps", default=1, type=int, 137 | help="Number of updates steps to accumulate before performing a backward/update pass.") 138 | parser.add_argument("--num_labels", default=1, type=int, 139 | help="Max number of labels in the prediction.") 140 | parser.add_argument("--learning_rate", default=2e-5, type=float, 141 | help="The initial learning rate for Adam.") 142 | parser.add_argument("--adam_epsilon", default=1e-6, type=float, 143 | help="Epsilon for Adam optimizer.") 144 | parser.add_argument("--max_grad_norm", default=1.0, type=float, 145 | help="Max gradient norm.") 146 | parser.add_argument("--warmup_ratio", default=0.06, type=float, 147 | help="Warm up ratio for Adam.") 148 | parser.add_argument("--num_train_epochs", default=30.0, type=float, 149 | help="Total number of training epochs to perform.") 150 | parser.add_argument("--evaluation_steps", default=-1, type=int, 151 | help="Number of training steps between evaluations.") 152 | parser.add_argument("--seed", type=int, default=66, 153 | help="random seed for initialization.") 154 | parser.add_argument("--num_class", type=int, default=2, 155 | help="Number of relation types in dataset.") 156 | 157 | args = parser.parse_args() 158 | wandb.init(project="CDR") 159 | 160 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 161 | args.n_gpu = torch.cuda.device_count() 162 | args.device = device 163 | 164 | config = AutoConfig.from_pretrained( 165 | args.config_name if args.config_name else args.model_name_or_path, 166 | num_labels=args.num_class, 167 | ) 168 | tokenizer = AutoTokenizer.from_pretrained( 169 | args.tokenizer_name if args.tokenizer_name else args.model_name_or_path, 170 | ) 171 | 172 | read = read_cdr if "cdr" in args.data_dir else read_gda 173 | 174 | train_file = os.path.join(args.data_dir, args.train_file) 175 | dev_file = os.path.join(args.data_dir, args.dev_file) 176 | test_file = os.path.join(args.data_dir, args.test_file) 177 | train_features = read(train_file, tokenizer, max_seq_length=args.max_seq_length) 178 | dev_features = read(dev_file, tokenizer, max_seq_length=args.max_seq_length) 179 | test_features = read(test_file, tokenizer, max_seq_length=args.max_seq_length) 180 | 181 | model = AutoModel.from_pretrained( 182 | args.model_name_or_path, 183 | from_tf=bool(".ckpt" in args.model_name_or_path), 184 | config=config, 185 | ) 186 | 187 | config.cls_token_id = tokenizer.cls_token_id 188 | config.sep_token_id = tokenizer.sep_token_id 189 | config.transformer_type = args.transformer_type 190 | 191 | set_seed(args) 192 | model = DocREModel(config, model, num_labels=args.num_labels) 193 | model.to(0) 194 | 195 | if args.load_path == "": 196 | train(args, model, train_features, dev_features, test_features) 197 | else: 198 | model = amp.initialize(model, opt_level="O1", verbosity=0) 199 | model.load_state_dict(torch.load(args.load_path)) 200 | dev_score, dev_output = evaluate(args, model, dev_features, tag="dev") 201 | test_score, test_output = evaluate(args, model, test_features, tag="test") 202 | print(dev_output) 203 | print(test_output) 204 | 205 | 206 | if __name__ == "__main__": 207 | main() 208 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import random 3 | import numpy as np 4 | 5 | 6 | def set_seed(args): 7 | random.seed(args.seed) 8 | np.random.seed(args.seed) 9 | torch.manual_seed(args.seed) 10 | if args.n_gpu > 0 and torch.cuda.is_available(): 11 | torch.cuda.manual_seed_all(args.seed) 12 | 13 | 14 | def collate_fn(batch): 15 | max_len = max([len(f["input_ids"]) for f in batch]) 16 | input_ids = [f["input_ids"] + [0] * (max_len - len(f["input_ids"])) for f in batch] 17 | input_mask = [[1.0] * len(f["input_ids"]) + [0.0] * (max_len - len(f["input_ids"])) for f in batch] 18 | labels = [f["labels"] for f in batch] 19 | entity_pos = [f["entity_pos"] for f in batch] 20 | hts = [f["hts"] for f in batch] 21 | input_ids = torch.tensor(input_ids, dtype=torch.long) 22 | input_mask = torch.tensor(input_mask, dtype=torch.float) 23 | output = (input_ids, input_mask, labels, entity_pos, hts) 24 | return output 25 | --------------------------------------------------------------------------------