├── dat ├── fb15k.tgz └── wordnet-mlj12.tar.gz ├── scripts └── preprocess.sh ├── src ├── utils │ ├── graph.py │ ├── math_utils.py │ └── dataset.py ├── models │ ├── base_model.py │ ├── param.py │ ├── hole.py │ ├── distmult.py │ ├── rescal.py │ ├── transe.py │ ├── analogy.py │ └── complex.py ├── test.py ├── processors │ ├── optimizer.py │ ├── trainer.py │ └── evaluator.py └── train.py ├── .gitignore ├── README.md └── LICENSE /dat/fb15k.tgz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mana-ysh/knowledge-graph-embeddings/HEAD/dat/fb15k.tgz -------------------------------------------------------------------------------- /dat/wordnet-mlj12.tar.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mana-ysh/knowledge-graph-embeddings/HEAD/dat/wordnet-mlj12.tar.gz -------------------------------------------------------------------------------- /scripts/preprocess.sh: -------------------------------------------------------------------------------- 1 | 2 | cd ../dat 3 | 4 | # Freebase 5 | tar zxvf fb15k.tgz 6 | cd FB15k 7 | cat freebase_mtr100_mte100-train.txt | cut -f 2 | sort | uniq > train.rellist 8 | cat freebase_mtr100_mte100-train.txt | cut -f 1,3 | perl -pe 's/\t/\n/g' | sort | uniq > train.entlist 9 | cat freebase_mtr100_mte100-train.txt freebase_mtr100_mte100-valid.txt freebase_mtr100_mte100-test.txt > whole.txt 10 | 11 | # WordNet 12 | cd .. 13 | tar zxvf wordnet-mlj12.tar.gz 14 | cd wordnet-mlj12 15 | cat wordnet-mlj12-train.txt | cut -f 2 | perl -pe 's/\t/\n/g' | sort | uniq > train.rellist 16 | cat wordnet-mlj12-train.txt | cut -f 1,3 | perl -pe 's/\t/\n/g' | sort | uniq > train.entlist 17 | cat wordnet-mlj12-train.txt wordnet-mlj12-valid.txt wordnet-mlj12-test.txt > whole.txt 18 | -------------------------------------------------------------------------------- /src/utils/graph.py: -------------------------------------------------------------------------------- 1 | 2 | import numpy as np 3 | from scipy.sparse import lil_matrix 4 | 5 | from utils.dataset import * 6 | 7 | 8 | class TensorTypeGraph(object): 9 | def __init__(self, triple_dat, n_ent, n_rel): 10 | self.rel2mat = [lil_matrix((n_ent, n_ent)) for _ in range(n_rel)] 11 | for triple in triple_dat.batch_iter(1, rand_flg=False): 12 | sub, rel, obj = triple[0] 13 | self.rel2mat[rel][sub, obj] = 1.0 14 | 15 | def search_obj_id(self, sub, rel): 16 | return np.where(self.rel2mat[rel][sub].todense() == 1.0)[1] 17 | 18 | def search_sub_id(self, rel, obj): 19 | return np.where(self.rel2mat[rel][:, obj].todense() == 1.0)[0] 20 | 21 | @classmethod 22 | def load_from_raw(cls, data_path, ent_v, rel_v): 23 | triples = TripletDataset.load(data_path, ent_v, rel_v) 24 | return TensorTypeGraph(triples, len(ent_v), len(rel_v)) 25 | -------------------------------------------------------------------------------- /src/utils/math_utils.py: -------------------------------------------------------------------------------- 1 | 2 | import math 3 | import numpy as np 4 | 5 | 6 | def random_xavier(size): 7 | assert len(size) < 4 8 | # for RESCAL 9 | if len(size) == 3: 10 | assert size[1] == size[2] 11 | dim = size[1] 12 | bound = math.sqrt(6) / math.sqrt(2*dim) 13 | return np.random.uniform(-bound, bound, size=size) 14 | 15 | 16 | def max_margin(pos_scores, neg_scores): 17 | return np.maximum(0, 1 - (pos_scores - neg_scores)) 18 | 19 | 20 | def sigmoid(x): 21 | return np.tanh(x * 0.5) * 0.5 + 0.5 22 | 23 | 24 | def softplus(x): 25 | return np.maximum(0, x)+np.log(1+np.exp(-np.abs(-x))) 26 | 27 | 28 | def circular_convolution(v1, v2): 29 | freq_v1 = np.fft.fft(v1) 30 | freq_v2 = np.fft.fft(v2) 31 | return np.fft.ifft(np.multiply(freq_v1, freq_v2)).real 32 | 33 | 34 | def circular_correlation(v1, v2): 35 | freq_v1 = np.fft.fft(v1) 36 | freq_v2 = np.fft.fft(v2) 37 | return np.fft.ifft(np.multiply(freq_v1.conj(), freq_v2)).real 38 | -------------------------------------------------------------------------------- /src/models/base_model.py: -------------------------------------------------------------------------------- 1 | 2 | import dill 3 | 4 | 5 | class BaseModel(object): 6 | def __init__(self, **kwargs): 7 | raise NotImplementedError 8 | 9 | def cal_rank(self, **kwargs): 10 | raise NotImplementedError 11 | 12 | # For max-margin loss 13 | def _pairwisegrads(self, **kwargs): 14 | raise NotImplementedError 15 | 16 | # For log-likelihood 17 | def _singlegrads(self, **kwargs): 18 | raise NotImplementedError 19 | 20 | def _composite(self, **kwargs): 21 | raise NotImplementedError 22 | 23 | def _cal_similarity(self, **kwargs): 24 | raise NotImplementedError 25 | 26 | def pick_ent(self, **kwargs): 27 | raise NotImplementedError 28 | 29 | def pick_rel(self, **kwargs): 30 | raise NotImplementedError 31 | 32 | def cal_scores(self, **kwargs): 33 | raise NotImplementedError 34 | 35 | def cal_scores_inv(self, **kwargs): 36 | raise NotImplementedError 37 | 38 | def cal_triplet_scores(self, **kwargs): 39 | raise NotImplementedError 40 | 41 | def zerograds(self): 42 | for param in self.params.values(): 43 | param.clear() 44 | 45 | def prepare(self): 46 | self.zerograds() 47 | 48 | def save_model(self, model_path): 49 | with open(model_path, 'wb') as fw: 50 | dill.dump(self, fw) 51 | 52 | @classmethod 53 | def load_model(cls, model_path): 54 | with open(model_path, 'rb') as f: 55 | model = dill.load(f) 56 | return model 57 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | 49 | # Translations 50 | *.mo 51 | *.pot 52 | 53 | # Django stuff: 54 | *.log 55 | local_settings.py 56 | 57 | # Flask stuff: 58 | instance/ 59 | .webassets-cache 60 | 61 | # Scrapy stuff: 62 | .scrapy 63 | 64 | # Sphinx documentation 65 | docs/_build/ 66 | 67 | # PyBuilder 68 | target/ 69 | 70 | # Jupyter Notebook 71 | .ipynb_checkpoints 72 | 73 | # pyenv 74 | .python-version 75 | 76 | # celery beat schedule file 77 | celerybeat-schedule 78 | 79 | # SageMath parsed files 80 | *.sage.py 81 | 82 | # dotenv 83 | .env 84 | 85 | # virtualenv 86 | .venv 87 | venv/ 88 | ENV/ 89 | 90 | # Spyder project settings 91 | .spyderproject 92 | .spyproject 93 | 94 | # Rope project settings 95 | .ropeproject 96 | 97 | # mkdocs documentation 98 | /site 99 | 100 | # mypy 101 | .mypy_cache/ 102 | -------------------------------------------------------------------------------- /src/models/param.py: -------------------------------------------------------------------------------- 1 | 2 | from collections import defaultdict 3 | import numpy as np 4 | 5 | from utils.math_utils import random_xavier 6 | 7 | 8 | # TODO: Abstract class 9 | class Parameter(object): 10 | def __init__(self, name, shape, init_method): 11 | self.name = name 12 | self.shape = shape 13 | if init_method == 'xavier': 14 | self.data = random_xavier(self.shape) 15 | else: 16 | raise NotImplementedError 17 | 18 | 19 | class LookupParameter(Parameter): 20 | def __init__(self, name, shape, init_method='xavier'): 21 | super(LookupParameter, self).__init__(name, shape, init_method) 22 | self.grad_indices = None 23 | self.part_grads = None 24 | self.dim = shape[1] 25 | if len(self.shape) == 2: 26 | self.idx2grad = defaultdict(lambda: np.zeros(self.dim)) 27 | elif len(self.shape) == 3: # for RESCAL 28 | assert self.shape[1] == self.shape[2] 29 | self.idx2grad = defaultdict(lambda: np.zeros((self.dim, self.dim))) 30 | else: 31 | raise 32 | 33 | def add_grad(self, idx, grad): 34 | self.idx2grad[idx] += grad 35 | 36 | def add_all_grads(self, indices, grads): 37 | # TODO: Fix 38 | [self.add_grad(i, g) for i, g in zip(indices, grads)] 39 | 40 | def clear(self): 41 | if len(self.shape) == 2: 42 | self.idx2grad = defaultdict(lambda: np.zeros(self.dim)) 43 | else: 44 | self.idx2grad = defaultdict(lambda: np.zeros((self.dim, self.dim))) 45 | 46 | def finalize(self): 47 | self.grad_indices = list(self.idx2grad.keys()) 48 | self.part_grads = list(self.idx2grad.values()) 49 | if not self.grad_indices: 50 | self.grad_indices = [] 51 | self.part_grads = [] 52 | -------------------------------------------------------------------------------- /src/utils/dataset.py: -------------------------------------------------------------------------------- 1 | """ 2 | TODO 3 | - writing Item class for supporting fancy indexing in PathQueryDataset 4 | """ 5 | 6 | import numpy as np 7 | 8 | 9 | # TODO: Abstract class 10 | class Dataset(object): 11 | def __init__(self, samples): 12 | assert type(samples) == list or type(samples) == np.ndarray 13 | self._samples = samples if type(samples) == np.ndarray else np.array(samples) 14 | 15 | def __getitem__(self, item): 16 | return self._samples[item] 17 | 18 | def __len__(self): 19 | return len(self._samples) 20 | 21 | def batch_iter(self, batchsize, rand_flg=True): 22 | indices = np.random.permutation(len(self)) if rand_flg else np.arange(len(self)) 23 | for start in range(0, len(self), batchsize): 24 | yield self[indices[start: start+batchsize]] 25 | 26 | @classmethod 27 | def load(cls, data_path, ent_vocab, rel_vocab): 28 | raise NotImplementedError 29 | 30 | 31 | class TripletDataset(Dataset): 32 | def __init__(self, samples): 33 | super(TripletDataset, self).__init__(samples) 34 | 35 | @classmethod 36 | def load(cls, data_path, ent_vocab, rel_vocab): 37 | samples = [] 38 | with open(data_path) as f: 39 | for line in f: 40 | sub, rel, obj = line.strip().split('\t') 41 | samples.append((ent_vocab[sub], rel_vocab[rel], ent_vocab[obj])) 42 | return TripletDataset(samples) 43 | 44 | 45 | class Vocab(object): 46 | def __init__(self): 47 | self.id2word = [] 48 | self.word2id = {} 49 | 50 | def add(self, word): 51 | if word not in self.id2word: 52 | self.word2id[word] = len(self.id2word) 53 | self.id2word.append(word) 54 | 55 | def __len__(self): 56 | return len(self.id2word) 57 | 58 | def __getitem__(self, word): 59 | return self.word2id[word] 60 | 61 | @classmethod 62 | def load(cls, vocab_path): 63 | v = Vocab() 64 | with open(vocab_path) as f: 65 | for word in f: 66 | v.add(word.strip()) 67 | return v 68 | -------------------------------------------------------------------------------- /src/test.py: -------------------------------------------------------------------------------- 1 | 2 | import argparse 3 | 4 | from processors.evaluator import Evaluator 5 | from utils.dataset import TripletDataset, Vocab 6 | from utils.graph import * 7 | 8 | 9 | def test(args): 10 | ent_vocab = Vocab.load(args.ent) 11 | rel_vocab = Vocab.load(args.rel) 12 | 13 | # preparing data 14 | test_dat = TripletDataset.load(args.data, ent_vocab, rel_vocab) 15 | 16 | print('loading model...') 17 | if args.method == 'complex': 18 | from models.complex import ComplEx as Model 19 | elif args.method == 'distmult': 20 | from models.distmult import DistMult as Model 21 | elif args.method == 'transe': 22 | from models.transe import TransE as Model 23 | elif args.method == 'hole': 24 | from models.hole import HolE as Model 25 | elif args.method == 'rescal': 26 | from models.rescal import RESCAL as Model 27 | elif args.method == 'analogy': 28 | from models.analogy import ANALOGY as Model 29 | else: 30 | raise NotImplementedError 31 | 32 | if args.filtered: 33 | print('loading whole graph...') 34 | from utils.graph import TensorTypeGraph 35 | whole_graph = TensorTypeGraph.load_from_raw(args.graphall, ent_vocab, rel_vocab) 36 | else: 37 | whole_graph = None 38 | evaluator = Evaluator('all', None, args.filtered, whole_graph) 39 | if args.filtered: 40 | evaluator.prepare_valid(test_dat) 41 | model = Model.load_model(args.model) 42 | 43 | all_res = evaluator.run_all_matric(model, test_dat) 44 | for metric in sorted(all_res.keys()): 45 | print('{:20s}: {}'.format(metric, all_res[metric])) 46 | 47 | 48 | if __name__ == '__main__': 49 | p = argparse.ArgumentParser('Link prediction models') 50 | 51 | # dataset 52 | p.add_argument('--ent', type=str, help='entity list') 53 | p.add_argument('--rel', type=str, help='relation list') 54 | p.add_argument('--data', type=str, help='test data') 55 | p.add_argument('--filtered', action='store_true', help='use filtered metric') 56 | p.add_argument('--graphall', type=str, help='all graph file for filtered evaluation') 57 | 58 | # model 59 | p.add_argument('--method', default=None, type=str, help='method ["complex", "distmult", "transe", "hole", "rescal", "analogy"]') 60 | p.add_argument('--model', type=str, help='trained model path') 61 | 62 | args = p.parse_args() 63 | test(args) 64 | -------------------------------------------------------------------------------- /src/models/hole.py: -------------------------------------------------------------------------------- 1 | 2 | from models.base_model import BaseModel 3 | from models.param import LookupParameter 4 | from utils.math_utils import * 5 | 6 | 7 | class HolE(BaseModel): 8 | def __init__(self, **kwargs): 9 | self.n_entity = kwargs.pop('n_entity') 10 | self.n_relation = kwargs.pop('n_relation') 11 | self.dim = kwargs.pop('dim') 12 | self.margin = kwargs.pop('margin') 13 | mode = kwargs.pop('mode', 'pairwise') 14 | if mode == 'pairwise': 15 | self.compute_gradients = self._pairwisegrads 16 | elif mode == 'single': 17 | self.compute_gradients = self._singlegrads 18 | else: 19 | raise NotImplementedError 20 | 21 | self.params = {'e': LookupParameter(name='e', shape=(self.n_entity, self.dim)), 22 | 'r': LookupParameter(name='r', shape=(self.n_relation, self.dim))} 23 | 24 | def _pairwisegrads(self, pos_samples, neg_samples): 25 | raise NotImplementedError 26 | 27 | def _singlegrads(self, samples, ys): 28 | self.prepare() 29 | scores = self.cal_triplet_scores(samples) 30 | loss = softplus(-ys*scores) 31 | 32 | subs, rels, objs = samples[:, 0], samples[:, 1], samples[:, 2] 33 | s_embs = self.pick_ent(subs) 34 | r_embs = self.pick_rel(rels) 35 | o_embs = self.pick_ent(objs) 36 | 37 | # compute gradients 38 | df = np.expand_dims(-ys * (1 - sigmoid(ys*scores)), axis=1) 39 | s_grads = circular_correlation(r_embs, o_embs) * df 40 | r_grads = circular_correlation(s_embs, o_embs) * df 41 | o_grads = circular_convolution(s_embs, r_embs) * df 42 | 43 | # TODO: unify how to passing the gradients 44 | ents = np.r_[subs, objs] 45 | self.params['e'].add_all_grads(ents, np.r_[s_grads, o_grads]) 46 | self.params['r'].add_all_grads(rels, r_grads) 47 | 48 | self.params['e'].finalize() 49 | self.params['r'].finalize() 50 | 51 | return loss.mean() 52 | 53 | def _composite(self, sub_emb, rel_emb): 54 | return circular_convolution(sub_emb, rel_emb) 55 | 56 | def _cal_similarity(self, query, obj_emb): 57 | return np.sum(query * obj_emb, axis=1) 58 | 59 | def cal_scores(self, subs, rels): 60 | sub_emb = self.pick_ent(subs) 61 | rel_emb = self.pick_rel(rels) 62 | qs = self._composite(sub_emb, rel_emb) 63 | score_mat = qs.dot(self.params['e'].data.T) 64 | return score_mat 65 | 66 | def cal_scores_inv(self, rels, objs): 67 | obj_emb = self.pick_ent(objs) 68 | rel_emb = self.pick_rel(rels) 69 | qs_inv = circular_correlation(rel_emb, obj_emb) 70 | score_mat = qs_inv.dot(self.params['e'].data.T) 71 | return score_mat 72 | 73 | def cal_triplet_scores(self, samples): 74 | subs, rels, objs = samples[:, 0], samples[:, 1], samples[:, 2] 75 | sub_emb = self.pick_ent(subs) 76 | rel_emb = self.pick_rel(rels) 77 | obj_emb = self.pick_ent(objs) 78 | qs = self._composite(sub_emb, rel_emb) 79 | return self._cal_similarity(qs, obj_emb) 80 | 81 | def pick_ent(self, ents): 82 | return self.params['e'].data[ents] 83 | 84 | def pick_rel(self, rels): 85 | return self.params['r'].data[rels] 86 | -------------------------------------------------------------------------------- /src/processors/optimizer.py: -------------------------------------------------------------------------------- 1 | 2 | import numpy as np 3 | import pickle 4 | 5 | from models.param import LookupParameter 6 | from utils.math_utils import * 7 | 8 | 9 | class Optimizer(object): 10 | def update(self): 11 | if hasattr(self, 'l2_coeff'): 12 | self._l2_addhook() 13 | if hasattr(self, 'gc_norm'): 14 | self._gradclip_addhook() 15 | self._update() 16 | 17 | def _update(self, **kwargs): 18 | raise NotImplementedError 19 | 20 | def _prepare(self): 21 | pass 22 | 23 | def _l2_addhook(self): 24 | for p_name in self.params.keys(): 25 | param = self.params[p_name] 26 | if type(param) == LookupParameter: 27 | for i, idx in enumerate(param.grad_indices): 28 | param.part_grads[i] += 2 * self.l2_coeff * param.data[idx] 29 | else: 30 | raise NotImplementedError 31 | 32 | def _gradclip_addhook(self): 33 | for p_name in self.params.keys(): 34 | param = self.params[p_name] 35 | if type(param) == LookupParameter: 36 | for i, idx in enumerate(param.grad_indices): 37 | norm = np.linalg.norm(param.part_grads[i]) 38 | if norm > self.gc_norm: 39 | param.part_grads[i] *= self.gc_norm / norm 40 | else: 41 | raise NotImplementedError 42 | 43 | def regist_params(self, params): 44 | self.params = params 45 | self._prepare() 46 | 47 | def set_l2_reg(self, coeff): 48 | self.l2_coeff = coeff 49 | 50 | def set_gradclip(self, norm): 51 | self.gc_norm = norm 52 | 53 | def save_opt(self, opt_path): 54 | with open(opt_path, 'wb') as fw: 55 | pickle.dump(self, fw) 56 | 57 | @classmethod 58 | def load_opt(cls, opt_path): 59 | with open(opt_path, 'rb') as f: 60 | opt = pickle.load(f) 61 | return opt 62 | 63 | 64 | class SGD(Optimizer): 65 | def __init__(self, lr): 66 | self.lr = lr 67 | 68 | def _update(self): 69 | for param in self.params.values(): 70 | if type(param) == LookupParameter: 71 | idxs = param.grad_indices 72 | if len(idxs) != 0: 73 | param.data[idxs] -= self.lr * np.array(param.part_grads) 74 | else: 75 | param.data -= self.lr * param.grad 76 | 77 | class Adagrad(Optimizer): 78 | def __init__(self, lr, eps=1e-8): 79 | self.lr = lr 80 | self.eps = eps 81 | self.grad_history = {} 82 | 83 | def _update(self): 84 | for p_name in self.params.keys(): 85 | param = self.params[p_name] 86 | if type(param) == LookupParameter: 87 | idxs = param.grad_indices 88 | if len(idxs) != 0: 89 | self.grad_history[p_name][idxs] += np.power(param.part_grads, 2) 90 | param.data[idxs] -= self.lr * np.array(param.part_grads) / (np.sqrt(self.grad_history[p_name][idxs]) + self.eps) 91 | else: 92 | self.grad_history[p_name] += np.power(param.grad, 2) 93 | param.data -= self.lr * param.grad / (np.sqrt(self.grad_history[p_name]) + self.eps) 94 | 95 | def _init_grad_history(self): 96 | for p_name in self.params.keys(): 97 | param = self.params[p_name] 98 | self.grad_history[p_name] = np.zeros_like(param.data) 99 | 100 | def _prepare(self): 101 | self._init_grad_history() 102 | -------------------------------------------------------------------------------- /src/models/distmult.py: -------------------------------------------------------------------------------- 1 | 2 | from models.base_model import BaseModel 3 | from models.param import LookupParameter 4 | from utils.math_utils import * 5 | 6 | 7 | class DistMult(BaseModel): 8 | def __init__(self, **kwargs): 9 | self.n_entity = kwargs.pop('n_entity') 10 | self.n_relation = kwargs.pop('n_relation') 11 | self.dim = kwargs.pop('dim') 12 | self.margin = kwargs.pop('margin') 13 | mode = kwargs.pop('mode', 'pairwise') 14 | if mode == 'pairwise': 15 | self.compute_gradients = self._pairwisegrads 16 | elif mode == 'single': 17 | self.compute_gradients = self._singlegrads 18 | else: 19 | raise NotImplementedError 20 | 21 | self.params = {'e': LookupParameter(name='e', shape=(self.n_entity, self.dim)), 22 | 'r': LookupParameter(name='r', shape=(self.n_relation, self.dim))} 23 | 24 | def _pairwisegrads(self, pos_samples, neg_samples): 25 | assert pos_samples.shape == neg_samples.shape 26 | self.prepare() 27 | p_scores = self.cal_triplet_scores(pos_samples) 28 | n_scores = self.cal_triplet_scores(neg_samples) 29 | 30 | loss = max_margin(p_scores, n_scores) 31 | idxs = np.where(loss > 0)[0] 32 | if len(idxs) != 0: 33 | # TODO: inefficient calculation 34 | pos_subs, pos_rels, pos_objs = pos_samples[idxs, 0], pos_samples[idxs, 1], pos_samples[idxs, 2] 35 | neg_subs, neg_rels, neg_objs = neg_samples[idxs, 0], neg_samples[idxs, 1], neg_samples[idxs, 2] 36 | 37 | p_s_embs = self.pick_ent(pos_subs) 38 | p_r_embs = self.pick_rel(pos_rels) 39 | p_o_embs = self.pick_ent(pos_objs) 40 | n_s_embs = self.pick_ent(neg_subs) 41 | n_r_embs = self.pick_rel(neg_rels) 42 | n_o_embs = self.pick_ent(neg_objs) 43 | 44 | _batchsize = len(pos_subs) 45 | 46 | p_s_grads = - (p_r_embs * p_o_embs) 47 | p_r_grads = - (p_s_embs * p_o_embs) 48 | p_o_grads = - (p_s_embs * p_r_embs) 49 | n_s_grads = n_r_embs * n_o_embs 50 | n_r_grads = n_s_embs * n_o_embs 51 | n_o_grads = n_s_embs * n_r_embs 52 | 53 | for idx in range(_batchsize): 54 | self.params['e'].add_grad(pos_subs[idx], p_s_grads[idx]) 55 | self.params['r'].add_grad(pos_rels[idx], p_r_grads[idx]) 56 | self.params['e'].add_grad(pos_objs[idx], p_o_grads[idx]) 57 | self.params['e'].add_grad(neg_subs[idx], n_s_grads[idx]) 58 | self.params['r'].add_grad(neg_rels[idx], n_r_grads[idx]) 59 | self.params['e'].add_grad(neg_objs[idx], n_o_grads[idx]) 60 | 61 | else: 62 | pass 63 | 64 | self.params['e'].finalize() 65 | self.params['r'].finalize() 66 | 67 | return loss.mean() 68 | 69 | def _singlegrads(self, samples, ys): 70 | raise NotImplementedError('Only pairwise setting is available') 71 | 72 | def _composite(self, sub_emb, rel_emb): 73 | return sub_emb * rel_emb 74 | 75 | def _cal_similarity(self, query, obj_emb): 76 | return np.sum(query * obj_emb, axis=1) 77 | 78 | def cal_scores(self, subs, rels): 79 | sub_emb = self.pick_ent(subs) 80 | rel_emb = self.pick_rel(rels) 81 | qs = self._composite(sub_emb, rel_emb) 82 | score_mat = qs.dot(self.params['e'].data.T) 83 | return score_mat 84 | 85 | # TODO: this procedure is the same as cal_scores 86 | def cal_scores_inv(self, rels, objs): 87 | obj_emb = self.pick_ent(objs) 88 | rel_emb = self.pick_rel(rels) 89 | qs_inv = obj_emb * rel_emb 90 | score_mat = qs_inv.dot(self.params['e'].data.T) 91 | return score_mat 92 | 93 | def cal_triplet_scores(self, samples): 94 | subs, rels, objs = samples[:, 0], samples[:, 1], samples[:, 2] 95 | sub_emb = self.pick_ent(subs) 96 | rel_emb = self.pick_rel(rels) 97 | obj_emb = self.pick_ent(objs) 98 | qs = self._composite(sub_emb, rel_emb) 99 | return self._cal_similarity(qs, obj_emb) 100 | 101 | def pick_ent(self, ents): 102 | return self.params['e'].data[ents] 103 | 104 | def pick_rel(self, rels): 105 | return self.params['r'].data[rels] 106 | -------------------------------------------------------------------------------- /src/models/rescal.py: -------------------------------------------------------------------------------- 1 | 2 | from models.base_model import BaseModel 3 | from models.param import LookupParameter 4 | from utils.math_utils import * 5 | 6 | 7 | class RESCAL(BaseModel): 8 | def __init__(self, **kwargs): 9 | self.n_entity = kwargs.pop('n_entity') 10 | self.n_relation = kwargs.pop('n_relation') 11 | self.dim = kwargs.pop('dim') 12 | self.margin = kwargs.pop('margin') 13 | mode = kwargs.pop('mode', 'pairwise') 14 | if mode == 'pairwise': 15 | self.compute_gradients = self._pairwisegrads 16 | elif mode == 'single': 17 | self.compute_gradients = self._singlegrads 18 | else: 19 | raise NotImplementedError 20 | 21 | self.params = {'e': LookupParameter(name='e', shape=(self.n_entity, self.dim)), 22 | 'r_mat': LookupParameter(name='r_mat', shape=(self.n_relation, self.dim, self.dim))} 23 | 24 | def _pairwisegrads(self, pos_samples, neg_samples): 25 | assert pos_samples.shape == neg_samples.shape 26 | self.prepare() 27 | p_scores = self.cal_triplet_scores(pos_samples) 28 | n_scores = self.cal_triplet_scores(neg_samples) 29 | 30 | loss = max_margin(p_scores, n_scores) 31 | idxs = np.where(loss > 0)[0] 32 | if len(idxs) != 0: 33 | # TODO: inefficient calculation 34 | pos_subs, pos_rels, pos_objs = pos_samples[idxs, 0], pos_samples[idxs, 1], pos_samples[idxs, 2] 35 | neg_subs, neg_rels, neg_objs = neg_samples[idxs, 0], neg_samples[idxs, 1], neg_samples[idxs, 2] 36 | 37 | p_s_embs = self.pick_ent(pos_subs) 38 | p_r_mats = self.pick_rel(pos_rels) 39 | p_o_embs = self.pick_ent(pos_objs) 40 | n_s_embs = self.pick_ent(neg_subs) 41 | n_r_mats = self.pick_rel(neg_rels) 42 | n_o_embs = self.pick_ent(neg_objs) 43 | 44 | _batchsize = len(pos_subs) 45 | 46 | p_s_grads = - np.matmul(p_r_mats, np.expand_dims(p_o_embs, axis=2)).reshape(_batchsize, self.dim) 47 | p_r_grads = - np.matmul(np.expand_dims(p_s_embs, axis=2), np.expand_dims(p_o_embs, axis=1)) 48 | p_o_grads = - self._composite(p_s_embs, p_r_mats) 49 | n_s_grads = np.matmul(n_r_mats, np.expand_dims(n_o_embs, axis=2)).reshape(_batchsize, self.dim) 50 | n_r_grads = np.matmul(np.expand_dims(n_s_embs, axis=2), np.expand_dims(n_o_embs, axis=1)) 51 | n_o_grads = self._composite(n_s_embs, n_r_mats) 52 | 53 | for idx in range(_batchsize): 54 | self.params['e'].add_grad(pos_subs[idx], p_s_grads[idx]) 55 | self.params['r_mat'].add_grad(pos_rels[idx], p_r_grads[idx]) 56 | self.params['e'].add_grad(pos_objs[idx], p_o_grads[idx]) 57 | self.params['e'].add_grad(neg_subs[idx], n_s_grads[idx]) 58 | self.params['r_mat'].add_grad(neg_rels[idx], n_r_grads[idx]) 59 | self.params['e'].add_grad(neg_objs[idx], n_o_grads[idx]) 60 | 61 | else: 62 | pass 63 | 64 | self.params['e'].finalize() 65 | self.params['r_mat'].finalize() 66 | 67 | return loss.mean() 68 | 69 | def _singlegrads(self, samples, ys): 70 | raise NotImplementedError('Only pairwise setting is available') 71 | 72 | def _composite(self, sub_emb, rel_mat): 73 | _batchsize = len(sub_emb) 74 | return np.matmul(np.expand_dims(sub_emb, axis=1), rel_mat).reshape(_batchsize, self.dim) 75 | 76 | def _cal_similarity(self, query, obj_emb): 77 | return np.sum(query * obj_emb, axis=1) 78 | 79 | def cal_scores(self, subs, rels): 80 | sub_emb = self.pick_ent(subs) 81 | rel_mat = self.pick_rel(rels) 82 | qs = self._composite(sub_emb, rel_mat) 83 | score_mat = qs.dot(self.params['e'].data.T) 84 | return score_mat 85 | 86 | def cal_scores_inv(self, rels, objs): 87 | _batchsize = len(rels) 88 | obj_emb = self.pick_ent(objs) 89 | rel_mat = self.pick_rel(rels) 90 | qs_inv = np.matmul(rel_mat, np.expand_dims(obj_emb, axis=2)).reshape(_batchsize, self.dim) 91 | score_mat = qs_inv.dot(self.params['e'].data.T) 92 | return score_mat 93 | 94 | def cal_triplet_scores(self, samples): 95 | subs, rels, objs = samples[:, 0], samples[:, 1], samples[:, 2] 96 | sub_emb = self.pick_ent(subs) 97 | rel_emb = self.pick_rel(rels) 98 | obj_emb = self.pick_ent(objs) 99 | qs = self._composite(sub_emb, rel_emb) 100 | return self._cal_similarity(qs, obj_emb) 101 | 102 | def pick_ent(self, ents): 103 | return self.params['e'].data[ents] 104 | 105 | def pick_rel(self, rels): 106 | return self.params['r_mat'].data[rels] 107 | -------------------------------------------------------------------------------- /src/models/transe.py: -------------------------------------------------------------------------------- 1 | 2 | from models.base_model import BaseModel 3 | from models.param import LookupParameter 4 | from utils.math_utils import * 5 | 6 | 7 | class TransE(BaseModel): 8 | def __init__(self, **kwargs): 9 | self.n_entity = kwargs.pop('n_entity') 10 | self.n_relation = kwargs.pop('n_relation') 11 | self.dim = kwargs.pop('dim') 12 | self.margin = kwargs.pop('margin') 13 | mode = kwargs.pop('mode', 'pairwise') 14 | if mode == 'pairwise': 15 | self.compute_gradients = self._pairwisegrads 16 | elif mode == 'single': 17 | self.compute_gradients = self._singlegrads 18 | else: 19 | raise NotImplementedError 20 | 21 | self.params = {'e': LookupParameter(name='e', shape=(self.n_entity, self.dim)), 22 | 'r': LookupParameter(name='r', shape=(self.n_relation, self.dim))} 23 | 24 | def _pairwisegrads(self, pos_samples, neg_samples): 25 | assert pos_samples.shape == neg_samples.shape 26 | self.prepare() 27 | p_scores = self.cal_triplet_scores(pos_samples) 28 | n_scores = self.cal_triplet_scores(neg_samples) 29 | 30 | loss = max_margin(p_scores, n_scores) 31 | idxs = np.where(loss > 0)[0] 32 | if len(idxs) != 0: 33 | # TODO: inefficient calculation 34 | pos_subs, pos_rels, pos_objs = pos_samples[idxs, 0], pos_samples[idxs, 1], pos_samples[idxs, 2] 35 | neg_subs, neg_rels, neg_objs = neg_samples[idxs, 0], neg_samples[idxs, 1], neg_samples[idxs, 2] 36 | 37 | p_s_embs = self.pick_ent(pos_subs) 38 | p_r_embs = self.pick_rel(pos_rels) 39 | p_o_embs = self.pick_ent(pos_objs) 40 | n_s_embs = self.pick_ent(neg_subs) 41 | n_r_embs = self.pick_rel(neg_rels) 42 | n_o_embs = self.pick_ent(neg_objs) 43 | 44 | p_qs = self._composite(p_s_embs, p_r_embs) 45 | n_qs = self._composite(n_s_embs, n_r_embs) 46 | 47 | p_s_grads = 2 * (p_qs - p_o_embs) 48 | p_r_grads = 2 * (p_qs - p_o_embs) 49 | p_o_grads = -2 * (p_qs - p_o_embs) 50 | 51 | n_s_grads = -2 * (n_qs - n_o_embs) 52 | n_r_grads = -2 * (n_qs - n_o_embs) 53 | n_o_grads = 2 * (n_qs - n_o_embs) 54 | 55 | _batchsize = len(pos_subs) 56 | 57 | for idx in range(_batchsize): 58 | self.params['e'].add_grad(pos_subs[idx], p_s_grads[idx]) 59 | self.params['r'].add_grad(pos_rels[idx], p_r_grads[idx]) 60 | self.params['e'].add_grad(pos_objs[idx], p_o_grads[idx]) 61 | self.params['e'].add_grad(neg_subs[idx], n_s_grads[idx]) 62 | self.params['r'].add_grad(neg_rels[idx], n_r_grads[idx]) 63 | self.params['e'].add_grad(neg_objs[idx], n_o_grads[idx]) 64 | 65 | else: 66 | pass 67 | 68 | self.params['e'].finalize() 69 | self.params['r'].finalize() 70 | 71 | return loss.mean() 72 | 73 | def _singlegrads(self, samples, ys): 74 | raise NotImplementedError('Only pairwise setting is available') 75 | 76 | def _composite(self, sub_emb, rel_emb): 77 | return sub_emb + rel_emb 78 | 79 | def _cal_similarity(self, query, obj_emb): 80 | return - np.sum((query - obj_emb)**2, axis=1) 81 | 82 | def cal_scores(self, subs, rels): 83 | _batchsize = len(subs) 84 | sub_emb = self.pick_ent(subs) 85 | rel_emb = self.pick_rel(rels) 86 | qs = self._composite(sub_emb, rel_emb) 87 | 88 | # TODO: maybe inefficient. use matrix operation 89 | score_mat = np.empty((_batchsize, self.n_entity)) 90 | for i in range(_batchsize): 91 | score_mat[i] = - np.linalg.norm(qs[i] - self.pick_ent(np.arange(self.n_entity)), axis=1) ** 2 92 | return score_mat 93 | 94 | def cal_scores_inv(self, rels, objs): 95 | _batchsize = len(objs) 96 | obj_emb = self.pick_ent(objs) 97 | rel_emb = self.pick_rel(rels) 98 | qs_inv = rel_emb - obj_emb 99 | 100 | # TODO: maybe inefficient. use matrix operation 101 | score_mat = np.empty((_batchsize, self.n_entity)) 102 | for i in range(_batchsize): 103 | score_mat[i] = - np.linalg.norm(self.pick_ent(np.arange(self.n_entity)) + qs_inv[i], axis=1) ** 2 104 | return score_mat 105 | 106 | def cal_triplet_scores(self, samples): 107 | subs, rels, objs = samples[:, 0], samples[:, 1], samples[:, 2] 108 | sub_emb = self.pick_ent(subs) 109 | rel_emb = self.pick_rel(rels) 110 | obj_emb = self.pick_ent(objs) 111 | qs = self._composite(sub_emb, rel_emb) 112 | return self._cal_similarity(qs, obj_emb) 113 | 114 | def pick_ent(self, ents): 115 | return self.params['e'].data[ents] 116 | 117 | def pick_rel(self, rels): 118 | return self.params['r'].data[rels] 119 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # knowledge-graph-embeddings 2 | 3 | Python Implementations of Embedding-based methods for Knowledge Base Completion tasks, mainly inspired by [scikit-kge](https://github.com/mnick/scikit-kge) and [complex](https://github.com/ttrouill/complex). 4 | 5 | ## List of methods 6 | - RESCAL [Nickel+. 2011] 7 | - TransE [Bordes+. 2013] 8 | - DistMult [Yang+. 2015] 9 | - HolE [Nicklel+. 2016] 10 | - This model is equivalent to ComplEx[Hayashi and Shimbo. 2018], and the computation cost of ComplEx is lower than of HolE. 11 | - ComplEx [Trouillon+. 2016] 12 | - ANALOGY [Liu+. 2017] 13 | - This model can be regarded as a hybrid between DistMult and ComplEx. 14 | 15 | 16 | ## Run to train and test 17 | 18 | For training... 19 | 20 | ``` 21 | ▶ python train.py -h 22 | usage: Link prediction models [-h] [--mode MODE] [--ent ENT] [--rel REL] 23 | [--train TRAIN] [--valid VALID] 24 | [--method METHOD] [--epoch EPOCH] 25 | [--batch BATCH] [--lr LR] [--dim DIM] 26 | [--margin MARGIN] [--negative NEGATIVE] 27 | [--opt OPT] [--l2_reg L2_REG] 28 | [--gradclip GRADCLIP] [--save_step SAVE_STEP] 29 | [--cp_ratio CP_RATIO] [--metric METRIC] 30 | [--nbest NBEST] [--filtered] 31 | [--graphall GRAPHALL] [--log LOG] 32 | 33 | optional arguments: 34 | -h, --help show this help message and exit 35 | --mode MODE training mode ["pairwise", "single"] 36 | --ent ENT entity list 37 | --rel REL relation list 38 | --train TRAIN training data 39 | --valid VALID validation data 40 | --method METHOD method ["complex", "distmult", "transe", "hole", 41 | "rescal", "analogy"] 42 | --epoch EPOCH number of epochs 43 | --batch BATCH batch size 44 | --lr LR learning rate 45 | --dim DIM dimension of embeddings 46 | --margin MARGIN margin in max-margin loss for pairwise training 47 | --negative NEGATIVE number of negative samples for pairwise training 48 | --opt OPT optimizer ["sgd", "adagrad"] 49 | --l2_reg L2_REG L2 regularization 50 | --gradclip GRADCLIP gradient clipping 51 | --save_step SAVE_STEP 52 | epoch step for saving model 53 | --cp_ratio CP_RATIO ratio of complex's dimention in ANALOGY 54 | --metric METRIC evaluation metrics ["mrr", "hits"] 55 | --nbest NBEST n-best for hits metric 56 | --filtered use filtered metric 57 | --graphall GRAPHALL all graph file for filtered evaluation 58 | --log LOG output log dir 59 | ``` 60 | 61 | 62 | For testing... 63 | 64 | ``` 65 | ▶ python test.py -h 66 | usage: Link prediction models [-h] [--ent ENT] [--rel REL] [--data DATA] 67 | [--filtered] [--graphall GRAPHALL] 68 | [--method METHOD] [--model MODEL] 69 | 70 | optional arguments: 71 | -h, --help show this help message and exit 72 | --ent ENT entity list 73 | --rel REL relation list 74 | --data DATA test data 75 | --filtered use filtered metric 76 | --graphall GRAPHALL all graph file for filtered evaluation 77 | --method METHOD method ["complex", "distmult", "transe", "hole", 78 | "rescal", "analogy"] 79 | --model MODEL trained model path 80 | ``` 81 | 82 | ## Experiments 83 | 84 | ### WordNet (WN18) 85 | 86 | | Models | MRR (flt) | MRR (raw) | Hits@1 (flt) | Hits@3 (flt) | Hits@10 (flt) | 87 | |:-----------:|:------------:|:------------:|:------------:|:------------:|:------------:| 88 | | ComplEx* | 94.1 | 58.7 | 93.6 | 94.5 | 94.7 | 89 | | ComplEx | 94.3 | 58.2 | 94.0 | 94.6 | 94.8 | 90 | 91 | hyper parameters 92 | 93 | * mode : single 94 | * epoch : 500 95 | * batch : 128 96 | * lr : 0.05 97 | * dim : 200 98 | * negative : 5 99 | * opt : adagrad 100 | * l2_reg : 0.001 101 | * gradclip : 5 102 | 103 | 104 | 105 | ### FreeBase (FB15k) 106 | | Models | MRR (flt) | MRR (raw) | Hits@1 (flt) | Hits@3 (flt) | Hits@10 (flt) | 107 | |:-----------:|:------------:|:------------:|:------------:|:------------:|:------------:| 108 | | ComplEx* | 69.2 | 24.2 | 59.9 | 75.9 | 84.0 | 109 | | ComplEx | 69.5 | 24.2 | 59.8 | 76.9 | 85.0 | 110 | 111 | hyper parameters 112 | 113 | * mode : single 114 | * epoch : 500 115 | * batch : 128 116 | * lr : 0.05 117 | * dim : 200 118 | * negative : 10 119 | * opt : adagrad 120 | * l2_reg : 0.0001 121 | * gradclip : 5 122 | 123 | 124 | \* means the results reported from the original papers 125 | 126 | ## Dependencies 127 | * numpy 128 | * scipy 129 | 130 | 131 | ## References 132 | 133 | * Bordes, A.; Usunier, N.; Garcia-Duran, A.; Weston, J.; and Yakhnenko, O. 2013. Translating embeddings for modeling multi-relational data. In Advances in Neural Information Processing Systems (NIPS). 134 | 135 | * Liu, H.; Wu, Y.; and Yang, Y. 2017. Analogical inference for multi-relational embeddings. In Proceedings of the 34th International Conference on Machine Learning (ICML). 136 | 137 | * Nickel, M.; Rosasco, L.; and Poggio, T. 2016. Holographic embeddings of knowledge graphs. In Proceedings of the Thirtieth AAAI Conference on Artificial Intelligence, AAAI’16. 138 | 139 | * Nickel, M.; Tresp, V.; and Kriegel, H.-P. 2011. A threeway model for collective learning on multi-relational data. In International Conference on Machine Learning (ICML-11), ICML ’11, 140 | 141 | * Trouillon, T.; Welbl, J.; Riedel, S.; Gaussier, E.; and Bouchard, G. 2016. Complex embeddings for simple link prediction. In International Conference on Machine Learning (ICML). 142 | 143 | * Yang, B.; Yih, W.; He, X.; Gao, J.; and Deng, L. 2015. Embedding entities and relations for learning and inference in knowledge bases. International Conference on Learning Representations 2015. 144 | -------------------------------------------------------------------------------- /src/models/analogy.py: -------------------------------------------------------------------------------- 1 | 2 | from models.base_model import BaseModel 3 | from models.param import LookupParameter 4 | from utils.math_utils import * 5 | 6 | 7 | class ANALOGY(BaseModel): 8 | def __init__(self, **kwargs): 9 | self.n_entity = kwargs.pop('n_entity') 10 | self.n_relation = kwargs.pop('n_relation') 11 | self.dim = kwargs.pop('dim') 12 | self.margin = kwargs.pop('margin') 13 | self.complex_ratio = kwargs.pop('cp_ratio') 14 | assert self.complex_ratio >= 0 and self.complex_ratio <= 1 15 | mode = kwargs.pop('mode', 'pairwise') 16 | if mode == 'pairwise': 17 | self.compute_gradients = self._pairwisegrads 18 | elif mode == 'single': 19 | self.compute_gradients = self._singlegrads 20 | else: 21 | raise NotImplementedError 22 | comp_dim = int(self.dim * self.complex_ratio) 23 | dist_dim = self.dim - comp_dim 24 | 25 | self.params = {'e_re': LookupParameter(name='e_re', shape=(self.n_entity, comp_dim)), 26 | 'e_im': LookupParameter(name='e_im', shape=(self.n_entity, comp_dim)), 27 | 'r_re': LookupParameter(name='r_re', shape=(self.n_relation, comp_dim)), 28 | 'r_im': LookupParameter(name='r_im', shape=(self.n_relation, comp_dim)), 29 | 'e': LookupParameter(name='e', shape=(self.n_entity, dist_dim)), 30 | 'r': LookupParameter(name='r', shape=(self.n_relation, dist_dim))} 31 | 32 | def _singlegrads(self, samples, ys): 33 | """ 34 | each element in ys must be \{-1, 1 \} 35 | """ 36 | self.prepare() 37 | scores = self.cal_triplet_scores(samples) 38 | loss = softplus(-ys*scores) 39 | 40 | subs, rels, objs = samples[:, 0], samples[:, 1], samples[:, 2] 41 | s_re_embs, s_im_embs, s_embs = self.pick_ent(subs) 42 | r_re_embs, r_im_embs, r_embs = self.pick_rel(rels) 43 | o_re_embs, o_im_embs, o_embs = self.pick_ent(objs) 44 | 45 | # compute gradient 46 | df = np.expand_dims(-ys * (1 - sigmoid(ys*scores)), axis=1) 47 | s_re_grads = (r_re_embs * o_re_embs + r_im_embs * o_im_embs) * df 48 | s_im_grads = (r_re_embs * o_im_embs - r_im_embs * o_re_embs) * df 49 | r_re_grads = (s_re_embs * o_re_embs + s_im_embs * o_im_embs) * df 50 | r_im_grads = (s_re_embs * o_im_embs - s_im_embs * o_re_embs) * df 51 | o_re_grads = (s_re_embs * r_re_embs - s_im_embs * r_im_embs) * df 52 | o_im_grads = (s_re_embs * r_im_embs + s_im_embs * r_re_embs) * df 53 | s_grads = (r_embs * o_embs) * df 54 | r_grads = (s_embs * o_embs) * df 55 | o_grads = (s_embs * r_embs) * df 56 | 57 | ents = np.r_[subs, objs] 58 | self.params['e_re'].add_all_grads(ents, np.r_[s_re_grads, o_re_grads]) 59 | self.params['e_im'].add_all_grads(ents, np.r_[s_im_grads, o_im_grads]) 60 | self.params['r_re'].add_all_grads(rels, r_re_grads) 61 | self.params['r_im'].add_all_grads(rels, r_im_grads) 62 | self.params['e'].add_all_grads(ents, np.r_[s_grads, o_grads]) 63 | self.params['r'].add_all_grads(rels, r_grads) 64 | 65 | self.params['e_re'].finalize() 66 | self.params['e_im'].finalize() 67 | self.params['r_re'].finalize() 68 | self.params['r_im'].finalize() 69 | self.params['e'].finalize() 70 | self.params['r'].finalize() 71 | 72 | return loss.mean() 73 | 74 | def _comp_composite(self, sub_re_emb, sub_im_emb, rel_re_emb, rel_im_emb): 75 | re_qs = sub_re_emb * rel_re_emb - sub_im_emb * rel_im_emb 76 | im_qs = sub_re_emb * rel_im_emb + sub_im_emb * rel_re_emb 77 | return re_qs, im_qs 78 | 79 | def _dist_composite(self, sub_emb, rel_emb): 80 | return np.multiply(sub_emb, rel_emb) 81 | 82 | def cal_scores(self, subs, rels): 83 | _batchsize = len(subs) 84 | sub_re_emb, sub_im_emb, sub_emb = self.pick_ent(subs) 85 | rel_re_emb, rel_im_emb, rel_emb = self.pick_rel(rels) 86 | re_qs, im_qs = self._comp_composite(sub_re_emb, sub_im_emb, rel_re_emb, rel_im_emb) 87 | comp_score_mat = re_qs.dot(self.params['e_re'].data.T) + im_qs.dot(self.params['e_im'].data.T) 88 | 89 | qs = self._dist_composite(sub_emb, rel_emb) 90 | dist_score_mat = qs.dot(self.params['e'].data.T) 91 | return comp_score_mat + dist_score_mat 92 | 93 | def cal_scores_inv(self, rels, objs): 94 | _batchsize = len(objs) 95 | obj_re_emb, obj_im_emb, obj_emb = self.pick_ent(objs) 96 | rel_re_emb, rel_im_emb, rel_emb = self.pick_rel(rels) 97 | re_qs_inv = obj_re_emb * rel_re_emb + obj_im_emb * rel_im_emb 98 | im_qs_inv = obj_im_emb * rel_re_emb - obj_re_emb * rel_im_emb 99 | comp_score_mat = re_qs_inv.dot(self.params['e_re'].data.T) + im_qs_inv.dot(self.params['e_im'].data.T) 100 | 101 | qs = self._dist_composite(obj_emb, rel_emb) 102 | dist_score_mat = qs.dot(self.params['e'].data.T) 103 | return comp_score_mat + dist_score_mat 104 | 105 | def cal_triplet_scores(self, samples): 106 | subs, rels, objs = samples[:, 0], samples[:, 1], samples[:, 2] 107 | sub_re_emb, sub_im_emb, sub_emb = self.pick_ent(subs) 108 | rel_re_emb, rel_im_emb, rel_emb = self.pick_rel(rels) 109 | obj_re_emb, obj_im_emb, obj_emb = self.pick_ent(objs) 110 | 111 | # complex 112 | re_qs, im_qs = self._comp_composite(sub_re_emb, sub_im_emb, rel_re_emb, rel_im_emb) 113 | comp_score = np.sum(re_qs * obj_re_emb, axis=1) + np.sum(im_qs * obj_im_emb, axis=1) 114 | 115 | # distmult 116 | qs = self._dist_composite(sub_emb, rel_emb) 117 | dist_score = np.sum(qs * obj_emb, axis=1) 118 | 119 | return comp_score + dist_score 120 | 121 | def pick_ent(self, ents): 122 | return self.params['e_re'].data[ents], self.params['e_im'].data[ents], self.params['e'].data[ents] 123 | 124 | def pick_rel(self, rels): 125 | return self.params['r_re'].data[rels], self.params['r_im'].data[rels], self.params['r'].data[rels] 126 | -------------------------------------------------------------------------------- /src/processors/trainer.py: -------------------------------------------------------------------------------- 1 | 2 | import copy 3 | import os 4 | import time 5 | 6 | from utils.dataset import * 7 | 8 | 9 | class Trainer(object): 10 | def __init__(self, **kwargs): 11 | self.model = kwargs.pop('model') 12 | self.opt = kwargs.pop('opt') 13 | self.n_epoch = kwargs.pop('epoch') 14 | self.batchsize = kwargs.pop('batchsize') 15 | self.logger = kwargs.pop('logger') 16 | self.log_dir = kwargs.pop('model_dir') 17 | self.evaluator = kwargs.pop('evaluator') 18 | self.valid_dat = kwargs.pop('valid_dat') 19 | self.save_step = kwargs.pop('save_step') 20 | self.cur_epoch = 0 21 | self.model_path = os.path.join(self.log_dir, self.model.__class__.__name__) 22 | 23 | def _setup(self): 24 | self.logger.info('setup trainer...') 25 | self.opt.regist_params(self.model.params) 26 | self.best_model = None 27 | 28 | def _finalize(self): 29 | assert self.n_epoch == self.cur_epoch 30 | if self.valid_dat: 31 | self.best_model.save_model(self.model_path + '.best') 32 | best_epoch, best_val = self.evaluator.get_best_info() 33 | self.logger.info('===== Best metric: {} ({} epoch) ====='.format(best_val, best_epoch)) 34 | else: 35 | self.model.save_model(self.model_path + '.epoch{}'.format(self.n_epoch)) 36 | 37 | def _validation(self): 38 | valid_start = time.time() 39 | res = self.evaluator.run(self.model, self.valid_dat) 40 | self.logger.info('evaluation metric in {} epoch: {}'.format(self.cur_epoch, res)) 41 | self.logger.info('evaluation time in {} epoch: {}'.format(self.cur_epoch, time.time() - valid_start)) 42 | 43 | cur_best_epoch, cur_best_val = self.evaluator.get_best_info() 44 | self.logger.info('< Current Best metric: {} ({} epoch) >'.format(cur_best_val, cur_best_epoch)) 45 | if cur_best_epoch == self.cur_epoch: 46 | self.best_model = copy.deepcopy(self.model) 47 | return cur_best_epoch 48 | 49 | def fit(self, **kwargs): 50 | raise NotImplementedError 51 | 52 | 53 | class PairwiseTrainer(Trainer): 54 | def __init__(self, **kwargs): 55 | super(PairwiseTrainer, self).__init__(**kwargs) 56 | self.n_negative = kwargs.pop('n_negative') 57 | self.neg_generator = UniformNegativeGenerator(self.model.n_entity, self.n_negative) 58 | 59 | def fit(self, train_dat): 60 | assert type(train_dat) == TripletDataset 61 | self._setup() 62 | for epoch in range(self.n_epoch): 63 | start = time.time() 64 | self.cur_epoch += 1 65 | start = time.time() 66 | sum_loss = 0. 67 | self.logger.info('start {} epoch'.format(epoch+1)) 68 | for pos_triplets in train_dat.batch_iter(self.batchsize): 69 | neg_triplets = self.neg_generator.generate(pos_triplets) 70 | loss = self.model.compute_gradients(np.tile(pos_triplets, (self.n_negative, 1)), neg_triplets) 71 | self.opt.update() 72 | sum_loss += loss 73 | 74 | if self.valid_dat: # run validation 75 | cur_best_epoch = self._validation() 76 | 77 | if (epoch+1) % self.save_step == 0: 78 | if self.valid_dat: 79 | self.best_model.save_model(self.model_path+'.epoch{}'.format(cur_best_epoch)) 80 | else: 81 | self.model.save_model(self.model_path + '.epoch{}'.format(epoch+1)) 82 | 83 | self.logger.info('training loss in {} epoch: {}'.format(epoch+1, sum_loss)) 84 | self.logger.info('training time in {} epoch: {}'.format(epoch+1, time.time()-start)) 85 | 86 | self._finalize() 87 | 88 | 89 | class SingleTrainer(Trainer): 90 | def __init__(self, **kwargs): 91 | super(SingleTrainer, self).__init__(**kwargs) 92 | self.n_negative = kwargs.pop('n_negative') 93 | self.neg_generator = UniformNegativeGenerator(self.model.n_entity, self.n_negative) 94 | 95 | def fit(self, train_dat): 96 | self._setup() 97 | if type(train_dat) == TripletDataset: 98 | self._fit_negative_sample(train_dat) 99 | else: 100 | raise NotImplementedError 101 | 102 | def _fit_negative_sample(self, train_dat): 103 | assert type(train_dat) == TripletDataset 104 | for epoch in range(self.n_epoch): 105 | start = time.time() 106 | self.cur_epoch += 1 107 | sum_loss = 0. 108 | self.logger.info('start {} epoch'.format(epoch+1)) 109 | for pos_triplets in train_dat.batch_iter(self.batchsize): 110 | neg_triplets = self.neg_generator.generate(pos_triplets) 111 | ys = np.concatenate((np.ones(len(pos_triplets)), -np.ones(len(neg_triplets)))) 112 | loss = self.model.compute_gradients(np.r_[pos_triplets, neg_triplets], ys) 113 | self.opt.update() 114 | sum_loss += loss 115 | 116 | if self.valid_dat: # run validation 117 | cur_best_epoch = self._validation() 118 | 119 | if (epoch+1) % self.save_step == 0: 120 | if self.valid_dat: 121 | self.best_model.save_model(self.model_path+'.epoch{}'.format(cur_best_epoch)) 122 | else: 123 | self.model.save_model(self.model_path + '.epoch{}'.format(epoch+1)) 124 | 125 | self.logger.info('training loss in {} epoch: {}'.format(epoch+1, sum_loss)) 126 | self.logger.info('training time in {} epoch: {}'.format(epoch+1, time.time()-start)) 127 | 128 | self._finalize() 129 | 130 | 131 | class NegativeGenerator(object): 132 | def __init__(self, n_ent, n_negative, train_graph=None): 133 | self.n_ent = n_ent 134 | self.n_negative = n_negative 135 | if train_graph: 136 | raise NotImplementedError 137 | self.graph = train_graph # for preventing from including positive triplets as negative ones 138 | 139 | def generate(self, pos_triplets): 140 | """ 141 | :return: neg_triplets, whose size is (length of positives \times n_sample , 3) 142 | """ 143 | raise NotImplementedError 144 | 145 | 146 | class UniformNegativeGenerator(NegativeGenerator): 147 | def __init__(self, n_ent, n_negative, train_graph=None): 148 | super(UniformNegativeGenerator, self).__init__(n_ent, n_negative, train_graph) 149 | 150 | def generate(self, pos_triplets): 151 | _batchsize = len(pos_triplets) 152 | sample_size = _batchsize * self.n_negative 153 | neg_ents = np.random.randint(0, self.n_ent, size=sample_size) 154 | neg_triplets = np.tile(pos_triplets, (self.n_negative, 1)) 155 | head_or_tail = 2 * np.random.randint(0, 2, sample_size) 156 | neg_triplets[np.arange(sample_size), head_or_tail] = neg_ents 157 | return neg_triplets 158 | 159 | 160 | -------------------------------------------------------------------------------- /src/train.py: -------------------------------------------------------------------------------- 1 | 2 | import argparse 3 | from datetime import datetime 4 | import logging 5 | import numpy as np 6 | import os 7 | 8 | from processors.trainer import PairwiseTrainer, SingleTrainer 9 | from processors.evaluator import Evaluator 10 | from processors.optimizer import SGD, Adagrad 11 | from utils.dataset import TripletDataset, Vocab 12 | 13 | 14 | np.random.seed(46) 15 | DEFAULT_LOG_DIR = os.path.join(os.path.abspath(os.path.dirname(__file__)), 16 | '{}'.format(datetime.now().strftime('%Y%m%d_%H:%M'))) 17 | 18 | 19 | def train(args): 20 | # setting for logging 21 | if not os.path.exists(args.log): 22 | os.mkdir(args.log) 23 | logger = logging.getLogger() 24 | logging.basicConfig(level=logging.INFO) 25 | log_path = os.path.join(args.log, 'log') 26 | file_handler = logging.FileHandler(log_path) 27 | fmt = logging.Formatter('%(asctime)s %(levelname)s %(message)s') 28 | file_handler.setFormatter(fmt) 29 | logger.addHandler(file_handler) 30 | 31 | # TODO: develop the recording of arguments in logging 32 | logger.info('Arguments...') 33 | for arg, val in sorted(vars(args).items()): 34 | logger.info('{:>10} -----> {}'.format(arg, val)) 35 | 36 | ent_vocab = Vocab.load(args.ent) 37 | rel_vocab = Vocab.load(args.rel) 38 | n_entity, n_relation = len(ent_vocab), len(rel_vocab) 39 | 40 | # preparing data 41 | logger.info('preparing data...') 42 | train_dat = TripletDataset.load(args.train, ent_vocab, rel_vocab) 43 | valid_dat = TripletDataset.load(args.valid, ent_vocab, rel_vocab) if args.valid else None 44 | 45 | if args.filtered: 46 | logger.info('loading whole graph...') 47 | from utils.graph import TensorTypeGraph 48 | whole_graph = TensorTypeGraph.load_from_raw(args.graphall, ent_vocab, rel_vocab) 49 | else: 50 | whole_graph = None 51 | 52 | if args.opt == 'sgd': 53 | opt = SGD(args.lr) 54 | elif args.opt == 'adagrad': 55 | opt = Adagrad(args.lr) 56 | else: 57 | raise NotImplementedError 58 | 59 | if args.l2_reg > 0: 60 | opt.set_l2_reg(args.l2_reg) 61 | if args.gradclip > 0: 62 | opt.set_gradclip(args.gradclip) 63 | 64 | logger.info('building model...') 65 | if args.method == 'complex': 66 | from models.complex import ComplEx 67 | model = ComplEx(n_entity=n_entity, 68 | n_relation=n_relation, 69 | margin=args.margin, 70 | dim=args.dim, 71 | mode=args.mode) 72 | elif args.method == 'distmult': 73 | from models.distmult import DistMult 74 | model = DistMult(n_entity=n_entity, 75 | n_relation=n_relation, 76 | margin=args.margin, 77 | dim=args.dim, 78 | mode=args.mode) 79 | elif args.method == 'transe': 80 | from models.transe import TransE 81 | model = TransE(n_entity=n_entity, 82 | n_relation=n_relation, 83 | margin=args.margin, 84 | dim=args.dim, 85 | mode=args.mode) 86 | elif args.method == 'hole': 87 | from models.hole import HolE 88 | model = HolE(n_entity=n_entity, 89 | n_relation=n_relation, 90 | margin=args.margin, 91 | dim=args.dim, 92 | mode=args.mode) 93 | elif args.method == 'rescal': 94 | from models.rescal import RESCAL 95 | model = RESCAL(n_entity=n_entity, 96 | n_relation=n_relation, 97 | margin=args.margin, 98 | dim=args.dim, 99 | mode=args.mode) 100 | elif args.method == 'analogy': 101 | from models.analogy import ANALOGY 102 | model = ANALOGY(n_entity=n_entity, 103 | n_relation=n_relation, 104 | margin=args.margin, 105 | dim=args.dim, 106 | cp_ratio=args.cp_ratio, 107 | mode=args.mode) 108 | else: 109 | raise NotImplementedError 110 | 111 | evaluator = Evaluator(args.metric, args.nbest, args.filtered, whole_graph) if args.valid or args.synthetic else None 112 | if args.filtered and args.valid: 113 | evaluator.prepare_valid(valid_dat) 114 | if args.mode == 'pairwise': 115 | trainer = PairwiseTrainer(model=model, opt=opt, save_step=args.save_step, 116 | batchsize=args.batch, logger=logger, 117 | evaluator=evaluator, valid_dat=valid_dat, 118 | n_negative=args.negative, epoch=args.epoch, 119 | model_dir=args.log) 120 | elif args.mode == 'single': 121 | trainer = SingleTrainer(model=model, opt=opt, save_step=args.save_step, 122 | batchsize=args.batch, logger=logger, 123 | evaluator=evaluator, valid_dat=valid_dat, 124 | n_negative=args.negative, epoch=args.epoch, 125 | model_dir=args.log) 126 | else: 127 | raise NotImplementedError 128 | 129 | trainer.fit(train_dat) 130 | 131 | logger.info('done all') 132 | 133 | 134 | if __name__ == '__main__': 135 | p = argparse.ArgumentParser('Link prediction models') 136 | p.add_argument('--mode', default='single', type=str, help='training mode ["pairwise", "single"]') 137 | 138 | # dataset 139 | p.add_argument('--ent', type=str, help='entity list') 140 | p.add_argument('--rel', type=str, help='relation list') 141 | p.add_argument('--train', type=str, help='training data') 142 | p.add_argument('--valid', type=str, help='validation data') 143 | 144 | # model 145 | p.add_argument('--method', default='complex', type=str, help='method ["complex", "distmult", "transe", "hole", "rescal", "analogy"]') 146 | p.add_argument('--epoch', default=300, type=int, help='number of epochs') 147 | p.add_argument('--batch', default=128, type=int, help='batch size') 148 | p.add_argument('--lr', default=0.05, type=float, help='learning rate') 149 | p.add_argument('--dim', default=200, type=int, help='dimension of embeddings') 150 | p.add_argument('--margin', default=1., type=float, help='margin in max-margin loss for pairwise training') 151 | p.add_argument('--negative', default=10, type=int, help='number of negative samples for pairwise training') 152 | p.add_argument('--opt', default='adagrad', type=str, help='optimizer ["sgd", "adagrad"]') 153 | p.add_argument('--l2_reg', default=0.0001, type=float, help='L2 regularization') 154 | p.add_argument('--gradclip', default=5, type=float, help='gradient clipping') 155 | p.add_argument('--save_step', default=100, type=int, help='epoch step for saving model') 156 | 157 | # model specific arguments 158 | p.add_argument('--cp_ratio', default=0.5, type=float, help="ratio of complex's dimention in ANALOGY") 159 | 160 | # evaluation 161 | p.add_argument('--metric', default='mrr', type=str, help='evaluation metrics ["mrr", "hits"]') 162 | p.add_argument('--nbest', default=None, type=int, help='n-best for hits metric') 163 | p.add_argument('--filtered', action='store_true', help='use filtered metric') 164 | p.add_argument('--graphall', type=str, help='all graph file for filtered evaluation') 165 | 166 | # others 167 | p.add_argument('--log', default=DEFAULT_LOG_DIR, type=str, help='output log dir') 168 | 169 | args = p.parse_args() 170 | 171 | train(args) 172 | -------------------------------------------------------------------------------- /src/models/complex.py: -------------------------------------------------------------------------------- 1 | 2 | from models.base_model import BaseModel 3 | from models.param import LookupParameter 4 | from utils.math_utils import * 5 | 6 | 7 | class ComplEx(BaseModel): 8 | def __init__(self, **kwargs): 9 | self.n_entity = kwargs.pop('n_entity') 10 | self.n_relation = kwargs.pop('n_relation') 11 | self.dim = kwargs.pop('dim') 12 | self.margin = kwargs.pop('margin') 13 | mode = kwargs.pop('mode', 'pairwise') 14 | if mode == 'pairwise': 15 | self.compute_gradients = self._pairwisegrads 16 | elif mode == 'single': 17 | self.compute_gradients = self._singlegrads 18 | else: 19 | raise NotImplementedError 20 | 21 | self.params = {'e_re': LookupParameter(name='e_re', shape=(self.n_entity, self.dim)), 22 | 'e_im': LookupParameter(name='e_im', shape=(self.n_entity, self.dim)), 23 | 'r_re': LookupParameter(name='r_re', shape=(self.n_relation, self.dim)), 24 | 'r_im': LookupParameter(name='r_im', shape=(self.n_relation, self.dim))} 25 | 26 | def _pairwisegrads(self, pos_samples, neg_samples): 27 | assert pos_samples.shape == neg_samples.shape 28 | self.prepare() 29 | p_scores = self.cal_triplet_scores(pos_samples) 30 | n_scores = self.cal_triplet_scores(neg_samples) 31 | 32 | loss = max_margin(p_scores, n_scores) 33 | idxs = np.where(loss > 0)[0] 34 | if len(idxs) != 0: 35 | # TODO: inefficient calculation 36 | pos_subs, pos_rels, pos_objs = pos_samples[idxs, 0], pos_samples[idxs, 1], pos_samples[idxs, 2] 37 | neg_subs, neg_rels, neg_objs = neg_samples[idxs, 0], neg_samples[idxs, 1], neg_samples[idxs, 2] 38 | 39 | p_s_re_embs, p_s_im_embs = self.pick_ent(pos_subs) 40 | p_r_re_embs, p_r_im_embs = self.pick_rel(pos_rels) 41 | p_o_re_embs, p_o_im_embs = self.pick_ent(pos_objs) 42 | n_s_re_embs, n_s_im_embs = self.pick_ent(neg_subs) 43 | n_r_re_embs, n_r_im_embs = self.pick_rel(neg_rels) 44 | n_o_re_embs, n_o_im_embs = self.pick_ent(neg_objs) 45 | 46 | _batchsize = len(pos_subs) 47 | 48 | p_s_re_grads = - (p_r_re_embs * p_o_re_embs + p_r_im_embs * p_o_im_embs) 49 | p_s_im_grads = - (p_r_re_embs * p_o_im_embs - p_r_im_embs * p_o_re_embs) 50 | p_r_re_grads = - (p_s_re_embs * p_o_re_embs + p_s_im_embs * p_o_im_embs) 51 | p_r_im_grads = - (p_s_re_embs * p_o_im_embs - p_s_im_embs * p_o_re_embs) 52 | p_o_re_grads = - (p_s_re_embs * p_r_re_embs - p_s_im_embs * p_r_im_embs) 53 | p_o_im_grads = - (p_s_re_embs * p_r_im_embs + p_s_im_embs * p_r_re_embs) 54 | 55 | n_s_re_grads = n_r_re_embs * n_o_re_embs + n_r_im_embs * n_o_im_embs 56 | n_s_im_grads = n_r_re_embs * n_o_im_embs - n_r_im_embs * n_o_re_embs 57 | n_r_re_grads = n_s_re_embs * n_o_re_embs + n_s_im_embs * n_o_im_embs 58 | n_r_im_grads = n_s_re_embs * n_o_im_embs - n_s_im_embs * n_o_re_embs 59 | n_o_re_grads = n_s_re_embs * n_r_re_embs - n_s_im_embs * n_r_im_embs 60 | n_o_im_grads = n_s_re_embs * n_r_im_embs + n_s_im_embs * n_r_re_embs 61 | 62 | # TODO: unify how to passing the gradients 63 | for idx in range(_batchsize): 64 | self.params['e_re'].add_grad(pos_subs[idx], p_s_re_grads[idx]) 65 | self.params['e_im'].add_grad(pos_subs[idx], p_s_im_grads[idx]) 66 | self.params['r_re'].add_grad(pos_rels[idx], p_r_re_grads[idx]) 67 | self.params['r_im'].add_grad(pos_rels[idx], p_r_im_grads[idx]) 68 | self.params['e_re'].add_grad(pos_objs[idx], p_o_re_grads[idx]) 69 | self.params['e_im'].add_grad(pos_objs[idx], p_o_im_grads[idx]) 70 | 71 | self.params['e_re'].add_grad(neg_subs[idx], n_s_re_grads[idx]) 72 | self.params['e_im'].add_grad(neg_subs[idx], n_s_im_grads[idx]) 73 | self.params['r_re'].add_grad(neg_rels[idx], n_r_re_grads[idx]) 74 | self.params['r_im'].add_grad(neg_rels[idx], n_r_im_grads[idx]) 75 | self.params['e_re'].add_grad(neg_objs[idx], n_o_re_grads[idx]) 76 | self.params['e_im'].add_grad(neg_objs[idx], n_o_im_grads[idx]) 77 | 78 | else: 79 | pass 80 | 81 | self.params['e_re'].finalize() 82 | self.params['e_im'].finalize() 83 | self.params['r_re'].finalize() 84 | self.params['r_im'].finalize() 85 | 86 | return loss.mean() 87 | 88 | def _singlegrads(self, samples, ys): 89 | """ 90 | each element in ys must be \{-1, 1 \} 91 | """ 92 | self.prepare() 93 | scores = self.cal_triplet_scores(samples) 94 | loss = softplus(-ys*scores) 95 | 96 | subs, rels, objs = samples[:, 0], samples[:, 1], samples[:, 2] 97 | s_re_embs, s_im_embs = self.pick_ent(subs) 98 | r_re_embs, r_im_embs = self.pick_rel(rels) 99 | o_re_embs, o_im_embs = self.pick_ent(objs) 100 | 101 | # compute gradients 102 | df = np.expand_dims(-ys * (1 - sigmoid(ys*scores)), axis=1) 103 | s_re_grads = (r_re_embs * o_re_embs + r_im_embs * o_im_embs) * df 104 | s_im_grads = (r_re_embs * o_im_embs - r_im_embs * o_re_embs) * df 105 | r_re_grads = (s_re_embs * o_re_embs + s_im_embs * o_im_embs) * df 106 | r_im_grads = (s_re_embs * o_im_embs - s_im_embs * o_re_embs) * df 107 | o_re_grads = (s_re_embs * r_re_embs - s_im_embs * r_im_embs) * df 108 | o_im_grads = (s_re_embs * r_im_embs + s_im_embs * r_re_embs) * df 109 | 110 | # TODO: unify how to passing the gradients 111 | ents = np.r_[subs, objs] 112 | self.params['e_re'].add_all_grads(ents, np.r_[s_re_grads, o_re_grads]) 113 | self.params['e_im'].add_all_grads(ents, np.r_[s_im_grads, o_im_grads]) 114 | self.params['r_re'].add_all_grads(rels, r_re_grads) 115 | self.params['r_im'].add_all_grads(rels, r_im_grads) 116 | 117 | self.params['e_re'].finalize() 118 | self.params['e_im'].finalize() 119 | self.params['r_re'].finalize() 120 | self.params['r_im'].finalize() 121 | 122 | return loss.mean() 123 | 124 | def _composite(self, sub_re_emb, sub_im_emb, rel_re_emb, rel_im_emb, prefix=''): 125 | re_qs = sub_re_emb * rel_re_emb - sub_im_emb * rel_im_emb 126 | im_qs = sub_re_emb * rel_im_emb + sub_im_emb * rel_re_emb 127 | return re_qs, im_qs 128 | 129 | def _cal_similarity(self, re_query, im_query, obj_re_emb, obj_im_emb): 130 | return np.sum(re_query * obj_re_emb, axis=1) + np.sum(im_query * obj_im_emb, axis=1) 131 | 132 | def cal_scores(self, subs, rels): 133 | sub_re_emb, sub_im_emb = self.pick_ent(subs) 134 | rel_re_emb, rel_im_emb = self.pick_rel(rels) 135 | re_qs, im_qs = self._composite(sub_re_emb, sub_im_emb, rel_re_emb, rel_im_emb) 136 | score_mat = re_qs.dot(self.params['e_re'].data.T) + im_qs.dot(self.params['e_im'].data.T) 137 | return score_mat 138 | 139 | def cal_scores_inv(self, rels, objs): 140 | obj_re_emb, obj_im_emb = self.pick_ent(objs) 141 | rel_re_emb, rel_im_emb = self.pick_rel(rels) 142 | re_qs_inv = obj_re_emb * rel_re_emb + obj_im_emb * rel_im_emb 143 | im_qs_inv = obj_im_emb * rel_re_emb - obj_re_emb * rel_im_emb 144 | score_mat = re_qs_inv.dot(self.params['e_re'].data.T) + im_qs_inv.dot(self.params['e_im'].data.T) 145 | return score_mat 146 | 147 | def cal_triplet_scores(self, samples): 148 | subs, rels, objs = samples[:, 0], samples[:, 1], samples[:, 2] 149 | sub_re_emb, sub_im_emb = self.pick_ent(subs) 150 | rel_re_emb, rel_im_emb = self.pick_rel(rels) 151 | obj_re_emb, obj_im_emb = self.pick_ent(objs) 152 | re_qs, im_qs = self._composite(sub_re_emb, sub_im_emb, rel_re_emb, rel_im_emb) 153 | return self._cal_similarity(re_qs, im_qs, obj_re_emb, obj_im_emb) 154 | 155 | def pick_ent(self, ents): 156 | return self.params['e_re'].data[ents], self.params['e_im'].data[ents] 157 | 158 | def pick_rel(self, rels): 159 | return self.params['r_re'].data[rels], self.params['r_im'].data[rels] 160 | -------------------------------------------------------------------------------- /src/processors/evaluator.py: -------------------------------------------------------------------------------- 1 | 2 | import numpy as np 3 | 4 | 5 | BATCHSIZE = 1000 6 | 7 | 8 | class Evaluator(object): 9 | def __init__(self, metric, nbest=None, filtered=False, whole_graph=None): 10 | assert metric in ['mrr', 'hits', 'all'], 'Invalid metric: {}'.format(metric) 11 | if metric == 'hits': 12 | assert nbest, 'Please indicate n-best in using hits' 13 | if filtered: 14 | assert whole_graph, 'If use filtered metric, Please indicate whole graph' 15 | self.all_graph = whole_graph 16 | self.metric = metric 17 | self.nbest = nbest 18 | self.filtered = filtered 19 | self.batchsize = BATCHSIZE 20 | self.ress = [] 21 | self.id2sub_list = [] 22 | self.id2obj_list = [] 23 | self.sr2o = {} 24 | self.ro2s = {} 25 | 26 | def run(self, model, dataset): 27 | if self.metric == 'mrr': 28 | res = self.cal_mrr(model, dataset) 29 | elif self.metric == 'hits': 30 | res = self.cal_hits(model, dataset, self.nbest) 31 | else: 32 | raise NotImplementedError 33 | self.ress.append(res) 34 | return res 35 | 36 | def run_all_matric(self, model, dataset): 37 | """ 38 | calculating MRR, Hits@1,3,10 (raw and filter) 39 | """ 40 | n_sample = len(dataset) 41 | sum_rr_raw = 0. 42 | sum_rr_flt = 0. 43 | n_corr_h1_raw = 0 44 | n_corr_h1_flt = 0 45 | n_corr_h3_raw = 0 46 | n_corr_h3_flt = 0 47 | n_corr_h10_raw = 0 48 | n_corr_h10_flt = 0 49 | start_id = 0 50 | for samples in dataset.batch_iter(self.batchsize, rand_flg=False): 51 | subs, rels, objs = samples[:, 0], samples[:, 1], samples[:, 2] 52 | ids = np.arange(start_id, start_id+len(samples)) 53 | 54 | # TODO: partitioned calculation 55 | # search objects 56 | raw_scores = model.cal_scores(subs, rels) 57 | raw_ranks = self.cal_rank(raw_scores, objs) 58 | sum_rr_raw += sum(float(1/rank) for rank in raw_ranks) 59 | n_corr_h1_raw += sum(1 for rank in raw_ranks if rank <= 1) 60 | n_corr_h3_raw += sum(1 for rank in raw_ranks if rank <= 3) 61 | n_corr_h10_raw += sum(1 for rank in raw_ranks if rank <= 10) 62 | # filter 63 | if self.filtered: 64 | flt_scores = self.cal_filtered_score_fast(subs, rels, objs, ids, raw_scores) 65 | flt_ranks = self.cal_rank(flt_scores, objs) 66 | sum_rr_flt += sum(float(1/rank) for rank in flt_ranks) 67 | n_corr_h1_flt += sum(1 for rank in flt_ranks if rank <=1) 68 | n_corr_h3_flt += sum(1 for rank in flt_ranks if rank <=3) 69 | n_corr_h10_flt += sum(1 for rank in flt_ranks if rank <=10) 70 | 71 | # search subjects 72 | raw_scores_inv = model.cal_scores_inv(rels, objs) 73 | raw_ranks_inv = self.cal_rank(raw_scores_inv, subs) 74 | sum_rr_raw += sum(float(1/rank) for rank in raw_ranks_inv) 75 | n_corr_h1_raw += sum(1 for rank in raw_ranks_inv if rank <= 1) 76 | n_corr_h3_raw += sum(1 for rank in raw_ranks_inv if rank <= 3) 77 | n_corr_h10_raw += sum(1 for rank in raw_ranks_inv if rank <= 10) 78 | # filter 79 | if self.filtered: 80 | flt_scores_inv = self.cal_filtered_score_inv_fast(subs, rels, objs, ids, raw_scores_inv) 81 | flt_ranks_inv = self.cal_rank(flt_scores_inv, subs) 82 | sum_rr_flt += sum(float(1/rank) for rank in flt_ranks_inv) 83 | n_corr_h1_flt += sum(1 for rank in flt_ranks_inv if rank <= 1) 84 | n_corr_h3_flt += sum(1 for rank in flt_ranks_inv if rank <= 3) 85 | n_corr_h10_flt += sum(1 for rank in flt_ranks_inv if rank <= 10) 86 | 87 | start_id += len(samples) 88 | 89 | return {'MRR': sum_rr_raw/n_sample/2, 90 | 'Hits@1': n_corr_h1_raw/n_sample/2, 91 | 'Hits@3': n_corr_h3_raw/n_sample/2, 92 | 'Hits@10': n_corr_h10_raw/n_sample/2, 93 | 'MRR(filter)': sum_rr_flt/n_sample/2, 94 | 'Hits@1(filter)': n_corr_h1_flt/n_sample/2, 95 | 'Hits@3(filter)': n_corr_h3_flt/n_sample/2, 96 | 'Hits@10(filter)': n_corr_h10_flt/n_sample/2} 97 | 98 | def cal_mrr(self, model, dataset): 99 | n_sample = len(dataset) 100 | sum_rr = 0. 101 | start_id = 0 102 | for samples in dataset.batch_iter(self.batchsize, rand_flg=False): 103 | subs, rels, objs = samples[:, 0], samples[:, 1], samples[:, 2] 104 | ids = np.arange(start_id, start_id+len(samples)) 105 | scores = model.cal_scores(subs, rels) 106 | if self.filtered: 107 | scores = self.cal_filtered_score_fast(subs, rels, objs, ids, scores) 108 | ranks1 = self.cal_rank(scores, objs) 109 | 110 | scores = model.cal_scores_inv(rels, objs) 111 | if self.filtered: 112 | scores = self.cal_filtered_score_inv_fast(subs, rels, objs, ids, scores) 113 | ranks2 = self.cal_rank(scores, subs) 114 | sum_rr += sum(float(1/rank) for rank in ranks1 + ranks2) 115 | start_id += len(samples) 116 | return float(sum_rr/n_sample/2) 117 | 118 | def cal_hits(self, model, dataset, nbest): 119 | n_sample = len(dataset) 120 | n_corr = 0 121 | start_id = 0 122 | for samples in dataset.batch_iter(self.batchsize, rand_flg=False): 123 | subs, rels, objs = samples[:, 0], samples[:, 1], samples[:, 2] 124 | ids = np.arange(start_id, start_id+len(samples)) 125 | scores = model.cal_scores(subs, rels) 126 | if self.filtered: 127 | scores = self.cal_filtered_score_fast(subs, rels, objs, ids, scores) 128 | res = np.flip(np.argsort(scores), 1)[:, :nbest] 129 | n_corr += sum(1 for i in range(len(objs)) if objs[i] in res[i]) 130 | 131 | scores = model.cal_scores_inv(rels, objs) 132 | if self.filtered: 133 | scores = self.cal_filtered_score_inv_fast(subs, rels, objs, ids, scores) 134 | res = np.flip(np.argsort(scores), 1) 135 | n_corr += sum(1 for i in range(len(subs)) if subs[i] in res[i]) 136 | start_id += len(samples) 137 | return float(n_corr/n_sample/2) 138 | 139 | def cal_filtered_score_fast(self, subs, rels, objs, ids, raw_scores, metric='sim'): 140 | assert metric in ['sim', 'dist'] 141 | new_scores = [] 142 | for s, r, o, i, score in zip(subs, rels, objs, ids, raw_scores): 143 | true_os = self.id2obj_list[i] 144 | true_os_rm_o = np.delete(true_os, np.where(true_os == o)) 145 | if metric == 'sim': 146 | score[true_os_rm_o] = -np.inf 147 | else: 148 | score[true_os_rm_o] = np.inf 149 | new_scores.append(score) 150 | return new_scores 151 | 152 | def cal_filtered_score_inv_fast(self, subs, rels, objs, ids, raw_scores, metric='sim'): 153 | assert metric in ['sim', 'dist'] 154 | new_scores = [] 155 | for s, r, o, i, score in zip(subs, rels, objs, ids, raw_scores): 156 | true_ss = self.id2sub_list[i] 157 | true_ss_rm_s = np.delete(true_ss, np.where(true_ss==s)) 158 | if metric == 'sim': 159 | score[true_ss_rm_s] = -np.inf 160 | else: 161 | score[true_ss_rm_s] = np.inf 162 | new_scores.append(score) 163 | return new_scores 164 | 165 | def cal_rank(self, score_mat, ents): 166 | return [np.sum(score >= score[e]) for score, e in zip(score_mat, ents)] 167 | 168 | def get_best_info(self): 169 | if self.metric == 'mrr' or self.metric == 'hits' or self.metric == 'acc': # higher value is better 170 | best_val = max(self.ress) 171 | elif self.metric == 'mr': 172 | best_val = min(self.ress) 173 | else: 174 | raise ValueError('Invalid') 175 | best_epoch = self.ress.index(best_val) + 1 176 | return best_epoch, best_val 177 | 178 | def prepare_valid(self, dataset): 179 | for i in range(len(dataset)): 180 | s, r, o = dataset[i] 181 | os = self.all_graph.search_obj_id(s, r) 182 | ss = self.all_graph.search_sub_id(r, o) 183 | self.id2obj_list.append(os) 184 | self.id2sub_list.append(ss) 185 | self.sr2o[(s, r)] = os 186 | self.ro2s[(r, o)] = ss 187 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | --------------------------------------------------------------------------------