├── .gitignore ├── README.md ├── datasets ├── __init__.py ├── download.sh ├── kg_dataset.py └── process.py ├── examples ├── train_SEA_NELLh100_32.sh └── train_SEPA_NELLh100_500.sh ├── models ├── __init__.py ├── base.py ├── complex.py ├── euclidean.py └── hyperbolic.py ├── optimizers ├── __init__.py ├── kg_optimizer.py └── regularizers.py ├── requirements.txt ├── run.py ├── set_env.sh ├── test.py └── utils ├── .DS_Store ├── __init__.py ├── euclidean.py ├── hyperbolic.py └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | data/ 2 | logs/ 3 | */__pycache__ 4 | hyp_kg_env/ 5 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Link Prediction with attention applied on multiple knowledge graph embedding models 2 | 3 | This code is the official PyTorch implementation of [Link Prediction with attention applied on multiple knowledge graph embedding models](https://dl.acm.org/doi/pdf/10.1145/3543507.3583358) [1] . 4 | This implementation lies on the KGEmb framework developed by [2] 5 | 6 | ## Datasets 7 | 8 | Download and pre-process the datasets: 9 | 10 | ```bash 11 | source datasets/download.sh 12 | python datasets/process.py 13 | ``` 14 | ## Installation 15 | 16 | First, create a python 3.7 environment and install dependencies: 17 | 18 | ```bash 19 | virtualenv -p python3.7 hyp_kg_env 20 | source hyp_kg_env/bin/activate 21 | pip install -r requirements.txt 22 | ``` 23 | 24 | Then, set environment variables and activate your environment: 25 | 26 | ```bash 27 | source set_env.sh 28 | ``` 29 | 30 | ``` 31 | 32 | ## Usage 33 | 34 | To train and evaluate a KG embedding model for the link prediction task, use the run.py script: 35 | 36 | ```bash 37 | usage: run.py [-h] [--dataset {FB15K,WN,WN18RR,FB237,YAGO3-10}] 38 | [--model {TransE,CP,MurE,RotE,RefE,AttE,RotH,RefH,AttH,ComplEx,RotatE}] 39 | [--regularizer {N3,N2}] [--reg REG] 40 | [--optimizer {Adagrad,Adam,SGD,SparseAdam,RSGD,RAdam}] 41 | [--max_epochs MAX_EPOCHS] [--patience PATIENCE] [--valid VALID] 42 | [--rank RANK] [--batch_size BATCH_SIZE] 43 | [--neg_sample_size NEG_SAMPLE_SIZE] [--dropout DROPOUT] 44 | [--init_size INIT_SIZE] [--learning_rate LEARNING_RATE] 45 | [--gamma GAMMA] [--bias {constant,learn,none}] 46 | [--dtype {single,double}] [--double_neg] [--debug] [--multi_c] 47 | 48 | Knowledge Graph Embedding 49 | 50 | optional arguments: 51 | -h, --help show this help message and exit 52 | --dataset {FB15K,WN,WN18RR,FB237,YAGO3-10} 53 | Knowledge Graph dataset 54 | --model {TransE,CP,MurE,RotE,RefE,AttE,RotH,RefH,AttH,ComplEx,RotatE} 55 | Knowledge Graph embedding model 56 | --regularizer {N3,N2} 57 | Regularizer 58 | --reg REG Regularization weight 59 | --optimizer {Adagrad,Adam,SparseAdam} 60 | Optimizer 61 | --max_epochs MAX_EPOCHS 62 | Maximum number of epochs to train for 63 | --patience PATIENCE Number of epochs before early stopping 64 | --valid VALID Number of epochs before validation 65 | --rank RANK Embedding dimension 66 | --batch_size BATCH_SIZE 67 | Batch size 68 | --neg_sample_size NEG_SAMPLE_SIZE 69 | Negative sample size, -1 to not use negative sampling 70 | --dropout DROPOUT Dropout rate 71 | --init_size INIT_SIZE 72 | Initial embeddings scale 73 | --learning_rate LEARNING_RATE 74 | Learning rate 75 | --gamma GAMMA Margin for distance-based losses 76 | --bias {constant,learn,none} 77 | Bias type (none for no bias) 78 | --dtype {single,double} 79 | Machine precision 80 | --double_neg Whether to negative sample both head and tail entities 81 | --debug Only use 1000 examples for debugging 82 | --multi_c Multiple curvatures per relation 83 | ``` 84 | 85 | 86 | 87 | ## Citation 88 | 89 | If you use this implementation, please cite the following paper [1]: 90 | 91 | ``` 92 | @inproceedings{10.1145/3543507.3583358, author = {Gregucci, Cosimo and Nayyeri, Mojtaba and Hern\'{a}ndez, Daniel and Staab, Steffen}, title = {Link Prediction with Attention Applied on Multiple Knowledge Graph Embedding Models}, year = {2023}, isbn = {9781450394161}, publisher = {Association for Computing Machinery}, address = {New York, NY, USA}, url = {https://doi.org/10.1145/3543507.3583358}, doi = {10.1145/3543507.3583358}, booktitle = {Proceedings of the ACM Web Conference 2023}, pages = {2600–2610}, numpages = {11}, keywords = {link prediction, geometric integration, ensemble, Knowledge graph embedding}, location = {Austin, TX, USA}, series = {WWW '23} } 93 | ``` 94 | 95 | ## References 96 | 97 | [1] Cosimo Gregucci, Mojtaba Nayyeri, Daniel Hernández, and Steffen Staab. 2023. Link Prediction with Attention Applied on Multiple Knowledge Graph Embedding Models. In Proceedings of the ACM Web Conference 2023 (WWW '23). Association for Computing Machinery, New York, NY, USA, 2600–2610. https://doi.org/10.1145/3543507.3583358 98 | 99 | 100 | [2] Chami, Ines, et al. "Low-Dimensional Hyperbolic Knowledge Graph Embeddings." 101 | Annual Meeting of the Association for Computational Linguistics. 2020. 102 | 103 | 104 | -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cgregucci/Link-Prediction-with-Attention-applied-on-multiple-knowledge-graph-embedding-models/188cef319905f72b47da90b19c0cd78d36fa0e78/datasets/__init__.py -------------------------------------------------------------------------------- /datasets/download.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | wget https://dl.fbaipublicfiles.com/kbc/data.tar.gz 3 | tar -xvzf data.tar.gz 4 | rm data.tar.gz 5 | -------------------------------------------------------------------------------- /datasets/kg_dataset.py: -------------------------------------------------------------------------------- 1 | """Dataset class for loading and processing KG datasets.""" 2 | 3 | import os 4 | import pickle as pkl 5 | 6 | import numpy as np 7 | import torch 8 | 9 | 10 | class KGDataset(object): 11 | """Knowledge Graph dataset class.""" 12 | 13 | def __init__(self, data_path, debug): 14 | """Creates KG dataset object for data loading. 15 | 16 | Args: 17 | data_path: Path to directory containing train/valid/test pickle files produced by process.py 18 | debug: boolean indicating whether to use debug mode or not 19 | if true, the dataset will only contain 1000 examples for debugging. 20 | """ 21 | self.data_path = data_path 22 | self.debug = debug 23 | self.data = {} 24 | for split in ["train", "valid", "test"]: 25 | file_path = os.path.join(self.data_path, split + ".pickle") 26 | with open(file_path, "rb") as in_file: 27 | self.data[split] = pkl.load(in_file) 28 | filters_file = open(os.path.join(self.data_path, "to_skip.pickle"), "rb") 29 | self.to_skip = pkl.load(filters_file) 30 | filters_file.close() 31 | max_axis = np.max(self.data["train"], axis=0) 32 | self.n_entities = int(max(max_axis[0], max_axis[2]) + 1) 33 | self.n_predicates = int(max_axis[1] + 1) * 2 34 | 35 | def get_examples(self, split, rel_idx=-1): 36 | """Get examples in a split. 37 | 38 | Args: 39 | split: String indicating the split to use (train/valid/test) 40 | rel_idx: integer for relation index to keep (-1 to keep all relation) 41 | 42 | Returns: 43 | examples: torch.LongTensor containing KG triples in a split 44 | """ 45 | examples = self.data[split] 46 | if split == "train": 47 | copy = np.copy(examples) 48 | tmp = np.copy(copy[:, 0]) 49 | copy[:, 0] = copy[:, 2] 50 | copy[:, 2] = tmp 51 | copy[:, 1] += self.n_predicates // 2 52 | examples = np.vstack((examples, copy)) 53 | if rel_idx >= 0: 54 | examples = examples[examples[:, 1] == rel_idx] 55 | if self.debug: 56 | examples = examples[:1000] 57 | if type(examples) is str: 58 | rows = examples.split("\n") 59 | res = np.zeros((len(rows)-1,3), dtype=np.int64) 60 | for i in range(0,len(rows)-1): 61 | np_row = np.array(rows[i].split("\t"), dtype=np.int64) 62 | res[i] = np_row 63 | examples = res 64 | return torch.from_numpy(examples.astype("int64")) 65 | 66 | def get_filters(self, ): 67 | """Return filter dict to compute ranking metrics in the filtered setting.""" 68 | return self.to_skip 69 | 70 | def get_shape(self): 71 | """Returns KG dataset shape.""" 72 | return self.n_entities, self.n_predicates, self.n_entities 73 | -------------------------------------------------------------------------------- /datasets/process.py: -------------------------------------------------------------------------------- 1 | """Knowledge Graph dataset pre-processing functions.""" 2 | 3 | import collections 4 | import os 5 | import pickle 6 | 7 | import numpy as np 8 | 9 | 10 | def get_idx(path): 11 | """Map entities and relations to unique ids. 12 | 13 | Args: 14 | path: path to directory with raw dataset files (tab-separated train/valid/test triples) 15 | 16 | Returns: 17 | ent2idx: Dictionary mapping raw entities to unique ids 18 | rel2idx: Dictionary mapping raw relations to unique ids 19 | """ 20 | entities, relations = set(), set() 21 | for split in ["train", "valid", "test"]: 22 | with open(os.path.join(path, split), "r") as lines: 23 | for line in lines: 24 | lhs, rel, rhs = line.strip().split("\t") 25 | entities.add(lhs) 26 | entities.add(rhs) 27 | relations.add(rel) 28 | ent2idx = {x: i for (i, x) in enumerate(sorted(entities))} 29 | rel2idx = {x: i for (i, x) in enumerate(sorted(relations))} 30 | return ent2idx, rel2idx 31 | 32 | 33 | def to_np_array(dataset_file, ent2idx, rel2idx): 34 | """Map raw dataset file to numpy array with unique ids. 35 | 36 | Args: 37 | dataset_file: Path to file containing raw triples in a split 38 | ent2idx: Dictionary mapping raw entities to unique ids 39 | rel2idx: Dictionary mapping raw relations to unique ids 40 | 41 | Returns: 42 | Numpy array of size n_examples x 3 mapping the raw dataset file to ids 43 | """ 44 | examples = [] 45 | with open(dataset_file, "r") as lines: 46 | for line in lines: 47 | lhs, rel, rhs = line.strip().split("\t") 48 | try: 49 | examples.append([ent2idx[lhs], rel2idx[rel], ent2idx[rhs]]) 50 | except ValueError: 51 | continue 52 | return np.array(examples).astype("int64") 53 | 54 | 55 | def get_filters(examples, n_relations): 56 | """Create filtering lists for evaluation. 57 | 58 | Args: 59 | examples: Numpy array of size n_examples x 3 containing KG triples 60 | n_relations: Int indicating the total number of relations in the KG 61 | 62 | Returns: 63 | lhs_final: Dictionary mapping queries (entity, relation) to filtered entities for left-hand-side prediction 64 | rhs_final: Dictionary mapping queries (entity, relation) to filtered entities for right-hand-side prediction 65 | """ 66 | lhs_filters = collections.defaultdict(set) 67 | rhs_filters = collections.defaultdict(set) 68 | for lhs, rel, rhs in examples: 69 | rhs_filters[(lhs, rel)].add(rhs) 70 | lhs_filters[(rhs, rel + n_relations)].add(lhs) 71 | lhs_final = {} 72 | rhs_final = {} 73 | for k, v in lhs_filters.items(): 74 | lhs_final[k] = sorted(list(v)) 75 | for k, v in rhs_filters.items(): 76 | rhs_final[k] = sorted(list(v)) 77 | return lhs_final, rhs_final 78 | 79 | 80 | def process_dataset(path): 81 | """Map entities and relations to ids and saves corresponding pickle arrays. 82 | 83 | Args: 84 | path: Path to dataset directory 85 | 86 | Returns: 87 | examples: Dictionary mapping splits to with Numpy array containing corresponding KG triples. 88 | filters: Dictionary containing filters for lhs and rhs predictions. 89 | """ 90 | ent2idx, rel2idx = get_idx(dataset_path) 91 | examples = {} 92 | splits = ["train", "valid", "test"] 93 | for split in splits: 94 | dataset_file = os.path.join(path, split) 95 | examples[split] = to_np_array(dataset_file, ent2idx, rel2idx) 96 | all_examples = np.concatenate([examples[split] for split in splits], axis=0) 97 | lhs_skip, rhs_skip = get_filters(all_examples, len(rel2idx)) 98 | filters = {"lhs": lhs_skip, "rhs": rhs_skip} 99 | return examples, filters 100 | 101 | 102 | if __name__ == "__main__": 103 | data_path = os.environ["DATA_PATH"] 104 | for dataset_name in os.listdir(data_path): 105 | dataset_path = os.path.join(data_path, dataset_name) 106 | dataset_examples, dataset_filters = process_dataset(dataset_path) 107 | for dataset_split in ["train", "valid", "test"]: 108 | save_path = os.path.join(dataset_path, dataset_split + ".pickle") 109 | with open(save_path, "wb") as save_file: 110 | pickle.dump(dataset_examples[dataset_split], save_file) 111 | with open(os.path.join(dataset_path, "to_skip.pickle"), "wb") as save_file: 112 | pickle.dump(dataset_filters, save_file) 113 | -------------------------------------------------------------------------------- /examples/train_SEA_NELLh100_32.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | source set_env.sh 3 | source hyp_kg_env/bin/activate 4 | export CUDA_VISIBLE_DEVICES=3 5 | python run.py \ 6 | --dataset NELL-995-h100 \ 7 | --model SEA \ 8 | --rank 32 \ 9 | --regularizer N3 \ 10 | --reg 0.0 \ 11 | --optimizer Adam \ 12 | --max_epochs 500 \ 13 | --patience 15 \ 14 | --valid 5 \ 15 | --neg_sample_size 250 \ 16 | --init_size 0.001 \ 17 | --learning_rate 0.001 \ 18 | --gamma 0.0 \ 19 | --bias learn \ 20 | --batch_size 500 \ 21 | --dtype single \ 22 | --double_neg 23 | 24 | -------------------------------------------------------------------------------- /examples/train_SEPA_NELLh100_500.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | source set_env.sh 3 | source hyp_kg_env/bin/activate 4 | export CUDA_VISIBLE_DEVICES=3 5 | python run.py \ 6 | --dataset NELL-995-h100 \ 7 | --model SEPA \ 8 | --rank 500 \ 9 | --regularizer N3 \ 10 | --reg 0.0 \ 11 | --optimizer Adam \ 12 | --max_epochs 500 \ 13 | --patience 15 \ 14 | --valid 5 \ 15 | --neg_sample_size 250 \ 16 | --init_size 0.001 \ 17 | --learning_rate 0.001 \ 18 | --gamma 0.0 \ 19 | --bias learn \ 20 | --batch_size 500 \ 21 | --dtype single \ 22 | --multi_c \ 23 | --double_neg 24 | 25 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .complex import * 2 | from .euclidean import * 3 | from .hyperbolic import * 4 | 5 | all_models = EUC_MODELS + HYP_MODELS + COMPLEX_MODELS 6 | -------------------------------------------------------------------------------- /models/base.py: -------------------------------------------------------------------------------- 1 | """Base Knowledge Graph embedding model.""" 2 | from abc import ABC, abstractmethod 3 | 4 | import torch 5 | from torch import nn 6 | 7 | 8 | class KGModel(nn.Module, ABC): 9 | """Base Knowledge Graph Embedding model class. 10 | 11 | Attributes: 12 | sizes: Tuple[int, int, int] with (n_entities, n_relations, n_entities) 13 | rank: integer for embedding dimension 14 | dropout: float for dropout rate 15 | gamma: torch.nn.Parameter for margin in ranking-based loss 16 | data_type: torch.dtype for machine precision (single or double) 17 | bias: string for whether to learn or fix bias (none for no bias) 18 | init_size: float for embeddings' initialization scale 19 | entity: torch.nn.Embedding with entity embeddings 20 | rel: torch.nn.Embedding with relation embeddings 21 | bh: torch.nn.Embedding with head entity bias embeddings 22 | bt: torch.nn.Embedding with tail entity bias embeddings 23 | """ 24 | 25 | def __init__(self, sizes, rank, dropout, gamma, data_type, bias, init_size): 26 | """Initialize KGModel.""" 27 | super(KGModel, self).__init__() 28 | if data_type == 'double': 29 | self.data_type = torch.double 30 | else: 31 | self.data_type = torch.float 32 | self.sizes = sizes 33 | self.rank = rank 34 | self.dropout = dropout 35 | self.bias = bias 36 | self.init_size = init_size 37 | self.gamma = nn.Parameter(torch.Tensor([gamma]), requires_grad=False) 38 | self.entity = nn.Embedding(sizes[0], rank) 39 | self.rel = nn.Embedding(sizes[1], rank) 40 | self.bh = nn.Embedding(sizes[0], 1) 41 | self.bh.weight.data = torch.zeros((sizes[0], 1), dtype=self.data_type) 42 | self.bt = nn.Embedding(sizes[0], 1) 43 | self.bt.weight.data = torch.zeros((sizes[0], 1), dtype=self.data_type) 44 | self.att = [] 45 | 46 | @abstractmethod 47 | def get_queries(self, queries): 48 | """Compute embedding and biases of queries. 49 | 50 | Args: 51 | queries: torch.LongTensor with query triples (head, relation, tail) 52 | Returns: 53 | lhs_e: torch.Tensor with queries' embeddings (embedding of head entities and relations) 54 | lhs_biases: torch.Tensor with head entities' biases 55 | """ 56 | pass 57 | 58 | @abstractmethod 59 | def get_rhs(self, queries, eval_mode): 60 | """Get embeddings and biases of target entities. 61 | 62 | Args: 63 | queries: torch.LongTensor with query triples (head, relation, tail) 64 | eval_mode: boolean, true for evaluation, false for training 65 | Returns: 66 | rhs_e: torch.Tensor with targets' embeddings 67 | if eval_mode=False returns embedding of tail entities (n_queries x rank) 68 | else returns embedding of all possible entities in the KG dataset (n_entities x rank) 69 | rhs_biases: torch.Tensor with targets' biases 70 | if eval_mode=False returns biases of tail entities (n_queries x 1) 71 | else returns biases of all possible entities in the KG dataset (n_entities x 1) 72 | """ 73 | pass 74 | 75 | @abstractmethod 76 | def similarity_score(self, lhs_e, rhs_e, eval_mode): 77 | """Compute similarity scores or queries against targets in embedding space. 78 | 79 | Args: 80 | lhs_e: torch.Tensor with queries' embeddings 81 | rhs_e: torch.Tensor with targets' embeddings 82 | eval_mode: boolean, true for evaluation, false for training 83 | Returns: 84 | scores: torch.Tensor with similarity scores of queries against targets 85 | """ 86 | pass 87 | 88 | def score(self, lhs, rhs, eval_mode): 89 | """Scores queries against targets 90 | 91 | Args: 92 | lhs: Tuple[torch.Tensor, torch.Tensor] with queries' embeddings and head biases 93 | returned by get_queries(queries) 94 | rhs: Tuple[torch.Tensor, torch.Tensor] with targets' embeddings and tail biases 95 | returned by get_rhs(queries, eval_mode) 96 | eval_mode: boolean, true for evaluation, false for training 97 | Returns: 98 | score: torch.Tensor with scores of queries against targets 99 | if eval_mode=True, returns scores against all possible tail entities, shape (n_queries x n_entities) 100 | else returns scores for triples in batch (shape n_queries x 1) 101 | """ 102 | lhs_e, lhs_biases = lhs 103 | rhs_e, rhs_biases = rhs 104 | score = self.similarity_score(lhs_e, rhs_e, eval_mode) 105 | if self.bias == 'constant': 106 | return self.gamma.item() + score 107 | elif self.bias == 'learn': 108 | if eval_mode: 109 | return lhs_biases + rhs_biases.t() + score 110 | else: 111 | return lhs_biases + rhs_biases + score 112 | else: 113 | return score 114 | 115 | def get_factors(self, queries): 116 | """Computes factors for embeddings' regularization. 117 | 118 | Args: 119 | queries: torch.LongTensor with query triples (head, relation, tail) 120 | Returns: 121 | Tuple[torch.Tensor, torch.Tensor, torch.Tensor] with embeddings to regularize 122 | """ 123 | head_e = self.entity(queries[:, 0]) 124 | rel_e = self.rel(queries[:, 1]) 125 | rhs_e = self.entity(queries[:, 2]) 126 | return head_e, rel_e, rhs_e 127 | 128 | def forward(self, queries, eval_mode=False): 129 | """KGModel forward pass. 130 | 131 | Args: 132 | queries: torch.LongTensor with query triples (head, relation, tail) 133 | eval_mode: boolean, true for evaluation, false for training 134 | Returns: 135 | predictions: torch.Tensor with triples' scores 136 | shape is (n_queries x 1) if eval_mode is false 137 | else (n_queries x n_entities) 138 | factors: embeddings to regularize 139 | """ 140 | # get embeddings and similarity scores 141 | lhs_e, lhs_biases = self.get_queries(queries) 142 | # queries = F.dropout(queries, self.dropout, training=self.training) 143 | rhs_e, rhs_biases = self.get_rhs(queries, eval_mode) 144 | # candidates = F.dropout(candidates, self.dropout, training=self.training) 145 | predictions = self.score((lhs_e, lhs_biases), (rhs_e, rhs_biases), eval_mode) 146 | 147 | # get factors for regularization 148 | factors = self.get_factors(queries) 149 | return predictions, factors 150 | 151 | def get_ranking(self, queries, filters, batch_size=500): 152 | """Compute filtered ranking of correct entity for evaluation. 153 | 154 | Args: 155 | queries: torch.LongTensor with query triples (head, relation, tail) 156 | filters: filters[(head, relation)] gives entities to ignore (filtered setting) 157 | batch_size: int for evaluation batch size 158 | 159 | Returns: 160 | ranks: torch.Tensor with ranks or correct entities 161 | """ 162 | ranks = torch.ones(len(queries)) 163 | with torch.no_grad(): 164 | b_begin = 0 165 | candidates = self.get_rhs(queries, eval_mode=True) 166 | while b_begin < len(queries): 167 | these_queries = queries[b_begin:b_begin + batch_size].cuda() 168 | 169 | q = self.get_queries(these_queries) 170 | rhs = self.get_rhs(these_queries, eval_mode=False) 171 | 172 | scores = self.score(q, candidates, eval_mode=True) 173 | targets = self.score(q, rhs, eval_mode=False) 174 | 175 | # set filtered and true scores to -1e6 to be ignored 176 | for i, query in enumerate(these_queries): 177 | filter_out = filters[(query[0].item(), query[1].item())] 178 | filter_out += [queries[b_begin + i, 2].item()] 179 | if query[0].item() != query[2].item(): 180 | filter_out += [queries[b_begin + i, 0].item()] 181 | 182 | scores[i, torch.LongTensor(filter_out)] = -1e6 183 | ranks[b_begin:b_begin + batch_size] += torch.sum( 184 | (scores >= targets).float(), dim=1 185 | ).cpu() 186 | b_begin += batch_size 187 | return ranks 188 | 189 | def compute_metrics(self, examples, filters, batch_size=500): 190 | """Compute ranking-based evaluation metrics. 191 | 192 | Args: 193 | examples: torch.LongTensor of size n_examples x 3 containing triples' indices 194 | filters: Dict with entities to skip per query for evaluation in the filtered setting 195 | batch_size: integer for batch size to use to compute scores 196 | 197 | Returns: 198 | Evaluation metrics (mean rank, mean reciprocical rank and hits) 199 | """ 200 | mean_rank = {} 201 | mean_reciprocal_rank = {} 202 | hits_at = {} 203 | 204 | for m in ["rhs", "lhs"]: 205 | q = examples.clone() 206 | if m == "lhs": 207 | tmp = torch.clone(q[:, 0]) 208 | q[:, 0] = q[:, 2] 209 | q[:, 2] = tmp 210 | q[:, 1] += self.sizes[1] // 2 211 | ranks = self.get_ranking(q, filters[m], batch_size=batch_size) 212 | mean_rank[m] = torch.mean(ranks).item() 213 | mean_reciprocal_rank[m] = torch.mean(1. / ranks).item() 214 | hits_at[m] = torch.FloatTensor((list(map( 215 | lambda x: torch.mean((ranks <= x).float()).item(), 216 | (1, 3, 10) 217 | )))) 218 | 219 | return mean_rank, mean_reciprocal_rank, hits_at 220 | -------------------------------------------------------------------------------- /models/complex.py: -------------------------------------------------------------------------------- 1 | """Euclidean Knowledge Graph embedding models where embeddings are in complex space.""" 2 | import torch 3 | from torch import nn 4 | 5 | from models.base import KGModel 6 | 7 | COMPLEX_MODELS = ["ComplEx", "RotatE"] 8 | 9 | 10 | class BaseC(KGModel): 11 | """Complex Knowledge Graph Embedding models. 12 | 13 | Attributes: 14 | embeddings: complex embeddings for entities and relations 15 | """ 16 | 17 | def __init__(self, args): 18 | """Initialize a Complex KGModel.""" 19 | super(BaseC, self).__init__(args.sizes, args.rank, args.dropout, args.gamma, args.dtype, args.bias, 20 | args.init_size) 21 | assert self.rank % 2 == 0, "Complex models require even embedding dimension" 22 | self.rank = self.rank // 2 23 | self.embeddings = nn.ModuleList([ 24 | nn.Embedding(s, 2 * self.rank, sparse=True) 25 | for s in self.sizes[:2] 26 | ]) 27 | self.embeddings[0].weight.data = self.init_size * self.embeddings[0].weight.to(self.data_type) 28 | self.embeddings[1].weight.data = self.init_size * self.embeddings[1].weight.to(self.data_type) 29 | 30 | def get_rhs(self, queries, eval_mode): 31 | """Get embeddings and biases of target entities.""" 32 | if eval_mode: 33 | return self.embeddings[0].weight, self.bt.weight 34 | else: 35 | return self.embeddings[0](queries[:, 2]), self.bt(queries[:, 2]) 36 | 37 | def similarity_score(self, lhs_e, rhs_e, eval_mode): 38 | """Compute similarity scores or queries against targets in embedding space.""" 39 | lhs_e = lhs_e[:, :self.rank], lhs_e[:, self.rank:] 40 | rhs_e = rhs_e[:, :self.rank], rhs_e[:, self.rank:] 41 | if eval_mode: 42 | return lhs_e[0] @ rhs_e[0].transpose(0, 1) + lhs_e[1] @ rhs_e[1].transpose(0, 1) 43 | else: 44 | return torch.sum( 45 | lhs_e[0] * rhs_e[0] + lhs_e[1] * rhs_e[1], 46 | 1, keepdim=True 47 | ) 48 | 49 | def get_complex_embeddings(self, queries): 50 | """Get complex embeddings of queries.""" 51 | head_e = self.embeddings[0](queries[:, 0]) 52 | rel_e = self.embeddings[1](queries[:, 1]) 53 | rhs_e = self.embeddings[0](queries[:, 2]) 54 | head_e = head_e[:, :self.rank], head_e[:, self.rank:] 55 | rel_e = rel_e[:, :self.rank], rel_e[:, self.rank:] 56 | rhs_e = rhs_e[:, :self.rank], rhs_e[:, self.rank:] 57 | return head_e, rel_e, rhs_e 58 | 59 | def get_factors(self, queries): 60 | """Compute factors for embeddings' regularization.""" 61 | head_e, rel_e, rhs_e = self.get_complex_embeddings(queries) 62 | head_f = torch.sqrt(head_e[0] ** 2 + head_e[1] ** 2) 63 | rel_f = torch.sqrt(rel_e[0] ** 2 + rel_e[1] ** 2) 64 | rhs_f = torch.sqrt(rhs_e[0] ** 2 + rhs_e[1] ** 2) 65 | return head_f, rel_f, rhs_f 66 | 67 | 68 | class ComplEx(BaseC): 69 | """Simple complex model http://proceedings.mlr.press/v48/trouillon16.pdf""" 70 | 71 | def get_queries(self, queries): 72 | """Compute embedding and biases of queries.""" 73 | head_e, rel_e, _ = self.get_complex_embeddings(queries) 74 | lhs_e = torch.cat([ 75 | head_e[0] * rel_e[0] - head_e[1] * rel_e[1], 76 | head_e[0] * rel_e[1] + head_e[1] * rel_e[0] 77 | ], 1) 78 | return lhs_e, self.bh(queries[:, 0]) 79 | 80 | 81 | class RotatE(BaseC): 82 | """Rotations in complex space https://openreview.net/pdf?id=HkgEQnRqYQ""" 83 | 84 | def get_queries(self, queries): 85 | """Compute embedding and biases of queries.""" 86 | head_e, rel_e, _ = self.get_complex_embeddings(queries) 87 | rel_norm = torch.sqrt(rel_e[0] ** 2 + rel_e[1] ** 2) 88 | cos = rel_e[0] / rel_norm 89 | sin = rel_e[1] / rel_norm 90 | lhs_e = torch.cat([ 91 | head_e[0] * cos - head_e[1] * sin, 92 | head_e[0] * sin + head_e[1] * cos 93 | ], 1) 94 | return lhs_e, self.bh(queries[:, 0]) 95 | -------------------------------------------------------------------------------- /models/euclidean.py: -------------------------------------------------------------------------------- 1 | """Euclidean Knowledge Graph embedding models where embeddings are in real space.""" 2 | import numpy as np 3 | import torch 4 | from torch import nn 5 | 6 | from models.base import KGModel 7 | from utils.euclidean import euc_sqdistance, givens_rotations, givens_reflection 8 | 9 | EUC_MODELS = ["Distmult", "TransE", "CP", "MurE", "RotE", "RefE", "AttE", "SEA"] 10 | 11 | 12 | class BaseE(KGModel): 13 | """Euclidean Knowledge Graph Embedding models. 14 | 15 | Attributes: 16 | sim: similarity metric to use (dist for distance and dot for dot product) 17 | """ 18 | 19 | def __init__(self, args): 20 | super(BaseE, self).__init__(args.sizes, args.rank, args.dropout, args.gamma, args.dtype, args.bias, 21 | args.init_size) 22 | self.entity.weight.data = self.init_size * torch.randn((self.sizes[0], self.rank), dtype=self.data_type) 23 | self.rel.weight.data = self.init_size * torch.randn((self.sizes[1], self.rank), dtype=self.data_type) 24 | self.att = [] 25 | def get_rhs(self, queries, eval_mode): 26 | """Get embeddings and biases of target entities.""" 27 | if eval_mode: 28 | return self.entity.weight, self.bt.weight 29 | else: 30 | return self.entity(queries[:, 2]), self.bt(queries[:, 2]) 31 | 32 | def similarity_score(self, lhs_e, rhs_e, eval_mode): 33 | """Compute similarity scores or queries against targets in embedding space.""" 34 | if self.sim == "dot": 35 | if eval_mode: 36 | score = lhs_e @ rhs_e.transpose(0, 1) 37 | else: 38 | score = torch.sum(lhs_e * rhs_e, dim=-1, keepdim=True) 39 | else: 40 | score = - euc_sqdistance(lhs_e, rhs_e, eval_mode) 41 | return score 42 | 43 | 44 | class TransE(BaseE): 45 | """Euclidean translations https://www.utc.fr/~bordesan/dokuwiki/_media/en/transe_nips13.pdf""" 46 | 47 | def __init__(self, args): 48 | super(TransE, self).__init__(args) 49 | self.sim = "dist" 50 | 51 | def get_queries(self, queries): 52 | head_e = self.entity(queries[:, 0]) 53 | rel_e = self.rel(queries[:, 1]) 54 | lhs_e = head_e + rel_e 55 | lhs_biases = self.bh(queries[:, 0]) 56 | return lhs_e, lhs_biases 57 | 58 | class Distmult(BaseE): 59 | 60 | def __init__(self, args): 61 | super(Distmult, self).__init__(args) 62 | self.sim = "dist" 63 | 64 | def get_queries(self, queries): 65 | head_e = self.entity(queries[:, 0]) 66 | rel_e = self.rel(queries[:, 1]) 67 | lhs_e = head_e * rel_e 68 | lhs_biases = self.bh(queries[:, 0]) 69 | return lhs_e, lhs_biases 70 | 71 | class CP(BaseE): 72 | """Canonical tensor decomposition https://arxiv.org/pdf/1806.07297.pdf""" 73 | 74 | def __init__(self, args): 75 | super(CP, self).__init__(args) 76 | self.sim = "dot" 77 | 78 | def get_queries(self, queries: torch.Tensor): 79 | """Compute embedding and biases of queries.""" 80 | return self.entity(queries[:, 0]) * self.rel(queries[:, 1]), self.bh(queries[:, 0]) 81 | 82 | 83 | class MurE(BaseE): 84 | """Diagonal scaling https://arxiv.org/pdf/1905.09791.pdf""" 85 | 86 | def __init__(self, args): 87 | super(MurE, self).__init__(args) 88 | self.rel_diag = nn.Embedding(self.sizes[1], self.rank) 89 | self.rel_diag.weight.data = 2 * torch.rand((self.sizes[1], self.rank), dtype=self.data_type) - 1.0 90 | self.sim = "dist" 91 | 92 | def get_queries(self, queries: torch.Tensor): 93 | """Compute embedding and biases of queries.""" 94 | lhs_e = self.rel_diag(queries[:, 1]) * self.entity(queries[:, 0]) + self.rel(queries[:, 1]) 95 | lhs_biases = self.bh(queries[:, 0]) 96 | return lhs_e, lhs_biases 97 | 98 | 99 | class RotE(BaseE): 100 | """Euclidean 2x2 Givens rotations""" 101 | 102 | def __init__(self, args): 103 | super(RotE, self).__init__(args) 104 | self.rel_diag = nn.Embedding(self.sizes[1], self.rank) 105 | self.rel_diag.weight.data = 2 * torch.rand((self.sizes[1], self.rank), dtype=self.data_type) - 1.0 106 | self.sim = "dist" 107 | 108 | def get_queries(self, queries: torch.Tensor): 109 | """Compute embedding and biases of queries.""" 110 | lhs_e = givens_rotations(self.rel_diag(queries[:, 1]), self.entity(queries[:, 0])) + self.rel(queries[:, 1]) 111 | lhs_biases = self.bh(queries[:, 0]) 112 | return lhs_e, lhs_biases 113 | 114 | 115 | class RefE(BaseE): 116 | """Euclidean 2x2 Givens reflections""" 117 | 118 | def __init__(self, args): 119 | super(RefE, self).__init__(args) 120 | self.rel_diag = nn.Embedding(self.sizes[1], self.rank) 121 | self.rel_diag.weight.data = 2 * torch.rand((self.sizes[1], self.rank), dtype=self.data_type) - 1.0 122 | self.sim = "dist" 123 | 124 | def get_queries(self, queries): 125 | """Compute embedding and biases of queries.""" 126 | lhs = givens_reflection(self.rel_diag(queries[:, 1]), self.entity(queries[:, 0])) 127 | rel = self.rel(queries[:, 1]) 128 | lhs_biases = self.bh(queries[:, 0]) 129 | return lhs + rel, lhs_biases 130 | 131 | 132 | class AttE(BaseE): 133 | """Euclidean attention model combining translations, reflections and rotations""" 134 | 135 | def __init__(self, args): 136 | super(AttE, self).__init__(args) 137 | self.sim = "dist" 138 | 139 | # reflection 140 | self.ref = nn.Embedding(self.sizes[1], self.rank) 141 | self.ref.weight.data = 2 * torch.rand((self.sizes[1], self.rank), dtype=self.data_type) - 1.0 142 | 143 | # rotation 144 | self.rot = nn.Embedding(self.sizes[1], self.rank) 145 | self.rot.weight.data = 2 * torch.rand((self.sizes[1], self.rank), dtype=self.data_type) - 1.0 146 | 147 | # attention 148 | self.context_vec = nn.Embedding(self.sizes[1], self.rank) 149 | self.act = nn.Softmax(dim=1) 150 | self.scale = torch.Tensor([1. / np.sqrt(self.rank)]).cuda() 151 | 152 | def get_reflection_queries(self, queries): 153 | lhs_ref_e = givens_reflection( 154 | self.ref(queries[:, 1]), self.entity(queries[:, 0]) 155 | ) 156 | return lhs_ref_e 157 | 158 | def get_rotation_queries(self, queries): 159 | lhs_rot_e = givens_rotations( 160 | self.rot(queries[:, 1]), self.entity(queries[:, 0]) 161 | ) 162 | return lhs_rot_e 163 | 164 | def get_queries(self, queries): 165 | """Compute embedding and biases of queries.""" 166 | lhs_ref_e = self.get_reflection_queries(queries).view((-1, 1, self.rank)) 167 | lhs_rot_e = self.get_rotation_queries(queries).view((-1, 1, self.rank)) 168 | 169 | # self-attention mechanism 170 | cands = torch.cat([lhs_ref_e, lhs_rot_e], dim=1) 171 | context_vec = self.context_vec(queries[:, 1]).view((-1, 1, self.rank)) 172 | att_weights = torch.sum(context_vec * cands * self.scale, dim=-1, keepdim=True) 173 | att_weights = self.act(att_weights) 174 | lhs_e = torch.sum(att_weights * cands, dim=1) + self.rel(queries[:, 1]) 175 | return lhs_e, self.bh(queries[:, 0]) 176 | 177 | class SEA(BaseE): 178 | """Euclidean attention model combining several query representations""" 179 | def __init__(self, args): 180 | super(SEA, self).__init__(args) 181 | self.sim = "dist" 182 | 183 | # reflection 184 | self.ref = nn.Embedding(self.sizes[1], self.rank) 185 | self.ref.weight.data = 2 * torch.rand((self.sizes[1], self.rank), dtype=self.data_type) - 1.0 186 | 187 | # rotation 188 | self.rot = nn.Embedding(self.sizes[1], self.rank) 189 | self.rot.weight.data = 2 * torch.rand((self.sizes[1], self.rank), dtype=self.data_type) - 1.0 190 | 191 | # translation 192 | self.tr = nn.Embedding(self.sizes[1], self.rank) 193 | self.tr.weight.data = 2 * torch.rand((self.sizes[1], self.rank), dtype=self.data_type) - 1.0 194 | 195 | # distmult 196 | self.dm = nn.Embedding(self.sizes[1], self.rank) 197 | self.dm.weight.data = 2 * torch.rand((self.sizes[1], self.rank), dtype=self.data_type) - 1.0 198 | 199 | # complex 200 | self.cp = nn.Embedding(self.sizes[1], self.rank) 201 | self.cp.weight.data = 2 * torch.rand((self.sizes[1], self.rank), dtype=self.data_type) - 1.0 202 | 203 | # attention 204 | self.context_vec = nn.Embedding(self.sizes[1], self.rank) 205 | self.act = nn.Softmax(dim=1) 206 | self.scale = torch.Tensor([1. / np.sqrt(self.rank)]).cuda() 207 | 208 | def get_reflection_queries(self, queries): 209 | lhs_ref_e = givens_reflection( 210 | self.ref(queries[:, 1]), self.entity(queries[:, 0]) 211 | ) 212 | return lhs_ref_e 213 | 214 | def get_rotation_queries(self, queries): 215 | lhs_rot_e = givens_rotations( 216 | self.rot(queries[:, 1]), self.entity(queries[:, 0]) 217 | ) 218 | return lhs_rot_e 219 | def get_transe_queries(self, queries): 220 | tr = self.tr(queries[:, 1]) 221 | h = self.entity(queries[:, 0]) 222 | lhs_tr_e = h + tr 223 | return lhs_tr_e 224 | 225 | def get_complex_queries(self, queries): 226 | cp = self.cp(queries[:, 1]) 227 | cp = cp[:,:self.rank//2], cp[:,self.rank//2:] 228 | h = self.entity(queries[:, 0]) 229 | h = h[:,:self.rank//2], h[:,self.rank//2:] 230 | lhse_cp_e = h[0] * cp[0] - h[1] * cp[1], h[0] * cp[1] + h[1] * cp[0] 231 | lhs_cp_e = torch.cat((lhse_cp_e[0], lhse_cp_e[1]), dim = 1) 232 | return lhs_cp_e 233 | 234 | def get_distmult_queries(self, queries): 235 | dm = self.dm(queries[:, 1]) 236 | h = self.entity(queries[:, 0]) 237 | lhs_dm_e = h * dm 238 | return lhs_dm_e 239 | 240 | 241 | def get_queries(self, queries): 242 | """Compute embedding and biases of queries.""" 243 | #lhs_ref_e = self.get_reflection_queries(queries).view((-1, 1, self.rank)) 244 | #lhs_rot_e = self.get_rotation_queries(queries).view((-1, 1, self.rank)) 245 | lhs_tr_e = self.get_transe_queries(queries).view((-1, 1, self.rank)) 246 | lhs_cp_e = self.get_complex_queries(queries).view((-1, 1, self.rank)) 247 | lhs_dm_e = self.get_distmult_queries(queries).view((-1, 1, self.rank)) 248 | 249 | # self-attention mechanism 250 | # Add here all the KGE query representations (lhs_kge_e) you want to combine 251 | #cands = torch.cat([lhs_ref_e, lhs_rot_e, lhs_tr_e, lhs_cp_e, lhs_dm_e], dim=1) 252 | cands = torch.cat([lhs_tr_e, lhs_cp_e, lhs_dm_e], dim=1) 253 | # cands = torch.cat([lhs_tr_e, lhs_cp_e, lhs_dm_e, lhs_rot_e, lhs_ref_e], dim=1) 254 | 255 | context_vec = self.context_vec(queries[:, 1]).view((-1, 1, self.rank)) 256 | att_weights = torch.sum(context_vec * cands * self.scale, dim=-1, keepdim=True) 257 | att_weights = self.act(att_weights) 258 | 259 | 260 | # regularization 261 | #reg_att_weights = torch.mul(att_weights,att_weights) 262 | #att_sum = torch.sum(att_weights,dim=1) 263 | #att_normalizer = torch.div(1,att_sum) 264 | #norm_att_weights = torch.mul(att_weights,att_normalizer.unsqueeze(-1)) 265 | 266 | # save alphas 267 | self.att = att_weights 268 | 269 | lhs_e = torch.sum(att_weights * cands, dim=1) + self.rel(queries[:, 1]) 270 | 271 | return lhs_e, self.bh(queries[:, 0]) 272 | -------------------------------------------------------------------------------- /models/hyperbolic.py: -------------------------------------------------------------------------------- 1 | """Hyperbolic Knowledge Graph embedding models where all parameters are defined in tangent spaces.""" 2 | import numpy as np 3 | import torch 4 | import torch.nn.functional as F 5 | from torch import nn 6 | from numpy import asarray 7 | from numpy import savetxt 8 | from models.base import KGModel 9 | from utils.euclidean import givens_rotations, givens_reflection 10 | from utils.hyperbolic import mobius_add, expmap0, project, hyp_distance_multi_c 11 | 12 | HYP_MODELS = ["RotH", "RefH", "AttH", "SEPA"] 13 | 14 | 15 | class BaseH(KGModel): 16 | """Trainable curvature for each relationship.""" 17 | 18 | def __init__(self, args): 19 | super(BaseH, self).__init__(args.sizes, args.rank, args.dropout, args.gamma, args.dtype, args.bias, 20 | args.init_size) 21 | self.entity.weight.data = self.init_size * torch.randn((self.sizes[0], self.rank), dtype=self.data_type) 22 | self.rel.weight.data = self.init_size * torch.randn((self.sizes[1], 2 * self.rank), dtype=self.data_type) 23 | self.rel_diag = nn.Embedding(self.sizes[1], self.rank) 24 | self.rel_diag.weight.data = 2 * torch.rand((self.sizes[1], self.rank), dtype=self.data_type) - 1.0 25 | self.multi_c = args.multi_c 26 | self.att = [] 27 | if self.multi_c: 28 | c_init = torch.ones((self.sizes[1], 1), dtype=self.data_type) 29 | else: 30 | c_init = torch.ones((1, 1), dtype=self.data_type) 31 | self.c = nn.Parameter(c_init, requires_grad=True) 32 | 33 | def get_rhs(self, queries, eval_mode): 34 | """Get embeddings and biases of target entities.""" 35 | if eval_mode: 36 | return self.entity.weight, self.bt.weight 37 | else: 38 | return self.entity(queries[:, 2]), self.bt(queries[:, 2]) 39 | 40 | def similarity_score(self, lhs_e, rhs_e, eval_mode): 41 | """Compute similarity scores or queries against targets in embedding space.""" 42 | lhs_e, c = lhs_e 43 | return - hyp_distance_multi_c(lhs_e, rhs_e, c, eval_mode) ** 2 44 | 45 | 46 | class RotH(BaseH): 47 | """Hyperbolic 2x2 Givens rotations""" 48 | 49 | def get_queries(self, queries): 50 | """Compute embedding and biases of queries.""" 51 | c = F.softplus(self.c[queries[:, 1]]) 52 | head = expmap0(self.entity(queries[:, 0]), c) 53 | rel1, rel2 = torch.chunk(self.rel(queries[:, 1]), 2, dim=1) 54 | rel1 = expmap0(rel1, c) 55 | rel2 = expmap0(rel2, c) 56 | lhs = project(mobius_add(head, rel1, c), c) 57 | res1 = givens_rotations(self.rel_diag(queries[:, 1]), lhs) 58 | res2 = mobius_add(res1, rel2, c) 59 | return (res2, c), self.bh(queries[:, 0]) 60 | 61 | 62 | class RefH(BaseH): 63 | """Hyperbolic 2x2 Givens reflections""" 64 | 65 | def get_queries(self, queries): 66 | """Compute embedding and biases of queries.""" 67 | c = F.softplus(self.c[queries[:, 1]]) 68 | rel, _ = torch.chunk(self.rel(queries[:, 1]), 2, dim=1) 69 | rel = expmap0(rel, c) 70 | lhs = givens_reflection(self.rel_diag(queries[:, 1]), self.entity(queries[:, 0])) 71 | lhs = expmap0(lhs, c) 72 | res = project(mobius_add(lhs, rel, c), c) 73 | return (res, c), self.bh(queries[:, 0]) 74 | 75 | 76 | class AttH(BaseH): 77 | """Hyperbolic attention model combining translations, reflections and rotations""" 78 | 79 | def __init__(self, args): 80 | super(AttH, self).__init__(args) 81 | self.rel_diag = nn.Embedding(self.sizes[1], 2 * self.rank) 82 | self.rel_diag.weight.data = 2 * torch.rand((self.sizes[1], 2 * self.rank), dtype=self.data_type) - 1.0 83 | #relation 84 | self.context_vec = nn.Embedding(self.sizes[1], self.rank) 85 | self.context_vec.weight.data = self.init_size * torch.randn((self.sizes[1], self.rank), dtype=self.data_type) 86 | self.act = nn.Softmax(dim=1) 87 | if args.dtype == "double": 88 | self.scale = torch.Tensor([1. / np.sqrt(self.rank)]).double().cuda() 89 | else: 90 | self.scale = torch.Tensor([1. / np.sqrt(self.rank)]).cuda() 91 | 92 | def get_queries(self, queries): 93 | """Compute embedding and biases of queries.""" 94 | c = F.softplus(self.c[queries[:, 1]]) 95 | head = self.entity(queries[:, 0]) 96 | rot_mat, ref_mat = torch.chunk(self.rel_diag(queries[:, 1]), 2, dim=1) 97 | rot_q = givens_rotations(rot_mat, head).view((-1, 1, self.rank)) 98 | ref_q = givens_reflection(ref_mat, head).view((-1, 1, self.rank)) 99 | cands = torch.cat([ref_q, rot_q], dim=1) 100 | context_vec = self.context_vec(queries[:, 1]).view((-1, 1, self.rank)) 101 | att_weights = torch.sum(context_vec * cands * self.scale, dim=-1, keepdim=True) 102 | att_weights = self.act(att_weights) 103 | att_q = torch.sum(att_weights * cands, dim=1) 104 | lhs = expmap0(att_q, c) 105 | rel, _ = torch.chunk(self.rel(queries[:, 1]), 2, dim=1) 106 | rel = expmap0(rel, c) 107 | res = project(mobius_add(lhs, rel, c), c) 108 | return (res, c), self.bh(queries[:, 0]) 109 | 110 | class SEPA(BaseH): 111 | """Hyperbolic attention model combining several query representations""" 112 | 113 | def __init__(self, args): 114 | super(SEPA, self).__init__(args) 115 | 116 | # reflection 117 | self.ref = nn.Embedding(self.sizes[1], self.rank) 118 | self.ref.weight.data = 2 * torch.rand((self.sizes[1], self.rank), dtype=self.data_type) - 1.0 119 | 120 | # rotation 121 | self.rot = nn.Embedding(self.sizes[1], self.rank) 122 | self.rot.weight.data = 2 * torch.rand((self.sizes[1], self.rank), dtype=self.data_type) - 1.0 123 | 124 | # translation 125 | self.tr = nn.Embedding(self.sizes[1], self.rank) 126 | self.tr.weight.data = 2 * torch.rand((self.sizes[1], self.rank), dtype=self.data_type) - 1.0 127 | 128 | # distmult 129 | self.dm = nn.Embedding(self.sizes[1], self.rank) 130 | self.dm.weight.data = 2 * torch.rand((self.sizes[1], self.rank), dtype=self.data_type) - 1.0 131 | 132 | # complex 133 | self.cp = nn.Embedding(self.sizes[1], self.rank) 134 | self.cp.weight.data = 2 * torch.rand((self.sizes[1], self.rank), dtype=self.data_type) - 1.0 135 | 136 | # attention 137 | self.context_vec = nn.Embedding(self.sizes[1], self.rank) 138 | self.context_vec.weight.data = self.init_size * torch.randn((self.sizes[1], self.rank), dtype=self.data_type) 139 | self.act = nn.Softmax(dim=1) 140 | self.args = args 141 | if args.dtype == "double": 142 | self.scale = torch.Tensor([1. / np.sqrt(self.rank)]).double().cuda() 143 | else: 144 | self.scale = torch.Tensor([1. / np.sqrt(self.rank)]).cuda() 145 | 146 | def get_reflection_queries(self, queries): 147 | lhs_ref_e = givens_reflection( 148 | self.ref(queries[:, 1]), self.entity(queries[:, 0]) 149 | ) 150 | return lhs_ref_e 151 | 152 | def get_rotation_queries(self, queries): 153 | lhs_rot_e = givens_rotations( 154 | self.rot(queries[:, 1]), self.entity(queries[:, 0]) 155 | ) 156 | return lhs_rot_e 157 | 158 | def get_transe_queries(self, queries): 159 | tr = self.tr(queries[:, 1]) 160 | h = self.entity(queries[:, 0]) 161 | lhs_tr_e = h + tr 162 | return lhs_tr_e 163 | 164 | def get_complex_queries(self, queries): 165 | cp = self.cp(queries[:, 1]) 166 | cp = cp[:,:self.rank//2], cp[:,self.rank//2:] 167 | h = self.entity(queries[:, 0]) 168 | h = h[:,:self.rank//2], h[:,self.rank//2:] 169 | lhse_cp_e = h[0] * cp[0] - h[1] * cp[1], h[0] * cp[1] + h[1] * cp[0] 170 | lhs_cp_e = torch.cat((lhse_cp_e[0], lhse_cp_e[1]), dim = 1) 171 | return lhs_cp_e 172 | 173 | def get_distmult_queries(self, queries): 174 | dm = self.dm(queries[:, 1]) 175 | h = self.entity(queries[:, 0]) 176 | lhs_dm_e = h * dm 177 | return lhs_dm_e 178 | 179 | 180 | def get_queries(self, queries): 181 | """Compute embedding and biases of queries.""" 182 | c = F.softplus(self.c[queries[:, 1]]) 183 | #lhs_ref_e = self.get_reflection_queries(queries).view((-1, 1, self.rank)) 184 | #lhs_rot_e = self.get_rotation_queries(queries).view((-1, 1, self.rank)) 185 | lhs_tr_e = self.get_transe_queries(queries).view((-1, 1, self.rank)) 186 | lhs_cp_e = self.get_complex_queries(queries).view((-1, 1, self.rank)) 187 | lhs_dm_e = self.get_distmult_queries(queries).view((-1, 1, self.rank)) 188 | 189 | # self-attention mechanism 190 | # Add here all the KGE query representations (lhs_kge_e) you want to combine 191 | 192 | #cands = torch.cat([lhs_ref_e, lhs_rot_e, lhs_tr_e, lhs_cp_e, lhs_dm_e], dim=1) 193 | cands = torch.cat([lhs_tr_e, lhs_cp_e, lhs_dm_e], dim=1) 194 | #cands = torch.cat([lhs_ref_e, lhs_rot_e, lhs_tr_e, lhs_dm_e], dim=1) #mulde 195 | 196 | context_vec = self.context_vec(queries[:, 1]).view((-1, 1, self.rank)) 197 | att_weights = torch.sum(context_vec * cands * self.scale, dim=-1, keepdim=True) 198 | att_weights = self.act(att_weights) 199 | 200 | 201 | # regularization 202 | #reg_att_weights = torch.mul(att_weights,att_weights) 203 | #att_sum = torch.sum(reg_att_weights,dim=1) 204 | #att_normalizer = torch.div(1,att_sum) 205 | #norm_att_weights = torch.mul(reg_att_weights,att_normalizer.unsqueeze(-1)) 206 | 207 | # save alphas 208 | self.att = att_weights 209 | 210 | att_q = torch.sum(att_weights * cands, dim=1) 211 | lhs = expmap0(att_q, c) 212 | rel, _ = torch.chunk(self.rel(queries[:, 1]), 2, dim=1) 213 | rel = expmap0(rel, c) 214 | res = project(mobius_add(lhs, rel, c), c) 215 | return (res, c), self.bh(queries[:, 0]) 216 | 217 | 218 | -------------------------------------------------------------------------------- /optimizers/__init__.py: -------------------------------------------------------------------------------- 1 | from .kg_optimizer import KGOptimizer 2 | from .regularizers import * 3 | -------------------------------------------------------------------------------- /optimizers/kg_optimizer.py: -------------------------------------------------------------------------------- 1 | """Knowledge Graph embedding model optimizer.""" 2 | import numpy as np 3 | import torch 4 | import torch.nn.functional as F 5 | import tqdm 6 | from torch import nn 7 | 8 | 9 | class KGOptimizer(object): 10 | """Knowledge Graph embedding model optimizer. 11 | 12 | KGOptimizers performs loss computations with negative sampling and gradient descent steps. 13 | 14 | Attributes: 15 | model: models.base.KGModel 16 | regularizer: regularizers.Regularizer 17 | optimizer: torch.optim.Optimizer 18 | batch_size: An integer for the training batch size 19 | neg_sample_size: An integer for the number of negative samples 20 | double_neg: A boolean (True to sample both head and tail entities) 21 | """ 22 | 23 | def __init__( 24 | self, model, regularizer, optimizer, batch_size, neg_sample_size, double_neg, verbose=True): 25 | """Inits KGOptimizer.""" 26 | self.model = model 27 | self.regularizer = regularizer 28 | self.optimizer = optimizer 29 | self.batch_size = batch_size 30 | self.verbose = verbose 31 | self.double_neg = double_neg 32 | self.loss_fn = nn.CrossEntropyLoss(reduction='mean') 33 | #self.loss_fn = nn.BCEWithLogitsLoss() 34 | self.neg_sample_size = neg_sample_size 35 | self.n_entities = model.sizes[0] 36 | 37 | def reduce_lr(self, factor=0.8): 38 | """Reduce learning rate. 39 | 40 | Args: 41 | factor: float for the learning rate decay 42 | """ 43 | for param_group in self.optimizer.param_groups: 44 | param_group['lr'] *= factor 45 | 46 | def get_neg_samples(self, input_batch): 47 | """Sample negative examples. 48 | 49 | Args: 50 | input_batch: torch.LongTensor of shape (batch_size x 3) with ground truth training triples 51 | 52 | Returns: 53 | negative_batch: torch.Tensor of shape (neg_sample_size x 3) with negative examples 54 | """ 55 | negative_batch = input_batch.repeat(self.neg_sample_size, 1) 56 | batch_size = input_batch.shape[0] 57 | negsamples = torch.Tensor(np.random.randint( 58 | self.n_entities, 59 | size=batch_size * self.neg_sample_size) 60 | ).to(input_batch.dtype) 61 | negative_batch[:, 2] = negsamples 62 | if self.double_neg: 63 | negsamples = torch.Tensor(np.random.randint( 64 | self.n_entities, 65 | size=batch_size * self.neg_sample_size) 66 | ).to(input_batch.dtype) 67 | negative_batch[:, 0] = negsamples 68 | return negative_batch 69 | 70 | def neg_sampling_loss(self, input_batch): 71 | """Compute KG embedding loss with negative sampling. 72 | 73 | Args: 74 | input_batch: torch.LongTensor of shape (batch_size x 3) with ground truth training triples. 75 | 76 | Returns: 77 | loss: torch.Tensor with negative sampling embedding loss 78 | factors: torch.Tensor with embeddings weights to regularize 79 | """ 80 | # positive samples 81 | positive_score, factors = self.model(input_batch) 82 | positive_score = F.logsigmoid(positive_score) 83 | 84 | # negative samples 85 | neg_samples = self.get_neg_samples(input_batch) 86 | negative_score, _ = self.model(neg_samples) 87 | negative_score = F.logsigmoid(-negative_score) 88 | loss = - torch.cat([positive_score, negative_score], dim=0).mean() 89 | return loss, factors 90 | 91 | def no_neg_sampling_loss(self, input_batch): 92 | """Compute KG embedding loss without negative sampling. 93 | 94 | Args: 95 | input_batch: torch.LongTensor of shape (batch_size x 3) with ground truth training triples 96 | 97 | Returns: 98 | loss: torch.Tensor with embedding loss 99 | factors: torch.Tensor with embeddings weights to regularize 100 | """ 101 | predictions, factors = self.model(input_batch, eval_mode=True) 102 | truth = input_batch[:, 2] 103 | log_prob = F.logsigmoid(-predictions) 104 | idx = torch.arange(0, truth.shape[0], dtype=truth.dtype) 105 | pos_scores = F.logsigmoid(predictions[idx, truth]) - F.logsigmoid(-predictions[idx, truth]) 106 | log_prob[idx, truth] += pos_scores 107 | loss = - log_prob.mean() 108 | loss += self.regularizer.forward(factors) 109 | return loss, factors 110 | 111 | def calculate_loss(self, input_batch): 112 | """Compute KG embedding loss and regularization loss. 113 | 114 | Args: 115 | input_batch: torch.LongTensor of shape (batch_size x 3) with ground truth training triples 116 | 117 | Returns: 118 | loss: torch.Tensor with embedding loss and regularization loss 119 | """ 120 | if self.neg_sample_size > 0: 121 | loss, factors = self.neg_sampling_loss(input_batch) 122 | else: 123 | predictions, factors = self.model(input_batch, eval_mode=True) 124 | truth = input_batch[:, 2] 125 | loss = self.loss_fn(predictions, truth) 126 | # loss, factors = self.no_neg_sampling_loss(input_batch) 127 | 128 | # regularization loss 129 | loss += self.regularizer.forward(factors) 130 | 131 | # add epsilon contribution to the loss function 132 | 133 | return loss 134 | 135 | def calculate_valid_loss(self, examples): 136 | """Compute KG embedding loss over validation examples. 137 | 138 | Args: 139 | examples: torch.LongTensor of shape (N_valid x 3) with validation triples 140 | 141 | Returns: 142 | loss: torch.Tensor with loss averaged over all validation examples 143 | """ 144 | b_begin = 0 145 | loss = 0.0 146 | counter = 0 147 | with torch.no_grad(): 148 | while b_begin < examples.shape[0]: 149 | input_batch = examples[ 150 | b_begin:b_begin + self.batch_size 151 | ].cuda() 152 | b_begin += self.batch_size 153 | loss += self.calculate_loss(input_batch) 154 | counter += 1 155 | loss /= counter 156 | return loss 157 | 158 | def epoch(self, examples): 159 | """Runs one epoch of training KG embedding model. 160 | 161 | Args: 162 | examples: torch.LongTensor of shape (N_train x 3) with training triples 163 | 164 | Returns: 165 | loss: torch.Tensor with loss averaged over all training examples 166 | """ 167 | actual_examples = examples[torch.randperm(examples.shape[0]), :] 168 | with tqdm.tqdm(total=examples.shape[0], unit='ex', disable=not self.verbose) as bar: 169 | bar.set_description(f'train loss') 170 | b_begin = 0 171 | total_loss = 0.0 172 | counter = 0 173 | while b_begin < examples.shape[0]: 174 | input_batch = actual_examples[ 175 | b_begin:b_begin + self.batch_size 176 | ].cuda() 177 | 178 | # gradient step 179 | l = self.calculate_loss(input_batch) 180 | self.optimizer.zero_grad() 181 | l.backward() 182 | self.optimizer.step() 183 | 184 | b_begin += self.batch_size 185 | total_loss += l 186 | counter += 1 187 | bar.update(input_batch.shape[0]) 188 | bar.set_postfix(loss=f'{l.item():.4f}') 189 | total_loss /= counter 190 | return total_loss 191 | -------------------------------------------------------------------------------- /optimizers/regularizers.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | # 7 | from abc import ABC, abstractmethod 8 | from typing import Tuple 9 | 10 | import torch 11 | from torch import nn 12 | 13 | 14 | class Regularizer(nn.Module, ABC): 15 | @abstractmethod 16 | def forward(self, factors: Tuple[torch.Tensor]): 17 | pass 18 | 19 | 20 | class F2(Regularizer): 21 | def __init__(self, weight: float): 22 | super(F2, self).__init__() 23 | self.weight = weight 24 | 25 | def forward(self, factors): 26 | norm = 0 27 | for f in factors: 28 | norm += self.weight * torch.sum(f ** 2) 29 | return norm / factors[0].shape[0] 30 | 31 | 32 | class N3(Regularizer): 33 | def __init__(self, weight: float): 34 | super(N3, self).__init__() 35 | self.weight = weight 36 | 37 | def forward(self, factors): 38 | """Regularized complex embeddings https://arxiv.org/pdf/1806.07297.pdf""" 39 | norm = 0 40 | for f in factors: 41 | norm += self.weight * torch.sum( 42 | torch.abs(f) ** 3 43 | ) 44 | return norm / factors[0].shape[0] 45 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | torch 3 | tqdm 4 | -------------------------------------------------------------------------------- /run.py: -------------------------------------------------------------------------------- 1 | """Train Knowledge Graph embeddings for link prediction.""" 2 | 3 | import argparse 4 | import json 5 | import logging 6 | import os 7 | import io 8 | 9 | import torch 10 | import torch.optim 11 | 12 | import models 13 | import optimizers.regularizers as regularizers 14 | from datasets.kg_dataset import KGDataset 15 | from models import all_models 16 | from optimizers.kg_optimizer import KGOptimizer 17 | from utils.train import get_savedir, avg_both, format_metrics, count_params 18 | from itertools import combinations 19 | parser = argparse.ArgumentParser( 20 | description="Knowledge Graph Embedding" 21 | ) 22 | parser.add_argument( 23 | "--dataset", default="WN18RR", choices=["FB15K", "WN", "WN18RR", "FB237", "YAGO3-10", "NELL-995-h100","NELL-995-h75", "NELL-995-h50", "NELL-995-h25"], 24 | help="Knowledge Graph dataset" 25 | ) 26 | parser.add_argument( 27 | "--model", default="RotE", choices=all_models, help="Knowledge Graph embedding model" 28 | ) 29 | parser.add_argument( 30 | "--regularizer", choices=["N3", "F2"], default="N3", help="Regularizer" 31 | ) 32 | parser.add_argument( 33 | "--reg", default=0, type=float, help="Regularization weight" 34 | ) 35 | parser.add_argument( 36 | "--optimizer", choices=["Adagrad", "Adam", "SparseAdam"], default="Adagrad", 37 | help="Optimizer" 38 | ) 39 | parser.add_argument( 40 | "--max_epochs", default=50, type=int, help="Maximum number of epochs to train for" 41 | ) 42 | parser.add_argument( 43 | "--patience", default=10, type=int, help="Number of epochs before early stopping" 44 | ) 45 | parser.add_argument( 46 | "--valid", default=3, type=float, help="Number of epochs before validation" 47 | ) 48 | parser.add_argument( 49 | "--rank", default=1000, type=int, help="Embedding dimension" 50 | ) 51 | parser.add_argument( 52 | "--batch_size", default=1000, type=int, help="Batch size" 53 | ) 54 | parser.add_argument( 55 | "--neg_sample_size", default=50, type=int, help="Negative sample size, -1 to not use negative sampling" 56 | ) 57 | parser.add_argument( 58 | "--dropout", default=0, type=float, help="Dropout rate" 59 | ) 60 | parser.add_argument( 61 | "--init_size", default=1e-3, type=float, help="Initial embeddings' scale" 62 | ) 63 | parser.add_argument( 64 | "--learning_rate", default=1e-1, type=float, help="Learning rate" 65 | ) 66 | parser.add_argument( 67 | "--gamma", default=0, type=float, help="Margin for distance-based losses" 68 | ) 69 | parser.add_argument( 70 | "--bias", default="constant", type=str, choices=["constant", "learn", "none"], help="Bias type (none for no bias)" 71 | ) 72 | parser.add_argument( 73 | "--dtype", default="double", type=str, choices=["single", "double"], help="Machine precision" 74 | ) 75 | parser.add_argument( 76 | "--double_neg", action="store_true", 77 | help="Whether to negative sample both head and tail entities" 78 | ) 79 | parser.add_argument( 80 | "--debug", action="store_true", 81 | help="Only use 1000 examples for debugging" 82 | ) 83 | parser.add_argument( 84 | "--multi_c", action="store_true", help="Multiple curvatures per relation" 85 | ) 86 | 87 | parser.add_argument( 88 | "--cuda_n", default=3, type=int, help="Cuda core number" 89 | ) 90 | 91 | 92 | 93 | def train(args): 94 | save_dir = get_savedir(args.model, args.dataset) 95 | 96 | # file logger 97 | logging.basicConfig( 98 | format="%(asctime)s %(levelname)-8s %(message)s", 99 | level=logging.INFO, 100 | datefmt="%Y-%m-%d %H:%M:%S", 101 | filename=os.path.join(save_dir, "train.log") 102 | ) 103 | 104 | # stdout logger 105 | console = logging.StreamHandler() 106 | console.setLevel(logging.INFO) 107 | formatter = logging.Formatter("%(asctime)s %(levelname)-8s %(message)s") 108 | console.setFormatter(formatter) 109 | logging.getLogger("").addHandler(console) 110 | logging.info("Saving logs in: {}".format(save_dir)) 111 | 112 | # create dataset 113 | dataset_path = os.path.join(os.environ["DATA_PATH"], args.dataset) 114 | dataset = KGDataset(dataset_path, args.debug) 115 | args.sizes = dataset.get_shape() 116 | 117 | # load data 118 | logging.info("\t " + str(dataset.get_shape())) 119 | train_examples = dataset.get_examples("train") 120 | valid_examples = dataset.get_examples("valid") 121 | test_examples = dataset.get_examples("test") 122 | filters = dataset.get_filters() 123 | 124 | # save config 125 | with open(os.path.join(save_dir, "config.json"), "w") as fjson: 126 | json.dump(vars(args), fjson) 127 | 128 | # create model 129 | model = getattr(models, args.model)(args) 130 | total = count_params(model) 131 | logging.info("Total number of parameters {}".format(total)) 132 | device = "cuda" 133 | model.to(device) 134 | 135 | # get optimizer 136 | regularizer = getattr(regularizers, args.regularizer)(args.reg) 137 | optim_method = getattr(torch.optim, args.optimizer)(model.parameters(), lr=args.learning_rate) 138 | optimizer = KGOptimizer(model, regularizer, optim_method, args.batch_size, args.neg_sample_size, 139 | bool(args.double_neg)) 140 | counter = 0 141 | best_mrr = None 142 | best_epoch = None 143 | logging.info("\t Start training") 144 | for step in range(args.max_epochs): 145 | 146 | # Train step 147 | model.train() 148 | train_loss = optimizer.epoch(train_examples) 149 | logging.info("\t Epoch {} | average train loss: {:.4f}".format(step, train_loss)) 150 | 151 | # Valid step 152 | model.eval() 153 | valid_loss = optimizer.calculate_valid_loss(valid_examples) 154 | logging.info("\t Epoch {} | average valid loss: {:.4f}".format(step, valid_loss)) 155 | 156 | if (step + 1) % args.valid == 0: 157 | valid_metrics = avg_both(*model.compute_metrics(valid_examples, filters)) 158 | logging.info(format_metrics(valid_metrics, split="valid")) 159 | 160 | valid_mrr = valid_metrics["MRR"] 161 | if not best_mrr or valid_mrr > best_mrr: 162 | best_mrr = valid_mrr 163 | counter = 0 164 | best_epoch = step 165 | logging.info("\t Saving model at epoch {} in {}".format(step, save_dir)) 166 | torch.save(model.cpu().state_dict(), os.path.join(save_dir, "model.pt")) 167 | model.cuda() 168 | else: 169 | counter += 1 170 | if counter == args.patience: 171 | logging.info("\t Early stopping") 172 | break 173 | elif counter == args.patience // 2: 174 | pass 175 | # logging.info("\t Reducing learning rate") 176 | # optimizer.reduce_lr() 177 | 178 | logging.info("\t Optimization finished") 179 | if not best_mrr: 180 | torch.save(model.cpu().state_dict(), os.path.join(save_dir, "model.pt")) 181 | else: 182 | logging.info("\t Loading best model saved at epoch {}".format(best_epoch)) 183 | model.load_state_dict(torch.load(os.path.join(save_dir, "model.pt"))) 184 | model.cuda() 185 | model.eval() 186 | 187 | # Validation metrics 188 | valid_metrics = avg_both(*model.compute_metrics(valid_examples, filters)) 189 | logging.info(format_metrics(valid_metrics, split="valid")) 190 | 191 | # Test metrics 192 | test_metrics = avg_both(*model.compute_metrics(test_examples, filters)) 193 | logging.info(format_metrics(test_metrics, split="test")) 194 | 195 | if __name__ == "__main__": 196 | train(parser.parse_args()) 197 | -------------------------------------------------------------------------------- /set_env.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | KGHOME=$(pwd) 3 | export PYTHONPATH="$KGHOME:$PYTHONPATH" 4 | export LOG_DIR="$KGHOME/logs" 5 | export DATA_PATH="$KGHOME/data" 6 | source hyp_kg_env/bin/activate 7 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | """Evaluation script.""" 2 | 3 | import argparse 4 | import json 5 | import os 6 | 7 | import torch 8 | 9 | import models 10 | from datasets.kg_dataset import KGDataset 11 | from utils.train import avg_both, format_metrics 12 | 13 | parser = argparse.ArgumentParser(description="Test") 14 | parser.add_argument( 15 | '--model_dir', 16 | help="Model path" 17 | ) 18 | 19 | 20 | def test(model_dir): 21 | # load config 22 | with open(os.path.join(model_dir, "config.json"), "r") as f: 23 | config = json.load(f) 24 | args = argparse.Namespace(**config) 25 | 26 | # create dataset 27 | dataset_path = os.path.join(os.environ["DATA_PATH"], args.dataset) 28 | dataset = KGDataset(dataset_path, False) 29 | test_examples = dataset.get_examples("test") 30 | filters = dataset.get_filters() 31 | 32 | # load pretrained model weights 33 | model = getattr(models, args.model)(args) 34 | device = 'cuda' 35 | model.to(device) 36 | model.load_state_dict(torch.load(os.path.join(model_dir, 'model.pt'))) 37 | 38 | # eval 39 | test_metrics = avg_both(*model.compute_metrics(test_examples, filters)) 40 | return test_metrics 41 | 42 | 43 | if __name__ == "__main__": 44 | args = parser.parse_args() 45 | test_metrics = test(args.model_dir) 46 | print(format_metrics(test_metrics, split='test')) 47 | -------------------------------------------------------------------------------- /utils/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cgregucci/Link-Prediction-with-Attention-applied-on-multiple-knowledge-graph-embedding-models/188cef319905f72b47da90b19c0cd78d36fa0e78/utils/.DS_Store -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cgregucci/Link-Prediction-with-Attention-applied-on-multiple-knowledge-graph-embedding-models/188cef319905f72b47da90b19c0cd78d36fa0e78/utils/__init__.py -------------------------------------------------------------------------------- /utils/euclidean.py: -------------------------------------------------------------------------------- 1 | """Euclidean operations utils functions.""" 2 | 3 | import torch 4 | 5 | 6 | def euc_sqdistance(x, y, eval_mode=False): 7 | """Compute euclidean squared distance between tensors. 8 | 9 | Args: 10 | x: torch.Tensor of shape (N1 x d) 11 | y: torch.Tensor of shape (N2 x d) 12 | eval_mode: boolean 13 | 14 | Returns: 15 | torch.Tensor of shape N1 x 1 with pairwise squared distances if eval_mode is false 16 | else torch.Tensor of shape N1 x N2 with all-pairs distances 17 | 18 | """ 19 | x2 = torch.sum(x * x, dim=-1, keepdim=True) 20 | y2 = torch.sum(y * y, dim=-1, keepdim=True) 21 | if eval_mode: 22 | y2 = y2.t() 23 | xy = x @ y.t() 24 | else: 25 | assert x.shape[0] == y.shape[0] 26 | xy = torch.sum(x * y, dim=-1, keepdim=True) 27 | return x2 + y2 - 2 * xy 28 | 29 | 30 | def givens_rotations(r, x): 31 | """Givens rotations. 32 | 33 | Args: 34 | r: torch.Tensor of shape (N x d), rotation parameters 35 | x: torch.Tensor of shape (N x d), points to rotate 36 | 37 | Returns: 38 | torch.Tensor os shape (N x d) representing rotation of x by r 39 | """ 40 | givens = r.view((r.shape[0], -1, 2)) 41 | givens = givens / torch.norm(givens, p=2, dim=-1, keepdim=True).clamp_min(1e-15) 42 | x = x.view((r.shape[0], -1, 2)) 43 | x_rot = givens[:, :, 0:1] * x + givens[:, :, 1:] * torch.cat((-x[:, :, 1:], x[:, :, 0:1]), dim=-1) 44 | return x_rot.view((r.shape[0], -1)) 45 | 46 | 47 | def givens_reflection(r, x): 48 | """Givens reflections. 49 | 50 | Args: 51 | r: torch.Tensor of shape (N x d), rotation parameters 52 | x: torch.Tensor of shape (N x d), points to reflect 53 | 54 | Returns: 55 | torch.Tensor os shape (N x d) representing reflection of x by r 56 | """ 57 | givens = r.view((r.shape[0], -1, 2)) 58 | givens = givens / torch.norm(givens, p=2, dim=-1, keepdim=True).clamp_min(1e-15) 59 | x = x.view((r.shape[0], -1, 2)) 60 | x_ref = givens[:, :, 0:1] * torch.cat((x[:, :, 0:1], -x[:, :, 1:]), dim=-1) + givens[:, :, 1:] * torch.cat( 61 | (x[:, :, 1:], x[:, :, 0:1]), dim=-1) 62 | return x_ref.view((r.shape[0], -1)) 63 | -------------------------------------------------------------------------------- /utils/hyperbolic.py: -------------------------------------------------------------------------------- 1 | """Hyperbolic operations utils functions.""" 2 | 3 | import torch 4 | 5 | MIN_NORM = 1e-15 6 | BALL_EPS = {torch.float32: 4e-3, torch.float64: 1e-5} 7 | 8 | 9 | # ################# MATH FUNCTIONS ######################## 10 | 11 | class Artanh(torch.autograd.Function): 12 | @staticmethod 13 | def forward(ctx, x): 14 | x = x.clamp(-1 + 1e-5, 1 - 1e-5) 15 | ctx.save_for_backward(x) 16 | dtype = x.dtype 17 | x = x.double() 18 | return (torch.log_(1 + x).sub_(torch.log_(1 - x))).mul_(0.5).to(dtype) 19 | 20 | @staticmethod 21 | def backward(ctx, grad_output): 22 | input, = ctx.saved_tensors 23 | return grad_output / (1 - input ** 2) 24 | 25 | 26 | def artanh(x): 27 | return Artanh.apply(x) 28 | 29 | 30 | def tanh(x): 31 | return x.clamp(-15, 15).tanh() 32 | 33 | 34 | # ################# HYP OPS ######################## 35 | 36 | def expmap0(u, c): 37 | """Exponential map taken at the origin of the Poincare ball with curvature c. 38 | 39 | Args: 40 | u: torch.Tensor of size B x d with hyperbolic points 41 | c: torch.Tensor of size 1 or B x 1 with absolute hyperbolic curvatures 42 | 43 | Returns: 44 | torch.Tensor with tangent points. 45 | """ 46 | sqrt_c = c ** 0.5 47 | u_norm = u.norm(dim=-1, p=2, keepdim=True).clamp_min(MIN_NORM) 48 | gamma_1 = tanh(sqrt_c * u_norm) * u / (sqrt_c * u_norm) 49 | return project(gamma_1, c) 50 | 51 | 52 | def logmap0(y, c): 53 | """Logarithmic map taken at the origin of the Poincare ball with curvature c. 54 | 55 | Args: 56 | y: torch.Tensor of size B x d with tangent points 57 | c: torch.Tensor of size 1 or B x 1 with absolute hyperbolic curvatures 58 | 59 | Returns: 60 | torch.Tensor with hyperbolic points. 61 | """ 62 | sqrt_c = c ** 0.5 63 | y_norm = y.norm(dim=-1, p=2, keepdim=True).clamp_min(MIN_NORM) 64 | return y / y_norm / sqrt_c * artanh(sqrt_c * y_norm) 65 | 66 | 67 | def project(x, c): 68 | """Project points to Poincare ball with curvature c. 69 | 70 | Args: 71 | x: torch.Tensor of size B x d with hyperbolic points 72 | c: torch.Tensor of size 1 or B x 1 with absolute hyperbolic curvatures 73 | 74 | Returns: 75 | torch.Tensor with projected hyperbolic points. 76 | """ 77 | norm = x.norm(dim=-1, p=2, keepdim=True).clamp_min(MIN_NORM) 78 | eps = BALL_EPS[x.dtype] 79 | maxnorm = (1 - eps) / (c ** 0.5) 80 | cond = norm > maxnorm 81 | projected = x / norm * maxnorm 82 | return torch.where(cond, projected, x) 83 | 84 | 85 | def mobius_add(x, y, c): 86 | """Mobius addition of points in the Poincare ball with curvature c. 87 | 88 | Args: 89 | x: torch.Tensor of size B x d with hyperbolic points 90 | y: torch.Tensor of size B x d with hyperbolic points 91 | c: torch.Tensor of size 1 or B x 1 with absolute hyperbolic curvatures 92 | 93 | Returns: 94 | Tensor of shape B x d representing the element-wise Mobius addition of x and y. 95 | """ 96 | x2 = torch.sum(x * x, dim=-1, keepdim=True) 97 | y2 = torch.sum(y * y, dim=-1, keepdim=True) 98 | xy = torch.sum(x * y, dim=-1, keepdim=True) 99 | num = (1 + 2 * c * xy + c * y2) * x + (1 - c * x2) * y 100 | denom = 1 + 2 * c * xy + c ** 2 * x2 * y2 101 | return num / denom.clamp_min(MIN_NORM) 102 | 103 | 104 | # ################# HYP DISTANCES ######################## 105 | 106 | def hyp_distance(x, y, c, eval_mode=False): 107 | """Hyperbolic distance on the Poincare ball with curvature c. 108 | 109 | Args: 110 | x: torch.Tensor of size B x d with hyperbolic queries 111 | y: torch.Tensor with hyperbolic queries, shape n_entities x d if eval_mode is true else (B x d) 112 | c: torch.Tensor of size 1 with absolute hyperbolic curvature 113 | 114 | Returns: torch,Tensor with hyperbolic distances, size B x 1 if eval_mode is False 115 | else B x n_entities matrix with all pairs distances 116 | """ 117 | sqrt_c = c ** 0.5 118 | x2 = torch.sum(x * x, dim=-1, keepdim=True) 119 | if eval_mode: 120 | y2 = torch.sum(y * y, dim=-1, keepdim=True).transpose(0, 1) 121 | xy = x @ y.transpose(0, 1) 122 | else: 123 | y2 = torch.sum(y * y, dim=-1, keepdim=True) 124 | xy = torch.sum(x * y, dim=-1, keepdim=True) 125 | c1 = 1 - 2 * c * xy + c * y2 126 | c2 = 1 - c * x2 127 | num = torch.sqrt((c1 ** 2) * x2 + (c2 ** 2) * y2 - (2 * c1 * c2) * xy) 128 | denom = 1 - 2 * c * xy + c ** 2 * x2 * y2 129 | pairwise_norm = num / denom.clamp_min(MIN_NORM) 130 | dist = artanh(sqrt_c * pairwise_norm) 131 | return 2 * dist / sqrt_c 132 | 133 | 134 | def hyp_distance_multi_c(x, v, c, eval_mode=False): 135 | """Hyperbolic distance on Poincare balls with varying curvatures c. 136 | 137 | Args: 138 | x: torch.Tensor of size B x d with hyperbolic queries 139 | y: torch.Tensor with hyperbolic queries, shape n_entities x d if eval_mode is true else (B x d) 140 | c: torch.Tensor of size B x d with absolute hyperbolic curvatures 141 | 142 | Return: torch,Tensor with hyperbolic distances, size B x 1 if eval_mode is False 143 | else B x n_entities matrix with all pairs distances 144 | """ 145 | sqrt_c = c ** 0.5 146 | if eval_mode: 147 | vnorm = torch.norm(v, p=2, dim=-1, keepdim=True).transpose(0, 1) 148 | xv = x @ v.transpose(0, 1) / vnorm 149 | else: 150 | vnorm = torch.norm(v, p=2, dim=-1, keepdim=True) 151 | xv = torch.sum(x * v / vnorm, dim=-1, keepdim=True) 152 | gamma = tanh(sqrt_c * vnorm) / sqrt_c 153 | x2 = torch.sum(x * x, dim=-1, keepdim=True) 154 | c1 = 1 - 2 * c * gamma * xv + c * gamma ** 2 155 | c2 = 1 - c * x2 156 | num = torch.sqrt((c1 ** 2) * x2 + (c2 ** 2) * (gamma ** 2) - (2 * c1 * c2) * gamma * xv) 157 | denom = 1 - 2 * c * gamma * xv + (c ** 2) * (gamma ** 2) * x2 158 | pairwise_norm = num / denom.clamp_min(MIN_NORM) 159 | dist = artanh(sqrt_c * pairwise_norm) 160 | return 2 * dist / sqrt_c 161 | -------------------------------------------------------------------------------- /utils/train.py: -------------------------------------------------------------------------------- 1 | """Training utils.""" 2 | import datetime 3 | import os 4 | from os.path import exists 5 | 6 | def get_savedir(model, dataset): 7 | """Get unique saving directory name.""" 8 | dt = datetime.datetime.now() 9 | date = dt.strftime("%m_%d") 10 | save_dir = os.path.join( 11 | os.environ["LOG_DIR"], date, dataset, 12 | model + dt.strftime('_%H_%M_%S') 13 | ) 14 | os.makedirs(save_dir) 15 | return save_dir 16 | 17 | 18 | def avg_both(mrs, mrrs, hits): 19 | """Aggregate metrics for missing lhs and rhs. 20 | 21 | Args: 22 | mrs: Dict[str, float] 23 | mrrs: Dict[str, float] 24 | hits: Dict[str, torch.FloatTensor] 25 | 26 | Returns: 27 | Dict[str, torch.FloatTensor] mapping metric name to averaged score 28 | """ 29 | mr = (mrs['lhs'] + mrs['rhs']) / 2. 30 | mrr = (mrrs['lhs'] + mrrs['rhs']) / 2. 31 | h = (hits['lhs'] + hits['rhs']) / 2. 32 | return {'MR': mr, 'MRR': mrr, 'hits@[1,3,10]': h} 33 | 34 | 35 | def format_metrics(metrics, split): 36 | """Format metrics for logging.""" 37 | result = "\t {} MR: {:.2f} | ".format(split, metrics['MR']) 38 | result += "MRR: {:.3f} | ".format(metrics['MRR']) 39 | result += "H@1: {:.3f} | ".format(metrics['hits@[1,3,10]'][0]) 40 | result += "H@3: {:.3f} | ".format(metrics['hits@[1,3,10]'][1]) 41 | result += "H@10: {:.3f}".format(metrics['hits@[1,3,10]'][2]) 42 | return result 43 | 44 | 45 | def write_metrics(writer, step, metrics, split): 46 | """Write metrics to tensorboard logs.""" 47 | writer.add_scalar('{}_MR'.format(split), metrics['MR'], global_step=step) 48 | writer.add_scalar('{}_MRR'.format(split), metrics['MRR'], global_step=step) 49 | writer.add_scalar('{}_H1'.format(split), metrics['hits@[1,3,10]'][0], global_step=step) 50 | writer.add_scalar('{}_H3'.format(split), metrics['hits@[1,3,10]'][1], global_step=step) 51 | writer.add_scalar('{}_H10'.format(split), metrics['hits@[1,3,10]'][2], global_step=step) 52 | 53 | 54 | def count_params(model): 55 | """Count total number of trainable parameters in model""" 56 | total = 0 57 | for x in model.parameters(): 58 | if x.requires_grad: 59 | res = 1 60 | for y in x.shape: 61 | res *= y 62 | total += res 63 | return total 64 | --------------------------------------------------------------------------------