├── README.md ├── loss.py ├── utils.py ├── data.py ├── train.py └── model.py /README.md: -------------------------------------------------------------------------------- 1 | # RAGA 2 | [RAGA: Relation-aware Graph Attention Networks for Global Entity Alignment (PAKDD 2021)](https://arxiv.org/abs/2103.00791) 3 | 4 | ## Datasets 5 | Please download the datasets [here](https://drive.google.com/file/d/1uJ2omzIs0NCtJsGQsyFCBHCXUhoK1mkO/view?usp=sharing) and extract them into root directory. 6 | 7 | ## Environment 8 | 9 | ``` 10 | apex 11 | pytorch 12 | torch_geometric 13 | ``` 14 | 15 | ## Running 16 | 17 | For local alignment, use: 18 | ``` 19 | CUDA_VISIBLE_DEVICES=0 python train.py --data data/DBP15K --lang zh_en 20 | ``` 21 | 22 | For local and global alignment, use: 23 | ``` 24 | CUDA_VISIBLE_DEVICES=0 python train.py --data data/DBP15K --lang zh_en --stable_test 25 | ``` 26 | -------------------------------------------------------------------------------- /loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class L1_Loss(nn.Module): 7 | def __init__(self, gamma=3): 8 | super(L1_Loss, self).__init__() 9 | self.gamma = gamma 10 | 11 | def dis(self, x, y): 12 | return torch.sum(torch.abs(x-y), dim=-1) 13 | 14 | def forward(self, x1, x2, train_set, train_batch): 15 | x1_train, x2_train = x1[train_set[:, 0]], x2[train_set[:, 1]] 16 | x1_neg1 = x1[train_batch[0].view(-1)].reshape(-1, train_set.size(0), x1.size(1)) 17 | x1_neg2 = x2[train_batch[1].view(-1)].reshape(-1, train_set.size(0), x2.size(1)) 18 | x2_neg1 = x2[train_batch[2].view(-1)].reshape(-1, train_set.size(0), x2.size(1)) 19 | x2_neg2 = x1[train_batch[3].view(-1)].reshape(-1, train_set.size(0), x1.size(1)) 20 | 21 | dis_x1_x2 = self.dis(x1_train, x2_train) 22 | loss11 = torch.mean(F.relu(self.gamma+dis_x1_x2-self.dis(x1_train, x1_neg1))) 23 | loss12 = torch.mean(F.relu(self.gamma+dis_x1_x2-self.dis(x1_train, x1_neg2))) 24 | loss21 = torch.mean(F.relu(self.gamma+dis_x1_x2-self.dis(x2_train, x2_neg1))) 25 | loss22 = torch.mean(F.relu(self.gamma+dis_x1_x2-self.dis(x2_train, x2_neg2))) 26 | loss = (loss11+loss12+loss21+loss22)/4 27 | return loss 28 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch_geometric.utils import add_self_loops 5 | 6 | 7 | def add_inverse_rels(edge_index, rel): 8 | edge_index_all = torch.cat([edge_index, edge_index[[1,0]]], dim=1) 9 | rel_all = torch.cat([rel, rel+rel.max()+1]) 10 | return edge_index_all, rel_all 11 | 12 | 13 | def get_train_batch(x1, x2, train_set, k=5): 14 | e1_neg1 = torch.cdist(x1[train_set[:, 0]], x1, p=1).topk(k+1, largest=False)[1].t()[1:] 15 | e1_neg2 = torch.cdist(x1[train_set[:, 0]], x2, p=1).topk(k+1, largest=False)[1].t()[1:] 16 | e2_neg1 = torch.cdist(x2[train_set[:, 1]], x2, p=1).topk(k+1, largest=False)[1].t()[1:] 17 | e2_neg2 = torch.cdist(x2[train_set[:, 1]], x1, p=1).topk(k+1, largest=False)[1].t()[1:] 18 | train_batch = torch.stack([e1_neg1, e1_neg2, e2_neg1, e2_neg2], dim=0) 19 | return train_batch 20 | 21 | 22 | def get_hits(x1, x2, pair, dist='L1', Hn_nums=(1, 10)): 23 | pair_num = pair.size(0) 24 | S = torch.cdist(x1[pair[:, 0]], x2[pair[:, 1]], p=1) 25 | print('Left:\t',end='') 26 | for k in Hn_nums: 27 | pred_topk= S.topk(k, largest=False)[1] 28 | Hk = (pred_topk == torch.arange(pair_num, device=S.device).view(-1, 1)).sum().item()/pair_num 29 | print('Hits@%d: %.2f%% ' % (k, Hk*100),end='') 30 | rank = torch.where(S.sort()[1] == torch.arange(pair_num, device=S.device).view(-1, 1))[1].float() 31 | MRR = (1/(rank+1)).mean().item() 32 | print('MRR: %.3f' % MRR) 33 | print('Right:\t',end='') 34 | for k in Hn_nums: 35 | pred_topk= S.t().topk(k, largest=False)[1] 36 | Hk = (pred_topk == torch.arange(pair_num, device=S.device).view(-1, 1)).sum().item()/pair_num 37 | print('Hits@%d: %.2f%% ' % (k, Hk*100),end='') 38 | rank = torch.where(S.t().sort()[1] == torch.arange(pair_num, device=S.device).view(-1, 1))[1].float() 39 | MRR = (1/(rank+1)).mean().item() 40 | print('MRR: %.3f' % MRR) 41 | 42 | 43 | def get_hits_stable(x1, x2, pair): 44 | pair_num = pair.size(0) 45 | S = -torch.cdist(x1[pair[:, 0]], x2[pair[:, 1]], p=1).cpu() 46 | #index = S.flatten().argsort(descending=True) 47 | index = (S.softmax(1)+S.softmax(0)).flatten().argsort(descending=True) 48 | index_e1 = index//pair_num 49 | index_e2 = index%pair_num 50 | aligned_e1 = torch.zeros(pair_num, dtype=torch.bool) 51 | aligned_e2 = torch.zeros(pair_num, dtype=torch.bool) 52 | true_aligned = 0 53 | for _ in range(pair_num*100): 54 | if aligned_e1[index_e1[_]] or aligned_e2[index_e2[_]]: 55 | continue 56 | if index_e1[_] == index_e2[_]: 57 | true_aligned += 1 58 | aligned_e1[index_e1[_]] = True 59 | aligned_e2[index_e2[_]] = True 60 | print('Both:\tHits@Stable: %.2f%% ' % (true_aligned/pair_num*100)) 61 | -------------------------------------------------------------------------------- /data.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import torch 4 | from torch_geometric.io import read_txt_array 5 | from torch_geometric.data import Data, InMemoryDataset 6 | from torch_geometric.utils import sort_edge_index 7 | 8 | 9 | class DBP15K(InMemoryDataset): 10 | def __init__(self, root, pair, KG_num=1, rate=0.3, seed=1): 11 | self.pair = pair 12 | self.KG_num = KG_num 13 | self.rate = rate 14 | self.seed = seed 15 | torch.manual_seed(seed) 16 | super(DBP15K, self).__init__(root) 17 | self.data, self.slices = torch.load(self.processed_paths[0]) 18 | 19 | @property 20 | def raw_file_names(self): 21 | return ['zh_en', 'fr_en', 'ja_en'] 22 | 23 | @property 24 | def processed_file_names(self): 25 | return '%s_%d_%.1f_%d.pt' % (self.pair, self.KG_num, self.rate, self.seed) 26 | 27 | def process(self): 28 | x1_path = os.path.join(self.root, self.pair, 'ent_ids_1') 29 | x2_path = os.path.join(self.root, self.pair, 'ent_ids_2') 30 | g1_path = os.path.join(self.root, self.pair, 'triples_1') 31 | g2_path = os.path.join(self.root, self.pair, 'triples_2') 32 | emb_path = os.path.join(self.root, self.pair, self.pair[:2]+'_vectorList.json') 33 | x1, edge_index1, rel1, assoc1 = self.process_graph(g1_path, x1_path, emb_path) 34 | x2, edge_index2, rel2, assoc2 = self.process_graph(g2_path, x2_path, emb_path) 35 | 36 | pair_path = os.path.join(self.root, self.pair, 'ref_ent_ids') 37 | pair_set = self.process_pair(pair_path, assoc1, assoc2) 38 | pair_set = pair_set[:, torch.randperm(pair_set.size(1))] 39 | train_set = pair_set[:, :int(self.rate*pair_set.size(1))] 40 | test_set = pair_set[:, int(self.rate*pair_set.size(1)):] 41 | 42 | if self.KG_num == 1: 43 | data = Data(x1=x1, edge_index1=edge_index1, rel1=rel1, 44 | x2=x2, edge_index2=edge_index2, rel2=rel2, 45 | train_set=train_set.t(), test_set=test_set.t()) 46 | else: 47 | x = torch.cat([x1, x2], dim=0) 48 | edge_index = torch.cat([edge_index1, edge_index2+x1.size(0)], dim=1) 49 | rel = torch.cat([rel1, rel2+rel1.max()+1], dim=0) 50 | data = Data(x=x, edge_index=edge_index, rel=rel,train_set=train_set.t(), test_set=test_set.t()) 51 | torch.save(self.collate([data]), self.processed_paths[0]) 52 | 53 | def process_graph(self, triple_path, ent_path, emb_path): 54 | g = read_txt_array(triple_path, sep='\t', dtype=torch.long) 55 | subj, rel, obj = g.t() 56 | 57 | assoc = torch.full((rel.max().item()+1,), -1, dtype=torch.long) 58 | assoc[rel.unique()] = torch.arange(rel.unique().size(0)) 59 | rel = assoc[rel] 60 | 61 | idx = [] 62 | with open(ent_path, 'r') as f: 63 | for line in f: 64 | info = line.strip().split('\t') 65 | idx.append(int(info[0])) 66 | idx = torch.tensor(idx) 67 | with open(emb_path, 'r', encoding='utf-8') as f: 68 | embedding_list = torch.tensor(json.load(f)) 69 | x = embedding_list[idx] 70 | 71 | assoc = torch.full((idx.max().item()+1, ), -1, dtype=torch.long) 72 | assoc[idx] = torch.arange(idx.size(0)) 73 | subj, obj = assoc[subj], assoc[obj] 74 | edge_index = torch.stack([subj, obj], dim=0) 75 | edge_index, rel = sort_edge_index(edge_index, rel) 76 | return x, edge_index, rel, assoc 77 | 78 | def process_pair(self, path, assoc1, assoc2): 79 | e1, e2 = read_txt_array(path, sep='\t', dtype=torch.long).t() 80 | return torch.stack([assoc1[e1], assoc2[e2]], dim=0) 81 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import itertools 4 | 5 | import apex 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | 10 | from model import RAGA 11 | from data import DBP15K 12 | from loss import L1_Loss 13 | from utils import add_inverse_rels, get_train_batch, get_hits, get_hits_stable 14 | 15 | 16 | def parse_args(): 17 | parser = argparse.ArgumentParser() 18 | parser.add_argument("--cuda", action="store_true", default=True) 19 | parser.add_argument("--data", default="data/DBP15K") 20 | parser.add_argument("--lang", default="zh_en") 21 | parser.add_argument("--rate", type=float, default=0.3) 22 | 23 | parser.add_argument("--r_hidden", type=int, default=100) 24 | 25 | parser.add_argument("--k", type=int, default=5) 26 | parser.add_argument("--gamma", type=float, default=3) 27 | 28 | parser.add_argument("--epoch", type=int, default=80) 29 | parser.add_argument("--neg_epoch", type=int, default=10) 30 | parser.add_argument("--test_epoch", type=int, default=5) 31 | parser.add_argument("--stable_test", action="store_true", default=False) 32 | args = parser.parse_args() 33 | return args 34 | 35 | 36 | def init_data(args, device): 37 | data = DBP15K(args.data, args.lang, rate=args.rate)[0] 38 | data.x1 = F.normalize(data.x1, dim=1, p=2).to(device).requires_grad_() 39 | data.x2 = F.normalize(data.x2, dim=1, p=2).to(device).requires_grad_() 40 | data.edge_index_all1, data.rel_all1 = add_inverse_rels(data.edge_index1, data.rel1) 41 | data.edge_index_all2, data.rel_all2 = add_inverse_rels(data.edge_index2, data.rel2) 42 | return data 43 | 44 | 45 | def get_emb(model, data): 46 | model.eval() 47 | with torch.no_grad(): 48 | x1 = model(data.x1, data.edge_index1, data.rel1, data.edge_index_all1, data.rel_all1) 49 | x2 = model(data.x2, data.edge_index2, data.rel2, data.edge_index_all2, data.rel_all2) 50 | return x1, x2 51 | 52 | 53 | def train(model, criterion, optimizer, data, train_batch): 54 | model.train() 55 | x1 = model(data.x1, data.edge_index1, data.rel1, data.edge_index_all1, data.rel_all1) 56 | x2 = model(data.x2, data.edge_index2, data.rel2, data.edge_index_all2, data.rel_all2) 57 | loss = criterion(x1, x2, data.train_set, train_batch) 58 | optimizer.zero_grad() 59 | with apex.amp.scale_loss(loss, optimizer) as scaled_loss: 60 | scaled_loss.backward() 61 | optimizer.step() 62 | return loss 63 | 64 | 65 | def test(model, data, stable=False): 66 | x1, x2 = get_emb(model, data) 67 | print('-'*16+'Train_set'+'-'*16) 68 | get_hits(x1, x2, data.train_set) 69 | print('-'*16+'Test_set'+'-'*17) 70 | get_hits(x1, x2, data.test_set) 71 | if stable: 72 | get_hits_stable(x1, x2, data.test_set) 73 | print() 74 | return x1, x2 75 | 76 | 77 | def main(args): 78 | device = 'cuda' if args.cuda and torch.cuda.is_available() else 'cpu' 79 | data = init_data(args, device).to(device) 80 | model = RAGA(data.x1.size(1), args.r_hidden).to(device) 81 | optimizer = torch.optim.Adam(itertools.chain(model.parameters(), iter([data.x1, data.x2]))) 82 | model, optimizer = apex.amp.initialize(model, optimizer) 83 | criterion = L1_Loss(args.gamma) 84 | for epoch in range(args.epoch): 85 | if epoch%args.neg_epoch == 0: 86 | x1, x2 = get_emb(model, data) 87 | train_batch = get_train_batch(x1, x2, data.train_set, args.k) 88 | loss = train(model, criterion, optimizer, data, train_batch) 89 | print('Epoch:', epoch+1, '/', args.epoch, '\tLoss: %.3f'%loss, '\r', end='') 90 | if (epoch+1)%args.test_epoch == 0: 91 | print() 92 | test(model, data, args.stable_test) 93 | #x1, x2 = get_emb(model, data) 94 | #torch.save([x1[data.test_set[:, 0]].cpu(), x2[data.test_set[:, 1]].cpu()], 'x.pt') 95 | 96 | 97 | if __name__ == '__main__': 98 | args = parse_args() 99 | main(args) 100 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch_sparse import spmm 5 | from torch_geometric.utils import softmax, degree 6 | 7 | 8 | class GCN(nn.Module): 9 | def __init__(self, hidden): 10 | super(GCN, self).__init__() 11 | 12 | def forward(self, x, edge_index): 13 | edge_index_j, edge_index_i = edge_index 14 | deg = degree(edge_index_i, x.size(0), dtype=x.dtype) 15 | deg_inv_sqrt = deg.pow(-0.5) 16 | norm = deg_inv_sqrt[edge_index_j]*deg_inv_sqrt[edge_index_i] 17 | x = F.relu(spmm(edge_index[[1, 0]], norm, x.size(0), x.size(0), x)) 18 | return x 19 | 20 | 21 | class Highway(nn.Module): 22 | def __init__(self, x_hidden): 23 | super(Highway, self).__init__() 24 | self.lin = nn.Linear(x_hidden, x_hidden) 25 | 26 | def forward(self, x1, x2): 27 | gate = torch.sigmoid(self.lin(x1)) 28 | x = torch.mul(gate, x2)+torch.mul(1-gate, x1) 29 | return x 30 | 31 | 32 | class GAT_E_to_R(nn.Module): 33 | def __init__(self, e_hidden, r_hidden): 34 | super(GAT_E_to_R, self).__init__() 35 | self.a_h1 = nn.Linear(r_hidden, 1, bias=False) 36 | self.a_h2 = nn.Linear(r_hidden, 1, bias=False) 37 | self.a_t1 = nn.Linear(r_hidden, 1, bias=False) 38 | self.a_t2 = nn.Linear(r_hidden, 1, bias=False) 39 | self.w_h = nn.Linear(e_hidden, r_hidden, bias=False) 40 | self.w_t = nn.Linear(e_hidden, r_hidden, bias=False) 41 | 42 | def forward(self, x_e, edge_index, rel): 43 | edge_index_h, edge_index_t = edge_index 44 | x_r_h = self.w_h(x_e) 45 | x_r_t = self.w_t(x_e) 46 | 47 | e1 = self.a_h1(x_r_h).squeeze()[edge_index_h]+self.a_h2(x_r_t).squeeze()[edge_index_t] 48 | e2 = self.a_t1(x_r_h).squeeze()[edge_index_h]+self.a_t2(x_r_t).squeeze()[edge_index_t] 49 | 50 | alpha = softmax(F.leaky_relu(e1).float(), rel) 51 | x_r_h = spmm(torch.cat([rel.view(1, -1), edge_index_h.view(1, -1)], dim=0), alpha, rel.max()+1, x_e.size(0), x_r_h) 52 | 53 | alpha = softmax(F.leaky_relu(e2).float(), rel) 54 | x_r_t = spmm(torch.cat([rel.view(1, -1), edge_index_t.view(1, -1)], dim=0), alpha, rel.max()+1, x_e.size(0), x_r_t) 55 | x_r = x_r_h+x_r_t 56 | return x_r 57 | 58 | 59 | class GAT_R_to_E(nn.Module): 60 | def __init__(self, e_hidden, r_hidden): 61 | super(GAT_R_to_E, self).__init__() 62 | self.a_h = nn.Linear(e_hidden, 1, bias=False) 63 | self.a_t = nn.Linear(e_hidden, 1, bias=False) 64 | self.a_r = nn.Linear(r_hidden, 1, bias=False) 65 | 66 | def forward(self, x_e, x_r, edge_index, rel): 67 | edge_index_h, edge_index_t = edge_index 68 | e_h = self.a_h(x_e).squeeze()[edge_index_h] 69 | e_t = self.a_t(x_e).squeeze()[edge_index_t] 70 | e_r = self.a_r(x_r).squeeze()[rel] 71 | alpha = softmax(F.leaky_relu(e_h+e_r).float(), edge_index_h) 72 | x_e_h = spmm(torch.cat([edge_index_h.view(1, -1), rel.view(1, -1)], dim=0), alpha, x_e.size(0), x_r.size(0), x_r) 73 | alpha = softmax(F.leaky_relu(e_t+e_r).float(), edge_index_t) 74 | x_e_t = spmm(torch.cat([edge_index_t.view(1, -1), rel.view(1, -1)], dim=0), alpha, x_e.size(0), x_r.size(0), x_r) 75 | x = torch.cat([x_e_h, x_e_t], dim=1) 76 | return x 77 | 78 | 79 | class GAT(nn.Module): 80 | def __init__(self, hidden): 81 | super(GAT, self).__init__() 82 | self.a_i = nn.Linear(hidden, 1, bias=False) 83 | self.a_j = nn.Linear(hidden, 1, bias=False) 84 | self.a_r = nn.Linear(hidden, 1, bias=False) 85 | 86 | def forward(self, x, edge_index): 87 | edge_index_j, edge_index_i = edge_index 88 | e_i = self.a_i(x).squeeze()[edge_index_i] 89 | e_j = self.a_j(x).squeeze()[edge_index_j] 90 | e = e_i+e_j 91 | alpha = softmax(F.leaky_relu(e).float(), edge_index_i) 92 | x = F.relu(spmm(edge_index[[1, 0]], alpha, x.size(0), x.size(0), x)) 93 | return x 94 | 95 | 96 | class RAGA(nn.Module): 97 | def __init__(self, e_hidden=300, r_hidden=100): 98 | super(RAGA, self).__init__() 99 | self.gcn1 = GCN(e_hidden) 100 | self.highway1 = Highway(e_hidden) 101 | self.gcn2 = GCN(e_hidden) 102 | self.highway2 = Highway(e_hidden) 103 | self.gat_e_to_r = GAT_E_to_R(e_hidden, r_hidden) 104 | self.gat_r_to_e = GAT_R_to_E(e_hidden, r_hidden) 105 | self.gat = GAT(e_hidden+2*r_hidden) 106 | 107 | def forward(self, x_e, edge_index, rel, edge_index_all, rel_all): 108 | x_e = self.highway1(x_e, self.gcn1(x_e, edge_index_all)) 109 | x_e = self.highway2(x_e, self.gcn2(x_e, edge_index_all)) 110 | x_r = self.gat_e_to_r(x_e, edge_index, rel) 111 | x_e = torch.cat([x_e, self.gat_r_to_e(x_e, x_r, edge_index, rel)], dim=1) 112 | x_e = torch.cat([x_e, self.gat(x_e, edge_index_all)], dim=1) 113 | return x_e 114 | --------------------------------------------------------------------------------