├── README.md ├── dataloader.py ├── utils.py ├── trm.py ├── run.py └── TET.py /README.md: -------------------------------------------------------------------------------- 1 | # Transformer-based Entity Typing in Knowledge Graphs 2 | #### This repo provides the source code & data of our paper: [Transformer-based Entity Typing in Knowledge Graphs (EMNLP2022)](https://arxiv.org/pdf/2210.11151.pdf). 3 | 4 | ## Dependencies 5 | * conda create -n tet python=3.7 -y 6 | * PyTorch 1.8.1 7 | * transformers 4.7.0 8 | * pytorch-pretrained-bert 0.6.2 9 | 10 | ## Running the code 11 | ### Dataset 12 | * Download the datasets from [Here](https://drive.google.com/drive/folders/120QIGxsGQXfH6Rd8wJe7i57gg8dlx7l2?usp=sharing). 13 | * Create the root directory ./data and put the dataset in. 14 | 15 | ### Training model 16 | #### For FB15kET dataset 17 | ```python 18 | export DATASET=FB15kET 19 | export SAVE_DIR_NAME=FB15kET 20 | export LOG_PATH=./logs/FB15kET.out 21 | export HIDDEN_DIM=100 22 | export TEMPERATURE=0.5 23 | export LEARNING_RATE=0.001 24 | export TRAIN_BATCH_SIZE=128 25 | export MAX_EPOCH=500 26 | export VALID_EPOCH=25 27 | export BETA=1 28 | export LOSS=SFNA 29 | 30 | export PAIR_POOLING=avg 31 | export SAMPLE_ET_SIZE=3 32 | export SAMPLE_KG_SIZE=7 33 | export SAMPLE_ENT2PAIR_SIZE=6 34 | export WARM_UP_STEPS=50 35 | export TT_ABLATION=all 36 | 37 | CUDA_VISIBLE_DEVICES=0 python ./run.py --dataset $DATASET --save_path $SAVE_DIR_NAME --hidden_dim $HIDDEN_DIM --temperature $TEMPERATURE --lr $LEARNING_RATE \ 38 | --train_batch_size $TRAIN_BATCH_SIZE --cuda --max_epoch $MAX_EPOCH --valid_epoch $VALID_EPOCH --beta $BETA --loss $LOSS \ 39 | --pair_pooling $PAIR_POOLING --sample_et_size $SAMPLE_ET_SIZE --sample_kg_size $SAMPLE_KG_SIZE --sample_ent2pair_size $SAMPLE_ENT2PAIR_SIZE --warm_up_steps $WARM_UP_STEPS \ 40 | --tt_ablation $TT_ABLATION \ 41 | > $LOG_PATH 2>&1 & 42 | ``` 43 | #### For YAGO43kET dataset 44 | ```python 45 | export DATASET=YAGO43kET 46 | export SAVE_DIR_NAME=YAGO43kET 47 | export LOG_PATH=./logs/YAGO43kET.out 48 | export HIDDEN_DIM=100 49 | export TEMPERATURE=0.5 50 | export LEARNING_RATE=0.001 51 | export TRAIN_BATCH_SIZE=128 52 | export MAX_EPOCH=500 53 | export VALID_EPOCH=25 54 | export BETA=1 55 | export LOSS=SFNA 56 | 57 | export PAIR_POOLING=avg 58 | export SAMPLE_ET_SIZE=3 59 | export SAMPLE_KG_SIZE=8 60 | export SAMPLE_ENT2PAIR_SIZE=6 61 | export WARM_UP_STEPS=50 62 | export TT_ABLATION=all 63 | 64 | CUDA_VISIBLE_DEVICES=1 python ./run.py --dataset $DATASET --save_path $SAVE_DIR_NAME --hidden_dim $HIDDEN_DIM --temperature $TEMPERATURE --lr $LEARNING_RATE \ 65 | --train_batch_size $TRAIN_BATCH_SIZE --cuda --max_epoch $MAX_EPOCH --valid_epoch $VALID_EPOCH --beta $BETA --loss $LOSS \ 66 | --pair_pooling $PAIR_POOLING --sample_et_size $SAMPLE_ET_SIZE --sample_kg_size $SAMPLE_KG_SIZE --sample_ent2pair_size $SAMPLE_ENT2PAIR_SIZE --warm_up_steps $WARM_UP_STEPS \ 67 | --tt_ablation $TT_ABLATION \ 68 | > $LOG_PATH 2>&1 & 69 | ``` 70 | 71 | * **Note:** Before running, you need to create the ./logs folder first. 72 | 73 | ## Citation 74 | If you find this code useful, please consider citing the following paper. 75 | ``` 76 | @article{ 77 | author={Zhiwei Hu and Víctor Gutiérrez-Basulto and Zhiliang Xiang and Ru Li and Jeff Z. Pan}, 78 | title={Transformer-based Entity Typing in Knowledge Graphs}, 79 | publisher="The Conference on Empirical Methods in Natural Language Processing", 80 | year={2022} 81 | } 82 | ``` 83 | ## Acknowledgement 84 | We refer to the code of [CET](https://github.com/CCIIPLab/CET). Thanks for their contributions. 85 | -------------------------------------------------------------------------------- /dataloader.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import os 3 | import numpy as np 4 | import torch 5 | 6 | from tqdm import tqdm 7 | from torch.utils.data import Dataset 8 | 9 | class ETDataset(Dataset): 10 | def __init__(self, args, data_name, e2id, r2id, t2id, c2id, data_flag): 11 | self.args = args 12 | self.data_name = data_name 13 | self.e2id = e2id 14 | self.r2id = r2id 15 | self.t2id = t2id 16 | self.c2id = c2id 17 | self.sample_et_size = args["sample_et_size"] 18 | self.sample_kg_size = args["sample_kg_size"] 19 | self.data = self.load_dataset() 20 | self.data_flag = data_flag 21 | 22 | def load_dataset(self): 23 | data_name_path = self.args["data_dir"] + '/' + self.args["dataset"] + '/' + self.data_name 24 | contents = [] 25 | 26 | output_pickle = data_name_path[0: data_name_path.rfind('.')] + '.pkl' 27 | if os.path.exists(output_pickle): 28 | with open(output_pickle, 'rb') as handle: 29 | contents = pickle.load(handle) 30 | return contents 31 | 32 | with open(data_name_path, 'r', encoding='UTF-8') as f: 33 | for line in tqdm(f): 34 | line = line.strip() 35 | if not line: 36 | continue 37 | mask_ent, et_triples, kg_triples, clu_triples = [_.strip() for _ in line.split('|||')] 38 | et_content_list = et_triples.split(' [SEP] ') 39 | et_list = [] 40 | 41 | kg_content_list = kg_triples.split(' [SEP] ') 42 | kg_list = [] 43 | 44 | for et_content in et_content_list: 45 | et_head, et_rel, et_type = et_content.split(' ') 46 | et_head_id = self.e2id[et_head] 47 | et_type_id = self.t2id[et_type] + len(self.e2id) 48 | et_list.append([et_head_id, self.c2id[et_rel] + len(self.r2id), et_type_id]) 49 | 50 | for kg_content in kg_content_list: 51 | kg_head, kg_rel, kg_tail = kg_content.split(' ') 52 | if kg_rel.startswith('inv-'): 53 | kg_rel_id = len(self.r2id) + len(self.c2id) + self.r2id[kg_rel[4:]] 54 | else: 55 | kg_rel_id = self.r2id[kg_rel] 56 | kg_head_id = self.e2id[kg_head] 57 | kg_tail_id = self.e2id[kg_tail] 58 | kg_list.append([kg_head_id, kg_rel_id, kg_tail_id]) 59 | 60 | contents.append((et_list, kg_list, self.e2id[mask_ent])) 61 | 62 | with open(output_pickle, 'wb') as handle: 63 | pickle.dump(contents, handle) 64 | 65 | return contents 66 | 67 | def __getitem__(self, index): 68 | et_content = self.data[index][0] 69 | kg_content = self.data[index][1] 70 | ent = self.data[index][2] 71 | 72 | single_et_np_list = [] 73 | if self.sample_et_size != 1: 74 | sampled_index = np.random.choice(range(0, len(et_content)), size=self.sample_et_size, 75 | replace=len(range(0, len(et_content))) < self.sample_et_size) 76 | for i in sampled_index: 77 | single_et_np_list.append(et_content[i]) 78 | else: 79 | single_et_np_list.append(et_content[0]) 80 | 81 | single_kg_np_list = [] 82 | if self.sample_kg_size != 1: 83 | sampled_index = np.random.choice(range(0, len(kg_content)), size=self.sample_kg_size, 84 | replace=len(range(0, len(kg_content))) < self.sample_kg_size) 85 | for i in sampled_index: 86 | single_kg_np_list.append(kg_content[i]) 87 | else: 88 | single_kg_np_list.append(kg_content[0]) 89 | 90 | all_et = et_content 91 | all_kg = kg_content 92 | sample_et = single_et_np_list 93 | sample_kg = single_kg_np_list 94 | 95 | gt_ent = ent 96 | 97 | if self.data_flag == 'test': 98 | # for test, we need all neighbor information 99 | return all_et, all_kg, gt_ent 100 | else: 101 | return sample_et, sample_kg, gt_ent 102 | 103 | def __len__(self): 104 | return len(self.data) 105 | 106 | @staticmethod 107 | def collate_fn(batch): 108 | sample_et_content_list = [] 109 | sample_et_content_list.append([_[0] for _ in batch]) 110 | 111 | sample_kg_content_list = [] 112 | sample_kg_content_list.append([_[1] for _ in batch]) 113 | 114 | gt_ent_list = [] 115 | gt_ent_list.append([_[2] for _ in batch]) 116 | 117 | et_content = torch.LongTensor(sample_et_content_list[0]) 118 | kg_content = torch.LongTensor(sample_kg_content_list[0]) 119 | 120 | gt_ent = torch.LongTensor(gt_ent_list[0]) 121 | 122 | return et_content, kg_content, gt_ent 123 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import logging 3 | import os 4 | import numpy as np 5 | 6 | def set_logger(args): 7 | if not os.path.exists(os.path.join(args['save_dir'], args['save_path'])): 8 | os.makedirs(os.path.join(os.getcwd(), args['save_dir'], args['save_path'])) 9 | 10 | log_file = os.path.join(args['save_dir'], args['save_path'], args['log_name']+'.txt') 11 | 12 | logging.basicConfig( 13 | format='%(asctime)s %(levelname)-8s %(message)s', 14 | level=logging.DEBUG, 15 | datefmt='%Y-%m-%d %H:%M:%S', 16 | filename=log_file, 17 | filemode='w' 18 | ) 19 | 20 | console = logging.StreamHandler() 21 | console.setLevel(logging.DEBUG) 22 | formatter = logging.Formatter('%(asctime)s %(levelname)-8s %(message)s') 23 | console.setFormatter(formatter) 24 | logging.getLogger('').addHandler(console) 25 | 26 | def read_id(path): 27 | tmp = dict() 28 | with open(path, encoding='utf-8') as r: 29 | for line in r: 30 | e, t = line.strip().split('\t') 31 | tmp[e] = int(t) 32 | return tmp 33 | 34 | def load_type_labels(paths, e2id, t2id): 35 | labels = torch.zeros(len(e2id), len(t2id)) 36 | for path in paths: 37 | with open(path, encoding='utf-8') as r: 38 | for line in r: 39 | e, t = line.strip().split('\t') 40 | e_id, t_id = e2id[e], t2id[t] 41 | labels[e_id, t_id] = 1 42 | return labels 43 | 44 | def load_id(path, e2id): 45 | ret = set() 46 | with open(path, encoding='utf-8') as r: 47 | for line in r: 48 | e, t = line.strip().split('\t') 49 | ret.add(e2id[e]) 50 | return list(ret) 51 | 52 | def load_train_all_labels(data_dir, e2id, t2id): 53 | train_type_label = load_type_labels([ 54 | os.path.join(data_dir, 'ET_train.txt'), 55 | os.path.join(data_dir, 'ET_valid.txt') 56 | ], e2id, t2id) 57 | test_type_label = load_type_labels([ 58 | os.path.join(data_dir, 'ET_train.txt'), 59 | os.path.join(data_dir, 'ET_valid.txt'), 60 | os.path.join(data_dir, 'ET_test.txt'), 61 | ], e2id, t2id).half() 62 | 63 | return train_type_label, test_type_label 64 | 65 | def load_entity_cluster_type_pair_context(args, r2id, e2id): 66 | data_name_path = args["data_dir"] + '/' + args["dataset"] + '/ent2pair.npy' 67 | sample_ent2pair_size = args["sample_ent2pair_size"] 68 | ent2pair = np.load(data_name_path, allow_pickle=True).tolist() 69 | sample_ent2pair = [] 70 | for single_sample_ent2pair in ent2pair: 71 | single_sample_ent2pair_list = [] 72 | if sample_ent2pair_size != 1: 73 | sampled_index = np.random.choice(range(0, len(single_sample_ent2pair)), size=sample_ent2pair_size, 74 | replace=len(range(0, len(single_sample_ent2pair))) < sample_ent2pair_size) 75 | for i in sampled_index: 76 | clu_info = single_sample_ent2pair[i][0] + len(r2id) 77 | type_info = single_sample_ent2pair[i][1] + len(e2id) 78 | single_sample_ent2pair_list.append([clu_info, type_info]) 79 | else: 80 | clu_info = single_sample_ent2pair[0][0] + len(r2id) 81 | type_info = single_sample_ent2pair[0][1] + len(e2id) 82 | single_sample_ent2pair_list.append([clu_info, type_info]) 83 | sample_ent2pair.append(single_sample_ent2pair_list) 84 | 85 | return sample_ent2pair 86 | 87 | def evaluate(path, predict, all_true, e2id, t2id): 88 | logs = [] 89 | f = open('./rank.txt', 'w', encoding='utf-8') 90 | with open(path, 'r', encoding='utf-8') as r: 91 | for line in r: 92 | e, t = line.strip().split('\t') 93 | e, t = e2id[e], t2id[t] 94 | tmp = predict[e] - all_true[e] 95 | tmp[t] = predict[e, t] 96 | argsort = torch.argsort(tmp, descending=True) 97 | ranking = (argsort == t).nonzero() 98 | assert ranking.size(0) == 1 99 | ranking = ranking.item() + 1 100 | print(line.strip(), ranking, file=f) 101 | logs.append({ 102 | 'MRR': 1.0 / ranking, 103 | 'MR': float(ranking), 104 | 'HIT@1': 1.0 if ranking <= 1 else 0.0, 105 | 'HIT@3': 1.0 if ranking <= 3 else 0.0, 106 | 'HIT@10': 1.0 if ranking <= 10 else 0.0 107 | }) 108 | MRR = 0 109 | for metric in logs[0]: 110 | tmp = sum([_[metric] for _ in logs]) / len(logs) 111 | if metric == 'MRR': 112 | MRR = tmp 113 | logging.debug('%s: %f' % (metric, tmp)) 114 | return MRR 115 | 116 | def slight_fna_loss(predict, label, beta): 117 | loss = torch.nn.BCELoss(reduction='none') 118 | output = loss(predict, label) 119 | positive_loss = output * label 120 | negative_weight = predict.detach().clone() 121 | small_ids = negative_weight < 0.5 122 | large_ids = negative_weight >= 0.5 123 | 124 | negative_weight[small_ids] = beta * (3 * negative_weight[small_ids] - 2 * negative_weight[small_ids].pow(2)) 125 | negative_weight[large_ids] = beta * (negative_weight[large_ids] - 2 * negative_weight[large_ids].pow(2) + 1) 126 | 127 | negative_weight = negative_weight * (1 - label) 128 | negative_loss = negative_weight * output 129 | return positive_loss.mean(), negative_loss.mean() 130 | -------------------------------------------------------------------------------- /trm.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | def relative_attention_logits(query, key, relation): 8 | qk_matmul = torch.matmul(query, key.transpose(-2, -1)) 9 | 10 | if relation == None: 11 | return qk_matmul / math.sqrt(query.shape[-1]) 12 | 13 | q_t = query.permute(0, 2, 1, 3) 14 | r_t = relation.transpose(-2, -1) 15 | q_tr_t_matmul = torch.matmul(q_t, r_t) 16 | q_tr_tmatmul_t = q_tr_t_matmul.permute(0, 2, 1, 3) 17 | 18 | return (qk_matmul + q_tr_tmatmul_t) / math.sqrt(query.shape[-1]) 19 | 20 | def relative_attention_values(weight, value, relation): 21 | wv_matmul = torch.matmul(weight, value) 22 | 23 | if relation == None: 24 | return wv_matmul 25 | 26 | w_t = weight.permute(0, 2, 1, 3) 27 | w_tr_matmul = torch.matmul(w_t, relation) 28 | w_tr_matmul_t = w_tr_matmul.permute(0, 2, 1, 3) 29 | 30 | return wv_matmul + w_tr_matmul_t 31 | 32 | def clones(module_fn, N): 33 | return nn.ModuleList([module_fn() for _ in range(N)]) 34 | 35 | def attention_with_relations(query, key, value, relation_k, relation_v, mask=None, dropout=None): 36 | d_k = query.size(-1) 37 | scores = relative_attention_logits(query, key, relation_k) 38 | if mask is not None: 39 | scores = scores.masked_fill(mask == 0, -1e9) 40 | p_attn_orig = F.softmax(scores, dim=-1) 41 | if dropout is not None: 42 | p_attn = dropout(p_attn_orig) 43 | return relative_attention_values(p_attn, value, relation_v), p_attn_orig 44 | 45 | class MultiHeadedAttentionWithRelations(nn.Module): 46 | def __init__(self, h, d_model, dropout=0.1): 47 | super(MultiHeadedAttentionWithRelations, self).__init__() 48 | assert d_model % h == 0 49 | self.d_k = d_model // h 50 | self.h = h 51 | self.linears = clones(lambda: nn.Linear(d_model, d_model), 4) 52 | self.attn = None 53 | self.dropout = nn.Dropout(p=dropout) 54 | 55 | def forward(self, query, key, value, relation_k, relation_v, mask=None): 56 | if mask is not None: 57 | mask = mask.unsqueeze(1) 58 | nbatches = query.size(0) 59 | 60 | query, key, value = \ 61 | [l(x).view(nbatches, -1, self.h, self.d_k).transpose(1, 2) 62 | for l, x in zip(self.linears, (query, key, value))] 63 | 64 | x, self.attn = attention_with_relations( 65 | query, 66 | key, 67 | value, 68 | relation_k, 69 | relation_v, 70 | mask=mask, 71 | dropout=self.dropout) 72 | 73 | x = x.transpose(1, 2).contiguous() \ 74 | .view(nbatches, -1, self.h * self.d_k) 75 | return self.linears[-1](x) 76 | 77 | class Encoder(nn.Module): 78 | def __init__(self, layer, N, initializer_range, tie_layers=False): 79 | super(Encoder, self).__init__() 80 | if tie_layers: 81 | self.layer = layer() 82 | self.layers = [self.layer for _ in range(N)] 83 | else: 84 | self.layers = clones(layer, N) 85 | self.initializer_range = initializer_range 86 | self.apply(self.init_bert_weights) 87 | 88 | def forward(self, x, relation, mask): 89 | all_x = [] 90 | for layer in self.layers: 91 | x = layer(x, relation, mask) 92 | all_x.append(x) 93 | return all_x 94 | 95 | def init_bert_weights(self, module): 96 | if isinstance(module, (nn.Linear, nn.Embedding)): 97 | module.weight.data.normal_(mean=0.0, std=self.initializer_range) 98 | elif isinstance(module, nn.LayerNorm): 99 | module.bias.data.zero_() 100 | module.weight.data.fill_(1.0) 101 | if isinstance(module, nn.Linear) and module.bias is not None: 102 | module.bias.data.zero_() 103 | 104 | class SublayerConnection(nn.Module): 105 | def __init__(self, size, dropout): 106 | super(SublayerConnection, self).__init__() 107 | self.norm = nn.LayerNorm(size) 108 | self.dropout = nn.Dropout(dropout) 109 | 110 | def forward(self, x, sublayer): 111 | return self.norm(self.dropout(sublayer(x)) + x) 112 | 113 | class EncoderLayer(nn.Module): 114 | def __init__(self, size, self_attn, feed_forward, num_relation_kinds, dropout): 115 | super(EncoderLayer, self).__init__() 116 | self.self_attn = self_attn 117 | self.feed_forward = feed_forward 118 | self.sublayer = clones(lambda: SublayerConnection(size, dropout), 2) 119 | self.size = size 120 | 121 | if num_relation_kinds != 0: 122 | self.relation_k_emb = nn.Embedding(num_relation_kinds, self.self_attn.d_k) 123 | self.relation_v_emb = nn.Embedding(num_relation_kinds, self.self_attn.d_k) 124 | else: 125 | self.relation_k_emb = lambda x: None 126 | self.relation_v_emb = lambda x: None 127 | 128 | def forward(self, x, relation, mask): 129 | relation_k = self.relation_k_emb(relation) 130 | relation_v = self.relation_v_emb(relation) 131 | 132 | x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, relation_k, relation_v, mask)) 133 | return self.sublayer[1](x, self.feed_forward) 134 | 135 | class PositionwiseFeedForward(nn.Module): 136 | def __init__(self, d_model, d_ff, dropout=0.1): 137 | super(PositionwiseFeedForward, self).__init__() 138 | self.w_1 = nn.Linear(d_model, d_ff) 139 | self.w_2 = nn.Linear(d_ff, d_model) 140 | self.dropout = nn.Dropout(dropout) 141 | 142 | def forward(self, x): 143 | return self.w_2(self.dropout(F.gelu(self.w_1(x)))) 144 | -------------------------------------------------------------------------------- /run.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from utils import * 3 | from TET import TET 4 | from dataloader import ETDataset 5 | from torch.utils.data import DataLoader 6 | 7 | def main(args): 8 | use_cuda = args['cuda'] and torch.cuda.is_available() 9 | data_path = os.path.join(args['data_dir'], args['dataset']) 10 | save_path = os.path.join(args['save_dir'], args['save_path']) 11 | 12 | e2id = read_id(os.path.join(data_path, 'entities.tsv')) 13 | r2id = read_id(os.path.join(data_path, 'relations.tsv')) 14 | t2id = read_id(os.path.join(data_path, 'types.tsv')) 15 | c2id = read_id(os.path.join(data_path, 'clusters.tsv')) 16 | num_entities = len(e2id) 17 | num_rels = len(r2id) 18 | num_types = len(t2id) 19 | num_clusters = len(c2id) 20 | train_type_label, test_type_label = load_train_all_labels(data_path, e2id, t2id) 21 | if use_cuda: 22 | sample_ent2pair = torch.LongTensor(load_entity_cluster_type_pair_context(args, r2id, e2id)).cuda() 23 | train_dataset = ETDataset(args, "LMET_train.txt", e2id, r2id, t2id, c2id, 'train') 24 | valid_dataset = ETDataset(args, "LMET_valid.txt", e2id, r2id, t2id, c2id, 'valid') 25 | test_dataset = ETDataset(args, "LMET_test.txt", e2id, r2id, t2id, c2id, 'test') 26 | 27 | train_dataloader = DataLoader(train_dataset, 28 | batch_size=args['train_batch_size'], 29 | shuffle=True, 30 | collate_fn=ETDataset.collate_fn, 31 | num_workers=6) 32 | valid_dataloader = DataLoader(valid_dataset, 33 | batch_size=args['train_batch_size'], 34 | shuffle=False, 35 | collate_fn=ETDataset.collate_fn, 36 | num_workers=6) 37 | test_dataloader = DataLoader(test_dataset, 38 | batch_size=args['test_batch_size'], 39 | shuffle=False, 40 | collate_fn=ETDataset.collate_fn, 41 | num_workers=6) 42 | 43 | model = TET(args, num_entities, num_rels, num_types, num_clusters) 44 | 45 | if use_cuda: 46 | model = model.to('cuda') 47 | for name, param in model.named_parameters(): 48 | logging.debug('Parameter %s: %s, require_grad=%s' % (name, str(param.size()), str(param.requires_grad))) 49 | 50 | current_learning_rate = args['lr'] 51 | warm_up_steps = args['warm_up_steps'] 52 | optimizer = torch.optim.Adam( 53 | filter(lambda p: p.requires_grad, model.parameters()), 54 | lr=current_learning_rate 55 | ) 56 | 57 | max_valid_mrr = 0 58 | model.train() 59 | for epoch in range(args['max_epoch']): 60 | log = [] 61 | for sample_et_content, sample_kg_content, gt_ent in train_dataloader: 62 | type_label = train_type_label[gt_ent, :] 63 | if use_cuda: 64 | sample_et_content = sample_et_content.cuda() 65 | sample_kg_content = sample_kg_content.cuda() 66 | type_label = type_label.cuda() 67 | type_predict = model(sample_et_content, sample_kg_content, sample_ent2pair) 68 | 69 | if args['loss'] == 'BCE': 70 | bce_loss = torch.nn.BCELoss() 71 | type_loss = bce_loss(type_predict, type_label) 72 | type_pos_loss, type_neg_loss = type_loss, type_loss 73 | elif args['loss'] == 'SFNA': 74 | type_pos_loss, type_neg_loss = slight_fna_loss(type_predict, type_label, args['beta']) 75 | type_loss = type_pos_loss + type_neg_loss 76 | else: 77 | raise ValueError('loss %s is not defined' % args['loss']) 78 | 79 | log.append({ 80 | "loss": type_loss.item(), 81 | "pos_loss": type_pos_loss.item(), 82 | "neg_loss": type_neg_loss.item(), 83 | }) 84 | 85 | optimizer.zero_grad() 86 | type_loss.requires_grad_(True) 87 | type_loss.backward() 88 | optimizer.step() 89 | 90 | if epoch >= warm_up_steps: 91 | current_learning_rate = current_learning_rate / 5 92 | optimizer = torch.optim.Adam( 93 | filter(lambda p: p.requires_grad, model.parameters()), 94 | lr=current_learning_rate 95 | ) 96 | warm_up_steps = warm_up_steps * 2 97 | 98 | avg_type_loss = sum([_['loss'] for _ in log]) / len(log) 99 | avg_type_pos_loss = sum([_['pos_loss'] for _ in log]) / len(log) 100 | avg_type_neg_loss = sum([_['neg_loss'] for _ in log]) / len(log) 101 | logging.debug('epoch %d: loss: %f\tpos_loss: %f\tneg_loss: %f' % 102 | (epoch, avg_type_loss, avg_type_pos_loss, avg_type_neg_loss)) 103 | 104 | if epoch != 0 and epoch % args['valid_epoch'] == 0: 105 | model.eval() 106 | with torch.no_grad(): 107 | logging.debug('-----------------------valid step-----------------------') 108 | predict = torch.zeros(num_entities, num_types, dtype=torch.half) 109 | for sample_et_content, sample_kg_content, gt_ent in valid_dataloader: 110 | if use_cuda: 111 | sample_et_content = sample_et_content.cuda() 112 | sample_kg_content = sample_kg_content.cuda() 113 | predict[gt_ent] = model(sample_et_content, sample_kg_content, sample_ent2pair).cpu().half() 114 | valid_mrr = evaluate(os.path.join(data_path, 'ET_valid.txt'), predict, test_type_label, e2id, t2id) 115 | 116 | logging.debug('-----------------------test step-----------------------') 117 | predict = torch.zeros(num_entities, num_types, dtype=torch.half) 118 | for sample_et_content, sample_kg_content, gt_ent in test_dataloader: 119 | if use_cuda: 120 | sample_et_content = sample_et_content.cuda() 121 | sample_kg_content = sample_kg_content.cuda() 122 | predict[gt_ent] = model(sample_et_content, sample_kg_content, sample_ent2pair).cpu().half() 123 | evaluate(os.path.join(data_path, 'ET_test.txt'), predict, test_type_label, e2id, t2id) 124 | 125 | model.train() 126 | if valid_mrr < max_valid_mrr: 127 | logging.debug('early stop') 128 | break 129 | else: 130 | torch.save(model.state_dict(), os.path.join(save_path, 'best_model.pkl')) 131 | max_valid_mrr = valid_mrr 132 | 133 | # save embedding 134 | entity_embedding = model.entity.detach().cpu().numpy() 135 | np.save( 136 | os.path.join(save_path, 'entity_embedding'), 137 | entity_embedding 138 | ) 139 | relation_embedding = model.relation.detach().cpu().numpy() 140 | np.save( 141 | os.path.join(save_path, 'relation_embedding'), 142 | relation_embedding 143 | ) 144 | 145 | logging.debug('-----------------------best test step-----------------------') 146 | with torch.no_grad(): 147 | model.load_state_dict(torch.load(os.path.join(save_path, 'best_model.pkl'))) 148 | model.eval() 149 | predict = torch.zeros(num_entities, num_types, dtype=torch.half) 150 | for sample_et_content, sample_kg_content, gt_ent in test_dataloader: 151 | if use_cuda: 152 | sample_et_content = sample_et_content.cuda() 153 | sample_kg_content = sample_kg_content.cuda() 154 | predict[gt_ent] = model(sample_et_content, sample_kg_content, sample_ent2pair).cpu().half() 155 | evaluate(os.path.join(data_path, 'ET_test.txt'), predict, test_type_label, e2id, t2id) 156 | 157 | 158 | def get_params(): 159 | parser = argparse.ArgumentParser() 160 | parser.add_argument('--data_dir', type=str, default='./data') 161 | parser.add_argument('--dataset', type=str, default='FB15kET') 162 | parser.add_argument('--save_dir', type=str, default='save') 163 | parser.add_argument('--save_path', type=str, default='SFNA') 164 | parser.add_argument('--hidden_dim', type=int, default=100) 165 | parser.add_argument('--temperature', type=float, default=0.5) 166 | parser.add_argument('--lr', type=float, default=0.001) 167 | parser.add_argument('--train_batch_size', type=int, default=128) 168 | parser.add_argument('--test_batch_size', type=int, default=1) 169 | parser.add_argument('--cuda', action='store_true', default=True) 170 | parser.add_argument('--max_epoch', type=int, default=500) 171 | parser.add_argument('--valid_epoch', type=int, default=25) 172 | parser.add_argument('--beta', type=float, default=1.0) 173 | parser.add_argument('--loss', type=str, default='SFNA') 174 | 175 | # params for first trm layer 176 | parser.add_argument('--bert_nlayer', type=int, default=3) 177 | parser.add_argument('--bert_nhead', type=int, default=4) 178 | parser.add_argument('--bert_ff_dim', type=int, default=480) 179 | parser.add_argument('--bert_activation', type=str, default='gelu') 180 | parser.add_argument('--bert_hidden_dropout', type=float, default=0.2) 181 | parser.add_argument('--bert_attn_dropout', type=float, default=0.2) 182 | parser.add_argument('--local_pos_size', type=int, default=200) 183 | 184 | # params for pair trm layer 185 | parser.add_argument('--pair_layer', type=int, default=3) 186 | parser.add_argument('--pair_head', type=int, default=4) 187 | parser.add_argument('--pair_dropout', type=float, default=0.2) 188 | parser.add_argument('--pair_ff_dim', type=int, default=480) 189 | 190 | # params for second trm layer 191 | parser.add_argument('--trm_nlayer', type=int, default=3) 192 | parser.add_argument('--trm_nhead', type=int, default=4) 193 | parser.add_argument('--trm_hidden_dropout', type=float, default=0.2) 194 | parser.add_argument('--trm_attn_dropout', type=float, default=0.2) 195 | parser.add_argument('--trm_ff_dim', type=int, default=480) 196 | parser.add_argument('--global_pos_size', type=int, default=200) 197 | 198 | parser.add_argument('--pair_pooling', type=str, default='avg', choices=['max', 'avg', 'min']) 199 | parser.add_argument('--sample_et_size', type=int, default=3) 200 | parser.add_argument('--sample_kg_size', type=int, default=7) 201 | parser.add_argument('--sample_ent2pair_size', type=int, default=6) 202 | parser.add_argument('--warm_up_steps', default=50, type=int) 203 | parser.add_argument('--tt_ablation', type=str, default='all', choices=['all', 'triple', 'type'], 204 | help='ablation choice') 205 | parser.add_argument('--log_name', type=str, default='log') 206 | 207 | args, _ = parser.parse_known_args() 208 | print(args) 209 | return args 210 | 211 | 212 | if __name__ == '__main__': 213 | try: 214 | params = vars(get_params()) 215 | set_logger(params) 216 | main(params) 217 | except Exception as e: 218 | logging.exception(e) 219 | raise 220 | -------------------------------------------------------------------------------- /TET.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import trm 4 | 5 | from torch.nn import TransformerEncoder, TransformerEncoderLayer 6 | from pytorch_pretrained_bert.modeling import BertEncoder, BertConfig, BertLayerNorm 7 | 8 | class TET(nn.Module): 9 | def __init__(self, args, num_entities, num_rels, num_types, num_cluster): 10 | super(TET, self).__init__() 11 | self.embedding_dim = args['hidden_dim'] 12 | self.embedding_range = 10 / self.embedding_dim 13 | self.num_rels = num_rels + num_cluster 14 | self.use_cuda = args['cuda'] 15 | self.dataset = args['dataset'] 16 | self.sample_ent2pair_size = args['sample_ent2pair_size'] 17 | self.tt_ablation = args['tt_ablation'] 18 | self.pooling = args['pair_pooling'] 19 | self.device = torch.device('cuda') 20 | self.num_nodes = num_entities + num_types 21 | 22 | self.layer = TETLayer(args, self.embedding_dim, num_types, args['temperature']) 23 | 24 | self.entity = nn.Parameter(torch.randn(self.num_nodes, self.embedding_dim)) 25 | nn.init.uniform_(tensor=self.entity, a=-self.embedding_range, b=self.embedding_range) 26 | self.relation = nn.Parameter(torch.randn(self.num_rels, self.embedding_dim)) 27 | nn.init.uniform_(tensor=self.relation, a=-self.embedding_range, b=self.embedding_range) 28 | 29 | self.bert_nlayer = args['bert_nlayer'] 30 | self.bert_nhead = args['bert_nhead'] 31 | self.bert_ff_dim = args['bert_ff_dim'] 32 | self.bert_activation = args['bert_activation'] 33 | self.bert_hidden_dropout = args['bert_hidden_dropout'] 34 | self.bert_attn_dropout = args['bert_attn_dropout'] 35 | self.local_pos_size = args['local_pos_size'] 36 | self.bert_layer_norm = BertLayerNorm(self.embedding_dim, eps=1e-12) 37 | self.local_cls = nn.Parameter(torch.Tensor(1, self.embedding_dim)) 38 | torch.nn.init.normal_(self.local_cls, std=self.embedding_range) 39 | self.local_pos_embeds = nn.Embedding(self.local_pos_size, self.embedding_dim) 40 | torch.nn.init.normal_(self.local_pos_embeds.weight, std=self.embedding_range) 41 | bert_config = BertConfig(0, hidden_size=self.embedding_dim, 42 | num_hidden_layers=self.bert_nlayer // 2, 43 | num_attention_heads=self.bert_nhead, 44 | intermediate_size=self.bert_ff_dim, 45 | hidden_act=self.bert_activation, 46 | hidden_dropout_prob=self.bert_hidden_dropout, 47 | attention_probs_dropout_prob=self.bert_attn_dropout, 48 | max_position_embeddings=0, 49 | type_vocab_size=0, 50 | initializer_range=self.embedding_range) 51 | self.bert_encoder = BertEncoder(bert_config) 52 | 53 | self.pair_layer = args['pair_layer'] 54 | self.pair_head = args['pair_head'] 55 | self.pair_dropout = args['pair_dropout'] 56 | self.pair_ff_dim = args['pair_ff_dim'] 57 | self.pair_pos_embeds = nn.Embedding(1 + 2*self.sample_ent2pair_size, self.embedding_dim) 58 | torch.nn.init.normal_(self.pair_pos_embeds.weight, std=self.embedding_range) 59 | pair_encoder_layers = TransformerEncoderLayer(self.embedding_dim, self.pair_head, self.pair_ff_dim, self.pair_dropout) 60 | self.pair_encoder = TransformerEncoder(pair_encoder_layers, self.pair_layer) 61 | 62 | def convert_mask(self, attention_mask): 63 | attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) 64 | attention_mask = (1.0 - attention_mask.float()) * -10000.0 65 | return attention_mask 66 | 67 | def forward(self, et_content, kg_content, sample_ent2pair): 68 | batch_size, et_neighbor_size = et_content[:, :, 2].size() 69 | et_types = torch.index_select(self.entity, 0, et_content[:, :, 2].view(-1)).view(batch_size, et_neighbor_size, -1) 70 | et_relations_types = et_content[:, :, 1] 71 | et_relations = torch.index_select(self.relation, 0, et_relations_types.view(-1) % self.num_rels).view(batch_size, et_neighbor_size, -1) 72 | et_relations[et_relations_types >= self.num_rels] = et_relations[et_relations_types >= self.num_rels] * -1 73 | 74 | if 'YAGO' in self.dataset: 75 | # for YAGO dataset, we should use cluster and type context pair to represent the KG relation 76 | batch_size, kg_neighbor_size = kg_content[:, :, 2].size() 77 | kg_entities = torch.index_select(self.entity, 0, kg_content[:, :, 2].view(-1)).view(batch_size, kg_neighbor_size, -1) 78 | _, pair_neighbor_size, _ = sample_ent2pair.size() 79 | kg_entity2pair = torch.index_select(sample_ent2pair, 0, kg_content[:, :, 2].view(-1)).view(batch_size, kg_neighbor_size, pair_neighbor_size, -1) 80 | pair_cluster = kg_entity2pair[:, :, :, 0] 81 | pair_type = kg_entity2pair[:, :, :, 1] 82 | pair_cluster_embs = torch.index_select(self.relation, 0, pair_cluster.view(-1)).view(-1, pair_neighbor_size, self.embedding_dim) 83 | pair_type_embs = torch.index_select(self.entity, 0, pair_type.view(-1)).view(-1, pair_neighbor_size, self.embedding_dim) 84 | kg_relations_types = kg_content[:, :, 1] 85 | kg_relations = torch.index_select(self.relation, 0, kg_relations_types.view(-1) % self.num_rels).view(-1, 1, self.embedding_dim) 86 | pairs = torch.cat((pair_cluster_embs, pair_type_embs), 2).view(-1, 2 * pair_cluster_embs.shape[1], pair_cluster_embs.shape[2]) 87 | ent_pairs = torch.cat([kg_relations, pairs], 1).transpose(1, 0) # [1 + num_pairs, bs, emb_dim] 88 | ent_pairs_pos = torch.arange(ent_pairs.shape[0], dtype=torch.long, device=self.device).repeat(ent_pairs.shape[1], 1) 89 | ent_pairs_pos_embeddings = self.pair_pos_embeds(ent_pairs_pos).transpose(1, 0) 90 | ent_pairs_embs = ent_pairs + ent_pairs_pos_embeddings 91 | mask = torch.zeros((ent_pairs_embs.shape[1], ent_pairs_embs.shape[0])).bool().to(self.device) 92 | x = self.pair_encoder(ent_pairs_embs, src_key_padding_mask=mask) 93 | 94 | if self.pooling == 'max': 95 | x, _ = torch.max(x, dim=0) 96 | elif self.pooling == "avg": 97 | x = torch.mean(x, dim=0) 98 | elif self.pooling == "min": 99 | x, _ = torch.min(x, dim=0) 100 | kg_relations = x.view(batch_size, -1, self.embedding_dim) 101 | kg_relations[kg_relations_types >= self.num_rels] = kg_relations[kg_relations_types >= self.num_rels] * -1 102 | else: 103 | batch_size, kg_neighbor_size = kg_content[:, :, 2].size() 104 | kg_entities = torch.index_select(self.entity, 0, kg_content[:, :, 2].view(-1)).view(batch_size, kg_neighbor_size, -1) 105 | kg_relations_types = kg_content[:, :, 1] 106 | kg_relations = torch.index_select(self.relation, 0, kg_relations_types.view(-1) % self.num_rels).view(batch_size, kg_neighbor_size, -1) 107 | kg_relations[kg_relations_types >= self.num_rels] = kg_relations[kg_relations_types >= self.num_rels] * -1 108 | 109 | et_merge = torch.cat([et_types, et_relations], dim=1).view(-1, 2, self.embedding_dim) 110 | et_pos = self.local_pos_embeds(torch.arange(0, 3, device=self.device)).unsqueeze(0).repeat(et_merge.shape[0], 1, 1) 111 | et_merge = torch.cat([self.local_cls.expand(et_merge.size(0), 1, self.embedding_dim), et_merge], dim=1) + et_pos 112 | et_merge = self.bert_layer_norm(et_merge) 113 | et_merge = self.bert_encoder(et_merge, self.convert_mask(et_merge.new_ones(et_merge.size(0), et_merge.size(1), dtype=torch.long)), 114 | output_all_encoded_layers=False)[-1][:, 0].view(batch_size, -1, self.embedding_dim) 115 | 116 | kg_merge = torch.cat([kg_entities, kg_relations], dim=1).view(-1, 2, self.embedding_dim) 117 | kg_pos = self.local_pos_embeds(torch.arange(0, 3, device=self.device)).unsqueeze(0).repeat(kg_merge.shape[0], 1, 1) 118 | kg_merge = torch.cat([self.local_cls.expand(kg_merge.size(0), 1, self.embedding_dim), kg_merge], dim=1) + kg_pos 119 | kg_merge = self.bert_layer_norm(kg_merge) 120 | kg_merge = self.bert_encoder(kg_merge, self.convert_mask(kg_merge.new_ones(kg_merge.size(0), kg_merge.size(1), dtype=torch.long)), 121 | output_all_encoded_layers=False)[-1][:, 0].view(batch_size, -1, self.embedding_dim) 122 | if self.tt_ablation == 'all': 123 | et_kg_merge = torch.cat([et_types, et_relations, kg_entities, kg_relations], dim=1).view(batch_size, -1, self.embedding_dim) 124 | elif self.tt_ablation == 'triple': 125 | et_kg_merge = torch.cat([kg_entities, kg_relations], dim=1).view(batch_size, -1, self.embedding_dim) 126 | elif self.tt_ablation == 'type': 127 | et_kg_merge = torch.cat([et_types, et_relations], dim=1).view(batch_size, -1, self.embedding_dim) 128 | 129 | _, et_kg_size, _ = et_kg_merge.size() 130 | if et_kg_size >= self.local_pos_size-1: 131 | et_kg_merge = et_kg_merge[:, 0:self.local_pos_size-1, :] 132 | et_kg_size = self.local_pos_size-1 133 | et_kg_pos = self.local_pos_embeds(torch.arange(0, et_kg_size + 1, device=self.device)).unsqueeze(0).repeat(et_kg_merge.shape[0], 1, 1) 134 | et_kg_merge = torch.cat([self.local_cls.expand(et_kg_merge.size(0), 1, self.embedding_dim), et_kg_merge], dim=1) + et_kg_pos 135 | et_kg_merge = self.bert_layer_norm(et_kg_merge) 136 | et_kg_merge = self.bert_encoder(et_kg_merge, self.convert_mask(et_kg_merge.new_ones(et_merge.size(0), et_kg_merge.size(1), dtype=torch.long)), 137 | output_all_encoded_layers=False)[-1][:, 0].view(batch_size, -1, self.embedding_dim) 138 | 139 | if self.tt_ablation == 'all': 140 | local_embedding = torch.cat([et_merge, kg_merge], dim=1) 141 | elif self.tt_ablation == 'triple': 142 | local_embedding = kg_merge 143 | elif self.tt_ablation == 'type': 144 | local_embedding = et_merge 145 | global_embedding = et_kg_merge 146 | output = self.layer(local_embedding, global_embedding) 147 | 148 | return output 149 | 150 | 151 | class TETLayer(nn.Module): 152 | def __init__(self, args, embedding_dim, num_types, temperature): 153 | super(TETLayer, self).__init__() 154 | self.embedding_dim = embedding_dim 155 | self.num_types = num_types 156 | self.fc = nn.Linear(embedding_dim, num_types) 157 | self.temperature = temperature 158 | self.device = torch.device('cuda') 159 | 160 | self.trm_nlayer = args['trm_nlayer'] 161 | self.trm_nhead = args['trm_nhead'] 162 | self.trm_hidden_dropout = args['trm_hidden_dropout'] 163 | self.trm_attn_dropout = args['trm_attn_dropout'] 164 | self.trm_ff_dim = args['trm_ff_dim'] 165 | self.global_pos_size = args['global_pos_size'] 166 | self.embedding_range = 10 / self.embedding_dim 167 | 168 | self.global_cls = nn.Parameter(torch.Tensor(1, self.embedding_dim)) 169 | torch.nn.init.normal_(self.global_cls, std=self.embedding_range) 170 | self.pos_embeds = nn.Embedding(self.global_pos_size, self.embedding_dim) 171 | torch.nn.init.normal_(self.pos_embeds.weight, std=self.embedding_range) 172 | self.layer_norm = BertLayerNorm(self.embedding_dim, eps=1e-12) 173 | 174 | self.transformer_encoder = trm.Encoder( 175 | lambda: trm.EncoderLayer( 176 | self.embedding_dim, 177 | trm.MultiHeadedAttentionWithRelations( 178 | self.trm_nhead, 179 | self.embedding_dim, 180 | self.trm_attn_dropout), 181 | trm.PositionwiseFeedForward( 182 | self.embedding_dim, 183 | self.trm_ff_dim, 184 | self.trm_hidden_dropout), 185 | num_relation_kinds=0, 186 | dropout=self.trm_hidden_dropout), 187 | self.trm_nlayer, 188 | self.embedding_range, 189 | tie_layers=False) 190 | 191 | def convert_mask_trm(self, attention_mask): 192 | attention_mask = attention_mask.unsqueeze(1).repeat(1, attention_mask.size(1), 1) 193 | return attention_mask 194 | 195 | def forward(self, local_embedding, global_embedding): 196 | local_msg = torch.relu(local_embedding) 197 | predict1 = self.fc(local_msg) 198 | 199 | batch_size, neighbor_size, emb_size = local_embedding.size() 200 | attention_mask = torch.ones(batch_size, neighbor_size + 1).bool().to(self.device) 201 | second_local = torch.cat([self.global_cls.expand(batch_size, 1, emb_size), local_embedding], dim=1) 202 | pos = self.pos_embeds(torch.arange(0, 3).to(self.device)) 203 | second_local[:, 0] = second_local[:, 0] + pos[0].unsqueeze(0) 204 | second_local[:, 1] = second_local[:, 1] + pos[1].unsqueeze(0) 205 | second_local[:, 2:] = second_local[:, 2:] + pos[2].view(1, 1, -1) 206 | second_local = self.layer_norm(second_local) 207 | second_local = self.transformer_encoder(second_local, None, self.convert_mask_trm(attention_mask)) 208 | second_local = second_local[-1][:, :2][:, 0].unsqueeze(1) 209 | predict2 = self.fc(torch.relu(second_local)) 210 | predict3 = self.fc(torch.relu(global_embedding)) 211 | 212 | predict = torch.cat([predict1, predict2, predict3], dim=1) 213 | weight = torch.softmax(self.temperature * predict, dim=1) 214 | predict = (predict * weight.detach()).sum(1).sigmoid() 215 | 216 | return predict 217 | 218 | --------------------------------------------------------------------------------