├── .gitignore ├── CompWebQ ├── data.py ├── model.py ├── predict.py └── train.py ├── MetaQA-KB ├── Knowledge_graph.py ├── data.py ├── model.py ├── predict.py ├── preprocess.py └── train.py ├── MetaQA-Text ├── data.py ├── model.py ├── predict.py ├── preprocess.py └── train.py ├── README.md ├── WebQSP ├── data.py ├── model.py ├── predict.py └── train.py ├── example.png ├── pickle_glove.py └── utils ├── BiGRU.py ├── __init__.py ├── lr_scheduler.py └── misc.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.cache 2 | *.sublime-workspace 3 | *.sublime-project 4 | 5 | -------------------------------------------------------------------------------- /CompWebQ/data.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | import json 4 | import pickle 5 | from collections import defaultdict 6 | from transformers import AutoTokenizer 7 | from utils.misc import invert_dict 8 | 9 | def collate(batch): 10 | batch = list(zip(*batch)) 11 | topic_entity, question, answer, triples, entity_range = batch 12 | topic_entity = torch.stack(topic_entity) 13 | question = {k:torch.cat([q[k] for q in question], dim=0) for k in question[0]} 14 | answer = torch.stack(answer) 15 | entity_range = torch.stack(entity_range) 16 | return topic_entity, question, answer, triples, entity_range 17 | 18 | 19 | class Dataset(torch.utils.data.Dataset): 20 | def __init__(self, questions, ent2id): 21 | self.questions = questions 22 | self.ent2id = ent2id 23 | 24 | def __getitem__(self, index): 25 | topic_entity, question, answer, triples, entity_range = self.questions[index] 26 | topic_entity = self.toOneHot(topic_entity) 27 | answer = self.toOneHot(answer) 28 | triples = torch.LongTensor(triples) 29 | if triples.dim() == 1: 30 | triples = triples.unsqueeze(0) 31 | entity_range = self.toOneHot(entity_range) 32 | return topic_entity, question, answer, triples, entity_range 33 | 34 | def __len__(self): 35 | return len(self.questions) 36 | 37 | def toOneHot(self, indices): 38 | indices = torch.LongTensor(indices) 39 | vec_len = len(self.ent2id) 40 | one_hot = torch.FloatTensor(vec_len) 41 | one_hot.zero_() 42 | one_hot.scatter_(0, indices, 1) 43 | return one_hot 44 | 45 | 46 | class DataLoader(torch.utils.data.DataLoader): 47 | def __init__(self, fn, bert_name, ent2id, rel2id, batch_size, add_rev=False, training=False): 48 | print('Reading questions from {} {}'.format(fn, '(add reverse)' if add_rev else '')) 49 | self.tokenizer = AutoTokenizer.from_pretrained(bert_name) 50 | self.ent2id = ent2id 51 | self.rel2id = rel2id 52 | self.id2ent = invert_dict(ent2id) 53 | self.id2rel = invert_dict(rel2id) 54 | 55 | data = [] 56 | cnt_bad = 0 57 | for line in open(fn): 58 | instance = json.loads(line.strip()) 59 | 60 | question = self.tokenizer(instance['question'].strip(), max_length=64, padding='max_length', return_tensors="pt") 61 | head = instance['entities'] 62 | ans = [ent2id[a['kb_id']] for a in instance['answers']] 63 | triples = instance['subgraph']['tuples'] 64 | 65 | if len(triples) == 0: 66 | continue 67 | 68 | sub_ents = set(t[0] for t in triples) 69 | obj_ents = set(t[2] for t in triples) 70 | entity_range = sub_ents | obj_ents 71 | 72 | is_bad = False 73 | if all(e not in entity_range for e in head): 74 | is_bad = True 75 | if all(e not in entity_range for e in ans): 76 | is_bad = True 77 | 78 | if is_bad: 79 | cnt_bad += 1 80 | 81 | if training and is_bad: # skip bad examples during training 82 | continue 83 | 84 | entity_range = list(entity_range) 85 | 86 | if add_rev: 87 | supply_triples = [] 88 | # add self relation 89 | # for e in entity_range: 90 | # supply_triples.append([e, self.rel2id[''], e]) 91 | # add reverse relation 92 | for s, r, o in triples: 93 | rev_r = self.rel2id[self.id2rel[r]+'_rev'] 94 | supply_triples.append([o, rev_r, s]) 95 | triples += supply_triples 96 | 97 | data.append([head, question, ans, triples, entity_range]) 98 | 99 | print('data number: {}, bad number: {} (exluded in training)'.format(len(data), cnt_bad)) 100 | 101 | dataset = Dataset(data, ent2id) 102 | 103 | super().__init__( 104 | dataset, 105 | batch_size=batch_size, 106 | shuffle=training, 107 | collate_fn=collate, 108 | ) 109 | 110 | # need to download the data from https://github.com/RichardHGL/WSDM2021_NSM 111 | def load_data(input_dir, bert_name, batch_size, add_rev=False): 112 | cache_fn = os.path.join(input_dir, 'cache{}.pt'.format('_rev' if add_rev else '')) 113 | if os.path.exists(cache_fn): 114 | print('Read from cache file: {} (NOTE: delete it if you modified data loading process)'.format(cache_fn)) 115 | with open(cache_fn, 'rb') as fp: 116 | ent2id, rel2id, train_data, dev_data, test_data = pickle.load(fp) 117 | print('Train number: {}, dev number: {}, test number: {}'.format( 118 | len(train_data.dataset), len(dev_data.dataset), len(test_data.dataset))) 119 | else: 120 | print('Read data...') 121 | ent2id = {} 122 | for line in open(os.path.join(input_dir, 'entities.txt')): 123 | ent2id[line.strip()] = len(ent2id) 124 | print(len(ent2id)) 125 | rel2id = {} 126 | for line in open(os.path.join(input_dir, 'relations.txt')): 127 | rel2id[line.strip()] = len(rel2id) 128 | # add self relation and reverse relation 129 | # rel2id[''] = len(rel2id) 130 | if add_rev: 131 | for line in open(os.path.join(input_dir, 'relations.txt')): 132 | rel2id[line.strip()+'_rev'] = len(rel2id) 133 | print(len(rel2id)) 134 | 135 | train_data = DataLoader(os.path.join(input_dir, 'train_simple.json'), bert_name, ent2id, rel2id, batch_size, add_rev=add_rev, training=True) 136 | dev_data = DataLoader(os.path.join(input_dir, 'dev_simple.json'), bert_name, ent2id, rel2id, batch_size, add_rev=add_rev) 137 | test_data = DataLoader(os.path.join(input_dir, 'test_simple.json'), bert_name, ent2id, rel2id, batch_size, add_rev=add_rev) 138 | 139 | with open(cache_fn, 'wb') as fp: 140 | pickle.dump((ent2id, rel2id, train_data, dev_data, test_data), fp) 141 | 142 | return ent2id, rel2id, train_data, dev_data, test_data 143 | 144 | 145 | 146 | def cnt_hops(input_dir): 147 | def bfs(triples, start, end): 148 | if len(start)==0 or len(end)==0: 149 | return 1000, 1000 150 | 151 | hops = {i:0 for i in start} 152 | cur_set = set(start) 153 | next_set = set() 154 | for h in range(5): 155 | for s,r,o in triples: 156 | if s in cur_set and o not in hops: 157 | hops[o] = h 158 | next_set.add(o) 159 | cur_set = next_set 160 | next_set = set() 161 | only_forwad_res = min(hops.get(i,1000) for i in end) 162 | 163 | hops = {i:0 for i in start} 164 | cur_set = set(start) 165 | next_set = set() 166 | for h in range(5): 167 | for s,r,o in triples: 168 | if s in cur_set and o not in hops: 169 | hops[o] = h 170 | next_set.add(o) 171 | if o in cur_set and s not in hops: 172 | hops[s] = h 173 | next_set.add(s) 174 | cur_set = next_set 175 | next_set = set() 176 | add_reverse_res = min(hops.get(i,1000) for i in end) 177 | 178 | return only_forwad_res, add_reverse_res 179 | 180 | 181 | ent2id = {} 182 | for line in open(os.path.join(input_dir, 'entities.txt')): 183 | ent2id[line.strip()] = len(ent2id) 184 | 185 | 186 | for fn in ['train_simple.json', 'test_simple.json']: 187 | only_forward_res = defaultdict(int) 188 | add_reverse_res = defaultdict(int) 189 | for line in open(os.path.join(input_dir, fn)): 190 | instance = json.loads(line.strip()) 191 | triples = instance['subgraph']['tuples'] 192 | head = instance['entities'] 193 | ans = [ent2id[a['kb_id']] for a in instance['answers']] 194 | i, j = bfs(triples, head, ans) 195 | only_forward_res[i] += 1 196 | add_reverse_res[j] += 1 197 | 198 | print(fn) 199 | print(only_forward_res) 200 | print(add_reverse_res) # increase the ratio of 1-hop 201 | 202 | if __name__ == '__main__': 203 | cnt_hops('/data/sjx/dataset/WSDM_processed/CWQ') 204 | -------------------------------------------------------------------------------- /CompWebQ/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import math 4 | from transformers import AutoModel 5 | from utils.BiGRU import GRU, BiGRU 6 | 7 | class TransferNet(nn.Module): 8 | def __init__(self, args, ent2id, rel2id): 9 | super().__init__() 10 | num_relations = len(rel2id) 11 | self.num_ents = len(ent2id) 12 | self.num_steps = args.num_steps 13 | self.num_ways = args.num_ways 14 | 15 | self.bert_encoder = AutoModel.from_pretrained(args.bert_name, return_dict=True) 16 | dim_hidden = self.bert_encoder.config.hidden_size 17 | 18 | self.step_encoders = {} 19 | self.hop_selectors = {} 20 | self.rel_classifiers = {} 21 | for i in range(self.num_ways): 22 | for j in range(self.num_steps): 23 | m = nn.Sequential( 24 | nn.Linear(dim_hidden*2, dim_hidden), 25 | nn.Tanh() 26 | ) 27 | name = 'way_{}_step_{}'.format(i, j) 28 | self.step_encoders[name] = m 29 | self.add_module(name, m) 30 | 31 | m = nn.Linear(dim_hidden, self.num_steps) 32 | self.hop_selectors['way_{}'.format(i)] = m 33 | self.add_module('hop-way_{}'.format(i), m) 34 | 35 | m = nn.Linear(dim_hidden, num_relations) 36 | self.rel_classifiers['way_{}'.format(i)] = m 37 | self.add_module('rel-way_{}'.format(i), m) 38 | 39 | 40 | 41 | def forward(self, heads, questions, answers=None, triples=None, entity_range=None): 42 | q = self.bert_encoder(**questions) 43 | q_embeddings, q_word_h = q.pooler_output, q.last_hidden_state # (bsz, dim_h), (bsz, len, dim_h) 44 | bsz = len(heads) 45 | device = heads.device 46 | 47 | e_score = [] 48 | last_h = torch.zeros_like(q_embeddings) 49 | for w in range(self.num_ways): 50 | last_e = heads 51 | word_attns = [] 52 | rel_probs = [] 53 | ent_probs = [] 54 | for t in range(self.num_steps): 55 | cq_t = self.step_encoders['way_{}_step_{}'.format(w, t)]( 56 | torch.cat((q_embeddings, last_h), dim=1) # consider history 57 | ) # [bsz, dim_h] 58 | q_logits = torch.sum(cq_t.unsqueeze(1) * q_word_h, dim=2) # [bsz, max_q] 59 | q_dist = torch.softmax(q_logits, 1) # [bsz, max_q] 60 | q_dist = q_dist * questions['attention_mask'].float() 61 | q_dist = q_dist / (torch.sum(q_dist, dim=1, keepdim=True) + 1e-6) # [bsz, max_q] 62 | word_attns.append(q_dist) 63 | ctx_h = (q_dist.unsqueeze(1) @ q_word_h).squeeze(1) # [bsz, dim_h] 64 | ctx_h = ctx_h + cq_t 65 | last_h = ctx_h 66 | 67 | rel_logit = self.rel_classifiers['way_{}'.format(w)](ctx_h) # [bsz, num_relations] 68 | # rel_dist = torch.softmax(rel_logit, 1) # bad 69 | rel_dist = torch.sigmoid(rel_logit) 70 | rel_probs.append(rel_dist) 71 | 72 | new_e = [] 73 | for b in range(bsz): 74 | sub, rel, obj = triples[b][:,0], triples[b][:,1], triples[b][:,2] 75 | sub_p = last_e[b:b+1, sub] # [1, #tri] 76 | rel_p = rel_dist[b:b+1, rel] # [1, #tri] 77 | obj_p = sub_p * rel_p 78 | new_e.append( 79 | torch.index_add(torch.zeros(1, self.num_ents).to(device), 1, obj, obj_p)) 80 | last_e = torch.cat(new_e, dim=0) 81 | 82 | # reshape >1 scores to 1 in a differentiable way 83 | m = last_e.gt(1).float() 84 | z = (m * last_e + (1-m)).detach() 85 | last_e = last_e / z 86 | 87 | ent_probs.append(last_e) 88 | 89 | hop_res = torch.stack(ent_probs, dim=1) # [bsz, num_hop, num_ent] 90 | hop_logit = self.hop_selectors['way_{}'.format(w)](q_embeddings) 91 | hop_attn = torch.softmax(hop_logit, dim=1).unsqueeze(2) # [bsz, num_hop, 1] 92 | last_e = torch.sum(hop_res * hop_attn, dim=1) # [bsz, num_ent] 93 | 94 | e_score.append(last_e) 95 | 96 | e_score = torch.prod(torch.stack(e_score), dim=0) 97 | 98 | if not self.training: 99 | return { 100 | 'e_score': e_score, 101 | 'word_attns': word_attns, 102 | 'rel_probs': rel_probs, 103 | 'ent_probs': ent_probs, 104 | # 'hop_attn': hop_attn.squeeze(2) 105 | } 106 | else: 107 | weight = answers * 9 + 1 108 | loss = torch.sum(entity_range * weight * torch.pow(last_e - answers, 2)) / torch.sum(entity_range * weight) 109 | 110 | return {'loss': loss} 111 | -------------------------------------------------------------------------------- /CompWebQ/predict.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.nn as nn 4 | import argparse 5 | from tqdm import tqdm 6 | from collections import defaultdict 7 | from utils.misc import batch_device 8 | from .data import load_data 9 | from .model import TransferNet 10 | 11 | from IPython import embed 12 | 13 | 14 | def validate(args, model, data, device, verbose = False): 15 | model.eval() 16 | count = 0 17 | correct = 0 18 | hop_count = defaultdict(list) 19 | with torch.no_grad(): 20 | for batch in tqdm(data, total=len(data)): 21 | outputs = model(*batch_device(batch, device)) # [bsz, Esize] 22 | e_score = outputs['e_score'].cpu() 23 | scores, idx = torch.max(e_score, dim = 1) # [bsz], [bsz] 24 | match_score = torch.gather(batch[2], 1, idx.unsqueeze(-1)).squeeze().tolist() 25 | count += len(match_score) 26 | correct += sum(match_score) 27 | # for i in range(len(match_score)): 28 | # h = outputs['hop_attn'][i].argmax().item() 29 | # hop_count[h].append(match_score[i]) 30 | 31 | if verbose: 32 | answers = batch[2] 33 | for i in range(len(match_score)): 34 | if match_score[i] == 0: 35 | print('================================================================') 36 | question_ids = batch[1]['input_ids'][i].tolist() 37 | question_tokens = data.tokenizer.convert_ids_to_tokens(question_ids) 38 | print(' '.join(question_tokens)) 39 | topic_id = batch[0][i].argmax(0).item() 40 | print('> topic entity: {}'.format(data.id2ent[topic_id])) 41 | for t in range(2): 42 | print('>>>>>>> step {}'.format(t)) 43 | tmp = ' '.join(['{}: {:.3f}'.format(x, y) for x,y in 44 | zip(question_tokens, outputs['word_attns'][t][i].tolist())]) 45 | print('> Attention: ' + tmp) 46 | print('> Relation:') 47 | rel_idx = outputs['rel_probs'][t][i].gt(0.9).nonzero().squeeze(1).tolist() 48 | for x in rel_idx: 49 | print(' {}: {:.3f}'.format(data.id2rel[x], outputs['rel_probs'][t][i][x].item())) 50 | 51 | print('> Entity: {}'.format('; '.join([data.id2ent[_] for _ in outputs['ent_probs'][t][i].gt(0.8).nonzero().squeeze(1).tolist()]))) 52 | print('----') 53 | print('> max is {}'.format(data.id2ent[idx[i].item()])) 54 | print('> golden: {}'.format('; '.join([data.id2ent[_] for _ in answers[i].gt(0.9).nonzero().squeeze(1).tolist()]))) 55 | print('> prediction: {}'.format('; '.join([data.id2ent[_] for _ in e_score[i].gt(0.9).nonzero().squeeze(1).tolist()]))) 56 | print(' '.join(question_tokens)) 57 | print(outputs['hop_attn'][i].tolist()) 58 | embed() 59 | acc = correct / count 60 | print(acc) 61 | # print('pred hop accuracy: 1-hop {} (total {}), 2-hop {} (total {})'.format( 62 | # sum(hop_count[0])/(len(hop_count[0])+0.1), 63 | # len(hop_count[0]), 64 | # sum(hop_count[1])/(len(hop_count[1])+0.1), 65 | # len(hop_count[1]), 66 | # )) 67 | return acc 68 | 69 | 70 | def main(): 71 | parser = argparse.ArgumentParser() 72 | # input and output 73 | parser.add_argument('--input_dir', default = './input') 74 | parser.add_argument('--ckpt', required = True) 75 | parser.add_argument('--mode', default='val', choices=['val', 'vis', 'test']) 76 | args = parser.parse_args() 77 | 78 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 79 | ent2id, rel2id, triples, train_loader, val_loader = load_data(args.input_dir, 16) 80 | 81 | model = TransferNet(args, ent2id, rel2id, triples) 82 | missing, unexpected = model.load_state_dict(torch.load(args.ckpt), strict=False) 83 | if missing: 84 | print("Missing keys: {}".format("; ".join(missing))) 85 | if unexpected: 86 | print("Unexpected keys: {}".format("; ".join(unexpected))) 87 | model = model.to(device) 88 | # model.triples = model.triples.to(device) 89 | model.Msubj = model.Msubj.to(device) 90 | model.Mobj = model.Mobj.to(device) 91 | model.Mrel = model.Mrel.to(device) 92 | 93 | if args.mode == 'vis': 94 | validate(args, model, val_loader, device, True) 95 | elif args.mode == 'val': 96 | validate(args, model, val_loader, device, False) 97 | 98 | if __name__ == '__main__': 99 | main() 100 | -------------------------------------------------------------------------------- /CompWebQ/train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.nn as nn 4 | import argparse 5 | from tqdm import tqdm 6 | import numpy as np 7 | import time 8 | from utils.misc import MetricLogger, batch_device 9 | from .data import load_data 10 | from .model import TransferNet 11 | from .predict import validate 12 | from transformers import AdamW, get_linear_schedule_with_warmup 13 | import logging 14 | logging.basicConfig(level=logging.INFO, format='%(asctime)s %(levelname)-8s %(message)s') 15 | logFormatter = logging.Formatter('%(asctime)s %(levelname)-8s %(message)s') 16 | rootLogger = logging.getLogger() 17 | 18 | torch.set_num_threads(1) # avoid using multiple cpus 19 | 20 | 21 | def train(args): 22 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 23 | 24 | ent2id, rel2id, train_loader, val_loader, test_loader = load_data(args.input_dir, args.bert_name, args.batch_size, args.rev) 25 | logging.info("Create model.........") 26 | model = TransferNet(args, ent2id, rel2id) 27 | if not args.ckpt == None: 28 | model.load_state_dict(torch.load(args.ckpt)) 29 | model = model.to(device) 30 | logging.info(model) 31 | 32 | 33 | t_total = len(train_loader) * args.num_epoch 34 | no_decay = ["bias", "LayerNorm.weight"] 35 | bert_param = [(n,p) for n,p in model.named_parameters() if n.startswith('bert_encoder')] 36 | other_param = [(n,p) for n,p in model.named_parameters() if not n.startswith('bert_encoder')] 37 | print('number of bert param: {}'.format(len(bert_param))) 38 | optimizer_grouped_parameters = [ 39 | {'params': [p for n, p in bert_param if not any(nd in n for nd in no_decay)], 40 | 'weight_decay': args.weight_decay, 'lr': args.bert_lr}, 41 | {'params': [p for n, p in bert_param if any(nd in n for nd in no_decay)], 42 | 'weight_decay': 0.0, 'lr': args.bert_lr}, 43 | {'params': [p for n, p in other_param if not any(nd in n for nd in no_decay)], 44 | 'weight_decay': args.weight_decay, 'lr': args.lr}, 45 | {'params': [p for n, p in other_param if any(nd in n for nd in no_decay)], 46 | 'weight_decay': 0.0, 'lr': args.lr}, 47 | ] 48 | 49 | optimizer = AdamW(optimizer_grouped_parameters) 50 | args.warmup_steps = int(t_total * args.warmup_proportion) 51 | scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total) 52 | meters = MetricLogger(delimiter=" ") 53 | validate(args, model, val_loader, device) 54 | logging.info("Start training........") 55 | 56 | for epoch in range(args.num_epoch): 57 | model.train() 58 | for iteration, batch in enumerate(train_loader): 59 | iteration = iteration + 1 60 | loss = model(*batch_device(batch, device)) 61 | optimizer.zero_grad() 62 | if isinstance(loss, dict): 63 | if len(loss) > 1: 64 | total_loss = sum(loss.values()) 65 | else: 66 | total_loss = loss[list(loss.keys())[0]] 67 | meters.update(**{k:v.item() for k,v in loss.items()}) 68 | else: 69 | total_loss = loss 70 | meters.update(loss=loss.item()) 71 | total_loss.backward() 72 | nn.utils.clip_grad_value_(model.parameters(), 0.5) 73 | nn.utils.clip_grad_norm_(model.parameters(), 2) 74 | optimizer.step() 75 | scheduler.step() 76 | 77 | if iteration % (len(train_loader) // 10) == 0: 78 | # if True: 79 | 80 | logging.info( 81 | meters.delimiter.join( 82 | [ 83 | "progress: {progress:.3f}", 84 | "{meters}", 85 | "lr: {lr:.6f}", 86 | ] 87 | ).format( 88 | progress=epoch + iteration / len(train_loader), 89 | meters=str(meters), 90 | lr=optimizer.param_groups[0]["lr"], 91 | ) 92 | ) 93 | if (epoch+1)%5 == 0: 94 | val_acc = validate(args, model, val_loader, device) 95 | test_acc = validate(args, model, test_loader, device) 96 | logging.info('val acc: {:.4f}, test acc: {:.4f}'.format(val_acc, test_acc)) 97 | torch.save(model.state_dict(), os.path.join(args.save_dir, 'model-{}-{:.4f}.pt'.format(epoch, test_acc))) 98 | 99 | def main(): 100 | parser = argparse.ArgumentParser() 101 | # input and output 102 | parser.add_argument('--input_dir', required=True, help='path to the data') 103 | parser.add_argument('--save_dir', required=True, help='path to save checkpoints and logs') 104 | parser.add_argument('--ckpt', default = None) 105 | # training parameters 106 | parser.add_argument('--bert_lr', default=3e-5, type=float) 107 | parser.add_argument('--lr', default=0.001, type=float) 108 | parser.add_argument('--weight_decay', default=1e-5, type=float) 109 | parser.add_argument('--num_epoch', default=30, type=int) 110 | parser.add_argument('--batch_size', default=64, type=int) 111 | parser.add_argument('--seed', type=int, default=666, help='random seed') 112 | parser.add_argument('--warmup_proportion', default=0.1, type = float) 113 | # model parameters 114 | parser.add_argument('--rev', action='store_true', help='whether add reversed relations') 115 | parser.add_argument('--num_ways', default=1, type=int) 116 | parser.add_argument('--num_steps', default=2, type=int) 117 | parser.add_argument('--bert_name', default='bert-base-cased', choices=['roberta-base', 'bert-base-cased', 'bert-base-uncased']) 118 | args = parser.parse_args() 119 | 120 | # make logging.info display into both shell and file 121 | if not os.path.exists(args.save_dir): 122 | os.makedirs(args.save_dir) 123 | fileHandler = logging.FileHandler(os.path.join(args.save_dir, 'log.txt')) 124 | fileHandler.setFormatter(logFormatter) 125 | rootLogger.addHandler(fileHandler) 126 | # args display 127 | for k, v in vars(args).items(): 128 | logging.info(k+':'+str(v)) 129 | 130 | torch.backends.cudnn.deterministic = True 131 | torch.backends.cudnn.benchmark = False 132 | # set random seed 133 | torch.manual_seed(args.seed) 134 | np.random.seed(args.seed) 135 | 136 | train(args) 137 | 138 | 139 | if __name__ == '__main__': 140 | main() 141 | -------------------------------------------------------------------------------- /MetaQA-KB/Knowledge_graph.py: -------------------------------------------------------------------------------- 1 | import collections 2 | import os 3 | import pickle 4 | from collections import defaultdict 5 | import torch 6 | import torch.nn as nn 7 | from utils.misc import * 8 | import numpy as np 9 | 10 | class KnowledgeGraph(nn.Module): 11 | def __init__(self, args, vocab): 12 | super(KnowledgeGraph, self).__init__() 13 | self.args = args 14 | self.entity2id, self.id2entity = vocab['entity2id'], vocab['id2entity'] 15 | self.relation2id, self.id2relation = vocab['relation2id'], vocab['id2relation'] 16 | Msubj = torch.from_numpy(np.load(os.path.join(args.input_dir, 'Msubj.npy'))).long() 17 | Mobj = torch.from_numpy(np.load(os.path.join(args.input_dir, 'Mobj.npy'))).long() 18 | Mrel = torch.from_numpy(np.load(os.path.join(args.input_dir, 'Mrel.npy'))).long() 19 | Tsize = Msubj.size()[0] 20 | Esize = len(self.entity2id) 21 | Rsize = len(self.relation2id) 22 | self.Msubj = torch.sparse.FloatTensor(Msubj.t(), torch.FloatTensor([1] * Tsize), torch.Size([Tsize, Esize])) 23 | self.Mobj = torch.sparse.FloatTensor(Mobj.t(), torch.FloatTensor([1] * Tsize), torch.Size([Tsize, Esize])) 24 | self.Mrel = torch.sparse.FloatTensor(Mrel.t(), torch.FloatTensor([1] * Tsize), torch.Size([Tsize, Rsize])) 25 | self.num_entities = len(self.entity2id) 26 | 27 | -------------------------------------------------------------------------------- /MetaQA-KB/data.py: -------------------------------------------------------------------------------- 1 | import json 2 | import pickle 3 | import torch 4 | import numpy as np 5 | from utils.misc import invert_dict 6 | 7 | 8 | def load_vocab(path): 9 | vocab = json.load(open(path)) 10 | vocab['id2word'] = invert_dict(vocab['word2id']) 11 | vocab['id2entity'] = invert_dict(vocab['entity2id']) 12 | vocab['id2relation'] = invert_dict(vocab['relation2id']) 13 | return vocab 14 | 15 | def collate(batch): 16 | batch = list(zip(*batch)) 17 | question, topic_entity, answer = list(map(torch.stack, batch[:3])) 18 | hop = torch.LongTensor(batch[3]) 19 | return question, topic_entity, answer, hop 20 | 21 | 22 | class Dataset(torch.utils.data.Dataset): 23 | def __init__(self, inputs): 24 | self.questions, self.topic_entities, self.answers, self.hops = inputs 25 | # print(self.questions.shape) 26 | # print(self.topic_entities.shape) 27 | # print(self.answers.shape) 28 | 29 | def __getitem__(self, index): 30 | question = torch.LongTensor(self.questions[index]) 31 | topic_entity = torch.LongTensor(self.topic_entities[index]) 32 | answer = torch.LongTensor(self.answers[index]) 33 | hop = self.hops[index] 34 | return question, topic_entity, answer, hop 35 | 36 | 37 | def __len__(self): 38 | return len(self.questions) 39 | 40 | 41 | class DataLoader(torch.utils.data.DataLoader): 42 | def __init__(self, vocab_json, question_pt, batch_size, ratio=1, training=False): 43 | vocab = load_vocab(vocab_json) 44 | 45 | inputs = [] 46 | with open(question_pt, 'rb') as f: 47 | for _ in range(4): 48 | inputs.append(pickle.load(f)) 49 | 50 | if ratio < 1: 51 | total = len(inputs[0]) 52 | num = int(total * ratio) 53 | index = np.random.choice(total, num) 54 | print('random select {} of {} (ratio={})'.format(num, total, ratio)) 55 | inputs = [i[index] for i in inputs] 56 | 57 | dataset = Dataset(inputs) 58 | 59 | super().__init__( 60 | dataset, 61 | batch_size=batch_size, 62 | shuffle=training, 63 | collate_fn=collate, 64 | ) 65 | self.vocab = vocab 66 | 67 | # if __name__ == '__main__': 68 | # vocab_json = '/data/csl/exp/AI_project/SRN/input/vocab.json' 69 | # question_pt = '/data/csl/exp/AI_project/SRN/input/train.pt' 70 | # inputs = [] 71 | # with open(question_pt, 'rb') as f: 72 | # for _ in range(3): 73 | # inputs.append(pickle.load(f)) 74 | # dataset = Dataset(inputs) 75 | # # print(dataset[0]) 76 | # print(len(dataset)) 77 | # question, topic_entity, answer = dataset[0] 78 | # print(question.size()) 79 | # print(topic_entity.size()) 80 | # print(answer.size()) 81 | -------------------------------------------------------------------------------- /MetaQA-KB/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import math 4 | 5 | from utils.BiGRU import GRU, BiGRU 6 | from .Knowledge_graph import KnowledgeGraph 7 | 8 | class TransferNet(nn.Module): 9 | def __init__(self, args, dim_word, dim_hidden, vocab): 10 | super().__init__() 11 | self.args = args 12 | self.vocab = vocab 13 | self.kg = KnowledgeGraph(args, vocab) 14 | num_words = len(vocab['word2id']) 15 | num_entities = len(vocab['entity2id']) 16 | num_relations = len(vocab['relation2id']) 17 | self.num_steps = args.num_steps 18 | self.aux_hop = args.aux_hop 19 | 20 | self.question_encoder = BiGRU(dim_word, dim_hidden, num_layers=1, dropout=0.2) 21 | 22 | self.word_embeddings = nn.Embedding(num_words, dim_word) 23 | self.word_dropout = nn.Dropout(0.2) 24 | self.step_encoders = [] 25 | for i in range(self.num_steps): 26 | m = nn.Sequential( 27 | nn.Linear(dim_hidden, dim_hidden), 28 | nn.Tanh() 29 | ) 30 | self.step_encoders.append(m) 31 | self.add_module('step_encoders_{}'.format(i), m) 32 | self.rel_classifier = nn.Linear(dim_hidden, num_relations) 33 | # self.q_classifier = nn.Linear(dim_hidden, num_entities) 34 | 35 | self.hop_selector = nn.Linear(dim_hidden, self.num_steps) 36 | 37 | 38 | def follow(self, e, r): 39 | x = torch.sparse.mm(self.kg.Msubj, e.t()) * torch.sparse.mm(self.kg.Mrel, r.t()) 40 | return torch.sparse.mm(self.kg.Mobj.t(), x).t() # [bsz, Esize] 41 | 42 | 43 | def forward(self, questions, e_s, answers=None, hop=None): 44 | question_lens = questions.size(1) - questions.eq(0).long().sum(dim=1) # 0 means 45 | q_word_emb = self.word_dropout(self.word_embeddings(questions)) # [bsz, max_q, dim_hidden] 46 | q_word_h, q_embeddings, q_hn = self.question_encoder(q_word_emb, question_lens) # [bsz, max_q, dim_h], [bsz, dim_h], [num_layers, bsz, dim_h] 47 | 48 | device = q_word_h.device 49 | bsz = q_word_h.size(0) 50 | dim_h = q_word_h.size(-1) 51 | last_e = e_s 52 | word_attns = [] 53 | rel_probs = [] 54 | ent_probs = [] 55 | for t in range(self.num_steps): 56 | cq_t = self.step_encoders[t](q_embeddings) # [bsz, dim_h] 57 | q_logits = torch.sum(cq_t.unsqueeze(1) * q_word_h, dim=2) # [bsz, max_q] 58 | q_dist = torch.softmax(q_logits, 1).unsqueeze(1) # [bsz, 1, max_q] 59 | word_attns.append(q_dist.squeeze(1)) 60 | ctx_h = (q_dist @ q_word_h).squeeze(1) # [bsz, dim_h] 61 | rel_dist = torch.softmax(self.rel_classifier(ctx_h), 1) # [bsz, num_relations] 62 | rel_probs.append(rel_dist) 63 | 64 | last_e = self.follow(last_e, rel_dist) 65 | 66 | # reshape >1 scores to 1 in a differentiable way 67 | m = last_e.gt(1).float() 68 | z = (m * last_e + (1-m)).detach() 69 | last_e = last_e / z 70 | 71 | # Specifically for MetaQA: reshape cycle entities to 0, because A-r->B-r_inv->A is not allowed 72 | if t > 0: 73 | prev_rel = torch.argmax(rel_probs[-2], dim=1) 74 | curr_rel = torch.argmax(rel_probs[-1], dim=1) 75 | prev_prev_ent_prob = ent_probs[-2] if len(ent_probs)>=2 else e_s 76 | # in our vocabulary, indices of inverse relations are adjacent. e.g., director:0, director_inv:1 77 | m = torch.zeros((bsz,1)).to(device) 78 | m[(torch.abs(prev_rel-curr_rel)==1) & (torch.remainder(torch.min(prev_rel,curr_rel),2)==0)] = 1 79 | ent_m = m.float() * prev_prev_ent_prob.gt(0.9).float() 80 | last_e = (1-ent_m) * last_e 81 | 82 | ent_probs.append(last_e) 83 | 84 | hop_res = torch.stack(ent_probs, dim=1) # [bsz, num_hop, num_ent] 85 | hop_logit = self.hop_selector(q_embeddings) 86 | hop_attn = torch.softmax(hop_logit, dim=1) # [bsz, num_hop] 87 | last_e = torch.sum(hop_res * hop_attn.unsqueeze(2), dim=1) # [bsz, num_ent] 88 | 89 | # Specifically for MetaQA: for 2-hop questions, topic entity is excluded from answer 90 | m = hop_attn.argmax(dim=1).eq(1).float().unsqueeze(1) * e_s 91 | last_e = (1-m) * last_e 92 | 93 | # question mask, incorporate language bias 94 | # q_mask = torch.sigmoid(self.q_classifier(q_embeddings)) 95 | # last_e = last_e * q_mask 96 | 97 | if not self.training: 98 | return { 99 | 'e_score': last_e, 100 | 'word_attns': word_attns, 101 | 'rel_probs': rel_probs, 102 | 'ent_probs': ent_probs 103 | } 104 | else: 105 | # Distance loss 106 | weight = answers * 9 + 1 107 | loss_score = torch.mean(weight * torch.pow(last_e - answers, 2)) 108 | loss = {'loss_score': loss_score} 109 | 110 | if self.aux_hop: 111 | loss_hop = nn.CrossEntropyLoss()(hop_logit, hop-1) 112 | loss['loss_hop'] = 0.01 * loss_hop 113 | 114 | return loss 115 | -------------------------------------------------------------------------------- /MetaQA-KB/predict.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.nn as nn 4 | import numpy as np 5 | import argparse 6 | from tqdm import tqdm 7 | from collections import defaultdict 8 | from utils.misc import MetricLogger, load_glove, idx_to_one_hot 9 | from .data import DataLoader 10 | from .model import TransferNet 11 | 12 | from IPython import embed 13 | 14 | 15 | def validate(args, model, data, device, verbose = False): 16 | vocab = data.vocab 17 | model.eval() 18 | count = defaultdict(int) 19 | correct = defaultdict(int) 20 | with torch.no_grad(): 21 | for batch in tqdm(data, total=len(data)): 22 | questions, topic_entities, answers, hops = batch 23 | topic_entities = idx_to_one_hot(topic_entities, len(vocab['entity2id'])) 24 | answers = idx_to_one_hot(answers, len(vocab['entity2id'])) 25 | answers[:, 0] = 0 26 | questions = questions.to(device) 27 | topic_entities = topic_entities.to(device) 28 | hops = hops.tolist() 29 | outputs = model(questions, topic_entities) # [bsz, Esize] 30 | e_score = outputs['e_score'].cpu() 31 | scores, idx = torch.max(e_score, dim = 1) # [bsz], [bsz] 32 | match_score = torch.gather(answers, 1, idx.unsqueeze(-1)).squeeze().tolist() 33 | for h, m in zip(hops, match_score): 34 | count['all'] += 1 35 | count['{}-hop'.format(h)] += 1 36 | correct['all'] += m 37 | correct['{}-hop'.format(h)] += m 38 | if verbose: 39 | for i in range(len(answers)): 40 | # if answers[i][idx[i]].item() == 0: 41 | if hops[i] != 3: 42 | continue 43 | print('================================================================') 44 | question = ' '.join([vocab['id2word'][_] for _ in questions.tolist()[i] if _ > 0]) 45 | print(question) 46 | print('hop: {}'.format(hops[i])) 47 | print('> topic entity: {}'.format(vocab['id2entity'][topic_entities[i].max(0)[1].item()])) 48 | for t in range(args.num_steps): 49 | print('> > > step {}'.format(t)) 50 | tmp = ' '.join(['{}: {:.3f}'.format(vocab['id2word'][x], y) for x,y in 51 | zip(questions.tolist()[i], outputs['word_attns'][t].tolist()[i]) 52 | if x > 0]) 53 | print('> ' + tmp) 54 | tmp = ' '.join(['{}: {:.3f}'.format(vocab['id2relation'][x], y) for x,y in 55 | enumerate(outputs['rel_probs'][t].tolist()[i])]) 56 | print('> ' + tmp) 57 | print('> entity: {}'.format('; '.join([vocab['id2entity'][_] for _ in range(len(answers[i])) if outputs['ent_probs'][t][i][_].item() > 0.9]))) 58 | print('----') 59 | print('> max is {}'.format(vocab['id2entity'][idx[i].item()])) 60 | print('> golden: {}'.format('; '.join([vocab['id2entity'][_] for _ in range(len(answers[i])) if answers[i][_].item() == 1]))) 61 | print('> prediction: {}'.format('; '.join([vocab['id2entity'][_] for _ in range(len(answers[i])) if e_score[i][_].item() > 0.9]))) 62 | embed() 63 | acc = {k:correct[k]/count[k] for k in count} 64 | result = ' | '.join(['%s:%.4f'%(key, value) for key, value in acc.items()]) 65 | print(result) 66 | return acc 67 | 68 | 69 | def main(): 70 | parser = argparse.ArgumentParser() 71 | # input and output 72 | parser.add_argument('--input_dir', default = './input') 73 | parser.add_argument('--ckpt', required = True) 74 | parser.add_argument('--mode', default='val', choices=['val', 'vis', 'test']) 75 | # model hyperparameters 76 | parser.add_argument('--num_steps', default=3, type=int) 77 | parser.add_argument('--dim_word', default=300, type=int) 78 | parser.add_argument('--dim_hidden', default=1024, type=int) 79 | parser.add_argument('--aux_hop', type=int, default=1, choices=[0, 1], help='utilize question hop to constrain the probability of self relation') 80 | args = parser.parse_args() 81 | 82 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 83 | vocab_json = os.path.join(args.input_dir, 'vocab.json') 84 | val_pt = os.path.join(args.input_dir, 'val.pt') 85 | test_pt = os.path.join(args.input_dir, 'test.pt') 86 | val_loader = DataLoader(vocab_json, val_pt, 64, True) 87 | test_loader = DataLoader(vocab_json, test_pt, 64) 88 | vocab = val_loader.vocab 89 | 90 | model = TransferNet(args, args.dim_word, args.dim_hidden, vocab) 91 | missing, unexpected = model.load_state_dict(torch.load(args.ckpt), strict=False) 92 | if missing: 93 | print("Missing keys: {}".format("; ".join(missing))) 94 | if unexpected: 95 | print("Unexpected keys: {}".format("; ".join(unexpected))) 96 | model = model.to(device) 97 | model.kg.Msubj = model.kg.Msubj.to(device) 98 | model.kg.Mobj = model.kg.Mobj.to(device) 99 | model.kg.Mrel = model.kg.Mrel.to(device) 100 | 101 | num_params = sum(np.prod(p.size()) for p in model.parameters()) 102 | print('number of parameters: {}'.format(num_params)) 103 | 104 | if args.mode == 'vis': 105 | validate(args, model, val_loader, device, True) 106 | elif args.mode == 'val': 107 | validate(args, model, val_loader, device, False) 108 | elif args.mode == 'test': 109 | validate(args, model, test_loader, device, False) 110 | 111 | if __name__ == '__main__': 112 | main() 113 | -------------------------------------------------------------------------------- /MetaQA-KB/preprocess.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import pickle 4 | import argparse 5 | import numpy as np 6 | from nltk import word_tokenize 7 | import collections 8 | from collections import Counter 9 | from itertools import chain 10 | from tqdm import tqdm 11 | from utils.misc import * 12 | import re 13 | 14 | 15 | 16 | def encode_kb(args, vocab): 17 | with open(os.path.join(args.input_dir, 'kb/kb.txt')) as f: 18 | kb = f.readlines() 19 | 20 | Msubj = [] 21 | Mobj = [] 22 | Mrel = [] 23 | idx = 0 24 | for line in tqdm(kb): 25 | s, r, o = line.strip().split('|') 26 | r_inv = r + '_inv' 27 | add_item_to_x2id(s, vocab['entity2id']) 28 | add_item_to_x2id(o, vocab['entity2id']) 29 | add_item_to_x2id(r, vocab['relation2id']) 30 | add_item_to_x2id(r_inv, vocab['relation2id']) 31 | s_id, r_id, o_id, r_inv_id = vocab['entity2id'][s], vocab['relation2id'][r], vocab['entity2id'][o], vocab['relation2id'][r_inv] 32 | Msubj.append([idx, s_id]) 33 | Mobj.append([idx, o_id]) 34 | Mrel.append([idx, r_id]) 35 | idx += 1 36 | Msubj.append([idx, o_id]) 37 | Mobj.append([idx, s_id]) 38 | Mrel.append([idx, r_inv_id]) 39 | idx += 1 40 | 41 | # self relation 42 | # r = '' 43 | # add_item_to_x2id(r, vocab['relation2id']) 44 | # r_id = vocab['relation2id'][r] 45 | # for i in vocab['entity2id'].values(): 46 | # Msubj.append([idx, i]) 47 | # Mobj.append([idx, i]) 48 | # Mrel.append([idx, r_id]) 49 | # idx += 1 50 | 51 | 52 | Tsize = len(Msubj) 53 | Esize = len(vocab['entity2id']) 54 | Rsize = len(vocab['relation2id']) 55 | Msubj = np.array(Msubj, dtype = np.int32) 56 | Mobj = np.array(Mobj, dtype = np.int32) 57 | Mrel = np.array(Mrel, dtype = np.int32) 58 | assert len(Msubj) == Tsize 59 | assert len(Mobj) == Tsize 60 | assert len(Mrel) == Tsize 61 | np.save(os.path.join(args.output_dir, 'Msubj.npy'), Msubj) 62 | np.save(os.path.join(args.output_dir, 'Mobj.npy'), Mobj) 63 | np.save(os.path.join(args.output_dir, 'Mrel.npy'), Mrel) 64 | 65 | 66 | # Sanity check 67 | print('Sanity check: {} entities'.format(len(vocab['entity2id']))) 68 | print('Sanity check: {} relations'.format(len(vocab['relation2id']))) 69 | print('Sanity check: {} triples'.format(len(kb))) 70 | 71 | def encode_qa(args, vocab): 72 | pattern = re.compile(r'\[(.*?)\]') 73 | hops = ['%d-hop'%((int)(num)) for num in args.num_hop.split(',')] 74 | datasets = [] 75 | for dataset in ['train', 'test', 'dev']: 76 | data = [] 77 | for hop in hops: 78 | with open(os.path.join(args.input_dir, (hop + '/vanilla/qa_%s.txt'%(dataset)))) as f: 79 | qas = f.readlines() 80 | for qa in qas: 81 | question, answers = qa.strip().split('\t') 82 | topic_entity = re.search(pattern, question).group(1) 83 | if args.replace_es: 84 | question = re.sub(r'\[.*\]', 'E_S', question) 85 | else: 86 | question = question.replace('[', '').replace(']', '') 87 | answers = answers.split('|') 88 | assert topic_entity in vocab['entity2id'] 89 | for answer in answers: 90 | assert answer in vocab['entity2id'] 91 | data.append({'question':question, 'topic_entity':topic_entity, 'answers':answers, 'hop':int(hop[0])}) 92 | datasets.append(data) 93 | json.dump(data, open(os.path.join(args.output_dir, '%s.json'%(dataset)), 'w')) 94 | 95 | train_set, test_set, val_set = datasets[0], datasets[1], datasets[2] 96 | print('size of training data: {}'.format(len(train_set))) 97 | print('size of test data: {}'.format(len(test_set))) 98 | print('size of valid data: {}'.format(len(val_set))) 99 | print('Build question vocabulary') 100 | word_counter = Counter() 101 | for qa in tqdm(train_set): 102 | tokens = word_tokenize(qa['question'].lower()) 103 | word_counter.update(tokens) 104 | stopwords = set() 105 | for w, c in word_counter.items(): 106 | if w and c >= args.min_cnt: 107 | add_item_to_x2id(w, vocab['word2id']) 108 | if w and c >= args.stop_thresh: 109 | stopwords.add(w) 110 | print('number of stop words (>={}): {}'.format(args.stop_thresh, len(stopwords))) 111 | print('number of word in dict: {}'.format(len(vocab['word2id']))) 112 | with open(os.path.join(args.output_dir, 'vocab.json'), 'w') as f: 113 | json.dump(vocab, f, indent=2) 114 | 115 | for name, dataset in zip(('train', 'val', 'test'), (train_set, val_set, test_set)): 116 | print('Encode {} set'.format(name)) 117 | outputs = encode_dataset(vocab, dataset) 118 | print('shape of questions, topic_entities, answers, hops:') 119 | with open(os.path.join(args.output_dir, '{}.pt'.format(name)), 'wb') as f: 120 | for o in outputs: 121 | print(o.shape) 122 | pickle.dump(o, f) 123 | 124 | def encode_dataset(vocab, dataset): 125 | questions = [] 126 | topic_entities = [] 127 | answers = [] 128 | hops = [] 129 | for qa in tqdm(dataset): 130 | assert len(qa['topic_entity']) > 0 131 | questions.append([vocab['word2id'].get(w, vocab['word2id']['']) for w in word_tokenize(qa['question'].lower())]) 132 | topic_entities.append([vocab['entity2id'][qa['topic_entity']]]) 133 | answers.append([vocab['entity2id'][answer] for answer in qa['answers']]) 134 | hops.append(qa['hop']) 135 | 136 | # question padding 137 | max_len = max(len(q) for q in questions) 138 | print('max question length:{}'.format(max_len)) 139 | for q in questions: 140 | while len(q) < max_len: 141 | q.append(vocab['word2id']['']) 142 | questions = np.asarray(questions, dtype=np.int32) 143 | topic_entities = np.asarray(topic_entities, dtype=np.int32) 144 | max_len = max(len(a) for a in answers) 145 | print('max answer length:{}'.format(max_len)) 146 | for a in answers: 147 | while len(a) < max_len: 148 | a.append(DUMMY_ENTITY_ID) 149 | answers = np.asarray(answers, dtype=np.int32) 150 | hops = np.asarray(hops, dtype=np.int8) 151 | return questions, topic_entities, answers, hops 152 | 153 | def main(): 154 | parser = argparse.ArgumentParser() 155 | parser.add_argument('--input_dir', default = '/data/csl/resources/KBQA_datasets/MetaQA', type = str) 156 | parser.add_argument('--output_dir', default = '/data/csl/exp/TransferNet/input', type = str) 157 | parser.add_argument('--min_cnt', type=int, default=1) 158 | parser.add_argument('--stop_thresh', type=int, default=1000) 159 | parser.add_argument('--num_hop', type = str, default = '1, 2, 3') 160 | parser.add_argument('--replace_es', type = int, default = 1) 161 | args = parser.parse_args() 162 | print(args) 163 | if not os.path.isdir(args.output_dir): 164 | os.makedirs(args.output_dir) 165 | 166 | print('Init vocabulary') 167 | vocab = { 168 | 'word2id': init_word2id(), 169 | 'entity2id': init_entity2id(), 170 | 'relation2id': {}, 171 | 'topic_entity': {} 172 | } 173 | 174 | print('Encode kb') 175 | encode_kb(args, vocab) 176 | 177 | print('Encode qa') 178 | encode_qa(args, vocab) 179 | 180 | if __name__ == '__main__': 181 | main() 182 | -------------------------------------------------------------------------------- /MetaQA-KB/train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.optim as optim 4 | import torch.nn as nn 5 | import numpy as np 6 | import argparse 7 | import shutil 8 | from tqdm import tqdm 9 | import time 10 | from utils.misc import MetricLogger, load_glove, idx_to_one_hot, RAdam 11 | from .data import DataLoader 12 | from .model import TransferNet 13 | from .predict import validate 14 | import logging 15 | logging.basicConfig(level=logging.INFO, format='%(asctime)s %(levelname)-8s %(message)s') 16 | logFormatter = logging.Formatter('%(asctime)s %(levelname)-8s %(message)s') 17 | rootLogger = logging.getLogger() 18 | 19 | torch.set_num_threads(1) # avoid using multiple cpus 20 | 21 | 22 | def train(args): 23 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 24 | 25 | logging.info("Create train_loader, val_loader and test_loader.........") 26 | vocab_json = os.path.join(args.input_dir, 'vocab.json') 27 | train_pt = os.path.join(args.input_dir, 'train.pt') 28 | val_pt = os.path.join(args.input_dir, 'val.pt') 29 | test_pt = os.path.join(args.input_dir, 'test.pt') 30 | train_loader = DataLoader(vocab_json, train_pt, args.batch_size, args.ratio, training=True) 31 | val_loader = DataLoader(vocab_json, val_pt, args.batch_size) 32 | test_loader = DataLoader(vocab_json, test_pt, args.batch_size) 33 | vocab = train_loader.vocab 34 | 35 | logging.info("Create model.........") 36 | pretrained = load_glove(args.glove_pt, vocab['id2word']) 37 | model = TransferNet(args, args.dim_word, args.dim_hidden, vocab) 38 | model.word_embeddings.weight.data = torch.Tensor(pretrained) 39 | if not args.ckpt == None: 40 | missing, unexpected = model.load_state_dict(torch.load(args.ckpt), strict=False) 41 | if missing: 42 | logging.info("Missing keys: {}".format("; ".join(missing))) 43 | if unexpected: 44 | logging.info("Unexpected keys: {}".format("; ".join(unexpected))) 45 | model = model.to(device) 46 | model.kg.Msubj = model.kg.Msubj.to(device) 47 | model.kg.Mobj = model.kg.Mobj.to(device) 48 | model.kg.Mrel = model.kg.Mrel.to(device) 49 | 50 | logging.info(model) 51 | if args.opt == 'adam': 52 | optimizer = optim.Adam(model.parameters(), args.lr, weight_decay=args.weight_decay) 53 | elif args.opt == 'radam': 54 | optimizer = RAdam(model.parameters(), args.lr, weight_decay=args.weight_decay) 55 | elif args.opt == 'sgd': 56 | optimizer = optim.SGD(model.parameters(), args.lr, weight_decay=args.weight_decay) 57 | elif args.opt == 'adagrad': 58 | optimizer = optim.Adagrad(model.parameters(), args.lr, weight_decay=args.weight_decay) 59 | else: 60 | raise NotImplementedError 61 | scheduler = optim.lr_scheduler.MultiStepLR(optimizer=optimizer, milestones=[3], gamma=0.1) 62 | 63 | meters = MetricLogger(delimiter=" ") 64 | # validate(args, model, val_loader, device) 65 | logging.info("Start training........") 66 | 67 | for epoch in range(args.num_epoch): 68 | model.train() 69 | for iteration, batch in enumerate(train_loader): 70 | iteration = iteration + 1 71 | 72 | question, topic_entity, answer, hop = batch 73 | question = question.to(device) 74 | topic_entity = idx_to_one_hot(topic_entity, len(vocab['entity2id'])).to(device) 75 | answer = idx_to_one_hot(answer, len(vocab['entity2id'])).to(device) 76 | answer[:, 0] = 0 77 | hop = hop.to(device) 78 | loss = model(question, topic_entity, answer, hop) 79 | optimizer.zero_grad() 80 | if isinstance(loss, dict): 81 | total_loss = sum(loss.values()) 82 | meters.update(**{k:v.item() for k,v in loss.items()}) 83 | else: 84 | total_loss = loss 85 | meters.update(loss=loss.item()) 86 | total_loss.backward() 87 | nn.utils.clip_grad_value_(model.parameters(), 0.5) 88 | nn.utils.clip_grad_norm_(model.parameters(), 2) 89 | optimizer.step() 90 | 91 | if iteration % (len(train_loader) // 100) == 0: 92 | logging.info( 93 | meters.delimiter.join( 94 | [ 95 | "progress: {progress:.3f}", 96 | "{meters}", 97 | "lr: {lr:.6f}", 98 | ] 99 | ).format( 100 | progress=epoch + iteration / len(train_loader), 101 | meters=str(meters), 102 | lr=optimizer.param_groups[0]["lr"], 103 | ) 104 | ) 105 | 106 | if (epoch + 1) % int(1 / args.ratio) == 0: 107 | acc = validate(args, model, val_loader, device) 108 | logging.info(acc) 109 | scheduler.step() 110 | torch.save(model.state_dict(), os.path.join(args.save_dir, 'model_epoch-{}_acc-{:.4f}.pt'.format(epoch, acc['all']))) 111 | 112 | 113 | def main(): 114 | parser = argparse.ArgumentParser() 115 | # input and output 116 | parser.add_argument('--input_dir', required=True) 117 | parser.add_argument('--save_dir', required=True, help='path to save checkpoints and logs') 118 | parser.add_argument('--glove_pt', required=True) 119 | parser.add_argument('--ckpt', default = None) 120 | # training parameters 121 | parser.add_argument('--lr', default=0.001, type=float) 122 | parser.add_argument('--weight_decay', default=1e-5, type=float) 123 | parser.add_argument('--num_epoch', default=10, type=int) 124 | parser.add_argument('--batch_size', default=128, type=int) 125 | parser.add_argument('--seed', type=int, default=666, help='random seed') 126 | parser.add_argument('--opt', default='radam', type = str) 127 | parser.add_argument('--ratio', default=1.0, type=float) 128 | # model hyperparameters 129 | parser.add_argument('--num_steps', default=3, type=int) 130 | parser.add_argument('--dim_word', default=300, type=int) 131 | parser.add_argument('--dim_hidden', default=1024, type=int) 132 | parser.add_argument('--aux_hop', type=int, default=1, choices=[0, 1], help='utilize question hop to constrain the probability of self relation') 133 | args = parser.parse_args() 134 | 135 | # make logging.info display into both shell and file 136 | if not os.path.exists(args.save_dir): 137 | os.makedirs(args.save_dir) 138 | time_ = time.strftime("%Y-%m-%d-%H:%M:%S", time.localtime()) 139 | args.log_name = time_ + '_{}_{}_{}.log'.format(args.opt, args.lr, args.batch_size) 140 | fileHandler = logging.FileHandler(os.path.join(args.save_dir, args.log_name)) 141 | fileHandler.setFormatter(logFormatter) 142 | rootLogger.addHandler(fileHandler) 143 | # args display 144 | for k, v in vars(args).items(): 145 | logging.info(k+':'+str(v)) 146 | 147 | if args.ratio < 1: 148 | args.num_epoch = int(args.num_epoch / args.ratio) 149 | logging.info('Due to partial training examples, the actual num_epoch is set to {}'.format(args.num_epoch)) 150 | 151 | torch.backends.cudnn.deterministic = True 152 | torch.backends.cudnn.benchmark = False 153 | # set random seed 154 | torch.manual_seed(args.seed) 155 | np.random.seed(args.seed) 156 | 157 | train(args) 158 | 159 | 160 | if __name__ == '__main__': 161 | main() 162 | -------------------------------------------------------------------------------- /MetaQA-Text/data.py: -------------------------------------------------------------------------------- 1 | import json 2 | import pickle 3 | import torch 4 | import numpy as np 5 | from utils.misc import invert_dict 6 | 7 | 8 | def load_vocab(path): 9 | vocab = json.load(open(path)) 10 | vocab['id2word'] = invert_dict(vocab['word2id']) 11 | vocab['id2entity'] = invert_dict(vocab['entity2id']) 12 | return vocab 13 | 14 | def collate(batch): 15 | batch = list(zip(*batch)) 16 | question, topic_entity, answer = list(map(torch.stack, batch[:3])) 17 | hop = torch.LongTensor(batch[3]) 18 | return question, topic_entity, answer, hop 19 | 20 | 21 | class Dataset(torch.utils.data.Dataset): 22 | def __init__(self, inputs): 23 | self.questions, self.topic_entities, self.answers, self.hops = inputs 24 | 25 | def __getitem__(self, index): 26 | question = torch.LongTensor(self.questions[index]) 27 | topic_entity = torch.LongTensor(self.topic_entities[index]) 28 | answer = torch.LongTensor(self.answers[index]) 29 | hop = self.hops[index] 30 | return question, topic_entity, answer, hop 31 | 32 | 33 | def __len__(self): 34 | return len(self.questions) 35 | 36 | 37 | class DataLoader(torch.utils.data.DataLoader): 38 | def __init__(self, vocab_json, question_pt, batch_size, limit_hop=-1, training=False, curriculum=False): 39 | vocab = load_vocab(vocab_json) 40 | 41 | inputs = [] 42 | with open(question_pt, 'rb') as f: 43 | for _ in range(4): 44 | inputs.append(pickle.load(f)) 45 | 46 | if limit_hop > 0: 47 | print('only keep questions of hop {}'.format(limit_hop)) 48 | mask = inputs[3] == limit_hop 49 | inputs = [i[mask] for i in inputs] 50 | curriculum = False 51 | 52 | if curriculum: 53 | print('curriculum') 54 | hops = inputs[3] 55 | idxs = [] 56 | for h in [1, 2, 3]: 57 | idx = np.nonzero(hops==h)[0] 58 | np.random.shuffle(idx) 59 | idxs.append(idx) 60 | idxs = np.concatenate(idxs) 61 | inputs = [i[idxs] for i in inputs] 62 | 63 | print('data number: {}'.format(len(inputs[0]))) 64 | 65 | dataset = Dataset(inputs) 66 | 67 | shuffle = training 68 | if curriculum: 69 | shuffle = False 70 | super().__init__( 71 | dataset, 72 | batch_size=batch_size, 73 | shuffle=shuffle, 74 | collate_fn=collate, 75 | ) 76 | self.vocab = vocab 77 | -------------------------------------------------------------------------------- /MetaQA-Text/model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.nn as nn 4 | import pickle 5 | import math 6 | import random 7 | 8 | from utils.BiGRU import GRU, BiGRU 9 | 10 | class TransferNet(nn.Module): 11 | def __init__(self, args, vocab): 12 | super().__init__() 13 | self.args = args 14 | self.vocab = vocab 15 | self.max_active = args.max_active 16 | self.ent_act_thres = args.ent_act_thres 17 | self.aux_hop = args.aux_hop 18 | dim_word = args.dim_word 19 | dim_hidden = args.dim_hidden 20 | 21 | with open(os.path.join(args.input_dir, 'wiki.pt'), 'rb') as f: 22 | self.kb_pair = torch.LongTensor(pickle.load(f)) 23 | self.kb_range = torch.LongTensor(pickle.load(f)) 24 | self.kb_desc = torch.LongTensor(pickle.load(f)) 25 | 26 | print('number of triples: {}'.format(len(self.kb_pair))) 27 | 28 | num_words = len(vocab['word2id']) 29 | num_entities = len(vocab['entity2id']) 30 | self.num_steps = args.num_steps 31 | 32 | self.desc_encoder = BiGRU(dim_word, dim_hidden, num_layers=1, dropout=0.2) 33 | self.question_encoder = BiGRU(dim_word, dim_hidden, num_layers=1, dropout=0.2) 34 | 35 | self.word_embeddings = nn.Embedding(num_words, dim_word) 36 | self.word_dropout = nn.Dropout(0.2) 37 | self.step_encoders = [] 38 | for i in range(self.num_steps): 39 | m = nn.Sequential( 40 | nn.Linear(dim_hidden, dim_hidden), 41 | nn.Tanh(), 42 | ) 43 | self.step_encoders.append(m) 44 | self.add_module('step_encoders_{}'.format(i), m) 45 | self.rel_classifier = nn.Linear(dim_hidden, 1) 46 | 47 | self.q_classifier = nn.Linear(dim_hidden, num_entities) 48 | self.hop_selector = nn.Linear(dim_hidden, self.num_steps) 49 | 50 | 51 | def follow(self, e, pair, p): 52 | """ 53 | Args: 54 | e [num_ent]: entity scores 55 | pair [rsz, 2]: pairs that are taken into consider 56 | p [rsz]: transfer probabilities of each pair 57 | """ 58 | sub, obj = pair[:, 0], pair[:, 1] 59 | obj_p = e[sub] * p 60 | out = torch.index_add(torch.zeros_like(e), 0, obj, obj_p) 61 | return out 62 | 63 | 64 | def forward(self, questions, e_s, answers=None, hop=None): 65 | question_lens = questions.size(1) - questions.eq(0).long().sum(dim=1) # 0 means 66 | q_word_emb = self.word_dropout(self.word_embeddings(questions)) # [bsz, max_q, dim_hidden] 67 | q_word_h, q_embeddings, q_hn = self.question_encoder(q_word_emb, question_lens) # [bsz, max_q, dim_h], [bsz, dim_h], [num_layers, bsz, dim_h] 68 | 69 | 70 | device = q_word_h.device 71 | bsz, dim_h = q_embeddings.size() 72 | last_e = e_s 73 | word_attns = [] 74 | ent_probs = [] 75 | 76 | path_infos = [] # [bsz, num_steps] 77 | for i in range(bsz): 78 | path_infos.append([]) 79 | for j in range(self.num_steps): 80 | path_infos[i].append(None) 81 | 82 | for t in range(self.num_steps): 83 | cq_t = self.step_encoders[t](q_embeddings) # [bsz, dim_h] 84 | q_logits = torch.sum(cq_t.unsqueeze(1) * q_word_h, dim=2) # [bsz, max_q] 85 | q_dist = torch.softmax(q_logits, 1).unsqueeze(1) # [bsz, 1, max_q] 86 | q_dist = q_dist * questions.ne(0).float().unsqueeze(1) 87 | q_dist = q_dist / (torch.sum(q_dist, dim=2, keepdim=True) + 1e-6) # [bsz, 1, max_q] 88 | word_attns.append(q_dist.squeeze(1)) 89 | ctx_h = (q_dist @ q_word_h).squeeze(1) # [bsz, dim_h] 90 | ctx_h = ctx_h + cq_t 91 | 92 | e_stack = [] 93 | cnt_trunc = 0 94 | for i in range(bsz): 95 | # e_idx = torch.topk(last_e[i], k=1, dim=0)[1].tolist() + \ 96 | # last_e[i].gt(self.ent_act_thres).nonzero().squeeze(1).tolist() 97 | # TRY 98 | # if self.training and t > 0 and random.random() < 0.005: 99 | # e_idx = last_e[i].gt(0).nonzero().squeeze(1).tolist() 100 | # random.shuffle(e_idx) 101 | # else: 102 | sort_score, sort_idx = torch.sort(last_e[i], dim=0, descending=True) 103 | e_idx = sort_idx[sort_score.gt(self.ent_act_thres)].tolist() 104 | e_idx = set(e_idx) - set([0]) 105 | if len(e_idx) == 0: 106 | # print('no active entity at step {}'.format(t)) 107 | pad = sort_idx[0].item() 108 | if pad == 0: 109 | pad = sort_idx[1].item() 110 | e_idx = set([pad]) 111 | 112 | rg = [] 113 | for j in e_idx: 114 | rg.append(torch.arange(self.kb_range[j,0], self.kb_range[j,1]).long().to(device)) 115 | rg = torch.cat(rg, dim=0) # [rsz,] 116 | # print(len(e_idx), len(rg)) 117 | if len(rg) > self.max_active: # limit the number of next-hop 118 | rg = rg[:self.max_active] 119 | # TRY 120 | # rg = rg[torch.randperm(len(rg))[:self.max_active]] 121 | cnt_trunc += 1 122 | # print('trunc: {}'.format(cnt_trunc)) 123 | 124 | # print('step {}, desc number {}'.format(t, len(rg))) 125 | pair = self.kb_pair[rg] # [rsz, 2] 126 | desc = self.kb_desc[rg] # [rsz, max_desc] 127 | desc_lens = desc.size(1) - desc.eq(0).long().sum(dim=1) 128 | desc_word_emb = self.word_dropout(self.word_embeddings(desc)) 129 | desc_word_h, desc_embeddings, _ = self.desc_encoder(desc_word_emb, desc_lens) # [rsz, dim_h] 130 | d_logit = self.rel_classifier(ctx_h[i:i+1] * desc_embeddings).squeeze(1) # [rsz,] 131 | d_prob = torch.sigmoid(d_logit) # [rsz,] 132 | # transfer probability 133 | e_stack.append(self.follow(last_e[i], pair, d_prob)) 134 | 135 | # collect path 136 | act_idx = d_prob.gt(0.9) 137 | act_pair = pair[act_idx].tolist() 138 | act_desc = [' '.join([self.vocab['id2word'][w] for w in d if w > 0]) for d in desc[act_idx].tolist()] 139 | path_infos[i][t] = [(act_pair[_][0], act_desc[_], act_pair[_][1]) for _ in range(len(act_pair))] 140 | 141 | last_e = torch.stack(e_stack, dim=0) 142 | 143 | # reshape >1 scores to 1 in a differentiable way 144 | m = last_e.gt(1).float() 145 | z = (m * last_e + (1-m)).detach() 146 | last_e = last_e / z 147 | 148 | # Specifically for MetaQA: reshape cycle entities to 0, because A-r->B-r_inv->A is not allowed 149 | if t > 0: 150 | ent_m = torch.zeros_like(last_e) 151 | for i in range(bsz): 152 | prev_inv = set() 153 | for (s, r, o) in path_infos[i][t-1]: 154 | prev_inv.add((o, r.replace('__subject__', 'obj').replace('__object__', 'sub'), s)) 155 | for (s, r, o) in path_infos[i][t]: 156 | element = (s, r.replace('__subject__', 'sub').replace('__object__', 'obj'), o) 157 | if r != '__self_rel__' and element in prev_inv: 158 | ent_m[i, o] = 1 159 | # print('block cycle: {}'.format(' ---> '.join(list(map(str, element))))) 160 | last_e = (1-ent_m) * last_e 161 | 162 | ent_probs.append(last_e) 163 | 164 | hop_res = torch.stack(ent_probs, dim=1) # [bsz, num_hop, num_ent] 165 | hop_logit = self.hop_selector(q_embeddings) 166 | hop_attn = torch.softmax(hop_logit, dim=1) # [bsz, num_hop] 167 | last_e = torch.sum(hop_res * hop_attn.unsqueeze(2), dim=1) # [bsz, num_ent] 168 | 169 | # Specifically for MetaQA: for 2-hop questions, topic entity is excluded from answer 170 | m = hop_attn.argmax(dim=1).eq(1).float().unsqueeze(1) * e_s 171 | last_e = (1-m) * last_e 172 | 173 | # question mask, incorporate language bias 174 | q_mask = torch.sigmoid(self.q_classifier(q_embeddings)) 175 | last_e = last_e * q_mask 176 | 177 | if not self.training: 178 | return { 179 | 'e_score': last_e, 180 | 'word_attns': word_attns, 181 | 'ent_probs': ent_probs, 182 | 'path_infos': path_infos 183 | } 184 | else: 185 | weight = answers * 9 + 1 186 | loss_score = torch.mean(weight * torch.pow(last_e - answers, 2)) 187 | 188 | loss = {'loss_score': loss_score} 189 | 190 | if self.aux_hop: 191 | loss_hop = nn.CrossEntropyLoss()(hop_logit, hop-1) 192 | loss['loss_hop'] = 0.01 * loss_hop 193 | 194 | return loss 195 | -------------------------------------------------------------------------------- /MetaQA-Text/predict.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.nn as nn 4 | import numpy as np 5 | import argparse 6 | from tqdm import tqdm 7 | from collections import defaultdict 8 | from utils.misc import MetricLogger, load_glove, idx_to_one_hot 9 | from .data import DataLoader 10 | from .model import TransferNet 11 | 12 | from IPython import embed 13 | 14 | 15 | def validate(args, model, data, device, verbose = False): 16 | vocab = data.vocab 17 | model.eval() 18 | count = defaultdict(int) 19 | correct = defaultdict(int) 20 | with torch.no_grad(): 21 | for batch in tqdm(data, total=len(data)): 22 | questions, topic_entities, answers, hops = batch 23 | topic_entities = idx_to_one_hot(topic_entities, len(vocab['entity2id'])) 24 | answers = idx_to_one_hot(answers, len(vocab['entity2id'])) 25 | answers[:, 0] = 0 26 | questions = questions.to(device) 27 | topic_entities = topic_entities.to(device) 28 | hops = hops.tolist() 29 | outputs = model(questions, topic_entities) # [bsz, Esize] 30 | e_score = outputs['e_score'].cpu() 31 | scores, idx = torch.max(e_score, dim = 1) # [bsz], [bsz] 32 | match_score = torch.gather(answers, 1, idx.unsqueeze(-1)).squeeze(1).tolist() 33 | for h, m in zip(hops, match_score): 34 | count['all'] += 1 35 | count['{}-hop'.format(h)] += 1 36 | correct['all'] += m 37 | correct['{}-hop'.format(h)] += m 38 | if verbose: 39 | for i in range(len(answers)): 40 | # if answers[i][idx[i]].item() == 0: 41 | if hops[i] != 3: 42 | continue 43 | print('================================================================') 44 | question = ' '.join([vocab['id2word'][_] for _ in questions.tolist()[i] if _ > 0]) 45 | print(question) 46 | print('hop: {}'.format(hops[i])) 47 | print('> topic entity: {}'.format(vocab['id2entity'][topic_entities[i].max(0)[1].item()])) 48 | 49 | for t in range(args.num_steps): 50 | print('>>>>>>>>>> step {} <<<<<<<<<<'.format(t)) 51 | tmp = ' '.join(['{}: {:.3f}'.format(vocab['id2word'][x], y) for x,y in 52 | zip(questions.tolist()[i], outputs['word_attns'][t].tolist()[i]) 53 | if x >= 0]) 54 | print('> ' + tmp) 55 | print('--- transfer path ---') 56 | for (ps, rd, pt) in outputs['path_infos'][i][t]: 57 | print('{} ---> {} ---> {}'.format( 58 | vocab['id2entity'][ps], rd, vocab['id2entity'][pt] 59 | )) 60 | print('> entity: {}'.format('; '.join([vocab['id2entity'][_] for _ in range(len(answers[i])) if outputs['ent_probs'][t][i][_].item() > 0.9]))) 61 | print('-----------') 62 | print('> max is {}'.format(vocab['id2entity'][idx[i].item()])) 63 | print('> golden: {}'.format('; '.join([vocab['id2entity'][_] for _ in range(len(answers[i])) if answers[i][_].item() == 1]))) 64 | print('> prediction: {}'.format('; '.join([vocab['id2entity'][_] for _ in range(len(answers[i])) if e_score[i][_].item() > 0.9]))) 65 | embed() 66 | acc = {k:(correct[k]/count[k] if count[k]>0 else -1) for k in count} 67 | result = ' | '.join(['%s:%.4f'%(key, value) for key, value in acc.items()]) 68 | print(result) 69 | return acc 70 | 71 | 72 | def main(): 73 | parser = argparse.ArgumentParser() 74 | # input and output 75 | parser.add_argument('--input_dir', default = './input') 76 | parser.add_argument('--ckpt', required = True) 77 | parser.add_argument('--mode', default='val', choices=['val', 'vis', 'test']) 78 | # model hyperparameters 79 | parser.add_argument('--aux_hop', type=int, default=1, choices=[0, 1], help='utilize question hop to constrain the probability of self relation') 80 | parser.add_argument('--num_steps', default=3, type=int) 81 | parser.add_argument('--dim_word', default=300, type=int) 82 | parser.add_argument('--dim_hidden', default=768, type=int) 83 | parser.add_argument('--ent_act_thres', default=0.7, type=float, help='activate an entity when its score exceeds this value') 84 | parser.add_argument('--max_active', default=400, type=int, help='max number of active entities at each step') 85 | parser.add_argument('--limit_hop', default=-1, type=int) 86 | args = parser.parse_args() 87 | 88 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 89 | vocab_json = os.path.join(args.input_dir, 'vocab.json') 90 | val_pt = os.path.join(args.input_dir, 'val.pt') 91 | test_pt = os.path.join(args.input_dir, 'test.pt') 92 | val_loader = DataLoader(vocab_json, val_pt, 64, args.limit_hop, True) 93 | test_loader = DataLoader(vocab_json, test_pt, 64, args.limit_hop) 94 | vocab = val_loader.vocab 95 | 96 | model = TransferNet(args, vocab) 97 | model.load_state_dict(torch.load(args.ckpt)) 98 | model = model.to(device) 99 | model.kb_pair = model.kb_pair.to(device) 100 | model.kb_range = model.kb_range.to(device) 101 | model.kb_desc = model.kb_desc.to(device) 102 | 103 | num_params = sum(np.prod(p.size()) for p in model.parameters()) 104 | print('number of parameters: {}'.format(num_params)) 105 | 106 | if args.mode == 'vis': 107 | validate(args, model, val_loader, device, True) 108 | elif args.mode == 'val': 109 | validate(args, model, val_loader, device, False) 110 | elif args.mode == 'test': 111 | validate(args, model, test_loader, device, False) 112 | 113 | if __name__ == '__main__': 114 | main() 115 | -------------------------------------------------------------------------------- /MetaQA-Text/preprocess.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import pickle 4 | import argparse 5 | import numpy as np 6 | from nltk import word_tokenize 7 | import collections 8 | from collections import Counter, defaultdict 9 | from itertools import chain 10 | from tqdm import tqdm 11 | from utils.misc import * 12 | import re 13 | import random 14 | 15 | 16 | SUB_PH = '__subject__' 17 | OBJ_PH = '__object__' 18 | ENT_PH = '__entity__' 19 | SELF_PH = '__self_rel__' 20 | 21 | 22 | def encode_kb(args, vocab): 23 | def tokenPat(s): 24 | # avoid that s is a substring of another token 25 | return r'(^|(?<=\W))' + s + r'((?=\W)|$)' 26 | 27 | kb = defaultdict(list) 28 | for line in open(os.path.join(args.input_dir, 'kb/kb.txt')): 29 | s, r, o = line.strip().split('|') 30 | kb[s].append((r, o)) 31 | 32 | # read from wiki 33 | triples = [] 34 | cache = [] 35 | for line in tqdm(chain(open(os.path.join(args.input_dir, 'kb/wiki.txt')), ['\n'])): 36 | line = line.strip() 37 | if line == '': 38 | if len(cache) == 0: 39 | continue 40 | subject = re.sub(r'\(.*\)', '', cache[0]).strip() 41 | for line in cache[1:]: 42 | # Note: force the match is in lower case, but keep subject and object in original case 43 | line = line.lower() 44 | # first replace subject with placeholder 45 | line = re.sub(tokenPat(subject.lower()), SUB_PH, line) 46 | used_objs = set([o for _, o in kb[subject] if re.search(tokenPat(o.lower()), line)]) 47 | for obj in used_objs: 48 | desc = line 49 | # Note: some objects share the same name, so we must guarantee the OBJ_PH is placed before ENT_PH 50 | desc = re.sub(tokenPat(obj.lower()), OBJ_PH, desc) 51 | for other in used_objs-{obj}: 52 | desc = re.sub(tokenPat(other.lower()), ENT_PH, desc) 53 | # Note: operations to desc must be after placeholder 54 | desc = desc.replace('/', ' / ').replace('–', ' - ').replace('-', ' - ') 55 | tokens = word_tokenize(desc) 56 | if OBJ_PH not in tokens: 57 | # print() 58 | # print(line) 59 | # print(desc) 60 | # print(obj) 61 | # print(tokens) 62 | # from IPython import embed; embed() 63 | if len(obj) > 3: 64 | tokens = word_tokenize(' '.join(tokens).replace(OBJ_PH, ' '+OBJ_PH+' ')) 65 | else: 66 | continue 67 | 68 | # filter out useless tokens 69 | tokens = list(filter(lambda x: x not in {',', '.', 'and', ENT_PH}, tokens)) 70 | # truncate to max_desc 71 | c = tokens.index(OBJ_PH) 72 | if len(tokens) > args.max_desc: 73 | tokens = tokens[max(c-args.max_desc//2, 0): c+args.max_desc//2] 74 | triples.append((subject, obj, tokens)) 75 | 76 | backward_tokens = [] 77 | for t in tokens: 78 | if t == SUB_PH: 79 | backward_tokens.append(OBJ_PH) 80 | elif t == OBJ_PH: 81 | backward_tokens.append(SUB_PH) 82 | else: 83 | backward_tokens.append(t) 84 | triples.append((obj, subject, backward_tokens)) 85 | cache = [] 86 | else: 87 | line = ' '.join(line.split()[1:]) 88 | cache.append(line) 89 | 90 | # add structured knowledge based on required ratio 91 | if args.kb_ratio > 0: 92 | assert args.kb_ratio <= 1 93 | cnt = 0 94 | for s in kb: 95 | for (r, o) in kb[s]: 96 | if random.random() < args.kb_ratio: 97 | triples.append((s, o, [r])) 98 | cnt += 2 99 | triples.append((o, s, [r+'_inv'])) 100 | print('add {} ({}%) structured triples'.format(cnt, args.kb_ratio*100)) 101 | 102 | # add self relation 103 | if args.add_self == 1: 104 | print('add self relations') 105 | entities = set() 106 | for sub, obj, desc in triples: 107 | entities.add(sub) 108 | for e in entities: 109 | triples.append((e, e, [SELF_PH])) 110 | else: 111 | print('NOT self relations') 112 | 113 | for tri in triples[:100]: 114 | print(tri) 115 | print('===') 116 | print('number of triples: {}'.format(len(triples))) 117 | 118 | 119 | triples = sorted(triples) 120 | # for tri in triples[:50]: 121 | # print(tri) 122 | 123 | # build vocabulary 124 | word_counter = Counter() 125 | for sub, obj, desc in triples: 126 | add_item_to_x2id(sub, vocab['entity2id']) 127 | add_item_to_x2id(obj, vocab['entity2id']) 128 | word_counter.update(desc) 129 | cnt = 0 130 | for w, c in word_counter.items(): 131 | if w and c >= args.min_cnt: 132 | add_item_to_x2id(w, vocab['word2id']) 133 | else: 134 | cnt += 1 135 | print('remove {} words whose frequency < {}'.format(cnt, args.min_cnt)) 136 | print('vocabulary size: {} entities, {} words'.format(len(vocab['entity2id']), len(vocab['word2id']))) 137 | 138 | # [start, end) of each entity 139 | knowledge_range = np.full((len(vocab['entity2id']), 2), -1) 140 | start = 0 141 | for i in range(len(triples)): 142 | if i > 0 and triples[i][0] != triples[i-1][0]: 143 | idx = vocab['entity2id'][triples[i-1][0]] 144 | knowledge_range[idx] = (start, i) 145 | start = i 146 | idx = vocab['entity2id'][triples[-1][0]] 147 | knowledge_range[idx] = (start, len(triples)) 148 | 149 | # Encode 150 | so_pair = [[vocab['entity2id'][s], vocab['entity2id'][o]] for s,o,_ in triples] 151 | descs = [[vocab['word2id'].get(w, vocab['word2id']['']) for w in d] for _,_,d in triples] 152 | for d in descs: 153 | while len(d) < args.max_desc: 154 | d.append(vocab['word2id']['']) 155 | 156 | so_pair = np.asarray(so_pair, dtype=np.int64) 157 | knowledge_range = np.asarray(knowledge_range, dtype=np.int64) 158 | descs = np.asarray(descs, dtype=np.int64) 159 | print(so_pair.shape, knowledge_range.shape, descs.shape) 160 | 161 | with open(os.path.join(args.output_dir, 'wiki.pt'), 'wb') as f: 162 | pickle.dump(so_pair, f) 163 | pickle.dump(knowledge_range, f) 164 | pickle.dump(descs, f) 165 | 166 | print('finish wiki process\n=====') 167 | 168 | 169 | def encode_qa(args, vocab): 170 | pattern = re.compile(r'\[(.*)\]') 171 | hops = ['%d-hop'%((int)(num)) for num in args.num_hop.split(',')] 172 | drop_cnt = 0 173 | datasets = [] 174 | for dataset in ['train', 'test', 'dev']: 175 | data = [] 176 | for hop in hops: 177 | with open(os.path.join(args.input_dir, (hop + '/vanilla/qa_%s.txt'%(dataset)))) as f: 178 | qas = f.readlines() 179 | for qa in qas: 180 | question, answers = qa.strip().split('\t') 181 | topic_entity = re.search(pattern, question).group(1) 182 | if args.replace_es: 183 | question = re.sub(r'\[.*\]', 'E_S', question) 184 | else: 185 | question = question.replace('[', '').replace(']', '') 186 | answers = answers.split('|') 187 | 188 | # Note: some entities are not included in wiki 189 | # assert topic_entity in vocab['entity2id'] 190 | # for answer in answers: 191 | # assert answer in vocab['entity2id'] 192 | answers = [a for a in answers if a in vocab['entity2id']] 193 | if topic_entity not in vocab['entity2id'] or len(answers) == 0: 194 | drop_cnt += 1 195 | continue 196 | 197 | data.append({'question':question, 'topic_entity':topic_entity, 'answers':answers, 'hop':int(hop[0])}) 198 | datasets.append(data) 199 | json.dump(data, open(os.path.join(args.output_dir, '%s.json'%(dataset)), 'w')) 200 | 201 | train_set, test_set, val_set = datasets[0], datasets[1], datasets[2] 202 | print('size of training data: {}'.format(len(train_set))) 203 | print('size of test data: {}'.format(len(test_set))) 204 | print('size of valid data: {}'.format(len(val_set))) 205 | print('drop number: {}'.format(drop_cnt)) 206 | print('=====') 207 | print('Build question vocabulary') 208 | word_counter = Counter() 209 | for qa in tqdm(train_set): 210 | tokens = word_tokenize(qa['question'].lower()) 211 | word_counter.update(tokens) 212 | for w, c in word_counter.items(): 213 | if w and c >= args.min_cnt: 214 | add_item_to_x2id(w, vocab['word2id']) 215 | print('number of word in dict: {}'.format(len(vocab['word2id']))) 216 | with open(os.path.join(args.output_dir, 'vocab.json'), 'w') as f: 217 | json.dump(vocab, f, indent=2) 218 | 219 | for name, dataset in zip(('train', 'val', 'test'), (train_set, val_set, test_set)): 220 | print('Encode {} set'.format(name)) 221 | outputs = encode_dataset(vocab, dataset) 222 | print('shape of questions, topic_entities, answers, hops:') 223 | with open(os.path.join(args.output_dir, '{}.pt'.format(name)), 'wb') as f: 224 | for o in outputs: 225 | print(o.shape) 226 | pickle.dump(o, f) 227 | 228 | def encode_dataset(vocab, dataset): 229 | questions = [] 230 | topic_entities = [] 231 | answers = [] 232 | hops = [] 233 | for qa in tqdm(dataset): 234 | assert len(qa['topic_entity']) > 0 235 | questions.append([vocab['word2id'].get(w, vocab['word2id']['']) for w in word_tokenize(qa['question'].lower())]) 236 | topic_entities.append([vocab['entity2id'][qa['topic_entity']]]) 237 | answers.append([vocab['entity2id'][answer] for answer in qa['answers']]) 238 | hops.append(qa['hop']) 239 | 240 | # question padding 241 | max_len = max(len(q) for q in questions) 242 | print('max question length:{}'.format(max_len)) 243 | for q in questions: 244 | while len(q) < max_len: 245 | q.append(vocab['word2id']['']) 246 | questions = np.asarray(questions, dtype=np.int32) 247 | topic_entities = np.asarray(topic_entities, dtype=np.int32) 248 | max_len = max(len(a) for a in answers) 249 | print('max answer length:{}'.format(max_len)) 250 | for a in answers: 251 | while len(a) < max_len: 252 | a.append(DUMMY_ENTITY_ID) 253 | answers = np.asarray(answers, dtype=np.int32) 254 | hops = np.asarray(hops, dtype=np.int8) 255 | return questions, topic_entities, answers, hops 256 | 257 | def main(): 258 | parser = argparse.ArgumentParser() 259 | parser.add_argument('--input_dir', required=True, type = str) 260 | parser.add_argument('--output_dir', required=True, type = str) 261 | parser.add_argument('--kb_ratio', type=float, default=0, 262 | help='How many structured knowledge will be incorporated into textual knowledge. Note they are randomly selected.') 263 | parser.add_argument('--add_self', type = int, default = 0, help='whether add self relation, 0 means not') 264 | 265 | parser.add_argument('--min_cnt', type=int, default=5) 266 | parser.add_argument('--max_desc', type=int, default=16) 267 | parser.add_argument('--num_hop', type = str, default = '1, 2, 3') 268 | parser.add_argument('--replace_es', type = int, default = 1) 269 | args = parser.parse_args() 270 | print(args) 271 | if not os.path.isdir(args.output_dir): 272 | os.makedirs(args.output_dir) 273 | 274 | print('Init vocabulary') 275 | vocab = { 276 | 'word2id': init_word2id(), 277 | 'entity2id': init_entity2id(), 278 | } 279 | 280 | print('Encode kb') 281 | encode_kb(args, vocab) 282 | 283 | print('Encode qa') 284 | encode_qa(args, vocab) 285 | 286 | if __name__ == '__main__': 287 | main() 288 | -------------------------------------------------------------------------------- /MetaQA-Text/train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.optim as optim 4 | import torch.nn as nn 5 | import argparse 6 | import shutil 7 | import numpy as np 8 | from tqdm import tqdm 9 | import time 10 | from utils.misc import MetricLogger, load_glove, idx_to_one_hot, UseStyle, RAdam 11 | from .data import DataLoader 12 | from .model import TransferNet 13 | from .predict import validate 14 | import logging 15 | logging.basicConfig(level=logging.INFO, format='%(asctime)s %(levelname)-8s %(message)s') 16 | logFormatter = logging.Formatter('%(asctime)s %(levelname)-8s %(message)s') 17 | rootLogger = logging.getLogger() 18 | 19 | torch.set_num_threads(1) # avoid using multiple cpus 20 | 21 | 22 | def train(args): 23 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 24 | 25 | logging.info("Create train_loader, val_loader and test_loader.........") 26 | vocab_json = os.path.join(args.input_dir, 'vocab.json') 27 | train_pt = os.path.join(args.input_dir, 'train.pt') 28 | val_pt = os.path.join(args.input_dir, 'val.pt') 29 | test_pt = os.path.join(args.input_dir, 'test.pt') 30 | train_loader = DataLoader(vocab_json, train_pt, args.batch_size, args.limit_hop, training=True) 31 | val_loader = DataLoader(vocab_json, val_pt, args.batch_size, args.limit_hop) 32 | test_loader = DataLoader(vocab_json, test_pt, args.batch_size, args.limit_hop) 33 | vocab = train_loader.vocab 34 | 35 | logging.info("Create model.........") 36 | pretrained = load_glove(args.glove_pt, vocab['id2word']) 37 | model = TransferNet(args, vocab) 38 | model.word_embeddings.weight.data = torch.Tensor(pretrained) 39 | if not args.ckpt == None: 40 | logging.info("Load ckpt from {}".format(args.ckpt)) 41 | model.load_state_dict(torch.load(args.ckpt)) 42 | model = model.to(device) 43 | model.kb_pair = model.kb_pair.to(device) 44 | model.kb_range = model.kb_range.to(device) 45 | model.kb_desc = model.kb_desc.to(device) 46 | 47 | logging.info(model) 48 | if args.opt == 'adam': 49 | optimizer = optim.Adam(model.parameters(), args.lr, weight_decay=args.weight_decay) 50 | elif args.opt == 'radam': 51 | optimizer = RAdam(model.parameters(), args.lr, weight_decay=args.weight_decay) 52 | elif args.opt == 'sgd': 53 | optimizer = optim.SGD(model.parameters(), args.lr, weight_decay=args.weight_decay) 54 | elif args.opt == 'adagrad': 55 | optimizer = optim.Adagrad(model.parameters(), args.lr, weight_decay=args.weight_decay) 56 | else: 57 | raise NotImplementedError 58 | # scheduler = optim.lr_scheduler.MultiStepLR(optimizer=optimizer, milestones=[4], gamma=0.1) 59 | scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'max', factor=0.1, patience=5) 60 | 61 | meters = MetricLogger(delimiter=" ") 62 | # validate(args, model, val_loader, device) 63 | logging.info("Start training........") 64 | 65 | for epoch in range(args.num_epoch): 66 | model.train() 67 | if args.curriculum==1: 68 | if epoch < args.stop_curri_epo: 69 | train_loader = DataLoader(vocab_json, train_pt, args.batch_size, args.limit_hop, training=True, curriculum=True) 70 | elif epoch == args.stop_curri_epo: 71 | train_loader = DataLoader(vocab_json, train_pt, args.batch_size, args.limit_hop, training=True) 72 | 73 | for iteration, batch in enumerate(train_loader): 74 | iteration = iteration + 1 75 | 76 | question, topic_entity, answer, hop = batch 77 | question = question.to(device) 78 | topic_entity = idx_to_one_hot(topic_entity, len(vocab['entity2id'])).to(device) 79 | answer = idx_to_one_hot(answer, len(vocab['entity2id'])).to(device) 80 | answer[:, 0] = 0 81 | hop = hop.to(device) 82 | loss = model(question, topic_entity, answer, hop) 83 | optimizer.zero_grad() 84 | if isinstance(loss, dict): 85 | total_loss = sum(loss.values()) 86 | meters.update(**{k:v.item() for k,v in loss.items()}) 87 | else: 88 | total_loss = loss 89 | meters.update(loss=loss.item()) 90 | total_loss.backward() 91 | nn.utils.clip_grad_value_(model.parameters(), 0.5) 92 | nn.utils.clip_grad_norm_(model.parameters(), 2) 93 | optimizer.step() 94 | 95 | if iteration % (len(train_loader) // 100) == 0: 96 | logging.info( 97 | meters.delimiter.join( 98 | [ 99 | "progress: {progress:.3f}", 100 | "{meters}", 101 | "lr: {lr:.6f}", 102 | ] 103 | ).format( 104 | progress=epoch + iteration / len(train_loader), 105 | meters=str(meters), 106 | lr=optimizer.param_groups[0]["lr"], 107 | ) 108 | ) 109 | 110 | acc = validate(args, model, val_loader, device) 111 | logging.info(acc) 112 | scheduler.step(acc['all']) 113 | torch.save(model.state_dict(), os.path.join(args.save_dir, 'model_epoch-{}_acc-{:.4f}.pt'.format(epoch, acc['all']))) 114 | 115 | 116 | def main(): 117 | parser = argparse.ArgumentParser() 118 | # input and output 119 | parser.add_argument('--input_dir', required=True) 120 | parser.add_argument('--save_dir', required=True, help='path to save checkpoints and logs') 121 | parser.add_argument('--glove_pt', default='/data/sjx/glove.840B.300d.py36.pt') 122 | parser.add_argument('--ckpt', default = None) 123 | # training parameters 124 | parser.add_argument('--lr', default=0.001, type=float) 125 | parser.add_argument('--weight_decay', default=1e-5, type=float) 126 | parser.add_argument('--num_epoch', default=20, type=int) 127 | parser.add_argument('--batch_size', default=64, type=int) 128 | parser.add_argument('--seed', type=int, default=666, help='random seed') 129 | parser.add_argument('--opt', default='radam', type = str) 130 | parser.add_argument('--curriculum', default=0, type=int, help='whether use curriculum learning, 0 means not') 131 | parser.add_argument('--stop_curri_epo', default=3, type=int, help='at which epoch currirulum learning stops') 132 | # model hyperparameters 133 | parser.add_argument('--aux_hop', type=int, default=1, choices=[0, 1], help='utilize question hop to constrain the probability of self relation') 134 | parser.add_argument('--num_steps', default=3, type=int) 135 | parser.add_argument('--dim_word', default=300, type=int) 136 | parser.add_argument('--dim_hidden', default=768, type=int) 137 | parser.add_argument('--ent_act_thres', default=0.7, type=float, help='activate an entity when its score exceeds this value') # 0.9 may cause convergency issue 138 | parser.add_argument('--max_active', default=400, type=int, help='max number of active path at each step') 139 | parser.add_argument('--limit_hop', default=-1, type=int, help='only keep questions of certain hop, -1 means all questions') 140 | args = parser.parse_args() 141 | 142 | # make logging.info display into both shell and file 143 | if not os.path.exists(args.save_dir): 144 | os.makedirs(args.save_dir) 145 | time_ = time.strftime("%Y-%m-%d-%H:%M:%S", time.localtime()) 146 | args.log_name = time_ + '_{}_{}_{}.log'.format(args.opt, args.lr, args.batch_size) 147 | fileHandler = logging.FileHandler(os.path.join(args.save_dir, args.log_name)) 148 | fileHandler.setFormatter(logFormatter) 149 | rootLogger.addHandler(fileHandler) 150 | # args display 151 | for k, v in vars(args).items(): 152 | logging.info(k+':'+str(v)) 153 | 154 | torch.backends.cudnn.deterministic = True 155 | torch.backends.cudnn.benchmark = False 156 | # set random seed 157 | torch.manual_seed(args.seed) 158 | np.random.seed(args.seed) 159 | 160 | train(args) 161 | 162 | 163 | if __name__ == '__main__': 164 | main() 165 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # TransferNet 2 | 3 | Pytorch implementation of EMNLP 2021 paper 4 | 5 | **[TransferNet: An Effective and Transparent Framework for Multi-hop Question Answering over Relation Graph](https://arxiv.org/abs/2104.07302)** 6 |
7 | [Jiaxin Shi](https://shijx12.github.io), Shulin Cao, Lei Hou, [Juanzi Li](http://keg.cs.tsinghua.edu.cn/persons/ljz/), [Hanwang Zhang](http://www.ntu.edu.sg/home/hanwangzhang/#aboutme) 8 | 9 | We perform transparent multi-hop reasoning over relation graphs of label form (i.e., knowledge graph) and text form. This is an example: 10 | 11 |
12 | 13 |

14 | 15 | If you find this code useful in your research, please cite 16 | ``` tex 17 | @inproceedings{shi2021transfernet, 18 | title={TransferNet: An Effective and Transparent Framework for Multi-hop Question Answering over Relation Graph}, 19 | author={Jiaxin Shi, Shulin Cao, Lei Hou, Juanzi Li, Hanwang Zhang}, 20 | booktitle={EMNLP}, 21 | year={2021} 22 | } 23 | ``` 24 | 25 | ## dependencies 26 | - pytorch>=1.2.0 27 | - [transformers](https://github.com/huggingface/transformers) 28 | - tqdm 29 | - nltk 30 | - shutil 31 | 32 | ## Prepare Datasets 33 | - [MetaQA](https://goo.gl/f3AmcY), we only use its vanilla version. 34 | - [MovieQA](http://www.thespermwhale.com/jaseweston/babi/movieqa.tar.gz), we need its `knowledge_source/wiki.txt` as the text corpus for our MetaQA-Text experiments. Copy the file into the folder of MetaQA, and put it together with `kb.txt`. The files of MetaQA should be something like 35 | ```shell 36 | MetaQA 37 | +-- kb 38 | | +-- kb.txt 39 | | +-- wiki.txt 40 | +-- 1-hop 41 | | +-- vanilla 42 | | | +-- qa_train.txt 43 | | | +-- qa_dev.txt 44 | | | +-- qa_test.txt 45 | +-- 2-hop 46 | +-- 3-hop 47 | ``` 48 | - [WebQSP](https://drive.google.com/drive/folders/1RlqGBMo45lTmWz9MUPTq-0KcjSd3ujxc?usp=sharing), which has been processed by [EmbedKGQA](https://github.com/malllabiisc/EmbedKGQA). 49 | - [ComplexWebQuestions](https://drive.google.com/file/d/1ua7h88kJ6dECih6uumLeOIV9a3QNdP-g/view?usp=sharing), which has been processed by [NSM](https://github.com/RichardHGL/WSDM2021_NSM). 50 | - [GloVe 300d pretrained vector](http://nlp.stanford.edu/data/glove.840B.300d.zip), which is used in the BiGRU model. After unzipping it, you need to convert the txt file to pickle file by 51 | ``` shell 52 | python pickle_glove.py --txt --pt 53 | ``` 54 | 55 | 56 | ## Experiments 57 | 58 | ### MetaQA-KB 59 | 60 | 1. Preprocess 61 | ```shell 62 | python -m MetaQA-KB.preprocess --input_dir --output_dir 63 | ``` 64 | 65 | 2. Train 66 | ```shell 67 | python -m MetaQA-KB.train --glove_pt --input_dir --save_dir 68 | ``` 69 | 70 | 3. Predict on the test set 71 | ```shell 72 | python -m MetaQA-KB.predict --input_dir --ckpt --mode test 73 | ``` 74 | 75 | 4. Visualize the reasoning process. It will enter an IPython environment after showing the information of each sample. You can print more variables that you are insterested in. To stop the process, you need to quit the IPython by `Ctrl+D` and then kill the loop by `Ctrl+C` immediately. 76 | ```shell 77 | python -m MetaQA-KB.predict --input_dir --ckpt --mode vis 78 | ``` 79 | 80 | ### MetaQA-Text 81 | 82 | 1. Preprocess 83 | ```shell 84 | python -m MetaQA-Text.preprocess --input_dir --output_dir 85 | ``` 86 | 87 | 2. Train 88 | ```shell 89 | python -m MetaQA-Text.train --glove_pt --input_dir --save_dir 90 | ``` 91 | 92 | The scripts for inference and visualization are the same as **MetaQA-KB**. Just change the python module to `MetaQA-Text.predict`. 93 | 94 | 95 | ### MetaQA-Text + 50% KB 96 | 97 | 1. Preprocess 98 | ```shell 99 | python -m MetaQA-Text.preprocess --input_dir --output_dir --kb_ratio 0.5 100 | ``` 101 | 102 | 2. Train, it needs more active paths than MetaQA-Text 103 | ```shell 104 | python -m MetaQA-Text.train --input_dir --save_dir --max_active 800 --batch_size 32 105 | ``` 106 | 107 | The scripts for inference and visualization are the same as **MetaQA-Text**. 108 | 109 | 110 | ### WebQSP 111 | WebQSP does not need preprocess. We can directly start the training: 112 | 113 | ```shell 114 | python -m WebQSP.train --input_dir --save_dir 115 | ``` 116 | 117 | 118 | ### ComplexWebQuestions 119 | Similar to WebQSP, CWQ does not need preprocess. We can directly start the training: 120 | 121 | ```shell 122 | python -m CompWebQ.train --input_dir --save_dir 123 | ``` 124 | -------------------------------------------------------------------------------- /WebQSP/data.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | import pickle 4 | from collections import defaultdict 5 | from transformers import AutoTokenizer 6 | from utils.misc import invert_dict 7 | 8 | def collate(batch): 9 | batch = list(zip(*batch)) 10 | topic_entity, question, answer, entity_range = batch 11 | topic_entity = torch.stack(topic_entity) 12 | question = {k:torch.cat([q[k] for q in question], dim=0) for k in question[0]} 13 | answer = torch.stack(answer) 14 | entity_range = torch.stack(entity_range) 15 | return topic_entity, question, answer, entity_range 16 | 17 | 18 | class Dataset(torch.utils.data.Dataset): 19 | def __init__(self, questions, ent2id): 20 | self.questions = questions 21 | self.ent2id = ent2id 22 | 23 | def __getitem__(self, index): 24 | topic_entity, question, answer, entity_range = self.questions[index] 25 | topic_entity = self.toOneHot(topic_entity) 26 | answer = self.toOneHot(answer) 27 | entity_range = self.toOneHot(entity_range) 28 | return topic_entity, question, answer, entity_range 29 | 30 | def __len__(self): 31 | return len(self.questions) 32 | 33 | def toOneHot(self, indices): 34 | indices = torch.LongTensor(indices) 35 | vec_len = len(self.ent2id) 36 | one_hot = torch.FloatTensor(vec_len) 37 | one_hot.zero_() 38 | one_hot.scatter_(0, indices, 1) 39 | return one_hot 40 | 41 | 42 | class DataLoader(torch.utils.data.DataLoader): 43 | def __init__(self, input_dir, fn, bert_name, ent2id, rel2id, batch_size, training=False): 44 | print('Reading questions from {}'.format(fn)) 45 | self.tokenizer = AutoTokenizer.from_pretrained(bert_name) 46 | self.ent2id = ent2id 47 | self.rel2id = rel2id 48 | self.id2ent = invert_dict(ent2id) 49 | self.id2rel = invert_dict(rel2id) 50 | 51 | 52 | 53 | sub_map = defaultdict(list) 54 | so_map = defaultdict(list) 55 | for line in open(os.path.join(input_dir, 'fbwq_full/train.txt')): 56 | l = line.strip().split('\t') 57 | s = l[0].strip() 58 | p = l[1].strip() 59 | o = l[2].strip() 60 | sub_map[s].append((p, o)) 61 | so_map[(s, o)].append(p) 62 | 63 | 64 | data = [] 65 | for line in open(fn): 66 | line = line.strip() 67 | if line == '': 68 | continue 69 | line = line.split('\t') 70 | # if no answer 71 | if len(line) != 2: 72 | continue 73 | question = line[0].split('[') 74 | question_1 = question[0] 75 | question_2 = question[1].split(']') 76 | head = question_2[0].strip() 77 | question_2 = question_2[1] 78 | # question = question_1 + 'NE' + question_2 79 | question = question_1.strip() 80 | ans = line[1].split('|') 81 | 82 | 83 | # if (head, ans[0]) not in so_map: 84 | # continue 85 | 86 | entity_range = set() 87 | for p, o in sub_map[head]: 88 | entity_range.add(o) 89 | for p2, o2 in sub_map[o]: 90 | entity_range.add(o2) 91 | entity_range = [ent2id[o] for o in entity_range] 92 | 93 | head = [ent2id[head]] 94 | question = self.tokenizer(question.strip(), max_length=64, padding='max_length', return_tensors="pt") 95 | ans = [ent2id[a] for a in ans] 96 | data.append([head, question, ans, entity_range]) 97 | 98 | print('data number: {}'.format(len(data))) 99 | 100 | dataset = Dataset(data, ent2id) 101 | 102 | super().__init__( 103 | dataset, 104 | batch_size=batch_size, 105 | shuffle=training, 106 | collate_fn=collate, 107 | ) 108 | 109 | 110 | def load_data(input_dir, bert_name, batch_size): 111 | cache_fn = os.path.join(input_dir, 'processed.pt') 112 | if os.path.exists(cache_fn): 113 | print('Read from cache file: {} (NOTE: delete it if you modified data loading process)'.format(cache_fn)) 114 | with open(cache_fn, 'rb') as fp: 115 | ent2id, rel2id, triples, train_data, test_data = pickle.load(fp) 116 | print('Train number: {}, test number: {}'.format(len(train_data.dataset), len(test_data.dataset))) 117 | else: 118 | print('Read data...') 119 | ent2id = {} 120 | for line in open(os.path.join(input_dir, 'fbwq_full/entities.dict')): 121 | l = line.strip().split('\t') 122 | ent2id[l[0].strip()] = len(ent2id) 123 | # print(len(ent2id)) 124 | # print(max(ent2id.values())) 125 | rel2id = {} 126 | for line in open(os.path.join(input_dir, 'fbwq_full/relations.dict')): 127 | l = line.strip().split('\t') 128 | rel2id[l[0].strip()] = int(l[1]) 129 | 130 | triples = [] 131 | for line in open(os.path.join(input_dir, 'fbwq_full/train.txt')): 132 | l = line.strip().split('\t') 133 | s = ent2id[l[0].strip()] 134 | p = rel2id[l[1].strip()] 135 | o = ent2id[l[2].strip()] 136 | triples.append((s, p, o)) 137 | p_rev = rel2id[l[1].strip()+'_reverse'] 138 | triples.append((o, p_rev, s)) 139 | triples = torch.LongTensor(triples) 140 | 141 | train_data = DataLoader(input_dir, os.path.join(input_dir, 'QA_data/WebQuestionsSP/qa_train_webqsp.txt'), bert_name, ent2id, rel2id, batch_size, training=True) 142 | test_data = DataLoader(input_dir, os.path.join(input_dir, 'QA_data/WebQuestionsSP/qa_test_webqsp.txt'), bert_name, ent2id, rel2id, batch_size) 143 | 144 | with open(cache_fn, 'wb') as fp: 145 | pickle.dump((ent2id, rel2id, triples, train_data, test_data), fp) 146 | 147 | return ent2id, rel2id, triples, train_data, test_data 148 | -------------------------------------------------------------------------------- /WebQSP/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import math 4 | from transformers import AutoModel 5 | from utils.BiGRU import GRU, BiGRU 6 | 7 | class TransferNet(nn.Module): 8 | def __init__(self, args, ent2id, rel2id, triples): 9 | super().__init__() 10 | self.args = args 11 | self.num_steps = 2 12 | num_relations = len(rel2id) 13 | # self.triples = triples 14 | 15 | Tsize = len(triples) 16 | Esize = len(ent2id) 17 | idx = torch.LongTensor([i for i in range(Tsize)]) 18 | self.Msubj = torch.sparse.FloatTensor( 19 | torch.stack((idx, triples[:,0])), torch.FloatTensor([1] * Tsize), torch.Size([Tsize, Esize])) 20 | self.Mobj = torch.sparse.FloatTensor( 21 | torch.stack((idx, triples[:,2])), torch.FloatTensor([1] * Tsize), torch.Size([Tsize, Esize])) 22 | self.Mrel = torch.sparse.FloatTensor( 23 | torch.stack((idx, triples[:,1])), torch.FloatTensor([1] * Tsize), torch.Size([Tsize, num_relations])) 24 | print('triple size: {}'.format(Tsize)) 25 | 26 | self.bert_encoder = AutoModel.from_pretrained(args.bert_name, return_dict=True) 27 | dim_hidden = self.bert_encoder.config.hidden_size 28 | 29 | self.step_encoders = [] 30 | for i in range(self.num_steps): 31 | m = nn.Sequential( 32 | nn.Linear(dim_hidden, dim_hidden), 33 | nn.Tanh() 34 | ) 35 | self.step_encoders.append(m) 36 | self.add_module('step_encoders_{}'.format(i), m) 37 | 38 | self.rel_classifier = nn.Linear(dim_hidden, num_relations) 39 | 40 | self.hop_selector = nn.Linear(dim_hidden, self.num_steps) 41 | 42 | 43 | def follow(self, e, r): 44 | x = torch.sparse.mm(self.Msubj, e.t()) * torch.sparse.mm(self.Mrel, r.t()) 45 | return torch.sparse.mm(self.Mobj.t(), x).t() # [bsz, Esize] 46 | 47 | def forward(self, heads, questions, answers=None, entity_range=None): 48 | q = self.bert_encoder(**questions) 49 | q_embeddings, q_word_h = q.pooler_output, q.last_hidden_state # (bsz, dim_h), (bsz, len, dim_h) 50 | 51 | device = heads.device 52 | last_e = heads 53 | word_attns = [] 54 | rel_probs = [] 55 | ent_probs = [] 56 | for t in range(self.num_steps): 57 | cq_t = self.step_encoders[t](q_embeddings) # [bsz, dim_h] 58 | q_logits = torch.sum(cq_t.unsqueeze(1) * q_word_h, dim=2) # [bsz, max_q] 59 | q_dist = torch.softmax(q_logits, 1) # [bsz, max_q] 60 | q_dist = q_dist * questions['attention_mask'].float() 61 | q_dist = q_dist / (torch.sum(q_dist, dim=1, keepdim=True) + 1e-6) # [bsz, max_q] 62 | word_attns.append(q_dist) 63 | ctx_h = (q_dist.unsqueeze(1) @ q_word_h).squeeze(1) # [bsz, dim_h] 64 | 65 | rel_logit = self.rel_classifier(ctx_h) # [bsz, num_relations] 66 | # rel_dist = torch.softmax(rel_logit, 1) # bad 67 | rel_dist = torch.sigmoid(rel_logit) 68 | rel_probs.append(rel_dist) 69 | 70 | # sub, rel, obj = self.triples[:,0], self.triples[:,1], self.triples[:,2] 71 | # sub_p = last_e[:, sub] # [bsz, #tri] 72 | # rel_p = rel_dist[:, rel] # [bsz, #tri] 73 | # obj_p = sub_p * rel_p 74 | # last_e = torch.index_add(torch.zeros_like(last_e), 1, obj, obj_p) 75 | 76 | last_e = self.follow(last_e, rel_dist) # faster than index_add 77 | 78 | # reshape >1 scores to 1 in a differentiable way 79 | m = last_e.gt(1).float() 80 | z = (m * last_e + (1-m)).detach() 81 | last_e = last_e / z 82 | 83 | ent_probs.append(last_e) 84 | 85 | hop_res = torch.stack(ent_probs, dim=1) # [bsz, num_hop, num_ent] 86 | hop_attn = torch.softmax(self.hop_selector(q_embeddings), dim=1).unsqueeze(2) # [bsz, num_hop, 1] 87 | last_e = torch.sum(hop_res * hop_attn, dim=1) # [bsz, num_ent] 88 | 89 | if not self.training: 90 | return { 91 | 'e_score': last_e, 92 | 'word_attns': word_attns, 93 | 'rel_probs': rel_probs, 94 | 'ent_probs': ent_probs, 95 | 'hop_attn': hop_attn.squeeze(2) 96 | } 97 | else: 98 | weight = answers * 99 + 1 99 | loss = torch.sum(entity_range * weight * torch.pow(last_e - answers, 2)) / torch.sum(entity_range * weight) 100 | 101 | return {'loss': loss} 102 | -------------------------------------------------------------------------------- /WebQSP/predict.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.nn as nn 4 | import argparse 5 | from tqdm import tqdm 6 | from collections import defaultdict 7 | from utils.misc import batch_device 8 | from .data import load_data 9 | from .model import TransferNet 10 | 11 | from IPython import embed 12 | 13 | 14 | def validate(args, model, data, device, verbose = False): 15 | model.eval() 16 | count = 0 17 | correct = 0 18 | hop_count = defaultdict(list) 19 | with torch.no_grad(): 20 | for batch in tqdm(data, total=len(data)): 21 | outputs = model(*batch_device(batch, device)) # [bsz, Esize] 22 | e_score = outputs['e_score'].cpu() 23 | scores, idx = torch.max(e_score, dim = 1) # [bsz], [bsz] 24 | match_score = torch.gather(batch[2], 1, idx.unsqueeze(-1)).squeeze().tolist() 25 | count += len(match_score) 26 | correct += sum(match_score) 27 | for i in range(len(match_score)): 28 | h = outputs['hop_attn'][i].argmax().item() 29 | hop_count[h].append(match_score[i]) 30 | 31 | if verbose: 32 | answers = batch[2] 33 | for i in range(len(match_score)): 34 | if match_score[i] == 0: 35 | print('================================================================') 36 | question_ids = batch[1]['input_ids'][i].tolist() 37 | question_tokens = data.tokenizer.convert_ids_to_tokens(question_ids) 38 | print(' '.join(question_tokens)) 39 | topic_id = batch[0][i].argmax(0).item() 40 | print('> topic entity: {}'.format(data.id2ent[topic_id])) 41 | for t in range(2): 42 | print('>>>>>>> step {}'.format(t)) 43 | tmp = ' '.join(['{}: {:.3f}'.format(x, y) for x,y in 44 | zip(question_tokens, outputs['word_attns'][t][i].tolist())]) 45 | print('> Attention: ' + tmp) 46 | print('> Relation:') 47 | rel_idx = outputs['rel_probs'][t][i].gt(0.9).nonzero().squeeze(1).tolist() 48 | for x in rel_idx: 49 | print(' {}: {:.3f}'.format(data.id2rel[x], outputs['rel_probs'][t][i][x].item())) 50 | 51 | print('> Entity: {}'.format('; '.join([data.id2ent[_] for _ in outputs['ent_probs'][t][i].gt(0.8).nonzero().squeeze(1).tolist()]))) 52 | print('----') 53 | print('> max is {}'.format(data.id2ent[idx[i].item()])) 54 | print('> golden: {}'.format('; '.join([data.id2ent[_] for _ in answers[i].gt(0.9).nonzero().squeeze(1).tolist()]))) 55 | print('> prediction: {}'.format('; '.join([data.id2ent[_] for _ in e_score[i].gt(0.9).nonzero().squeeze(1).tolist()]))) 56 | print(' '.join(question_tokens)) 57 | print(outputs['hop_attn'][i].tolist()) 58 | embed() 59 | acc = correct / count 60 | print(acc) 61 | print('pred hop accuracy: 1-hop {} (total {}), 2-hop {} (total {})'.format( 62 | sum(hop_count[0])/(len(hop_count[0])+0.1), 63 | len(hop_count[0]), 64 | sum(hop_count[1])/(len(hop_count[1])+0.1), 65 | len(hop_count[1]), 66 | )) 67 | return acc 68 | 69 | 70 | def main(): 71 | parser = argparse.ArgumentParser() 72 | # input and output 73 | parser.add_argument('--input_dir', default = './input') 74 | parser.add_argument('--ckpt', required = True) 75 | parser.add_argument('--mode', default='val', choices=['val', 'vis', 'test']) 76 | args = parser.parse_args() 77 | 78 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 79 | ent2id, rel2id, triples, train_loader, val_loader = load_data(args.input_dir, 16) 80 | 81 | model = TransferNet(args, ent2id, rel2id, triples) 82 | missing, unexpected = model.load_state_dict(torch.load(args.ckpt), strict=False) 83 | if missing: 84 | print("Missing keys: {}".format("; ".join(missing))) 85 | if unexpected: 86 | print("Unexpected keys: {}".format("; ".join(unexpected))) 87 | model = model.to(device) 88 | # model.triples = model.triples.to(device) 89 | model.Msubj = model.Msubj.to(device) 90 | model.Mobj = model.Mobj.to(device) 91 | model.Mrel = model.Mrel.to(device) 92 | 93 | if args.mode == 'vis': 94 | validate(args, model, val_loader, device, True) 95 | elif args.mode == 'val': 96 | validate(args, model, val_loader, device, False) 97 | 98 | if __name__ == '__main__': 99 | main() 100 | -------------------------------------------------------------------------------- /WebQSP/train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.optim as optim 4 | import torch.nn as nn 5 | import argparse 6 | import shutil 7 | from tqdm import tqdm 8 | import numpy as np 9 | import time 10 | from utils.misc import MetricLogger, batch_device, RAdam 11 | from utils.lr_scheduler import get_linear_schedule_with_warmup 12 | from .data import load_data 13 | from .model import TransferNet 14 | from .predict import validate 15 | from transformers import AdamW 16 | import logging 17 | logging.basicConfig(level=logging.INFO, format='%(asctime)s %(levelname)-8s %(message)s') 18 | logFormatter = logging.Formatter('%(asctime)s %(levelname)-8s %(message)s') 19 | rootLogger = logging.getLogger() 20 | 21 | torch.set_num_threads(1) # avoid using multiple cpus 22 | 23 | 24 | def train(args): 25 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 26 | 27 | ent2id, rel2id, triples, train_loader, val_loader = load_data(args.input_dir, args.bert_name, args.batch_size) 28 | logging.info("Create model.........") 29 | model = TransferNet(args, ent2id, rel2id, triples) 30 | if not args.ckpt == None: 31 | model.load_state_dict(torch.load(args.ckpt)) 32 | model = model.to(device) 33 | # model.triples = model.triples.to(device) 34 | model.Msubj = model.Msubj.to(device) 35 | model.Mobj = model.Mobj.to(device) 36 | model.Mrel = model.Mrel.to(device) 37 | logging.info(model) 38 | 39 | 40 | t_total = len(train_loader) * args.num_epoch 41 | no_decay = ["bias", "LayerNorm.weight"] 42 | bert_param = [(n,p) for n,p in model.named_parameters() if n.startswith('bert_encoder')] 43 | other_param = [(n,p) for n,p in model.named_parameters() if not n.startswith('bert_encoder')] 44 | print('number of bert param: {}'.format(len(bert_param))) 45 | optimizer_grouped_parameters = [ 46 | {'params': [p for n, p in bert_param if not any(nd in n for nd in no_decay)], 47 | 'weight_decay': args.weight_decay, 'lr': args.bert_lr}, 48 | {'params': [p for n, p in bert_param if any(nd in n for nd in no_decay)], 49 | 'weight_decay': 0.0, 'lr': args.bert_lr}, 50 | {'params': [p for n, p in other_param if not any(nd in n for nd in no_decay)], 51 | 'weight_decay': args.weight_decay, 'lr': args.lr}, 52 | {'params': [p for n, p in other_param if any(nd in n for nd in no_decay)], 53 | 'weight_decay': 0.0, 'lr': args.lr}, 54 | ] 55 | # optimizer_grouped_parameters = [{'params':model.parameters(), 'weight_decay': args.weight_decay, 'lr': args.lr}] 56 | if args.opt == 'adam': 57 | optimizer = optim.Adam(optimizer_grouped_parameters) 58 | elif args.opt == 'radam': 59 | optimizer = RAdam(optimizer_grouped_parameters) 60 | elif args.opt == 'sgd': 61 | optimizer = optim.SGD(optimizer_grouped_parameters) 62 | else: 63 | raise NotImplementedError 64 | args.warmup_steps = int(t_total * args.warmup_proportion) 65 | scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total) 66 | meters = MetricLogger(delimiter=" ") 67 | # validate(args, model, val_loader, device) 68 | logging.info("Start training........") 69 | 70 | for epoch in range(args.num_epoch): 71 | model.train() 72 | for iteration, batch in enumerate(train_loader): 73 | iteration = iteration + 1 74 | loss = model(*batch_device(batch, device)) 75 | optimizer.zero_grad() 76 | if isinstance(loss, dict): 77 | if len(loss) > 1: 78 | total_loss = sum(loss.values()) 79 | else: 80 | total_loss = loss[list(loss.keys())[0]] 81 | meters.update(**{k:v.item() for k,v in loss.items()}) 82 | else: 83 | total_loss = loss 84 | meters.update(loss=loss.item()) 85 | total_loss.backward() 86 | nn.utils.clip_grad_value_(model.parameters(), 0.5) 87 | nn.utils.clip_grad_norm_(model.parameters(), 2) 88 | optimizer.step() 89 | scheduler.step() 90 | 91 | if iteration % (len(train_loader) // 10) == 0: 92 | # if True: 93 | 94 | logging.info( 95 | meters.delimiter.join( 96 | [ 97 | "progress: {progress:.3f}", 98 | "{meters}", 99 | "lr: {lr:.6f}", 100 | ] 101 | ).format( 102 | progress=epoch + iteration / len(train_loader), 103 | meters=str(meters), 104 | lr=optimizer.param_groups[0]["lr"], 105 | ) 106 | ) 107 | if (epoch+1)%5 == 0: 108 | acc = validate(args, model, val_loader, device) 109 | logging.info(acc) 110 | torch.save(model.state_dict(), os.path.join(args.save_dir, 'model-{}-{:.4f}.pt'.format(epoch, acc))) 111 | 112 | def main(): 113 | parser = argparse.ArgumentParser() 114 | # input and output 115 | parser.add_argument('--input_dir', required=True, help='path to the data') 116 | parser.add_argument('--save_dir', required=True, help='path to save checkpoints and logs') 117 | parser.add_argument('--ckpt', default = None) 118 | # training parameters 119 | parser.add_argument('--bert_lr', default=3e-5, type=float) 120 | parser.add_argument('--lr', default=0.001, type=float) 121 | parser.add_argument('--weight_decay', default=1e-5, type=float) 122 | parser.add_argument('--num_epoch', default=30, type=int) 123 | parser.add_argument('--batch_size', default=16, type=int) 124 | parser.add_argument('--seed', type=int, default=666, help='random seed') 125 | parser.add_argument('--opt', default='radam', type = str) 126 | parser.add_argument('--warmup_proportion', default=0.1, type = float) 127 | # model parameters 128 | parser.add_argument('--bert_name', default='bert-base-uncased', choices=['roberta-base', 'bert-base-uncased']) 129 | args = parser.parse_args() 130 | 131 | # make logging.info display into both shell and file 132 | if not os.path.exists(args.save_dir): 133 | os.makedirs(args.save_dir) 134 | time_ = time.strftime("%Y-%m-%d-%H:%M:%S", time.localtime()) 135 | args.log_name = time_ + '_{}_{}_{}.log'.format(args.opt, args.lr, args.batch_size) 136 | fileHandler = logging.FileHandler(os.path.join(args.save_dir, args.log_name)) 137 | fileHandler.setFormatter(logFormatter) 138 | rootLogger.addHandler(fileHandler) 139 | # args display 140 | for k, v in vars(args).items(): 141 | logging.info(k+':'+str(v)) 142 | 143 | torch.backends.cudnn.deterministic = True 144 | torch.backends.cudnn.benchmark = False 145 | # set random seed 146 | torch.manual_seed(args.seed) 147 | np.random.seed(args.seed) 148 | 149 | train(args) 150 | 151 | 152 | if __name__ == '__main__': 153 | main() 154 | -------------------------------------------------------------------------------- /example.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shijx12/TransferNet/60bc2416438370b3036cfe33f7a11dc421cbc7b0/example.png -------------------------------------------------------------------------------- /pickle_glove.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import argparse 3 | import numpy as np 4 | 5 | def main(): 6 | parser = argparse.ArgumentParser() 7 | parser.add_argument('--txt', required=True) 8 | parser.add_argument('--pt', required=True) 9 | args = parser.parse_args() 10 | 11 | glove = {} 12 | for line in open(args.txt, encoding='latin-1'): 13 | w, *vector = line.strip().split(' ') 14 | vector = list(map(float, vector)) 15 | vector = np.asarray(vector) 16 | glove[w] = vector 17 | 18 | with open(args.pt, 'wb') as f: 19 | pickle.dump(glove, f) 20 | 21 | if __name__ == '__main__': 22 | main() 23 | -------------------------------------------------------------------------------- /utils/BiGRU.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | class GRU(nn.Module): 5 | 6 | def __init__(self, dim_word, dim_h, num_layers, dropout = 0.0): 7 | super().__init__() 8 | self.encoder = nn.GRU(input_size=dim_word, 9 | hidden_size=dim_h, 10 | num_layers=num_layers, 11 | dropout=dropout, 12 | batch_first=True, 13 | bidirectional=False) 14 | 15 | def forward_one_step(self, input, last_h): 16 | """ 17 | Args: 18 | - input (bsz, 1, w_dim) 19 | - last_h (num_layers, bsz, h_dim) 20 | """ 21 | hidden, new_h = self.encoder(input, last_h) 22 | return hidden, new_h # (bsz, 1, h_dim), (num_layers, bsz, h_dim) 23 | 24 | 25 | def generate_sequence(self, word_lookup_func, h_0, classifier, vocab, max_step, early_stop=True): 26 | bsz = h_0.size(1) 27 | device = h_0.device 28 | start_id, end_id, pad_id = vocab[''], vocab[''], vocab[''] 29 | 30 | latest = torch.LongTensor([start_id]*bsz).to(device) # [bsz, ] 31 | results = [latest] 32 | last_h = h_0 33 | finished = torch.zeros((bsz,)).bool().to(device) # record whether is produced 34 | for i in range(max_step-1): # exclude 35 | word_emb = word_lookup_func(latest).unsqueeze(1) # [bsz, 1, dim_w] 36 | word_h, last_h = self.forward_one_step(word_emb, last_h) # [bsz, 1, dim_h] 37 | 38 | logit = classifier(word_h).squeeze(1) # [bsz, num_func] 39 | latest = torch.argmax(logit, dim=1).long() # [bsz, ] 40 | latest[finished] = pad_id # set to after 41 | results.append(latest) 42 | 43 | finished = finished | latest.eq(end_id).bool() 44 | if early_stop and finished.sum().item() == bsz: 45 | # print('finished at step {}'.format(i)) 46 | break 47 | results = torch.stack(results, dim=1) # [bsz, max_len'] 48 | return results 49 | 50 | 51 | def forward(self, input, length, h_0=None): 52 | """ 53 | Args: 54 | - input (bsz, len, w_dim) 55 | - length (bsz, ) 56 | - h_0 (num_layers, bsz, h_dim) 57 | Return: 58 | - hidden (bsz, len, dim) : hidden state of each word 59 | - output (bsz, dim) : sentence embedding 60 | """ 61 | bsz, max_len = input.size(0), input.size(1) 62 | sorted_seq_lengths, indices = torch.sort(length, descending=True) 63 | _, desorted_indices = torch.sort(indices, descending=False) 64 | input = input[indices] 65 | packed_input = nn.utils.rnn.pack_padded_sequence(input, sorted_seq_lengths, batch_first=True) 66 | if h_0 is None: 67 | hidden, h_n = self.encoder(packed_input) 68 | else: 69 | h_0 = h_0[:, indices] 70 | hidden, h_n = self.encoder(packed_input, h_0) 71 | # h_n is (num_layers, bsz, h_dim) 72 | hidden = nn.utils.rnn.pad_packed_sequence(hidden, batch_first=True, total_length=max_len)[0] # (bsz, max_len, h_dim) 73 | 74 | output = h_n[-1, :, :] # (bsz, h_dim), take the last layer's state 75 | 76 | # recover order 77 | hidden = hidden[desorted_indices] 78 | output = output[desorted_indices] 79 | h_n = h_n[:, desorted_indices] 80 | return hidden, output, h_n 81 | 82 | 83 | 84 | class BiGRU(nn.Module): 85 | 86 | def __init__(self, dim_word, dim_h, num_layers, dropout): 87 | super().__init__() 88 | self.encoder = nn.GRU(input_size=dim_word, 89 | hidden_size=dim_h//2, 90 | num_layers=num_layers, 91 | dropout=dropout, 92 | batch_first=True, 93 | bidirectional=True) 94 | 95 | def forward(self, input, length): 96 | """ 97 | Args: 98 | - input (bsz, len, w_dim) 99 | - length (bsz, ) 100 | Return: 101 | - hidden (bsz, len, dim) : hidden state of each word 102 | - output (bsz, dim) : sentence embedding 103 | - h_n (num_layers * 2, bsz, dim//2) 104 | """ 105 | bsz, max_len = input.size(0), input.size(1) 106 | sorted_seq_lengths, indices = torch.sort(length, descending=True) 107 | _, desorted_indices = torch.sort(indices, descending=False) 108 | input = input[indices] 109 | packed_input = nn.utils.rnn.pack_padded_sequence(input, sorted_seq_lengths, batch_first=True) 110 | hidden, h_n = self.encoder(packed_input) 111 | # h_n is (num_layers * num_directions, bsz, h_dim//2) 112 | hidden = nn.utils.rnn.pad_packed_sequence(hidden, batch_first=True, total_length=max_len)[0] # (bsz, max_len, h_dim) 113 | 114 | output = h_n[-2:, :, :] # (2, bsz, h_dim//2), take the last layer's state 115 | output = output.permute(1, 0, 2).contiguous().view(bsz, -1) # (bsz, h_dim), merge forward and backward h_n 116 | 117 | # recover order 118 | hidden = hidden[desorted_indices] 119 | output = output[desorted_indices] 120 | h_n = h_n[:, desorted_indices] 121 | return hidden, output, h_n 122 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shijx12/TransferNet/60bc2416438370b3036cfe33f7a11dc421cbc7b0/utils/__init__.py -------------------------------------------------------------------------------- /utils/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numpy as np 3 | import warnings 4 | from torch.optim.optimizer import Optimizer 5 | from torch.optim.lr_scheduler import LambdaLR 6 | 7 | def get_constant_schedule(optimizer, last_epoch=-1): 8 | """ Create a schedule with a constant learning rate. 9 | """ 10 | return LambdaLR(optimizer, lambda _: 1, last_epoch=last_epoch) 11 | 12 | 13 | def get_constant_schedule_with_warmup(optimizer, num_warmup_steps, last_epoch=-1): 14 | """ Create a schedule with a constant learning rate preceded by a warmup 15 | period during which the learning rate increases linearly between 0 and 1. 16 | """ 17 | def lr_lambda(current_step): 18 | if current_step < num_warmup_steps: 19 | return float(current_step) / float(max(1.0, num_warmup_steps)) 20 | return 1. 21 | 22 | return LambdaLR(optimizer, lr_lambda, last_epoch=last_epoch) 23 | 24 | 25 | def get_linear_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, last_epoch=-1): 26 | """ Create a schedule with a learning rate that decreases linearly after 27 | linearly increasing during a warmup period. 28 | """ 29 | def lr_lambda(current_step): 30 | if current_step < num_warmup_steps: 31 | return float(current_step) / float(max(1, num_warmup_steps)) 32 | return max(0.0, float(num_training_steps - current_step) / float(max(1, num_training_steps - num_warmup_steps))) 33 | 34 | return LambdaLR(optimizer, lr_lambda, last_epoch) 35 | 36 | 37 | def get_cosine_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, num_cycles=.5, last_epoch=-1): 38 | """ Create a schedule with a learning rate that decreases following the 39 | values of the cosine function between 0 and `pi * cycles` after a warmup 40 | period during which it increases linearly between 0 and 1. 41 | """ 42 | def lr_lambda(current_step): 43 | if current_step < num_warmup_steps: 44 | return float(current_step) / float(max(1, num_warmup_steps)) 45 | progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps)) 46 | return max(0., 0.5 * (1. + math.cos(math.pi * float(num_cycles) * 2. * progress))) 47 | 48 | return LambdaLR(optimizer, lr_lambda, last_epoch) 49 | 50 | 51 | def get_cosine_with_hard_restarts_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, num_cycles=1., last_epoch=-1): 52 | """ Create a schedule with a learning rate that decreases following the 53 | values of the cosine function with several hard restarts, after a warmup 54 | period during which it increases linearly between 0 and 1. 55 | """ 56 | def lr_lambda(current_step): 57 | if current_step < num_warmup_steps: 58 | return float(current_step) / float(max(1, num_warmup_steps)) 59 | progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps)) 60 | if progress >= 1.: 61 | return 0. 62 | return max(0., 0.5 * (1. + math.cos(math.pi * ((float(num_cycles) * progress) % 1.)))) 63 | 64 | return LambdaLR(optimizer, lr_lambda, last_epoch) 65 | 66 | 67 | class CustomDecayLR(object): 68 | ''' 69 | Example: 70 | >>> scheduler = CustomDecayLR(optimizer) 71 | >>> for epoch in range(100): 72 | >>> scheduler.epoch_step() 73 | >>> train(...) 74 | >>> ... 75 | >>> optimizer.zero_grad() 76 | >>> loss.backward() 77 | >>> optimizer.step() 78 | >>> validate(...) 79 | ''' 80 | def __init__(self,optimizer,lr): 81 | self.optimizer = optimizer 82 | self.lr = lr 83 | 84 | def epoch_step(self,epoch): 85 | lr = self.lr 86 | if epoch > 12: 87 | lr = lr / 1000 88 | elif epoch > 8: 89 | lr = lr / 100 90 | elif epoch > 4: 91 | lr = lr / 10 92 | for param_group in self.optimizer.param_groups: 93 | param_group['lr'] = lr 94 | 95 | class BertLR(object): 96 | ''' 97 | Example: 98 | >>> scheduler = BertLR(optimizer) 99 | >>> for epoch in range(100): 100 | >>> scheduler.step() 101 | >>> train(...) 102 | >>> ... 103 | >>> optimizer.zero_grad() 104 | >>> loss.backward() 105 | >>> optimizer.step() 106 | >>> scheduler.batch_step() 107 | >>> validate(...) 108 | ''' 109 | def __init__(self,optimizer,learning_rate,t_total,warmup): 110 | self.learning_rate = learning_rate 111 | self.optimizer = optimizer 112 | self.t_total = t_total 113 | self.warmup = warmup 114 | 115 | def warmup_linear(self,x, warmup=0.002): 116 | if x < warmup: 117 | return x / warmup 118 | return 1.0 - x 119 | 120 | def batch_step(self,training_step): 121 | lr_this_step = self.learning_rate * self.warmup_linear(training_step / self.t_total,self.warmup) 122 | for param_group in self.optimizer.param_groups: 123 | param_group['lr'] = lr_this_step 124 | 125 | class CyclicLR(object): 126 | ''' 127 | Example: 128 | >>> scheduler = CyclicLR(optimizer) 129 | >>> for epoch in range(100): 130 | >>> scheduler.step() 131 | >>> train(...) 132 | >>> ... 133 | >>> optimizer.zero_grad() 134 | >>> loss.backward() 135 | >>> optimizer.step() 136 | >>> scheduler.batch_step() 137 | >>> validate(...) 138 | ''' 139 | def __init__(self, optimizer, base_lr=1e-3, max_lr=6e-3, 140 | step_size=2000, mode='triangular', gamma=1., 141 | scale_fn=None, scale_mode='cycle', last_batch_iteration=-1): 142 | 143 | if not isinstance(optimizer, Optimizer): 144 | raise TypeError('{} is not an Optimizer'.format( 145 | type(optimizer).__name__)) 146 | 147 | self.optimizer = optimizer 148 | 149 | if isinstance(base_lr, list) or isinstance(base_lr, tuple): 150 | if len(base_lr) != len(optimizer.param_groups): 151 | raise ValueError("expected {} base_lr, got {}".format( 152 | len(optimizer.param_groups), len(base_lr))) 153 | self.base_lrs = list(base_lr) 154 | else: 155 | self.base_lrs = [base_lr] * len(optimizer.param_groups) 156 | 157 | if isinstance(max_lr, list) or isinstance(max_lr, tuple): 158 | if len(max_lr) != len(optimizer.param_groups): 159 | raise ValueError("expected {} max_lr, got {}".format( 160 | len(optimizer.param_groups), len(max_lr))) 161 | self.max_lrs = list(max_lr) 162 | else: 163 | self.max_lrs = [max_lr] * len(optimizer.param_groups) 164 | 165 | self.step_size = step_size 166 | 167 | if mode not in ['triangular', 'triangular2', 'exp_range'] \ 168 | and scale_fn is None: 169 | raise ValueError('mode is invalid and scale_fn is None') 170 | 171 | self.mode = mode 172 | self.gamma = gamma 173 | 174 | if scale_fn is None: 175 | if self.mode == 'triangular': 176 | self.scale_fn = self._triangular_scale_fn 177 | self.scale_mode = 'cycle' 178 | elif self.mode == 'triangular2': 179 | self.scale_fn = self._triangular2_scale_fn 180 | self.scale_mode = 'cycle' 181 | elif self.mode == 'exp_range': 182 | self.scale_fn = self._exp_range_scale_fn 183 | self.scale_mode = 'iterations' 184 | else: 185 | self.scale_fn = scale_fn 186 | self.scale_mode = scale_mode 187 | 188 | self.batch_step(last_batch_iteration + 1) 189 | self.last_batch_iteration = last_batch_iteration 190 | 191 | def _triangular_scale_fn(self, x): 192 | return 1. 193 | 194 | def _triangular2_scale_fn(self, x): 195 | return 1 / (2. ** (x - 1)) 196 | 197 | def _exp_range_scale_fn(self, x): 198 | return self.gamma**(x) 199 | 200 | def get_lr(self): 201 | step_size = float(self.step_size) 202 | cycle = np.floor(1 + self.last_batch_iteration / (2 * step_size)) 203 | x = np.abs(self.last_batch_iteration / step_size - 2 * cycle + 1) 204 | 205 | lrs = [] 206 | param_lrs = zip(self.optimizer.param_groups, self.base_lrs, self.max_lrs) 207 | for param_group, base_lr, max_lr in param_lrs: 208 | base_height = (max_lr - base_lr) * np.maximum(0, (1 - x)) 209 | if self.scale_mode == 'cycle': 210 | lr = base_lr + base_height * self.scale_fn(cycle) 211 | else: 212 | lr = base_lr + base_height * self.scale_fn(self.last_batch_iteration) 213 | lrs.append(lr) 214 | return lrs 215 | 216 | def batch_step(self, batch_iteration=None): 217 | if batch_iteration is None: 218 | batch_iteration = self.last_batch_iteration + 1 219 | self.last_batch_iteration = batch_iteration 220 | for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()): 221 | param_group['lr'] = lr 222 | 223 | class ReduceLROnPlateau(object): 224 | """Reduce learning rate when a metric has stopped improving. 225 | Models often benefit from reducing the learning rate by a factor 226 | of 2-10 once learning stagnates. This scheduler reads a metrics 227 | quantity and if no improvement is seen for a 'patience' number 228 | of epochs, the learning rate is reduced. 229 | 230 | Args: 231 | factor: factor by which the learning rate will 232 | be reduced. new_lr = lr * factor 233 | patience: number of epochs with no improvement 234 | after which learning rate will be reduced. 235 | verbose: int. 0: quiet, 1: update messages. 236 | mode: one of {min, max}. In `min` mode, 237 | lr will be reduced when the quantity 238 | monitored has stopped decreasing; in `max` 239 | mode it will be reduced when the quantity 240 | monitored has stopped increasing. 241 | epsilon: threshold for measuring the new optimum, 242 | to only focus on significant changes. 243 | cooldown: number of epochs to wait before resuming 244 | normal operation after lr has been reduced. 245 | min_lr: lower bound on the learning rate. 246 | 247 | 248 | Example: 249 | >>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9) 250 | >>> scheduler = ReduceLROnPlateau(optimizer, 'min') 251 | >>> for epoch in range(10): 252 | >>> train(...) 253 | >>> val_acc, val_loss = validate(...) 254 | >>> scheduler.epoch_step(val_loss, epoch) 255 | """ 256 | 257 | def __init__(self, optimizer, mode='min', factor=0.1, patience=10, 258 | verbose=0, epsilon=1e-4, cooldown=0, min_lr=0,eps=1e-8): 259 | 260 | super(ReduceLROnPlateau, self).__init__() 261 | assert isinstance(optimizer, Optimizer) 262 | if factor >= 1.0: 263 | raise ValueError('ReduceLROnPlateau ' 264 | 'does not support a factor >= 1.0.') 265 | self.factor = factor 266 | self.min_lr = min_lr 267 | self.epsilon = epsilon 268 | self.patience = patience 269 | self.verbose = verbose 270 | self.cooldown = cooldown 271 | self.cooldown_counter = 0 # Cooldown counter. 272 | self.monitor_op = None 273 | self.wait = 0 274 | self.best = 0 275 | self.mode = mode 276 | self.optimizer = optimizer 277 | self.eps = eps 278 | self._reset() 279 | 280 | def _reset(self): 281 | """Resets wait counter and cooldown counter. 282 | """ 283 | if self.mode not in ['min', 'max']: 284 | raise RuntimeError('Learning Rate Plateau Reducing mode %s is unknown!') 285 | if self.mode == 'min': 286 | self.monitor_op = lambda a, b: np.less(a, b - self.epsilon) 287 | self.best = np.Inf 288 | else: 289 | self.monitor_op = lambda a, b: np.greater(a, b + self.epsilon) 290 | self.best = -np.Inf 291 | self.cooldown_counter = 0 292 | self.wait = 0 293 | 294 | def reset(self): 295 | self._reset() 296 | 297 | def epoch_step(self, metrics, epoch): 298 | current = metrics 299 | if current is None: 300 | warnings.warn('Learning Rate Plateau Reducing requires metrics available!', RuntimeWarning) 301 | else: 302 | if self.in_cooldown(): 303 | self.cooldown_counter -= 1 304 | self.wait = 0 305 | 306 | if self.monitor_op(current, self.best): 307 | self.best = current 308 | self.wait = 0 309 | elif not self.in_cooldown(): 310 | if self.wait >= self.patience: 311 | for param_group in self.optimizer.param_groups: 312 | old_lr = float(param_group['lr']) 313 | if old_lr > self.min_lr + self.eps: 314 | new_lr = old_lr * self.factor 315 | new_lr = max(new_lr, self.min_lr) 316 | param_group['lr'] = new_lr 317 | if self.verbose > 0: 318 | print('\nEpoch %05d: reducing learning rate to %s.' % (epoch, new_lr)) 319 | self.cooldown_counter = self.cooldown 320 | self.wait = 0 321 | self.wait += 1 322 | 323 | def in_cooldown(self): 324 | return self.cooldown_counter > 0 325 | 326 | class ReduceLRWDOnPlateau(ReduceLROnPlateau): 327 | """Reduce learning rate and weight decay when a metric has stopped 328 | improving. Models often benefit from reducing the learning rate by 329 | a factor of 2-10 once learning stagnates. This scheduler reads a metric 330 | quantity and if no improvement is seen for a 'patience' number 331 | of epochs, the learning rate and weight decay factor is reduced for 332 | optimizers that implement the the weight decay method from the paper 333 | `Fixing Weight Decay Regularization in Adam`_. 334 | 335 | .. _Fixing Weight Decay Regularization in Adam: 336 | https://arxiv.org/abs/1711.05101 337 | for AdamW or SGDW 338 | Example: 339 | >>> optimizer = AdamW(model.parameters(), lr=0.1, weight_decay=1e-3) 340 | >>> scheduler = ReduceLRWDOnPlateau(optimizer, 'min') 341 | >>> for epoch in range(10): 342 | >>> train(...) 343 | >>> val_loss = validate(...) 344 | >>> # Note that step should be called after validate() 345 | >>> scheduler.epoch_step(val_loss) 346 | """ 347 | def epoch_step(self, metrics, epoch): 348 | current = metrics 349 | if current is None: 350 | warnings.warn('Learning Rate Plateau Reducing requires metrics available!', RuntimeWarning) 351 | else: 352 | if self.in_cooldown(): 353 | self.cooldown_counter -= 1 354 | self.wait = 0 355 | 356 | if self.monitor_op(current, self.best): 357 | self.best = current 358 | self.wait = 0 359 | elif not self.in_cooldown(): 360 | if self.wait >= self.patience: 361 | for param_group in self.optimizer.param_groups: 362 | old_lr = float(param_group['lr']) 363 | if old_lr > self.min_lr + self.eps: 364 | new_lr = old_lr * self.factor 365 | new_lr = max(new_lr, self.min_lr) 366 | param_group['lr'] = new_lr 367 | if self.verbose > 0: 368 | print('\nEpoch %d: reducing learning rate to %s.' % (epoch, new_lr)) 369 | if param_group['weight_decay'] != 0: 370 | old_weight_decay = float(param_group['weight_decay']) 371 | new_weight_decay = max(old_weight_decay * self.factor, self.min_lr) 372 | if old_weight_decay > new_weight_decay + self.eps: 373 | param_group['weight_decay'] = new_weight_decay 374 | if self.verbose: 375 | print('\nEpoch {epoch}: reducing weight decay factor of group {i} to {new_weight_decay:.4e}.') 376 | self.cooldown_counter = self.cooldown 377 | self.wait = 0 378 | self.wait += 1 379 | 380 | class CosineLRWithRestarts(object): 381 | """Decays learning rate with cosine annealing, normalizes weight decay 382 | hyperparameter value, implements restarts. 383 | https://arxiv.org/abs/1711.05101 384 | 385 | Args: 386 | optimizer (Optimizer): Wrapped optimizer. 387 | batch_size: minibatch size 388 | epoch_size: training samples per epoch 389 | restart_period: epoch count in the first restart period 390 | t_mult: multiplication factor by which the next restart period will extend/shrink 391 | 392 | Example: 393 | >>> scheduler = CosineLRWithRestarts(optimizer, 32, 1024, restart_period=5, t_mult=1.2) 394 | >>> for epoch in range(100): 395 | >>> scheduler.step() 396 | >>> train(...) 397 | >>> ... 398 | >>> optimizer.zero_grad() 399 | >>> loss.backward() 400 | >>> optimizer.step() 401 | >>> scheduler.batch_step() 402 | >>> validate(...) 403 | """ 404 | 405 | def __init__(self, optimizer, batch_size, epoch_size, restart_period=100, 406 | t_mult=2, last_epoch=-1, eta_threshold=1000, verbose=False): 407 | if not isinstance(optimizer, Optimizer): 408 | raise TypeError('{} is not an Optimizer'.format( 409 | type(optimizer).__name__)) 410 | self.optimizer = optimizer 411 | if last_epoch == -1: 412 | for group in optimizer.param_groups: 413 | group.setdefault('initial_lr', group['lr']) 414 | else: 415 | for i, group in enumerate(optimizer.param_groups): 416 | if 'initial_lr' not in group: 417 | raise KeyError("param 'initial_lr' is not specified " 418 | "in param_groups[{}] when resuming an" 419 | " optimizer".format(i)) 420 | self.base_lrs = list(map(lambda group: group['initial_lr'], 421 | optimizer.param_groups)) 422 | 423 | self.last_epoch = last_epoch 424 | self.batch_size = batch_size 425 | self.iteration = 0 426 | self.epoch_size = epoch_size 427 | self.eta_threshold = eta_threshold 428 | self.t_mult = t_mult 429 | self.verbose = verbose 430 | self.base_weight_decays = list(map(lambda group: group['weight_decay'], 431 | optimizer.param_groups)) 432 | self.restart_period = restart_period 433 | self.restarts = 0 434 | self.t_epoch = -1 435 | self.batch_increments = [] 436 | self._set_batch_increment() 437 | 438 | def _schedule_eta(self): 439 | """ 440 | Threshold value could be adjusted to shrink eta_min and eta_max values. 441 | """ 442 | eta_min = 0 443 | eta_max = 1 444 | if self.restarts <= self.eta_threshold: 445 | return eta_min, eta_max 446 | else: 447 | d = self.restarts - self.eta_threshold 448 | k = d * 0.09 449 | return (eta_min + k, eta_max - k) 450 | 451 | def get_lr(self, t_cur): 452 | eta_min, eta_max = self._schedule_eta() 453 | 454 | eta_t = (eta_min + 0.5 * (eta_max - eta_min) 455 | * (1. + math.cos(math.pi * 456 | (t_cur / self.restart_period)))) 457 | 458 | weight_decay_norm_multi = math.sqrt(self.batch_size / 459 | (self.epoch_size * 460 | self.restart_period)) 461 | lrs = [base_lr * eta_t for base_lr in self.base_lrs] 462 | weight_decays = [base_weight_decay * eta_t * weight_decay_norm_multi 463 | for base_weight_decay in self.base_weight_decays] 464 | 465 | if self.t_epoch % self.restart_period < self.t_epoch: 466 | if self.verbose: 467 | print("Restart at epoch {}".format(self.last_epoch)) 468 | self.restart_period *= self.t_mult 469 | self.restarts += 1 470 | self.t_epoch = 0 471 | 472 | return zip(lrs, weight_decays) 473 | 474 | def _set_batch_increment(self): 475 | d, r = divmod(self.epoch_size, self.batch_size) 476 | batches_in_epoch = d + 2 if r > 0 else d + 1 477 | self.iteration = 0 478 | self.batch_increments = list(np.linspace(0, 1, batches_in_epoch)) 479 | 480 | def batch_step(self): 481 | self.last_epoch += 1 482 | self.t_epoch += 1 483 | self._set_batch_increment() 484 | try: 485 | t_cur = self.t_epoch + self.batch_increments[self.iteration] 486 | self.iteration += 1 487 | except (IndexError): 488 | raise RuntimeError("Epoch size and batch size used in the " 489 | "training loop and while initializing " 490 | "scheduler should be the same.") 491 | 492 | for param_group, (lr, weight_decay) in zip(self.optimizer.param_groups,self.get_lr(t_cur)): 493 | param_group['lr'] = lr 494 | param_group['weight_decay'] = weight_decay 495 | 496 | 497 | class NoamLR(object): 498 | ''' 499 | Example: 500 | >>> scheduler = NoamLR(d_model,factor,warm_up,optimizer) 501 | >>> for epoch in range(100): 502 | >>> scheduler.step() 503 | >>> train(...) 504 | >>> ... 505 | >>> glopab_step += 1 506 | >>> optimizer.zero_grad() 507 | >>> loss.backward() 508 | >>> optimizer.step() 509 | >>> scheduler.batch_step(global_step) 510 | >>> validate(...) 511 | ''' 512 | def __init__(self,d_model,factor,warm_up,optimizer): 513 | self.optimizer = optimizer 514 | self.warm_up = warm_up 515 | self.factor = factor 516 | self.d_model = d_model 517 | self._lr = 0 518 | 519 | def get_lr(self,step): 520 | lr = self.factor * (self.d_model ** (-0.5) * min(step ** (-0.5),step * self.warm_up ** (-1.5))) 521 | return lr 522 | 523 | def batch_step(self,step): 524 | ''' 525 | update parameters and rate 526 | :return: 527 | ''' 528 | lr = self.get_lr(step) 529 | for p in self.optimizer.param_groups: 530 | p['lr'] = lr 531 | self._lr = lr 532 | -------------------------------------------------------------------------------- /utils/misc.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict, Counter, deque 2 | import torch 3 | import json 4 | import pickle 5 | import numpy as np 6 | import torch.nn as nn 7 | import math 8 | from torch.optim.optimizer import Optimizer 9 | import transformers 10 | 11 | DUMMY_RELATION = 'DUMMY_RELATION' 12 | DUMMY_ENTITY = 'DUMMY_ENTITY' 13 | 14 | DUMMY_ENTITY_ID = 0 15 | 16 | def batch_device(batch, device): 17 | res = [] 18 | for x in batch: 19 | if isinstance(x, torch.Tensor): 20 | x = x.to(device) 21 | elif isinstance(x, (dict, transformers.tokenization_utils_base.BatchEncoding)): 22 | for k in x: 23 | if isinstance(x[k], torch.Tensor): 24 | x[k] = x[k].to(device) 25 | elif isinstance(x, (list, tuple)) and isinstance(x[0], torch.Tensor): 26 | x = list(map(lambda i: i.to(device), x)) 27 | res.append(x) 28 | return res 29 | 30 | def idx_to_one_hot(idx, size): 31 | """ 32 | Args: 33 | idx [bsz, 1] or int or list 34 | Return: 35 | one_hot [bsz, size] 36 | """ 37 | if isinstance(idx, int): 38 | one_hot = torch.zeros((size,)) 39 | one_hot[idx] = 1 40 | elif isinstance(idx, list): 41 | one_hot = torch.zeros((size,)) 42 | for i in idx: 43 | one_hot[i] = 1 44 | else: 45 | one_hot = torch.FloatTensor(len(idx), size) 46 | one_hot.zero_() 47 | one_hot.scatter_(1, idx, 1) 48 | return one_hot 49 | 50 | 51 | def init_word2id(): 52 | return { 53 | '': 0, 54 | '': 1, 55 | 'E_S': 2, 56 | } 57 | def init_entity2id(): 58 | return { 59 | DUMMY_ENTITY: DUMMY_ENTITY_ID 60 | } 61 | 62 | def add_item_to_x2id(item, x2id): 63 | if not item in x2id: 64 | x2id[item] = len(x2id) 65 | 66 | def invert_dict(d): 67 | return {v: k for k, v in d.items()} 68 | 69 | def load_glove(glove_pt, idx_to_token): 70 | glove = pickle.load(open(glove_pt, 'rb')) 71 | dim = len(glove['the']) 72 | matrix = [] 73 | for i in range(len(idx_to_token)): 74 | token = idx_to_token[i] 75 | tokens = token.split() 76 | if len(tokens) > 1: 77 | v = np.zeros((dim,)) 78 | for token in tokens: 79 | v = v + glove.get(token, glove['the']) 80 | v = v / len(tokens) 81 | else: 82 | v = glove.get(token, glove['the']) 83 | matrix.append(v) 84 | matrix = np.asarray(matrix) 85 | return matrix 86 | 87 | 88 | class SmoothedValue(object): 89 | """Track a series of values and provide access to smoothed values over a 90 | window or the global series average. 91 | """ 92 | 93 | def __init__(self, window_size=20): 94 | self.deque = deque(maxlen=window_size) 95 | self.series = [] 96 | self.total = 0.0 97 | self.count = 0 98 | 99 | def update(self, value): 100 | self.deque.append(value) 101 | self.series.append(value) 102 | self.count += 1 103 | self.total += value 104 | 105 | @property 106 | def median(self): 107 | d = torch.tensor(list(self.deque)) 108 | return d.median().item() 109 | 110 | @property 111 | def avg(self): 112 | d = torch.tensor(list(self.deque)) 113 | return d.mean().item() 114 | 115 | @property 116 | def global_avg(self): 117 | return self.total / self.count 118 | 119 | 120 | class MetricLogger(object): 121 | def __init__(self, delimiter="\t"): 122 | self.meters = defaultdict(SmoothedValue) 123 | self.delimiter = delimiter 124 | 125 | def update(self, **kwargs): 126 | for k, v in kwargs.items(): 127 | if isinstance(v, torch.Tensor): 128 | v = v.item() 129 | assert isinstance(v, (float, int)) 130 | self.meters[k].update(v) 131 | 132 | def __getattr__(self, attr): 133 | if attr in self.meters: 134 | return self.meters[attr] 135 | if attr in self.__dict__: 136 | return self.__dict__[attr] 137 | raise AttributeError("'{}' object has no attribute '{}'".format( 138 | type(self).__name__, attr)) 139 | 140 | def __str__(self): 141 | loss_str = [] 142 | for name, meter in self.meters.items(): 143 | loss_str.append( 144 | "{}: {:.4f} ({:.4f})".format(name, meter.median, meter.global_avg) 145 | ) 146 | return self.delimiter.join(loss_str) 147 | 148 | 149 | class RAdam(Optimizer): 150 | 151 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, degenerated_to_sgd=True): 152 | if not 0.0 <= lr: 153 | raise ValueError("Invalid learning rate: {}".format(lr)) 154 | if not 0.0 <= eps: 155 | raise ValueError("Invalid epsilon value: {}".format(eps)) 156 | if not 0.0 <= betas[0] < 1.0: 157 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) 158 | if not 0.0 <= betas[1] < 1.0: 159 | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) 160 | 161 | self.degenerated_to_sgd = degenerated_to_sgd 162 | defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) 163 | self.buffer = [[None, None, None] for ind in range(10)] 164 | super(RAdam, self).__init__(params, defaults) 165 | 166 | def __setstate__(self, state): 167 | super(RAdam, self).__setstate__(state) 168 | 169 | def step(self, closure=None): 170 | 171 | loss = None 172 | if closure is not None: 173 | loss = closure() 174 | 175 | for group in self.param_groups: 176 | 177 | for p in group['params']: 178 | if p.grad is None: 179 | continue 180 | grad = p.grad.data.float() 181 | if grad.is_sparse: 182 | raise RuntimeError('RAdam does not support sparse gradients') 183 | 184 | p_data_fp32 = p.data.float() 185 | 186 | state = self.state[p] 187 | 188 | if len(state) == 0: 189 | state['step'] = 0 190 | state['exp_avg'] = torch.zeros_like(p_data_fp32) 191 | state['exp_avg_sq'] = torch.zeros_like(p_data_fp32) 192 | else: 193 | state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32) 194 | state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32) 195 | 196 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 197 | beta1, beta2 = group['betas'] 198 | 199 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) 200 | exp_avg.mul_(beta1).add_(1 - beta1, grad) 201 | 202 | state['step'] += 1 203 | buffered = self.buffer[int(state['step'] % 10)] 204 | if state['step'] == buffered[0]: 205 | N_sma, step_size = buffered[1], buffered[2] 206 | else: 207 | buffered[0] = state['step'] 208 | beta2_t = beta2 ** state['step'] 209 | N_sma_max = 2 / (1 - beta2) - 1 210 | N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t) 211 | buffered[1] = N_sma 212 | 213 | # more conservative since it's an approximated value 214 | if N_sma >= 5: 215 | step_size = math.sqrt((1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / (N_sma_max - 2)) / (1 - beta1 ** state['step']) 216 | elif self.degenerated_to_sgd: 217 | step_size = 1.0 / (1 - beta1 ** state['step']) 218 | else: 219 | step_size = -1 220 | buffered[2] = step_size 221 | 222 | # more conservative since it's an approximated value 223 | if N_sma >= 5: 224 | if group['weight_decay'] != 0: 225 | p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32) 226 | denom = exp_avg_sq.sqrt().add_(group['eps']) 227 | p_data_fp32.addcdiv_(-step_size * group['lr'], exp_avg, denom) 228 | p.data.copy_(p_data_fp32) 229 | elif step_size > 0: 230 | if group['weight_decay'] != 0: 231 | p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32) 232 | p_data_fp32.add_(-step_size * group['lr'], exp_avg) 233 | p.data.copy_(p_data_fp32) 234 | 235 | return loss 236 | --------------------------------------------------------------------------------