├── README.md ├── constants.py ├── dataloader.py ├── fuzzyreasoning.py ├── gumbel.py ├── investigation_helper.py ├── main.py ├── models.py ├── operations.py ├── regularizers.py ├── requirements.txt ├── run.sh ├── test-pretrained-model.py └── util.py /README.md: -------------------------------------------------------------------------------- 1 | Resources and code for paper "Fuzzy Logic based Logical Query Answering on Knowledge Graphs". 2 | 3 | 4 | ## Environment 5 | Make sure your local environment has the following installed: 6 | 7 | Python3.9 8 | torch == 1.9.0 9 | wandb == 0.9.7 10 | 11 | 12 | Install the dependency using: 13 | 14 | pip install -r requirements.txt 15 | 16 | 17 | ## Download data 18 | 19 | Download data from [here](http://snap.stanford.edu/betae/KG_data.zip) and put it under `data` folder. 20 | 21 | The directory structure should be like `[PROJECT_DIR]/data/NELL-betae/train-queries.pkl`. 22 | 23 | 24 | Only FB15k-237 and NELL995 are used in our study. 25 | 26 | 27 | ## Train 28 | Training script example: `./run.sh` 29 | 30 | It usually takes 4 days to a week to finish a run on a NVIDIA® GP102 TITAN Xp (12GB) GPU. 31 | 32 | 33 | 34 | *TODO: More training scripts for easy training will be added soon.* 35 | 36 | 37 | 38 | ## Test 39 | 40 | The trained model will be automatically stored under the folder `./trained_models`. The model name will be `[WANDB_RUN_NAME].pt`. 41 | 42 | To test a trained model, you can use the following command: 43 | 44 | python ./test-pretrained-model.py [DATA_NAME] [WANDB_RUN_NAME] 45 | 46 | By default, the test tests for product logic. You can also test for other logic systems ('godel' or 'luka') by modifying the `logic` variable in the script. 47 | 48 | 49 | ### Test the pretrained model 50 | 51 | The pretrained FuzzQE model (product logic) for NELL can be downloaded [here](https://drive.google.com/file/d/15ByNcDayg5Vw67SaIk9ZPE3Gfa9tlTmo/view?usp=sharing). You can put it under `./trained_models` and use the following command to test it: 52 | 53 | python ./test-pretrained-model.py NELL feasible-resonance-1518 54 | 55 | 56 | *TODO: More pretrained models will be uploaded soon.* 57 | 58 | 59 | 60 | ## Reference 61 | Please refer to our paper if you find the resources useful. 62 | 63 | Xuelu Chen, Ziniu Hu, Yizhou Sun. Fuzzy Logic based Logical Query Answering on Knowledge Graphs. *Proceedings of the Thirty-sixth AAAI Conference on Artificial Intelligence (AAAI), 2022.* 64 | 65 | 66 | 67 | @inproceedings{chen2021fuzzyqa, 68 | title={Fuzzy Logic based Logical Query Answering on Knowledge Graphs}, 69 | author={Chen, Xuelu and Hu, Ziniu and Sun, Yizhou} 70 | booktitle={Proceedings of the Thirty-sixth AAAI Conference on Artificial Intelligence (AAAI)}, 71 | year={2022} 72 | } 73 | 74 | -------------------------------------------------------------------------------- /constants.py: -------------------------------------------------------------------------------- 1 | 2 | query_name_dict = { 3 | ('e',('r',)): '1p', 4 | ('e', ('r', 'r')): '2p', 5 | ('e', ('r', 'r', 'r')): '3p', 6 | (('e', ('r',)), ('e', ('r',))): '2i', 7 | (('e', ('r',)), ('e', ('r',)), ('e', ('r',))): '3i', 8 | ((('e', ('r',)), ('e', ('r',))), ('r',)): 'ip', 9 | (('e', ('r', 'r')), ('e', ('r',))): 'pi', 10 | (('e', ('r',)), ('e', ('r', 'n'))): '2in', 11 | (('e', ('r',)), ('e', ('r',)), ('e', ('r', 'n'))): '3in', 12 | ((('e', ('r',)), ('e', ('r', 'n'))), ('r',)): 'inp', 13 | (('e', ('r', 'r')), ('e', ('r', 'n'))): 'pin', 14 | (('e', ('r', 'r', 'n')), ('e', ('r',))): 'pni', 15 | (('e', ('r',)), ('e', ('r',)), ('u',)): '2u-DNF', 16 | ((('e', ('r',)), ('e', ('r',)), ('u',)), ('r',)): 'up-DNF', 17 | ((('e', ('r', 'n')), ('e', ('r', 'n'))), ('n',)): '2u-DM', 18 | ((('e', ('r', 'n')), ('e', ('r', 'n'))), ('n', 'r')): 'up-DM' 19 | } 20 | query_structure_list = list(query_name_dict.keys()) # query_structure_list[0] -> query_structure of index 0 21 | query_structure2idx = {s: i for i, s in enumerate(query_structure_list)} # {('e',('r',)):0} 22 | -------------------------------------------------------------------------------- /dataloader.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python3 2 | 3 | from __future__ import absolute_import 4 | from __future__ import division 5 | from __future__ import print_function 6 | 7 | import numpy as np 8 | import torch 9 | import time 10 | import pickle 11 | import os 12 | from torch.utils.data import DataLoader 13 | from torch.utils.data import Dataset 14 | from util import list2tuple, tuple2list, flatten, flatten_query_and_convert_structure_to_idx 15 | from collections import defaultdict 16 | from constants import query_structure2idx 17 | 18 | 19 | class TestDataset(Dataset): 20 | def __init__(self, queries, nentity, nrelation): 21 | """ 22 | :param queries: list[(query, query_structure_idx)] 23 | """ 24 | self.len = len(queries) 25 | self.queries = queries 26 | self.nentity = nentity 27 | self.nrelation = nrelation 28 | 29 | def __len__(self): 30 | return self.len 31 | 32 | def __getitem__(self, idx): 33 | query = self.queries[idx][0] 34 | query_structure_idx = self.queries[idx][1] 35 | negative_sample = torch.LongTensor(range(self.nentity)) 36 | return negative_sample, flatten(query), query, query_structure_idx 37 | 38 | @staticmethod 39 | def collate_fn(data): 40 | negative_sample = torch.stack([_[0] for _ in data], dim=0) 41 | query = np.array([_[1] for _ in data]) 42 | query_unflatten = [_[2] for _ in data] # don't make it np.array. keep it as list of tuples 43 | query_structure_idx = np.array([_[3] for _ in data]) 44 | return negative_sample, query, query_unflatten, query_structure_idx 45 | 46 | 47 | class TrainDataset(Dataset): 48 | def __init__(self, queries, nentity, nrelation, negative_sample_size, answer): 49 | """ 50 | :param queries: list[(query, query_structure_idx)] 51 | """ 52 | self.len = len(queries) 53 | self.queries = queries 54 | self.nentity = nentity 55 | self.nrelation = nrelation 56 | self.negative_sample_size = negative_sample_size 57 | self.count = self.count_frequency(queries, answer) 58 | self.answer = answer 59 | 60 | def __len__(self): 61 | return self.len 62 | 63 | def __getitem__(self, idx): 64 | query = self.queries[idx][0] 65 | query_structure_idx = self.queries[idx][1] 66 | tail = np.random.choice(list(self.answer[query])) 67 | subsampling_weight = self.count[query] 68 | subsampling_weight = torch.sqrt(1 / torch.Tensor([subsampling_weight])) 69 | 70 | # negative_sample_list = [] 71 | # negative_sample_size = 0 72 | # while negative_sample_size < self.negative_sample_size: 73 | # negative_sample = np.random.randint(self.nentity, size=self.negative_sample_size*2) 74 | # mask = np.in1d( 75 | # negative_sample, 76 | # self.answer[query], 77 | # assume_unique=True, 78 | # invert=True 79 | # ) 80 | # negative_sample = negative_sample[mask] 81 | # negative_sample_list.append(negative_sample) 82 | # negative_sample_size += negative_sample.size 83 | # negative_sample = np.concatenate(negative_sample_list)[:self.negative_sample_size] 84 | # negative_sample = torch.from_numpy(negative_sample).type(torch.LongTensor) 85 | 86 | # the above sampling is too slow but not significant performance gain 87 | # Shirley 88 | negative_sample = torch.randint(self.nentity, (self.negative_sample_size,)) 89 | positive_sample = torch.LongTensor([tail]) 90 | return positive_sample, negative_sample, subsampling_weight, flatten(query), query_structure_idx 91 | 92 | @staticmethod 93 | def collate_fn(data): 94 | positive_sample = torch.cat([_[0] for _ in data], dim=0) 95 | negative_sample = torch.stack([_[1] for _ in data], dim=0) 96 | subsample_weight = torch.cat([_[2] for _ in data], dim=0) 97 | query = np.array([_[3] for _ in data]) # can't convert to tensor due to the varying length 98 | query_structure_idx = np.array([_[4] for _ in data]) 99 | return positive_sample, negative_sample, subsample_weight, query, query_structure_idx 100 | 101 | @staticmethod 102 | def count_frequency(queries, answer, start=4): 103 | count = {} 104 | for query, qtype in queries: 105 | count[query] = start + len(answer[query]) 106 | return count 107 | 108 | 109 | class SingledirectionalOneShotIterator(object): 110 | def __init__(self, dataloader): 111 | self.iterator = self.one_shot_iterator(dataloader) 112 | self.step = 0 113 | 114 | def __next__(self): 115 | self.step += 1 116 | data = next(self.iterator) 117 | return data 118 | 119 | @staticmethod 120 | def one_shot_iterator(dataloader): 121 | while True: 122 | for data in dataloader: 123 | yield data 124 | 125 | 126 | def filter_by_tasks(queries, name_query_dict, tasks, evaluate_union): 127 | """ 128 | remove queries not in tasks. 129 | """ 130 | all_query_names = set(name_query_dict.keys()) 131 | for name in all_query_names: 132 | if 'u' in name: 133 | name, name_evaluate_union = name.split('-') 134 | else: 135 | name_evaluate_union = evaluate_union 136 | if name not in tasks or name_evaluate_union != evaluate_union: 137 | query_structure = name_query_dict[name if 'u' not in name else '-'.join([name, name_evaluate_union])] 138 | if query_structure in queries: 139 | del queries[query_structure] 140 | return queries 141 | 142 | 143 | def load_data_from_pickle(args, query_name_dict, tasks): 144 | """ 145 | Load queries/answers and remove queries not in tasks. 146 | To save time, only load corresponding queries and answers 147 | when flags like args.do_train, args.do_valid, args.do_test are True. 148 | Otherwise return None. 149 | :param query_name_dict: all possible query types across models. type dict{query_str:query_name} 150 | :param tasks: task to use 151 | """ 152 | # To save time, only load corresponding queries and answers 153 | # when flags like args.do_train, args.do_valid, args.do_test are True. 154 | # Otherwise return None. 155 | train_queries, train_answers, valid_queries, valid_hard_answers, valid_easy_answers, \ 156 | test_queries, test_hard_answers, test_easy_answers = None, None, None, None, None, None, None, None 157 | 158 | print("loading data") 159 | time0 = time.time() 160 | 161 | name_query_dict = {value: key for key, value in query_name_dict.items()} # {'1p': ('e',('r',)), ...} 162 | 163 | if args.do_train: 164 | train_queries = pickle.load(open(os.path.join(args.data_path, "train-queries.pkl"), 'rb')) 165 | train_answers = pickle.load(open(os.path.join(args.data_path, "train-answers.pkl"), 'rb')) 166 | train_queries = filter_by_tasks(train_queries, name_query_dict, tasks, args.evaluate_union) 167 | if args.do_valid: 168 | valid_queries = pickle.load(open(os.path.join(args.data_path, "valid-queries.pkl"), 'rb')) 169 | valid_hard_answers = pickle.load(open(os.path.join(args.data_path, "valid-hard-answers.pkl"), 'rb')) 170 | valid_easy_answers = pickle.load(open(os.path.join(args.data_path, "valid-easy-answers.pkl"), 'rb')) 171 | valid_queries = filter_by_tasks(valid_queries, name_query_dict, tasks, args.evaluate_union) 172 | if args.do_test: 173 | test_queries = pickle.load(open(os.path.join(args.data_path, "test-queries.pkl"), 'rb')) 174 | test_hard_answers = pickle.load(open(os.path.join(args.data_path, "test-hard-answers.pkl"), 'rb')) 175 | test_easy_answers = pickle.load(open(os.path.join(args.data_path, "test-easy-answers.pkl"), 'rb')) 176 | test_queries = filter_by_tasks(test_queries, name_query_dict, tasks, args.evaluate_union) 177 | 178 | print(f'Loading data uses time: {time.time()-time0}') 179 | 180 | return train_queries, train_answers, valid_queries, valid_hard_answers, valid_easy_answers, \ 181 | test_queries, test_hard_answers, test_easy_answers 182 | 183 | 184 | def load_data(args, query_name_dict, tasks): 185 | # only generate it when necessary 186 | train_path_iterator, train_other_iterator, valid_dataloader, test_dataloader = None, None, None, None 187 | 188 | train_queries, train_answers, valid_queries, valid_hard_answers, valid_easy_answers, \ 189 | test_queries, test_hard_answers, test_easy_answers = load_data_from_pickle(args, query_name_dict, tasks) 190 | 191 | if args.do_train: 192 | print('Training query info:') 193 | for query_structure in train_queries: 194 | print(query_name_dict[query_structure] + ": " + str(len(train_queries[query_structure]))) 195 | train_path_queries = defaultdict(set) 196 | train_other_queries = defaultdict(set) 197 | path_list = ['1p', '2p', '3p'] 198 | for query_structure in train_queries: 199 | if query_name_dict[query_structure] in path_list: 200 | train_path_queries[query_structure] = train_queries[query_structure] 201 | else: 202 | train_other_queries[query_structure] = train_queries[query_structure] 203 | train_path_queries = flatten_query_and_convert_structure_to_idx(train_path_queries, query_structure2idx) 204 | train_path_iterator = SingledirectionalOneShotIterator(DataLoader( 205 | TrainDataset(train_path_queries, args.nentity, args.nrelation, args.negative_sample_size, train_answers), 206 | batch_size=args.batch_size, 207 | shuffle=True, 208 | num_workers=args.cpu_num, 209 | collate_fn=TrainDataset.collate_fn 210 | )) 211 | if len(train_other_queries) > 0: 212 | train_other_queries = flatten_query_and_convert_structure_to_idx(train_other_queries, query_structure2idx) 213 | train_other_iterator = SingledirectionalOneShotIterator(DataLoader( 214 | TrainDataset(train_other_queries, args.nentity, args.nrelation, args.negative_sample_size, train_answers), 215 | batch_size=args.batch_size, 216 | shuffle=True, 217 | num_workers=args.cpu_num, 218 | collate_fn=TrainDataset.collate_fn 219 | )) 220 | else: 221 | train_other_iterator = None 222 | 223 | if args.do_valid: 224 | print('Validation query info:') 225 | for query_structure in valid_queries: 226 | print(query_name_dict[query_structure] + ": " + str(len(valid_queries[query_structure]))) 227 | valid_queries = flatten_query_and_convert_structure_to_idx(valid_queries, query_structure2idx) 228 | valid_dataloader = DataLoader( 229 | TestDataset( 230 | valid_queries, 231 | args.nentity, 232 | args.nrelation, 233 | ), 234 | batch_size=args.test_batch_size, 235 | num_workers=args.cpu_num, 236 | collate_fn=TestDataset.collate_fn 237 | ) 238 | 239 | if args.do_test: 240 | print('Test query info:') 241 | for query_structure in test_queries: 242 | print(query_name_dict[query_structure] + ": " + str(len(test_queries[query_structure]))) 243 | test_queries = flatten_query_and_convert_structure_to_idx(test_queries, query_structure2idx) 244 | test_dataloader = DataLoader( 245 | TestDataset( 246 | test_queries, 247 | args.nentity, 248 | args.nrelation, 249 | ), 250 | batch_size=args.test_batch_size, 251 | num_workers=args.cpu_num, 252 | collate_fn=TestDataset.collate_fn 253 | ) 254 | 255 | return train_path_iterator, train_other_iterator, valid_dataloader, test_dataloader, \ 256 | valid_hard_answers, valid_easy_answers, test_hard_answers, test_easy_answers # keep answers for evaluation 257 | -------------------------------------------------------------------------------- /fuzzyreasoning.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from models import * 4 | import wandb 5 | from constants import query_structure_list, query_structure2idx 6 | from util import get_regularizer 7 | from operations import Projection, Conjunction, Disjunction, Negation 8 | from gumbel import gumbel_softmax 9 | import torch.nn.functional as F 10 | 11 | class KGFuzzyReasoning(KGReasoning): 12 | def __init__( 13 | self, nentity, nrelation, hidden_dim, gamma, 14 | geo, test_batch_size=1, 15 | box_mode=None, use_cuda=False, 16 | query_name_dict=None, beta_mode=None, 17 | logic_type='product', 18 | regularizer_setting=None, 19 | gamma_coff=20, 20 | loss_type='cos', 21 | margin_type='logsigmoid', 22 | device=None, 23 | godel_gumbel_beta=0.01, 24 | gumbel_temperature=1, 25 | projection_type='mlp', 26 | args=None 27 | 28 | ): 29 | super(KGFuzzyReasoning, self).__init__(nentity, nrelation, hidden_dim, gamma, 30 | geo, test_batch_size, 31 | box_mode, use_cuda, 32 | query_name_dict, beta_mode) 33 | 34 | self.device = device 35 | 36 | # embedding 37 | self.hidden_dim = hidden_dim 38 | self.epsilon = 2.0 39 | 40 | self.batch_entity_range = torch.arange(nentity).to(torch.float).repeat(test_batch_size,1).to(self.device) 41 | 42 | self.entity_dim = hidden_dim 43 | 44 | self.no_anchor_reg = args.no_anchor_reg 45 | 46 | 47 | if args.load_pretrained == True: 48 | with open('./trained_models/NELL-entity-emb.pt', 'rb') as f: 49 | # use pretrained embeddings to initialize and speed up training 50 | entity_embs = pickle.load(f) 51 | self.entity_embedding = nn.Parameter(entity_embs) 52 | else: 53 | self.entity_embedding = nn.Parameter(torch.zeros(nentity, self.entity_dim)) 54 | if self.no_anchor_reg: 55 | nn.init.xavier_uniform_(self.entity_embedding) 56 | else: 57 | # embedding definition 58 | # embedding initialization 59 | nn.init.uniform_(tensor=self.entity_embedding, a=0, b=1) 60 | 61 | 62 | self.simplE = args.simplE 63 | if args.simplE: # use separate head and tail embeddings for entities 64 | self.entity_head_embedding = nn.Parameter(torch.zeros(nentity, self.entity_dim)) 65 | nn.init.uniform_(tensor=self.entity_embedding, a=0, b=1) 66 | 67 | 68 | 69 | 70 | # loss 71 | self.gamma_coff = gamma_coff 72 | self.loss_type = loss_type 73 | self.margin_type = margin_type 74 | if self.loss_type == 'weighted_dot': 75 | self.dim_weight = nn.Parameter(torch.ones((self.entity_dim,))) 76 | self.dim_weight_softmax = nn.Softmax(dim=-1) 77 | 78 | if margin_type == 'softmax': 79 | self.softmax_weight = torch.Tensor([10]).to(device) 80 | 81 | # regularizer: how to turn elements into 0,1 82 | self.entity_regularizer = get_regularizer(regularizer_setting, self.entity_dim, neg_input_possible=True, entity=True) 83 | 84 | wandb.log({'loss_type': loss_type}) 85 | 86 | self.godel_gumbel_beta = godel_gumbel_beta 87 | 88 | # intersection and projectizAaz<>on 89 | projection_dim, num_layers = beta_mode 90 | self.projection_net = Projection( 91 | nrelation, 92 | self.entity_dim, 93 | logic_type, 94 | regularizer_setting, 95 | self.relation_dim, 96 | projection_dim, 97 | num_layers, 98 | projection_type, 99 | num_rel_base=args.num_rel_base 100 | ) 101 | 102 | self.conjunction_net = Conjunction(self.entity_dim, logic_type, regularizer_setting, use_attention=args.use_attention, godel_gumbel_beta=godel_gumbel_beta) 103 | self.disjunction_net = Disjunction(self.entity_dim, logic_type, regularizer_setting, godel_gumbel_beta=godel_gumbel_beta) 104 | self.negation_net = Negation(self.entity_dim, logic_type, regularizer_setting) 105 | 106 | # gumbel softmax 107 | self.gumbel_temperature = gumbel_temperature # used if loss_type == 'gumbel_softmax' 108 | self.gumbel_attention = args.gumbel_attention if args.gumbel_attention != 'none' else None # None or 'plain' or 'query_dependent' 109 | if self.loss_type == 'gumbel_softmax' and args.gumbel_attention: 110 | self.n_distribution = self.entity_regularizer.get_num_distributions() 111 | self.distribution_weights = nn.Parameter(torch.ones(self.n_distribution)) 112 | if args.gumbel_attention == 'query_dependent': 113 | self.attention_layer = nn.Linear(self.entity_dim, self.n_distribution) 114 | self.gumbel_query_unnorm = args.query_unnorm 115 | 116 | self.in_batch_negative = args.in_batch_negative 117 | 118 | if self.loss_type == 'dot_layernorm_digits': 119 | self.entity_ln = nn.LayerNorm(self.hidden_dim, elementwise_affine=False) 120 | self.query_ln = nn.LayerNorm(self.hidden_dim, elementwise_affine=False) 121 | 122 | self.counter_for_neg = args.with_counter # add \neg q to negative samples 123 | 124 | self.margin_type = args.margin_type 125 | 126 | 127 | 128 | 129 | 130 | def forward( 131 | self, 132 | positive_sample, 133 | negative_sample, 134 | subsampling_weight, 135 | batch_queries_full, 136 | query_structure_idxs_full, 137 | idxs, 138 | inference=False # for discrete, use soft for training and hard for inference 139 | ): 140 | """ 141 | :param batch_queries_full: np.array[queries], e.g. array[array[8140,0], array[7269, 12, 13]] 142 | :param query_structures_idxs: np.array[query_structure_idx], e.g. array[0 3] 143 | """ 144 | 145 | # batch_queries_full is numpy and wasn't split when using multiple GPUs 146 | if len(idxs) != len(batch_queries_full): # multiple GPUs 147 | min_id, max_id = idxs[0], idxs[-1] 148 | batch_queries = batch_queries_full[min_id:max_id+1] 149 | query_structure_idxs = query_structure_idxs_full[min_id:max_id+1] 150 | else: 151 | batch_queries, query_structure_idxs = batch_queries_full, query_structure_idxs_full 152 | # print('query_structure_idxs', query_structure_idxs) 153 | 154 | # aggregate by query structure 155 | # i_qs: index for query structures 156 | sample_idx_list = [(query_structure_idxs == i_qs) for i_qs in range(len(query_structure_list))] 157 | batch_idxs_dict = { 158 | query_structure_list[i]: sample_idx.nonzero() 159 | for i, sample_idx in enumerate(sample_idx_list) 160 | if np.any(sample_idx) 161 | } 162 | 163 | batch_queries_dict = { 164 | query_structure: torch.LongTensor(np.stack(batch_queries[sample_idxs])).to(self.device) 165 | for query_structure, sample_idxs in batch_idxs_dict.items() 166 | } 167 | 168 | # all query embeddings 169 | # concatenate vectors 170 | all_idxs = np.concatenate([batch_idxs_dict[query_structure] for query_structure in batch_queries_dict], axis=None) 171 | all_embeddings = torch.cat( 172 | [ 173 | self.embed_query_fuzzy( 174 | batch_queries_dict[query_structure], 175 | query_structure, 176 | idx=0 177 | )[0] 178 | for query_structure in batch_queries_dict 179 | ], 180 | dim=0 181 | ).unsqueeze(1) 182 | 183 | all_idxs = torch.from_numpy(all_idxs).to(negative_sample.device) 184 | 185 | # 186 | # if len(all_embeddings) > 0: 187 | # all_embeddings = torch.cat(all_embeddings, dim=0).unsqueeze(1) 188 | 189 | if subsampling_weight is not None: 190 | subsampling_weight = subsampling_weight[all_idxs] 191 | 192 | if positive_sample is not None: 193 | if len(all_embeddings) > 0: 194 | # positive samples for non-union queries in this batch 195 | positive_sample_regular = positive_sample[all_idxs] 196 | if self.loss_type.startswith('discrete'): 197 | # soft discretization 198 | # use steep sigmoid to make entries closer to 0,1 199 | positive_embedding = self.entity_regularizer.soft_discretize( 200 | torch.index_select( 201 | self.entity_embedding, 202 | dim=0, 203 | index=positive_sample_regular 204 | ).unsqueeze(1) 205 | ) 206 | else: 207 | positive_embedding = self.entity_regularizer( 208 | torch.index_select( 209 | self.entity_embedding, 210 | dim=0, 211 | index=positive_sample_regular 212 | ).unsqueeze(1) 213 | ) 214 | 215 | positive_score = self.cal_logit_fuzzy(positive_embedding, all_embeddings, inference=inference) 216 | else: 217 | positive_score = torch.Tensor([]).to(self.device) 218 | 219 | else: 220 | positive_score = None 221 | 222 | if negative_sample is None: 223 | negative_score = None 224 | else: 225 | if len(all_embeddings) > 0: 226 | negative_sample_regular = negative_sample[all_idxs] 227 | 228 | batch_size, negative_size = negative_sample_regular.shape 229 | if self.loss_type.startswith('discrete'): 230 | # soft discretization 231 | # use steep sigmoid to make entries closer to 0,1 232 | negative_embedding = self.entity_regularizer.soft_discretize( 233 | torch.index_select( 234 | self.entity_embedding, 235 | dim=0, 236 | index=negative_sample_regular.view(-1) 237 | ).view( 238 | batch_size, 239 | negative_size, 240 | -1 241 | ) 242 | ) 243 | else: 244 | negative_embedding = self.entity_regularizer( 245 | torch.index_select( 246 | self.entity_embedding, 247 | dim=0, 248 | index=negative_sample_regular.view(-1) 249 | ).view( 250 | batch_size, 251 | negative_size, 252 | -1 253 | ) 254 | ) 255 | # random negative samples 256 | negative_score = self.cal_logit_fuzzy(negative_embedding, all_embeddings, inference=inference) 257 | else: 258 | negative_score = torch.Tensor([]).to(self.entity_embedding.device) 259 | 260 | if self.counter_for_neg and (not inference): 261 | # add \neg q as a negative sample into training 262 | emphasize = 16 263 | neg_q_embeddings = self.negation_net(all_embeddings) 264 | negative_score_2 = self.cal_logit_fuzzy(positive_embedding, neg_q_embeddings, inference=inference) # [batch_size, 1] 265 | negative_score_2 = negative_score_2.expand(-1, emphasize) 266 | negative_score = torch.cat((negative_score, negative_score_2), dim=1) 267 | return positive_score, negative_score, subsampling_weight, all_idxs 268 | else: 269 | return positive_score, negative_score, subsampling_weight, all_idxs 270 | 271 | def embed_query_fuzzy(self, queries, query_structure, idx): 272 | """ 273 | :param query_structure: e.g. ((('e', ('r',)), ('e', ('r',)), ('u',)), ('r',)) 274 | :param queries: Tensor. shape [batch_size, M], 275 | where M is the number of elements in query_structure (6 in the above examples) 276 | :param idx: which column to start in tensor queries 277 | """ 278 | all_relation_flag = True 279 | for ele in query_structure[-1]: 280 | # whether the current query tree has merged to one branch 281 | # and only need to do relation traversal, 282 | # e.g., path queries or conjunctive queries after the intersection 283 | if ele not in ['r', 'n']: 284 | all_relation_flag = False 285 | break 286 | if all_relation_flag: # only relation traversal 287 | if query_structure[0] == 'e': 288 | if self.simplE: 289 | # use head embeddings 290 | embedding = self.entity_regularizer( 291 | torch.index_select(self.entity_head_embedding, dim=0, index=queries[:, idx]) 292 | ) 293 | else: 294 | if self.no_anchor_reg: 295 | # entity embedding 296 | embedding = torch.index_select(self.entity_embedding, dim=0, index=queries[:, idx]) 297 | 298 | else: 299 | # entity embedding 300 | embedding = self.entity_regularizer( 301 | torch.index_select(self.entity_embedding, dim=0, index=queries[:, idx]) 302 | ) 303 | 304 | idx += 1 # move to next element (next column in queries) 305 | else: 306 | # recursion 307 | embedding, idx = self.embed_query_fuzzy(queries, query_structure[0], idx) 308 | 309 | for i in range(len(query_structure[-1])): # query_structure[-1]: ('r', 'n', 'r', ...'r') 310 | if query_structure[-1][i] == 'n': # negation 311 | assert (queries[:, idx] == -2).all() 312 | # embedding = self.fuzzy_logic.negation(embedding) 313 | embedding = self.negation_net(embedding) 314 | else: 315 | rel_indices = queries[:,idx] 316 | # embedding = self.fuzzy_logic.projection(embedding, r_embedding) 317 | embedding = self.projection_net(embedding, rel_indices) 318 | idx += 1 319 | else: 320 | subtree_embedding_list = [] 321 | if 'u' in query_structure[-1]: # last one is ('u') 322 | # aggregation by disjunction (union) 323 | num_subtrees = len(query_structure) - 1 # last one is 'u' 324 | # agg_net = self.fuzzy_logic.disjunction 325 | agg_net = self.disjunction_net 326 | else: 327 | # aggregation by conjunction (intersection) 328 | num_subtrees = len(query_structure) 329 | agg_net = self.conjunction_net 330 | 331 | for i in range(num_subtrees): 332 | subtree_embedding, idx = self.embed_query_fuzzy(queries, query_structure[i], idx) 333 | subtree_embedding_list.append(subtree_embedding) 334 | 335 | embedding = agg_net(torch.stack(subtree_embedding_list)) 336 | 337 | if 'u' in query_structure[-1]: # move to next 338 | idx += 1 339 | 340 | return embedding, idx 341 | 342 | def get_distribution_attention(self, query_embedding=None): 343 | # for gumbel softmax 344 | softmax = nn.Softmax(dim=-1) 345 | 346 | if self.gumbel_attention == 'plain': 347 | return softmax(self.distribution_weights) 348 | elif self.gumbel_attention == 'query_dependent': 349 | distribution_attention = softmax(self.attention_layer(query_embedding)) 350 | return distribution_attention 351 | 352 | 353 | 354 | def cal_logit_fuzzy(self, entity_embedding, query_embedding, inference=False): 355 | """ 356 | define scoring function for loss 357 | :param entity_embedding: shape [batch_size, 1, dim] (positive), [batch_size, num_neg, dim] (negative) 358 | :param query_embedding: shape [batch_size, 1, dim] 359 | :param inference: for discrete case, use soft for training and hard for inference 360 | :return score: shape [batch_size, 1] for positive, [batch_size, num_neg] for negative 361 | """ 362 | cos = nn.CosineSimilarity(dim=-1) 363 | if self.loss_type == 'gumbel_softmax': # regularizer must start with 'matrix' 364 | # entity embedding has been normalized 365 | # query embedding has been normalized as summing up to 1 if it's out of projection 366 | # not necessarily summing up to 1 if out of logic operations 367 | 368 | if self.gumbel_query_unnorm: 369 | query_normalized = query_embedding 370 | else: 371 | query_normalized = self.entity_regularizer.L1_normalize(query_embedding) # vector shape 372 | 373 | # query_normalized = query_embedding # vector shape 374 | if inference: 375 | # hard discrete 376 | entity_one_hot = self.entity_regularizer.hard_discretize(entity_embedding) # vector shape 377 | else: 378 | # convert entity to one-hot vector using gumbel 379 | entity_one_hot = self.entity_regularizer.soft_discretize(entity_embedding, self.gumbel_temperature) 380 | 381 | if self.gumbel_attention: 382 | entity_one_hot = self.entity_regularizer.reshape_to_matrix(entity_one_hot) 383 | query_normalized = self.entity_regularizer.reshape_to_matrix(query_normalized) 384 | score = cos(entity_one_hot, query_normalized) 385 | distribution_attention = self.get_distribution_attention(query_embedding) 386 | score = torch.sum(score * distribution_attention, dim=-1) 387 | else: 388 | # equivalent to torch.sum(entity_one_hot, query_normalized)/constant 389 | # since ||entity_one_hot|| is the same for all entities 390 | score = cos(entity_one_hot, query_normalized) 391 | return score 392 | 393 | if self.loss_type == 'dot': 394 | # score = torch.sum(entity_embedding * query_embedding, dim=-1) / math.sqrt(self.entity_dim) # dot product 395 | score = torch.sum(entity_embedding * query_embedding, dim=-1) # dot product 396 | elif self.loss_type == 'weighted_dot': 397 | dim_weights = self.dim_weight_softmax(self.dim_weight) 398 | score = torch.sum(entity_embedding * query_embedding * dim_weights, dim=-1) 399 | elif self.loss_type.startswith('discrete'): 400 | # entity embedding should have been discretized 401 | 402 | if self.loss_type == 'discrete_cos': 403 | cos = nn.CosineSimilarity(dim=-1) 404 | score = cos(entity_embedding, query_embedding) 405 | # inference only 406 | # thres = 0.7 407 | # entity_embedding[entity_embedding >= thres] = 1 408 | # entity_embedding[entity_embedding < thres] = 0 409 | 410 | elif self.loss_type == 'discrete_prob': 411 | # In discrete representation, entities are considered entry value 0 or 1 412 | # entity_embedding should have been discretized 413 | 414 | # For the qth query 415 | # unlike other score computation, this score is not aggregated for each sample 416 | score = entity_embedding * query_embedding + (1-entity_embedding) * (1-query_embedding) 417 | 418 | elif self.loss_type == 'entropy': 419 | query_embedding = self.entity_regularizer.L1_normalize(query_embedding) # vector shape 420 | 421 | # score = torch.mean(query_embedding * torch.log(entity_embedding+eps), dim=-1) 422 | 423 | # JSD 424 | m = torch.log2((query_embedding + entity_embedding) / 2 + 1e-9) 425 | dist = F.kl_div(m, query_embedding.expand(m.shape), reduction="none") \ 426 | + F.kl_div(m, entity_embedding, reduction="none") 427 | num_distributions = self.entity_regularizer.get_num_distributions() # entity_dim // k 428 | dist = 0.5 * torch.sum(dist, dim=-1) / num_distributions 429 | score = 1 - dist 430 | 431 | elif self.loss_type == 'fuzzy_containment': 432 | # for Godel only 433 | # use with sigmoid regularizer 434 | # L1 435 | score = entity_embedding - torch.relu(entity_embedding - query_embedding) 436 | score = torch.max(score, dim=-1) 437 | # / torch.sum(entity_embedding, dim=-1) 438 | 439 | elif self.loss_type == 'weighted_fuzzy_containment': 440 | # for Godel only, use with sigmoid regularizer 441 | entity_vals, entity_val_weights = torch.chunk(entity_embedding, 2, dim=-1) 442 | query_vals, query_val_weights = torch.chunk(query_embedding, 2, dim=-1) 443 | val_weights = F.softmax(entity_val_weights * query_val_weights, dim=-1) 444 | 445 | score = entity_vals - torch.relu(entity_vals - query_vals) # containment score 446 | score = torch.sum(score * val_weights, dim=-1) / torch.sum(entity_vals * val_weights, dim=-1) 447 | 448 | elif self.loss_type == 'cos_digits': # use with logsigmoid_bpr_digits 449 | if not inference: 450 | entity_embedding = F.normalize(entity_embedding, p=2, dim=-1) 451 | query_embedding = F.normalize(query_embedding, p=2, dim=-1) 452 | score_digits = (entity_embedding * query_embedding) * self.entity_dim 453 | # score_digits = score_digits / norm.unsqueeze(2) * self.entity_dim 454 | return score_digits # no aggregation, [batch_size, 1 or num_neg, dim] 455 | 456 | # # use cos for inference 457 | cos = nn.CosineSimilarity(dim=-1) 458 | score = cos(entity_embedding, query_embedding) 459 | 460 | elif self.loss_type == 'dot_layernorm_digits': # use with logsigmoid_bpr_digits 461 | entity_embedding = self.entity_ln(entity_embedding) 462 | query_embedding = self.query_ln(query_embedding) 463 | score_digits = (entity_embedding * query_embedding) 464 | # score_digits = score_digits / norm.unsqueeze(2) * self.entity_dim 465 | if not inference: 466 | return score_digits # no aggregation, [batch_size, 1 or num_neg, dim] 467 | # inference 468 | return torch.mean(score_digits, dim=-1) 469 | 470 | 471 | elif self.loss_type == 'L1_cos_digits': # use with logsigmoid_avg 472 | entity_embedding = F.normalize(entity_embedding, p=1, dim=-1) 473 | query_embedding = F.normalize(query_embedding, p=1, dim=-1) 474 | score_digits = (entity_embedding * query_embedding) * self.entity_dim 475 | # score_digits = score_digits / norm.unsqueeze(2) * self.entity_dim 476 | if not inference: 477 | return score_digits # no aggregation, [batch_size, 1 or num_neg, dim] 478 | # inference 479 | return torch.mean(score_digits, dim=-1) 480 | 481 | 482 | elif self.loss_type == 'soft_min_digits': 483 | # use with godel logic 484 | # entity_embedding = F.normalize(entity_embedding, p=2, dim=-1) 485 | # query_embedding = F.normalize(query_embedding, p=2, dim=-1) 486 | entity_embedding, query_embedding = torch.broadcast_tensors(entity_embedding, query_embedding) 487 | compare = torch.stack((entity_embedding, query_embedding)) 488 | # a smooth way to compute min 489 | score_digits = -self.godel_gumbel_beta * torch.logsumexp( 490 | -compare / self.godel_gumbel_beta, 0 491 | ) 492 | if not inference: 493 | return score_digits # no aggregation, [batch_size, 1 or num_neg, dim] 494 | # inference, aggregated 495 | score = torch.mean(score_digits, dim=-1) 496 | # score = torch.logsumexp(-score_digits, dim=-1) 497 | 498 | elif self.loss_type == 'entity_multinomial_dot': 499 | entity_embedding = F.normalize(entity_embedding, p=1, dim=-1) 500 | score = torch.sum(entity_embedding * query_embedding, dim=-1) 501 | 502 | elif self.loss_type == 'normalized_entity_dot': 503 | score = torch.sum(entity_embedding * query_embedding, dim=-1) 504 | 505 | else: # cos by default 506 | cos = nn.CosineSimilarity(dim=-1) 507 | score = cos(entity_embedding, query_embedding) 508 | return score 509 | 510 | @staticmethod 511 | def compute_loss(model, positive_score, negative_score, subsampling_weight): 512 | if model.margin_type == 'logsigmoid': 513 | # the loss of BetaE and RotatE 514 | positive_dist = 1-positive_score 515 | negative_dist = 1-negative_score 516 | positive_unweighted_loss = -F.logsigmoid((model.gamma - positive_dist)*model.gamma_coff).squeeze(dim=1) 517 | negative_unweighted_loss = -F.logsigmoid((negative_dist - model.gamma)*model.gamma_coff).mean(dim=1) 518 | positive_sample_loss = (subsampling_weight * positive_unweighted_loss).sum() 519 | negative_sample_loss = (subsampling_weight * negative_unweighted_loss).sum() 520 | positive_sample_loss /= subsampling_weight.sum() 521 | negative_sample_loss /= subsampling_weight.sum() 522 | loss = (positive_sample_loss + negative_sample_loss) / 2 523 | log = { 524 | 'positive_sample_loss': positive_sample_loss.item(), 525 | 'negative_sample_loss': negative_sample_loss.item(), 526 | 'loss': loss.item(), 527 | } 528 | elif model.margin_type == 'logsigmoid_avg': 529 | # use with cos_digits 530 | positive_dist = 1-positive_score 531 | negative_dist = 1-negative_score 532 | positive_unweighted_loss = -torch.mean(F.logsigmoid((model.gamma - positive_dist)*model.gamma_coff), dim=-1).squeeze(dim=1) 533 | negative_unweighted_loss = -torch.mean(F.logsigmoid((negative_dist - model.gamma)*model.gamma_coff), dim=-1).mean(dim=1) 534 | positive_sample_loss = (subsampling_weight * positive_unweighted_loss).sum() 535 | negative_sample_loss = (subsampling_weight * negative_unweighted_loss).sum() 536 | positive_sample_loss /= subsampling_weight.sum() 537 | negative_sample_loss /= subsampling_weight.sum() 538 | loss = (positive_sample_loss + negative_sample_loss) / 2 539 | log = { 540 | 'positive_sample_loss': positive_sample_loss.item(), 541 | 'negative_sample_loss': negative_sample_loss.item(), 542 | 'loss': loss.item(), 543 | } 544 | elif model.margin_type == 'softmax': 545 | # positive_score shape [batch_size, 1] 546 | criterion = nn.CrossEntropyLoss(reduction='none') # keep loss for each sample 547 | if model.loss_type != 'discrete_prob': 548 | softmax_weight = 10 549 | scores = torch.cat([positive_score, negative_score], dim=1)*softmax_weight # [batch_size, 1+negative_sample_size] 550 | else: 551 | # score: log(prob) 552 | # softmax=exp(x1)/(exp(x1)+...exp(xn))=exp(x1+exp_shift)/(exp(x1+exp_shift)+...+exp(xn+exp_shift)) 553 | # otherwise the log scores are too small and the results are all zero 554 | exp_shift, _ = torch.max(positive_score, dim=-1) 555 | exp_shift = torch.unsqueeze(exp_shift, 1) 556 | positive_score = positive_score - exp_shift # still in log scale 557 | negative_score = negative_score - exp_shift 558 | 559 | # debug only 560 | positive_score_real = torch.exp(positive_score) 561 | negative_score_real = torch.exp(negative_score) 562 | 563 | scores = torch.cat([positive_score, negative_score], dim=1) 564 | 565 | target = torch.zeros((positive_score.shape[0],), dtype=torch.long).to(device) 566 | loss = (criterion(scores, target) * subsampling_weight).sum() # CrossEntropyLoss includes softmax 567 | loss /= subsampling_weight.sum() 568 | log = {'loss': loss.item()} 569 | elif model.margin_type == 'bpr': 570 | # gamma as margin 571 | diff = torch.relu(model.gamma + negative_score -positive_score) # relu or softplus 572 | unweighted_sample_loss = torch.mean(diff, dim=-1) 573 | loss = (subsampling_weight * unweighted_sample_loss).sum() 574 | loss /= subsampling_weight.sum() 575 | log = { 576 | 'loss': loss.item(), 577 | } 578 | elif model.margin_type == 'bpr_digits': 579 | # positive_score: shape [batch_size, 1, dim] (not aggregated yet) 580 | # negative_score: shape [batch_size, neg_per_pos, dim] 581 | # gamma as margin 582 | diff = torch.mean(torch.relu(model.gamma + negative_score -positive_score), dim=-1) # relu or softplus 583 | unweighted_sample_loss = torch.mean(diff, dim=-1) 584 | loss = (subsampling_weight * unweighted_sample_loss).sum() 585 | loss /= subsampling_weight.sum() 586 | log = { 587 | 'loss': loss.item(), 588 | } 589 | elif model.margin_type == 'logsigmoid_bpr_digits': 590 | # positive_score: shape [batch_size, 1, dim] (not aggregated yet) 591 | # negative_score: shape [batch_size, neg_per_pos, dim] 592 | diff = -F.logsigmoid(model.gamma_coff*(torch.mean(positive_score - negative_score, dim=-1))) 593 | # diff = torch.mean(-F.logsigmoid(model.gamma_coff*(positive_score - negative_score)), dim=-1) 594 | unweighted_sample_loss = torch.mean(diff, dim=-1) 595 | loss = (subsampling_weight * unweighted_sample_loss).sum() 596 | loss /= subsampling_weight.sum() 597 | log = { 598 | 'loss': loss.item(), 599 | } 600 | elif model.margin_type == 'logsigmoid_bpr': 601 | # gamma as margin 602 | diff = -F.logsigmoid(model.gamma_coff*(positive_score - negative_score)) 603 | # diff = torch.mean(-F.logsigmoid(model.gamma_coff*(positive_score - negative_score)), dim=-1) 604 | unweighted_sample_loss = torch.mean(diff, dim=-1) 605 | loss = (subsampling_weight * unweighted_sample_loss).sum() 606 | loss /= subsampling_weight.sum() 607 | log = { 608 | 'loss': loss.item(), 609 | } 610 | 611 | elif model.margin_type == 'nll': # negative log likelihood. used together with discrete_prob 612 | if model.loss_type == 'discrete_prob': 613 | # positive_score: shape [batch_size, 1, dim] (not aggregated yet) 614 | # negative_score: shape [batch_size, neg_per_pos, dim] 615 | 616 | eps = 1e-4 # avoid torch.log(zero) 617 | log_positive_score = torch.log(positive_score+eps) 618 | log_negative_score = torch.log(1-negative_score+eps) # flip for negative samples 619 | 620 | # negative log likelihood 621 | # use torch.mean instead of torch.sum to divide by a constant (dim) 622 | positive_sample_loss = - torch.mean(log_positive_score, dim=-1).squeeze(dim=1) 623 | negative_sample_loss = - torch.mean(log_negative_score, dim=-1).mean(dim=1) 624 | # positive_sample_loss = -positive_score.squeeze(dim=1) 625 | # negative_sample_loss = -torch.log(1-torch.exp(negative_score)+eps).mean(dim=1) 626 | 627 | positive_sample_loss = (subsampling_weight * positive_sample_loss).sum() 628 | negative_sample_loss = (subsampling_weight * negative_sample_loss).sum() 629 | positive_sample_loss /= subsampling_weight.sum() 630 | negative_sample_loss /= subsampling_weight.sum() 631 | loss = (positive_sample_loss + negative_sample_loss) 632 | log = { 633 | 'positive_sample_loss': positive_sample_loss.item(), 634 | 'negative_sample_loss': negative_sample_loss.item(), 635 | 'loss': loss.item(), 636 | } 637 | elif model.loss_type == 'entropy': 638 | # version 1 639 | # positive_sample_loss = - positive_score.squeeze(dim=-1) 640 | # negative_sample_loss = negative_score.mean(dim=-1) 641 | # positive_sample_loss = (subsampling_weight * positive_sample_loss).sum() 642 | # negative_sample_loss = (subsampling_weight * negative_sample_loss).sum() 643 | # positive_sample_loss /= subsampling_weight.sum() 644 | # negative_sample_loss /= subsampling_weight.sum() 645 | # loss = (positive_sample_loss + negative_sample_loss) 646 | # log = { 647 | # 'positive_sample_loss': positive_sample_loss.item(), 648 | # 'negative_sample_loss': negative_sample_loss.item(), 649 | # 'loss': loss.item(), 650 | # } 651 | 652 | # # # version 2 653 | positive_score = positive_score.squeeze(dim=-1) 654 | negative_score = negative_score.mean(dim=-1) 655 | diff = torch.relu(model.gamma + negative_score-positive_score) 656 | unweighted_sample_loss = torch.mean(diff, dim=-1) 657 | loss = (subsampling_weight * unweighted_sample_loss).sum() 658 | loss /= subsampling_weight.sum() 659 | log = { 660 | 'loss': loss.item(), 661 | } 662 | 663 | return loss, log 664 | 665 | 666 | @staticmethod 667 | def train_step(model, optimizer, train_iterator, args, step): 668 | """ 669 | Adapted for multiple GPUs 670 | """ 671 | # device = model.module.device 672 | device = model.device 673 | 674 | model.train() 675 | optimizer.zero_grad() 676 | 677 | positive_sample, negative_sample, subsampling_weight, batch_queries, query_structure_idxs = next(train_iterator) 678 | 679 | if args.cuda: 680 | positive_sample = positive_sample.to(device) 681 | negative_sample = negative_sample.to(device) 682 | subsampling_weight = subsampling_weight.to(device) 683 | # no need to move query_structure_idxs to GPU 684 | 685 | # nn.DataParallel helper 686 | batch_size = len(positive_sample) 687 | slice_idxs = torch.arange(0, batch_size).view((batch_size, 1)) 688 | 689 | positive_score, negative_score, subsampling_weight, _ = model( 690 | positive_sample, 691 | negative_sample, 692 | subsampling_weight, 693 | batch_queries, # np.array([queries]), won't be split when using multiple GPUs 694 | query_structure_idxs, # torch.LongTensor 695 | slice_idxs, # to help track batch_queries and query_structures when using multiple GPUs 696 | inference=False 697 | ) 698 | loss, log = KGFuzzyReasoning.compute_loss(model, positive_score, negative_score, subsampling_weight) 699 | 700 | loss.backward() 701 | optimizer.step() 702 | 703 | if model.loss_type == 'normalized_entity_dot': 704 | with torch.no_grad(): 705 | # normalize entity embeddings 706 | normalized = nn.Parameter(torch.clamp(model.entity_embedding, 0, 1)) 707 | # F1 normalize 708 | model.entity_embedding = nn.Parameter(F.normalize(normalized, p=1, dim=-1)) 709 | 710 | 711 | 712 | return log 713 | 714 | @staticmethod 715 | def test_step(model, easy_answers, hard_answers, args, test_dataloader, query_name_dict, save_result=False, save_str="", save_empty=False): 716 | model.eval() 717 | 718 | # device = model.module.device 719 | device = model.device 720 | 721 | step = 0 722 | total_steps = len(test_dataloader) 723 | logs = collections.defaultdict(list) 724 | 725 | with torch.no_grad(): 726 | for negative_sample, queries, queries_unflatten, query_structure_idxs in tqdm(test_dataloader, disable=not args.print_on_screen): 727 | # example: query_structures: [('e', ('r',))]. queries: [[1804,4]]. queries_unflatten: [(1804, (4,)] 728 | if args.cuda: 729 | negative_sample = negative_sample.to(device) 730 | 731 | # nn.DataParallel helper 732 | batch_size = len(negative_sample) 733 | slice_idxs = torch.arange(0, batch_size).view((batch_size, 1)) 734 | 735 | _, negative_logit, _, idxs = model( 736 | None, 737 | negative_sample, 738 | None, 739 | queries, # np.array([queries]), won't be split when using multiple GPUs 740 | query_structure_idxs, 741 | slice_idxs, # to help track batch_queries and query_structures when using multiple GPUs 742 | inference=True 743 | ) 744 | 745 | if model.loss_type == 'discrete_prob': 746 | # negative_logit shape[batch_size, num_entity, dim], not aggregated yet 747 | eps = 1e-4 748 | negative_logit = torch.sum(torch.log(negative_logit+eps), dim=-1) 749 | 750 | idxs_np = idxs.detach().cpu().numpy() 751 | # if not converted to numpy, idxs_np will be considered scalar when test_batch_size=1 752 | # queries_unflatten = queries_unflatten[idxs_np] 753 | query_structure_idxs = query_structure_idxs[idxs_np] 754 | queries_unflatten = [queries_unflatten[i] for i in idxs] 755 | 756 | argsort = torch.argsort(negative_logit, dim=1, descending=True) 757 | ranking = argsort.clone().to(torch.float) 758 | 759 | # rank all entities 760 | # If it is the same shape with test_batch_size, reuse batch_entity_range without creating a new one 761 | if len(argsort) == args.test_batch_size: 762 | # ranking = ranking.scatter_(1, argsort, model.module.batch_entity_range) # achieve the ranking of all entities 763 | ranking = ranking.scatter_(1, argsort, model.batch_entity_range) # achieve the ranking of all entities 764 | else: # otherwise, create a new torch Tensor for batch_entity_range 765 | ranking = ranking.scatter_( 766 | 1, 767 | argsort, 768 | torch.arange(model.nentity).to(torch.float).repeat(argsort.shape[0], 1).to(device) 769 | # torch.arange(model.module.nentity).to(torch.float).repeat(argsort.shape[0], 1).to(device) 770 | ) 771 | 772 | for idx, (i, query, query_structure_idx) in enumerate(zip(argsort[:, 0], queries_unflatten, query_structure_idxs)): 773 | # convert query from np.ndarray to nested tuple 774 | query_key = tuple(query) 775 | query_structure = query_structure_list[query_structure_idx] 776 | 777 | hard_answer = hard_answers[query_key] 778 | easy_answer = easy_answers[query_key] 779 | num_hard = len(hard_answer) 780 | num_easy = len(easy_answer) 781 | assert len(hard_answer.intersection(easy_answer)) == 0 782 | cur_ranking = ranking[idx, list(easy_answer) + list(hard_answer)] 783 | cur_ranking, indices = torch.sort(cur_ranking) 784 | masks = indices >= num_easy 785 | if args.cuda: 786 | answer_list = torch.arange(num_hard + num_easy).to(torch.float).cuda() 787 | else: 788 | answer_list = torch.arange(num_hard + num_easy).to(torch.float) 789 | cur_ranking = cur_ranking - answer_list + 1 # filtered setting 790 | cur_ranking = cur_ranking[masks] # only take indices that belong to the hard answers 791 | 792 | mrr = torch.mean(1./cur_ranking).item() 793 | h1 = torch.mean((cur_ranking <= 1).to(torch.float)).item() 794 | h3 = torch.mean((cur_ranking <= 3).to(torch.float)).item() 795 | h10 = torch.mean((cur_ranking <= 10).to(torch.float)).item() 796 | 797 | logs[query_structure].append({ 798 | 'MRR': mrr, 799 | 'HITS1': h1, 800 | 'HITS3': h3, 801 | 'HITS10': h10, 802 | 'num_hard_answer': num_hard, 803 | }) 804 | 805 | if step % args.test_log_steps == 0: 806 | logging.info('Evaluating the model... (%d/%d)' % (step, total_steps)) 807 | 808 | step += 1 809 | 810 | metrics = collections.defaultdict(lambda: collections.defaultdict(int)) 811 | for query_structure in logs: 812 | for metric in logs[query_structure][0].keys(): 813 | if metric in ['num_hard_answer']: 814 | continue 815 | metrics[query_structure][metric] = sum([log[metric] for log in logs[query_structure]])/len(logs[query_structure]) 816 | metrics[query_structure]['num_queries'] = len(logs[query_structure]) 817 | 818 | return metrics 819 | 820 | def JSD(p, q): 821 | m = (p + q) / 2 822 | loss = F.kl_div(p.log(), m, reduction="mean") + F.kl_div(q.log(), m, reduction="mean") 823 | return 1 - (0.5 * loss) -------------------------------------------------------------------------------- /gumbel.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.autograd import Variable 5 | 6 | # ================================ 7 | # 8 | def sample_gumbel(shape, eps=1e-20): 9 | U = torch.rand(shape).cuda() # from uniform distribution [0,1) 10 | return -Variable(torch.log(-torch.log(U + eps) + eps)) 11 | 12 | 13 | def gumbel_softmax_sample(logits, temperature): 14 | y = logits + sample_gumbel(logits.size()) 15 | return F.softmax(y / temperature, dim=-1) 16 | 17 | 18 | def gumbel_softmax(logits, temperature): 19 | """ 20 | input: [*, n_class] 21 | return: [*, n_class] an one-hot vector 22 | """ 23 | y = gumbel_softmax_sample(logits, temperature) 24 | shape = y.size() 25 | _, ind = y.max(dim=-1) 26 | y_hard = torch.zeros_like(y).view(-1, shape[-1]) 27 | y_hard.scatter_(1, ind.view(-1, 1), 1) 28 | y_hard = y_hard.view(*shape) 29 | return (y_hard - y).detach() + y 30 | 31 | 32 | 33 | # =============================== 34 | # https://github.com/shaabhishek/gumbel-softmax-pytorch/blob/master/Gumbel-softmax%20visualization.ipynb 35 | 36 | # def sample_gumbel(shape): 37 | # unif = torch.distributions.Uniform(0,1).sample(shape).cuda() 38 | # g = -torch.log(-torch.log(unif)) 39 | # return g 40 | # 41 | # 42 | # def gumbel_softmax(pi, temperature): 43 | # g = sample_gumbel(pi.size()) 44 | # h = (g + torch.log(pi))/temperature 45 | # h_max = h.max(dim=1, keepdim=True)[0] 46 | # h = h - h_max 47 | # cache = torch.exp(h) 48 | # # print(pi, torch.log(pi), intmdt) 49 | # y = cache / cache.sum(dim=-1, keepdim=True) 50 | # return y 51 | 52 | 53 | 54 | #================== 55 | # https://github.com/yandexdataschool/gumbel_dpg/blob/master/gumbel.py 56 | 57 | # 58 | # class GumbelSigmoid: 59 | # """ 60 | # A gumbel-sigmoid nonlinearity with gumbel(0,1) noize 61 | # In short, it's a function that mimics #[a>0] indicator where a is the logit 62 | # 63 | # Explaination and motivation: https://arxiv.org/abs/1611.01144 64 | # 65 | # Math: 66 | # Sigmoid is a softmax of two logits: a and 0 67 | # e^a / (e^a + e^0) = 1 / (1 + e^(0 - a)) = sigm(a) 68 | # 69 | # Gumbel-sigmoid is a gumbel-softmax for same logits: 70 | # gumbel_sigm(a) = e^([a+gumbel1]/t) / [ e^([a+gumbel1]/t) + e^(gumbel2/t)] 71 | # where t is temperature, gumbel1 and gumbel2 are two samples from gumbel noize: -log(-log(uniform(0,1))) 72 | # gumbel_sigm(a) = 1 / ( 1 + e^(gumbel2/t - [a+gumbel1]/t) = 1 / ( 1+ e^(-[a + gumbel1 - gumbel2]/t) 73 | # gumbel_sigm(a) = sigm([a+gumbel1-gumbel2]/t) 74 | # 75 | # For computation reasons: 76 | # gumbel1-gumbel2 = -log(-log(uniform1(0,1)) +log(-log(uniform2(0,1)) = -log( log(uniform2(0,1)) / log(uniform1(0,1)) ) 77 | # gumbel_sigm(a) = sigm([a-log(log(uniform2(0,1))/log(uniform1(0,1))]/t) 78 | # 79 | # 80 | # :param t: temperature of sampling. Lower means more spike-like sampling. Can be symbolic. 81 | # :param eps: a small number used for numerical stability 82 | # :returns: a callable that can (and should) be used as a nonlinearity 83 | # 84 | # """ 85 | # 86 | # def __init__(self, 87 | # t=0.1, 88 | # eps=1e-20): 89 | # assert t != 0 90 | # self.temperature = t 91 | # self.eps = eps 92 | # self._srng = RandomStreams(get_rng().randint(1, 2147462579)) 93 | # 94 | # def __call__(self, logits): 95 | # """computes a gumbel softmax sample""" 96 | # 97 | # # sample from Gumbel(0, 1) 98 | # uniform1 = self._srng.uniform(logits.shape, low=0, high=1) 99 | # uniform2 = self._srng.uniform(logits.shape, low=0, high=1) 100 | # 101 | # noise = -T.log(T.log(uniform2 + self.eps) / T.log(uniform1 + self.eps) + self.eps) 102 | # 103 | # # draw a sample from the Gumbel-Sigmoid distribution 104 | # return T.nnet.sigmoid((logits + noise) / self.temperature) 105 | # 106 | # 107 | # def hard_sigm(logits): 108 | # """computes a hard indicator function. Not differentiable""" 109 | # return T.switch(T.gt(logits, 0), 1, 0) 110 | # 111 | # 112 | # class GumbelSigmoidLayer(Layer): 113 | # """ 114 | # lasagne.layers.GumbelSigmoidLayer(incoming,**kwargs) 115 | # A layer that just applies a GumbelSigmoid nonlinearity. 116 | # In short, it's a function that mimics #[a>0] indicator where a is the logit 117 | # 118 | # Explaination and motivation: https://arxiv.org/abs/1611.01144 119 | # 120 | # Math: 121 | # Sigmoid is a softmax of two logits: a and 0 122 | # e^a / (e^a + e^0) = 1 / (1 + e^(0 - a)) = sigm(a) 123 | # 124 | # Gumbel-sigmoid is a gumbel-softmax for same logits: 125 | # gumbel_sigm(a) = e^([a+gumbel1]/t) / [ e^([a+gumbel1]/t) + e^(gumbel2/t)] 126 | # where t is temperature, gumbel1 and gumbel2 are two samples from gumbel noize: -log(-log(uniform(0,1))) 127 | # gumbel_sigm(a) = 1 / ( 1 + e^(gumbel2/t - [a+gumbel1]/t) = 1 / ( 1+ e^(-[a + gumbel1 - gumbel2]/t) 128 | # gumbel_sigm(a) = sigm([a+gumbel1-gumbel2]/t) 129 | # 130 | # For computation reasons: 131 | # gumbel1-gumbel2 = -log(-log(uniform1(0,1)) +log(-log(uniform2(0,1)) = -log( log(uniform2(0,1)) / log(uniform1(0,1)) ) 132 | # gumbel_sigm(a) = sigm([a-log(log(uniform2(0,1))/log(uniform1(0,1))]/t) 133 | # 134 | # Parameters 135 | # ---------- 136 | # incoming : a :class:`Layer` instance or a tuple 137 | # The layer feeding into this layer, or the expected input shape 138 | # t: temperature of sampling. Lower means more spike-like sampling. Can be symbolic (e.g. shared) 139 | # eps: a small number used for numerical stability 140 | # """ 141 | # 142 | # def __init__(self, incoming, t=0.1, eps=1e-20, **kwargs): 143 | # super(GumbelSigmoidLayer, self).__init__(incoming, **kwargs) 144 | # self.gumbel_sigm = GumbelSigmoid(t=t, eps=eps) 145 | # 146 | # def get_output_for(self, input, hard_max=False, **kwargs): 147 | # if hard_max: 148 | # return hard_sigm(input) 149 | # else: 150 | # return self.gumbel_sigm(input) -------------------------------------------------------------------------------- /investigation_helper.py: -------------------------------------------------------------------------------- 1 | from sklearn.preprocessing import MinMaxScaler 2 | from sklearn.decomposition import PCA 3 | import os 4 | from main import parse_args 5 | import pandas as pd 6 | from os.path import join 7 | import pickle 8 | import torch 9 | from fuzzyreasoning import KGFuzzyReasoning 10 | from dataloader import load_data_from_pickle, load_data 11 | from constants import query_name_dict, query_structure_list, query_structure2idx 12 | from collections import defaultdict 13 | from main import parse_args 14 | from util import evaluate, read_num_entity_relation_from_file, eval_tuple, wandb_initialize 15 | import collections 16 | import copy 17 | from constants import query_structure2idx, query_name_dict, query_structure_list 18 | from dataloader import TestDataset 19 | from torch.utils.data import DataLoader 20 | import pickle 21 | import pandas as pd 22 | import numpy as np 23 | import matplotlib.pyplot as plt 24 | import torch.nn as nn 25 | from util import log_metrics 26 | import torch.nn.functional as F 27 | import os 28 | from util import list2tuple, tuple2list, flatten, flatten_query_and_convert_structure_to_idx 29 | from dataloader import SingledirectionalOneShotIterator, TrainDataset 30 | 31 | def relPCA(model): 32 | rel_base0 = model.projection_net.rel_base.detach() 33 | n_base, d1, d2 = rel_base0.size() 34 | rel_base = rel_base0.view(n_base, d1*d2).cpu().numpy() 35 | scaler = MinMaxScaler() 36 | data_rescaled = scaler.fit_transform(rel_base.transpose()) 37 | 38 | 39 | pca = PCA(n_components = 0.95) 40 | pca.fit(data_rescaled) 41 | reduced = pca.transform(data_rescaled) 42 | print(reduced.shape) 43 | 44 | 45 | def set_GPU_id(gpu_id): 46 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" 47 | os.environ["CUDA_VISIBLE_DEVICES"] = gpu_id # specify which GPU(s) to be used 48 | 49 | 50 | def print_one_sigmoid(regularizer): 51 | print(regularizer.weight) 52 | print(regularizer.bias) 53 | 54 | 55 | def print_all_sigmoid(model): 56 | print_one_sigmoid(model.entity_regularizer) 57 | print_one_sigmoid(model.projection_net.regularizer) 58 | print_one_sigmoid(model.negation_net.regularizer) 59 | print_one_sigmoid(model.disjunction_net.regularizer) 60 | print_one_sigmoid(model.conjunction_net.regularizer) 61 | 62 | 63 | def get_args(arg_str, model): 64 | arg_str = arg_str.split() 65 | args = parse_args(arg_str) 66 | args.nentity = model.nentity 67 | model.conjunction_net.use_attention = args.use_attention 68 | print(args) 69 | return args 70 | 71 | def get_answers_from_train(data_dir): 72 | """Get answers from train.txt""" 73 | train_df = pd.read_csv(open(join(data_dir, 'train.txt')), sep='\t') 74 | hr2t = {} 75 | for index,row in train_df.iterrows(): 76 | h,r,t = row[0], row[1], row[2] 77 | if (h,r) not in hr2t: 78 | hr2t[(h,r)] = set() 79 | hr2t[(h,r)].add(t) 80 | return hr2t 81 | 82 | def get_answers_from_full(data_dir): 83 | """Get answers from train.txt""" 84 | train_df = pd.read_csv(open(join(data_dir, 'train.txt')), sep='\t') 85 | val_df = pd.read_csv(open(join(data_dir, 'valid.txt')), sep='\t') 86 | test_df = pd.read_csv(open(join(data_dir, 'test.txt')), sep='\t') 87 | hr2t = {} 88 | for index,row in train_df.iterrows(): 89 | h,r,t = row[0], row[1], row[2] 90 | if (h,r) not in hr2t: 91 | hr2t[(h,r)] = set() 92 | hr2t[(h,r)].add(t) 93 | for index,row in val_df.iterrows(): 94 | h,r,t = row[0], row[1], row[2] 95 | if (h,r) not in hr2t: 96 | hr2t[(h,r)] = set() 97 | hr2t[(h,r)].add(t) 98 | for index,row in test_df.iterrows(): 99 | h,r,t = row[0], row[1], row[2] 100 | if (h,r) not in hr2t: 101 | hr2t[(h,r)] = set() 102 | hr2t[(h,r)].add(t) 103 | return hr2t_full 104 | 105 | def get_ent_rel_labels(data_dir): 106 | rel_id2str = pickle.load(open(join(data_dir, 'id2rel.pkl'), 'rb')) 107 | 108 | if 'FB15k-237' in data_dir: 109 | # use the entity label pulled from Wikidata 110 | # since the original entity str in FB15k is FreeBase ID and not readable 111 | entity_df = pd.read_csv(open(join(data_dir, 'entity_label.tsv')), sep='\t') 112 | ent_id2str = pd.Series(entity_df['wiki'].values, index=entity_df['id']).to_dict() 113 | print(entity_df.head(5)) 114 | else: # e.g. NELL 115 | ent_id2str = pickle.load(open(join(data_dir, 'id2ent.pkl'), 'rb')) 116 | 117 | 118 | return ent_id2str, rel_id2str 119 | 120 | 121 | 122 | def get_projected_t_vec(model, h, rs): 123 | """rs is a list of relations""" 124 | 125 | h_vec = model.entity_regularizer( 126 | torch.index_select( 127 | model.entity_embedding, 128 | dim=0, 129 | index=torch.LongTensor([h]).to(model.device) 130 | ) 131 | ) 132 | rid =torch.LongTensor(rs) 133 | 134 | t_vec = model.projection_net(h_vec, rid) 135 | return t_vec 136 | 137 | def get_negated_vec(model, v): 138 | neg_v = model.negation_net(v) 139 | return neg_v 140 | 141 | 142 | def get_conjunction_vec(model, v1, v2): 143 | stack = torch.stack([v1, v2]) 144 | conj = model.conjunction_net(stack) 145 | return conj 146 | 147 | 148 | def get_disjunction_vec(v1, v2): 149 | stack = torch.stack([v1, v2]) 150 | union = model.disjunction_net(stack) 151 | return union 152 | 153 | def plot_vec(v, firstk=20): 154 | """ 155 | vec.shape: [1, dim] 156 | """ 157 | plt.figure() 158 | vec = v[0,:firstk].detach().cpu().numpy() 159 | ind = np.arange(len(vec)) 160 | plt.bar(ind, vec) 161 | 162 | 163 | def score_distribution_by_vec(model, vec): 164 | model.eval() 165 | 166 | device = model.device 167 | 168 | step = 0 169 | logs = collections.defaultdict(list) 170 | 171 | with torch.no_grad(): 172 | negative_samples = torch.arange(0, model.nentity).to(model.device) # all entities 173 | negative_embedding = model.entity_regularizer( 174 | torch.index_select( 175 | model.entity_embedding, 176 | dim=0, 177 | index=negative_samples.view(-1) 178 | ).view( 179 | 1, 180 | model.nentity, 181 | -1 182 | ) 183 | ) 184 | 185 | scores = model.cal_logit_fuzzy(negative_embedding, vec) 186 | return scores 187 | 188 | def plot_vec_wide(v, firstk=20, height=5, width=20): 189 | """ 190 | vec.shape: [1, dim] 191 | """ 192 | plt.figure(figsize=(width, height)) 193 | vec = v[0,:firstk].detach().cpu().numpy() 194 | ind = np.arange(len(vec)) 195 | plt.bar(ind, vec) 196 | 197 | def plot_vec_wide_with_color(v, highlight, flip=False, start=0, length=20, height=5, width=20): 198 | """ 199 | vec.shape: [1, dim] 200 | highlight: set of index set(3, 8, ...), otherwise blue 201 | """ 202 | 203 | colors = ['b' for i in range(v.shape[1])] 204 | for i in highlight: 205 | colors[i] = 'r' 206 | if flip: 207 | colors = ['r' for i in range(v.shape[1])] 208 | for i in highlight: 209 | colors[i] = 'b' 210 | 211 | plt.figure(figsize=(width, height)) 212 | vec = v[0,start:start+length].detach().cpu().numpy() 213 | ind = np.arange(len(vec)) 214 | plt.bar(ind, vec, color=colors[start:start+length]) 215 | 216 | 217 | def test_step(model, easy_answers, hard_answers, args, test_dataloader, query_name_dict, verbose=False): 218 | model.eval() 219 | 220 | device = model.device 221 | 222 | step = 0 223 | total_steps = len(test_dataloader) 224 | logs = collections.defaultdict(list) 225 | 226 | with torch.no_grad(): 227 | for negative_sample, queries, queries_unflatten, query_structure_idxs in test_dataloader: 228 | # example: query_structures: [('e', ('r',))]. queries: [[1804,4]]. queries_unflatten: [(1804, (4,)] 229 | if args.cuda: 230 | negative_sample = negative_sample.to(device) 231 | 232 | # nn.DataParallel helper 233 | batch_size = len(negative_sample) 234 | slice_idxs = torch.arange(0, batch_size).view((batch_size, 1)) 235 | 236 | _, negative_logit, _, idxs = model( 237 | None, 238 | negative_sample, 239 | None, 240 | queries, # np.array([queries]), won't be split when using multiple GPUs 241 | query_structure_idxs, 242 | slice_idxs, # to help track batch_queries and query_structures when using multiple GPUs 243 | inference=True 244 | ) 245 | 246 | idxs_np = idxs.detach().cpu().numpy() 247 | # if not converted to numpy, idxs_np will be considered scalar when test_batch_size=1 248 | # queries_unflatten = queries_unflatten[idxs_np] 249 | query_structure_idxs = query_structure_idxs[idxs_np] 250 | queries_unflatten = [queries_unflatten[i] for i in idxs] 251 | 252 | # 253 | # query_structures = [query_structures[i] for i in idxs] 254 | argsort = torch.argsort(negative_logit, dim=1, descending=True) 255 | ranking = argsort.clone().to(torch.float) 256 | 257 | # rank all entities 258 | # If it is the same shape with test_batch_size, reuse batch_entity_range without creating a new one 259 | if len(argsort) == args.test_batch_size: 260 | # ranking = ranking.scatter_(1, argsort, model.module.batch_entity_range) # achieve the ranking of all entities 261 | ranking = ranking.scatter_(1, argsort, model.batch_entity_range) # achieve the ranking of all entities 262 | else: # otherwise, create a new torch Tensor for batch_entity_range 263 | ranking = ranking.scatter_( 264 | 1, 265 | argsort, 266 | torch.arange(model.nentity).to(torch.float).repeat(argsort.shape[0], 1).to(device) 267 | # torch.arange(model.module.nentity).to(torch.float).repeat(argsort.shape[0], 1).to(device) 268 | ) 269 | 270 | for idx, (i, query, query_structure_idx) in enumerate(zip(argsort[:, 0], queries_unflatten, query_structure_idxs)): 271 | # convert query from np.ndarray to nested tuple 272 | query_key = tuple(query) 273 | query_structure = query_structure_list[query_structure_idx] 274 | 275 | hard_answer = hard_answers[query_key] 276 | easy_answer = easy_answers[query_key] 277 | num_hard = len(hard_answer) 278 | num_easy = len(easy_answer) 279 | assert len(hard_answer.intersection(easy_answer)) == 0 280 | cur_ranking = ranking[idx, list(easy_answer) + list(hard_answer)] 281 | cur_ranking, indices = torch.sort(cur_ranking) 282 | 283 | masks = indices >= num_easy 284 | if args.cuda: 285 | answer_list = torch.arange(num_hard + num_easy).to(torch.float).cuda() 286 | else: 287 | answer_list = torch.arange(num_hard + num_easy).to(torch.float) 288 | cur_ranking = cur_ranking - answer_list + 1 # filtered setting 289 | cur_ranking = cur_ranking[masks] # only take indices that belong to the hard answers 290 | 291 | if verbose: 292 | print(answer_list) 293 | print('ranking', cur_ranking) 294 | 295 | mrr = torch.mean(1./cur_ranking).item() 296 | h1 = torch.mean((cur_ranking <= 1).to(torch.float)).item() 297 | h3 = torch.mean((cur_ranking <= 3).to(torch.float)).item() 298 | h10 = torch.mean((cur_ranking <= 10).to(torch.float)).item() 299 | 300 | logs[query_structure].append({ 301 | 'MRR': mrr, 302 | 'HITS1': h1, 303 | 'HITS3': h3, 304 | 'HITS10': h10, 305 | 'num_hard_answer': num_hard, 306 | }) 307 | 308 | if step % args.test_log_steps == 0: 309 | print('Evaluating the model... (%d/%d)' % (step, total_steps)) 310 | 311 | step += 1 312 | 313 | metrics = collections.defaultdict(lambda: collections.defaultdict(int)) 314 | for query_structure in logs: 315 | for metric in logs[query_structure][0].keys(): 316 | if metric in ['num_hard_answer']: 317 | continue 318 | metrics[query_structure][metric] = sum([log[metric] for log in logs[query_structure]])/len(logs[query_structure]) 319 | metrics[query_structure]['num_queries'] = len(logs[query_structure]) 320 | 321 | return metrics 322 | 323 | 324 | def wandb_log_metrics(metrics, args): 325 | run = wandb_initialize(vars(args)) 326 | 327 | average_metrics = defaultdict(float) 328 | average_pos_metrics = defaultdict(float) 329 | average_neg_metrics = defaultdict(float) 330 | all_metrics = defaultdict(float) 331 | 332 | num_query_structures = 0 333 | num_pos_query_structures = 0 334 | num_neg_query_structures = 0 335 | 336 | num_queries = 0 337 | mode="Test" 338 | step=0 339 | for query_structure in metrics: 340 | log_metrics(mode + " " + query_name_dict[query_structure], step, metrics[query_structure]) 341 | for metric in metrics[query_structure]: 342 | query_name = query_name_dict[query_structure] # e.g. 1p 343 | all_metrics["_".join([query_name, metric])] = metrics[query_structure][metric] 344 | if metric != 'num_queries': 345 | average_metrics[metric] += metrics[query_structure][metric] 346 | if 'n' in query_name: 347 | average_neg_metrics[metric] += metrics[query_structure][metric] 348 | else: 349 | average_pos_metrics[metric] += metrics[query_structure][metric] 350 | num_queries += metrics[query_structure]['num_queries'] 351 | num_query_structures += 1 352 | if 'n' in query_name: 353 | num_neg_query_structures += 1 354 | else: 355 | num_pos_query_structures += 1 356 | 357 | for metric in average_pos_metrics: 358 | average_pos_metrics[metric] /= num_pos_query_structures 359 | # writer.add_scalar("_".join([mode, 'average', metric]), average_metrics[metric], step) 360 | all_metrics["_".join(["average_pos", metric])] = average_pos_metrics[metric] 361 | 362 | for metric in average_neg_metrics: 363 | average_neg_metrics[metric] /= num_neg_query_structures 364 | # writer.add_scalar("_".join([mode, 'average', metric]), average_metrics[metric], step) 365 | all_metrics["_".join(["average_neg", metric])] = average_neg_metrics[metric] 366 | 367 | for metric in average_metrics: 368 | average_metrics[metric] /= num_query_structures 369 | all_metrics["_".join(["average", metric])] = average_metrics[metric] 370 | log_metrics('%s average' % mode, step, average_metrics) 371 | log_metrics('%s average_pos' % mode, step, average_pos_metrics) 372 | log_metrics('%s average_neg' % mode, step, average_neg_metrics) 373 | 374 | log_metrics('%s average' % mode, step, all_metrics) 375 | 376 | def get_query_structure_by_type_name(type_name): 377 | """1p -> (e, (r,))""" 378 | for qstructure, qname in query_name_dict.items(): # query_name_dict: imported from constants 379 | if qname == type_name: 380 | return qstructure 381 | 382 | 383 | def get_type_idx_by_name(type_name): 384 | """type_name: e.g. 1p""" 385 | 386 | for qstructure, qname in query_name_dict.items(): # query_name_dict: imported from constants 387 | if qname == type_name: 388 | type_structure = qstructure 389 | 390 | return query_structure2idx[type_structure] # query_structure2idx: imported from constants 391 | 392 | def get_queries_by_type_id(all_queries, type_idx): 393 | """all queries: list[tuple(nested_query, type_idx)]""" 394 | keep = [q for q in all_queries if q[1] == type_idx] 395 | return keep 396 | 397 | 398 | def get_queries_by_type_name(all_queries, type_name): 399 | """all queries: list[tuple(nested_query, type_idx)]""" 400 | type_idx = get_type_idx_by_name(type_name) 401 | keep = get_queries_by_type_id(all_queries, type_idx) 402 | return keep 403 | 404 | 405 | 406 | def get_sub_dataloader(full_test_dataset, type_name, args): 407 | full_queries = full_test_dataset.queries 408 | sub_queries = get_queries_by_type_name(full_queries, type_name) 409 | print('number of queries:', len(sub_queries)) 410 | sub_test_dataset = TestDataset( 411 | sub_queries, 412 | full_test_dataset.nentity, 413 | full_test_dataset.nrelation 414 | ) 415 | sub_test_dataloader = sub_test_dataloader = DataLoader( 416 | sub_test_dataset, 417 | batch_size=args.test_batch_size, 418 | num_workers=args.cpu_num, 419 | collate_fn=TestDataset.collate_fn 420 | ) 421 | return sub_queries, sub_test_dataloader 422 | 423 | def test_on_a_type(model, full_test_dataset, args, test_easy_answers, test_hard_answers, type_name): 424 | sub_queries, sub_test_dataloader = get_sub_dataloader(full_test_dataset, type_name, args) 425 | metrics = test_step(model, test_easy_answers, test_hard_answers, args, sub_test_dataloader, query_name_dict) 426 | return sub_queries, metrics 427 | 428 | 429 | def get_dataloader_for_one_query(full_test_dataset, args, q): 430 | oneq = [q] 431 | sub_test_dataset = TestDataset( 432 | oneq, 433 | full_test_dataset.nentity, 434 | full_test_dataset.nrelation 435 | ) 436 | sub_test_dataloader = DataLoader( 437 | sub_test_dataset, 438 | batch_size=args.test_batch_size, 439 | num_workers=args.cpu_num, 440 | collate_fn=TestDataset.collate_fn 441 | ) 442 | return sub_test_dataloader 443 | 444 | def test_on_one_query(full_test_dataset, args, q): 445 | oneq = [q] 446 | sub_test_dataset = TestDataset( 447 | oneq, 448 | full_test_dataset.nentity, 449 | full_test_dataset.nrelation 450 | ) 451 | sub_test_dataloader = DataLoader( 452 | sub_test_dataset, 453 | batch_size=args.test_batch_size, 454 | num_workers=args.cpu_num, 455 | collate_fn=TestDataset.collate_fn 456 | ) 457 | metrics = test_step(model, test_easy_answers, test_hard_answers, args, sub_test_dataloader, query_name_dict, verbose=True) 458 | return metrics 459 | 460 | 461 | def prepare_new_attributes(model, margin_type=None): 462 | """Prepare value for the new attributes in the model""" 463 | 464 | def set_dual_for_regularizer(regularizer): 465 | if not hasattr(regularizer, 'dual'): 466 | regularizer.dual = False 467 | 468 | def set_dual_as_false(model): 469 | model.entity_regularizer.dual = False 470 | model.projection_net.dual = False 471 | model.projection_net.regularizer.dual = False 472 | 473 | set_dual_as_false(model) 474 | model.counter_for_neg = False 475 | 476 | if not hasattr(model, 'margin_type'): 477 | if margin_type is not None: 478 | model.margin_type = margin_type 479 | else: 480 | print('Model margin type not recorded. Set as "logsigmoid" by default. Please double check if it\'s consistent with the model.') 481 | model.margin_type = 'logsigmoid' 482 | 483 | if not hasattr(model, 'no_anchor_reg'): 484 | model.no_anchor_reg = False 485 | 486 | if not hasattr(model, 'simplE'): 487 | model.simplE = False 488 | 489 | set_dual_for_regularizer(model.entity_regularizer) 490 | set_dual_for_regularizer(model.projection_net.regularizer) 491 | set_dual_for_regularizer(model.conjunction_net.regularizer) 492 | set_dual_for_regularizer(model.disjunction_net.regularizer) 493 | 494 | 495 | 496 | ################## Translate ############## 497 | def translate(ent_id2str, rel_id2str, x, structure): 498 | def translate_one(x, structure): 499 | """Get into nested list and translate one by one""" 500 | if isinstance(x, int): 501 | if x >= 0 and structure in ('e', 'r'): 502 | if structure == 'e': 503 | if x in ent_id2str: 504 | return ent_id2str[x] 505 | else: 506 | return 'missed_x' 507 | elif structure == 'r': 508 | return rel_id2str[x] 509 | else: 510 | return structure 511 | else: 512 | return type(x)(map(translate_one, x, structure)) 513 | 514 | return translate_one(x, structure) 515 | 516 | 517 | def translate_queries(ent_id2str, rel_id2str, qs): 518 | results = [] 519 | for q in qs: 520 | query, structure_idx = q[0], q[1] 521 | res = translate(ent_id2str, rel_id2str, query, query_structure_list[structure_idx]) 522 | results.append(res) 523 | return results 524 | 525 | def translate_answers(ent_id2str, ents, with_id=False): 526 | labels = [] 527 | for e in ents: 528 | e0 = int(e) # incase of torch tensor 529 | if e0 in ent_id2str: 530 | if not with_id: 531 | labels.append(ent_id2str[e0]) 532 | else: 533 | labels.append((ent_id2str[e0], e0)) 534 | else: 535 | labels.append(e0) 536 | return labels 537 | 538 | def translate_queries_with_hard_answers(ent_id2str, rel_id2str, test_hard_answers, qs): 539 | query_results = [] 540 | answer_results = [] 541 | results = [] 542 | for q in qs: 543 | query, structure_idx = q[0], q[1] 544 | q_res = translate(ent_id2str, rel_id2str, query, query_structure_list[structure_idx]) 545 | 546 | answers_in_id = test_hard_answers[query] 547 | a_res = translate_answers(ent_id2str, answers_in_id) 548 | 549 | query_results.append(q_res) 550 | answer_results.append(a_res) 551 | results.append((q_res, a_res)) 552 | return results 553 | 554 | 555 | 556 | 557 | ############# Find top ranked answers 558 | def rank_all_entities_for_one_query(model, full_test_dataset, args, q): 559 | """ 560 | q: tuple(nested_query, query_structure_idx), e.g. (((967, (35,)), (8734, (351, -2))), 7) 561 | """ 562 | # easy_answers, hard_answers, args, test_dataloader, query_name_dict 563 | model.eval() 564 | device = model.device 565 | test_dataloader = get_dataloader_for_one_query(full_test_dataset, args, q) 566 | 567 | logs = collections.defaultdict(list) 568 | with torch.no_grad(): 569 | for negative_sample, queries, queries_unflatten, query_structure_idxs in test_dataloader: 570 | # example: query_structures: [('e', ('r',))]. queries: [[1804,4]]. queries_unflatten: [(1804, (4,)] 571 | negative_sample = negative_sample.to(device) 572 | batch_size = len(negative_sample) 573 | slice_idxs = torch.arange(0, batch_size).view((batch_size, 1)) 574 | _, negative_logit, _, idxs = model( 575 | None, 576 | negative_sample, 577 | None, 578 | queries, # np.array([queries]), won't be split when using multiple GPUs 579 | query_structure_idxs, 580 | slice_idxs # to help track batch_queries and query_structures when using multiple GPUs 581 | ) 582 | idxs_np = idxs.detach().cpu().numpy() 583 | query_structure_idxs = query_structure_idxs[idxs_np] 584 | queries_unflatten = [queries_unflatten[i] for i in idxs] 585 | 586 | argsort = torch.argsort(negative_logit, dim=1, descending=True) 587 | ranking = argsort.clone().to(torch.float) 588 | 589 | # rank all entities 590 | # If it is the same shape with test_batch_size, reuse batch_entity_range without creating a new one 591 | if len(argsort) == args.test_batch_size: 592 | # ranking = ranking.scatter_(1, argsort, model.module.batch_entity_range) # achieve the ranking of all entities 593 | ranking = ranking.scatter_(1, argsort, model.batch_entity_range) # achieve the ranking of all entities 594 | else: # otherwise, create a new torch Tensor for batch_entity_range 595 | ranking = ranking.scatter_( 596 | 1, 597 | argsort, 598 | torch.arange(model.nentity).to(torch.float).repeat(argsort.shape[0], 1).to(device) 599 | # torch.arange(model.module.nentity).to(torch.float).repeat(argsort.shape[0], 1).to(device) 600 | ) 601 | 602 | return ranking 603 | 604 | 605 | 606 | def find_top_ranked_entities(model, full_test_dataset, args, q, topk=10): 607 | ranking = rank_all_entities_for_one_query(model, full_test_dataset, args, q) 608 | top_entities = (ranking <= topk).nonzero() # not sorted among top k 609 | return ranking, top_entities 610 | 611 | def translate_top_k(ent_id2str, topk, ranking, easy_answer, hard_answer): 612 | res = translate_answers(ent_id2str, topk[:, 1], with_id=True) 613 | df = pd.DataFrame(res, columns=['entity', 'id']) 614 | df['rank'] = df['id'].apply(lambda x: int(ranking[0,x])) 615 | df['easy_answer'] = df['id'].apply(lambda x: x in easy_answer) 616 | df['hard_answer'] = df['id'].apply(lambda x: x in hard_answer) 617 | return df.sort_values(by='rank') 618 | 619 | 620 | def translate_entity_and_ranking(ent_id2str, eids, ranking, with_id=False): 621 | res = [] 622 | for e, r in zip(eids, ranking): 623 | e0 = int(e) # incase of torch tensor 624 | if e0 in ent_id2str: 625 | if not with_id: 626 | res.append((ent_id2str[e0], int(r))) 627 | else: 628 | res.append((ent_id2str[e0], e0, int(r))) 629 | else: 630 | res.append(e0) 631 | return res 632 | 633 | def get_metrics(cur_ranking): 634 | mrr = torch.mean(1./cur_ranking).item() 635 | h1 = torch.mean((cur_ranking <= 1).to(torch.float)).item() 636 | h3 = torch.mean((cur_ranking <= 3).to(torch.float)).item() 637 | h10 = torch.mean((cur_ranking <= 10).to(torch.float)).item() 638 | 639 | return mrr, h1, h3, h10 640 | 641 | 642 | 643 | ################ Find ranking for hard answers 644 | def find_ranking_for_hard_answers(test_easy_answers, test_hard_answers, args, q, ranking): 645 | """ 646 | q: tuple(nested_query, query_structure_idx), e.g. (((967, (35,)), (8734, (351, -2))), 7) 647 | """ 648 | query_key = q[0] 649 | easy_answer = test_easy_answers[query_key] 650 | hard_answer = test_hard_answers[query_key] 651 | num_easy = len(easy_answer) 652 | num_hard = len(hard_answer) 653 | 654 | answers = torch.Tensor(list(easy_answer) + list(hard_answer)).type(torch.LongTensor) 655 | cur_ranking = ranking[0, answers] # ranking of easy and hard answers 656 | 657 | unfiltered_ranking = cur_ranking.clone() 658 | 659 | cur_ranking, indices = torch.sort(cur_ranking) 660 | 661 | 662 | masks = (indices >= num_easy) 663 | if args.cuda: 664 | answer_list = torch.arange(num_hard + num_easy).to(torch.float).cuda() 665 | else: 666 | answer_list = torch.arange(num_hard + num_easy).to(torch.float) 667 | cur_ranking = cur_ranking - answer_list + 1 # filtered setting 668 | cur_ranking = cur_ranking[masks] # only take indices that belong to the hard answers 669 | 670 | return easy_answer, hard_answer, answers, indices, cur_ranking, unfiltered_ranking 671 | 672 | def answer_ranking_df(ent_id2str, entities, ranking, easy_answer, hard_answer): 673 | res = translate_entity_and_ranking(ent_id2str, entities, ranking, with_id=True) 674 | df = pd.DataFrame(res, columns=['entity', 'id', 'ranking']) 675 | df['easy_answer'] = df['id'].apply(lambda x: x in easy_answer) 676 | df['hard_answer'] = df['id'].apply(lambda x: x in hard_answer) 677 | return df 678 | 679 | 680 | 681 | ####################### Check training loss ################################# 682 | def train_step_full(model, train_iterator, args, step): 683 | """ 684 | Adapted for multiple GPUs 685 | """ 686 | # device = model.module.device 687 | device = model.device 688 | model.eval() 689 | 690 | with torch.no_grad(): 691 | 692 | positive_sample, negative_sample, subsampling_weight, batch_queries, query_structure_idxs = next(train_iterator) 693 | 694 | if args.cuda: 695 | positive_sample = positive_sample.to(device) 696 | negative_sample = negative_sample.to(device) 697 | subsampling_weight = subsampling_weight.to(device) 698 | # no need to move query_structure_idxs to GPU 699 | 700 | # nn.DataParallel helper 701 | batch_size = len(positive_sample) 702 | slice_idxs = torch.arange(0, batch_size).view((batch_size, 1)) 703 | 704 | positive_score, negative_score, subsampling_weight, _ = model( 705 | positive_sample, 706 | negative_sample, 707 | subsampling_weight, 708 | batch_queries, # np.array([queries]), won't be split when using multiple GPUs 709 | query_structure_idxs, # torch.LongTensor 710 | slice_idxs # to help track batch_queries and query_structures when using multiple GPUs 711 | ) 712 | if args.margin_type == 'logsigmoid': 713 | # the loss of BetaE and RotatE 714 | positive_dist = 1-positive_score 715 | negative_dist = 1-negative_score 716 | positive_unweighted_loss = -F.logsigmoid((model.gamma - positive_dist)*model.gamma_coff).squeeze(dim=1) 717 | negative_unweighted_loss = -F.logsigmoid((negative_dist - model.gamma)*model.gamma_coff).mean(dim=1) 718 | positive_sample_loss = (subsampling_weight * positive_unweighted_loss).sum() 719 | negative_sample_loss = (subsampling_weight * negative_unweighted_loss).sum() 720 | positive_sample_loss /= subsampling_weight.sum() 721 | negative_sample_loss /= subsampling_weight.sum() 722 | loss = (positive_sample_loss + negative_sample_loss) / 2 723 | log = { 724 | 'positive_sample_loss': positive_sample_loss.item(), 725 | 'negative_sample_loss': negative_sample_loss.item(), 726 | 'loss': loss.item(), 727 | } 728 | elif args.margin_type == 'logsigmoid_avg': 729 | # use with cos_digits 730 | positive_dist = 1-positive_score 731 | negative_dist = 1-negative_score 732 | positive_unweighted_loss = -torch.mean(F.logsigmoid((model.gamma - positive_dist)*model.gamma_coff), dim=-1).squeeze(dim=1) 733 | negative_unweighted_loss = -torch.mean(F.logsigmoid((negative_dist - model.gamma)*model.gamma_coff), dim=-1).mean(dim=1) 734 | positive_sample_loss = (subsampling_weight * positive_unweighted_loss).sum() 735 | negative_sample_loss = (subsampling_weight * negative_unweighted_loss).sum() 736 | positive_sample_loss /= subsampling_weight.sum() 737 | negative_sample_loss /= subsampling_weight.sum() 738 | loss = (positive_sample_loss + negative_sample_loss) / 2 739 | log = { 740 | 'positive_sample_loss': positive_sample_loss.item(), 741 | 'negative_sample_loss': negative_sample_loss.item(), 742 | 'loss': loss.item(), 743 | } 744 | elif args.margin_type == 'logsigmoid_bpr': 745 | # gamma as margin 746 | diff = -F.logsigmoid(model.gamma_coff*(positive_score - negative_score)) 747 | # diff = torch.mean(-F.logsigmoid(model.gamma_coff*(positive_score - negative_score)), dim=-1) 748 | unweighted_sample_loss = torch.mean(diff, dim=-1) 749 | loss = (subsampling_weight * unweighted_sample_loss).sum() 750 | loss /= subsampling_weight.sum() 751 | log = { 752 | 'loss': loss.item(), 753 | } 754 | elif args.margin_type == 'logsigmoid_bpr_digits': 755 | # positive_score: shape [batch_size, 1, dim] (not aggregated yet) 756 | # negative_score: shape [batch_size, neg_per_pos, dim] 757 | diff = -F.logsigmoid(model.gamma_coff*(torch.mean(positive_score - negative_score, dim=-1))) 758 | # diff = torch.mean(-F.logsigmoid(model.gamma_coff*(positive_score - negative_score)), dim=-1) 759 | unweighted_sample_loss = torch.mean(diff, dim=-1) 760 | loss = (subsampling_weight * unweighted_sample_loss).sum() 761 | loss /= subsampling_weight.sum() 762 | log = { 763 | 'loss': loss.item(), 764 | } 765 | elif args.margin_type == 'softmax': 766 | # positive_score shape [batch_size, 1] 767 | #TODO: multi positives, same negative sample for the batch 768 | criterion = nn.CrossEntropyLoss(reduction='none') # keep loss for each sample 769 | scores = torch.cat([positive_score, negative_score], dim=1)*model.softmax_weight # [batch_size, 1+negative_sample_size] 770 | target = torch.zeros((positive_score.shape[0],), dtype=torch.long).to(device) 771 | loss = (criterion(scores, target) * subsampling_weight).sum() 772 | loss /= subsampling_weight.sum() 773 | log = {'loss': loss.item()} 774 | elif args.margin_type == 'bpr': 775 | # gamma as margin 776 | diff = torch.relu(model.gamma + negative_score -positive_score) # relu or softplus 777 | unweighted_sample_loss = torch.mean(diff, dim=-1) 778 | loss = (subsampling_weight * unweighted_sample_loss).sum() 779 | loss /= subsampling_weight.sum() 780 | log = { 781 | 'loss': loss.item(), 782 | } 783 | 784 | return positive_score, negative_score, log 785 | 786 | def train_step(model, train_iterator, args, step): 787 | """ 788 | Adapted for multiple GPUs 789 | """ 790 | # device = model.module.device 791 | device = model.device 792 | model.eval() 793 | 794 | with torch.no_grad(): 795 | 796 | positive_sample, negative_sample, subsampling_weight, batch_queries, query_structure_idxs = next(train_iterator) 797 | 798 | if args.cuda: 799 | positive_sample = positive_sample.to(device) 800 | negative_sample = negative_sample.to(device) 801 | subsampling_weight = subsampling_weight.to(device) 802 | # no need to move query_structure_idxs to GPU 803 | 804 | # nn.DataParallel helper 805 | batch_size = len(positive_sample) 806 | slice_idxs = torch.arange(0, batch_size).view((batch_size, 1)) 807 | 808 | positive_score, negative_score, subsampling_weight, _ = model( 809 | positive_sample, 810 | negative_sample, 811 | subsampling_weight, 812 | batch_queries, # np.array([queries]), won't be split when using multiple GPUs 813 | query_structure_idxs, # torch.LongTensor 814 | slice_idxs # to help track batch_queries and query_structures when using multiple GPUs 815 | ) 816 | if args.margin_type == 'logsigmoid': 817 | # the loss of BetaE and RotatE 818 | positive_dist = 1-positive_score 819 | negative_dist = 1-negative_score 820 | positive_unweighted_loss = -F.logsigmoid((model.gamma - positive_dist)*model.gamma_coff).squeeze(dim=1) 821 | negative_unweighted_loss = -F.logsigmoid((negative_dist - model.gamma)*model.gamma_coff).mean(dim=1) 822 | positive_sample_loss = (subsampling_weight * positive_unweighted_loss).sum() 823 | negative_sample_loss = (subsampling_weight * negative_unweighted_loss).sum() 824 | positive_sample_loss /= subsampling_weight.sum() 825 | negative_sample_loss /= subsampling_weight.sum() 826 | loss = (positive_sample_loss + negative_sample_loss) / 2 827 | log = { 828 | 'positive_sample_loss': positive_sample_loss.item(), 829 | 'negative_sample_loss': negative_sample_loss.item(), 830 | 'loss': loss.item(), 831 | } 832 | return positive_score, negative_score, log 833 | 834 | def get_query_structure_by_type_name(type_name): 835 | """1p -> (e, (r,))""" 836 | for qstructure, qname in query_name_dict.items(): # query_name_dict: imported from constants 837 | if qname == type_name: 838 | return qstructure 839 | 840 | def get_sub_train_iterator(full_train_queries, full_train_answers, type_name, args): 841 | sub_queries = defaultdict(set) 842 | structure = get_query_structure_by_type_name(type_name) 843 | sub_queries[structure] = full_train_queries[structure] # take queries of a type 844 | sub_queries = flatten_query_and_convert_structure_to_idx(sub_queries, query_structure2idx) 845 | n_queries = len(sub_queries) 846 | # print('type name', type_name) 847 | # print('query number', len(sub_queries)) 848 | if n_queries > 0: 849 | if type_name in ('1p', '2p', '3p'): 850 | train_iterator = SingledirectionalOneShotIterator(DataLoader( 851 | TrainDataset(sub_queries, args.nentity, args.nrelation, args.negative_sample_size, full_train_answers), 852 | batch_size=args.batch_size, 853 | shuffle=True, 854 | num_workers=args.cpu_num, 855 | collate_fn=TrainDataset.collate_fn 856 | )) 857 | else: 858 | train_iterator = SingledirectionalOneShotIterator(DataLoader( 859 | TrainDataset(sub_queries, args.nentity, args.nrelation, args.negative_sample_size, full_train_answers), 860 | batch_size=args.batch_size, 861 | shuffle=True, 862 | num_workers=args.cpu_num, 863 | collate_fn=TrainDataset.collate_fn 864 | )) 865 | return train_iterator 866 | else: 867 | return None 868 | 869 | def compute_loss_by_type(model, full_train_queries, full_train_answers, type_name, args): 870 | sub_train_iterator = get_sub_train_iterator(full_train_queries, full_train_answers, type_name=type_name, args=args) 871 | if sub_train_iterator is not None: # None if the type is the present in training queries 872 | positive_score, negative_score, log = train_step(model, sub_train_iterator, args, step=0) 873 | return positive_score, negative_score, log 874 | else: 875 | return None, None, None 876 | 877 | def loss_of_all_types(model, full_train_queries, full_train_answers, args): 878 | types = [s for s in query_name_dict.values() if 'DM' not in s] 879 | logs = {} # {type: log} 880 | for type_name in types: # e.g. '1p' 881 | positive_score, negative_score, log = compute_loss_by_type(model, full_train_queries, full_train_answers, type_name, args) 882 | logs[type_name] = log 883 | log_df = pd.DataFrame(logs) 884 | return log_df 885 | 886 | 887 | ############# entity, relation embedding ############# 888 | def get_rel(model, rid): 889 | """Get relation embedding""" 890 | therid = [rid] 891 | projection = model.projection_net # Projection() object 892 | r_trans = torch.einsum('br,rio->bio', projection.rel_att[therid], projection.rel_base) 893 | r_bias = projection.rel_bias[therid] 894 | return r_trans, r_bias 895 | 896 | def get_ent(model, eid, return_raw=False): 897 | """Get entity embedding""" 898 | raw = model.entity_embedding[eid] 899 | constrained = model.entity_regularizer(raw) 900 | if return_raw: 901 | return raw, constrained 902 | return constrained 903 | 904 | 905 | def slice_rel(model, rid): 906 | with torch.no_grad(): 907 | """slice the first 10 element from the relation embeddings and check""" 908 | r_trans, r_bias = get_rel(model, rid) 909 | return r_trans[0][0][:10], r_bias[0][:10] 910 | 911 | def slice_ent(model, eid): 912 | with torch.no_grad(): 913 | e_emb = get_ent(model, eid) 914 | return e_emb[:10] 915 | 916 | def slice_vec(vec): 917 | if len(vec) > 1: 918 | return vec[:10] 919 | elif len(vec[0]) > 1: 920 | return vec[0][:10] 921 | 922 | 923 | def loss_from_triple(model, h,r,t): 924 | """ 925 | h,r: query; t: answer 926 | """ 927 | projected_t = get_projected_t_vec(model, h=h, rs=[r]) 928 | true_t = get_ent(model, eid=t) 929 | score = model.cal_logit_fuzzy(true_t, projected_t) 930 | loss = loss_from_score(score, is_positive=True) 931 | return loss 932 | 933 | 934 | def loss_from_score(model, score, is_positive=True): 935 | """Unweighted loss. Default type: logsigmoid""" 936 | dist = 1-score 937 | if is_positive: 938 | loss = -F.logsigmoid((model.gamma - dist)*model.gamma_coff) 939 | else: 940 | loss = -F.logsigmoid((dist - model.gamma)*model.gamma_coff) 941 | return loss 942 | 943 | def get_train_queries_answers(data_dir): 944 | train_queries = pickle.load(open(os.path.join(data_dir, "train-queries.pkl"), 'rb')) 945 | train_answers = pickle.load(open(os.path.join(data_dir, "train-answers.pkl"), 'rb')) 946 | return train_queries, train_answers 947 | 948 | 949 | def get_translated_sample_score_df(model, test_queries, test_hard_answers, type_name='pi'): 950 | sub_train_iterator = get_sub_train_iterator(test_queries, test_hard_answers, type_name=type_name, args=args) 951 | positive_score, negative_score, log, batch_queries, positive_sample, negative_sample = train_step1(model, sub_train_iterator, args, step=0, return_queries=True) 952 | if type_name == 'pi': 953 | flattened_structure = ('e', 'r', 'r', 'e', 'r') 954 | 955 | translated_queries = translate_flattened_queries(batch_queries, flattened_structure) 956 | translated_answers = translate_answers(ent_id2str, positive_sample, with_id=False) 957 | query_df = pd.DataFrame(translated_queries, columns=['e','r','r','e','r']) 958 | query_df['answer'] = translated_answers 959 | query_df['score'] = list(positive_score.squeeze().cpu().numpy()) 960 | 961 | query_df_with_ids = query_df.copy() 962 | query_df_with_ids['query'] = list(batch_queries) 963 | query_df_with_ids['answer_id'] = list(positive_sample.cpu().numpy()) 964 | return query_df, query_df_with_ids 965 | 966 | 967 | def scoring_triple(model, h, r, t): 968 | """ 969 | h,r: query; t: answer 970 | """ 971 | projected_t = get_projected_t_vec(model, h=h, rs=[r]) 972 | true_t = get_ent(model, eid=t) 973 | score = model.cal_logit_fuzzy(true_t, projected_t) 974 | return score 975 | 976 | 977 | 978 | 979 | def get_all_rids(q, qstructure_idx): 980 | """ 981 | q: e.g. (((711, (139, 59)), (1295, (59,))) 982 | """ 983 | if qstructure_idx == 0: # 1p 984 | entities = [q[0]] 985 | relations = [q[1][0]] 986 | elif qstructure_idx == 1: # 2p 987 | entities = [q[0]] 988 | relations = [q[1][0], q[1][1]] 989 | elif qstructure_idx == 2: # 3p 990 | entities = [q[0]] 991 | relations = [q[1][0], q[1][1], q[1][2]] 992 | elif qstructure_idx == 3: # 2i 993 | entities = [q[0][0], q[1][0]] 994 | relations = [q[0][1][0], q[1][1][0]] 995 | elif qstructure_idx == 4: # 3i 996 | entities = [q[0][0], q[1][0], q[2][0]] 997 | relations = [q[0][1][0], q[1][1][0], q[2][1][0]] 998 | elif qstructure_idx == 5: # ip 999 | entities = [q[0][0][0], q[0][1][0]] 1000 | relations = [q[0][0][1][0], q[0][1][1][0], q[1][0]] 1001 | elif qstructure_idx == 6: # pi 1002 | entities = [q[0][0], q[1][0]] 1003 | relations = [q[0][1][0], q[0][1][1], q[1][1][0]] 1004 | else: 1005 | print(f'type {qstructure_idx} not defined relation set!') 1006 | return entities, relations 1007 | 1008 | def filter_queries(queries, badrelations, badentities): 1009 | """ 1010 | :param queries: list[tuple(q, q_structure_idx)] 1011 | """ 1012 | print(badrelations) 1013 | def q_is_valid(badrelations, q, qstructure_idx): 1014 | entities, relations = get_all_rids(q, qstructure_idx) 1015 | for r in relations: 1016 | if r in badrelations: 1017 | return False 1018 | for e in entities: 1019 | if e in badentities: 1020 | return False 1021 | return True 1022 | 1023 | ok_queries = [] 1024 | for qtuple in queries: 1025 | q, qstructure_idx = qtuple[0], qtuple[1] 1026 | if q_is_valid(badrelations, q, qstructure_idx): 1027 | ok_queries.append(qtuple) 1028 | return ok_queries 1029 | 1030 | 1031 | 1032 | def make_test_dataset_and_dataloader(queries, args): 1033 | sub_test_dataset = TestDataset( 1034 | queries, 1035 | args.nentity, 1036 | args.nrelation 1037 | ) 1038 | 1039 | sub_test_dataloader = DataLoader( 1040 | sub_test_dataset, 1041 | batch_size=args.test_batch_size, 1042 | num_workers=args.cpu_num, 1043 | collate_fn=TestDataset.collate_fn 1044 | ) 1045 | return sub_test_dataset, sub_test_dataloader 1046 | 1047 | 1048 | # obsolete 1049 | def add2p(): 1050 | q1p = train_queries[('e',('r',))] 1051 | print(len(q1p)) 1052 | new_2p_queries = set() 1053 | new_2p_ans = {} # {query: set(ans)} 1054 | for q in q1p: 1055 | h = q[0] 1056 | r = q[1][0] 1057 | if r % 2 == 0: # r is even 1058 | q2p = (h, (r, r+1)) 1059 | else: 1060 | q2p = (h, (r, r-1)) 1061 | ans = h 1062 | 1063 | new_2p_queries.add(q2p) 1064 | if q2p in new_2p_ans: 1065 | new_2p_ans[q2p].add(ans) 1066 | else: 1067 | new_2p_ans[q2p] = set([ans]) 1068 | train_queries[('e', ('r', 'r'))] = set.union(train_queries[('e', ('r', 'r'))], new_2p_queries) 1069 | train_answers.update(new_2p_ans) 1070 | with open(join(data_dir, 'train-queries-new.pkl'), 'wb') as f: 1071 | pickle.dump(train_queries, f) 1072 | with open(join(data_dir, 'train-answers-new.pkl'), 'wb') as f: 1073 | pickle.dump(train_answers, f) 1074 | 1075 | 1076 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | 2 | #!/usr/bin/python3 3 | 4 | from __future__ import absolute_import 5 | from __future__ import division 6 | from __future__ import print_function 7 | 8 | import argparse 9 | import json 10 | import logging 11 | import os 12 | import wandb 13 | import random 14 | import torch.nn as nn 15 | from os.path import join 16 | import math 17 | import numpy as np 18 | import torch 19 | from torch.utils.data import DataLoader 20 | from models import KGReasoning 21 | from fuzzyreasoning import KGFuzzyReasoning 22 | from dataloader import TestDataset, TrainDataset, SingledirectionalOneShotIterator 23 | # from tensorboardX import SummaryWriter 24 | import time 25 | import pickle 26 | from util import * 27 | from dataloader import load_data 28 | from constants import * 29 | 30 | 31 | def parse_args(args=None): 32 | parser = argparse.ArgumentParser( 33 | description='Training and Testing Knowledge Graph Embedding Models', 34 | usage='train.py [] [-h | --help]' 35 | ) 36 | 37 | parser.add_argument('--cuda', action='store_true', help='use GPU') 38 | 39 | parser.add_argument('--do_train', action='store_true', help="do train") 40 | parser.add_argument('--do_valid', action='store_true', help="do valid") 41 | parser.add_argument('--do_test', action='store_true', help="do test") 42 | 43 | parser.add_argument('--data_path', type=str, default=None, help="KG data path") 44 | parser.add_argument('-n', '--negative_sample_size', default=128, type=int, help="negative entities sampled per query") 45 | parser.add_argument('-d', '--hidden_dim', default=500, type=int, help="embedding dimension") 46 | parser.add_argument('-g', '--gamma', default=0.5, type=float, help="margin in the loss") 47 | parser.add_argument('-b', '--batch_size', default=1024, type=int, help="batch size of queries") 48 | parser.add_argument('--test_batch_size', default=1, type=int, help='valid/test batch size') 49 | parser.add_argument('-lr', '--learning_rate', default=0.0001, type=float) 50 | parser.add_argument('-cpu', '--cpu_num', default=10, type=int, help="used to speed up torch.dataloader") 51 | parser.add_argument('-save', '--save_path', default='./trained_models', type=str, help="no need to set manually, will configure automatically") 52 | parser.add_argument('--max_steps', default=100000, type=int, help="maximum iterations to train") 53 | parser.add_argument('--warm_up_steps', default=None, type=int, help="no need to set manually, will configure automatically") 54 | parser.add_argument('--valid_steps', default=10000, type=int, help="evaluate validation queries every xx steps") 55 | parser.add_argument('--log_steps', default=100, type=int, help='train log every xx steps') 56 | parser.add_argument('--test_log_steps', default=1000, type=int, help='valid/test log every xx steps') 57 | 58 | parser.add_argument('--nentity', type=int, default=0, help='DO NOT MANUALLY SET') 59 | parser.add_argument('--nrelation', type=int, default=0, help='DO NOT MANUALLY SET') 60 | 61 | parser.add_argument('--geo', default='fuzzy', type=str, choices=['vec', 'box', 'beta', 'fuzzy'], help='the reasoning model, vec for GQE, box for Query2box, beta for BetaE') 62 | parser.add_argument('--print_on_screen', action='store_true') 63 | 64 | parser.add_argument('--tasks', default='1p.2p.3p.2i.3i.ip.pi.2in.3in.inp.pin.pni.2u.up', type=str, help="tasks connected by dot, refer to the BetaE paper for detailed meaning and structure of each task") 65 | parser.add_argument('--seed', default=0, type=int, help="random seed") 66 | parser.add_argument('-betam', '--beta_mode', default="(1600,2)", type=str, help='(hidden_dim,num_layer) for BetaE relational projection') 67 | parser.add_argument('-boxm', '--box_mode', default="(none,0.02)", type=str, help='(offset activation,center_reg) for Query2box, center_reg balances the in_box dist and out_box dist') 68 | parser.add_argument('--prefix', default=None, type=str, help='prefix of the log path') 69 | parser.add_argument('--checkpoint_path', default=None, type=str, help='path for loading the checkpoints') 70 | parser.add_argument('-evu', '--evaluate_union', default="DNF", type=str, choices=['DNF', 'DM'], help='the way to evaluate union queries, transform it to disjunctive normal form (DNF) or use the De Morgan\'s laws (DM)') 71 | 72 | # fuzzy logic 73 | parser.add_argument('--logic', default='godel', type=str, choices=['luka', 'godel', 'product', 'godel_gumbel'], 74 | help='fuzzy logic type') 75 | 76 | 77 | # regularizer 78 | parser.add_argument('--regularizer', default='sigmoid', type=str, 79 | choices=['01', 'vector_softmax', 'matrix_softmax', 'matrix_L1', 'matrix_sigmoid_L1','sigmoid', 'vector_sigmoid_L1'], 80 | help='ways to regularize parameters') # By default, this regularizer applies to both entities and queries 81 | 82 | parser.add_argument('--e_regularizer', default='same', type=str, 83 | choices=['same', '01', 'vector_softmax', 'matrix_softmax', 'matrix_L1', 'matrix_sigmoid_L1','sigmoid', 'vector_sigmoid_L1'], 84 | help='set regularizer for entities, different from queries') # if 'same' (default), just use args.regularizer 85 | parser.add_argument('--entity_ln_before_reg', action="store_true", help='apply layer normalization before applying regularizer to entities') 86 | 87 | 88 | parser.add_argument('--gamma_coff', default=20, type=float, help='coefficient for gamma') 89 | parser.add_argument('-k', '--prob_dim', default=8, type=int, help="for matrix_softmax and matrix_L1. dims per prob vector") 90 | parser.add_argument('--godel_gumbel_beta', default=0.01, type=int, help="Gumbel beta for min/max computation when logic=godel_gumbel") 91 | parser.add_argument('--loss_type', default='cos', type=str, 92 | choices=['cos', 93 | 'cos_digits', 'L1_cos_digits', 'dot_layernorm_digits', 94 | 'dot', 'weighted_dot', 95 | 'soft_min_digits', 96 | 'kl', 'entropy', 97 | 'discrete_cos', 'discrete_prob', 'discrete_gumbel', 'gumbel_softmax', 98 | 'fuzzy_containment', 'weighted_fuzzy_containment', 99 | 'entity_multinomial_dot', # use with sigmoid regularizer. L1 noramlize entity before computing score 100 | 'normalized_entity_dot' # normalize entity when no grad. use with 0/1 regularizer for entity, and sigmoid for query 101 | ], help="loss type") 102 | parser.add_argument( 103 | '--margin_type', default='logsigmoid_bpr', type=str, 104 | choices=[ 105 | 'logsigmoid', 'logsigmoid_bpr', 'logsigmoid_bpr_digits', 'bpr_digits', 'logsigmoid_avg', 'bpr', 'softmax', 'nll' 106 | ], 107 | help='ways to implement margin' 108 | ) 109 | parser.add_argument( 110 | '--with_counter', action="store_true", help="add neg q into negative samples" 111 | ) 112 | parser.add_argument('--gpu_ids', default='0', type=str) 113 | parser.add_argument('--continue_train', default=None, type=str, help='run name to load and continue training') 114 | 115 | # gumbel softmax 116 | parser.add_argument('--gumbel_temperature', default=1, type=float, 117 | help="Gumbel temperature for gumbel softmax") 118 | parser.add_argument('--gumbel_attention', default='none', type=str, choices=['none', 'plain', 'query_dependent'], help="Add distribution-wise attention") 119 | parser.add_argument('--query_unnorm', action="store_true") 120 | parser.add_argument('--simplE', action="store_true", help="Use different head and tail embeddings for entities") 121 | 122 | # conjunction 123 | parser.add_argument('--use_attention', action='store_true', help='use attention for conjunction') 124 | 125 | # relation as a transformation 126 | parser.add_argument('--projection_type', default='rtransform', type=str, choices=['mlp', 'rtransform', 'transe']) 127 | parser.add_argument('--num_rel_base', default=50, type=int) 128 | 129 | # lr scheduler 130 | # original is BetaE original 131 | parser.add_argument('--lr_scheduler', default='annealing', type=str, choices=['none', 'original', 'step', 'annealing', 'plateau', 'onecycle']) 132 | parser.add_argument('--optimizer', default='Adam', type=str, choices=['Adam', 'AdamW']) 133 | parser.add_argument('--L2_reg', default=0, type=float) 134 | parser.add_argument('--N3_regularization', action='store_true', help='nuclear 3-norm regularization. L2_reg as coefficient. not using weight decay.') 135 | 136 | parser.add_argument('--in_batch_negative', action='store_true', help='use in-batch negatives') 137 | 138 | parser.add_argument('--load_pretrained', action='store_true', help='load pretrained embeddings. dimension=1000. only for NELL') 139 | 140 | parser.add_argument('--no_anchor_reg', action='store_true', help='no anchor entity regularizer') 141 | parser.add_argument('--share_relation_bias', action='store_true', help='share relation bias') 142 | 143 | 144 | return parser.parse_args(args) 145 | 146 | 147 | def main(args): 148 | run = wandb_initialize(vars(args)) 149 | run.save() 150 | print(f'Wandb run name: {run.name}') 151 | 152 | # cuda settings 153 | if args.cuda: 154 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" 155 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_ids # specify which GPU(s) to be used 156 | args.batch_size = args.batch_size * torch.cuda.device_count() # adjust batch size 157 | print(f'Cuda device count:{torch.cuda.device_count()}') 158 | wandb.log({'num_gpu': torch.cuda.device_count()}) 159 | device = torch.device('cuda' if args.cuda else 'cpu') 160 | 161 | set_global_seed(args.seed) 162 | tasks = args.tasks.split('.') 163 | for task in tasks: 164 | if 'n' in task and args.geo in ['box', 'vec']: 165 | assert False, "Q2B and GQE cannot handle queries with negation" 166 | if args.evaluate_union == 'DM': 167 | assert args.geo == 'beta', "only BetaE supports modeling union using De Morgan's Laws" 168 | 169 | # Model save path 170 | args.save_path = join('./trained_models') 171 | print(f'Overwrite model save path. Save model and log to folder {args.save_path}') 172 | if not os.path.exists(args.save_path): 173 | os.makedirs(args.save_path) 174 | 175 | # logger 176 | set_logger(args) 177 | 178 | nentity, nrelation = read_num_entity_relation_from_file(args.data_path) 179 | args.nentity, args.nrelation = nentity, nrelation 180 | wandb.log({'nentity': nentity, 'nrelation': nrelation}) 181 | 182 | train_path_iterator, train_other_iterator, valid_dataloader, test_dataloader,\ 183 | valid_hard_answers, valid_easy_answers, \ 184 | test_hard_answers, test_easy_answers = load_data(args, query_name_dict, tasks) 185 | 186 | if len(tasks) == 1: # 1p only 187 | # load full test data for testing 188 | full_tasks = '1p.2p.2i.2in'.split('.') 189 | _, _, _, test_dataloader,\ 190 | _, _, \ 191 | test_hard_answers, test_easy_answers = load_data(args, query_name_dict, full_tasks) 192 | 193 | 194 | # Fuzzy only. This repo does not support other geo like BetaE, box anymore 195 | if args.geo == 'fuzzy': 196 | model = KGFuzzyReasoning( 197 | nentity=nentity, 198 | nrelation=nrelation, 199 | hidden_dim=args.hidden_dim, 200 | gamma=args.gamma, 201 | geo=args.geo, 202 | use_cuda=args.cuda, 203 | box_mode=eval_tuple(args.box_mode), 204 | beta_mode=eval_tuple(args.beta_mode), 205 | test_batch_size=args.test_batch_size, 206 | query_name_dict=query_name_dict, 207 | logic_type=args.logic, 208 | gamma_coff=args.gamma_coff, 209 | regularizer_setting={ 210 | 'type': args.regularizer, # for query 211 | 'e_reg_type': args.regularizer if args.e_regularizer == 'same' else args.e_regularizer, 212 | 'prob_dim': args.prob_dim, # for matrix softmax 213 | 'dual': True if args.loss_type == 'weighted_fuzzy_containment' else False, 214 | 'e_layernorm': args.entity_ln_before_reg # apply Layer Norm before next step's regularizer 215 | }, 216 | loss_type=args.loss_type, 217 | margin_type=args.margin_type, 218 | device=device, 219 | godel_gumbel_beta=args.godel_gumbel_beta, 220 | gumbel_temperature=args.gumbel_temperature, 221 | projection_type=args.projection_type, 222 | args=args 223 | ) 224 | else: 225 | model = KGReasoning( 226 | nentity=nentity, 227 | nrelation=nrelation, 228 | hidden_dim=args.hidden_dim, 229 | gamma=args.gamma, 230 | geo=args.geo, 231 | use_cuda=args.cuda, 232 | box_mode=eval_tuple(args.box_mode), 233 | beta_mode=eval_tuple(args.beta_mode), 234 | test_batch_size=args.test_batch_size, 235 | query_name_dict=query_name_dict 236 | ) 237 | model = model.to(device) 238 | # model = nn.DataParallel(model) # make it parallel 239 | wandb.watch(model) 240 | print_parameters(model) 241 | 242 | # set lr and optimizer 243 | if args.do_train: 244 | current_learning_rate = args.learning_rate 245 | if args.optimizer == 'AdamW': # use together with lr_scheduler none 246 | if args.L2_reg > 0: 247 | weight_decay = args.L2_reg 248 | else: 249 | weight_decay = 1e-2 250 | print(f'AdamW weight decay: {weight_decay}') 251 | optimizer = torch.optim.AdamW( 252 | filter(lambda p: p.requires_grad, list(model.parameters())), 253 | lr=args.learning_rate, 254 | eps=1e-06, 255 | weight_decay=weight_decay 256 | ) 257 | else: 258 | optimizer = torch.optim.Adam( 259 | filter(lambda p: p.requires_grad, list(model.parameters())), 260 | lr=current_learning_rate, 261 | weight_decay=args.L2_reg # L2 regularization 262 | ) 263 | 264 | if args.lr_scheduler == 'original': 265 | warm_up_steps = args.max_steps // 2 # reduce lr when reaching warm up steps 266 | elif args.lr_scheduler == 'step': 267 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=50000, gamma=0.2) 268 | elif args.lr_scheduler == 'annealing': 269 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.max_steps, eta_min=0, last_epoch=-1) 270 | elif args.lr_scheduler == 'plateau': 271 | scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( 272 | optimizer, mode='min', factor=0.5, patience=args.valid_steps*2, 273 | verbose=False, threshold=0.0001, threshold_mode='rel', cooldown=0, 274 | min_lr=0.0001, eps=1e-07 275 | ) 276 | elif args.lr_scheduler == 'onecycle': 277 | scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, pct_start=0.05, anneal_strategy='linear', final_div_factor=10,\ 278 | max_lr = 5e-4, total_steps = args.batch_size * args.max_steps + 1) 279 | 280 | if args.continue_train is not None: 281 | saved_run_name = args.continue_train 282 | model_path = join(args.save_path, saved_run_name+'.pt') 283 | model = torch.load(model_path) 284 | 285 | 286 | init_step = 0 287 | step = init_step 288 | 289 | if args.do_train: 290 | print('=== Start training ===') 291 | time0 = time.time() 292 | 293 | training_logs = [] 294 | # #Training Loop 295 | 296 | last_best_metric = None 297 | last_best_step = 0 298 | early_stop_metric = 'average_MRR' 299 | patience = args.valid_steps * 5 300 | 301 | for step in range(init_step, args.max_steps): 302 | if step == 2*args.max_steps//3: 303 | args.valid_steps *= 4 304 | 305 | # if args.loss_type == 'gumbel_softmax' and step % 5000 == 0: 306 | # # gumbel temperature annealing; temperature = max(0.5, exp(-rt)), r={1e-5, 1e-4}, t=step 307 | # model.gumbel_temperature = max(0.5, math.exp(-1e-4*step)) 308 | 309 | # log = model.module.train_step(model, optimizer, train_path_iterator, args, step) 310 | log = model.train_step(model, optimizer, train_path_iterator, args, step) 311 | if train_other_iterator is not None: 312 | log = model.train_step(model, optimizer, train_other_iterator, args, step) 313 | # log = model.train_step(model, optimizer, train_path_iterator, args, step) 314 | 315 | training_logs.append(log) 316 | 317 | # update learning rate 318 | if args.lr_scheduler != 'none': # do not change lr if 'none' 319 | if args.lr_scheduler == 'original': # BetaE original 320 | if step >= warm_up_steps: 321 | current_learning_rate = current_learning_rate / 5 322 | warm_up_steps = warm_up_steps * 1.5 323 | optimizer = torch.optim.Adam( 324 | filter(lambda p: p.requires_grad, model.parameters()), 325 | lr=current_learning_rate 326 | ) # new optimizer 327 | 328 | elif args.lr_scheduler in ('step', 'annealing', 'plateau', 'onecycle'): 329 | if args.lr_scheduler == 'plateau': 330 | scheduler.step(log['loss']) 331 | else: 332 | scheduler.step() 333 | 334 | 335 | if step % args.valid_steps == 0 and step > 0: 336 | if args.do_valid: 337 | print('Evaluating on Valid Dataset...') 338 | valid_all_metrics = evaluate(model, valid_easy_answers, valid_hard_answers, args, valid_dataloader, query_name_dict, 'Valid', step) 339 | 340 | if args.do_test: 341 | print('Evaluating on Test Dataset...') 342 | time1 = time.time() 343 | test_all_metrics = evaluate(model, test_easy_answers, test_hard_answers, args, test_dataloader, query_name_dict, 'Test', step) 344 | print(test_all_metrics) 345 | print(f'Finished testing. Testing used time {time.time()-time1:.2f}') 346 | 347 | # if last_best_metric is None: 348 | # last_best_metric = test_all_metrics.copy() 349 | 350 | # early stop 351 | #TODO: change it to valid_all_metrics 352 | if last_best_metric is None or test_all_metrics[early_stop_metric] > last_best_metric[early_stop_metric]: 353 | last_best_metric = test_all_metrics.copy() 354 | last_best_step = step 355 | # save 356 | if args.geo == 'fuzzy': 357 | save_path = os.path.join(args.save_path, f'{run.name}.pt') 358 | torch.save(model, save_path) 359 | else: # baseline models. can only save model.state_dict 360 | save_dir = join(args.save_path, 'baselines', args.geo, str(run.name)) 361 | if not os.path.exists(save_dir): 362 | os.makedirs(save_dir) 363 | save_variable_list = { 364 | 'step': step, 365 | 'current_learning_rate': current_learning_rate, 366 | 'warm_up_steps': warm_up_steps 367 | } 368 | save_model(model, optimizer, save_variable_list, save_dir, args) 369 | elif step > last_best_step + patience: 370 | # early stop 371 | break 372 | 373 | 374 | if step % args.log_steps == 0: 375 | metrics = {} 376 | for metric in training_logs[0].keys(): 377 | metrics[metric] = sum([log[metric] for log in training_logs])/len(training_logs) 378 | 379 | log_metrics('Training average', step, metrics) 380 | training_logs = [] 381 | 382 | print(f'Time to train {args.log_steps} step: {time.time() - time0:.2f}') 383 | 384 | # # debug parameter change 385 | # if args.projection_type == 'mlp': 386 | # wandb.log({ 387 | # 'projection_layer00': model.projection_net.layer0.weight[0,0] 388 | # }) 389 | # if args.regularizer == 'sigmoid': 390 | # wandb.log({ 391 | # 'conjunction_regularizer': model.conjunction_net.regularizer.weight[0] 392 | # }) 393 | 394 | cur_lr = optimizer.param_groups[0]['lr'] 395 | # print('Change learning_rate to %f at step %d' % (current_learning_rate, step)) 396 | wandb.log({'current_lr': cur_lr}) 397 | 398 | 399 | 400 | 401 | 402 | try: 403 | print(step) 404 | except: 405 | step = 0 406 | 407 | if args.do_test: 408 | logging.info('Evaluating on Test Dataset...') 409 | test_all_metrics = evaluate(model, test_easy_answers, test_hard_answers, args, test_dataloader, query_name_dict, 'Test', step) 410 | 411 | logging.info("Training finished!!") 412 | 413 | if __name__ == '__main__': 414 | main(parse_args()) 415 | -------------------------------------------------------------------------------- /models.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python3 2 | 3 | from __future__ import absolute_import 4 | from __future__ import division 5 | from __future__ import print_function 6 | 7 | import logging 8 | import numpy as np 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | from torch.utils.data import DataLoader 13 | from dataloader import TestDataset, TrainDataset, SingledirectionalOneShotIterator 14 | import random 15 | import pickle 16 | import math 17 | import collections 18 | import itertools 19 | import time 20 | from tqdm import tqdm 21 | import os 22 | from regularizers import Regularizer 23 | from constants import query_structure_list 24 | 25 | def Identity(x): 26 | return x 27 | 28 | class BoxOffsetIntersection(nn.Module): 29 | 30 | def __init__(self, dim): 31 | super(BoxOffsetIntersection, self).__init__() 32 | self.dim = dim 33 | self.layer1 = nn.Linear(self.dim, self.dim) 34 | self.layer2 = nn.Linear(self.dim, self.dim) 35 | 36 | nn.init.xavier_uniform_(self.layer1.weight) 37 | nn.init.xavier_uniform_(self.layer2.weight) 38 | 39 | def forward(self, embeddings): 40 | layer1_act = F.relu(self.layer1(embeddings)) 41 | layer1_mean = torch.mean(layer1_act, dim=0) 42 | gate = torch.sigmoid(self.layer2(layer1_mean)) 43 | offset, _ = torch.min(embeddings, dim=0) 44 | 45 | return offset * gate 46 | 47 | class CenterIntersection(nn.Module): 48 | 49 | def __init__(self, dim): 50 | super(CenterIntersection, self).__init__() 51 | self.dim = dim 52 | self.layer1 = nn.Linear(self.dim, self.dim) 53 | self.layer2 = nn.Linear(self.dim, self.dim) 54 | 55 | nn.init.xavier_uniform_(self.layer1.weight) 56 | nn.init.xavier_uniform_(self.layer2.weight) 57 | 58 | def forward(self, embeddings): 59 | layer1_act = F.relu(self.layer1(embeddings)) # (num_conj, dim) 60 | attention = F.softmax(self.layer2(layer1_act), dim=0) # (num_conj, dim) 61 | embedding = torch.sum(attention * embeddings, dim=0) 62 | 63 | return embedding 64 | 65 | class BetaIntersection(nn.Module): 66 | 67 | def __init__(self, dim): 68 | super(BetaIntersection, self).__init__() 69 | self.dim = dim 70 | self.layer1 = nn.Linear(2 * self.dim, 2 * self.dim) 71 | self.layer2 = nn.Linear(2 * self.dim, self.dim) 72 | 73 | nn.init.xavier_uniform_(self.layer1.weight) 74 | nn.init.xavier_uniform_(self.layer2.weight) 75 | 76 | def forward(self, alpha_embeddings, beta_embeddings): 77 | all_embeddings = torch.cat([alpha_embeddings, beta_embeddings], dim=-1) 78 | layer1_act = F.relu(self.layer1(all_embeddings)) # (num_conj, batch_size, 2 * dim) 79 | attention = F.softmax(self.layer2(layer1_act), dim=0) # (num_conj, batch_size, dim) 80 | 81 | alpha_embedding = torch.sum(attention * alpha_embeddings, dim=0) 82 | beta_embedding = torch.sum(attention * beta_embeddings, dim=0) 83 | 84 | return alpha_embedding, beta_embedding 85 | 86 | class BetaProjection(nn.Module): 87 | def __init__(self, entity_dim, relation_dim, hidden_dim, projection_regularizer, num_layers): 88 | super(BetaProjection, self).__init__() 89 | self.entity_dim = entity_dim 90 | self.relation_dim = relation_dim 91 | self.hidden_dim = hidden_dim 92 | self.num_layers = num_layers 93 | self.layer1 = nn.Linear(self.entity_dim + self.relation_dim, self.hidden_dim) # 1st layer 94 | self.layer0 = nn.Linear(self.hidden_dim, self.entity_dim) # final layer 95 | for nl in range(2, num_layers + 1): 96 | setattr(self, "layer{}".format(nl), nn.Linear(self.hidden_dim, self.hidden_dim)) 97 | for nl in range(num_layers + 1): 98 | nn.init.xavier_uniform_(getattr(self, "layer{}".format(nl)).weight) 99 | self.projection_regularizer = projection_regularizer 100 | 101 | def forward(self, e_embedding, r_embedding): 102 | x = torch.cat([e_embedding, r_embedding], dim=-1) 103 | for nl in range(1, self.num_layers + 1): 104 | x = F.relu(getattr(self, "layer{}".format(nl))(x)) 105 | x = self.layer0(x) 106 | x = self.projection_regularizer(x) 107 | 108 | return x 109 | 110 | 111 | class KGReasoning(nn.Module): 112 | def __init__(self, nentity, nrelation, hidden_dim, gamma, 113 | geo, test_batch_size=1, 114 | box_mode=None, use_cuda=False, 115 | query_name_dict=None, beta_mode=None): 116 | super(KGReasoning, self).__init__() 117 | self.nentity = nentity 118 | self.nrelation = nrelation 119 | self.hidden_dim = hidden_dim 120 | self.epsilon = 2.0 121 | self.geo = geo 122 | self.use_cuda = use_cuda 123 | self.batch_entity_range = torch.arange(nentity).to(torch.float).repeat(test_batch_size, 1).cuda() if self.use_cuda else torch.arange(nentity).to(torch.float).repeat(test_batch_size, 1) # used in test_step 124 | self.query_name_dict = query_name_dict 125 | 126 | self.gamma = nn.Parameter( 127 | torch.Tensor([gamma]), 128 | requires_grad=False 129 | ) 130 | 131 | self.embedding_range = nn.Parameter( 132 | torch.Tensor([(self.gamma.item() + self.epsilon) / hidden_dim]), 133 | requires_grad=False 134 | ) 135 | 136 | self.entity_dim = hidden_dim 137 | self.relation_dim = hidden_dim 138 | 139 | if self.geo == 'box': 140 | self.entity_embedding = nn.Parameter(torch.zeros(nentity, self.entity_dim)) # centor for entities 141 | activation, cen = box_mode 142 | self.cen = cen # hyperparameter that balances the in-box distance and the out-box distance 143 | if activation == 'none': 144 | self.func = Identity 145 | elif activation == 'relu': 146 | self.func = F.relu 147 | elif activation == 'softplus': 148 | self.func = F.softplus 149 | elif self.geo == 'vec': 150 | self.entity_embedding = nn.Parameter(torch.zeros(nentity, self.entity_dim)) # center for entities 151 | elif self.geo == 'beta': 152 | self.entity_embedding = nn.Parameter(torch.zeros(nentity, self.entity_dim * 2)) # alpha and beta 153 | self.entity_regularizer = Regularizer(1, 0.05, 1e9) # make sure the parameters of beta embeddings are positive 154 | self.projection_regularizer = Regularizer(1, 0.05, 1e9) # make sure the parameters of beta embeddings after relation projection are positive 155 | 156 | if self.geo != 'fuzzy': 157 | nn.init.uniform_( 158 | tensor=self.entity_embedding, 159 | a=-self.embedding_range.item(), 160 | b=self.embedding_range.item() 161 | ) 162 | 163 | self.relation_embedding = nn.Parameter(torch.zeros(nrelation, self.relation_dim)) 164 | nn.init.uniform_( 165 | tensor=self.relation_embedding, 166 | a=-self.embedding_range.item(), 167 | b=self.embedding_range.item() 168 | ) 169 | 170 | if self.geo == 'box': 171 | self.offset_embedding = nn.Parameter(torch.zeros(nrelation, self.entity_dim)) 172 | nn.init.uniform_( 173 | tensor=self.offset_embedding, 174 | a=0., 175 | b=self.embedding_range.item() 176 | ) 177 | self.center_net = CenterIntersection(self.entity_dim) 178 | self.offset_net = BoxOffsetIntersection(self.entity_dim) 179 | elif self.geo == 'vec': 180 | self.center_net = CenterIntersection(self.entity_dim) 181 | elif self.geo == 'beta': 182 | hidden_dim, num_layers = beta_mode 183 | self.center_net = BetaIntersection(self.entity_dim) 184 | self.projection_net = BetaProjection(self.entity_dim * 2, 185 | self.relation_dim, 186 | hidden_dim, 187 | self.projection_regularizer, 188 | num_layers) 189 | 190 | def forward(self, positive_sample, negative_sample, subsampling_weight, batch_queries_dict, batch_idxs_dict, idxs=None): 191 | if self.geo == 'box': 192 | return self.forward_box(positive_sample, negative_sample, subsampling_weight, batch_queries_dict, batch_idxs_dict) 193 | elif self.geo == 'vec': 194 | return self.forward_vec(positive_sample, negative_sample, subsampling_weight, batch_queries_dict, batch_idxs_dict) 195 | elif self.geo == 'beta': 196 | return self.forward_beta(positive_sample, negative_sample, subsampling_weight, batch_queries_dict, batch_idxs_dict) 197 | 198 | def embed_query_box(self, queries, query_structure, idx): 199 | ''' 200 | Iterative embed a batch of queries with same structure using Query2box 201 | queries: a flattened batch of queries 202 | ''' 203 | # query_structure = query_structure_list[query_structure_idx] 204 | all_relation_flag = True 205 | for ele in query_structure[-1]: 206 | # whether the current query tree has merged to one branch 207 | # and only need to do relation traversal, 208 | # e.g., path queries or conjunctive queries after the intersection 209 | if ele not in ['r', 'n']: 210 | all_relation_flag = False 211 | break 212 | if all_relation_flag: 213 | if query_structure[0] == 'e': 214 | embedding = torch.index_select(self.entity_embedding, dim=0, index=queries[:, idx]) 215 | if self.use_cuda: 216 | offset_embedding = torch.zeros_like(embedding).cuda() 217 | else: 218 | offset_embedding = torch.zeros_like(embedding) 219 | idx += 1 220 | else: 221 | embedding, offset_embedding, idx = self.embed_query_box(queries, query_structure[0], idx) 222 | for i in range(len(query_structure[-1])): 223 | if query_structure[-1][i] == 'n': 224 | assert False, "box cannot handle queries with negation" 225 | else: 226 | r_embedding = torch.index_select(self.relation_embedding, dim=0, index=queries[:, idx]) 227 | r_offset_embedding = torch.index_select(self.offset_embedding, dim=0, index=queries[:, idx]) 228 | embedding += r_embedding 229 | offset_embedding += self.func(r_offset_embedding) 230 | idx += 1 231 | else: 232 | embedding_list = [] 233 | offset_embedding_list = [] 234 | for i in range(len(query_structure)): 235 | embedding, offset_embedding, idx = self.embed_query_box(queries, query_structure[i], idx) 236 | embedding_list.append(embedding) 237 | offset_embedding_list.append(offset_embedding) 238 | embedding = self.center_net(torch.stack(embedding_list)) 239 | offset_embedding = self.offset_net(torch.stack(offset_embedding_list)) 240 | 241 | return embedding, offset_embedding, idx 242 | 243 | def embed_query_vec(self, queries, query_structure, idx): 244 | ''' 245 | Iterative embed a batch of queries with same structure using GQE 246 | queries: a flattened batch of queries 247 | ''' 248 | # query_structure = query_structure_list[query_structure_idx] 249 | all_relation_flag = True 250 | for ele in query_structure[-1]: # whether the current query tree has merged to one branch and only need to do relation traversal, e.g., path queries or conjunctive queries after the intersection 251 | if ele not in ['r', 'n']: 252 | all_relation_flag = False 253 | break 254 | if all_relation_flag: 255 | if query_structure[0] == 'e': 256 | embedding = torch.index_select(self.entity_embedding, dim=0, index=queries[:, idx]) 257 | idx += 1 258 | else: 259 | embedding, idx = self.embed_query_vec(queries, query_structure[0], idx) 260 | for i in range(len(query_structure[-1])): 261 | if query_structure[-1][i] == 'n': 262 | assert False, "vec cannot handle queries with negation" 263 | else: 264 | r_embedding = torch.index_select(self.relation_embedding, dim=0, index=queries[:, idx]) 265 | embedding += r_embedding 266 | idx += 1 267 | else: 268 | embedding_list = [] 269 | for i in range(len(query_structure)): 270 | embedding, idx = self.embed_query_vec(queries, query_structure[i], idx) 271 | embedding_list.append(embedding) 272 | embedding = self.center_net(torch.stack(embedding_list)) 273 | 274 | return embedding, idx 275 | 276 | def embed_query_beta(self, queries, query_structure, idx): 277 | ''' 278 | Iterative embed a batch of queries with same structure using BetaE 279 | queries: a flattened batch of queries 280 | ''' 281 | all_relation_flag = True 282 | for ele in query_structure[-1]: # whether the current query tree has merged to one branch and only need to do relation traversal, e.g., path queries or conjunctive queries after the intersection 283 | if ele not in ['r', 'n']: 284 | all_relation_flag = False 285 | break 286 | if all_relation_flag: 287 | if query_structure[0] == 'e': 288 | embedding = self.entity_regularizer(torch.index_select(self.entity_embedding, dim=0, index=queries[:, idx])) 289 | idx += 1 290 | else: 291 | alpha_embedding, beta_embedding, idx = self.embed_query_beta(queries, query_structure[0], idx) 292 | embedding = torch.cat([alpha_embedding, beta_embedding], dim=-1) 293 | for i in range(len(query_structure[-1])): 294 | if query_structure[-1][i] == 'n': 295 | assert (queries[:, idx] == -2).all() 296 | embedding = 1./embedding 297 | else: 298 | r_embedding = torch.index_select(self.relation_embedding, dim=0, index=queries[:, idx]) 299 | embedding = self.projection_net(embedding, r_embedding) 300 | idx += 1 301 | alpha_embedding, beta_embedding = torch.chunk(embedding, 2, dim=-1) 302 | else: 303 | alpha_embedding_list = [] 304 | beta_embedding_list = [] 305 | for i in range(len(query_structure)): 306 | alpha_embedding, beta_embedding, idx = self.embed_query_beta(queries, query_structure[i], idx) 307 | alpha_embedding_list.append(alpha_embedding) 308 | beta_embedding_list.append(beta_embedding) 309 | alpha_embedding, beta_embedding = self.center_net(torch.stack(alpha_embedding_list), torch.stack(beta_embedding_list)) 310 | 311 | return alpha_embedding, beta_embedding, idx 312 | 313 | def cal_logit_beta(self, entity_embedding, query_dist): 314 | alpha_embedding, beta_embedding = torch.chunk(entity_embedding, 2, dim=-1) 315 | entity_dist = torch.distributions.beta.Beta(alpha_embedding, beta_embedding) 316 | logit = self.gamma - torch.norm(torch.distributions.kl.kl_divergence(entity_dist, query_dist), p=1, dim=-1) 317 | return logit 318 | 319 | def forward_beta(self, positive_sample, negative_sample, subsampling_weight, batch_queries_dict, batch_idxs_dict): 320 | all_idxs, all_alpha_embeddings, all_beta_embeddings = [], [], [] 321 | all_union_idxs, all_union_alpha_embeddings, all_union_beta_embeddings = [], [], [] 322 | for query_structure_idx in batch_queries_dict: 323 | if 'u' in self.query_name_dict[query_structure_list[query_structure_idx]] \ 324 | and 'DNF' in self.query_name_dict[query_structure_list[query_structure_idx]]: 325 | alpha_embedding, beta_embedding, _ = \ 326 | self.embed_query_beta(self.transform_union_query(batch_queries_dict[query_structure_idx], 327 | query_structure_idx), 328 | self.transform_union_structure(query_structure_idx), 329 | 0) 330 | all_union_idxs.extend(batch_idxs_dict[query_structure_idx]) 331 | all_union_alpha_embeddings.append(alpha_embedding) 332 | all_union_beta_embeddings.append(beta_embedding) 333 | else: 334 | alpha_embedding, beta_embedding, _ = self.embed_query_beta(batch_queries_dict[query_structure_idx], 335 | query_structure_list[query_structure_idx], 336 | 0) 337 | all_idxs.extend(batch_idxs_dict[query_structure_idx]) 338 | all_alpha_embeddings.append(alpha_embedding) 339 | all_beta_embeddings.append(beta_embedding) 340 | 341 | if len(all_alpha_embeddings) > 0: 342 | all_alpha_embeddings = torch.cat(all_alpha_embeddings, dim=0).unsqueeze(1) 343 | all_beta_embeddings = torch.cat(all_beta_embeddings, dim=0).unsqueeze(1) 344 | all_dists = torch.distributions.beta.Beta(all_alpha_embeddings, all_beta_embeddings) 345 | if len(all_union_alpha_embeddings) > 0: 346 | all_union_alpha_embeddings = torch.cat(all_union_alpha_embeddings, dim=0).unsqueeze(1) 347 | all_union_beta_embeddings = torch.cat(all_union_beta_embeddings, dim=0).unsqueeze(1) 348 | all_union_alpha_embeddings = all_union_alpha_embeddings.view(all_union_alpha_embeddings.shape[0]//2, 2, 1, -1) 349 | all_union_beta_embeddings = all_union_beta_embeddings.view(all_union_beta_embeddings.shape[0]//2, 2, 1, -1) 350 | all_union_dists = torch.distributions.beta.Beta(all_union_alpha_embeddings, all_union_beta_embeddings) 351 | 352 | if type(subsampling_weight) != type(None): 353 | subsampling_weight = subsampling_weight[all_idxs+all_union_idxs] 354 | 355 | if type(positive_sample) != type(None): 356 | if len(all_alpha_embeddings) > 0: 357 | positive_sample_regular = positive_sample[all_idxs] # positive samples for non-union queries in this batch 358 | positive_embedding = self.entity_regularizer(torch.index_select(self.entity_embedding, dim=0, index=positive_sample_regular).unsqueeze(1)) 359 | positive_logit = self.cal_logit_beta(positive_embedding, all_dists) 360 | else: 361 | positive_logit = torch.Tensor([]).to(self.entity_embedding.device) 362 | 363 | if len(all_union_alpha_embeddings) > 0: 364 | positive_sample_union = positive_sample[all_union_idxs] # positive samples for union queries in this batch 365 | positive_embedding = self.entity_regularizer(torch.index_select(self.entity_embedding, dim=0, index=positive_sample_union).unsqueeze(1).unsqueeze(1)) 366 | positive_union_logit = self.cal_logit_beta(positive_embedding, all_union_dists) 367 | positive_union_logit = torch.max(positive_union_logit, dim=1)[0] 368 | else: 369 | positive_union_logit = torch.Tensor([]).to(self.entity_embedding.device) 370 | positive_logit = torch.cat([positive_logit, positive_union_logit], dim=0) 371 | else: 372 | positive_logit = None 373 | 374 | if type(negative_sample) != type(None): 375 | if len(all_alpha_embeddings) > 0: 376 | negative_sample_regular = negative_sample[all_idxs] 377 | batch_size, negative_size = negative_sample_regular.shape 378 | negative_embedding = self.entity_regularizer(torch.index_select(self.entity_embedding, dim=0, index=negative_sample_regular.view(-1)).view(batch_size, negative_size, -1)) 379 | negative_logit = self.cal_logit_beta(negative_embedding, all_dists) 380 | else: 381 | negative_logit = torch.Tensor([]).to(self.entity_embedding.device) 382 | 383 | if len(all_union_alpha_embeddings) > 0: 384 | negative_sample_union = negative_sample[all_union_idxs] 385 | batch_size, negative_size = negative_sample_union.shape 386 | negative_embedding = self.entity_regularizer(torch.index_select(self.entity_embedding, dim=0, index=negative_sample_union.view(-1)).view(batch_size, 1, negative_size, -1)) 387 | negative_union_logit = self.cal_logit_beta(negative_embedding, all_union_dists) 388 | negative_union_logit = torch.max(negative_union_logit, dim=1)[0] 389 | else: 390 | negative_union_logit = torch.Tensor([]).to(self.entity_embedding.device) 391 | negative_logit = torch.cat([negative_logit, negative_union_logit], dim=0) 392 | else: 393 | negative_logit = None 394 | 395 | return positive_logit, negative_logit, subsampling_weight, all_idxs+all_union_idxs 396 | 397 | def transform_union_query(self, queries, query_structure_idx): 398 | ''' 399 | transform 2u queries to two 1p queries 400 | transform up queries to two 2p queries 401 | ''' 402 | if self.query_name_dict[query_structure_list[query_structure_idx]] == '2u-DNF': 403 | queries = queries[:, :-1] # remove union -1 404 | elif self.query_name_dict[query_structure_list[query_structure_idx]] == 'up-DNF': 405 | queries = torch.cat([torch.cat([queries[:, :2], queries[:, 5:6]], dim=1), torch.cat([queries[:, 2:4], queries[:, 5:6]], dim=1)], dim=1) 406 | queries = torch.reshape(queries, [queries.shape[0]*2, -1]) 407 | return queries 408 | 409 | def transform_union_structure(self, query_structure_idx): 410 | query_structure = query_structure_list[query_structure_idx] 411 | if self.query_name_dict[query_structure] == '2u-DNF': 412 | return ('e', ('r',)) 413 | elif self.query_name_dict[query_structure] == 'up-DNF': 414 | return ('e', ('r', 'r')) 415 | 416 | def cal_logit_box(self, entity_embedding, query_center_embedding, query_offset_embedding): 417 | delta = (entity_embedding - query_center_embedding).abs() 418 | distance_out = F.relu(delta - query_offset_embedding) 419 | distance_in = torch.min(delta, query_offset_embedding) 420 | logit = self.gamma - torch.norm(distance_out, p=1, dim=-1) - self.cen * torch.norm(distance_in, p=1, dim=-1) 421 | return logit 422 | 423 | def forward_box(self, positive_sample, negative_sample, subsampling_weight, batch_queries_dict, batch_idxs_dict): 424 | all_center_embeddings, all_offset_embeddings, all_idxs = [], [], [] 425 | all_union_center_embeddings, all_union_offset_embeddings, all_union_idxs = [], [], [] 426 | for query_structure_idx in batch_queries_dict: 427 | if 'u' in self.query_name_dict[query_structure_list[query_structure_idx]]: 428 | center_embedding, offset_embedding, _ = \ 429 | self.embed_query_box(self.transform_union_query(batch_queries_dict[query_structure_idx], 430 | query_structure_idx), 431 | self.transform_union_structure(query_structure_idx), 432 | 0) 433 | all_union_center_embeddings.append(center_embedding) 434 | all_union_offset_embeddings.append(offset_embedding) 435 | all_union_idxs.extend(batch_idxs_dict[query_structure_idx]) 436 | else: 437 | center_embedding, offset_embedding, _ = self.embed_query_box(batch_queries_dict[query_structure_idx], 438 | query_structure_list[query_structure_idx], 439 | 0) 440 | all_center_embeddings.append(center_embedding) 441 | all_offset_embeddings.append(offset_embedding) 442 | all_idxs.extend(batch_idxs_dict[query_structure_idx]) 443 | 444 | if len(all_center_embeddings) > 0 and len(all_offset_embeddings) > 0: 445 | all_center_embeddings = torch.cat(all_center_embeddings, dim=0).unsqueeze(1) 446 | all_offset_embeddings = torch.cat(all_offset_embeddings, dim=0).unsqueeze(1) 447 | if len(all_union_center_embeddings) > 0 and len(all_union_offset_embeddings) > 0: 448 | all_union_center_embeddings = torch.cat(all_union_center_embeddings, dim=0).unsqueeze(1) 449 | all_union_offset_embeddings = torch.cat(all_union_offset_embeddings, dim=0).unsqueeze(1) 450 | all_union_center_embeddings = all_union_center_embeddings.view(all_union_center_embeddings.shape[0]//2, 2, 1, -1) 451 | all_union_offset_embeddings = all_union_offset_embeddings.view(all_union_offset_embeddings.shape[0]//2, 2, 1, -1) 452 | 453 | if type(subsampling_weight) != type(None): 454 | subsampling_weight = subsampling_weight[all_idxs+all_union_idxs] 455 | 456 | if type(positive_sample) != type(None): 457 | if len(all_center_embeddings) > 0: 458 | positive_sample_regular = positive_sample[all_idxs] 459 | positive_embedding = torch.index_select(self.entity_embedding, dim=0, index=positive_sample_regular).unsqueeze(1) 460 | positive_logit = self.cal_logit_box(positive_embedding, all_center_embeddings, all_offset_embeddings) 461 | else: 462 | positive_logit = torch.Tensor([]).to(self.entity_embedding.device) 463 | 464 | if len(all_union_center_embeddings) > 0: 465 | positive_sample_union = positive_sample[all_union_idxs] 466 | positive_embedding = torch.index_select(self.entity_embedding, dim=0, index=positive_sample_union).unsqueeze(1).unsqueeze(1) 467 | positive_union_logit = self.cal_logit_box(positive_embedding, all_union_center_embeddings, all_union_offset_embeddings) 468 | positive_union_logit = torch.max(positive_union_logit, dim=1)[0] 469 | else: 470 | positive_union_logit = torch.Tensor([]).to(self.entity_embedding.device) 471 | positive_logit = torch.cat([positive_logit, positive_union_logit], dim=0) 472 | else: 473 | positive_logit = None 474 | 475 | if type(negative_sample) != type(None): 476 | if len(all_center_embeddings) > 0: 477 | negative_sample_regular = negative_sample[all_idxs] 478 | batch_size, negative_size = negative_sample_regular.shape 479 | negative_embedding = torch.index_select(self.entity_embedding, dim=0, index=negative_sample_regular.view(-1)).view(batch_size, negative_size, -1) 480 | negative_logit = self.cal_logit_box(negative_embedding, all_center_embeddings, all_offset_embeddings) 481 | else: 482 | negative_logit = torch.Tensor([]).to(self.entity_embedding.device) 483 | 484 | if len(all_union_center_embeddings) > 0: 485 | negative_sample_union = negative_sample[all_union_idxs] 486 | batch_size, negative_size = negative_sample_union.shape 487 | negative_embedding = torch.index_select(self.entity_embedding, dim=0, index=negative_sample_union.view(-1)).view(batch_size, 1, negative_size, -1) 488 | negative_union_logit = self.cal_logit_box(negative_embedding, all_union_center_embeddings, all_union_offset_embeddings) 489 | negative_union_logit = torch.max(negative_union_logit, dim=1)[0] 490 | else: 491 | negative_union_logit = torch.Tensor([]).to(self.entity_embedding.device) 492 | negative_logit = torch.cat([negative_logit, negative_union_logit], dim=0) 493 | else: 494 | negative_logit = None 495 | 496 | return positive_logit, negative_logit, subsampling_weight, all_idxs+all_union_idxs 497 | 498 | def cal_logit_vec(self, entity_embedding, query_embedding): 499 | distance = entity_embedding - query_embedding 500 | logit = self.gamma - torch.norm(distance, p=1, dim=-1) 501 | return logit 502 | 503 | def forward_vec(self, positive_sample, negative_sample, subsampling_weight, batch_queries_dict, batch_idxs_dict): 504 | all_center_embeddings, all_idxs = [], [] 505 | all_union_center_embeddings, all_union_idxs = [], [] 506 | for query_structure_idx in batch_queries_dict: 507 | if 'u' in self.query_name_dict[query_structure_list[query_structure_idx]] and 'DNF' in self.query_name_dict[query_structure_list[query_structure_idx]]: 508 | center_embedding, _ = self.embed_query_vec(self.transform_union_query(batch_queries_dict[query_structure_idx], 509 | query_structure_idx), 510 | self.transform_union_structure(query_structure_idx), 0) 511 | all_union_center_embeddings.append(center_embedding) 512 | all_union_idxs.extend(batch_idxs_dict[query_structure_idx]) 513 | else: 514 | center_embedding, _ = self.embed_query_vec(batch_queries_dict[query_structure_idx], query_structure_list[query_structure_idx], 0) 515 | all_center_embeddings.append(center_embedding) 516 | all_idxs.extend(batch_idxs_dict[query_structure_idx]) 517 | 518 | if len(all_center_embeddings) > 0: 519 | all_center_embeddings = torch.cat(all_center_embeddings, dim=0).unsqueeze(1) 520 | if len(all_union_center_embeddings) > 0: 521 | all_union_center_embeddings = torch.cat(all_union_center_embeddings, dim=0).unsqueeze(1) 522 | all_union_center_embeddings = all_union_center_embeddings.view(all_union_center_embeddings.shape[0]//2, 2, 1, -1) 523 | 524 | if type(subsampling_weight) != type(None): 525 | subsampling_weight = subsampling_weight[all_idxs+all_union_idxs] 526 | 527 | if type(positive_sample) != type(None): 528 | if len(all_center_embeddings) > 0: 529 | positive_sample_regular = positive_sample[all_idxs] 530 | positive_embedding = torch.index_select(self.entity_embedding, dim=0, index=positive_sample_regular).unsqueeze(1) 531 | positive_logit = self.cal_logit_vec(positive_embedding, all_center_embeddings) 532 | else: 533 | positive_logit = torch.Tensor([]).to(self.entity_embedding.device) 534 | 535 | if len(all_union_center_embeddings) > 0: 536 | positive_sample_union = positive_sample[all_union_idxs] 537 | positive_embedding = torch.index_select(self.entity_embedding, dim=0, index=positive_sample_union).unsqueeze(1).unsqueeze(1) 538 | positive_union_logit = self.cal_logit_vec(positive_embedding, all_union_center_embeddings) 539 | positive_union_logit = torch.max(positive_union_logit, dim=1)[0] 540 | else: 541 | positive_union_logit = torch.Tensor([]).to(self.entity_embedding.device) 542 | positive_logit = torch.cat([positive_logit, positive_union_logit], dim=0) 543 | else: 544 | positive_logit = None 545 | 546 | if type(negative_sample) != type(None): 547 | if len(all_center_embeddings) > 0: 548 | negative_sample_regular = negative_sample[all_idxs] 549 | batch_size, negative_size = negative_sample_regular.shape 550 | negative_embedding = torch.index_select(self.entity_embedding, dim=0, index=negative_sample_regular.view(-1)).view(batch_size, negative_size, -1) 551 | negative_logit = self.cal_logit_vec(negative_embedding, all_center_embeddings) 552 | else: 553 | negative_logit = torch.Tensor([]).to(self.entity_embedding.device) 554 | 555 | if len(all_union_center_embeddings) > 0: 556 | negative_sample_union = negative_sample[all_union_idxs] 557 | batch_size, negative_size = negative_sample_union.shape 558 | negative_embedding = torch.index_select(self.entity_embedding, dim=0, index=negative_sample_union.view(-1)).view(batch_size, 1, negative_size, -1) 559 | negative_union_logit = self.cal_logit_vec(negative_embedding, all_union_center_embeddings) 560 | negative_union_logit = torch.max(negative_union_logit, dim=1)[0] 561 | else: 562 | negative_union_logit = torch.Tensor([]).to(self.entity_embedding.device) 563 | negative_logit = torch.cat([negative_logit, negative_union_logit], dim=0) 564 | else: 565 | negative_logit = None 566 | 567 | return positive_logit, negative_logit, subsampling_weight, all_idxs+all_union_idxs 568 | 569 | @staticmethod 570 | def train_step(model, optimizer, train_iterator, args, step): 571 | model.train() 572 | optimizer.zero_grad() 573 | 574 | positive_sample, negative_sample, subsampling_weight, batch_queries, query_structures = next(train_iterator) 575 | batch_queries_dict = collections.defaultdict(list) 576 | batch_idxs_dict = collections.defaultdict(list) 577 | for i, query in enumerate(batch_queries): # group queries with same structure 578 | batch_queries_dict[query_structures[i]].append(query) 579 | batch_idxs_dict[query_structures[i]].append(i) 580 | for query_structure in batch_queries_dict: 581 | if args.cuda: 582 | batch_queries_dict[query_structure] = torch.LongTensor(batch_queries_dict[query_structure]).cuda() 583 | else: 584 | batch_queries_dict[query_structure] = torch.LongTensor(batch_queries_dict[query_structure]) 585 | if args.cuda: 586 | positive_sample = positive_sample.cuda() 587 | negative_sample = negative_sample.cuda() 588 | subsampling_weight = subsampling_weight.cuda() 589 | 590 | positive_logit, negative_logit, subsampling_weight, _ = model(positive_sample, negative_sample, subsampling_weight, batch_queries_dict, batch_idxs_dict) 591 | negative_score = F.logsigmoid(-negative_logit).mean(dim=1) 592 | positive_score = F.logsigmoid(positive_logit).squeeze(dim=1) 593 | positive_sample_loss = - (subsampling_weight * positive_score).sum() 594 | negative_sample_loss = - (subsampling_weight * negative_score).sum() 595 | positive_sample_loss /= subsampling_weight.sum() 596 | negative_sample_loss /= subsampling_weight.sum() 597 | 598 | loss = (positive_sample_loss + negative_sample_loss)/2 599 | loss.backward() 600 | optimizer.step() 601 | log = { 602 | 'positive_sample_loss': positive_sample_loss.item(), 603 | 'negative_sample_loss': negative_sample_loss.item(), 604 | 'loss': loss.item(), 605 | } 606 | return log 607 | 608 | @staticmethod 609 | def test_step(model, easy_answers, hard_answers, args, test_dataloader, query_name_dict, save_result=False, save_str="", save_empty=False): 610 | model.eval() 611 | 612 | step = 0 613 | total_steps = len(test_dataloader) 614 | logs = collections.defaultdict(list) 615 | 616 | with torch.no_grad(): 617 | for negative_sample, queries, queries_unflatten, query_structure_idxs in tqdm(test_dataloader, disable=not args.print_on_screen): 618 | batch_queries_dict = collections.defaultdict(list) 619 | batch_idxs_dict = collections.defaultdict(list) 620 | for i, query in enumerate(queries): 621 | batch_queries_dict[query_structure_idxs[i]].append(query) 622 | batch_idxs_dict[query_structure_idxs[i]].append(i) 623 | for query_structure_idx in batch_queries_dict: 624 | if args.cuda: 625 | batch_queries_dict[query_structure_idx] = torch.LongTensor(batch_queries_dict[query_structure_idx]).cuda() 626 | else: 627 | batch_queries_dict[query_structure_idx] = torch.LongTensor(batch_queries_dict[query_structure_idx]) 628 | if args.cuda: 629 | negative_sample = negative_sample.cuda() 630 | 631 | _, negative_scores, _, idxs = model(None, negative_sample, None, batch_queries_dict, batch_idxs_dict) 632 | queries_unflatten = [queries_unflatten[i] for i in idxs] 633 | query_structure_idxs = [query_structure_idxs[i] for i in idxs] 634 | argsort = torch.argsort(negative_scores, dim=1, descending=True) 635 | ranking = argsort.clone().to(torch.float) 636 | if len(argsort) == args.test_batch_size: # if it is the same shape with test_batch_size, we can reuse batch_entity_range without creating a new one 637 | ranking = ranking.scatter_(1, argsort, model.batch_entity_range) # achieve the ranking of all entities 638 | else: # otherwise, create a new torch Tensor for batch_entity_range 639 | if args.cuda: 640 | ranking = ranking.scatter_(1, 641 | argsort, 642 | torch.arange(model.nentity).to(torch.float).repeat(argsort.shape[0], 643 | 1).cuda() 644 | ) # achieve the ranking of all entities 645 | else: 646 | ranking = ranking.scatter_(1, 647 | argsort, 648 | torch.arange(model.nentity).to(torch.float).repeat(argsort.shape[0], 649 | 1) 650 | ) # achieve the ranking of all entities 651 | for idx, (i, query, query_structure_idx) in enumerate(zip(argsort[:, 0], queries_unflatten, query_structure_idxs)): 652 | query_structure = query_structure_list[query_structure_idx] 653 | 654 | hard_answer = hard_answers[query] 655 | easy_answer = easy_answers[query] 656 | num_hard = len(hard_answer) 657 | num_easy = len(easy_answer) 658 | assert len(hard_answer.intersection(easy_answer)) == 0 659 | cur_ranking = ranking[idx, list(easy_answer) + list(hard_answer)] 660 | cur_ranking, indices = torch.sort(cur_ranking) 661 | masks = indices >= num_easy 662 | if args.cuda: 663 | answer_list = torch.arange(num_hard + num_easy).to(torch.float).cuda() 664 | else: 665 | answer_list = torch.arange(num_hard + num_easy).to(torch.float) 666 | cur_ranking = cur_ranking - answer_list + 1 # filtered setting 667 | cur_ranking = cur_ranking[masks] # only take indices that belong to the hard answers 668 | 669 | mrr = torch.mean(1./cur_ranking).item() 670 | h1 = torch.mean((cur_ranking <= 1).to(torch.float)).item() 671 | h3 = torch.mean((cur_ranking <= 3).to(torch.float)).item() 672 | h10 = torch.mean((cur_ranking <= 10).to(torch.float)).item() 673 | 674 | logs[query_structure].append({ 675 | 'MRR': mrr, 676 | 'HITS1': h1, 677 | 'HITS3': h3, 678 | 'HITS10': h10, 679 | 'num_hard_answer': num_hard, 680 | }) 681 | 682 | if step % args.test_log_steps == 0: 683 | logging.info('Evaluating the model... (%d/%d)' % (step, total_steps)) 684 | 685 | step += 1 686 | 687 | metrics = collections.defaultdict(lambda: collections.defaultdict(int)) 688 | for query_structure_idx in logs: 689 | for metric in logs[query_structure_idx][0].keys(): 690 | if metric in ['num_hard_answer']: 691 | continue 692 | metrics[query_structure_idx][metric] = sum([log[metric] for log in logs[query_structure_idx]])/len(logs[query_structure_idx]) 693 | metrics[query_structure_idx]['num_queries'] = len(logs[query_structure_idx]) 694 | 695 | return metrics -------------------------------------------------------------------------------- /operations.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from util import get_regularizer 5 | 6 | class Projection(nn.Module): 7 | # def __init__(self, entity_dim, logic_type, regularizer_setting): 8 | def __init__( 9 | self, 10 | nrelation, 11 | entity_dim, 12 | logic_type, 13 | regularizer_setting, 14 | relation_dim, 15 | projection_dim, 16 | num_layers, 17 | projection_type, 18 | num_rel_base, # for 'rtransform' 19 | ): 20 | super(Projection, self).__init__() 21 | self.logic = logic_type 22 | 23 | # # temporary testing 24 | # regularizer_setting = { 25 | # 'type': 'sigmoid', 26 | # } 27 | 28 | self.regularizer = get_regularizer(regularizer_setting, entity_dim, neg_input_possible=True) 29 | # for projection 30 | self.entity_dim = entity_dim 31 | self.relation_dim = relation_dim 32 | self.projection_type = projection_type 33 | 34 | self.dual = regularizer_setting['dual'] 35 | 36 | 37 | # mlp 38 | if projection_type == 'mlp': 39 | self.relation_embedding = nn.Parameter(torch.zeros(nrelation, self.entity_dim)) # same dim 40 | nn.init.uniform_(tensor=self.relation_embedding, a=0, b=1) 41 | 42 | # mlp 43 | self.hidden_dim = projection_dim 44 | self.num_layers = num_layers 45 | self.layer1 = nn.Linear(self.entity_dim + self.relation_dim, self.hidden_dim) # 1st layer 46 | self.layer0 = nn.Linear(self.hidden_dim, self.entity_dim) # final layer 47 | for nl in range(2, num_layers + 1): 48 | setattr(self, "layer{}".format(nl), nn.Linear(self.hidden_dim, self.hidden_dim)) 49 | for nl in range(num_layers + 1): 50 | nn.init.xavier_uniform_(getattr(self, "layer{}".format(nl)).weight) 51 | elif projection_type == 'rtransform': 52 | n_base = num_rel_base 53 | if not self.dual: 54 | self.hidden_dim = entity_dim 55 | self.rel_base = nn.Parameter(torch.zeros(n_base, self.hidden_dim, self.hidden_dim)) 56 | # nn.init.uniform_(self.rel_base, a=0, b=1e-2) 57 | self.rel_bias = nn.Parameter(torch.zeros(n_base, self.hidden_dim)) 58 | self.rel_att = nn.Parameter(torch.zeros(nrelation, n_base)) 59 | self.norm = nn.LayerNorm(self.hidden_dim, elementwise_affine=False) 60 | 61 | # new initialization 62 | torch.nn.init.orthogonal_(self.rel_base) 63 | torch.nn.init.xavier_normal_(self.rel_bias) 64 | torch.nn.init.xavier_normal_(self.rel_att) 65 | 66 | else: 67 | self.hidden_dim = entity_dim//2 68 | 69 | # for property vals 70 | self.rel_base1 = nn.Parameter(torch.randn(n_base, self.hidden_dim, self.hidden_dim)) 71 | nn.init.uniform_(self.rel_base1, a=0, b=1e-2) 72 | self.rel_bias1 = nn.Parameter(torch.zeros(nrelation, self.hidden_dim)) 73 | self.rel_att1 = nn.Parameter(torch.randn(nrelation, n_base)) 74 | self.norm1 = nn.LayerNorm(self.hidden_dim, elementwise_affine=False) 75 | 76 | # for property weights 77 | self.rel_base2 = nn.Parameter(torch.randn(n_base, self.hidden_dim, self.hidden_dim)) 78 | nn.init.uniform_(self.rel_base2, a=0, b=1e-2) 79 | self.rel_bias2 = nn.Parameter(torch.zeros(nrelation, self.hidden_dim)) 80 | self.rel_att2 = nn.Parameter(torch.randn(nrelation, n_base)) 81 | self.norm2 = nn.LayerNorm(self.hidden_dim, elementwise_affine=False) 82 | elif projection_type == 'transe': 83 | self.hidden_dim = entity_dim 84 | self.rel_trans = nn.Parameter(torch.zeros(nrelation, self.hidden_dim)) 85 | self.rel_bias = nn.Parameter(torch.zeros(nrelation, self.hidden_dim)) 86 | torch.nn.init.xavier_normal_(self.rel_trans) 87 | torch.nn.init.xavier_normal_(self.rel_bias) 88 | self.norm = nn.LayerNorm(self.hidden_dim, elementwise_affine=False) 89 | 90 | 91 | 92 | def forward(self, e_embedding, rid): 93 | if self.projection_type == 'mlp': 94 | r_embedding = torch.index_select(self.relation_embedding, dim=0, index=rid) 95 | x = torch.cat([e_embedding, r_embedding], dim=-1) 96 | for nl in range(1, self.num_layers + 1): 97 | x = F.relu(getattr(self, "layer{}".format(nl))(x)) 98 | x = self.layer0(x) 99 | x = self.regularizer(x) 100 | return x 101 | 102 | if self.projection_type == 'rtransform': 103 | if not self.dual: 104 | project_r = torch.einsum('br,rio->bio', self.rel_att[rid], self.rel_base) 105 | if self.rel_bias.shape[0] == self.rel_base.shape[0]: 106 | bias = torch.einsum('br,ri->bi', self.rel_att[rid], self.rel_bias) 107 | else: 108 | bias = self.rel_bias[rid] 109 | output = torch.einsum('bio,bi->bo', project_r, e_embedding) + bias 110 | output = self.norm(output) 111 | else: 112 | e_embedding1, e_embedding2 = torch.chunk(e_embedding, 2, dim=-1) 113 | project_r1 = torch.einsum('br,rio->bio', self.rel_att1[rid], self.rel_base1) 114 | bias1 = self.rel_bias1[rid] 115 | output1 = torch.einsum('bio,bi->bo', project_r1, e_embedding1) + bias1 116 | output1 = self.norm1(output1) 117 | 118 | project_r2 = torch.einsum('br,rio->bio', self.rel_att2[rid], self.rel_base2) 119 | bias2 = self.rel_bias2[rid] 120 | output2 = torch.einsum('bio,bi->bo', project_r2, e_embedding2) + bias2 121 | output2 = self.norm1(output2) 122 | 123 | output = torch.cat((output1, output2), dim=-1) 124 | 125 | output = self.regularizer(output) 126 | return output 127 | 128 | if self.projection_type == 'transe': 129 | r_trans = torch.index_select(self.rel_trans, dim=0, index=rid) 130 | r_bias = torch.index_select(self.rel_bias, dim=0, index=rid) 131 | output = e_embedding * r_trans + r_bias 132 | 133 | output = self.norm(output) 134 | output = self.regularizer(output) 135 | return output 136 | 137 | 138 | 139 | class Conjunction(nn.Module): 140 | def __init__(self, entity_dim, logic_type, regularizer_setting, use_attention='False', godel_gumbel_beta=0.01): 141 | super(Conjunction, self).__init__() 142 | self.logic = logic_type 143 | self.regularizer = get_regularizer(regularizer_setting, entity_dim) 144 | self.use_attention = use_attention 145 | self.entity_dim = entity_dim 146 | 147 | if logic_type == 'godel_gumbel': 148 | self.godel_gumbel_beta = godel_gumbel_beta 149 | if use_attention: 150 | self.conjunction_layer1 = nn.Linear(self.entity_dim, self.entity_dim) 151 | # self.conjunction_layer2 = nn.Linear(self.entity_dim, self.entity_dim) 152 | self.conjunction_layer2 = nn.Linear(self.entity_dim, 1) # no dimension-wise attention 153 | nn.init.xavier_uniform_(self.conjunction_layer1.weight) 154 | nn.init.xavier_uniform_(self.conjunction_layer2.weight) 155 | self.norm = nn.LayerNorm(entity_dim, elementwise_affine=False) 156 | 157 | def forward(self, embeddings): 158 | """ 159 | :param embeddings: shape (# of sets, batch, dim). 160 | :return embeddings: shape (batch, dim) 161 | """ 162 | if self.logic == 'godel': 163 | if self.logic == 'godel': 164 | # conjunction(x,y) = min{x,y} 165 | embeddings, _ = torch.min(embeddings, dim=0) 166 | elif self.logic == 'godel_gumbel': 167 | # soft way to compute min 168 | embeddings = -self.godel_gumbel_beta * torch.logsumexp( 169 | -embeddings / self.godel_gumbel_beta, 170 | 0 171 | ) 172 | return embeddings 173 | else: # logic == product 174 | if self.logic == 'luka': 175 | # conjunction(x,y) = max{0, x+y-1} 176 | embeddings = torch.sum(embeddings, dim=0) - embeddings.shape[0] + 1 177 | elif self.logic == 'product': 178 | if not self.use_attention: 179 | # conjunction(x,y) = xy 180 | embeddings = torch.prod(embeddings, dim=0) 181 | else: 182 | attention = self.get_conjunction_attention(embeddings) 183 | # attention conjunction(x,y) = (x^p)*(y^q), p+q=1 184 | # compute in log scale 185 | epsilon = 1e-7 # avoid torch.log(0) 186 | embeddings = torch.log(embeddings+epsilon) 187 | embeddings = torch.exp(torch.sum(embeddings * attention, dim=0)) 188 | embeddings = self.norm(embeddings) 189 | return self.regularizer(embeddings) 190 | 191 | def get_conjunction_attention(self, embeddings): 192 | layer1_act = F.relu(self.conjunction_layer1(embeddings)) # (num_conj, batch_size, 2 * dim) 193 | attention = F.softmax(self.conjunction_layer2(layer1_act)/torch.sqrt(self.entity_dim), dim=0) # (num_conj, batch_size, 1) 194 | return attention 195 | 196 | 197 | class Disjunction(nn.Module): 198 | def __init__(self, entity_dim, logic_type, regularizer_setting, godel_gumbel_beta=0.01): 199 | super(Disjunction, self).__init__() 200 | self.logic = logic_type 201 | self.regularizer = get_regularizer(regularizer_setting, entity_dim) 202 | 203 | if logic_type == 'godel_gumbel': 204 | self.godel_gumbel_beta = godel_gumbel_beta 205 | 206 | self.norm = nn.LayerNorm(entity_dim, elementwise_affine=False) 207 | 208 | 209 | def forward(self, embeddings): 210 | """ 211 | :param embeddings: shape (# of sets, batch, dim). 212 | :return embeddings: shape (batch, dim) 213 | """ 214 | if self.logic == 'godel': 215 | if self.logic == 'godel': 216 | # disjunction(x,y) = max{x,y} 217 | embeddings, _ = torch.max(embeddings, dim=0) 218 | return embeddings 219 | elif self.logic == 'godel_gumbel': 220 | # soft way to compute max 221 | embeddings = self.godel_gumbel_beta * torch.logsumexp( 222 | embeddings / self.godel_gumbel_beta, 223 | 0 224 | ) 225 | return embeddings 226 | else: 227 | if self.logic == 'luka': 228 | # disjunction(x,y) = min{1, x+y} 229 | embeddings = torch.sum(embeddings, dim=0) 230 | else: # self.logic == 'product' 231 | # disjunction(x,y) = x+y-xy 232 | embeddings = torch.sum(embeddings, dim=0) - torch.prod(embeddings, dim=0) 233 | return self.regularizer(embeddings) 234 | 235 | 236 | class Negation(nn.Module): 237 | def __init__(self, entity_dim, logic_type, regularizer_setting): 238 | super(Negation, self).__init__() 239 | self.logic = logic_type 240 | self.regularizer = get_regularizer(regularizer_setting, entity_dim) 241 | 242 | def forward(self, embeddings): 243 | """ 244 | :param embeddings: shape (# of sets, batch, dim). 245 | :return embeddings: shape (batch, dim) 246 | """ 247 | # negation(x) = 1-x 248 | return 1 - embeddings 249 | 250 | 251 | -------------------------------------------------------------------------------- /regularizers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | # BetaE 6 | class Regularizer(): 7 | def __init__(self, base_add, min_val, max_val): 8 | self.base_add = base_add 9 | self.min_val = min_val 10 | self.max_val = max_val 11 | 12 | def __call__(self, entity_embedding): 13 | return torch.clamp(entity_embedding + self.base_add, self.min_val, self.max_val) 14 | 15 | 16 | class SigmoidRegularizer(nn.Module): 17 | def __init__(self, vector_dim, dual=False): 18 | """ 19 | :param dual: Split each embedding into 2 chunks. 20 | The first chunk is property values and the second is property weight. 21 | Do NOT sigmoid the second chunk. 22 | """ 23 | super(SigmoidRegularizer, self).__init__() 24 | self.vector_dim = vector_dim 25 | # initialize weight as 8 and bias as -4, so that 0~1 input still mostly falls in 0~1 26 | self.weight = nn.Parameter(torch.Tensor([8])) 27 | self.bias = nn.Parameter(torch.Tensor([-4])) 28 | 29 | self.dual = dual 30 | 31 | def __call__(self, entity_embedding): 32 | if not self.dual: 33 | return torch.sigmoid(entity_embedding * self.weight + self.bias) 34 | else: 35 | # The first half is property values and the second is property weight. 36 | # Do NOT sigmoid the second chunk. The second chunk will be free parameters 37 | entity_vals, entity_val_weights = torch.chunk(entity_embedding, 2, dim=-1) 38 | entity_vals = torch.sigmoid(entity_vals * self.weight + self.bias) 39 | return torch.cat((entity_vals, entity_val_weights), dim=-1) 40 | 41 | 42 | def soft_discretize(self, entity_embedding, temperature=10): 43 | return torch.sigmoid((entity_embedding * self.weight + self.bias)*temperature) # soft 44 | 45 | def hard_discretize(self, entity_embedding, temperature=10, thres=0.5): 46 | discrete = self.soft_discretize(entity_embedding, temperature) 47 | discrete[discrete>=thres] = 1 48 | discrete[discrete=0] = 0 # if min_per_row is positive, no need to shift 182 | # reshaped -= min_per_row # shift by the minimum negative value 183 | 184 | # L1 normalize 185 | reshaped = F.normalize(reshaped, p=1, dim=-1) # L1 normalize along the last dimension 186 | reshaped = reshaped.view(*dims, last_dim) # change to original shape 187 | return reshaped 188 | 189 | def reshape_to_matrix(self, embeddings): 190 | # reshape the last dimension into a matrix 191 | dims, last_dim = embeddings.size()[:-1], embeddings.size()[-1] 192 | n_row = last_dim//self.k 193 | n_col = self.k 194 | 195 | reshaped = embeddings.view(*dims, n_row, n_col) 196 | return reshaped 197 | 198 | def reshape_to_vector(self, embeddings_matrix): 199 | dims, n_row, n_col = embeddings_matrix.size()[:-2], embeddings_matrix.size()[-2], embeddings_matrix.size()[-1] 200 | last_dim = n_row*n_col 201 | return embeddings_matrix.view(*dims, last_dim) 202 | 203 | def hard_discretize(self, embeddings): 204 | """ 205 | Discretize as a matrix. k entries per row => one '1' per row. 206 | No gradient. 207 | No normalization added. (not needed) 208 | :param embeddings: shape [batch_size, 1 or num_neg, entity_dim], 0<=embeddings[i]<=1 209 | :return y_hard: [batch_size, 1 or num_neg, entity_dim] 210 | """ 211 | y = self.reshape_to_matrix(embeddings) 212 | shape = y.size() 213 | _, ind = y.max(dim=-1) 214 | y_hard = torch.zeros_like(y).view(-1, shape[-1]) 215 | y_hard.scatter_(1, ind.view(-1, 1), 1) 216 | y_hard = y_hard.view(*shape) # shape [*dims, entity_dim//k, k] 217 | y_hard = self.reshape_to_vector(y_hard) 218 | return y_hard 219 | 220 | def soft_discretize(self, embeddings, gumbel_temperature): 221 | """ 222 | Discretize as a matrix. k entries per row => one '1' per row. 223 | Soft discretize using Gumbel softmax. 224 | No normalization added. (not needed) 225 | :param embeddings: shape [batch_size, 1 or num_neg, entity_dim], 0<=embeddings[i]<=1 226 | :param gumbel_temperature: max(0.5, exp(-rt)), r={1e-4, 1e-5} 227 | :return y_hard: [batch_size, 1 or num_neg, entity_dim] 228 | """ 229 | y = self.reshape_to_matrix(embeddings) 230 | eps = 1e-5 231 | log_y = torch.log(y+eps) 232 | y_soft = F.gumbel_softmax(log_y, tau=gumbel_temperature, hard=False) 233 | y_soft = self.reshape_to_vector(y_soft) 234 | return y_soft 235 | 236 | def L1_normalize(self, embeddings): 237 | """ 238 | :param embeddings: shape [batch_size, dim] 239 | :return: shape [batch_size, dim] 240 | """ 241 | k = self.k 242 | # reshape the last dimension into a matrix 243 | dims, last_dim = embeddings.size()[:-1], embeddings.size()[-1] 244 | n_row = last_dim // k 245 | n_col = k 246 | 247 | reshaped = embeddings.view(*dims, n_row, n_col) 248 | 249 | # L1 normalize 250 | reshaped = F.normalize(reshaped, p=1, dim=-1) # L1 normalize along the last dimension 251 | reshaped = reshaped.view(*dims, last_dim) # change to original shape 252 | return reshaped 253 | 254 | def get_num_distributions(self): 255 | return self.vector_dim // self.k 256 | 257 | 258 | 259 | class MatrixSigmoidSumRegularizer(MatrixSumRegularizer): 260 | def __init__(self, vector_dim, k, neg_input_possible=False): 261 | super(MatrixSigmoidSumRegularizer, self).__init__(vector_dim, k, neg_input_possible) 262 | # initialize weight as 8 and bias as -4, so that 0~1 input still mostly falls in 0~1 263 | self.weight = nn.Parameter(torch.Tensor([1])) 264 | self.bias = nn.Parameter(torch.Tensor([0])) 265 | 266 | def forward(self, embeddings): 267 | """ 268 | :param embeddings: shape [batch_size, dim] 269 | """ 270 | # reshape the last dimension into a matrix 271 | dims, last_dim = embeddings.size()[:-1], embeddings.size()[-1] 272 | n_row = last_dim//self.k 273 | n_col = self.k 274 | 275 | reshaped = embeddings.view(*dims, n_row, n_col) 276 | 277 | if self.neg_input_possible: # for entity free parameters 278 | # shift to non-negative 279 | reshaped = torch.sigmoid(reshaped * self.weight + self.bias) 280 | 281 | # L1 normalize 282 | reshaped = F.normalize(reshaped, p=1, dim=-1) # L1 normalize along the last dimension 283 | reshaped = reshaped.view(*dims, last_dim) # change to original shape 284 | return reshaped 285 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==1.9.0 2 | wandb==0.9.7 -------------------------------------------------------------------------------- /run.sh: -------------------------------------------------------------------------------- 1 | python main.py --gpu_ids 1 --cuda --cpu_num 2 --test_batch_size 2 --data_path data/NELL-betae --do_train --do_valid --do_test --max_steps 450001 --valid_steps 5000 --margin_type logsigmoid_bpr --load_pretrained --regularizer 01 -d 1000 -b 512 -n 128 -lr 5e-4 --lr_scheduler annealing --optimizer AdamW --L2_reg 5e-2 --gamma_coff 20 -g 0.5 --projection_type rtransform --num_rel_base 30 -------------------------------------------------------------------------------- /test-pretrained-model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from os.path import join 3 | from fuzzyreasoning import KGFuzzyReasoning 4 | from dataloader import load_data_from_pickle, load_data 5 | from constants import query_name_dict, query_structure_list, query_structure2idx 6 | from collections import defaultdict 7 | from main import parse_args 8 | from util import evaluate, read_num_entity_relation_from_file, eval_tuple, wandb_initialize 9 | import collections 10 | import copy 11 | from constants import query_structure2idx, query_name_dict, query_structure_list 12 | from dataloader import TestDataset 13 | from torch.utils.data import DataLoader 14 | import pickle 15 | import pandas as pd 16 | import numpy as np 17 | import matplotlib.pyplot as plt 18 | import torch.nn as nn 19 | from util import log_metrics 20 | import torch.nn.functional as F 21 | import os 22 | from investigation_helper import wandb_log_metrics, prepare_new_attributes 23 | import sys 24 | 25 | 26 | 27 | 28 | def test_step(model, easy_answers, hard_answers, args, test_dataloader, query_name_dict, verbose=False): 29 | model.eval() 30 | 31 | device = model.device 32 | 33 | step = 0 34 | total_steps = len(test_dataloader) 35 | logs = collections.defaultdict(list) 36 | 37 | with torch.no_grad(): 38 | for negative_sample, queries, queries_unflatten, query_structure_idxs in test_dataloader: 39 | # example: query_structures: [('e', ('r',))]. queries: [[1804,4]]. queries_unflatten: [(1804, (4,)] 40 | if args.cuda: 41 | negative_sample = negative_sample.to(device) 42 | 43 | # nn.DataParallel helper 44 | batch_size = len(negative_sample) 45 | slice_idxs = torch.arange(0, batch_size).view((batch_size, 1)) 46 | 47 | _, negative_logit, _, idxs = model( 48 | None, 49 | negative_sample, 50 | None, 51 | queries, # np.array([queries]), won't be split when using multiple GPUs 52 | query_structure_idxs, 53 | slice_idxs, # to help track batch_queries and query_structures when using multiple GPUs 54 | inference=True 55 | ) 56 | 57 | idxs_np = idxs.detach().cpu().numpy() 58 | # if not converted to numpy, idxs_np will be considered scalar when test_batch_size=1 59 | # queries_unflatten = queries_unflatten[idxs_np] 60 | query_structure_idxs = query_structure_idxs[idxs_np] 61 | queries_unflatten = [queries_unflatten[i] for i in idxs] 62 | 63 | # 64 | # query_structures = [query_structures[i] for i in idxs] 65 | argsort = torch.argsort(negative_logit, dim=1, descending=True) 66 | ranking = argsort.clone().to(torch.float) 67 | 68 | # rank all entities 69 | # If it is the same shape with test_batch_size, reuse batch_entity_range without creating a new one 70 | if len(argsort) == args.test_batch_size: 71 | # ranking = ranking.scatter_(1, argsort, model.module.batch_entity_range) # achieve the ranking of all entities 72 | ranking = ranking.scatter_(1, argsort, model.batch_entity_range) # achieve the ranking of all entities 73 | else: # otherwise, create a new torch Tensor for batch_entity_range 74 | ranking = ranking.scatter_( 75 | 1, 76 | argsort, 77 | torch.arange(model.nentity).to(torch.float).repeat(argsort.shape[0], 1).to(device) 78 | # torch.arange(model.module.nentity).to(torch.float).repeat(argsort.shape[0], 1).to(device) 79 | ) 80 | 81 | for idx, (i, query, query_structure_idx) in enumerate(zip(argsort[:, 0], queries_unflatten, query_structure_idxs)): 82 | # convert query from np.ndarray to nested tuple 83 | query_key = tuple(query) 84 | query_structure = query_structure_list[query_structure_idx] 85 | 86 | hard_answer = hard_answers[query_key] 87 | easy_answer = easy_answers[query_key] 88 | num_hard = len(hard_answer) 89 | num_easy = len(easy_answer) 90 | assert len(hard_answer.intersection(easy_answer)) == 0 91 | cur_ranking = ranking[idx, list(easy_answer) + list(hard_answer)] 92 | cur_ranking, indices = torch.sort(cur_ranking) 93 | 94 | masks = indices >= num_easy 95 | if args.cuda: 96 | answer_list = torch.arange(num_hard + num_easy).to(torch.float).cuda() 97 | else: 98 | answer_list = torch.arange(num_hard + num_easy).to(torch.float) 99 | cur_ranking = cur_ranking - answer_list + 1 # filtered setting 100 | cur_ranking = cur_ranking[masks] # only take indices that belong to the hard answers 101 | 102 | if verbose: 103 | print(answer_list) 104 | print('ranking', cur_ranking) 105 | 106 | mrr = torch.mean(1./cur_ranking).item() 107 | h1 = torch.mean((cur_ranking <= 1).to(torch.float)).item() 108 | h3 = torch.mean((cur_ranking <= 3).to(torch.float)).item() 109 | h10 = torch.mean((cur_ranking <= 10).to(torch.float)).item() 110 | 111 | logs[query_structure].append({ 112 | 'MRR': mrr, 113 | 'HITS1': h1, 114 | 'HITS3': h3, 115 | 'HITS10': h10, 116 | 'num_hard_answer': num_hard, 117 | }) 118 | 119 | if step % args.test_log_steps == 0: 120 | print('Evaluating the model... (%d/%d)' % (step, total_steps)) 121 | 122 | step += 1 123 | 124 | metrics = collections.defaultdict(lambda: collections.defaultdict(int)) 125 | for query_structure in logs: 126 | for metric in logs[query_structure][0].keys(): 127 | if metric in ['num_hard_answer']: 128 | continue 129 | metrics[query_structure][metric] = sum([log[metric] for log in logs[query_structure]])/len(logs[query_structure]) 130 | metrics[query_structure]['num_queries'] = len(logs[query_structure]) 131 | 132 | return metrics 133 | 134 | 135 | 136 | if __name__=="__main__": 137 | 138 | model_dir = './trained_models' 139 | 140 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" 141 | os.environ["CUDA_VISIBLE_DEVICES"] = '7' 142 | 143 | data = sys.argv[1] 144 | run = sys.argv[2] 145 | 146 | logic = 'product' 147 | model_path = join(model_dir, f'{run}.pt') 148 | model = torch.load(model_path) 149 | 150 | 151 | 152 | arg_str = f'--do_test --cuda --data_path data/{data}-betae --test_batch_size 2 --logic {logic}' 153 | 154 | arg_str = arg_str.split() 155 | args = parse_args(arg_str) 156 | args.nentity = model.nentity 157 | model.conjunction_net.use_attention = args.use_attention 158 | prepare_new_attributes(model) 159 | 160 | 161 | 162 | train_path_iterator, train_other_iterator, valid_dataloader, test_dataloader,\ 163 | valid_hard_answers, valid_easy_answers, \ 164 | test_hard_answers, test_easy_answers = load_data(args, query_name_dict, args.tasks) 165 | 166 | 167 | data_dir = args.data_path 168 | rel_id2str = pickle.load(open(join(data_dir, 'id2rel.pkl'), 'rb')) 169 | 170 | 171 | 172 | metrics = test_step(model, test_easy_answers, test_hard_answers, args, test_dataloader, query_name_dict) 173 | print(metrics) 174 | 175 | a = wandb_log_metrics(metrics, args) 176 | -------------------------------------------------------------------------------- /util.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import random 3 | import torch 4 | import time 5 | import os 6 | import json 7 | import logging 8 | import pickle 9 | import wandb 10 | from collections import defaultdict 11 | import time 12 | from regularizers import * 13 | 14 | 15 | def get_regularizer(regularizer_setting, entity_dim, neg_input_possible=True, entity=False): 16 | """ 17 | :param neg_input_possible: for matrix_L1 (class MatrixSumRegularizer) 18 | :param dual: only apply regularizer to the first half embeddings (after chunk dim=-1) (for sigmoid only) 19 | """ 20 | if entity: 21 | key = 'e_reg_type' 22 | else: 23 | key = 'type' 24 | 25 | add_layernorm = regularizer_setting['e_layernorm'] 26 | if regularizer_setting[key] == '01': 27 | regularizer = Regularizer(base_add=0, min_val=0, max_val=1) 28 | elif regularizer_setting[key] == 'matrix_softmax': 29 | prob_dim = regularizer_setting['prob_dim'] 30 | regularizer = MatrixSoftmaxRegularizer(entity_dim, prob_dim) 31 | elif regularizer_setting[key] == 'vector_softmax': 32 | regularizer = VectorSoftmaxRegularizer(entity_dim) 33 | elif regularizer_setting[key] == 'sigmoid': 34 | regularizer = SigmoidRegularizer(entity_dim, dual=regularizer_setting['dual']) 35 | elif regularizer_setting[key] == 'matrix_L1': 36 | prob_dim = regularizer_setting['prob_dim'] 37 | regularizer = MatrixSumRegularizer(entity_dim, prob_dim, neg_input_possible) 38 | elif regularizer_setting[key] == 'matrix_sigmoid_L1': 39 | prob_dim = regularizer_setting['prob_dim'] 40 | regularizer = MatrixSigmoidSumRegularizer(entity_dim, prob_dim, neg_input_possible) 41 | elif regularizer_setting[key] == 'vector_sigmoid_L1': 42 | regularizer = VectorSigmoidSumRegularizer(entity_dim, neg_input_possible, add_layernorm) 43 | return regularizer 44 | 45 | def print_parameters(model): 46 | print('Model parameters:') 47 | num_params = 0 48 | for name, param in model.named_parameters(): 49 | print('Parameter %s: %s, require_grad = %s' % (name, str(param.size()), str(param.requires_grad))) 50 | if param.requires_grad: 51 | num_params += np.prod(param.size()) 52 | print('Parameter Number: %d' % num_params) 53 | 54 | def read_num_entity_relation_from_file(data_path): 55 | with open('%s/stats.txt'%data_path) as f: 56 | entrel = f.readlines() 57 | nentity = int(entrel[0].split(' ')[-1]) 58 | nrelation = int(entrel[1].split(' ')[-1]) 59 | return nentity, nrelation 60 | 61 | 62 | def wandb_initialize(config_dict): 63 | return wandb.init( 64 | project="kgfolreasoning", 65 | entity='kgfol', 66 | config=config_dict 67 | ) 68 | 69 | 70 | def save_model(model, optimizer, save_variable_list, save_dir, args): 71 | ''' 72 | Save the parameters of the model and the optimizer, 73 | as well as some other variables such as step and learning_rate 74 | ''' 75 | 76 | argparse_dict = vars(args) 77 | with open(os.path.join(save_dir, 'config.json'), 'w') as fjson: 78 | json.dump(argparse_dict, fjson) 79 | 80 | torch.save({ 81 | **save_variable_list, 82 | 'model_state_dict': model.state_dict(), 83 | 'optimizer_state_dict': optimizer.state_dict()}, 84 | os.path.join(save_dir, 'checkpoint') 85 | ) 86 | 87 | 88 | def set_logger(args): 89 | """ 90 | Write logs to console and log file 91 | """ 92 | if args.do_train: 93 | log_file = os.path.join(args.save_path, 'train.log') 94 | else: 95 | log_file = os.path.join(args.save_path, 'test.log') 96 | 97 | logging.basicConfig( 98 | format='%(asctime)s %(levelname)-8s %(message)s', 99 | level=logging.INFO, 100 | datefmt='%Y-%m-%d %H:%M:%S', 101 | filename=log_file, 102 | filemode='a+' 103 | ) 104 | if args.print_on_screen: 105 | console = logging.StreamHandler() 106 | console.setLevel(logging.INFO) 107 | formatter = logging.Formatter('%(asctime)s %(levelname)-8s %(message)s') 108 | console.setFormatter(formatter) 109 | logging.getLogger('').addHandler(console) 110 | 111 | 112 | def log_metrics(mode, step, metrics): 113 | ''' 114 | Print the evaluation logs 115 | ''' 116 | for metric in metrics: 117 | logging.info('%s %s at step %d: %f' % (mode, metric, step, metrics[metric])) 118 | print('%s %s at step %d: %f' % (mode, metric, step, metrics[metric])) 119 | wandb.log({f'{mode}_{metric}': metrics[metric], 'current_step': step}) 120 | 121 | 122 | def evaluate(model, tp_answers, fn_answers, args, dataloader, query_name_dict, mode, step): 123 | ''' 124 | Evaluate queries in dataloader 125 | ''' 126 | average_metrics = defaultdict(float) 127 | average_pos_metrics = defaultdict(float) 128 | average_neg_metrics = defaultdict(float) 129 | all_metrics = defaultdict(float) 130 | 131 | metrics = model.test_step(model, tp_answers, fn_answers, args, dataloader, query_name_dict) 132 | num_query_structures = 0 133 | 134 | num_pos_query_structures = 0 135 | num_neg_query_structures = 0 136 | 137 | num_queries = 0 138 | for query_structure in metrics: 139 | log_metrics(mode + " " + query_name_dict[query_structure], step, metrics[query_structure]) 140 | for metric in metrics[query_structure]: 141 | query_name = query_name_dict[query_structure] # e.g. 1p 142 | all_metrics["_".join([query_name, metric])] = metrics[query_structure][metric] 143 | if metric != 'num_queries': 144 | average_metrics[metric] += metrics[query_structure][metric] 145 | if 'n' in query_name: 146 | average_neg_metrics[metric] += metrics[query_structure][metric] 147 | else: 148 | average_pos_metrics[metric] += metrics[query_structure][metric] 149 | num_queries += metrics[query_structure]['num_queries'] 150 | num_query_structures += 1 151 | if 'n' in query_name: 152 | num_neg_query_structures += 1 153 | else: 154 | num_pos_query_structures += 1 155 | 156 | for metric in average_pos_metrics: 157 | average_pos_metrics[metric] /= num_pos_query_structures 158 | # writer.add_scalar("_".join([mode, 'average', metric]), average_metrics[metric], step) 159 | all_metrics["_".join(["average_pos", metric])] = average_pos_metrics[metric] 160 | 161 | for metric in average_neg_metrics: 162 | average_neg_metrics[metric] /= num_neg_query_structures 163 | # writer.add_scalar("_".join([mode, 'average', metric]), average_metrics[metric], step) 164 | all_metrics["_".join(["average_neg", metric])] = average_neg_metrics[metric] 165 | 166 | 167 | for metric in average_metrics: 168 | average_metrics[metric] /= num_query_structures 169 | # writer.add_scalar("_".join([mode, 'average', metric]), average_metrics[metric], step) 170 | all_metrics["_".join(["average", metric])] = average_metrics[metric] 171 | 172 | log_metrics('%s average' % mode, step, average_metrics) 173 | log_metrics('%s average_pos' % mode, step, average_pos_metrics) 174 | log_metrics('%s average_neg' % mode, step, average_neg_metrics) 175 | 176 | 177 | return all_metrics 178 | 179 | 180 | 181 | def list2tuple(l): 182 | return tuple(list2tuple(x) if type(x)==list else x for x in l) 183 | 184 | def tuple2list(t): 185 | return list(tuple2list(x) if type(x)==tuple else x for x in t) 186 | 187 | flatten=lambda l: sum(map(flatten, l),[]) if isinstance(l,tuple) else [l] 188 | 189 | def parse_time(): 190 | return time.strftime("%Y.%m.%d-%H:%M:%S", time.localtime()) 191 | 192 | def set_global_seed(seed): 193 | torch.manual_seed(seed) 194 | torch.cuda.manual_seed(seed) 195 | np.random.seed(seed) 196 | random.seed(seed) 197 | torch.backends.cudnn.deterministic=True 198 | 199 | def eval_tuple(arg_return): 200 | """Evaluate a tuple string into a tuple.""" 201 | if type(arg_return) == tuple: 202 | return arg_return 203 | if arg_return[0] not in ["(", "["]: 204 | arg_return = eval(arg_return) 205 | else: 206 | splitted = arg_return[1:-1].split(",") 207 | List = [] 208 | for item in splitted: 209 | try: 210 | item = eval(item) 211 | except: 212 | pass 213 | if item == "": 214 | continue 215 | List.append(item) 216 | arg_return = tuple(List) 217 | return arg_return 218 | 219 | 220 | def flatten_query_and_convert_structure_to_idx(query_structure2queries, query_structure2idx): 221 | """ 222 | :param query_structure2queries: type dict{query_structure: list[query_info(with entity and relation id)]} 223 | e.g. {('e', ('r',)): [(8410, (11,)), (7983, (12,))} 224 | :param query_structure2idx: type dict{query_structure: structure_idx} 225 | e.g. {('e', ('r',)): 0} 226 | :return all_queries: type list[(query_info, query_structure)] 227 | """ 228 | all_query_list = [(q, query_structure2idx[query_structure]) 229 | for query_structure, query_list in query_structure2queries.items() 230 | for q in query_list] 231 | return all_query_list --------------------------------------------------------------------------------