├── README.md ├── fb15k ├── README ├── freebase_mtr100_mte100-test.txt ├── freebase_mtr100_mte100-train.txt └── freebase_mtr100_mte100-valid.txt ├── main.py ├── model.py └── prepare_data.py /README.md: -------------------------------------------------------------------------------- 1 | # Pytorch Implementation of TransE 2 | 3 | Pytorch version: 1.1.0 4 | 5 | **Paper:** 6 | - [Translating Embeddings for Modeling Multi-relational Data](https://papers.nips.cc/paper/5071-translating-embeddings-for-modeling-multi-relational-data) 7 | 8 | **Dataset:** 9 | - [FB15k](https://everest.hds.utc.fr/lib/exe/fetch.php?media=en:fb15k.tgz) 10 | 11 | To evaluate, we do tail prediction on the test set, and this TransE model reaches hits@10 of **34.5%**, which is similar to the raw performance mentioned in the paper. -------------------------------------------------------------------------------- /fb15k/README: -------------------------------------------------------------------------------- 1 | ---------------------------------- 2 | -- Freebase FB15k data -- 2013 -- 3 | ---------------------------------- 4 | 5 | ------------------ 6 | OUTLINE: 7 | 1. Introduction 8 | 2. Content 9 | 3. Data Format 10 | 4. Data Statistics 11 | 5. How to Cite 12 | 6. License 13 | 7. Contact 14 | ------------------- 15 | 16 | 17 | 1. INTRODUCTION: 18 | 19 | This FREEBASE FB15k DATA consists of a collection of triplets (synset, relation_type, 20 | triplet) extracted from Freebase (http://www.freebase.com). This data set can 21 | be seen as a 3-mode tensor depicting ternary relationships between synsets. 22 | 23 | 2. CONTENT: 24 | 25 | The data archive contains 4 files: 26 | - README 3K 27 | - freebase_mtr100_mte100-train.txt 36M 28 | - freebase_mtr100_mte100-valid.txt 3.7K 29 | - freebase_mtr100_mte100-test.txt 4.4M 30 | 31 | The 3 files freebase_mtr100_mte100-*.txt contain the triplets (training, validation 32 | and test sets). 33 | 34 | 3. DATA FORMAT 35 | 36 | All freebase_mtr100_mte100-*.txt files contain one triplet per line, with 2 mids 37 | (unique Freebase entity identifier) and relation type identifier in a tab separated 38 | format. The first element is the mid of the left hand side (of head) of the relation triple, 39 | the third one is the mid of the right hand side (or tail) and the second element is the name 40 | of the relationship between them. 41 | 42 | 4. DATA STATISTICS 43 | 44 | There are 14,951 mids and 1,345 relation types among them. The training set contains 45 | 483,142 triplets, the validation set 50,000 and the test set 59,071. 46 | 47 | All triplets are unique and we made sure that all synsets appearing in 48 | the validation or test sets were occurring in the training set. 49 | 50 | 5. HOW TO CITE 51 | 52 | When using this data, one should cite the original paper: 53 | @incollection{bordes-nips13, 54 | title = {Translating Embeddings for Modeling Multi-relational Data}, 55 | author = {Antoine Bordes and Nicolas Usunier and Alberto Garcia-Dur\'an and Jason Weston and Oksana Yakhnenko}, 56 | booktitle={Advances in Neural Information Processing Systems (NIPS 26)}, 57 | year={2013} 58 | } 59 | 60 | One should also point at the project page with either the long URL: 61 | https://www.hds.utc.fr/everest/doku.php?id=en:transe , or the short 62 | one: http://goo.gl/0PpKQe . 63 | 64 | 6. LICENSE: 65 | 66 | FB15k data follows Freebase license, that is Creative Commons Attribution (aka CC-BY) 67 | (http://creativecommons.org/licenses/by/2.5/). 68 | 69 | 7. CONTACT 70 | 71 | For all remarks or questions please contact Antoine Bordes: antoine 72 | (dot) bordes (at) utc (dot) fr . 73 | 74 | 75 | 76 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import optim, nn 3 | from torch.utils.data import Dataset, DataLoader 4 | from model import TranE 5 | from prepare_data import TrainSet, TestSet 6 | 7 | device = torch.device('cuda') 8 | embed_dim = 50 9 | num_epochs = 50 10 | train_batch_size = 32 11 | test_batch_size = 256 12 | lr = 1e-2 13 | momentum = 0 14 | gamma = 1 15 | d_norm = 2 16 | top_k = 10 17 | 18 | 19 | def main(): 20 | train_dataset = TrainSet() 21 | test_dataset = TestSet() 22 | test_dataset.convert_word_to_index(train_dataset.entity_to_index, train_dataset.relation_to_index, 23 | test_dataset.raw_data) 24 | train_loader = DataLoader(train_dataset, batch_size=train_batch_size, shuffle=True) 25 | test_loader = DataLoader(test_dataset, batch_size=test_batch_size, shuffle=True) 26 | transe = TranE(train_dataset.entity_num, train_dataset.relation_num, device, dim=embed_dim, d_norm=d_norm, 27 | gamma=gamma).to(device) 28 | optimizer = optim.SGD(transe.parameters(), lr=lr, momentum=momentum) 29 | for epoch in range(num_epochs): 30 | # e <= e / ||e|| 31 | entity_norm = torch.norm(transe.entity_embedding.weight.data, dim=1, keepdim=True) 32 | transe.entity_embedding.weight.data = transe.entity_embedding.weight.data / entity_norm 33 | total_loss = 0 34 | for batch_idx, (pos, neg) in enumerate(train_loader): 35 | pos, neg = pos.to(device), neg.to(device) 36 | # pos: [batch_size, 3] => [3, batch_size] 37 | pos = torch.transpose(pos, 0, 1) 38 | # pos_head, pos_relation, pos_tail: [batch_size] 39 | pos_head, pos_relation, pos_tail = pos[0], pos[1], pos[2] 40 | neg = torch.transpose(neg, 0, 1) 41 | # neg_head, neg_relation, neg_tail: [batch_size] 42 | neg_head, neg_relation, neg_tail = neg[0], neg[1], neg[2] 43 | loss = transe(pos_head, pos_relation, pos_tail, neg_head, neg_relation, neg_tail) 44 | total_loss += loss.item() 45 | optimizer.zero_grad() 46 | loss.backward() 47 | optimizer.step() 48 | print(f"epoch {epoch+1}, loss = {total_loss/train_dataset.__len__()}") 49 | corrct_test = 0 50 | for batch_idx, data in enumerate(test_loader): 51 | data = data.to(device) 52 | # data: [batch_size, 3] => [3, batch_size] 53 | data = torch.transpose(data, 0, 1) 54 | corrct_test += transe.tail_predict(data[0], data[1], data[2], k=top_k) 55 | print(f"===>epoch {epoch+1}, test accuracy {corrct_test/test_dataset.__len__()}") 56 | 57 | 58 | if __name__ == '__main__': 59 | main() 60 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.utils.data import Dataset, DataLoader 4 | import torch.nn.functional as F 5 | from prepare_data import TrainSet, TestSet 6 | import math 7 | 8 | 9 | class TranE(nn.Module): 10 | def __init__(self, entity_num, relation_num, device, dim=50, d_norm=2, gamma=1): 11 | """ 12 | :param entity_num: number of entities 13 | :param relation_num: number of relations 14 | :param dim: embedding dim 15 | :param device: 16 | :param d_norm: measure d(h+l, t), either L1-norm or L2-norm 17 | :param gamma: margin hyperparameter 18 | """ 19 | super(TranE, self).__init__() 20 | self.dim = dim 21 | self.d_norm = d_norm 22 | self.device = device 23 | self.gamma = torch.FloatTensor([gamma]).to(self.device) 24 | self.entity_num = entity_num 25 | self.relation_num = relation_num 26 | self.entity_embedding = nn.Embedding.from_pretrained( 27 | torch.empty(entity_num, self.dim).uniform_(-6 / math.sqrt(self.dim), 6 / math.sqrt(self.dim)), freeze=False) 28 | self.relation_embedding = nn.Embedding.from_pretrained( 29 | torch.empty(relation_num, self.dim).uniform_(-6 / math.sqrt(self.dim), 6 / math.sqrt(self.dim)), 30 | freeze=False) 31 | # l <= l / ||l|| 32 | relation_norm = torch.norm(self.relation_embedding.weight.data, dim=1, keepdim=True) 33 | self.relation_embedding.weight.data = self.relation_embedding.weight.data / relation_norm 34 | 35 | def forward(self, pos_head, pos_relation, pos_tail, neg_head, neg_relation, neg_tail): 36 | """ 37 | :param pos_head: [batch_size] 38 | :param pos_relation: [batch_size] 39 | :param pos_tail: [batch_size] 40 | :param neg_head: [batch_size] 41 | :param neg_relation: [batch_size] 42 | :param neg_tail: [batch_size] 43 | :return: triples loss 44 | """ 45 | pos_dis = self.entity_embedding(pos_head) + self.relation_embedding(pos_relation) - self.entity_embedding( 46 | pos_tail) 47 | neg_dis = self.entity_embedding(neg_head) + self.relation_embedding(neg_relation) - self.entity_embedding( 48 | neg_tail) 49 | # return pos_head_and_relation, pos_tail, neg_head_and_relation, neg_tail 50 | return self.calculate_loss(pos_dis, neg_dis).requires_grad_() 51 | 52 | def calculate_loss(self, pos_dis, neg_dis): 53 | """ 54 | :param pos_dis: [batch_size, embed_dim] 55 | :param neg_dis: [batch_size, embed_dim] 56 | :return: triples loss: [batch_size] 57 | """ 58 | distance_diff = self.gamma + torch.norm(pos_dis, p=self.d_norm, dim=1) - torch.norm(neg_dis, p=self.d_norm, 59 | dim=1) 60 | return torch.sum(F.relu(distance_diff)) 61 | 62 | def tail_predict(self, head, relation, tail, k=10): 63 | """ 64 | to do tail prediction hits@k 65 | :param head: [batch_size] 66 | :param relation: [batch_size] 67 | :param tail: [batch_size] 68 | :param k: hits@k 69 | :return: 70 | """ 71 | # head: [batch_size] 72 | # h_and_r: [batch_size, embed_size] => [batch_size, 1, embed_size] => [batch_size, N, embed_size] 73 | h_and_r = self.entity_embedding(head) + self.relation_embedding(relation) 74 | h_and_r = torch.unsqueeze(h_and_r, dim=1) 75 | h_and_r = h_and_r.expand(h_and_r.shape[0], self.entity_num, self.dim) 76 | # embed_tail: [batch_size, N, embed_size] 77 | embed_tail = self.entity_embedding.weight.data.expand(h_and_r.shape[0], self.entity_num, self.dim) 78 | # indices: [batch_size, k] 79 | values, indices = torch.topk(torch.norm(h_and_r - embed_tail, dim=2), k, dim=1, largest=False) 80 | # tail: [batch_size] => [batch_size, 1] 81 | tail = tail.view(-1, 1) 82 | return torch.sum(torch.eq(indices, tail)).item() 83 | 84 | 85 | if __name__ == '__main__': 86 | train_data_set = TrainSet() 87 | test_data_set = TestSet() 88 | test_data_set.convert_word_to_index(train_data_set.entity_to_index, train_data_set.relation_to_index, 89 | test_data_set.raw_data) 90 | train_loader = DataLoader(train_data_set, batch_size=32, shuffle=True) 91 | test_loader = DataLoader(test_data_set, batch_size=32, shuffle=True) 92 | for batch_idx, data in enumerate(test_loader): 93 | print(data.shape) 94 | break 95 | -------------------------------------------------------------------------------- /prepare_data.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import numpy as np 3 | from collections import Counter 4 | from torch.utils.data import Dataset, DataLoader 5 | import random 6 | 7 | 8 | class TrainSet(Dataset): 9 | def __init__(self): 10 | super(TrainSet, self).__init__() 11 | # self.raw_data, self.entity_dic, self.relation_dic = self.load_texd() 12 | self.raw_data, self.entity_to_index, self.relation_to_index = self.load_text() 13 | self.entity_num, self.relation_num = len(self.entity_to_index), len(self.relation_to_index) 14 | self.triple_num = self.raw_data.shape[0] 15 | print(f'Train set: {self.entity_num} entities, {self.relation_num} relations, {self.triple_num} triplets.') 16 | self.pos_data = self.convert_word_to_index(self.raw_data) 17 | self.related_dic = self.get_related_entity() 18 | # print(self.related_dic[0], self.related_dic[479]) 19 | self.neg_data = self.generate_neg() 20 | 21 | def __len__(self): 22 | return self.triple_num 23 | 24 | def __getitem__(self, item): 25 | return [self.pos_data[item], self.neg_data[item]] 26 | 27 | def load_text(self): 28 | raw_data = pd.read_csv('./fb15k/freebase_mtr100_mte100-train.txt', sep='\t', header=None, 29 | names=['head', 'relation', 'tail'], 30 | keep_default_na=False, encoding='utf-8') 31 | raw_data = raw_data.applymap(lambda x: x.strip()) 32 | head_count = Counter(raw_data['head']) 33 | tail_count = Counter(raw_data['tail']) 34 | relation_count = Counter(raw_data['relation']) 35 | entity_list = list((head_count + tail_count).keys()) 36 | relation_list = list(relation_count.keys()) 37 | entity_dic = dict([(word, idx) for idx, word in enumerate(entity_list)]) 38 | relation_dic = dict([(word, idx) for idx, word in enumerate(relation_list)]) 39 | return raw_data.values, entity_dic, relation_dic 40 | 41 | def convert_word_to_index(self, data): 42 | index_list = np.array([ 43 | [self.entity_to_index[triple[0]], self.relation_to_index[triple[1]], self.entity_to_index[triple[2]]] for 44 | triple in data]) 45 | return index_list 46 | 47 | def generate_neg(self): 48 | """ 49 | generate negative sampling 50 | :return: same shape as positive sampling 51 | """ 52 | neg_candidates, i = [], 0 53 | neg_data = [] 54 | population = list(range(self.entity_num)) 55 | for idx, triple in enumerate(self.pos_data): 56 | while True: 57 | if i == len(neg_candidates): 58 | i = 0 59 | neg_candidates = random.choices(population=population, k=int(1e4)) 60 | neg, i = neg_candidates[i], i + 1 61 | if random.randint(0, 1) == 0: 62 | # replace head 63 | if neg not in self.related_dic[triple[2]]: 64 | neg_data.append([neg, triple[1], triple[2]]) 65 | break 66 | else: 67 | # replace tail 68 | if neg not in self.related_dic[triple[0]]: 69 | neg_data.append([triple[0], triple[1], neg]) 70 | break 71 | 72 | return np.array(neg_data) 73 | 74 | def get_related_entity(self): 75 | """ 76 | get related entities 77 | :return: {entity_id: {related_entity_id_1, related_entity_id_2...}} 78 | """ 79 | related_dic = dict() 80 | for triple in self.pos_data: 81 | if related_dic.get(triple[0]) is None: 82 | related_dic[triple[0]] = {triple[2]} 83 | else: 84 | related_dic[triple[0]].add(triple[2]) 85 | if related_dic.get(triple[2]) is None: 86 | related_dic[triple[2]] = {triple[0]} 87 | else: 88 | related_dic[triple[2]].add(triple[0]) 89 | return related_dic 90 | 91 | 92 | class TestSet(Dataset): 93 | def __init__(self): 94 | super(TestSet, self).__init__() 95 | self.raw_data = self.load_text() 96 | self.data = self.raw_data 97 | print(f"Test set: {self.raw_data.shape[0]} triplets") 98 | 99 | def __getitem__(self, item): 100 | return self.data[item] 101 | 102 | def __len__(self): 103 | return self.data.shape[0] 104 | 105 | def load_text(self): 106 | raw_data = pd.read_csv('./fb15k/freebase_mtr100_mte100-test.txt', sep='\t', header=None, 107 | names=['head', 'relation', 'tail'], 108 | keep_default_na=False, encoding='utf-8') 109 | raw_data = raw_data.applymap(lambda x: x.strip()) 110 | return raw_data.values 111 | 112 | def convert_word_to_index(self, entity_to_index, relation_to_index, data): 113 | index_list = np.array( 114 | [[entity_to_index[triple[0]], relation_to_index[triple[1]], entity_to_index[triple[2]]] for triple in data]) 115 | self.data = index_list 116 | 117 | 118 | if __name__ == '__main__': 119 | train_data_set = TrainSet() 120 | test_data_set = TestSet() 121 | test_data_set.convert_word_to_index(train_data_set.entity_to_index, train_data_set.relation_to_index, 122 | test_data_set.raw_data) 123 | train_loader = DataLoader(train_data_set, batch_size=32, shuffle=True) 124 | test_loader = DataLoader(test_data_set, batch_size=32, shuffle=True) 125 | for batch_idx, data in enumerate(train_loader): 126 | break 127 | # for batch_idx, (pos, neg) in enumerate(loader): 128 | # # print(pos, neg) 129 | # break 130 | --------------------------------------------------------------------------------