├── README.md └── codes ├── dataloader.py ├── gmm.sh ├── main.py ├── models.py ├── operators.py └── util.py /README.md: -------------------------------------------------------------------------------- 1 | # NMP-QEM 2 | This is the PyTorch implementation of NMP-QEM. 3 | 4 | <**Neural-based Mixture Probabilistic Query Embedding for Answering FOL queries on Knowledge Graphs**>. EMNLP2022 5 | 6 | ## Requirements 7 | - Python 3.7 8 | - PyTorch 1.7 9 | - tqdm 10 | 11 | 12 | ## Acknowledgement 13 | We refer to the code of [KGReasoning](https://github.com/snap-stanford/KGReasoning). Thanks for their contributions. -------------------------------------------------------------------------------- /codes/dataloader.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python3 2 | 3 | from __future__ import absolute_import 4 | from __future__ import division 5 | from __future__ import print_function 6 | 7 | import numpy as np 8 | import torch 9 | 10 | from torch.utils.data import Dataset 11 | from util import list2tuple, tuple2list, flatten 12 | 13 | class TestDataset(Dataset): 14 | def __init__(self, queries, nentity, nrelation): 15 | # queries is a list of (query, query_structure) pairs 16 | self.len = len(queries) 17 | self.queries = queries 18 | self.nentity = nentity 19 | self.nrelation = nrelation 20 | 21 | def __len__(self): 22 | return self.len 23 | 24 | def __getitem__(self, idx): 25 | query = self.queries[idx][0] 26 | query_structure = self.queries[idx][1] 27 | negative_sample = torch.LongTensor(range(self.nentity)) 28 | return negative_sample, flatten(query), query, query_structure 29 | 30 | @staticmethod 31 | def collate_fn(data): 32 | negative_sample = torch.stack([_[0] for _ in data], dim=0) 33 | query = [_[1] for _ in data] 34 | query_unflatten = [_[2] for _ in data] 35 | query_structure = [_[3] for _ in data] 36 | return negative_sample, query, query_unflatten, query_structure 37 | 38 | class TrainDataset(Dataset): 39 | def __init__(self, queries, nentity, nrelation, negative_sample_size, answer): 40 | # queries is a list of (query, query_structure) pairs 41 | self.len = len(queries) 42 | self.queries = queries 43 | self.nentity = nentity 44 | self.nrelation = nrelation 45 | self.negative_sample_size = negative_sample_size 46 | self.count = self.count_frequency(queries, answer) 47 | self.answer = answer 48 | 49 | def __len__(self): 50 | return self.len 51 | 52 | def __getitem__(self, idx): 53 | query = self.queries[idx][0] 54 | query_structure = self.queries[idx][1] 55 | tail = np.random.choice(list(self.answer[query])) 56 | subsampling_weight = self.count[query] 57 | subsampling_weight = torch.sqrt(1 / torch.Tensor([subsampling_weight])) 58 | negative_sample_list = [] 59 | negative_sample_size = 0 60 | while negative_sample_size < self.negative_sample_size: 61 | negative_sample = np.random.randint(self.nentity, size=self.negative_sample_size*2) 62 | mask = np.in1d( 63 | negative_sample, 64 | self.answer[query], 65 | assume_unique=True, 66 | invert=True 67 | ) 68 | negative_sample = negative_sample[mask] 69 | negative_sample_list.append(negative_sample) 70 | negative_sample_size += negative_sample.size 71 | negative_sample = np.concatenate(negative_sample_list)[:self.negative_sample_size] 72 | negative_sample = torch.from_numpy(negative_sample) 73 | positive_sample = torch.LongTensor([tail]) 74 | return positive_sample, negative_sample, subsampling_weight, flatten(query), query_structure 75 | 76 | @staticmethod 77 | def collate_fn(data): 78 | positive_sample = torch.cat([_[0] for _ in data], dim=0) 79 | negative_sample = torch.stack([_[1] for _ in data], dim=0) 80 | subsample_weight = torch.cat([_[2] for _ in data], dim=0) 81 | query = [_[3] for _ in data] 82 | query_structure = [_[4] for _ in data] 83 | return positive_sample, negative_sample, subsample_weight, query, query_structure 84 | 85 | @staticmethod 86 | def count_frequency(queries, answer, start=4): 87 | count = {} 88 | for query, qtype in queries: 89 | count[query] = start + len(answer[query]) 90 | return count 91 | 92 | class SingledirectionalOneShotIterator(object): 93 | def __init__(self, dataloader): 94 | self.iterator = self.one_shot_iterator(dataloader) 95 | self.step = 0 96 | 97 | def __next__(self): 98 | self.step += 1 99 | data = next(self.iterator) 100 | return data 101 | 102 | @staticmethod 103 | def one_shot_iterator(dataloader): 104 | while True: 105 | for data in dataloader: 106 | yield data -------------------------------------------------------------------------------- /codes/gmm.sh: -------------------------------------------------------------------------------- 1 | # CUDA_VISIBLE_DEVICES=3 python main.py --cuda --do_train --do_test \ 2 | # --data_path data/FB15k-237-betae -n 128 -b 512 -d 400 -g 42 \ 3 | # -lr 0.00007 --max_steps 800001 --cpu_num 4 --geo gmm --valid_steps 15000 \ 4 | # --tasks "1p.2p.3p.2i.3i.ip.pi.2u.up.2in.3in.inp.pin.pni" --exp_info "FB15k-237" 5 | 6 | 7 | # CUDA_VISIBLE_DEVICES=1 python main.py --cuda --do_train --do_test \ 8 | # --data_path data/NELL-betae -n 128 -b 512 -d 400 -g 30 \ 9 | # -lr 0.00005 --max_steps 500001 --cpu_num 4 --geo gmm --valid_steps 15000 \ 10 | # --tasks "1p.2p.3p.2i.3i.ip.pi.2u.up.2in.3in.inp.pin.pni" --exp_info "NELL995" 11 | 12 | 13 | # CUDA_VISIBLE_DEVICES=1 python main.py --cuda --do_train --do_test \ 14 | # --data_path data/wn18rr -n 128 -b 512 -d 400 -g 24 \ 15 | # -lr 0.00005 --max_steps 500001 --cpu_num 4 --geo gmm --valid_steps 15000 \ 16 | # --tasks "1p.2p.3p.2i.3i.ip.pi.2u.up.2in.3in.inp.pin.pni" --exp_info "wn18rr" 17 | 18 | 19 | ################################ evaluation ########################################## 20 | # CUDA_VISIBLE_DEVICES=7 python main.py --cuda --do_test \ 21 | # --data_path data/NELL-betae -n 128 -b 512 -d 400 -g 42 \ 22 | # -lr 0.00005 --max_steps 600001 --cpu_num 4 --geo gmm --valid_steps 15000 \ 23 | # --tasks "1p.2p.3p.2i.3i.ip.pi.2u.up.2in.3in.inp.pin.pni" --checkpoint_path ...... 24 | ###################################################################################### 25 | -------------------------------------------------------------------------------- /codes/main.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python3 2 | 3 | import argparse 4 | import json 5 | import logging 6 | import os 7 | import random 8 | import pdb 9 | import numpy as np 10 | import torch 11 | from torch.utils.data import DataLoader 12 | from models import NMP_QEModel 13 | from dataloader import TestDataset, TrainDataset, SingledirectionalOneShotIterator 14 | from tensorboardX import SummaryWriter 15 | import time 16 | import pickle 17 | from collections import defaultdict 18 | from tqdm import tqdm 19 | from util import flatten_query, list2tuple, parse_time, set_global_seed, eval_tuple 20 | 21 | query_name_dict = {('e',('r',)): '1p', 22 | ('e', ('r', 'r')): '2p', 23 | ('e', ('r', 'r', 'r')): '3p', 24 | (('e', ('r',)), ('e', ('r',))): '2i', 25 | (('e', ('r',)), ('e', ('r',)), ('e', ('r',))): '3i', 26 | ((('e', ('r',)), ('e', ('r',))), ('r',)): 'ip', 27 | (('e', ('r', 'r')), ('e', ('r',))): 'pi', 28 | (('e', ('r',)), ('e', ('r', 'n'))): '2in', 29 | (('e', ('r',)), ('e', ('r',)), ('e', ('r', 'n'))): '3in', 30 | ((('e', ('r',)), ('e', ('r', 'n'))), ('r',)): 'inp', 31 | (('e', ('r', 'r')), ('e', ('r', 'n'))): 'pin', 32 | (('e', ('r', 'r', 'n')), ('e', ('r',))): 'pni', 33 | (('e', ('r',)), ('e', ('r',)), ('u',)): '2u-DNF', 34 | ((('e', ('r',)), ('e', ('r',)), ('u',)), ('r',)): 'up-DNF', 35 | ((('e', ('r', 'n')), ('e', ('r', 'n'))), ('n',)): '2u-DM', 36 | ((('e', ('r', 'n')), ('e', ('r', 'n'))), ('n', 'r')): 'up-DM' 37 | } 38 | name_query_dict = {value: key for key, value in query_name_dict.items()} 39 | all_tasks = list(name_query_dict.keys()) # ['1p', '2p', '3p', '2i', '3i', 'ip', 'pi', '2in', '3in', 'inp', 'pin', 'pni', '2u-DNF', '2u-DM', 'up-DNF', 'up-DM'] 40 | 41 | def parse_args(args=None): 42 | parser = argparse.ArgumentParser( 43 | description='Training and Testing Knowledge Graph Embedding Models', 44 | usage='train.py [] [-h | --help]' 45 | ) 46 | parser.add_argument('--cuda', action='store_true', help='use GPU') 47 | parser.add_argument('--do_train', action='store_true', help="do train") 48 | parser.add_argument('--do_valid', action='store_true', help="do valid") 49 | parser.add_argument('--do_test', action='store_true', help="do test") 50 | 51 | parser.add_argument('--data_path', type=str, default=None, help="KG data path") 52 | parser.add_argument('-n', '--negative_sample_size', default=128, type=int, help="negative entities sampled per query") 53 | parser.add_argument('-d', '--hidden_dim', default=500, type=int, help="embedding dimension") 54 | parser.add_argument('-g', '--gamma', default=12.0, type=float, help="margin in the loss") 55 | parser.add_argument('-b', '--batch_size', default=1024, type=int, help="batch size of queries") 56 | parser.add_argument('--test_batch_size', default=1, type=int, help='valid/test batch size') 57 | parser.add_argument('-lr', '--learning_rate', default=0.0001, type=float) 58 | parser.add_argument('-cpu', '--cpu_num', default=10, type=int, help="used to speed up torch.dataloader") 59 | parser.add_argument('-save', '--save_path', default=None, type=str, help="no need to set manually, will configure automatically") 60 | parser.add_argument('--max_steps', default=100000, type=int, help="maximum iterations to train") 61 | 62 | parser.add_argument('--warm_up_steps', default=None, type=int, help="no need to set manually, will configure automatically") 63 | parser.add_argument('--save_checkpoint_steps', default=50000, type=int, help="save checkpoints every xx steps") 64 | parser.add_argument('--valid_steps', default=10000, type=int, help="evaluate validation queries every xx steps") 65 | parser.add_argument('--log_steps', default=100, type=int, help='train log every xx steps') 66 | parser.add_argument('--test_log_steps', default=1000, type=int, help='valid/test log every xx steps') 67 | parser.add_argument('--nentity', type=int, default=0, help='DO NOT MANUALLY SET') 68 | parser.add_argument('--nrelation', type=int, default=0, help='DO NOT MANUALLY SET') 69 | 70 | parser.add_argument('--geo', default='gmm', type=str, choices=['vec', 'gaussian', 'beta', 'gmm'], help='the reasoning model, vec for GQE, gaussian for PERM, beta for BetaE, gmm for NMPQEM') 71 | parser.add_argument('--print_on_screen', action='store_true', default=True) 72 | parser.add_argument('--tasks', default='1p.2p.3p.2i.3i.ip.pi.2u.up', type=str, help="tasks connected by dot, refer to the BetaE paper for detailed meaning and structure of each task") # 2in.3in.inp.pin.pni. 73 | parser.add_argument('--seed', default=0, type=int, help="random seed") 74 | parser.add_argument('-gmm', '--gmm_mode', default="(none, 10, 1)", type=str, help='(activation, gmm_num) for GMM') 75 | 76 | parser.add_argument('--prefix', default=None, type=str, help='prefix of the log path') 77 | parser.add_argument('--checkpoint_path', default=None, type=str, help='path for loading the checkpoints') 78 | parser.add_argument('-evu', '--evaluate_union', default="DNF", type=str, choices=['DNF', 'DM'], help='the way to evaluate union queries, transform it to disjunctive normal form (DNF) or use the De Morgan\'s laws (DM)') 79 | parser.add_argument('--exp_info', default=None, type=str, help='prefix of the log path') 80 | 81 | return parser.parse_args(args) 82 | 83 | 84 | def save_model(model, optimizer, save_variable_list, args): 85 | ''' 86 | Save the parameters of the model and the optimizer, 87 | as well as some other variables such as step and learning_rate 88 | ''' 89 | argparse_dict = vars(args) 90 | with open(os.path.join(args.save_path, 'config.json'), 'w') as fjson: 91 | json.dump(argparse_dict, fjson) 92 | 93 | state_dict = { 94 | **save_variable_list, 95 | 'model_state_dict': model.state_dict(), 96 | 'optimizer_state_dict': optimizer.state_dict() 97 | } 98 | torch.save(state_dict, 99 | os.path.join(args.save_path, 'checkpoint')) 100 | 101 | 102 | def save_best_model(model, optimizer, save_variable_list, args): 103 | ''' 104 | Save the parameters of the model and the optimizer, 105 | as well as some other variables such as step and learning_rate 106 | ''' 107 | argparse_dict = vars(args) 108 | with open(os.path.join(args.save_path, 'config.json'), 'w') as fjson: 109 | json.dump(argparse_dict, fjson) 110 | 111 | state_dict = { 112 | **save_variable_list, 113 | 'model_state_dict': model.state_dict(), 114 | 'optimizer_state_dict': optimizer.state_dict() 115 | } 116 | torch.save(state_dict, 117 | os.path.join(args.save_path, 'best_checkpoint')) 118 | 119 | 120 | def set_logger(args): 121 | ''' 122 | Write logs to console and log file 123 | ''' 124 | if args.do_train: 125 | log_file = os.path.join(args.save_path, 'train.log') 126 | else: 127 | log_file = os.path.join(args.save_path, 'test.log') 128 | 129 | logging.basicConfig( 130 | format='%(asctime)s %(levelname)-8s %(message)s', 131 | level=logging.INFO, 132 | datefmt='%Y-%m-%d %H:%M:%S', 133 | filename=log_file, 134 | filemode='a+' 135 | ) 136 | if args.print_on_screen: 137 | console = logging.StreamHandler() 138 | console.setLevel(logging.INFO) 139 | formatter = logging.Formatter('%(asctime)s %(levelname)-8s %(message)s') 140 | console.setFormatter(formatter) 141 | logging.getLogger('').addHandler(console) 142 | 143 | 144 | def log_metrics(mode, step, metrics): 145 | ''' 146 | Print the evaluation logs 147 | ''' 148 | for metric in metrics: 149 | logging.info('%s %s at step %d: %f' % (mode, metric, step, metrics[metric])) 150 | 151 | 152 | def evaluate(model, tp_answers, fn_answers, args, dataloader, query_name_dict, mode, step, writer): 153 | ''' 154 | Evaluate queries in dataloader 155 | ''' 156 | average_metrics = defaultdict(float) 157 | all_metrics = defaultdict(float) 158 | metrics = model.test_step(model, tp_answers, fn_answers, args, dataloader, query_name_dict) 159 | num_query_structures = 0 160 | num_queries = 0 161 | for query_structure in metrics: 162 | log_metrics(mode+" "+query_name_dict[query_structure], step, metrics[query_structure]) 163 | for metric in metrics[query_structure]: 164 | writer.add_scalar("_".join([mode, query_name_dict[query_structure], metric]), metrics[query_structure][metric], step) 165 | all_metrics["_".join([query_name_dict[query_structure], metric])] = metrics[query_structure][metric] 166 | if metric != 'num_queries': 167 | average_metrics[metric] += metrics[query_structure][metric] 168 | num_queries += metrics[query_structure]['num_queries'] 169 | num_query_structures += 1 170 | 171 | for metric in average_metrics: 172 | average_metrics[metric] /= num_query_structures 173 | writer.add_scalar("_".join([mode, 'average', metric]), average_metrics[metric], step) 174 | all_metrics["_".join(["average", metric])] = average_metrics[metric] 175 | log_metrics('%s average'%mode, step, average_metrics) 176 | # pdb.set_trace() 177 | 178 | return metrics, average_metrics 179 | 180 | 181 | def evaluate_best_model(model, tp_answers, fn_answers, args, dataloader, query_name_dict, mode, step, writer): 182 | ''' 183 | Evaluate queries in dataloader 184 | ''' 185 | average_metrics = defaultdict(float) 186 | all_metrics = defaultdict(float) 187 | metrics = model.test_step(model, tp_answers, fn_answers, args, dataloader, query_name_dict) 188 | num_query_structures = 0 189 | num_queries = 0 190 | for query_structure in metrics: 191 | log_metrics(mode+" "+query_name_dict[query_structure], step, metrics[query_structure]) 192 | for metric in metrics[query_structure]: 193 | writer.add_scalar("_".join([mode, query_name_dict[query_structure], metric]), metrics[query_structure][metric], step) 194 | all_metrics["_".join([query_name_dict[query_structure], metric])] = metrics[query_structure][metric] 195 | if metric != 'num_queries': 196 | average_metrics[metric] += metrics[query_structure][metric] 197 | num_queries += metrics[query_structure]['num_queries'] 198 | num_query_structures += 1 199 | 200 | for metric in average_metrics: 201 | average_metrics[metric] /= num_query_structures 202 | writer.add_scalar("_".join([mode, 'average', metric]), average_metrics[metric], step) 203 | all_metrics["_".join(["average", metric])] = average_metrics[metric] 204 | log_metrics('%s average'%mode, step, average_metrics) 205 | # pdb.set_trace() 206 | 207 | return metrics, average_metrics 208 | 209 | 210 | def load_data(args, tasks): 211 | ''' 212 | Load queries and remove queries not in tasks 213 | ''' 214 | logging.info("loading data") 215 | train_queries = pickle.load(open(os.path.join(args.data_path, "train-queries.pkl"), 'rb')) 216 | train_answers = pickle.load(open(os.path.join(args.data_path, "train-answers.pkl"), 'rb')) 217 | valid_queries = pickle.load(open(os.path.join(args.data_path, "valid-queries.pkl"), 'rb')) 218 | valid_hard_answers = pickle.load(open(os.path.join(args.data_path, "valid-hard-answers.pkl"), 'rb')) 219 | valid_easy_answers = pickle.load(open(os.path.join(args.data_path, "valid-easy-answers.pkl"), 'rb')) 220 | test_queries = pickle.load(open(os.path.join(args.data_path, "test-queries.pkl"), 'rb')) 221 | test_hard_answers = pickle.load(open(os.path.join(args.data_path, "test-hard-answers.pkl"), 'rb')) 222 | test_easy_answers = pickle.load(open(os.path.join(args.data_path, "test-easy-answers.pkl"), 'rb')) 223 | 224 | # remove tasks not in args.tasks 225 | for name in all_tasks: 226 | if 'u' in name: 227 | name, evaluate_union = name.split('-') 228 | else: 229 | evaluate_union = args.evaluate_union 230 | if name not in tasks or evaluate_union != args.evaluate_union: 231 | query_structure = name_query_dict[name if 'u' not in name else '-'.join([name, evaluate_union])] 232 | if query_structure in train_queries: 233 | del train_queries[query_structure] 234 | if query_structure in valid_queries: 235 | del valid_queries[query_structure] 236 | if query_structure in test_queries: 237 | del test_queries[query_structure] 238 | # del train_queries[name_query_dict['2p']] 239 | # del train_queries[name_query_dict['3p']] 240 | # del train_queries[name_query_dict['3i']] 241 | # del train_queries[name_query_dict['3in']] 242 | # del train_queries[name_query_dict['pni']] 243 | # del train_queries[name_query_dict['pin']] 244 | # pdb.set_trace() 245 | 246 | return train_queries, train_answers, valid_queries, valid_hard_answers, valid_easy_answers, test_queries, test_hard_answers, test_easy_answers 247 | 248 | def main(args): 249 | set_global_seed(args.seed) 250 | tasks = args.tasks.split('.') 251 | for task in tasks: 252 | if 'n' in task and args.geo in ['vec']: 253 | assert False, "Gmm cannot handle queries with negation" 254 | if args.evaluate_union == 'DM': 255 | assert args.geo == 'beta', "only BetaE supports modeling union using De Morgan's Laws" 256 | 257 | cur_time = parse_time() 258 | if args.prefix is None: 259 | prefix = 'logs' 260 | else: 261 | prefix = args.prefix 262 | 263 | print ("overwritting args.save_path") 264 | args.save_path = os.path.join(prefix, args.data_path.split('/')[-1], args.tasks, args.geo) 265 | if args.geo in ['gmm']: 266 | tmp_str = "g-{}-mode-{}".format(args.gamma, args.gmm_mode) 267 | elif args.geo in ['vec']: 268 | tmp_str = "g-{}".format(args.gamma) 269 | elif args.geo == 'beta': 270 | tmp_str = "g-{}-mode-{}".format(args.gamma, args.beta_mode) 271 | 272 | if args.checkpoint_path is not None: 273 | args.save_path = args.checkpoint_path 274 | else: 275 | args.save_path = os.path.join(args.save_path, tmp_str, cur_time) 276 | 277 | if not os.path.exists(args.save_path): 278 | os.makedirs(args.save_path) 279 | 280 | print ("logging to", args.save_path) 281 | if not args.do_train: # if not training, then create tensorboard files in some tmp location 282 | writer = SummaryWriter('./logs-debug/unused-tb') 283 | else: 284 | writer = SummaryWriter(args.save_path) 285 | set_logger(args) 286 | 287 | with open('%s/stats.txt'%args.data_path) as f: 288 | entrel = f.readlines() 289 | nentity = int(entrel[0].split(' ')[-1]) 290 | nrelation = int(entrel[1].split(' ')[-1]) 291 | 292 | args.nentity = nentity 293 | args.nrelation = nrelation 294 | dataset = args.data_path.split('/')[1].split('-')[0] 295 | 296 | logging.info('-------------------------------'*3) 297 | logging.info('Geo: %s' % args.geo) 298 | logging.info('Data Path: %s' % args.data_path) 299 | logging.info('#entity: %d' % nentity) 300 | logging.info('#relation: %d' % nrelation) 301 | logging.info('#max steps: %d' % args.max_steps) 302 | logging.info('Evaluate unoins using: %s' % args.evaluate_union) 303 | logging.info('experiment information: %s' % args.exp_info) 304 | 305 | train_queries, train_answers, valid_queries, valid_hard_answers, valid_easy_answers, test_queries, test_hard_answers, test_easy_answers = load_data(args, tasks) 306 | 307 | logging.info("Training info:") 308 | if args.do_train: 309 | for query_structure in train_queries: 310 | logging.info(query_name_dict[query_structure]+": "+str(len(train_queries[query_structure]))) 311 | train_path_queries = defaultdict(set) 312 | train_other_queries = defaultdict(set) 313 | path_list = ['1p', '2p', '3p'] 314 | for query_structure in train_queries: 315 | if query_name_dict[query_structure] in path_list: 316 | train_path_queries[query_structure] = train_queries[query_structure] 317 | else: 318 | train_other_queries[query_structure] = train_queries[query_structure] 319 | train_path_queries = flatten_query(train_path_queries) 320 | train_path_iterator = SingledirectionalOneShotIterator(DataLoader( 321 | TrainDataset(train_path_queries, nentity, nrelation, args.negative_sample_size, train_answers), 322 | batch_size=args.batch_size, 323 | shuffle=True, 324 | num_workers=args.cpu_num, 325 | collate_fn=TrainDataset.collate_fn 326 | )) 327 | if len(train_other_queries) > 0: 328 | train_other_queries = flatten_query(train_other_queries) 329 | train_other_iterator = SingledirectionalOneShotIterator(DataLoader( 330 | TrainDataset(train_other_queries, nentity, nrelation, args.negative_sample_size, train_answers), 331 | batch_size=args.batch_size, 332 | shuffle=True, 333 | num_workers=args.cpu_num, 334 | collate_fn=TrainDataset.collate_fn 335 | )) 336 | else: 337 | train_other_iterator = None 338 | 339 | logging.info("Validation info:") 340 | if args.do_valid: 341 | for query_structure in valid_queries: 342 | logging.info(query_name_dict[query_structure]+": "+str(len(valid_queries[query_structure]))) 343 | valid_queries = flatten_query(valid_queries) 344 | valid_dataloader = DataLoader( 345 | TestDataset(valid_queries, 346 | args.nentity, 347 | args.nrelation), 348 | batch_size=args.test_batch_size, 349 | num_workers=args.cpu_num, 350 | collate_fn=TestDataset.collate_fn 351 | ) 352 | logging.info("Test info:") 353 | if args.do_test: 354 | for query_structure in test_queries: 355 | logging.info(query_name_dict[query_structure]+": "+str(len(test_queries[query_structure]))) 356 | test_queries = flatten_query(test_queries) 357 | test_dataloader = DataLoader( 358 | TestDataset(test_queries, 359 | args.nentity, 360 | args.nrelation), 361 | batch_size=args.test_batch_size, 362 | num_workers=args.cpu_num, 363 | collate_fn=TestDataset.collate_fn 364 | ) 365 | 366 | model = NMP_QEModel( 367 | nentity=nentity, 368 | nrelation=nrelation, 369 | hidden_dim=args.hidden_dim, 370 | gamma=args.gamma, 371 | geo=args.geo, 372 | use_cuda = args.cuda, 373 | gmm_mode=eval_tuple(args.gmm_mode), 374 | test_batch_size=args.test_batch_size, 375 | query_name_dict = query_name_dict, 376 | dataset = dataset 377 | ) 378 | 379 | logging.info('Model Parameter Configuration:') 380 | num_params = 0 381 | for name, param in model.named_parameters(): 382 | logging.info('Parameter %s: %s, require_grad = %s, %s' % (name, str(param.size()), str(param.requires_grad), str(param.is_leaf))) 383 | if param.requires_grad: 384 | num_params += np.prod(param.size()) 385 | logging.info('Parameter Number: %d' % num_params) 386 | 387 | if args.cuda: 388 | model = model.cuda() 389 | # pdb.set_trace() 390 | if args.do_train: 391 | current_learning_rate = args.learning_rate 392 | optimizer = torch.optim.Adam( 393 | filter(lambda p: p.requires_grad, model.parameters()), 394 | lr=current_learning_rate 395 | ) 396 | warm_up_steps = args.max_steps // 2 397 | 398 | if args.checkpoint_path is not None: 399 | logging.info('Loading checkpoint %s...' % args.checkpoint_path) 400 | checkpoint = torch.load(os.path.join(args.checkpoint_path, 'best_checkpoint')) 401 | init_step = checkpoint['step'] 402 | model.load_state_dict(checkpoint['model_state_dict']) 403 | if args.do_train: 404 | current_learning_rate = checkpoint['current_learning_rate'] 405 | warm_up_steps = checkpoint['warm_up_steps'] 406 | optimizer.load_state_dict(checkpoint['optimizer_state_dict']) 407 | else: 408 | logging.info('Ramdomly Initializing %s Model...' % args.geo) 409 | init_step = 0 410 | 411 | step = init_step 412 | if args.geo == 'gmm': 413 | logging.info('gaussian mode = %s' % args.gmm_mode) 414 | elif args.geo == 'beta': 415 | logging.info('beta mode = %s' % args.beta_mode) 416 | logging.info('tasks = %s' % args.tasks) 417 | logging.info('init_step = %d' % init_step) 418 | if args.do_train: 419 | logging.info('Start Training...') 420 | logging.info('learning_rate = %d' % current_learning_rate) 421 | logging.info('batch_size = %d' % args.batch_size) 422 | logging.info('hidden_dim = %d' % args.hidden_dim) 423 | logging.info('gamma = %f' % args.gamma) 424 | 425 | if args.do_train: 426 | best_avg_metrics = defaultdict(float) 427 | best_mrr = 0. 428 | best_metrics = defaultdict(lambda: defaultdict(int)) 429 | 430 | training_logs = [] 431 | # #Training Loop 432 | for step in range(init_step, args.max_steps): 433 | if step == 2*args.max_steps//3: 434 | args.valid_steps *= 4 435 | 436 | log = model.train_step(model, optimizer, train_path_iterator, args, step) 437 | for metric in log: 438 | writer.add_scalar('path_'+metric, log[metric], step) 439 | if train_other_iterator is not None: 440 | log = model.train_step(model, optimizer, train_other_iterator, args, step) 441 | for metric in log: 442 | writer.add_scalar('other_'+metric, log[metric], step) 443 | log = model.train_step(model, optimizer, train_path_iterator, args, step) 444 | 445 | training_logs.append(log) 446 | 447 | if step >= warm_up_steps: 448 | current_learning_rate = current_learning_rate / 5 449 | logging.info('Change learning_rate to %f at step %d' % (current_learning_rate, step)) 450 | optimizer = torch.optim.Adam( 451 | filter(lambda p: p.requires_grad, model.parameters()), 452 | lr=current_learning_rate 453 | ) 454 | warm_up_steps = warm_up_steps * 1.5 455 | 456 | if step % args.save_checkpoint_steps == 0: 457 | save_variable_list = { 458 | 'step': step, 459 | 'current_learning_rate': current_learning_rate, 460 | 'warm_up_steps': warm_up_steps 461 | } 462 | save_model(model, optimizer, save_variable_list, args) 463 | 464 | if step % args.valid_steps == 0 and step > 0: 465 | if args.do_valid: 466 | logging.info('Evaluating on Valid Dataset...') 467 | valid_all_metrics, average_metrics = evaluate(model, valid_easy_answers, valid_hard_answers, args, valid_dataloader, query_name_dict, 'Valid', step, writer) 468 | 469 | if args.do_test: 470 | logging.info('Evaluating on Test Dataset...') 471 | test_all_metrics, average_metrics = evaluate(model, test_easy_answers, test_hard_answers, args, test_dataloader, query_name_dict, 'Test', step, writer) 472 | # pdb.set_trace() 473 | avg_mrr = average_metrics['MRR'] 474 | avg_h10 = average_metrics['HITS10'] 475 | avg_h3 = average_metrics['HITS3'] 476 | avg_h1 = average_metrics['HITS1'] 477 | 478 | if avg_mrr > best_mrr: 479 | save_best_model(model, optimizer, save_variable_list, args) 480 | best_mrr = avg_mrr 481 | 482 | log_metrics('%s average'%'Best_Test', step, average_metrics) 483 | 484 | for query_structure in test_all_metrics.keys(): 485 | for metric in test_all_metrics[query_structure].keys(): 486 | if metric == 'num_queries': 487 | continue 488 | else: 489 | if test_all_metrics[query_structure][metric] > best_metrics[query_structure][metric]: 490 | best_metrics[query_structure][metric] = test_all_metrics[query_structure][metric] 491 | 492 | for query_structure in best_metrics.keys(): 493 | log_metrics('Best Test'+" "+query_name_dict[query_structure], step, best_metrics[query_structure]) 494 | 495 | 496 | if step % args.log_steps == 0: 497 | metrics = {} 498 | for metric in training_logs[0].keys(): 499 | metrics[metric] = sum([log[metric] for log in training_logs])/len(training_logs) 500 | log_metrics('Training average', step, metrics) 501 | training_logs = [] 502 | 503 | save_variable_list = { 504 | 'step': step, 505 | 'current_learning_rate': current_learning_rate, 506 | 'warm_up_steps': warm_up_steps 507 | } 508 | save_model(model, optimizer, save_variable_list, args) 509 | try: 510 | print (step) 511 | except: 512 | step = 0 513 | 514 | if args.do_test: 515 | logging.info('Evaluating on Test Dataset...') 516 | test_all_metrics, test_average_metrics = evaluate(model, test_easy_answers, test_hard_answers, args, test_dataloader, query_name_dict, 'Test', step, writer) 517 | 518 | logging.info("Training finished!!") 519 | 520 | if __name__ == '__main__': 521 | main(parse_args()) 522 | -------------------------------------------------------------------------------- /codes/models.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python3 2 | 3 | import logging 4 | import numpy as np 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | from torch.utils.data import DataLoader 9 | from dataloader import TestDataset, TrainDataset, SingledirectionalOneShotIterator 10 | import collections 11 | from operators import * 12 | from tqdm import tqdm 13 | from util import * 14 | import pdb 15 | 16 | 17 | 18 | class NMP_QEModel(nn.Module): 19 | def __init__(self, 20 | nentity, 21 | nrelation, 22 | hidden_dim, 23 | gamma, 24 | geo, 25 | test_batch_size=1, 26 | gmm_mode=None, 27 | use_cuda=True, 28 | query_name_dict=None, 29 | dataset=None): 30 | super(NMP_QEModel, self).__init__() 31 | self.nentity = nentity 32 | self.nrelation = nrelation 33 | self.hidden_dim = hidden_dim 34 | self.epsilon = 2.0 35 | self.geo = geo 36 | self.use_cuda = use_cuda 37 | self.batch_entity_range = torch.arange(nentity).to(torch.float).repeat(test_batch_size, 1).cuda() if self.use_cuda else torch.arange(nentity).to(torch.float).repeat(test_batch_size, 1) # used in test_step 38 | self.query_name_dict = query_name_dict 39 | 40 | self.gamma = nn.Parameter( 41 | torch.Tensor([gamma]), 42 | requires_grad=False 43 | ) 44 | 45 | self.embedding_range = nn.Parameter( 46 | torch.Tensor([(self.gamma.item() + self.epsilon) / hidden_dim]), 47 | requires_grad=False 48 | ) 49 | 50 | self.entity_dim = hidden_dim 51 | self.relation_dim = hidden_dim 52 | self.projection_regularizer = Regularizer(1, 0.05, 1e9) 53 | 54 | 55 | if dataset == 'NELL': 56 | self.entity_embedding = nn.Parameter(torch.from_numpy(np.load('path of pretrained embedding')).float()) 57 | if dataset == 'wn18rr': 58 | self.entity_embedding = nn.Parameter(torch.from_numpy(np.load('path of pretrained embedding')).float()) 59 | if dataset == 'FB15k': 60 | self.entity_embedding = nn.Parameter(torch.from_numpy(np.load('path of pretrained embedding')).float()) 61 | activation, gmm_num, layers_num = gmm_mode 62 | 63 | self.gmm_num = gmm_num 64 | 65 | self.input_entity_embedding = nn.Parameter(torch.zeros(nentity, 2*gmm_num, hidden_dim+1)) 66 | self.init_input(gmm_num, nentity) 67 | 68 | self.projection_regularizer = Regularizer(1, 0.05, 1e9) 69 | 70 | self.projectionNN = RelationProjectionLayer(nrelation, input_dim=self.relation_dim + 1, 71 | output_dim=self.relation_dim, ngauss=gmm_num, 72 | projection_regularizer=self.projection_regularizer) 73 | 74 | self.AndNN = AndMLP(nguass=gmm_num, 75 | hidden_dim=hidden_dim, 76 | and_regularizer=self.projection_regularizer) 77 | 78 | self.notNN = NotMLP(n_layers=1, entity_dim=hidden_dim+1) 79 | # self.OrNN = OrMLP() 80 | 81 | def init_input(self, gmm_num, nentity): 82 | for i in range(nentity): 83 | for j in range(gmm_num*2): 84 | self.input_entity_embedding.data[i, j, :-1] = self.entity_embedding.data[i] 85 | if j < gmm_num: 86 | self.input_entity_embedding.data[i, j, -1] = 1 / gmm_num 87 | 88 | 89 | def forward(self, positive_sample, negative_sample, subsampling_weight, batch_queries_dict, batch_idxs_dict): 90 | # pdb.set_trace() 91 | all_center_embeddings, all_idxs = [], [] 92 | all_union_center_embeddings, all_union_idxs = [], [] 93 | for query_structure in batch_queries_dict: 94 | if 'u' in self.query_name_dict[query_structure]: 95 | center_embedding, _ = self.embed_query_gmm(self.transform_union_query(batch_queries_dict[query_structure], query_structure), 96 | self.transform_union_structure(query_structure), 0) 97 | all_union_center_embeddings.append(center_embedding) 98 | all_union_idxs.extend(batch_idxs_dict[query_structure]) 99 | else: 100 | center_embedding, _ = self.embed_query_gmm(batch_queries_dict[query_structure], query_structure, 0) 101 | all_center_embeddings.append(center_embedding) 102 | all_idxs.extend(batch_idxs_dict[query_structure]) 103 | 104 | if len(all_center_embeddings) > 0: 105 | all_center_embeddings = torch.cat(all_center_embeddings, dim=0).unsqueeze(1) 106 | if len(all_union_center_embeddings) > 0: 107 | all_union_center_embeddings = torch.cat(all_union_center_embeddings, dim=0).unsqueeze(1) 108 | all_union_center_embeddings = all_union_center_embeddings.view(all_union_center_embeddings.shape[0] // 2, 2, 1, self.gmm_num*2, -1) 109 | 110 | if type(subsampling_weight) != type(None): 111 | subsampling_weight = subsampling_weight[all_idxs + all_union_idxs] 112 | 113 | if type(positive_sample) != type(None): 114 | if len(all_center_embeddings) > 0: 115 | positive_sample_regular = positive_sample[all_idxs] 116 | positive_embedding = torch.index_select(self.input_entity_embedding, dim=0, 117 | index=positive_sample_regular).unsqueeze(1) 118 | positive_logit = self.cal_logit_gmm(positive_embedding, all_center_embeddings) 119 | else: 120 | positive_logit = torch.Tensor([]).to(self.input_entity_embedding.device) 121 | 122 | if len(all_union_center_embeddings) > 0: 123 | positive_sample_union = positive_sample[all_union_idxs] 124 | positive_embedding = torch.index_select(self.input_entity_embedding, dim=0, 125 | index=positive_sample_union).unsqueeze(1).unsqueeze(1) 126 | positive_union_logit = self.cal_logit_gmm(positive_embedding, all_union_center_embeddings) 127 | positive_union_logit = torch.max(positive_union_logit, dim=1)[0] 128 | else: 129 | positive_union_logit = torch.Tensor([]).to(self.input_entity_embedding.device) 130 | positive_logit = torch.cat([positive_logit, positive_union_logit], dim=0) 131 | else: 132 | positive_logit = None 133 | 134 | if type(negative_sample) != type(None): 135 | # pdb.set_trace() 136 | if len(all_center_embeddings) > 0: 137 | negative_sample_regular = negative_sample[all_idxs] 138 | batch_size, negative_size = negative_sample_regular.shape 139 | negative_embedding = torch.index_select(self.input_entity_embedding, dim=0, 140 | index=negative_sample_regular.view(-1)).view(batch_size, negative_size, self.gmm_num * 2, -1) 141 | # pdb.set_trace() 142 | negative_logit = self.cal_logit_gmm(negative_embedding, all_center_embeddings) 143 | else: 144 | negative_logit = torch.Tensor([]).to(self.input_entity_embedding.device) 145 | 146 | if len(all_union_center_embeddings) > 0: 147 | negative_sample_union = negative_sample[all_union_idxs] 148 | batch_size, negative_size = negative_sample_union.shape 149 | # pdb.set_trace() 150 | negative_embedding = torch.index_select(self.input_entity_embedding, dim=0, 151 | index=negative_sample_union.view(-1)).view(batch_size, 1, negative_size, self.gmm_num * 2, -1) # B, 1, neg_size, 2N, dim+1 152 | negative_embedding = negative_embedding.squeeze(0) 153 | all_union_center_embeddings = all_union_center_embeddings.squeeze(0) 154 | negative_union_logit = self.cal_logit_gmm(negative_embedding, all_union_center_embeddings) 155 | negative_union_logit = negative_union_logit.unsqueeze(0) 156 | negative_union_logit = torch.max(negative_union_logit, dim=1)[0] 157 | else: 158 | negative_union_logit = torch.Tensor([]).to(self.input_entity_embedding.device) 159 | negative_logit = torch.cat([negative_logit, negative_union_logit], dim=0) 160 | else: 161 | negative_logit = None 162 | 163 | return positive_logit, negative_logit, subsampling_weight, all_idxs + all_union_idxs 164 | 165 | 166 | ########################## for visual ######################################################## 167 | # def embed_query_gmm(self, queries, query_structure, idx): 168 | # ''' 169 | # Iterative embed a batch of queries with same structure using GMM 170 | # queries: a flattened batch of queries 171 | # ''' 172 | # # pdb.set_trace() 173 | # all_relation_flag = True 174 | # for ele in query_structure[-1]: # whether the current query tree has merged to one branch and only need to do relation traversal, e.g., path queries or conjunctive queries after the intersection 175 | # if ele not in ['r', 'n']: 176 | # all_relation_flag = False 177 | # break 178 | # # pdb.set_trace() 179 | # if all_relation_flag: 180 | # if query_structure[0] == 'e': 181 | # embedding = torch.index_select(self.input_entity_embedding, dim=0, index=queries[:, idx]) 182 | # torch.save(embedding, './visual_data/entity_'+ str(idx) + '.pth') 183 | # idx += 1 184 | # else: 185 | # embedding, idx = self.embed_query_gmm(queries, query_structure[0], idx) 186 | # for i in range(len(query_structure[-1])): 187 | # if query_structure[-1][i] == 'n': 188 | # # assert False, "gaussian cannot handle queries with negation" 189 | # assert (queries[:, idx] == -2).all() 190 | # embedding = self.notNN(embedding) 191 | # else: 192 | # relation_id = queries[:, idx] 193 | # embedding = self.projectionNN(embedding, relation_id) 194 | # torch.save(embedding, './visual_data/p_entity_'+ str(idx) + '.pth') 195 | # idx += 1 196 | # else: 197 | # # queries: 5 * 6, query_structure: (('e', ('r',)), ('e', ('r',)), ('e', ('r',))) embedding_list&offset_embedding_list len: 3 内的元素: 5 * 6400 198 | # embedding_list = [] 199 | # for i in range(len(query_structure)): 200 | # embedding, idx = self.embed_query_gmm(queries, query_structure[i], idx) 201 | # embedding_list.append(embedding) 202 | 203 | # # pdb.set_trace() 204 | # vector = embedding_list[0] 205 | # for i in range(1, len(embedding_list)): 206 | # vector = self.AndNN(vector, embedding_list[i]) 207 | # torch.save(vector, './visual_data/insert_entity.pth') 208 | # embedding = vector 209 | 210 | # return embedding, idx 211 | 212 | #################################################################################### 213 | 214 | 215 | def embed_query_gmm(self, queries, query_structure, idx): 216 | ''' 217 | Iterative embed a batch of queries with same structure using GMM 218 | queries: a flattened batch of queries 219 | ''' 220 | all_relation_flag = True 221 | for ele in query_structure[-1]: # whether the current query tree has merged to one branch and only need to do relation traversal, e.g., path queries or conjunctive queries after the intersection 222 | if ele not in ['r', 'n']: 223 | all_relation_flag = False 224 | break 225 | # pdb.set_trace() 226 | if all_relation_flag: 227 | if query_structure[0] == 'e': 228 | embedding = torch.index_select(self.input_entity_embedding, dim=0, index=queries[:, idx]) 229 | idx += 1 230 | else: 231 | embedding, idx = self.embed_query_gmm(queries, query_structure[0], idx) 232 | for i in range(len(query_structure[-1])): 233 | if query_structure[-1][i] == 'n': 234 | # assert False, "gaussian cannot handle queries with negation" 235 | assert (queries[:, idx] == -2).all() 236 | embedding = self.notNN(embedding) 237 | else: 238 | relation_id = queries[:, idx] 239 | embedding = self.projectionNN(embedding, relation_id) 240 | idx += 1 241 | else: 242 | # queries: 5 * 6, query_structure: (('e', ('r',)), ('e', ('r',)), ('e', ('r',))) embedding_list&offset_embedding_list len 243 | embedding_list = [] 244 | for i in range(len(query_structure)): 245 | embedding, idx = self.embed_query_gmm(queries, query_structure[i], idx) 246 | embedding_list.append(embedding) 247 | 248 | # pdb.set_trace() 249 | vector = embedding_list[0] 250 | for i in range(1, len(embedding_list)): 251 | vector = self.AndNN(vector, embedding_list[i]) 252 | embedding = vector 253 | 254 | return embedding, idx 255 | 256 | 257 | def transform_union_query(self, queries, query_structure): 258 | ''' 259 | transform 2u queries to two 1p queries 260 | transform up queries to two 2p queries 261 | ''' 262 | if self.query_name_dict[query_structure] == '2u-DNF': 263 | queries = queries[:, :-1] # remove union -1 264 | elif self.query_name_dict[query_structure] == 'up-DNF': 265 | queries = torch.cat([torch.cat([queries[:, :2], queries[:, 5:6]], dim=1), torch.cat([queries[:, 2:4], queries[:, 5:6]], dim=1)], dim=1) 266 | queries = torch.reshape(queries, [queries.shape[0]*2, -1]) 267 | return queries 268 | 269 | def transform_union_structure(self, query_structure): 270 | if self.query_name_dict[query_structure] == '2u-DNF': 271 | return ('e', ('r',)) 272 | elif self.query_name_dict[query_structure] == 'up-DNF': 273 | return ('e', ('r', 'r')) 274 | 275 | 276 | def cal_logit_gmm(self, entity_embedding, query_embedding): 277 | one_entity_embedding = entity_embedding[:, :, 0, :-1].unsqueeze(-2) 278 | query_embedding_gauss_prob = query_embedding[:, :, :self.gmm_num, -1]# B, 1(neg_size), N 279 | query_embedding_mu = query_embedding[:, :, :self.gmm_num, :-1] 280 | query_embedding_sigma = query_embedding[:, :, self.gmm_num:, :-1] 281 | weighted_sigma = torch.matmul(query_embedding_gauss_prob.unsqueeze(-2), query_embedding_sigma) 282 | weighted_mu = torch.matmul(query_embedding_gauss_prob.unsqueeze(-2), query_embedding_mu) # B, 1(neg_size), 1, N * B, 1(neg_size), N, dim => B, 1(neg_size), 1, dim 283 | distance = one_entity_embedding - weighted_mu # B, 1(neg_size), 1, dim 284 | logit = self.gamma - torch.norm(distance, p=1, dim=-1).squeeze(-1) 285 | return logit 286 | 287 | 288 | @staticmethod 289 | def train_step(model, optimizer, train_iterator, args, step): 290 | model.train() 291 | optimizer.zero_grad() 292 | 293 | positive_sample, negative_sample, subsampling_weight, batch_queries, query_structures = next(train_iterator) # torch.Size([512]) batch_queries:list [[8057, 81, 96, 30], ..] query_structures:list [('e', ('r', 'r', 'r')), ...] 294 | # pdb.set_trace() 295 | batch_queries_dict = collections.defaultdict(list) 296 | batch_idxs_dict = collections.defaultdict(list) 297 | for i, query in enumerate(batch_queries): # group queries with same structure 298 | batch_queries_dict[query_structures[i]].append(query) 299 | batch_idxs_dict[query_structures[i]].append(i) 300 | for query_structure in batch_queries_dict: 301 | if args.cuda: 302 | batch_queries_dict[query_structure] = torch.LongTensor(batch_queries_dict[query_structure]).cuda() 303 | else: 304 | batch_queries_dict[query_structure] = torch.LongTensor(batch_queries_dict[query_structure]) 305 | if args.cuda: 306 | positive_sample = positive_sample.cuda() 307 | negative_sample = negative_sample.cuda() 308 | subsampling_weight = subsampling_weight.cuda() 309 | 310 | #with autograd.detect_anomaly(): 311 | if 1==1: 312 | positive_logit, negative_logit, subsampling_weight, _ = model(positive_sample, negative_sample, subsampling_weight, batch_queries_dict, batch_idxs_dict) 313 | # pdb.set_trace() 314 | asd = F.logsigmoid 315 | negative_score = asd(-negative_logit).mean(dim=1) 316 | positive_score = asd(positive_logit).squeeze(dim=1) 317 | positive_sample_loss = - (subsampling_weight * positive_score).sum() 318 | negative_sample_loss = - (subsampling_weight * negative_score).sum() 319 | positive_sample_loss /= subsampling_weight.sum() 320 | negative_sample_loss /= subsampling_weight.sum() 321 | 322 | loss = (positive_sample_loss + negative_sample_loss)/2 323 | loss.backward() 324 | pdb.set_trace() 325 | # for name, paras in model.named_parameters(): 326 | # print('-->name', name, '-->grad_required', paras.requires_grad, '-->grad_value', paras.grad) 327 | # pdb.set_trace() 328 | optimizer.step() 329 | log = { 330 | 'positive_sample_loss': positive_sample_loss.item(), 331 | 'negative_sample_loss': negative_sample_loss.item(), 332 | 'loss': loss.item(), 333 | } 334 | return log 335 | 336 | @staticmethod 337 | def test_step(model, easy_answers, hard_answers, args, test_dataloader, query_name_dict, save_result=False, save_str="", save_empty=False): 338 | model.eval() 339 | 340 | step = 0 341 | total_steps = len(test_dataloader) 342 | logs = collections.defaultdict(list) 343 | 344 | with torch.no_grad(): 345 | for negative_sample, queries, queries_unflatten, query_structures in tqdm(test_dataloader, disable=not args.print_on_screen): 346 | batch_queries_dict = collections.defaultdict(list) 347 | batch_idxs_dict = collections.defaultdict(list) 348 | for i, query in enumerate(queries): 349 | batch_queries_dict[query_structures[i]].append(query) 350 | batch_idxs_dict[query_structures[i]].append(i) 351 | for query_structure in batch_queries_dict: 352 | if args.cuda: 353 | batch_queries_dict[query_structure] = torch.LongTensor(batch_queries_dict[query_structure]).cuda() 354 | else: 355 | batch_queries_dict[query_structure] = torch.LongTensor(batch_queries_dict[query_structure]) 356 | if args.cuda: 357 | negative_sample = negative_sample.cuda() 358 | 359 | # pdb.set_trace() 360 | _, negative_logit, _, idxs = model(None, negative_sample, None, batch_queries_dict, batch_idxs_dict) 361 | # pdb.set_trace() 362 | queries_unflatten = [queries_unflatten[i] for i in idxs] 363 | query_structures = [query_structures[i] for i in idxs] 364 | argsort = torch.argsort(negative_logit, dim=1, descending=True) 365 | ranking = argsort.clone().to(torch.float) # [14505,] i.e.[[ 6398., 268., 3127., ..., 14216., 14504., 14215.]] 366 | if len(argsort) == args.test_batch_size: 367 | ranking = ranking.scatter_(1, argsort, model.batch_entity_range) # achieve the ranking of all entities 368 | else: 369 | if args.cuda: 370 | ranking = ranking.scatter_(1, 371 | argsort, 372 | torch.arange(model.nentity).to(torch.float).repeat(argsort.shape[0], 373 | 1).cuda() 374 | ) 375 | else: 376 | ranking = ranking.scatter_(1, 377 | argsort, 378 | torch.arange(model.nentity).to(torch.float).repeat(argsort.shape[0], 379 | 1) 380 | ) # achieve the ranking of all entities 381 | for idx, (i, query, query_structure) in enumerate(zip(argsort[:, 0], queries_unflatten, query_structures)): 382 | hard_answer = hard_answers[query] # set 383 | easy_answer = easy_answers[query] # set 384 | num_hard = len(hard_answer) 385 | num_easy = len(easy_answer) 386 | assert len(hard_answer.intersection(easy_answer)) == 0 387 | cur_ranking = ranking[idx, list(easy_answer) + list(hard_answer)] # 388 | cur_ranking, indices = torch.sort(cur_ranking) 389 | masks = indices >= num_easy 390 | if args.cuda: 391 | answer_list = torch.arange(num_hard + num_easy).to(torch.float).cuda() 392 | else: 393 | answer_list = torch.arange(num_hard + num_easy).to(torch.float) 394 | cur_ranking = cur_ranking - answer_list + 1 # filtered setting 395 | cur_ranking = cur_ranking[masks] # only take indices that belong to the hard answers 396 | 397 | mrr = torch.mean(1./cur_ranking).item() 398 | h1 = torch.mean((cur_ranking <= 1).to(torch.float)).item() 399 | h3 = torch.mean((cur_ranking <= 3).to(torch.float)).item() 400 | h10 = torch.mean((cur_ranking <= 10).to(torch.float)).item() 401 | 402 | logs[query_structure].append({ 403 | 'MRR': mrr, 404 | 'HITS1': h1, 405 | 'HITS3': h3, 406 | 'HITS10': h10, 407 | 'num_hard_answer': num_hard, 408 | }) 409 | 410 | if step % args.test_log_steps == 0: 411 | logging.info('Evaluating the model... (%d/%d)' % (step, total_steps)) 412 | 413 | step += 1 414 | 415 | metrics = collections.defaultdict(lambda: collections.defaultdict(int)) 416 | for query_structure in logs: 417 | for metric in logs[query_structure][0].keys(): 418 | if metric in ['num_hard_answer']: 419 | continue 420 | metrics[query_structure][metric] = sum([log[metric] for log in logs[query_structure]])/len(logs[query_structure]) 421 | metrics[query_structure]['num_queries'] = len(logs[query_structure]) 422 | 423 | return metrics 424 | -------------------------------------------------------------------------------- /codes/operators.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | from torch.nn.parameter import Parameter 5 | import torch.nn.functional as F 6 | import torch.nn.init as init 7 | import math 8 | from util import Regularizer 9 | import pdb 10 | 11 | 12 | class RelationProjectionMLP(nn.Module): 13 | """ 14 | define the multi-layer MLP by RelationProjectionLayer 15 | """ 16 | def __init__(self): 17 | super().__init__() 18 | 19 | 20 | class RelationProjectionLayer(nn.Module): 21 | """ 22 | define three NN for w_i mu_i sigma_i 23 | """ 24 | def __init__(self, nrelation, input_dim, output_dim, ngauss, projection_regularizer, bias=True): 25 | super().__init__() 26 | self.relation_num = nrelation 27 | self.ngauss = ngauss 28 | self.input_dim = input_dim 29 | self.output_dim = output_dim 30 | self.projection_regularizer = projection_regularizer 31 | self.W_mlp = RelationMlp(self.relation_num, input_dim, output_dim=1, ngauss=ngauss*2) 32 | self.Mu_mlp = RelationMlp(self.relation_num, input_dim, output_dim, ngauss=ngauss*2) 33 | self.Sigma_mlp = RelationMlp(self.relation_num, input_dim, output_dim, ngauss=ngauss*2) 34 | 35 | def forward(self, input_embedding, project_relation): 36 | x = input_embedding 37 | W_mlp_weight, W_mlp_bias = self.W_mlp(project_relation) 38 | Mu_mlp_weight, Mu_mlp_bias = self.Mu_mlp(project_relation) 39 | Sigma_mlp_weight, Sigma_mlp_bias = self.Sigma_mlp(project_relation) 40 | # B, 2N, dim+1 * B, dim+1, 1 => B, 2N, 1 => B, N=>B, N, 1 41 | output = F.relu(torch.matmul(x, W_mlp_weight) + W_mlp_bias).squeeze(-1) 42 | W_output = output[:, :self.ngauss] 43 | W_output = F.softmax(W_output, dim=1) 44 | W_output = W_output.unsqueeze(dim=-1) 45 | 46 | output = F.relu(torch.matmul(x, Mu_mlp_weight) + Mu_mlp_bias) 47 | Mu_output = output[:, :self.ngauss, :] 48 | 49 | output = F.relu(torch.matmul(x, Sigma_mlp_weight) + Sigma_mlp_bias) 50 | Sigma_output = output[:, self.ngauss:, :] 51 | 52 | mask_embedding = torch.zeros_like(W_output, requires_grad=True) 53 | output_embedding1 = torch.cat((Mu_output, W_output), dim=-1) 54 | output_embedding2 = torch.cat((Sigma_output, mask_embedding), dim=-1) 55 | output_embedding = torch.cat((output_embedding1, output_embedding2), dim=1) 56 | 57 | return output_embedding 58 | 59 | 60 | class RelationMlp(nn.Module): 61 | def __init__(self, nrelation, input_dim, output_dim, ngauss, bias=True): 62 | super().__init__() 63 | self.mlp_weight = Parameter(torch.Tensor(nrelation, input_dim, output_dim), requires_grad=True) 64 | if bias: 65 | self.mlp_bias = Parameter(torch.Tensor(nrelation, ngauss, output_dim), requires_grad=True) 66 | self.reset_parameters() 67 | 68 | def reset_parameters(self): 69 | init.kaiming_uniform_(self.mlp_weight, a=math.sqrt(5)) 70 | if self.mlp_bias is not None: 71 | fan_in, _ = init._calculate_fan_in_and_fan_out(self.mlp_weight) 72 | bound = 1 / math.sqrt(fan_in) 73 | init.uniform_(self.mlp_bias, -bound, bound) 74 | 75 | def forward(self, relation): 76 | return self.mlp_weight[relation], self.mlp_bias[relation] 77 | 78 | 79 | class AndMLP(nn.Module): 80 | def __init__(self, nguass, hidden_dim, and_regularizer): 81 | # hidden_dim: 2(dim+1) * 2(dim+1) 82 | super(AndMLP, self).__init__() 83 | self.nguass = nguass 84 | self.liner_dim = (hidden_dim + 1) * 2 85 | self.query = nn.Linear(self.liner_dim, self.liner_dim) 86 | self.key = nn.Linear(self.liner_dim, self.liner_dim) 87 | self.value = nn.Linear(self.liner_dim, self.liner_dim) 88 | 89 | self.mlp = nn.Linear(self.liner_dim, self.liner_dim) 90 | self.and_regularizer = and_regularizer 91 | 92 | def forward(self, embedding1, embedding2): 93 | # pdb.set_trace() 94 | trans_embedding1 = torch.cat((embedding1[:, :self.nguass, :], embedding1[:, self.nguass:, :]), dim=-1) 95 | trans_embedding2 = torch.cat((embedding2[:, :self.nguass, :], embedding2[:, self.nguass:, :]), dim=-1) 96 | q1 = self.query(trans_embedding1) 97 | k1 = self.key(trans_embedding1) 98 | v1 = self.value(trans_embedding1) 99 | 100 | q2 = self.query(trans_embedding2) 101 | k2 = self.key(trans_embedding2) 102 | v2 = self.value(trans_embedding2) 103 | 104 | d = trans_embedding1.shape[-1] 105 | attention_scores_1to2 = torch.matmul(q1, k2.transpose(-2, -1)) 106 | attention_scores_1to2 = attention_scores_1to2 / math.sqrt(d) 107 | attention_probs_1to2 = F.softmax(attention_scores_1to2, dim=-1) 108 | 109 | attention_scores_2to1 = torch.matmul(q2, k1.transpose(-2, -1)) 110 | attention_scores_2to1 = attention_scores_2to1 / math.sqrt(d) 111 | attention_probs_2to1 = F.softmax(attention_scores_2to1, dim=-1) 112 | 113 | out_1to2 = torch.matmul(attention_probs_1to2, v2) 114 | out_2to1 = torch.matmul(attention_probs_2to1, v1) 115 | 116 | out = out_1to2 + out_2to1 117 | out = F.relu(self.mlp(out)) 118 | 119 | W_out = out[:, :, int(d/2-1)] 120 | W_out = F.softmax(W_out, dim=1) 121 | W_out = W_out.unsqueeze(dim=-1) 122 | 123 | Mu_out = out[:, :, :int(d/2-1)] 124 | Sigma_out = out[:, :, int(d/2): -1] 125 | 126 | mask_embedding = torch.zeros_like(W_out, requires_grad=True) 127 | embedding_temp1 = torch.cat((Mu_out, W_out), dim=-1) 128 | embedding_temp2 = torch.cat((Sigma_out, mask_embedding), dim=-1) 129 | output_embedding = torch.cat((embedding_temp1, embedding_temp2), dim=1) 130 | 131 | return output_embedding 132 | 133 | 134 | class OrMLP(nn.Module): 135 | def __init__(self, n_layers, entity_dim): 136 | super(OrMLP, self).__init__() 137 | self.n_layers = n_layers 138 | self.layers = [] 139 | for i in range(1, self.n_layers + 1): 140 | setattr(self, "or_layer_{}".format(i), nn.Linear(2 * entity_dim, 2 * entity_dim)) 141 | self.last_layer = nn.Linear(2 * entity_dim, entity_dim) 142 | 143 | def forward(self, x1, x2): 144 | x = torch.cat((x1, x2), dim=-1) 145 | for i in range(1, self.n_layers + 1): 146 | x = F.relu(getattr(self, "or_layer_{}".format(i))(x)) 147 | x = self.last_layer(x) 148 | return x 149 | 150 | 151 | class NotMLP(nn.Module): 152 | def __init__(self, n_layers, entity_dim): 153 | super(NotMLP, self).__init__() 154 | self.n_layers = n_layers 155 | self.layers = [] 156 | for i in range(1, self.n_layers + 1): 157 | setattr(self, "not_layer_{}".format(i), nn.Linear(entity_dim, entity_dim)) 158 | self.last_layer = nn.Linear(entity_dim, entity_dim) 159 | 160 | def forward(self, x): 161 | for i in range(1, self.n_layers + 1): 162 | x = F.relu(getattr(self, "not_layer_{}".format(i))(x)) 163 | x = self.last_layer(x) 164 | return x -------------------------------------------------------------------------------- /codes/util.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import random 3 | import torch 4 | import time 5 | 6 | def list2tuple(l): 7 | return tuple(list2tuple(x) if type(x)==list else x for x in l) 8 | 9 | def tuple2list(t): 10 | return list(tuple2list(x) if type(x)==tuple else x for x in t) 11 | 12 | flatten=lambda l: sum(map(flatten, l),[]) if isinstance(l,tuple) else [l] 13 | 14 | def parse_time(): 15 | return time.strftime("%Y.%m.%d-%H:%M:%S", time.localtime()) 16 | 17 | def set_global_seed(seed): 18 | torch.manual_seed(seed) 19 | torch.cuda.manual_seed(seed) 20 | np.random.seed(seed) 21 | random.seed(seed) 22 | torch.backends.cudnn.deterministic=True 23 | 24 | def eval_tuple(arg_return): 25 | """Evaluate a tuple string into a tuple.""" 26 | if type(arg_return) == tuple: 27 | return arg_return 28 | if arg_return[0] not in ["(", "["]: 29 | arg_return = eval(arg_return) 30 | else: 31 | splitted = arg_return[1:-1].split(",") 32 | List = [] 33 | for item in splitted: 34 | try: 35 | item = eval(item) 36 | except: 37 | pass 38 | if item == "": 39 | continue 40 | List.append(item) 41 | arg_return = tuple(List) 42 | return arg_return 43 | 44 | def flatten_query(queries): 45 | all_queries = [] 46 | for query_structure in queries: 47 | tmp_queries = list(queries[query_structure]) 48 | all_queries.extend([(query, query_structure) for query in tmp_queries]) 49 | return all_queries 50 | 51 | 52 | class Regularizer(): 53 | def __init__(self, base_add, min_val, max_val): 54 | self.base_add = base_add 55 | self.min_val = min_val 56 | self.max_val = max_val 57 | 58 | def __call__(self, entity_embedding): 59 | return torch.clamp(entity_embedding + self.base_add, self.min_val, self.max_val) --------------------------------------------------------------------------------